File tree Expand file tree Collapse file tree 1 file changed +9
-4
lines changed
Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -135,11 +135,12 @@ def get_opt_state(self):
135135 artifact = wandb .run .use_artifact (state_artifact )
136136 else :
137137 artifact = wandb .Api ().artifact (state_artifact )
138- artifact_dir = artifact .download (tmp_dir )
139138 if artifact .metadata .get ("bucket_path" ):
139+ # we will read directly file contents
140140 self .restore_state = artifact .metadata ["bucket_path" ]
141141 else :
142- self .restore_state = Path (artifact_dir ) / "opt_state.msgpack"
142+ artifact_dir = artifact .download (tmp_dir )
143+ self .restore_state = str (Path (artifact_dir ) / "opt_state.msgpack" )
143144
144145 if self .restore_state .startswith ("gs://" ):
145146 bucket_path = Path (self .restore_state [5 :]) / "opt_state.msgpack"
@@ -1130,7 +1131,9 @@ def run_save_model(state, eval_metrics=None):
11301131 type = "DalleBart_model" ,
11311132 metadata = metadata ,
11321133 )
1133- if not use_bucket :
1134+ if use_bucket :
1135+ artifact .add_reference (metadata ["bucket_path" ])
1136+ else :
11341137 for filename in [
11351138 "config.json" ,
11361139 "flax_model.msgpack" ,
@@ -1153,7 +1156,9 @@ def run_save_model(state, eval_metrics=None):
11531156 type = "DalleBart_state" ,
11541157 metadata = metadata ,
11551158 )
1156- if not use_bucket :
1159+ if use_bucket :
1160+ artifact_state .add_reference (metadata ["bucket_path" ])
1161+ else :
11571162 artifact_state .add_file (
11581163 f"{ Path (training_args .output_dir ) / 'opt_state.msgpack' } "
11591164 )
You can’t perform that action at this time.
0 commit comments