diff --git a/MaxText/decode.py b/MaxText/decode.py index 1697cdd10f..19b76a2173 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -35,7 +35,7 @@ def main(config): ) assert true_length <= config.max_prefill_predict_length, "can't take too many tokens" assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet" - prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) + prefill_result, first_token = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) slot = 0 decode_state = engine.init_decode_state() @@ -43,6 +43,7 @@ def main(config): steps = range(config.max_prefill_predict_length, config.max_target_length) sampled_tokens_list = [] + sampled_tokens_list.append(first_token) for _ in steps: decode_state, sampled_tokens = engine.generate(params, decode_state) sampled_tokens_list.append(sampled_tokens)