From 85c105e71a0c3d8f4c9fef186366efff8660ed9f Mon Sep 17 00:00:00 2001 From: Vipan Nalla Date: Mon, 8 Jul 2024 22:19:51 +0000 Subject: [PATCH] Fix decode.py to also use first_token from prefill_call --- MaxText/decode.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MaxText/decode.py b/MaxText/decode.py index 1697cdd10f..19b76a2173 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -35,7 +35,7 @@ def main(config): ) assert true_length <= config.max_prefill_predict_length, "can't take too many tokens" assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet" - prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) + prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) slot = 0 decode_state = engine.init_decode_state() @@ -43,6 +43,7 @@ def main(config): steps = range(config.max_prefill_predict_length, config.max_target_length) sampled_tokens_list = [] + sampled_tokens_list.append(first_token) for _ in steps: decode_state, sampled_tokens = engine.generate(params, decode_state) sampled_tokens_list.append(sampled_tokens)