@@ -135,8 +135,21 @@ def get_opt_state(self):
135135 else :
136136 artifact = wandb .Api ().artifact (state_artifact )
137137 artifact_dir = artifact .download (tmp_dir )
138- self .restore_state = Path (artifact_dir ) / "opt_state.msgpack"
139- return Path (self .restore_state ).open ("rb" )
138+ if artifact .metadata .get ("bucket_path" ):
139+ self .restore_state = artifact .metadata ["bucket_path" ]
140+ else :
141+ self .restore_state = Path (artifact_dir ) / "opt_state.msgpack"
142+
143+ if self .restore_state .startswith ("gs://" ):
144+ bucket_path = Path (self .restore_state [5 :]) / "opt_state.msgpack"
145+ bucket , blob_name = str (bucket_path ).split ("/" , 1 )
146+ client = storage .Client ()
147+ bucket = client .bucket (bucket )
148+ blob = bucket .blob (blob_name )
149+ return blob .download_as_bytes ()
150+
151+ with Path (self .restore_state ).open ("rb" ) as f :
152+ return f .read ()
140153
141154
142155@dataclass
@@ -788,9 +801,7 @@ def init_state(params):
788801
789802 else :
790803 # load opt_state
791- opt_state_file = model_args .get_opt_state ()
792- opt_state = from_bytes (opt_state_shape , opt_state_file .read ())
793- opt_state_file .close ()
804+ opt_state = from_bytes (opt_state_shape , model_args .get_opt_state ())
794805
795806 # restore other attributes
796807 attr_state = {
@@ -1060,15 +1071,15 @@ def run_save_model(state, eval_metrics=None):
10601071 client = storage .Client ()
10611072 bucket = client .bucket (bucket )
10621073 for filename in Path (output_dir ).glob ("*" ):
1063- blob_name = str (Path (dir_path ) / filename .name )
1074+ blob_name = str (Path (dir_path ) / "model" / filename .name )
10641075 blob = bucket .blob (blob_name )
10651076 blob .upload_from_filename (str (filename ))
10661077 tmp_dir .cleanup ()
10671078
10681079 # save state
10691080 opt_state = jax .device_get (state .opt_state )
10701081 if use_bucket :
1071- blob_name = str (Path (dir_path ) / "opt_state.msgpack" )
1082+ blob_name = str (Path (dir_path ) / "state" / " opt_state.msgpack" )
10721083 blob = bucket .blob (blob_name )
10731084 blob .upload_from_file (io .BytesIO (to_bytes (opt_state )))
10741085 else :
@@ -1088,10 +1099,10 @@ def run_save_model(state, eval_metrics=None):
10881099 metadata ["num_params" ] = num_params
10891100 if eval_metrics is not None :
10901101 metadata ["eval" ] = eval_metrics
1091- if use_bucket :
1092- metadata ["bucket_path" ] = bucket_path
10931102
10941103 # create model artifact
1104+ if use_bucket :
1105+ metadata ["bucket_path" ] = f"gs://{ bucket_path } /model"
10951106 artifact = wandb .Artifact (
10961107 name = f"model-{ wandb .run .id } " ,
10971108 type = "DalleBart_model" ,
@@ -1113,6 +1124,8 @@ def run_save_model(state, eval_metrics=None):
11131124 wandb .run .log_artifact (artifact )
11141125
11151126 # create state artifact
1127+ if use_bucket :
1128+ metadata ["bucket_path" ] = f"gs://{ bucket_path } /state"
11161129 artifact_state = wandb .Artifact (
11171130 name = f"state-{ wandb .run .id } " ,
11181131 type = "DalleBart_state" ,
0 commit comments