Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion MaxText/multihost_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import os
from typing import Callable, Any, Dict, List, Tuple, Optional
import tensorflow as tf # pylint: disable=g-import-not-at-top
import time
import numpy as np

import jax
Expand Down Expand Up @@ -107,7 +108,12 @@ def get_next_batch_sharded(local_dataset: tf.data.Dataset,
global_mesh: Mesh) -> jax.Array:
"""Splits the host loaded data equally over all devices."""

local_data = local_dataset.next()
try:
local_data = local_dataset.next()
except tf.errors.FailedPreconditionError:
max_logging.log("Failed to get next data batch, retrying")
time.sleep(10)
local_data = local_dataset.next()

# local_devices = jax.local_devices()
local_devices = global_mesh.local_devices
Expand Down