From 225f1974e20df6176fa516640da4dae1a78ab39e Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Mon, 1 Jul 2024 17:52:02 +0000 Subject: [PATCH] Handle cases where memstats are not available for the device. Memstats are not guaranteed to be available and can throw an error or return None. This change will handle both `jaxlib.xla_extension.XlaRuntimeError` if the device is not a PjRt addressable device or `KeyError` if the memstats returns None if they are not available. --- MaxText/max_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 71fcb873e1..05bd3b4cf3 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -914,8 +914,11 @@ def save_quantized_checkpoint_if_configured(config, params): def print_mem_stats(label:str): print(f'\nMemstats: {label}:') - for d in jax.local_devices(): - stats = d.memory_stats() - used = round(stats['bytes_in_use']/2**30, 2) - limit = round(stats['bytes_limit']/2**30, 2) - print(f"\tUsing (GB) {used} / {limit} ({used/limit:%}) on {d}") + try: + for d in jax.local_devices(): + stats = d.memory_stats() + used = round(stats['bytes_in_use']/2**30, 2) + limit = round(stats['bytes_limit']/2**30, 2) + print(f"\tUsing (GB) {used} / {limit} ({used/limit:%}) on {d}") + except (RuntimeError, KeyError): + print("\tMemstats unavailable.")