1818Script adapted from run_summarization_flax.py
1919"""
2020
21- import json
21+ import io
2222import logging
2323import os
2424import sys
4141from flax .serialization import from_bytes , to_bytes
4242from flax .training import train_state
4343from flax .training .common_utils import onehot
44+ from google .cloud import storage
4445from jax .experimental import PartitionSpec , maps
4546from jax .experimental .compilation_cache import compilation_cache as cc
4647from jax .experimental .pjit import pjit , with_sharding_constraint
5960 "/home/boris/dalle-mini/jax_cache" , max_cache_size_bytes = 5 * 2 ** 30
6061)
6162
62-
6363logger = logging .getLogger (__name__ )
6464
6565
@@ -123,17 +123,20 @@ def get_metadata(self):
123123 else :
124124 return dict ()
125125
126- def get_opt_state (self , tmp_dir ):
127- if self .restore_state is True :
128- # wandb artifact
129- state_artifact = self .model_name_or_path .replace ("/model-" , "/state-" , 1 )
130- if jax .process_index () == 0 :
131- artifact = wandb .run .use_artifact (state_artifact )
132- else :
133- artifact = wandb .Api ().artifact (state_artifact )
134- artifact_dir = artifact .download (tmp_dir )
135- self .restore_state = Path (artifact_dir ) / "opt_state.msgpack"
136- return Path (self .restore_state ).open ("rb" )
126+ def get_opt_state (self ):
127+ with tempfile .TemporaryDirectory () as tmp_dir : # avoid multiple artifact copies
128+ if self .restore_state is True :
129+ # wandb artifact
130+ state_artifact = self .model_name_or_path .replace (
131+ "/model-" , "/state-" , 1
132+ )
133+ if jax .process_index () == 0 :
134+ artifact = wandb .run .use_artifact (state_artifact )
135+ else :
136+ artifact = wandb .Api ().artifact (state_artifact )
137+ artifact_dir = artifact .download (tmp_dir )
138+ self .restore_state = Path (artifact_dir ) / "opt_state.msgpack"
139+ return Path (self .restore_state ).open ("rb" )
137140
138141
139142@dataclass
@@ -785,10 +788,9 @@ def init_state(params):
785788
786789 else :
787790 # load opt_state
788- with tempfile .TemporaryDirectory () as tmp_dir : # avoid multiple artifact copies
789- opt_state_file = model_args .get_opt_state (tmp_dir )
790- opt_state = from_bytes (opt_state_shape , opt_state_file .read ())
791- opt_state_file .close ()
791+ opt_state_file = model_args .get_opt_state ()
792+ opt_state = from_bytes (opt_state_shape , opt_state_file .read ())
793+ opt_state_file .close ()
792794
793795 # restore other attributes
794796 attr_state = {
@@ -1034,59 +1036,80 @@ def run_evaluation():
10341036
10351037 def run_save_model (state , eval_metrics = None ):
10361038 if jax .process_index () == 0 :
1039+
1040+ output_dir = training_args .output_dir
1041+ use_bucket = output_dir .startswith ("gs://" )
1042+ if use_bucket :
1043+ bucket_path = Path (output_dir [5 :]) / wandb .run .id / f"step_{ state .step } "
1044+ bucket , dir_path = str (bucket_path ).split ("/" , 1 )
1045+ tmp_dir = tempfile .TemporaryDirectory ()
1046+ output_dir = tmp_dir .name
1047+
1048+ # save model
10371049 params = jax .device_get (state .params )
1038- # save model locally
10391050 model .save_pretrained (
1040- training_args . output_dir ,
1051+ output_dir ,
10411052 params = params ,
10421053 )
10431054
10441055 # save tokenizer
1045- tokenizer .save_pretrained (training_args .output_dir )
1056+ tokenizer .save_pretrained (output_dir )
1057+
1058+ # copy to bucket
1059+ if use_bucket :
1060+ client = storage .Client ()
1061+ bucket = client .bucket (bucket )
1062+ for filename in Path (output_dir ).glob ("*" ):
1063+ blob_name = str (Path (dir_path ) / filename .name )
1064+ blob = bucket .blob (blob_name )
1065+ blob .upload_from_filename (str (filename ))
1066+ tmp_dir .cleanup ()
10461067
10471068 # save state
10481069 opt_state = jax .device_get (state .opt_state )
1049- with (Path (training_args .output_dir ) / "opt_state.msgpack" ).open ("wb" ) as f :
1050- f .write (to_bytes (opt_state ))
1051- state_dict = {
1052- k : jax .device_get (getattr (state , k )).item ()
1053- for k in ["step" , "epoch" , "train_time" , "train_samples" ]
1054- }
1055- with (Path (training_args .output_dir ) / "training_state.json" ).open (
1056- "w"
1057- ) as f :
1058- json .dump (
1059- state_dict ,
1060- f ,
1061- )
1070+ if use_bucket :
1071+ blob_name = str (Path (dir_path ) / "opt_state.msgpack" )
1072+ blob = bucket .blob (blob_name )
1073+ blob .upload_from_file (io .BytesIO (to_bytes (opt_state )))
1074+ else :
1075+ with (Path (output_dir ) / "opt_state.msgpack" ).open ("wb" ) as f :
1076+ f .write (to_bytes (opt_state ))
10621077
10631078 # save to W&B
10641079 if training_args .log_model :
10651080 # save some space
10661081 c = wandb .wandb_sdk .wandb_artifacts .get_artifacts_cache ()
1067- c .cleanup (wandb .util .from_human_size ("10GB " ))
1082+ c .cleanup (wandb .util .from_human_size ("20GB " ))
10681083
1069- metadata = dict (state_dict )
1084+ metadata = {
1085+ k : jax .device_get (getattr (state , k )).item ()
1086+ for k in ["step" , "epoch" , "train_time" , "train_samples" ]
1087+ }
10701088 metadata ["num_params" ] = num_params
10711089 if eval_metrics is not None :
10721090 metadata ["eval" ] = eval_metrics
1091+ if use_bucket :
1092+ metadata ["bucket_path" ] = bucket_path
10731093
10741094 # create model artifact
10751095 artifact = wandb .Artifact (
10761096 name = f"model-{ wandb .run .id } " ,
10771097 type = "DalleBart_model" ,
10781098 metadata = metadata ,
10791099 )
1080- for filename in [
1081- "config.json" ,
1082- "flax_model.msgpack" ,
1083- "merges.txt" ,
1084- "special_tokens_map.json" ,
1085- "tokenizer.json" ,
1086- "tokenizer_config.json" ,
1087- "vocab.json" ,
1088- ]:
1089- artifact .add_file (f"{ Path (training_args .output_dir ) / filename } " )
1100+ if not use_bucket :
1101+ for filename in [
1102+ "config.json" ,
1103+ "flax_model.msgpack" ,
1104+ "merges.txt" ,
1105+ "special_tokens_map.json" ,
1106+ "tokenizer.json" ,
1107+ "tokenizer_config.json" ,
1108+ "vocab.json" ,
1109+ ]:
1110+ artifact .add_file (
1111+ f"{ Path (training_args .output_dir ) / filename } "
1112+ )
10901113 wandb .run .log_artifact (artifact )
10911114
10921115 # create state artifact
@@ -1095,9 +1118,9 @@ def run_save_model(state, eval_metrics=None):
10951118 type = "DalleBart_state" ,
10961119 metadata = metadata ,
10971120 )
1098- for filename in [ "opt_state.msgpack" , "training_state.json" ] :
1121+ if not use_bucket :
10991122 artifact_state .add_file (
1100- f"{ Path (training_args .output_dir ) / filename } "
1123+ f"{ Path (training_args .output_dir ) / 'opt_state.msgpack' } "
11011124 )
11021125 wandb .run .log_artifact (artifact_state )
11031126
0 commit comments