diff --git a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py index ef54005b83..8bb262958c 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py @@ -1642,6 +1642,62 @@ def __init__(self, **kwargs): } qwen3_vl_4b_config = PTConfig(**qwen3_vl_4b_dict) +qwen3_vl_2b_dict = { + "architectures": ["Qwen3VLForConditionalGeneration"], + "image_token_id": 151655, + "model_type": "qwen3_vl", + "text_config": { + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 6144, + "max_position_embeddings": 262144, + "model_type": "qwen3_vl_text", + "num_attention_heads": 16, + "num_hidden_layers": 28, + "num_key_value_heads": 8, + "pad_token_id": None, + "rms_norm_eps": 1e-06, + "rope_parameters": { + "mrope_interleaved": True, + "mrope_section": [24, 20, 20], + "rope_theta": 5000000, + "rope_type": "default", + }, + "tie_word_embeddings": True, + "use_cache": True, + "vocab_size": 151936, + }, + "tie_word_embeddings": True, + "transformers_version": "4.57.0.dev0", + "video_token_id": 151656, + "vision_config": { + "deepstack_visual_indexes": [5, 11, 17], + "depth": 24, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1024, + "in_channels": 3, + "initializer_range": 0.02, + "intermediate_size": 4096, + "model_type": "qwen3_vl_vision", + "num_heads": 16, + "num_position_embeddings": 2304, + "out_hidden_size": 2048, + "patch_size": 16, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, +} +qwen3_vl_2b_config = PTConfig(**qwen3_vl_2b_dict) + # {maxtext model name: hf model config} HF_MODEL_CONFIGS = { @@ -1669,6 +1725,7 @@ def __init__(self, **kwargs): "qwen3-14b": qwen3_14b_config, "qwen3-14b-base": qwen3_14b_config, "qwen3-32b": qwen3_32b_config, + "qwen3-vl-2b": qwen3_vl_2b_config, "qwen3-vl-4b": qwen3_vl_4b_config, "llama3.1-8b": llama31_8b_config, "llama3.1-8b-Instruct": llama31_8b_config, diff --git a/src/maxtext/checkpoint_conversion/utils/hf_shape.py b/src/maxtext/checkpoint_conversion/utils/hf_shape.py index 1849c971c6..82683780fc 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_shape.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_shape.py @@ -1182,6 +1182,7 @@ def QWEN3_VL_HF_WEIGHTS_TO_SHAPE(config): "qwen3-8b": QWEN_HF_WEIGHTS_TO_SHAPE, "qwen3-14b": QWEN_HF_WEIGHTS_TO_SHAPE, "qwen3-32b": QWEN_HF_WEIGHTS_TO_SHAPE, + "qwen3-vl-2b": QWEN3_VL_HF_WEIGHTS_TO_SHAPE, "qwen3-vl-4b": QWEN3_VL_HF_WEIGHTS_TO_SHAPE, "llama3.1-8b": LLAMA31_HF_WEIGHTS_TO_SHAPE, "llama3.1-8b-Instruct": LLAMA31_HF_WEIGHTS_TO_SHAPE, diff --git a/src/maxtext/checkpoint_conversion/utils/param_mapping.py b/src/maxtext/checkpoint_conversion/utils/param_mapping.py index 0f2aae1c03..80707a2abd 100644 --- a/src/maxtext/checkpoint_conversion/utils/param_mapping.py +++ b/src/maxtext/checkpoint_conversion/utils/param_mapping.py @@ -3883,6 +3883,7 @@ def reshape_vision_attn_out(input_tensor, target_shape): "qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-vl-2b": QWEN3_VL_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-vl-4b": QWEN3_VL_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-8b-Instruct": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, @@ -3934,6 +3935,7 @@ def reshape_vision_attn_out(input_tensor, target_shape): "qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-14b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-vl-2b": QWEN3_VL_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-vl-4b": QWEN3_VL_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-8b-Instruct": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, diff --git a/src/maxtext/configs/models/qwen3-vl-2b.yml b/src/maxtext/configs/models/qwen3-vl-2b.yml new file mode 100644 index 0000000000..d4ab791276 --- /dev/null +++ b/src/maxtext/configs/models/qwen3-vl-2b.yml @@ -0,0 +1,56 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Model config for Qwen/Qwen3-VL-2B-Instruct + +# Core Architectural Parameters +decoder_block: "qwen3" +base_emb_dim: 2048 +base_mlp_dim: 6144 +base_num_query_heads: 16 +base_num_kv_heads: 8 +base_num_decoder_layers: 28 +head_dim: 128 +mlp_activations: ["silu", "linear"] +vocab_size: 151936 +normalization_layer_epsilon: 1.0e-6 +use_qk_norm: true +logits_via_embedding: true +normalize_embedding_logits: false + +# RoPE Settings +rope_max_timescale: 5000000 + +# General Model Settings +enable_dropout: false + +# Vision Encoder Configuration +# Based on HuggingFace AutoConfig for Qwen/Qwen3-VL-2B-Instruct +use_multimodal: true +image_size_for_vit: 768 +hidden_size_for_vit: 1024 +intermediate_size_for_vit: 4096 +num_attention_heads_for_vit: 16 +num_hidden_layers_for_vit: 24 +num_channels_for_vit: 3 +patch_size_for_vit: 16 +temporal_patch_size_for_vit: 2 +spatial_merge_size_for_vit: 2 +out_hidden_size_for_vit: 2048 +num_position_embeddings_for_vit: 2304 +deepstack_visual_indexes_for_vit: [5, 11, 17] + +# MRoPE Settings +use_mrope: true +mrope_section: [24, 20, 20] diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 757e04e515..ef127ec30b 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -261,6 +261,7 @@ class ProfilerType(str, Enum): "qwen3-30b-a3b", "qwen3-30b-a3b-base", "qwen3-480b-a35b", + "qwen3-vl-2b", "qwen3-vl-4b", "qwen3-next-80b-a3b", "qwen3-omni-30b-a3b", @@ -3156,6 +3157,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de "llama4-17b-16e", "llama4-17b-128e", "qwen3-omni-30b-a3b", + "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b", diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 2ff3ff3315..0ea2913265 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -728,6 +728,7 @@ def _apply_embedding( "llama4-17b-16e", "llama4-17b-128e", "qwen3-omni-30b-a3b", + "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b", @@ -743,7 +744,7 @@ def _apply_embedding( raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") if video_embeddings is not None and cfg.use_multimodal: - if cfg.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + if cfg.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: y = mm_utils.merge_mm_embeddings( text_embeddings=y, multimodal_embeddings=video_embeddings, diff --git a/src/maxtext/layers/encoders.py b/src/maxtext/layers/encoders.py index 93d44452d4..19a7490016 100644 --- a/src/maxtext/layers/encoders.py +++ b/src/maxtext/layers/encoders.py @@ -80,7 +80,7 @@ def _setup_vision_encoder_layers(self): ) setattr(self, projector_name, qwen3_5_vision.Qwen3_5MoeVisionProjector(config=self.config, rngs=self.rngs)) return encoder_name, projector_name - elif self.config.model_name in ["qwen3-vl-4b"]: + elif self.config.model_name in ["qwen3-vl-4b", "qwen3-vl-2b"]: from maxtext.models import qwen3_vl_vision # pylint: disable=import-outside-toplevel encoder_name = "Qwen3VLVisionEncoder_0" diff --git a/src/maxtext/multimodal/processor.py b/src/maxtext/multimodal/processor.py index 3d7c5ee72e..381695915a 100644 --- a/src/maxtext/multimodal/processor.py +++ b/src/maxtext/multimodal/processor.py @@ -44,7 +44,7 @@ def preprocess_mm_data(config): images = [mm_utils.load_image_from_path(p) for p in config.image_path.split(",")] processor_outputs = preprocess_mm_data_llama4(images) - elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni # pylint: disable=import-outside-toplevel processor_outputs = preprocess_mm_data_qwen3_omni(config) @@ -68,7 +68,7 @@ def preprocess_image_for_training(image, config): from maxtext.multimodal.processor_llama4 import preprocess_mm_data_llama4 # pylint: disable=import-outside-toplevel return preprocess_mm_data_llama4(image) - elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel return preprocess_mm_data_qwen3_omni_for_training(image, config) @@ -90,7 +90,7 @@ def get_image_offsets(config, processor_output: mm_utils.PreprocessorOutput | No from maxtext.multimodal.processor_llama4 import get_image_offsets_llama4 # pylint: disable=import-outside-toplevel return get_image_offsets_llama4(processor_output) - elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import get_mm_offsets_qwen3_omni # pylint: disable=import-outside-toplevel return get_mm_offsets_qwen3_omni(config, processor_output) @@ -112,7 +112,7 @@ def reformat_prompt(prompt, image_placeholder, model_name, num_images, video_pla from maxtext.multimodal.processor_llama4 import reformat_prompt_llama4 # pylint: disable=import-outside-toplevel return reformat_prompt_llama4(prompt, image_placeholder, num_images) - elif model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import reformat_prompt_qwen3_omni # pylint: disable=import-outside-toplevel return reformat_prompt_qwen3_omni( @@ -137,7 +137,7 @@ def reformat_response(response, model_name): elif model_name in ["gemma4-26b", "gemma4-31b", "gemma4-e2b", "gemma4-e4b"]: formatted_response = f"{response}" return formatted_response - elif model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: formatted_response = f"{response}<|im_end|>" return formatted_response else: @@ -158,7 +158,7 @@ def prepare_text_for_image_fusion(tokens, config, processor_output=None): from maxtext.multimodal.processor_llama4 import add_extra_tokens_for_images_llama4 # pylint: disable=import-outside-toplevel return add_extra_tokens_for_images_llama4(tokens, processor_output) - elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import add_extra_tokens_for_qwen3_omni # pylint: disable=import-outside-toplevel return add_extra_tokens_for_qwen3_omni(tokens, config, processor_output) @@ -222,7 +222,7 @@ def get_bidirectional_mask_vision(config, decoder_input_tokens, is_video: bool = from maxtext.multimodal.processor_llama4 import LLAMA4_PATCH_TOKEN # pylint: disable=import-outside-toplevel bidirectional_mask_vision = decoder_input_tokens == LLAMA4_PATCH_TOKEN - elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: + elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3-vl-2b", "qwen3-vl-4b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: from maxtext.multimodal.processor_qwen3_omni import QwenTokens # pylint: disable=import-outside-toplevel tokens = QwenTokens(config) diff --git a/src/maxtext/utils/globals.py b/src/maxtext/utils/globals.py index 4ad0f1d9c4..4167eb0e88 100644 --- a/src/maxtext/utils/globals.py +++ b/src/maxtext/utils/globals.py @@ -63,6 +63,7 @@ "qwen3-8b": "Qwen/Qwen3-8B", "qwen3-14b": "Qwen/Qwen3-14B", "qwen3-32b": "Qwen/Qwen3-32B", + "qwen3-vl-2b": "Qwen/Qwen3-VL-2B-Instruct", "qwen3-vl-4b": "Qwen/Qwen3-VL-4B-Instruct", "llama3.1-8b": "meta-llama/Llama-3.1-8B", "llama3.1-8b-Instruct": "meta-llama/Llama-3.1-8B-Instruct",