diff --git a/MaxText/multihost_dataloading.py b/MaxText/multihost_dataloading.py index 09fc0875fa..24f5ec4dd1 100644 --- a/MaxText/multihost_dataloading.py +++ b/MaxText/multihost_dataloading.py @@ -26,6 +26,7 @@ import os from typing import Callable, Any, Dict, List, Tuple, Optional import tensorflow as tf # pylint: disable=g-import-not-at-top +import time import numpy as np import jax @@ -107,7 +108,12 @@ def get_next_batch_sharded(local_dataset: tf.data.Dataset, global_mesh: Mesh) -> jax.Array: """Splits the host loaded data equally over all devices.""" - local_data = local_dataset.next() + try: + local_data = local_dataset.next() + except tf.errors.FailedPreconditionError: + max_logging.log("Failed to get next data batch, retrying") + time.sleep(10) + local_data = local_dataset.next() # local_devices = jax.local_devices() local_devices = global_mesh.local_devices