From 2dbfb58320f1faeb1f160d93377d6d5154d8056a Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Wed, 1 Jul 2026 21:28:34 +0000 Subject: [PATCH] update --- src/maxtext/layers/nnx_decoders.py | 36 +++++++++--------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 1a1ed65924..4b7f6761a1 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -1111,38 +1111,22 @@ def get_deepseek(): DecoderBlockType.GEMMA: [gemma.GemmaDecoderLayer], DecoderBlockType.GEMMA2: [gemma2.Gemma2DecoderLayer], DecoderBlockType.GEMMA3: [gemma3.Gemma3DecoderLayer], - DecoderBlockType.GEMMA4: get_scannable( - gemma4.Gemma4DecoderLayer, gemma4.Gemma4ScannableBlock - ), + DecoderBlockType.GEMMA4: get_scannable(gemma4.Gemma4DecoderLayer, gemma4.Gemma4ScannableBlock), DecoderBlockType.GEMMA4_SMALL: [gemma4_small.Gemma4SmallDecoderLayer], DecoderBlockType.GPT3: [gpt3.Gpt3DecoderLayer], DecoderBlockType.QWEN2: [qwen2.Qwen2DecoderLayer], DecoderBlockType.QWEN3: [qwen3.Qwen3DecoderLayer], DecoderBlockType.QWEN3_MOE: [qwen3.Qwen3MoeDecoderLayer], - DecoderBlockType.QWEN3_CUSTOM_MOE: [ - qwen3_custom.Qwen3CustomMoeDecoderLayer - ], + DecoderBlockType.QWEN3_CUSTOM_MOE: [qwen3_custom.Qwen3CustomMoeDecoderLayer], DecoderBlockType.SIMPLE: [simple_layer.SimpleDecoderLayer], DecoderBlockType.SIMPLE_MLP: [simple_layer.SimpleMlpDecoderLayer], DecoderBlockType.DEEPSEEK: get_deepseek(), - DecoderBlockType.DEEPSEEK4: get_scannable( - deepseek4.DeepSeek4DecoderLayer, deepseek4.DeepSeek4ScannableBlock - ), - DecoderBlockType.GPT_OSS: get_scannable( - gpt_oss.GptOssDecoderLayer, gpt_oss.GptOssScannableBlock - ), - DecoderBlockType.QWEN3_NEXT: get_scannable( - qwen3.Qwen3NextDecoderLayer, qwen3.Qwen3NextScannableBlock - ), - DecoderBlockType.QWEN3_5: get_scannable( - qwen3_5.Qwen3_5DecoderLayer, qwen3_5.Qwen3_5ScannableBlock - ), - DecoderBlockType.LLAMA4: get_scannable( - llama4.Llama4DecoderLayer, llama4.Llama4ScannableBlock - ), - DecoderBlockType.OLMO3: get_scannable( - olmo3.Olmo3DecoderLayer, olmo3.Olmo3ScannableBlock - ), + DecoderBlockType.DEEPSEEK4: get_scannable(deepseek4.DeepSeek4DecoderLayer, deepseek4.DeepSeek4ScannableBlock), + DecoderBlockType.GPT_OSS: get_scannable(gpt_oss.GptOssDecoderLayer, gpt_oss.GptOssScannableBlock), + DecoderBlockType.QWEN3_NEXT: get_scannable(qwen3.Qwen3NextDecoderLayer, qwen3.Qwen3NextScannableBlock), + DecoderBlockType.QWEN3_5: get_scannable(qwen3_5.Qwen3_5DecoderLayer, qwen3_5.Qwen3_5ScannableBlock), + DecoderBlockType.LLAMA4: get_scannable(llama4.Llama4DecoderLayer, llama4.Llama4ScannableBlock), + DecoderBlockType.OLMO3: get_scannable(olmo3.Olmo3DecoderLayer, olmo3.Olmo3ScannableBlock), } if cfg.decoder_block not in layer_map: @@ -1360,6 +1344,8 @@ def _apply_embedding( "llama4-17b-16e", "llama4-17b-128e", "qwen3-omni-30b-a3b", + "qwen3-vl-2b", + "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b", }: @@ -1373,7 +1359,7 @@ def _apply_embedding( raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") if video_embeddings is not None and cfg.use_multimodal: - if cfg.model_name in {"qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"}: + if cfg.model_name in {"qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"}: y = mm_utils.merge_mm_embeddings( text_embeddings=y, multimodal_embeddings=video_embeddings,