Skip to content

Commit 1c4e839

Browse files
committed
feat: load from bucket
1 parent 50498e6 commit 1c4e839

File tree

2 files changed

+48
-12
lines changed

2 files changed

+48
-12
lines changed

src/dalle_mini/model/utils.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import tempfile
3+
from pathlib import Path
34

45
import wandb
56

@@ -8,19 +9,41 @@ class PretrainedFromWandbMixin:
89
@classmethod
910
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1011
"""
11-
Initializes from a wandb artifact, or delegates loading to the superclass.
12+
Initializes from a wandb artifact, google bucket path or delegates loading to the superclass.
1213
"""
1314
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
14-
if ":" in pretrained_model_name_or_path and not os.path.isdir(
15-
pretrained_model_name_or_path
15+
if (
16+
":" in pretrained_model_name_or_path
17+
and not os.path.isdir(pretrained_model_name_or_path)
18+
and not pretrained_model_name_or_path.startswith("gs")
1619
):
1720
# wandb artifact
1821
if wandb.run is not None:
1922
artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
2023
else:
2124
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
2225
pretrained_model_name_or_path = artifact.download(tmp_dir)
26+
if artifact.metadata.get("bucket_path"):
27+
pretrained_model_name_or_path = artifact.metadata["bucket_path"]
28+
29+
if pretrained_model_name_or_path.startswith("gs://"):
30+
copy_blobs(pretrained_model_name_or_path, tmp_dir)
31+
pretrained_model_name_or_path = tmp_dir
2332

2433
return super(PretrainedFromWandbMixin, cls).from_pretrained(
2534
pretrained_model_name_or_path, *model_args, **kwargs
2635
)
36+
37+
38+
def copy_blobs(source_path, dest_path):
39+
assert source_path.startswith("gs://")
40+
from google.cloud import storage
41+
42+
bucket_path = Path(source_path[5:])
43+
bucket, dir_path = str(bucket_path).split("/", 1)
44+
client = storage.Client()
45+
bucket = client.bucket(bucket)
46+
blobs = client.list_blobs(bucket, prefix=f"{dir_path}/")
47+
for blob in blobs:
48+
dest_name = str(Path(dest_path) / Path(blob.name).name)
49+
blob.download_to_filename(dest_name)

tools/train/train.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,21 @@ def get_opt_state(self):
135135
else:
136136
artifact = wandb.Api().artifact(state_artifact)
137137
artifact_dir = artifact.download(tmp_dir)
138-
self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
139-
return Path(self.restore_state).open("rb")
138+
if artifact.metadata.get("bucket_path"):
139+
self.restore_state = artifact.metadata["bucket_path"]
140+
else:
141+
self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
142+
143+
if self.restore_state.startswith("gs://"):
144+
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
145+
bucket, blob_name = str(bucket_path).split("/", 1)
146+
client = storage.Client()
147+
bucket = client.bucket(bucket)
148+
blob = bucket.blob(blob_name)
149+
return blob.download_as_bytes()
150+
151+
with Path(self.restore_state).open("rb") as f:
152+
return f.read()
140153

141154

142155
@dataclass
@@ -788,9 +801,7 @@ def init_state(params):
788801

789802
else:
790803
# load opt_state
791-
opt_state_file = model_args.get_opt_state()
792-
opt_state = from_bytes(opt_state_shape, opt_state_file.read())
793-
opt_state_file.close()
804+
opt_state = from_bytes(opt_state_shape, model_args.get_opt_state())
794805

795806
# restore other attributes
796807
attr_state = {
@@ -1060,15 +1071,15 @@ def run_save_model(state, eval_metrics=None):
10601071
client = storage.Client()
10611072
bucket = client.bucket(bucket)
10621073
for filename in Path(output_dir).glob("*"):
1063-
blob_name = str(Path(dir_path) / filename.name)
1074+
blob_name = str(Path(dir_path) / "model" / filename.name)
10641075
blob = bucket.blob(blob_name)
10651076
blob.upload_from_filename(str(filename))
10661077
tmp_dir.cleanup()
10671078

10681079
# save state
10691080
opt_state = jax.device_get(state.opt_state)
10701081
if use_bucket:
1071-
blob_name = str(Path(dir_path) / "opt_state.msgpack")
1082+
blob_name = str(Path(dir_path) / "state" / "opt_state.msgpack")
10721083
blob = bucket.blob(blob_name)
10731084
blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
10741085
else:
@@ -1088,10 +1099,10 @@ def run_save_model(state, eval_metrics=None):
10881099
metadata["num_params"] = num_params
10891100
if eval_metrics is not None:
10901101
metadata["eval"] = eval_metrics
1091-
if use_bucket:
1092-
metadata["bucket_path"] = bucket_path
10931102

10941103
# create model artifact
1104+
if use_bucket:
1105+
metadata["bucket_path"] = f"gs://{bucket_path}/model"
10951106
artifact = wandb.Artifact(
10961107
name=f"model-{wandb.run.id}",
10971108
type="DalleBart_model",
@@ -1113,6 +1124,8 @@ def run_save_model(state, eval_metrics=None):
11131124
wandb.run.log_artifact(artifact)
11141125

11151126
# create state artifact
1127+
if use_bucket:
1128+
metadata["bucket_path"] = f"gs://{bucket_path}/state"
11161129
artifact_state = wandb.Artifact(
11171130
name=f"state-{wandb.run.id}",
11181131
type="DalleBart_state",

0 commit comments

Comments
 (0)