From 45824d3a83d2f00c7233b31a830ddf121a83121a Mon Sep 17 00:00:00 2001 From: maxtext authors Date: Fri, 26 Jun 2026 11:40:54 -0700 Subject: [PATCH] Internal Change PiperOrigin-RevId: 938695270 --- src/maxtext/configs/pyconfig.py | 14 ++++++++++++++ src/maxtext/inference/decode.py | 4 ++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index f4a8e3965f..6ba5443cde 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -531,6 +531,20 @@ def _initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConf pydantic_kwargs = _prepare_for_pydantic(raw_keys_dict) + # Resolve relative tokenizer_path against the config directory (fileset root on Borg) + if pydantic_kwargs.get("tokenizer_path"): + fileset_root = os.path.dirname(config_path) + candidate_path = os.path.join( + fileset_root, pydantic_kwargs["tokenizer_path"] + ) + if os.path.exists(candidate_path): + logger.info( + "Resolved tokenizer_path %s to %s under fileset root", + pydantic_kwargs["tokenizer_path"], + candidate_path, + ) + pydantic_kwargs["tokenizer_path"] = candidate_path + if pydantic_kwargs.get("use_tokamax_splash") and pydantic_kwargs.get("use_jax_splash"): raise ValueError("At most one of `use_tokamax_splash` and `use_jax_splash` can be set to True.") diff --git a/src/maxtext/inference/decode.py b/src/maxtext/inference/decode.py index 68cd450c99..877904cd18 100644 --- a/src/maxtext/inference/decode.py +++ b/src/maxtext/inference/decode.py @@ -177,7 +177,7 @@ def main(argv: Sequence[str]) -> None: # Prefill rng, rng_prefill = jax.random.split(rng) # Split RNG before calling prefill for i in range(_NUM_STREAMS): - with jax.profiler.StepTraceAnnotation("prefill", stream=i): + with jax.profiler.StepTraceAnnotation("prefill", step_num=i): prefill_result, first_token = engine.prefill( params=params, padded_tokens=tokens, @@ -207,7 +207,7 @@ def main(argv: Sequence[str]) -> None: sampled_tokens_list.append(_batch_first_result_token(first_token_list, batch_size)) for i in steps: rng, rng_generate = jax.random.split(rng) - with jax.profiler.StepTraceAnnotation("generate", step=i): + with jax.profiler.StepTraceAnnotation("generate", step_num=i): decode_state, sampled_tokens = engine.generate(params, decode_state, rng=rng_generate) # Automatically deactivate profiler after profiler_steps steps