diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 71fcb873e1..05bd3b4cf3 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -914,8 +914,11 @@ def save_quantized_checkpoint_if_configured(config, params): def print_mem_stats(label:str): print(f'\nMemstats: {label}:') - for d in jax.local_devices(): - stats = d.memory_stats() - used = round(stats['bytes_in_use']/2**30, 2) - limit = round(stats['bytes_limit']/2**30, 2) - print(f"\tUsing (GB) {used} / {limit} ({used/limit:%}) on {d}") + try: + for d in jax.local_devices(): + stats = d.memory_stats() + used = round(stats['bytes_in_use']/2**30, 2) + limit = round(stats['bytes_limit']/2**30, 2) + print(f"\tUsing (GB) {used} / {limit} ({used/limit:%}) on {d}") + except (RuntimeError, KeyError): + print("\tMemstats unavailable.")