Skip to content

Commit d368fb6

Browse files
committed
feat: add bucket reference to artifact
1 parent d5d442a commit d368fb6

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

tools/train/train.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,12 @@ def get_opt_state(self):
135135
artifact = wandb.run.use_artifact(state_artifact)
136136
else:
137137
artifact = wandb.Api().artifact(state_artifact)
138-
artifact_dir = artifact.download(tmp_dir)
139138
if artifact.metadata.get("bucket_path"):
139+
# we will read directly file contents
140140
self.restore_state = artifact.metadata["bucket_path"]
141141
else:
142-
self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
142+
artifact_dir = artifact.download(tmp_dir)
143+
self.restore_state = str(Path(artifact_dir) / "opt_state.msgpack")
143144

144145
if self.restore_state.startswith("gs://"):
145146
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
@@ -1130,7 +1131,9 @@ def run_save_model(state, eval_metrics=None):
11301131
type="DalleBart_model",
11311132
metadata=metadata,
11321133
)
1133-
if not use_bucket:
1134+
if use_bucket:
1135+
artifact.add_reference(metadata["bucket_path"])
1136+
else:
11341137
for filename in [
11351138
"config.json",
11361139
"flax_model.msgpack",
@@ -1153,7 +1156,9 @@ def run_save_model(state, eval_metrics=None):
11531156
type="DalleBart_state",
11541157
metadata=metadata,
11551158
)
1156-
if not use_bucket:
1159+
if use_bucket:
1160+
artifact_state.add_reference(metadata["bucket_path"])
1161+
else:
11571162
artifact_state.add_file(
11581163
f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
11591164
)

0 commit comments

Comments
 (0)