@@ -37,44 +37,44 @@ def setup_mlflow_experiment():
3737@export
3838@debug (log = _log )
3939def load_model_from_mlflow_registry (model_name , version = 1 , tag = None ):
40- # set tracking uri
41- mlflow .set_tracking_uri (os .getenv ('MLFLOW_TRACKING_URI' ))
40+ # set tracking uri
41+ mlflow .set_tracking_uri (os .getenv ('MLFLOW_TRACKING_URI' ))
4242
43- if version :
44- # Load by specific version
45- model_uri = f"models:/{ model_name } /{ version } "
46- local_path = os .path .join (os .getcwd (), 'MLFLOW' , f"{ model_name } /{ version } " )
47- elif tag :
48- # Load by tag (if the tag is set in the UI)
49- model_uri = f"models:/{ model_name } /{ tag } "
50- local_path = os .path .join (os .getcwd (), 'MLFLOW' , f"{ model_name } /{ tag } " )
51- else :
52- raise ValueError ("Either version or tag must be specified for model loading." )
53-
54- os .makedirs (local_path , exist_ok = True )
55- model = mlflow .pytorch .load_model (model_uri , map_location = torch .device (check_backend ()), dst_path = local_path )
56- return model
43+ if version :
44+ # Load by specific version
45+ model_uri = f"models:/{ model_name } /{ version } "
46+ local_path = os .path .join (os .getcwd (), 'MLFLOW' , f"{ model_name } /{ version } " )
47+ elif tag :
48+ # Load by tag (if the tag is set in the UI)
49+ model_uri = f"models:/{ model_name } /{ tag } "
50+ local_path = os .path .join (os .getcwd (), 'MLFLOW' , f"{ model_name } /{ tag } " )
51+ else :
52+ raise ValueError ("Either version or tag must be specified for model loading." )
53+
54+ os .makedirs (local_path , exist_ok = True )
55+ model = mlflow .pytorch .load_model (model_uri , map_location = torch .device (check_backend ()), dst_path = local_path )
56+ return model
5757
5858@export
5959@debug (log = _log )
6060def load_model_from_mlflow (run_name , scaler = True , provenance = False ):
61- # set tracking uri
62- mlflow .set_tracking_uri (os .getenv ('MLFLOW_TRACKING_URI' ))
61+ # set tracking uri
62+ mlflow .set_tracking_uri (os .getenv ('MLFLOW_TRACKING_URI' ))
6363
64- run_id = mlflow .search_runs (filter_string = f"run_name='{ run_name } '" )['run_id' ].values [0 ]
64+ run_id = mlflow .search_runs (filter_string = f"run_name='{ run_name } '" )['run_id' ].values [0 ]
6565
66- local_path = os .path .join (os .getcwd (), 'MLFLOW' , f"{ run_name } " )
67- os .makedirs (local_path , exist_ok = True )
68- print (f"Data from MLFlow downloaded in: { local_path } " )
66+ local_path = os .path .join (os .getcwd (), 'MLFLOW' , f"{ run_name } " )
67+ os .makedirs (local_path , exist_ok = True )
68+ print (f"Data from MLFlow downloaded in: { local_path } " )
6969
70- client = mlflow .MlflowClient ()
71- if scaler :
72- artifact_path = client .download_artifacts (run_id = run_id , path = "scaler" , dst_path = local_path )
73- if provenance :
74- artifact_path = client .download_artifacts (run_id = run_id , path = f"provgraph_{ CONFIG .mlflow .EXPERIMENT_NAME } .svg" , dst_path = local_path )
75- artifact_path = client .download_artifacts (run_id = run_id , path = f"provgraph_{ CONFIG .mlflow .EXPERIMENT_NAME } .json" , dst_path = local_path )
70+ client = mlflow .MlflowClient ()
71+ if scaler :
72+ artifact_path = client .download_artifacts (run_id = run_id , path = "scaler" , dst_path = local_path )
73+ if provenance :
74+ artifact_path = client .download_artifacts (run_id = run_id , path = f"provgraph_{ CONFIG .mlflow .EXPERIMENT_NAME } .svg" , dst_path = local_path )
75+ artifact_path = client .download_artifacts (run_id = run_id , path = f"provgraph_{ CONFIG .mlflow .EXPERIMENT_NAME } .json" , dst_path = local_path )
7676
77- model_uri = f'runs:/{ run_id } /last_model'
78- model = mlflow .pytorch .load_model (model_uri , map_location = torch .device (check_backend ()), dst_path = local_path )
77+ model_uri = f'runs:/{ run_id } /last_model'
78+ model = mlflow .pytorch .load_model (model_uri , map_location = torch .device (check_backend ()), dst_path = local_path )
7979
80- return model
80+ return model
0 commit comments