|
42 | 42 | from flax.serialization import from_bytes, to_bytes |
43 | 43 | from flax.training import train_state |
44 | 44 | from flax.training.common_utils import onehot |
45 | | -from google.cloud import storage |
46 | 45 | from jax.experimental import PartitionSpec, maps |
47 | 46 | from jax.experimental.compilation_cache import compilation_cache as cc |
48 | 47 | from jax.experimental.pjit import pjit, with_sharding_constraint |
|
58 | 57 | set_partitions, |
59 | 58 | ) |
60 | 59 |
|
| 60 | +try: |
| 61 | + from google.cloud import storage |
| 62 | +except: |
| 63 | + storage = None |
| 64 | + |
61 | 65 | cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30) |
62 | 66 |
|
63 | 67 | logger = logging.getLogger(__name__) |
@@ -144,6 +148,9 @@ def get_opt_state(self): |
144 | 148 | if self.restore_state.startswith("gs://"): |
145 | 149 | bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack" |
146 | 150 | bucket, blob_name = str(bucket_path).split("/", 1) |
| 151 | + assert ( |
| 152 | + storage is not None |
| 153 | + ), 'Could not find google.storage. Install with "pip install google-cloud-storage"' |
147 | 154 | client = storage.Client() |
148 | 155 | bucket = client.bucket(bucket) |
149 | 156 | blob = bucket.blob(blob_name) |
@@ -456,6 +463,10 @@ def __post_init__(self): |
456 | 463 | assert ( |
457 | 464 | jax.local_device_count() == 8 |
458 | 465 | ), "TPUs in use, please check running processes" |
| 466 | + if self.output_dir.startswith("gs://"): |
| 467 | + assert ( |
| 468 | + storage is not None |
| 469 | + ), 'Could not find google.storage. Install with "pip install google-cloud-storage"' |
459 | 470 | assert self.optim in [ |
460 | 471 | "distributed_shampoo", |
461 | 472 | "adam", |
|
0 commit comments