Skip to content

Commit 50498e6

Browse files
committed
feat(train): save to bucket
1 parent 34cf91c commit 50498e6

File tree

1 file changed

+70
-47
lines changed

1 file changed

+70
-47
lines changed

tools/train/train.py

Lines changed: 70 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Script adapted from run_summarization_flax.py
1919
"""
2020

21-
import json
21+
import io
2222
import logging
2323
import os
2424
import sys
@@ -41,6 +41,7 @@
4141
from flax.serialization import from_bytes, to_bytes
4242
from flax.training import train_state
4343
from flax.training.common_utils import onehot
44+
from google.cloud import storage
4445
from jax.experimental import PartitionSpec, maps
4546
from jax.experimental.compilation_cache import compilation_cache as cc
4647
from jax.experimental.pjit import pjit, with_sharding_constraint
@@ -59,7 +60,6 @@
5960
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2 ** 30
6061
)
6162

62-
6363
logger = logging.getLogger(__name__)
6464

6565

@@ -123,17 +123,20 @@ def get_metadata(self):
123123
else:
124124
return dict()
125125

126-
def get_opt_state(self, tmp_dir):
127-
if self.restore_state is True:
128-
# wandb artifact
129-
state_artifact = self.model_name_or_path.replace("/model-", "/state-", 1)
130-
if jax.process_index() == 0:
131-
artifact = wandb.run.use_artifact(state_artifact)
132-
else:
133-
artifact = wandb.Api().artifact(state_artifact)
134-
artifact_dir = artifact.download(tmp_dir)
135-
self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
136-
return Path(self.restore_state).open("rb")
126+
def get_opt_state(self):
127+
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
128+
if self.restore_state is True:
129+
# wandb artifact
130+
state_artifact = self.model_name_or_path.replace(
131+
"/model-", "/state-", 1
132+
)
133+
if jax.process_index() == 0:
134+
artifact = wandb.run.use_artifact(state_artifact)
135+
else:
136+
artifact = wandb.Api().artifact(state_artifact)
137+
artifact_dir = artifact.download(tmp_dir)
138+
self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
139+
return Path(self.restore_state).open("rb")
137140

138141

139142
@dataclass
@@ -785,10 +788,9 @@ def init_state(params):
785788

786789
else:
787790
# load opt_state
788-
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
789-
opt_state_file = model_args.get_opt_state(tmp_dir)
790-
opt_state = from_bytes(opt_state_shape, opt_state_file.read())
791-
opt_state_file.close()
791+
opt_state_file = model_args.get_opt_state()
792+
opt_state = from_bytes(opt_state_shape, opt_state_file.read())
793+
opt_state_file.close()
792794

793795
# restore other attributes
794796
attr_state = {
@@ -1034,59 +1036,80 @@ def run_evaluation():
10341036

10351037
def run_save_model(state, eval_metrics=None):
10361038
if jax.process_index() == 0:
1039+
1040+
output_dir = training_args.output_dir
1041+
use_bucket = output_dir.startswith("gs://")
1042+
if use_bucket:
1043+
bucket_path = Path(output_dir[5:]) / wandb.run.id / f"step_{state.step}"
1044+
bucket, dir_path = str(bucket_path).split("/", 1)
1045+
tmp_dir = tempfile.TemporaryDirectory()
1046+
output_dir = tmp_dir.name
1047+
1048+
# save model
10371049
params = jax.device_get(state.params)
1038-
# save model locally
10391050
model.save_pretrained(
1040-
training_args.output_dir,
1051+
output_dir,
10411052
params=params,
10421053
)
10431054

10441055
# save tokenizer
1045-
tokenizer.save_pretrained(training_args.output_dir)
1056+
tokenizer.save_pretrained(output_dir)
1057+
1058+
# copy to bucket
1059+
if use_bucket:
1060+
client = storage.Client()
1061+
bucket = client.bucket(bucket)
1062+
for filename in Path(output_dir).glob("*"):
1063+
blob_name = str(Path(dir_path) / filename.name)
1064+
blob = bucket.blob(blob_name)
1065+
blob.upload_from_filename(str(filename))
1066+
tmp_dir.cleanup()
10461067

10471068
# save state
10481069
opt_state = jax.device_get(state.opt_state)
1049-
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
1050-
f.write(to_bytes(opt_state))
1051-
state_dict = {
1052-
k: jax.device_get(getattr(state, k)).item()
1053-
for k in ["step", "epoch", "train_time", "train_samples"]
1054-
}
1055-
with (Path(training_args.output_dir) / "training_state.json").open(
1056-
"w"
1057-
) as f:
1058-
json.dump(
1059-
state_dict,
1060-
f,
1061-
)
1070+
if use_bucket:
1071+
blob_name = str(Path(dir_path) / "opt_state.msgpack")
1072+
blob = bucket.blob(blob_name)
1073+
blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
1074+
else:
1075+
with (Path(output_dir) / "opt_state.msgpack").open("wb") as f:
1076+
f.write(to_bytes(opt_state))
10621077

10631078
# save to W&B
10641079
if training_args.log_model:
10651080
# save some space
10661081
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
1067-
c.cleanup(wandb.util.from_human_size("10GB"))
1082+
c.cleanup(wandb.util.from_human_size("20GB"))
10681083

1069-
metadata = dict(state_dict)
1084+
metadata = {
1085+
k: jax.device_get(getattr(state, k)).item()
1086+
for k in ["step", "epoch", "train_time", "train_samples"]
1087+
}
10701088
metadata["num_params"] = num_params
10711089
if eval_metrics is not None:
10721090
metadata["eval"] = eval_metrics
1091+
if use_bucket:
1092+
metadata["bucket_path"] = bucket_path
10731093

10741094
# create model artifact
10751095
artifact = wandb.Artifact(
10761096
name=f"model-{wandb.run.id}",
10771097
type="DalleBart_model",
10781098
metadata=metadata,
10791099
)
1080-
for filename in [
1081-
"config.json",
1082-
"flax_model.msgpack",
1083-
"merges.txt",
1084-
"special_tokens_map.json",
1085-
"tokenizer.json",
1086-
"tokenizer_config.json",
1087-
"vocab.json",
1088-
]:
1089-
artifact.add_file(f"{Path(training_args.output_dir) / filename}")
1100+
if not use_bucket:
1101+
for filename in [
1102+
"config.json",
1103+
"flax_model.msgpack",
1104+
"merges.txt",
1105+
"special_tokens_map.json",
1106+
"tokenizer.json",
1107+
"tokenizer_config.json",
1108+
"vocab.json",
1109+
]:
1110+
artifact.add_file(
1111+
f"{Path(training_args.output_dir) / filename}"
1112+
)
10901113
wandb.run.log_artifact(artifact)
10911114

10921115
# create state artifact
@@ -1095,9 +1118,9 @@ def run_save_model(state, eval_metrics=None):
10951118
type="DalleBart_state",
10961119
metadata=metadata,
10971120
)
1098-
for filename in ["opt_state.msgpack", "training_state.json"]:
1121+
if not use_bucket:
10991122
artifact_state.add_file(
1100-
f"{Path(training_args.output_dir) / filename}"
1123+
f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
11011124
)
11021125
wandb.run.log_artifact(artifact_state)
11031126

0 commit comments

Comments
 (0)