diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index af3db1e529..4608067e89 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -262,6 +262,7 @@ def is_gpu_backend(raw_keys): def get_coordinator_ip_address(): """Get coordinator IP Address with retries""" coordinator_address = "" + coordinator_ip_address = "" if os.environ.get("JAX_COORDINATOR_ADDRESS") is not None: coordinator_address = os.environ.get("JAX_COORDINATOR_ADDRESS") coordinator_found = False diff --git a/MaxText/tests/tokenizer_test.py b/MaxText/tests/tokenizer_test.py index 642e17787c..7f5134f73b 100644 --- a/MaxText/tests/tokenizer_test.py +++ b/MaxText/tests/tokenizer_test.py @@ -47,7 +47,6 @@ def setUpClass(cls): cls.dataset = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True) train_tokenizer.train_tokenizer( cls.dataset, - assets_path=assets_path, vocab_path=cls.tokenizer_path, vocab_size=cls.vocab_size, max_corpus_chars=cls.max_corpus_chars, diff --git a/MaxText/train.py b/MaxText/train.py index bfe5a32094..68ad04a97c 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -457,6 +457,8 @@ def train_loop(config, state=None): print("Loading the compiled function...", flush=True) # Need to pass train signature and state to determine i/o shapes of train_state for now. p_train_step = maxtext_utils.load_compiled(config, functional_train, state) + # TODO: p_eval_step is not yet supported in load_compiled + p_eval_step = None print("Loaded compiled function!", flush=True) else: p_train_step = jax.jit( @@ -475,6 +477,8 @@ def train_loop(config, state=None): static_argnums=static_argnums_eval, donate_argnums=donate_argnums_eval, ) + else: + p_eval_step = None local_metrics_file = open(config.metrics_file, "a", encoding="utf8") if config.metrics_file else None running_gcs_metrics = [] if config.gcs_metrics else None diff --git a/MaxText/train_tokenizer.py b/MaxText/train_tokenizer.py index 03c0d5269c..0a440a388f 100644 --- a/MaxText/train_tokenizer.py +++ b/MaxText/train_tokenizer.py @@ -63,7 +63,6 @@ def _train_sentencepiece( *, vocab_size: int, maxchars: int = int(1e7), - assets_path: str, model_path: str, model_type: str = "unigram", character_coverage: float = 1.0, @@ -87,7 +86,6 @@ def _train_sentencepiece( abs_model_path = model_path else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) - abs_assets_path = os.path.abspath(os.path.expanduser(assets_path)) fname, _ = _dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys) with tempfile.NamedTemporaryFile(delete=False, prefix="/tmp/sp_tmp") as model_fp: pass # we just want a prefix'd tmp-filename @@ -103,8 +101,6 @@ def _train_sentencepiece( # Use an intermediate filename that is renamed to the target name to address # create and fill delays. copy_rename_path = abs_model_path + ".rntmp" - if not model_path.startswith("gs://"): - tf.io.gfile.makedirs(abs_assets_path) tf.io.gfile.copy(model_fp.name + ".model", copy_rename_path, overwrite=True) tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True) logging.info("copied %s to %s", model_fp.name + ".model", abs_model_path) @@ -118,7 +114,6 @@ def _train_sentencepiece( def train_tokenizer( dataset: tf.data.Dataset, *, - assets_path: str, vocab_path: str, vocab_size: int, max_corpus_chars: int, @@ -130,7 +125,6 @@ def train_tokenizer( dataset, vocab_size=vocab_size, maxchars=max_corpus_chars, - assets_path=assets_path, model_path=vocab_path, data_keys=data_keys, ) @@ -148,7 +142,6 @@ def main(argv): train_ds = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True) train_tokenizer( train_ds, - assets_path=_ASSETS_PATH.value, vocab_path=os.path.join(_ASSETS_PATH.value, _VOCAB_MODEL_NAME.value), vocab_size=_VOCAB_SIZE.value, max_corpus_chars=_MAX_CORPUS_CHARS.value,