Skip to content

Commit 02b2308

Browse files
committed
feat(train): google-cloud-storage is optional
1 parent 955dc20 commit 02b2308

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

tools/train/train.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from flax.serialization import from_bytes, to_bytes
4343
from flax.training import train_state
4444
from flax.training.common_utils import onehot
45-
from google.cloud import storage
4645
from jax.experimental import PartitionSpec, maps
4746
from jax.experimental.compilation_cache import compilation_cache as cc
4847
from jax.experimental.pjit import pjit, with_sharding_constraint
@@ -58,6 +57,11 @@
5857
set_partitions,
5958
)
6059

60+
try:
61+
from google.cloud import storage
62+
except:
63+
storage = None
64+
6165
cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30)
6266

6367
logger = logging.getLogger(__name__)
@@ -144,6 +148,9 @@ def get_opt_state(self):
144148
if self.restore_state.startswith("gs://"):
145149
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
146150
bucket, blob_name = str(bucket_path).split("/", 1)
151+
assert (
152+
storage is not None
153+
), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
147154
client = storage.Client()
148155
bucket = client.bucket(bucket)
149156
blob = bucket.blob(blob_name)
@@ -456,6 +463,10 @@ def __post_init__(self):
456463
assert (
457464
jax.local_device_count() == 8
458465
), "TPUs in use, please check running processes"
466+
if self.output_dir.startswith("gs://"):
467+
assert (
468+
storage is not None
469+
), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
459470
assert self.optim in [
460471
"distributed_shampoo",
461472
"adam",

0 commit comments

Comments
 (0)