CarbonTracker model training¶
This notebook outlines the entire workflow to load and preprocess the following data sets, to be able to train a ML model on CarbonTracker's carbon flux:
- CarbonTracker
- ERA5 (monthly)
- ERA5-land (monthly)
- SPEI (monthly)
- MODIS (monthly)
- Biomass (yearly)
- Copernicus Landcover (yearly)
First follow the data downloading and config setup instuctions.
If you run this notebook on Surf Research Cloud, you shouldn't need to do this anymore.
import datetime
from pathlib import Path
from dask.distributed import Client
from excited_workflow import train_carbontracker_model
client = Client()
Define the paths for the carbon tracker dataset, the regions dataset and create output directory, as well as define the datasets that need to be included, the input variables (x_keys) for the model and the target variable (y_key).
carbontracker_file= Path("/data/volume_2/EXCITED_prepped_data/CT2022.flux1x1-monthly.nc")
regions_file = Path("/data/volume_2/EXCITED_prepped_data/regions.nc")
output_path = Path.home()
time = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")
output_dir = output_path / f"carbon_tracker-{time}"
output_dir.mkdir(parents=True, exist_ok=True)
required_datasets = [
"biomass",
"spei",
"modis",
"era5_monthly",
"era5_land_monthly",
"copernicus_landcover"
]
X_keys = ["d2m", "mslhf", "msshf", "ssr", "str", "t2m", "spei", "NIRv", "skt",
"stl1", "swvl1", "lccs_class"]
y_key = "bio_flux_opt"
Merge the desired datasets into a single xr.Dataset with the same dimensions as the carbon tracker dataset. To limit the analyis to Transcom region 2 (North America) we require the regions.nc file:
ds_na, df_na = train_carbontracker_model.mask_region(regions_file, carbontracker_file, required_datasets, 2, y_key)
2024-02-28 14:20:49,478 - distributed.utils_perf - WARNING - full garbage collections took 20% CPU time recently (threshold: 10%)
ds_na
<xarray.Dataset>
Dimensions: (time: 240, latitude: 180, longitude: 360)
Coordinates:
* longitude (longitude) float64 -179.5 -178.5 -177.5 ... 178.5 179.5
* latitude (latitude) float64 -89.5 -88.5 -87.5 ... 87.5 88.5 89.5
* time (time) datetime64[ns] 2000-01-01 2000-02-01 ... 2019-12-01
Data variables: (12/26)
bio_flux_opt (time, latitude, longitude) float64 nan nan ... nan nan
transcom_regions (latitude, longitude, time) float64 nan nan ... nan nan
biomass (time, latitude, longitude) float64 nan nan ... nan nan
spei (time, latitude, longitude) float64 nan nan ... nan nan
NDVI (time, latitude, longitude) float64 nan nan ... nan nan
NIRv (time, latitude, longitude) float64 nan nan ... nan nan
... ...
stl4 (time, latitude, longitude) float64 nan nan ... nan nan
swvl1 (time, latitude, longitude) float64 nan nan ... nan nan
swvl2 (time, latitude, longitude) float64 nan nan ... nan nan
swvl3 (time, latitude, longitude) float64 nan nan ... nan nan
swvl4 (time, latitude, longitude) float64 nan nan ... nan nan
lccs_class (time, latitude, longitude) float32 nan nan ... nan nan
Attributes:
averaging_period_length_hours: 744
email: carbontracker.team@noaa.gov
url: http://carbontracker.noaa.gov
institution: NOAA Earth System Research Laboratory
Conventions: CF-1.5
history: Time-stamp: <Orion-login-1.HPC.MsState.Ed...
NCO: netCDF Operators version 4.9.3 (Homepage ...
version: CT2022 1x1 3-hourly fluxes as of 2023-01-...Validate the model by splitting the dataset into 5 groups. Train the model over 4 groups and predict over the remaining group iteratively. Output rmse netcdfs and scatterplots are stored in the output directory.
train_carbontracker_model.validate_model(ds_na, 5, X_keys, y_key, required_datasets, output_dir)
2024-02-28 14:23:16,133 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%)
| Model | MAE | MSE | RMSE | R2 | RMSLE | MAPE | TT (Sec) | |
|---|---|---|---|---|---|---|---|---|
| lightgbm | Light Gradient Boosting Machine | 0.00000 | 0.00000 | 0.00000 | 0.77820 | 0.00000 | 3.84560 | 1.16000 |
2024-02-28 14:23:21,070 - distributed.utils_perf - WARNING - full garbage collections took 13% CPU time recently (threshold: 10%) 2024-02-28 14:23:21,656 - distributed.utils_perf - WARNING - full garbage collections took 13% CPU time recently (threshold: 10%) 2024-02-28 14:23:22,550 - distributed.utils_perf - WARNING - full garbage collections took 13% CPU time recently (threshold: 10%) 2024-02-28 14:23:22,862 - distributed.utils_perf - WARNING - full garbage collections took 13% CPU time recently (threshold: 10%) 2024-02-28 14:23:24,434 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%)
2024-02-28 14:23:25,220 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%) 2024-02-28 14:23:28,429 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%)
| Model | MAE | MSE | RMSE | R2 | RMSLE | MAPE | TT (Sec) | |
|---|---|---|---|---|---|---|---|---|
| lightgbm | Light Gradient Boosting Machine | 0.00000 | 0.00000 | 0.00000 | 0.76360 | 0.00000 | 4.67300 | 1.15000 |
2024-02-28 14:23:30,949 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%) 2024-02-28 14:23:31,975 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%) 2024-02-28 14:23:32,294 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%) 2024-02-28 14:23:33,869 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%)
2024-02-28 14:23:34,635 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%) 2024-02-28 14:23:37,699 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%)
| Model | MAE | MSE | RMSE | R2 | RMSLE | MAPE | TT (Sec) | |
|---|---|---|---|---|---|---|---|---|
| lightgbm | Light Gradient Boosting Machine | 0.00000 | 0.00000 | 0.00000 | 0.77220 | 0.00000 | 4.60550 | 1.13000 |
2024-02-28 14:23:40,250 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%) 2024-02-28 14:23:41,342 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%) 2024-02-28 14:23:41,726 - distributed.utils_perf - WARNING - full garbage collections took 15% CPU time recently (threshold: 10%) 2024-02-28 14:23:43,331 - distributed.utils_perf - WARNING - full garbage collections took 15% CPU time recently (threshold: 10%)
2024-02-28 14:23:44,112 - distributed.utils_perf - WARNING - full garbage collections took 15% CPU time recently (threshold: 10%) 2024-02-28 14:23:47,234 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%)
| Model | MAE | MSE | RMSE | R2 | RMSLE | MAPE | TT (Sec) | |
|---|---|---|---|---|---|---|---|---|
| lightgbm | Light Gradient Boosting Machine | 0.00000 | 0.00000 | 0.00000 | 0.75880 | 0.00000 | 6.70420 | 1.13000 |
2024-02-28 14:23:49,741 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%) 2024-02-28 14:23:50,753 - distributed.utils_perf - WARNING - full garbage collections took 14% CPU time recently (threshold: 10%) 2024-02-28 14:23:51,069 - distributed.utils_perf - WARNING - full garbage collections took 15% CPU time recently (threshold: 10%) 2024-02-28 14:23:52,678 - distributed.utils_perf - WARNING - full garbage collections took 15% CPU time recently (threshold: 10%)
2024-02-28 14:23:53,447 - distributed.utils_perf - WARNING - full garbage collections took 16% CPU time recently (threshold: 10%) 2024-02-28 14:23:56,571 - distributed.utils_perf - WARNING - full garbage collections took 15% CPU time recently (threshold: 10%)
| Model | MAE | MSE | RMSE | R2 | RMSLE | MAPE | TT (Sec) | |
|---|---|---|---|---|---|---|---|---|
| lightgbm | Light Gradient Boosting Machine | 0.00000 | 0.00000 | 0.00000 | 0.75560 | 0.00000 | 4.28130 | 1.13000 |
2024-02-28 14:23:59,070 - distributed.utils_perf - WARNING - full garbage collections took 15% CPU time recently (threshold: 10%) 2024-02-28 14:24:00,199 - distributed.utils_perf - WARNING - full garbage collections took 15% CPU time recently (threshold: 10%) 2024-02-28 14:24:00,592 - distributed.utils_perf - WARNING - full garbage collections took 15% CPU time recently (threshold: 10%) 2024-02-28 14:24:02,259 - distributed.utils_perf - WARNING - full garbage collections took 16% CPU time recently (threshold: 10%)
2024-02-28 14:24:03,042 - distributed.utils_perf - WARNING - full garbage collections took 21% CPU time recently (threshold: 10%)
<Figure size 600x600 with 0 Axes>
Train the model over the entire dataset.
pycs, model = train_carbontracker_model.train_model(df_na, X_keys, y_key)
2024-02-28 14:24:06,208 - distributed.utils_perf - WARNING - full garbage collections took 22% CPU time recently (threshold: 10%)
| Model | MAE | MSE | RMSE | R2 | RMSLE | MAPE | TT (Sec) | |
|---|---|---|---|---|---|---|---|---|
| lightgbm | Light Gradient Boosting Machine | 0.00000 | 0.00000 | 0.00000 | 0.76670 | 0.00000 | 4.41510 | 1.53000 |
2024-02-28 14:24:12,025 - distributed.utils_perf - WARNING - full garbage collections took 20% CPU time recently (threshold: 10%) 2024-02-28 14:24:12,641 - distributed.utils_perf - WARNING - full garbage collections took 20% CPU time recently (threshold: 10%) 2024-02-28 14:24:13,643 - distributed.utils_perf - WARNING - full garbage collections took 20% CPU time recently (threshold: 10%) 2024-02-28 14:24:13,977 - distributed.utils_perf - WARNING - full garbage collections took 20% CPU time recently (threshold: 10%) 2024-02-28 14:24:15,854 - distributed.utils_perf - WARNING - full garbage collections took 20% CPU time recently (threshold: 10%)
Save the model to ONNX in the output directory.
train_carbontracker_model.save_model(pycs, model, output_dir)
The maximum opset needed by this model is only 8.
Create dataframe to run the model with.
df = df_na[X_keys]
df.head()
| d2m | mslhf | msshf | ssr | str | t2m | spei | NIRv | skt | stl1 | swvl1 | lccs_class | |||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| time | latitude | longitude | ||||||||||||
| 2000-08-01 | 16.5 | -97.5 | 291.160736 | -123.552444 | -29.256828 | 17632392.0 | -4539167.0 | 294.455078 | 0.402637 | 0.264352 | 293.800140 | 294.488708 | 0.405656 | 90.0 |
| -96.5 | 287.451385 | -98.636139 | -44.350342 | 17875046.0 | -5688396.0 | 291.743347 | 0.859432 | 0.186314 | 291.346344 | 292.329987 | 0.441099 | 120.0 | ||
| -95.5 | 292.602661 | -88.933502 | -76.770813 | 18379498.0 | -4277693.5 | 298.020874 | -0.002091 | 0.245570 | 298.976379 | 299.503906 | 0.340303 | 60.0 | ||
| 17.5 | -100.5 | 290.568329 | -118.364532 | -51.205177 | 17707054.0 | -3109625.0 | 293.402252 | 0.364373 | 0.270905 | 293.096253 | 293.790924 | 0.412808 | 90.0 | |
| -99.5 | 290.078156 | -119.800644 | -42.374344 | 17862124.0 | -3914028.0 | 293.219940 | 0.007593 | 0.249498 | 291.973236 | 292.737579 | 0.406332 | 120.0 |
Open model and run it over the dataframe to check it was saved correctly.
from onnxruntime import InferenceSession
with open(output_dir / "lightgbm.onnx", "rb") as f:
model = f.read()
sess = InferenceSession(model)
predictions_onnx = sess.run(None, {'X': df.to_numpy()})[0]
predictions_onnx
array([[-1.2222255e-06],
[-7.1354458e-07],
[-3.5551579e-07],
...,
[ 4.7076261e-07],
[ 4.6052673e-07],
[ 1.2324770e-07]], dtype=float32)