Skip to content

Commit c8e86a8

Browse files
committed
- ED: Documentation
1 parent 8c6a1a0 commit c8e86a8

File tree

5 files changed

+538
-60
lines changed

5 files changed

+538
-60
lines changed

Fires/_utilities/utils_mlflow.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,44 +37,44 @@ def setup_mlflow_experiment():
3737
@export
3838
@debug(log=_log)
3939
def load_model_from_mlflow_registry(model_name, version=1, tag=None):
40-
# set tracking uri
41-
mlflow.set_tracking_uri(os.getenv('MLFLOW_TRACKING_URI'))
40+
# set tracking uri
41+
mlflow.set_tracking_uri(os.getenv('MLFLOW_TRACKING_URI'))
4242

43-
if version:
44-
# Load by specific version
45-
model_uri = f"models:/{model_name}/{version}"
46-
local_path = os.path.join(os.getcwd(), 'MLFLOW', f"{model_name}/{version}")
47-
elif tag:
48-
# Load by tag (if the tag is set in the UI)
49-
model_uri = f"models:/{model_name}/{tag}"
50-
local_path = os.path.join(os.getcwd(), 'MLFLOW', f"{model_name}/{tag}")
51-
else:
52-
raise ValueError("Either version or tag must be specified for model loading.")
53-
54-
os.makedirs(local_path, exist_ok=True)
55-
model = mlflow.pytorch.load_model(model_uri, map_location=torch.device(check_backend()), dst_path=local_path)
56-
return model
43+
if version:
44+
# Load by specific version
45+
model_uri = f"models:/{model_name}/{version}"
46+
local_path = os.path.join(os.getcwd(), 'MLFLOW', f"{model_name}/{version}")
47+
elif tag:
48+
# Load by tag (if the tag is set in the UI)
49+
model_uri = f"models:/{model_name}/{tag}"
50+
local_path = os.path.join(os.getcwd(), 'MLFLOW', f"{model_name}/{tag}")
51+
else:
52+
raise ValueError("Either version or tag must be specified for model loading.")
53+
54+
os.makedirs(local_path, exist_ok=True)
55+
model = mlflow.pytorch.load_model(model_uri, map_location=torch.device(check_backend()), dst_path=local_path)
56+
return model
5757

5858
@export
5959
@debug(log=_log)
6060
def load_model_from_mlflow(run_name, scaler=True, provenance=False):
61-
# set tracking uri
62-
mlflow.set_tracking_uri(os.getenv('MLFLOW_TRACKING_URI'))
61+
# set tracking uri
62+
mlflow.set_tracking_uri(os.getenv('MLFLOW_TRACKING_URI'))
6363

64-
run_id = mlflow.search_runs(filter_string=f"run_name='{run_name}'")['run_id'].values[0]
64+
run_id = mlflow.search_runs(filter_string=f"run_name='{run_name}'")['run_id'].values[0]
6565

66-
local_path = os.path.join(os.getcwd(), 'MLFLOW', f"{run_name}")
67-
os.makedirs(local_path, exist_ok=True)
68-
print(f"Data from MLFlow downloaded in: {local_path}")
66+
local_path = os.path.join(os.getcwd(), 'MLFLOW', f"{run_name}")
67+
os.makedirs(local_path, exist_ok=True)
68+
print(f"Data from MLFlow downloaded in: {local_path}")
6969

70-
client = mlflow.MlflowClient()
71-
if scaler:
72-
artifact_path = client.download_artifacts(run_id=run_id, path="scaler", dst_path=local_path)
73-
if provenance:
74-
artifact_path = client.download_artifacts(run_id=run_id, path=f"provgraph_{CONFIG.mlflow.EXPERIMENT_NAME}.svg", dst_path=local_path)
75-
artifact_path = client.download_artifacts(run_id=run_id, path=f"provgraph_{CONFIG.mlflow.EXPERIMENT_NAME}.json", dst_path=local_path)
70+
client = mlflow.MlflowClient()
71+
if scaler:
72+
artifact_path = client.download_artifacts(run_id=run_id, path="scaler", dst_path=local_path)
73+
if provenance:
74+
artifact_path = client.download_artifacts(run_id=run_id, path=f"provgraph_{CONFIG.mlflow.EXPERIMENT_NAME}.svg", dst_path=local_path)
75+
artifact_path = client.download_artifacts(run_id=run_id, path=f"provgraph_{CONFIG.mlflow.EXPERIMENT_NAME}.json", dst_path=local_path)
7676

77-
model_uri = f'runs:/{run_id}/last_model'
78-
model = mlflow.pytorch.load_model(model_uri, map_location=torch.device(check_backend()), dst_path=local_path)
77+
model_uri = f'runs:/{run_id}/last_model'
78+
model = mlflow.pytorch.load_model(model_uri, map_location=torch.device(check_backend()), dst_path=local_path)
7979

80-
return model
80+
return model

docs/data.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ All variables have:
3232

3333
## Overview
3434

35-
<p align="justify"> The dataset used to carry on the development of the code is composed by 9 climate variables. </p>
35+
<p align="justify"> The dataset used to carry on the development of the code is composed by 8 climate variables. </p>
3636

3737
### Drivers
3838

@@ -44,7 +44,6 @@ All variables have:
4444
| _Land-Sea Mask_ | `sftlf` | `lsm` | (lat, lon) | 0-1 | ERA5 |
4545
| _Land Surface Temperature at Day_ | `ts` | `lst_day` | (**time**, lat, lon) | K | Nasa MODIS MOD11C1, MOD13C1, MCD15A2 |
4646
| _Relative Humidity_ | `hur` | `rel_hum` | (**time**, lat, lon) | % | ERA5 |
47-
| _Surface net Solar Radiation_ | `rss` | `ssr` | (**time**, lat, lon) | $\small{MJm^{-2}}$ |ERA5 |
4847
| _Sea Surface Temperature_ | `tos` | `sst` | (**time**, lat, lon) | K | ERA5 |
4948
| _Temperature at 2 meters - Min_ | `tasmin` | `t2m_min` | (**time**, lat, lon) | K | ERA5 |
5049
| _Total Precipitation_ | `pr` | `tp` | (**time**, lat, lon) | m | ERA5 |

0 commit comments

Comments
 (0)