From 0cd505c88089168b94379ca242166a9d28b17b21 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 16 Apr 2026 10:34:22 -0700 Subject: [PATCH] Add Qwen 3.6 MoE model and switch CI to Qwen3.6-35B-A3B-HQQ-INT4 Qwen 3.6 MoE shares architecture and runner with Qwen 3.5 MoE. Add a stub README pointing to the existing qwen3_5_moe example. Update CI scripts and cuda.yml to use the Qwen 3.6 prequantized checkpoint. Improve qwen3_5_moe README: add quick-start section for prequantized weights, list available prequantized checkpoints, and clean up terminology. --- .ci/scripts/export_model_artifact.sh | 4 +- .ci/scripts/test_model_e2e.sh | 4 +- .github/workflows/cuda.yml | 28 +- examples/models/qwen3_5_moe/README.md | 44 +- examples/models/qwen3_5_moe/export.py | 204 ++++++-- examples/models/qwen3_5_moe/model.md | 456 ++++++++++++++---- .../models/qwen3_5_moe/quantize_and_save.py | 32 +- .../qwen3_5_moe/test_quantize_roundtrip.py | 73 ++- examples/models/qwen3_6_moe/README.md | 15 + 9 files changed, 711 insertions(+), 149 deletions(-) create mode 100644 examples/models/qwen3_6_moe/README.md diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index f19df233628..15e7c292264 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -184,7 +184,7 @@ case "$HF_MODEL" in PREPROCESSOR_FEATURE_SIZE="" PREPROCESSOR_OUTPUT="" ;; - SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4) + SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4|SocialLocalMobile/Qwen3.6-35B-A3B-HQQ-INT4) MODEL_NAME="qwen3_5_moe" TASK="" MAX_SEQ_LEN="" @@ -194,7 +194,7 @@ case "$HF_MODEL" in ;; *) echo "Error: Unsupported model '$HF_MODEL'" - echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, SocialLocalMobile/Qwen3.6-35B-A3B-HQQ-INT4" exit 1 ;; esac diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh index 5cee37b19cf..26961ddc3f3 100755 --- a/.ci/scripts/test_model_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -216,7 +216,7 @@ case "$HF_MODEL" in AUDIO_FILE="test_audio.wav" IMAGE_PATH="" ;; - SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4) + SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4|SocialLocalMobile/Qwen3.6-35B-A3B-HQQ-INT4) MODEL_NAME="qwen3_5_moe" RUNNER_TARGET="qwen3_5_moe_runner" RUNNER_PATH="qwen3_5_moe" @@ -230,7 +230,7 @@ case "$HF_MODEL" in ;; *) echo "Error: Unsupported model '$HF_MODEL'" - echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4, SocialLocalMobile/Qwen3.6-35B-A3B-HQQ-INT4" exit 1 ;; esac diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 68ded356b99..3e2fba8e66f 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -181,6 +181,8 @@ jobs: name: "dinov2-small-imagenet1k-1-layer" - repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" + - repo: "SocialLocalMobile" + name: "Qwen3.6-35B-A3B-HQQ-INT4" quant: - "non-quantized" - "quantized-int4-tile-packed" @@ -191,7 +193,7 @@ jobs: repo: "google" name: "gemma-3-4b-it" quant: "quantized-int4-weight-only" - # Qwen3.5 MoE uses a prequantized checkpoint, only tile-packed + # Qwen MoE uses prequantized checkpoints, only tile-packed - model: repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" @@ -200,6 +202,14 @@ jobs: repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" quant: "quantized-int4-weight-only" + - model: + repo: "SocialLocalMobile" + name: "Qwen3.6-35B-A3B-HQQ-INT4" + quant: "non-quantized" + - model: + repo: "SocialLocalMobile" + name: "Qwen3.6-35B-A3B-HQQ-INT4" + quant: "quantized-int4-weight-only" # Voxtral Realtime only supports int4-tile-packed on CUDA - model: repo: "mistralai" @@ -254,7 +264,7 @@ jobs: with: timeout: 90 secrets-env: EXECUTORCH_HF_TOKEN - runner: ${{ matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} + runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'Qwen3.6-35B-A3B-HQQ-INT4') && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} gpu-arch-type: cuda gpu-arch-version: 12.6 use-custom-docker-registry: false @@ -311,6 +321,8 @@ jobs: name: "dinov2-small-imagenet1k-1-layer" - repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" + - repo: "SocialLocalMobile" + name: "Qwen3.6-35B-A3B-HQQ-INT4" quant: - "non-quantized" - "quantized-int4-tile-packed" @@ -321,7 +333,7 @@ jobs: repo: "google" name: "gemma-3-4b-it" quant: "quantized-int4-weight-only" - # Qwen3.5 MoE uses a prequantized checkpoint, only tile-packed + # Qwen MoE uses prequantized checkpoints, only tile-packed - model: repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" @@ -330,6 +342,14 @@ jobs: repo: "SocialLocalMobile" name: "Qwen3.5-35B-A3B-HQQ-INT4" quant: "quantized-int4-weight-only" + - model: + repo: "SocialLocalMobile" + name: "Qwen3.6-35B-A3B-HQQ-INT4" + quant: "non-quantized" + - model: + repo: "SocialLocalMobile" + name: "Qwen3.6-35B-A3B-HQQ-INT4" + quant: "quantized-int4-weight-only" # Voxtral Realtime only supports int4-tile-packed on CUDA - model: repo: "mistralai" @@ -378,7 +398,7 @@ jobs: quant: "non-quantized" with: timeout: 90 - runner: ${{ matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} + runner: ${{ (matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' || matrix.model.name == 'Qwen3.6-35B-A3B-HQQ-INT4') && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }} gpu-arch-type: cuda gpu-arch-version: 12.6 use-custom-docker-registry: false diff --git a/examples/models/qwen3_5_moe/README.md b/examples/models/qwen3_5_moe/README.md index 83373a804f4..b237a7ef01e 100644 --- a/examples/models/qwen3_5_moe/README.md +++ b/examples/models/qwen3_5_moe/README.md @@ -30,6 +30,24 @@ Export produces a `model.pte` and `aoti_cuda_blob.ptd` containing the compiled CUDA kernels and quantized weights. Int4 quantization is recommended — the model is too large to fit in VRAM at bf16. +### Quick start: prequantized weights + +The fastest path is to export from prequantized weights, which skips +the slow quantization step entirely. + +Prequantized checkpoints are available for download: +- [SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4](https://huggingface.co/SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4) +- [SocialLocalMobile/Qwen3.6-35B-A3B-HQQ-INT4](https://huggingface.co/SocialLocalMobile/Qwen3.6-35B-A3B-HQQ-INT4) + +```bash +python export.py --prequantized +``` + +See [Generating Prequantized Weights](#generating-prequantized-weights) +to create your own. + +### Quantize and Export + ```bash python export.py \ --model-id Qwen/Qwen3.5-35B-A3B \ @@ -60,7 +78,8 @@ python export.py \ | `--qlinear-group-size` | `32` | Group size for linear quantization | | `--qembedding` | (none) | Embedding quantization: `8w` | | `--hqq` | off | Use HQQ scale-only optimization for expert quantization (slower, better accuracy) | -| `--prequantized` | (none) | Path to prequantized bundle directory (skips quantization) | +| `--sensitive` | off | Sensitivity-aware mixed precision (bf16/INT8/INT4). Required for models without quantization-aware training (e.g. Qwen3.6) | +| `--prequantized` | (none) | Path to prequantized checkpoint directory (skips quantization) | | `--turboquant` | off | Enable TurboQuant TQ4 KV cache compression (3.8x cache savings) | ### TurboQuant KV Cache Compression @@ -72,11 +91,11 @@ KV cache compression (3.8x savings) on the 10 full-attention layers. python export.py --prequantized qwen35_moe_int4_hqq --turboquant ``` -### Prequantized Export +### Generating Prequantized Weights Quantization is slow (~30 min with HQQ). To avoid re-quantizing on every -export, use `quantize_and_save.py` to create a self-contained bundle, then -export from it: +export, use `quantize_and_save.py` to create a prequantized checkpoint +directory, then export from it: ```bash # Step 1: Quantize once (slow) @@ -88,13 +107,24 @@ python quantize_and_save.py \ --hqq \ --output qwen35_moe_int4_hqq -# Step 2: Export from bundle (fast, no --model-dir needed) +# Step 2: Export from prequantized checkpoint (fast, no --model-dir needed) python export.py \ --prequantized qwen35_moe_int4_hqq ``` -The bundle contains `model.safetensors`, `config.json`, and tokenizer files. -It can be uploaded to HuggingFace Hub for easy sharing. +For models without quantization-aware training (e.g. Qwen 3.6), use +`--sensitive` for mixed-precision quantization: + +```bash +python quantize_and_save.py \ + --model-dir ~/models/Qwen3.6-35B-A3B \ + --sensitive \ + --hqq \ + --output qwen36_moe_int4_hqq +``` + +The output directory contains `model.safetensors`, `config.json`, and +tokenizer files. It can be uploaded to HuggingFace Hub for easy sharing. ## Build diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index d4839cb5e42..435054b6d2f 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -7,7 +7,7 @@ python export.py --model-id Qwen/Qwen3.5-35B-A3B python export.py --model-dir /path/to/Qwen3.5-MoE-A3B python export.py --model-dir /path/to/model --qlinear 4w - python export.py --prequantized /path/to/quantized_bundle/ + python export.py --prequantized /path/to/prequantized_dir/ python export.py --model-id Qwen/Qwen3.5-35B-A3B --backend mlx --qlinear 4w """ @@ -161,7 +161,9 @@ def load_and_quantize(args): # noqa: C901 ) # CUDA: quantize experts with packed INT4 for Triton kernel - if args.qlinear or args.qembedding: + if getattr(args, "sensitive", False): + _quantize_sensitive(model, config, args) + elif args.qlinear or args.qembedding: _quantize(model, config, args) else: model.to(dtype=torch.bfloat16) @@ -236,7 +238,6 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096, use_splitk_decod ".conv_state", ".recurrent_state", ".cache_positions", - ".inv_freq", ) expected_missing = {k for k in missing if any(p in k for p in runtime_prefixes)} weight_missing = set(missing) - expected_missing @@ -278,16 +279,16 @@ def _quantize_experts_int4(model, config, group_size=32, use_hqq=False): w1 [E, N, K//2] int8 packed, w1_scale [E, N, K//gs] bf16 w2 [E, N, K//2] int8 packed, w2_scale [E, N, K//gs] bf16 """ + from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + quantize_affine, + ) + if use_hqq: from torchao.quantization.quant_primitives import ( _choose_qparams_and_quantize_scale_only_hqq, ) - else: - from torchao.quantization.quant_primitives import ( - choose_qparams_affine, - MappingType, - quantize_affine, - ) method = "HQQ" if use_hqq else "min/max" @@ -433,6 +434,127 @@ def _quantize(model, config, args): print(f"Quantized linear layers ({args.qlinear})") +def _quantize_sensitive(model, config, args): + """Sensitivity-aware quantization using mixed precision. + + Based on GGUF Q4_K_M analysis and per-layer error + profiling of Qwen3.6. Matches GGUF Q4_K_M bit budget while using + HQQ to compensate where we use fewer bits. + + Precision assignment (bpw shown for default group_size=32): + - GatedDeltaNet internals (conv1d, dt_bias, A_log, norm): bf16 + - MoE gate routing, shared expert gate: bf16 + - Norms: bf16 + - Attention projections (qkv/in_proj, o/out_proj): INT8 + - Shared expert (gate_up, down): INT8 + - lm_head: INT8 + - Expert gate_up (w1) and down (w2): INT4, HQQ recommended + - Embeddings: INT8 + + Group size is controlled by --qlinear-group-size (default 32, matching + GGUF Q4_K granularity). Smaller group sizes improve accuracy at the + cost of more scale storage. + """ + from executorch.extension.llm.export.quantize import quantize_model_ + + group_size = args.qlinear_group_size + + # Expert weights: INT4 gs=32 with HQQ for all layers + _quantize_experts_int4(model, config, group_size, use_hqq=args.hqq) + + # Untie lm_head/embedding + if model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr(): + model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone()) + + # Per-layer quantization with sensitivity-aware precision + for i, layer in enumerate(model.layers): + _to_device_skip_meta(layer, device="cuda", dtype=torch.bfloat16) + + layer_type = config.layer_types[i] + + # GatedDeltaNet internals stay bf16: conv1d, dt_bias, A_log, norm + # are already bf16 from the device move above. + + # Attention projections: INT8 + if layer_type == "full_attention": + attn_wrapper = nn.ModuleDict( + { + "attn": nn.ModuleDict( + { + "qkv_proj": layer.attn.qkv_proj, + "o_proj": layer.attn.o_proj, + } + ) + } + ) + quantize_model_( + attn_wrapper, + qlinear_config="8w", + qlinear_group_size=group_size, + ) + layer.attn.qkv_proj = attn_wrapper.attn.qkv_proj + layer.attn.o_proj = attn_wrapper.attn.o_proj + else: + # GatedDeltaNet: quantize in_proj and out_proj to INT8, + # leave conv1d/dt_bias/A_log/norm at bf16 + attn_wrapper = nn.ModuleDict( + { + "attn": nn.ModuleDict( + { + "in_proj": layer.attn.in_proj, + "out_proj": layer.attn.out_proj, + } + ) + } + ) + quantize_model_( + attn_wrapper, + qlinear_config="8w", + qlinear_group_size=group_size, + ) + layer.attn.in_proj = attn_wrapper.attn.in_proj + layer.attn.out_proj = attn_wrapper.attn.out_proj + + # MoE gate routing: stays bf16 (no quantization) + # Shared expert gate: stays bf16 (no quantization) + + # Shared expert projections: INT8 + shared_wrapper = nn.ModuleDict({"shared": layer.mlp.shared_expert}) + quantize_model_( + shared_wrapper, + qlinear_config="8w", + qlinear_group_size=group_size, + ) + layer.mlp.shared_expert = shared_wrapper.shared + + _to_device_skip_meta(layer, device="cpu") + torch.cuda.empty_cache() + print( + f" Quantized layer {i + 1}/{config.num_hidden_layers} ({layer_type})", + end="\r", + ) + print() + + # lm_head: INT8 + print("Quantizing lm_head (INT8)...") + model.lm_head.to(device="cuda", dtype=torch.bfloat16) + wrapper = nn.ModuleDict({"lm_head": model.lm_head}) + quantize_model_(wrapper, qlinear_config="8w", qlinear_group_size=group_size) + model.lm_head = wrapper.lm_head + model.lm_head.to(device="cpu") + torch.cuda.empty_cache() + + # Embeddings: INT8 with same group size + print("Quantizing embeddings (INT8)...") + model.embed_tokens.to(dtype=torch.bfloat16) + quantize_model_(model, qembedding_config="8w", qembedding_group_size=group_size) + + # Norms stay bf16 + model.norm.to(dtype=torch.bfloat16) + + print("Quantized with sensitivity-aware mixed precision") + + def _materialize_buffers(model, config): """Materialize meta-device buffers before torch.export. @@ -739,6 +861,38 @@ def _export_cuda(model, config, args): # --------------------------------------------------------------------------- +def _validate_args(parser, args): + """Validate CLI argument combinations.""" + if args.model_id: + if args.model_dir is not None: + raise ValueError("Cannot specify model_dir when model_id is provided.") + from huggingface_hub import snapshot_download + + args.model_dir = snapshot_download(repo_id=args.model_id) + + if not args.prequantized and not args.model_dir and not args.tiny_test: + parser.error( + "--model-dir is required unless --prequantized or --tiny-test is provided." + ) + + if args.hqq and not args.qlinear and not args.sensitive: + parser.error("--hqq requires --qlinear or --sensitive") + + if args.sensitive and (args.qlinear or args.qembedding): + parser.error( + "--sensitive manages its own precision; " + "do not combine with --qlinear or --qembedding" + ) + + if args.backend == "mlx": + if args.prequantized: + parser.error("--prequantized is not supported with --backend mlx") + if getattr(args, "sensitive", False): + parser.error("--sensitive is not supported with --backend mlx") + if args.turboquant: + parser.error("--turboquant is not supported with --backend mlx") + + def main(): parser = argparse.ArgumentParser( description="Export Qwen3.5 MoE to ExecuTorch (CUDA or MLX)" @@ -809,38 +963,26 @@ def main(): "No checkpoint download needed. Tests all architectural features " "(GQA, GDN head ratio, mixed attention, MoE routing) at small scale.", ) + parser.add_argument( + "--sensitive", + action="store_true", + help="Use sensitivity-aware mixed precision quantization. " + "Keeps GatedDeltaNet internals and MoE gates at bf16, " + "uses INT8 for attention/shared experts/lm_head, and INT4 " + "for routed experts. Recommended for models without " + "quantization-aware training (e.g. Qwen3.6).", + ) parser.add_argument( "--no-splitk", action="store_true", help="Disable split-K (flash-decoding) SDPA for decode; use tiled SDPA instead.", ) args = parser.parse_args() - - if args.model_id: - if args.model_dir is not None: - raise ValueError("Cannot specify model_dir when model_id is provided.") - from huggingface_hub import snapshot_download - - args.model_dir = snapshot_download(repo_id=args.model_id) - - if not args.prequantized and not args.model_dir and not args.tiny_test: - parser.error( - "--model-dir is required unless --prequantized or --tiny-test is provided." - ) - - if args.hqq and not args.qlinear: - parser.error("--hqq requires --qlinear") + _validate_args(parser, args) if args.backend == "cuda": - # Register FLA Triton kernel (CUDA only) import executorch.backends.cuda.triton.kernels # noqa: F401 - if args.backend == "mlx": - if args.prequantized: - parser.error("--prequantized is not supported with --backend mlx") - if args.turboquant: - parser.error("--turboquant is not supported with --backend mlx") - model, config = load_and_quantize(args) if args.backend == "cuda": diff --git a/examples/models/qwen3_5_moe/model.md b/examples/models/qwen3_5_moe/model.md index 32510859b28..560eecd9572 100644 --- a/examples/models/qwen3_5_moe/model.md +++ b/examples/models/qwen3_5_moe/model.md @@ -1,6 +1,6 @@ -# Qwen 3.5 MoE — Architecture & Design Notes +# Qwen 3.5 MoE — Architecture & Implementation Reference -Developer reference for `model.py` and `export.py`. For export/usage +Developer reference for the qwen3_5_moe example. For export/usage instructions see [README.md](README.md). ## Architecture @@ -9,132 +9,400 @@ instructions see [README.md](README.md). Input tokens | v -Token Embedding (no learned position embedding — RoPE is inside attention) +embed_tokens: nn.Embedding(248320, 2048) | v -+--- Decoder Layer x40 -----------------------------------------+ -| | -| GemmaRMSNorm -> Attention (hybrid) -> residual add | -| +- 75% of layers: GatedDeltaNet (linear, O(n)) | -| +- 25% of layers: Full Attention (softmax, O(n^2)) | -| | -| GemmaRMSNorm -> Sparse MoE -> residual add | -| +- Router: top-8 expert selection + softmax weights | -| +- 256 routed experts: independent SwiGLU MLPs | -| +- Shared expert: always-on SwiGLU with sigmoid gate | -| | -+----------------------------------------------------------------+ ++--- Block x40 (layers[i]) ----------------------------------------+ +| | +| ln_1: GemmaRMSNorm(2048) -> attn -> residual add | +| +- 30 layers (i % 4 != 3): GatedDeltaNet (linear, O(n)) | +| +- 10 layers (i % 4 == 3): FullAttention (softmax, O(n^2)) | +| | +| ln_2: GemmaRMSNorm(2048) -> mlp: SparseMoE -> residual add | +| +- gate: nn.Linear(2048, 256) -> top-8 + softmax | +| +- experts: FusedMoEExperts (256 routed SwiGLU experts) | +| +- shared_expert: SwiGLU(2048, 512) always-on | +| +- shared_expert_gate: nn.Linear(2048, 1) -> sigmoid | +| | ++--------------------------------------------------------------------+ | v -GemmaRMSNorm -> LM Head -> logits +norm: GemmaRMSNorm(2048) -> lm_head: nn.Linear(2048, 248320) -> logits ``` Layer pattern (`full_attention_interval=4`): +``` +L L L F L L L F L L L F ... L L L F (L = GatedDeltaNet, F = FullAttention) +0 1 2 3 4 5 6 7 8 9 ... 39 +``` + +## Model Config + +| Field | Value | Notes | +|-------|-------|-------| +| `hidden_size` | 2048 | | +| `num_hidden_layers` | 40 | 30 linear + 10 full attention | +| `num_attention_heads` | 16 | Full attention Q heads | +| `num_kv_heads` | 2 | Full attention KV heads (GQA 8:1) | +| `head_dim` | 256 | | +| `partial_rotary_factor` | 0.25 | 64 of 256 dims get RoPE | +| `rope_theta` | 10,000,000 | | +| `linear_num_key_heads` | 16 | GatedDeltaNet K heads | +| `linear_num_value_heads` | 32 | GatedDeltaNet V heads (head_repeat=2) | +| `linear_key_head_dim` | 128 | | +| `linear_value_head_dim` | 128 | | +| `linear_conv_kernel_dim` | 4 | Causal depthwise conv1d kernel | +| `num_experts` | 256 | | +| `num_experts_per_tok` | 8 | Top-k routing | +| `moe_intermediate_size` | 512 | Per-expert hidden dim | +| `shared_expert_intermediate_size` | 512 | | +| `vocab_size` | 248320 | | +| `rms_norm_eps` | 1e-6 | | +| Total parameters | ~35B | ~3B active per token | +## Component Details + +### GemmaRMSNorm + +Unit-offset RMSNorm. Weight initialized to zeros, formula uses `(1 + weight)`: +```python +normed = x / sqrt(mean(x^2) + eps) +return normed * (1.0 + weight) ``` -L L L F L L L F L L L F ... L L L F (L = GatedDeltaNet, F = Full Attention) + +### RMSNormGated + +Used only in GatedDeltaNet output. Combines RMSNorm with SiLU gating: +```python +return (weight * RMSNorm(x)) * silu(z) ``` -## Model Parameters - -| Parameter | Value | -|-----------|-------| -| `hidden_size` | 2048 | -| `num_hidden_layers` | 40 | -| `num_attention_heads` / `num_kv_heads` | 16 / 2 | -| `head_dim` | 256 | -| `partial_rotary_factor` | 0.25 (64 of 256 dims rotated) | -| `linear_num_key_heads` / `linear_num_value_heads` | 16 / 32 | -| `linear_key_head_dim` / `linear_value_head_dim` | 128 / 128 | -| `num_experts` / `num_experts_per_tok` | 256 / 8 | -| `moe_intermediate_size` | 512 | -| `shared_expert_intermediate_size` | 512 | -| `vocab_size` | 248320 | -| Total parameters | ~35B (~3B active per token) | - -## Key Components - -| Component | Description | -|-----------|-------------| -| **GemmaRMSNorm** | `x / sqrt(mean(x^2) + eps) * (1 + weight)` — unit-offset variant, weight init to zeros | -| **RMSNormGated** | `weight * RMSNorm(x) * silu(z)` — used in GatedDeltaNet output | -| **Full Attention** | GQA with output gate (sigmoid), QK-norm (GemmaRMSNorm), partial RoPE (25% of dims). `q_proj` produces Q + gate (2x heads). | -| **GatedDeltaNet** | Linear attention via recurrent state. Mamba-style gating: `g = -exp(A_log) * softplus(a + dt_bias)`. Causal conv1d, L2-normalized Q/K, delta rule recurrence. Uses FLA Triton kernel on CUDA. | -| **Sparse MoE** | Router selects top-8 of 256 experts per token. Shared expert with sigmoid gate always runs. | +### FullAttention -## Memory-Efficient Loading +GQA with output gate, QK-norm, and partial RoPE. + +**Submodules:** +- `qkv_proj`: `nn.Linear(2048, 9216)` — fused Q (with gate) + K + V + - Q + gate: `n_heads * head_dim * 2 = 16 * 256 * 2 = 8192` + - K: `n_kv_heads * head_dim = 2 * 256 = 512` + - V: `n_kv_heads * head_dim = 2 * 256 = 512` +- `o_proj`: `nn.Linear(4096, 2048)` — output projection +- `q_norm`, `k_norm`: `GemmaRMSNorm(256)` — applied before RoPE +- `rotary_emb`: partial RoPE on first 64 of 256 dims +- `kv_cache`: `KVCache` — `[1, 2, max_seq_len, 256]` for K and V +- `cache_positions`: `arange(max_seq_len)` — for causal mask + +**Forward (decode, T=1):** +``` +x -> qkv_proj -> split Q+gate, K, V +Q, K -> q_norm, k_norm -> partial RoPE +K, V -> kv_cache.update +Q, K_cached, V_cached -> SDPA (split-K or tiled) -> output +output * sigmoid(gate) -> o_proj +``` -`from_hf_checkpoint()` uses the voxtral_realtime pattern to minimize peak -memory (~1x model size instead of ~3x): +**Forward (prefill, T>1):** Same but uses `sdpa` instead of `sdpa_decode_splitk`. -1. **Meta device construction** — `with torch.device("meta"):` builds the - model with zero-storage parameter tensors (shape/dtype metadata only). -2. **safetensors lazy access** — `safe_open` loads tensors on demand from - each shard, remapping checkpoint keys inline. -3. **`assign=True` state dict loading** — replaces meta tensors by reference - instead of copying into pre-allocated storage. No duplication. -4. **Buffers stay on meta** — KV caches, conv/recurrent state, causal masks, - and RoPE tables remain on meta device. They are materialized in - `export.py` before `torch.export` (which requires real tensors for - in-place buffer ops). +### GatedDeltaNet -## Expert Weight Structure +Linear attention with delta rule recurrence. Mamba-style gating. -Expert weights are stored as grouped `nn.Linear` modules for quantization -compatibility. Each group of 16 experts shares a single `nn.Linear`: +**Submodules:** +- `in_proj`: `nn.Linear(2048, 12352)` — fused projection, split into: + - `qkv` (conv_dim = k_dim*2 + v_dim = 2048*2 + 4096 = 8192): goes through conv1d + - `z` (value_dim = 4096): gating signal for output norm + - `b` (num_v_heads = 32): beta for delta rule + - `a` (num_v_heads = 32): decay parameter +- `conv1d`: depthwise `Conv1d(8192, 8192, kernel=4, groups=8192)` — no bias +- `dt_bias`: `Parameter([32])` — bias for decay computation +- `A_log`: `Parameter([32])` — log of decay base +- `norm`: `RMSNormGated(128)` — output norm with SiLU gate +- `out_proj`: `nn.Linear(4096, 2048)` +- `conv_state`: buffer `[1, 8192, 4]` — causal conv1d state +- `recurrent_state`: buffer `[1, 32, 128, 128]` — delta rule state (H, K, V) -- `gate_up_projs[g]`: `nn.Linear(2048, 16 * 512 * 2)` — fused gate+up -- `down_projs[g]`: `nn.Linear(512, 16 * 2048)` — down projection +**Forward (decode, T=1):** +``` +x -> in_proj -> split qkv, z, b, a +qkv -> causal conv1d (manual, with state) -> silu -> split Q, K, V +Q, K -> L2 normalize -> repeat_interleave (16 heads -> 32 heads) +beta = sigmoid(b) +g = -exp(A_log) * softplus(a + dt_bias) +state = state * exp(g) # decay +Sk = einsum(state, k) # project state by key +delta = beta * (v - Sk) # delta rule +state = state + einsum(k, delta) # update state +output = einsum(state, q) * scale # query state +output -> RMSNormGated(output, z) -> out_proj +``` -16 experts per group keeps each `nn.Linear` under ~32K output features, -within tinygemm int4 packing limits. 256 experts / 16 = 16 groups, giving -32 matmul nodes per layer instead of 768 with per-expert linears. +**Forward (prefill, T>1):** Uses chunked FLA Triton kernel +`torch.ops.triton.chunk_gated_delta_rule` instead of the recurrent loop. -Forward pass: compute all groups → cat → gather top-k → SwiGLU → compute -all groups → cat → gather correct expert per slot. +**State reset:** When `input_pos[0] == 0`, both `conv_state` and +`recurrent_state` are zeroed (multiplied by 0). -## Quantization +### FusedMoEExperts -`export.py` is split into `load_and_quantize()` and `export_and_lower()`. +Stores all expert weights as stacked tensors. -Quantization is done layer-by-layer on CUDA: each layer's parameters (not -meta buffers) are moved to CUDA, quantized (tinygemm int4 packing requires -CUDA), then moved back to CPU. Peak GPU memory is ~1 bf16 layer at a time. -The model stays on CPU — `torch.export` traces the graph without executing -ops. +**Before quantization (nn.Parameters):** +- `w1_weight`: `[256, 1024, 2048]` — fused gate+up (2 * 512 = 1024) +- `w2_weight`: `[256, 2048, 512]` — down projection -With `--qlinear 4w --qembedding 8w`: +**After quantization (registered buffers):** +- `w1`: `[256, 1024, 1024]` int8 — packed INT4 (two values per byte) +- `w1_scale`: `[256, 1024, 2048//gs]` bf16 +- `w2`: `[256, 2048, 256]` int8 — packed INT4 +- `w2_scale`: `[256, 2048, 512//gs]` bf16 +- `group_size`: int — inferred from weight/scale shape ratio -| Component | Quantization | -|-----------|-------------| -| 40 layers (attention + MoE linears) | 4w (int4 weight-only) | -| `lm_head` | 4w | -| `embed_tokens` | 8w (int8 weight-only) | -| Conv1d, norm weights, `A_log`, `dt_bias` | unquantized (bf16) | +**INT4 packing:** `uint4 = int4 + 8` (shift to [0,15]), then +`packed = low_nibble | (high_nibble << 4)` stored as int8. -Embedding and lm_head are untied before quantization since they require -different quantization formats (embedding uses index lookup, lm_head uses -matmul). +**Forward (decode):** `torch.ops.triton.fused_moe` — vec-mat MoE kernel. +**Forward (prefill):** `torch.ops.triton.fused_moe_batched_gemm` — batched +tensor-core MoE kernel. Toggled via `use_batched_moe` flag. -## Weight Mapping +### SparseMoE -| Checkpoint prefix | Model prefix | -|-------------------|-------------| +```python +scores = gate(x) # [B*T, 256] +expert_weights, expert_indices = topk(scores, 8) +expert_weights = softmax(expert_weights) # normalize top-k +routed_out = experts(x, expert_weights, expert_indices, top_k=8) +shared_out = shared_expert(x) # SwiGLU always runs +shared_gate_val = sigmoid(shared_expert_gate(x)) +output = routed_out + shared_gate_val * shared_out +``` + +### SwiGLU + +Used for shared expert. Fused gate+up projection: +```python +gate_up = gate_up_proj(x) # [B, 2*intermediate] +gate, up = split(gate_up) +return down_proj(silu(gate) * up) +``` + +## State Buffers + +All stateful buffers are registered buffers with in-place updates (no +state in/out function args). Shared across decode/prefill methods via +`share_mutable_buffers=True` in ExecuTorch export. + +| Buffer | Shape | Per | Purpose | +|--------|-------|-----|---------| +| `kv_cache.k_cache` | `[1, 2, max_seq_len, 256]` | full_attn layer (10) | Key cache | +| `kv_cache.v_cache` | `[1, 2, max_seq_len, 256]` | full_attn layer (10) | Value cache | +| `conv_state` | `[1, 8192, 4]` | GDN layer (30) | Causal conv1d state | +| `recurrent_state` | `[1, 32, 128, 128]` | GDN layer (30) | Delta rule recurrent state | +| `cache_positions` | `[max_seq_len]` | full_attn layer (10) | For causal mask computation | + +## Memory-Efficient Loading + +`from_hf_checkpoint()` minimizes peak memory (~1x model size): + +1. **Meta device construction** — `with torch.device("meta"):` allocates + no storage. +2. **safetensors lazy access** — `safe_open` loads one shard at a time, + remapping checkpoint keys inline via `_process_checkpoint_key`. +3. **Weight fusion** — separate Q/K/V projections are concatenated into + fused `qkv_proj`; GDN projections fused into `in_proj`; shared expert + gate+up fused into `gate_up_proj`. Done in `_fuse_projection_weights`. +4. **Expert stacking** — per-expert weights `experts.{N}.{gate,up,down}_proj` + are stacked into `[E, N, K]` tensors. Fused format + `experts.gate_up_proj` / `experts.down_proj` loaded directly. +5. **`assign=True` state dict loading** — replaces meta tensors by + reference, no duplication. +6. **Buffers stay on meta** — KV caches, conv/recurrent state, masks, and + RoPE tables materialized later in `_materialize_buffers`. + +## Weight Mapping (HuggingFace -> Model) + +Checkpoint keys may have `model.language_model.` prefix (multimodal +config). This is stripped in `_process_checkpoint_key`. + +### Embeddings and head +| Checkpoint | Model | +|------------|-------| | `model.embed_tokens.weight` | `embed_tokens.weight` | | `model.norm.weight` | `norm.weight` | +| `lm_head.weight` | `lm_head.weight` (cloned from `embed_tokens` if absent) | + +### Per-layer norms +| Checkpoint | Model | +|------------|-------| | `model.layers.{N}.input_layernorm.weight` | `layers.{N}.ln_1.weight` | | `model.layers.{N}.post_attention_layernorm.weight` | `layers.{N}.ln_2.weight` | -| `model.layers.{N}.self_attn.{q,k,v,o}_proj.weight` | `layers.{N}.attn.{q,k,v,o}_proj.weight` | -| `model.layers.{N}.self_attn.{q,k}_norm.weight` | `layers.{N}.attn.{q,k}_norm.weight` | -| `model.layers.{N}.linear_attn.*` | `layers.{N}.attn.*` | -| `model.layers.{N}.mlp.experts.gate_up_proj` | `layers.{N}.mlp.cond_ffn.gate_up_projs.{G}.weight` (split into groups) | -| `model.layers.{N}.mlp.experts.down_proj` | `layers.{N}.mlp.cond_ffn.down_projs.{G}.weight` (split into groups) | + +### Full attention (layers where N % 4 == 3) +| Checkpoint | Model | +|------------|-------| +| `model.layers.{N}.self_attn.q_proj.weight` | fused into `layers.{N}.attn.qkv_proj.weight` | +| `model.layers.{N}.self_attn.k_proj.weight` | fused into `layers.{N}.attn.qkv_proj.weight` | +| `model.layers.{N}.self_attn.v_proj.weight` | fused into `layers.{N}.attn.qkv_proj.weight` | +| `model.layers.{N}.self_attn.o_proj.weight` | `layers.{N}.attn.o_proj.weight` | +| `model.layers.{N}.self_attn.q_norm.weight` | `layers.{N}.attn.q_norm.weight` | +| `model.layers.{N}.self_attn.k_norm.weight` | `layers.{N}.attn.k_norm.weight` | + +### GatedDeltaNet (layers where N % 4 != 3) +| Checkpoint | Model | +|------------|-------| +| `model.layers.{N}.linear_attn.in_proj_qkv.weight` | fused into `layers.{N}.attn.in_proj.weight` | +| `model.layers.{N}.linear_attn.in_proj_z.weight` | fused into `layers.{N}.attn.in_proj.weight` | +| `model.layers.{N}.linear_attn.in_proj_b.weight` | fused into `layers.{N}.attn.in_proj.weight` | +| `model.layers.{N}.linear_attn.in_proj_a.weight` | fused into `layers.{N}.attn.in_proj.weight` | +| `model.layers.{N}.linear_attn.conv1d.weight` | `layers.{N}.attn.conv1d.weight` | +| `model.layers.{N}.linear_attn.dt_bias` | `layers.{N}.attn.dt_bias` | +| `model.layers.{N}.linear_attn.A_log` | `layers.{N}.attn.A_log` | +| `model.layers.{N}.linear_attn.norm.weight` | `layers.{N}.attn.norm.weight` | +| `model.layers.{N}.linear_attn.out_proj.weight` | `layers.{N}.attn.out_proj.weight` | + +### MoE +| Checkpoint | Model | +|------------|-------| | `model.layers.{N}.mlp.gate.weight` | `layers.{N}.mlp.gate.weight` | -| `model.layers.{N}.mlp.shared_expert.*` | `layers.{N}.mlp.shared_expert.*` | | `model.layers.{N}.mlp.shared_expert_gate.weight` | `layers.{N}.mlp.shared_expert_gate.weight` | +| `model.layers.{N}.mlp.shared_expert.gate_proj.weight` | fused into `layers.{N}.mlp.shared_expert.gate_up_proj.weight` | +| `model.layers.{N}.mlp.shared_expert.up_proj.weight` | fused into `layers.{N}.mlp.shared_expert.gate_up_proj.weight` | +| `model.layers.{N}.mlp.shared_expert.down_proj.weight` | `layers.{N}.mlp.shared_expert.down_proj.weight` | +| `model.layers.{N}.mlp.experts.gate_up_proj` | `layers.{N}.mlp.experts.w1_weight` [E, 2*I, H] | +| `model.layers.{N}.mlp.experts.down_proj` | `layers.{N}.mlp.experts.w2_weight` [E, H, I] | +| `model.layers.{N}.mlp.experts.{E}.gate_proj.weight` | stacked into `w1_weight` (alt format) | +| `model.layers.{N}.mlp.experts.{E}.up_proj.weight` | stacked into `w1_weight` (alt format) | +| `model.layers.{N}.mlp.experts.{E}.down_proj.weight` | stacked into `w2_weight` (alt format) | + +Ignored keys: `rotary_emb.inv_freq`, `linear_attn.conv1d.bias`, +visual/MTP prefixed keys. + +## Quantization + +### Standard mode (`_quantize`, `--qlinear 4w --qembedding 8w`) + +Uniform INT4 — works for models with quantization-aware training (Qwen 3.5). + +| Component | Method | Format | +|-----------|--------|--------| +| Expert w1/w2 | `_quantize_experts_int4` | Packed INT4 buffers + bf16 scales | +| All `nn.Linear` in layers | `quantize_model_("4w")` | `Int4TilePackedTo4dTensor` (tinygemm) | +| `lm_head` | `quantize_model_("4w")` | `Int4TilePackedTo4dTensor` | +| `embed_tokens` | `quantize_model_(qembedding="8w")` | `IntxUnpackedToInt8Tensor` | +| Norms, conv1d, dt_bias, A_log | Unquantized | bf16 | + +Layer-by-layer on CUDA: move layer to GPU, quantize, move back to CPU. +Peak GPU memory ~1 layer at a time. + +### Sensitive mode (`_quantize_sensitive`, `--sensitive`) + +Mixed-precision — required for models without QAT (Qwen 3.6). +Determined by per-layer error profiling and GGUF Q4_K_M analysis. + +| Component | Method | Format | bpw (gs=32) | +|-----------|--------|--------|-------------| +| Expert w1/w2 | `_quantize_experts_int4` | Packed INT4 + bf16 scales | 4.5 | +| Attention projections | `quantize_model_("8w")` | `IntxUnpackedToInt8Tensor` | 8.5 | +| Shared expert | `quantize_model_("8w")` | `IntxUnpackedToInt8Tensor` | 8.5 | +| `lm_head` | `quantize_model_("8w")` | `IntxUnpackedToInt8Tensor` | 8.5 | +| `embed_tokens` | `quantize_model_(qembedding="8w")` | `IntxUnpackedToInt8Tensor` | 8.5 | +| MoE gate, shared expert gate | Unquantized | bf16 | 16 | +| GDN conv1d, dt_bias, A_log, norm | Unquantized | bf16 | 16 | +| Layer norms, QK norms, final norm | Unquantized | bf16 | 16 | + +Selective quantization uses `nn.ModuleDict` wrappers to pass only +specific submodules to `quantize_model_`, leaving the rest at bf16. + +`--hqq` enables HQQ (Half-Quadratic Quantization) for expert INT4 — +iterative least-squares scale refinement. Only affects expert w1/w2, +not the INT8 layers. + +### Prequantized checkpoints (`quantize_and_save.py`) + +Saves quantized model to safetensors for fast reload via +`--prequantized`. Tensor subclasses (`Int4TilePackedTo4dTensor`, +`IntxUnpackedToInt8Tensor`) are flattened into inner tensors with +`.__qdata`, `.__scale`, `.__scale_and_zero`, `.__zero_point` suffixes. +Reconstruction metadata stored in safetensors header under `"quantization"`. -Visual and MTP keys are skipped. `lm_head.weight` is cloned from -`embed_tokens.weight` if not present in checkpoint (tied embeddings). +`load_prequantized_model` reconstructs subclasses via +`__tensor_unflatten__`, replaces `FusedMoEExperts` parameters with +quantized buffers, and infers `group_size` from weight/scale shape ratio. + +## Export + +`export_and_lower()` produces two methods sharing mutable state buffers: + +| Method | Shape | MoE kernel | Use | +|--------|-------|------------|-----| +| `decode` | T=1, static | `fused_moe` (vec-mat) | Token-by-token generation | +| `prefill` | T>=2, dynamic | `fused_moe_batched_gemm` (tensor-core) | Prompt processing | + +Both share KV cache, conv_state, and recurrent_state via +`share_mutable_buffers=True`. The prefill example uses +`T=max_seq_len-1` so AOTI compiles kernels for the full sequence range. + +Output: `model.pte` (program) + `aoti_cuda_blob.ptd` (CUDA kernels + weights). + +## Implementation Gotchas + +Things that will break if you change them without understanding why: + +### Manual conv1d implementation +GatedDeltaNet implements conv1d as a manual loop over kernel taps instead +of using `nn.Conv1d.forward()`. This is because `torch.export` decomposes +`nn.Conv1d` into `conv2d` ops, which lack AOTI fallback kernels. The +manual loop produces simple `mul` + `add` ops that AOTI handles natively. +The `conv1d.weight` is still an `nn.Conv1d` module (for correct weight +loading), but only `.weight` is accessed directly in forward. + +### `assign=True` in load_state_dict +Both `from_hf_checkpoint` and `load_prequantized_model` use +`model.load_state_dict(state_dict, strict=False, assign=True)`. +`assign=True` replaces meta tensors by reference — without it, PyTorch +tries to copy data into meta storage, which fails. For quantized models, +removing `assign=True` silently converts tensor subclasses +(`IntxUnpackedToInt8Tensor`, `Int4TilePackedTo4dTensor`) to regular +Parameters, losing quantization. + +### Two expert checkpoint formats +HuggingFace checkpoints come in two formats for expert weights: +1. **Fused**: `model.layers.{N}.mlp.experts.gate_up_proj` — single + `[E, 2*I, H]` tensor. Loaded directly as `w1_weight`. +2. **Per-expert**: `model.layers.{N}.mlp.experts.{E}.gate_proj.weight` — + individual `[I, H]` tensors per expert. Stacked in + `_load_and_remap_checkpoint` into `[E, 2*I, H]`. + +Both produce the same `w1_weight`/`w2_weight` tensors. The format depends +on how the checkpoint was saved upstream. `_process_checkpoint_key` +handles both via `_FUSED_EXPERT_RE` and `_EXPERT_RE` regex patterns. + +### `_to_device_skip_meta` for quantization +During quantization, layers are moved to CUDA one at a time. But some +submodules have meta-device buffers (KV cache, conv_state, +recurrent_state) that can't be moved. `_to_device_skip_meta` walks the +module tree and only moves submodules that have no meta buffers. Without +this, `layer.to("cuda")` crashes on meta buffers. + +### `torch.split` vs slicing in GatedDeltaNet +The forward uses explicit slicing (`proj[..., :cd]`) instead of +`torch.split` because `torch.split` produces `split_copy` ops in the +export graph, which lack AOTI fallback. Slicing produces `slice` ops +that AOTI handles. + +### Sensitive quantization wrapping pattern +`_quantize_sensitive` uses `nn.ModuleDict` wrappers to selectively +quantize specific submodules: +```python +wrapper = nn.ModuleDict({"attn": nn.ModuleDict({ + "in_proj": layer.attn.in_proj, + "out_proj": layer.attn.out_proj, +})}) +quantize_model_(wrapper, qlinear_config="8w", ...) +layer.attn.in_proj = wrapper.attn.in_proj +``` +This is necessary because `quantize_model_` quantizes every `nn.Linear` +it finds. The wrapper exposes only the linears we want quantized, leaving +GDN internals (conv1d, dt_bias, A_log, norm) and routing gates at bf16. ## References diff --git a/examples/models/qwen3_5_moe/quantize_and_save.py b/examples/models/qwen3_5_moe/quantize_and_save.py index 3f58ce969df..40eb8a89df5 100644 --- a/examples/models/qwen3_5_moe/quantize_and_save.py +++ b/examples/models/qwen3_5_moe/quantize_and_save.py @@ -1,4 +1,4 @@ -"""Quantize Qwen 3.5 MoE and save as a self-contained safetensors bundle. +"""Quantize Qwen 3.5 MoE and save as a self-contained safetensors checkpoint. Runs quantization once and saves the result so export.py can skip re-quantizing via --prequantized. The output directory contains everything @@ -16,6 +16,7 @@ Usage: python quantize_and_save.py --model-dir /path/to/Qwen3.5-MoE-A3B --qlinear 4w python quantize_and_save.py --model-dir /path/to/model --qlinear 4w --hqq + python quantize_and_save.py --model-dir /path/to/model --sensitive --hqq """ import argparse @@ -25,7 +26,7 @@ import torch -from executorch.examples.models.qwen3_5_moe.export import _quantize +from executorch.examples.models.qwen3_5_moe.export import _quantize, _quantize_sensitive from executorch.examples.models.qwen3_5_moe.model import Qwen35MoE from safetensors.torch import save_file @@ -197,10 +198,27 @@ def main(): action="store_true", help="Use HQQ scale-only optimization for expert quantization.", ) + parser.add_argument( + "--sensitive", + action="store_true", + help="Use sensitivity-aware mixed precision quantization. " + "Recommended for models without quantization-aware training.", + ) args = parser.parse_args() - if not args.qlinear and not args.qembedding: - parser.error("At least one of --qlinear or --qembedding is required.") + if not args.qlinear and not args.qembedding and not args.sensitive: + parser.error( + "At least one of --qlinear, --qembedding, or --sensitive is required." + ) + + if args.sensitive and (args.qlinear or args.qembedding): + parser.error( + "--sensitive manages its own precision; " + "do not combine with --qlinear or --qembedding" + ) + + if args.hqq and not args.qlinear and not args.sensitive: + parser.error("--hqq requires --qlinear or --sensitive") # Load model print("Loading model...") @@ -214,7 +232,10 @@ def main(): ) # Quantize (includes expert INT4 + linear + embedding quantization) - _quantize(model, config, args) + if args.sensitive: + _quantize_sensitive(model, config, args) + else: + _quantize(model, config, args) # Save bundle os.makedirs(args.output, exist_ok=True) @@ -230,6 +251,7 @@ def main(): "tokenizer_config.json", "merges.txt", "vocab.json", + "LICENSE", ]: src = os.path.join(args.model_dir, filename) if os.path.exists(src): diff --git a/examples/models/qwen3_5_moe/test_quantize_roundtrip.py b/examples/models/qwen3_5_moe/test_quantize_roundtrip.py index db77a9c01dd..560d9585463 100644 --- a/examples/models/qwen3_5_moe/test_quantize_roundtrip.py +++ b/examples/models/qwen3_5_moe/test_quantize_roundtrip.py @@ -27,6 +27,7 @@ from executorch.examples.models.qwen3_5_moe.export import ( _materialize_buffers, _quantize, + _quantize_sensitive, load_prequantized_model, ) from executorch.examples.models.qwen3_5_moe.model import Qwen35MoE, Qwen35MoEConfig @@ -59,8 +60,8 @@ ) -def _make_quantized_model(qlinear, qembedding, group_size, hqq=False): - """Create a tiny model with random weights and quantize it.""" +def _make_tiny_model(): + """Create a tiny model with deterministic random weights.""" torch.manual_seed(42) model = Qwen35MoE(TINY_CONFIG) model.to(dtype=torch.bfloat16) @@ -68,6 +69,12 @@ def _make_quantized_model(qlinear, qembedding, group_size, hqq=False): if p.device.type != "meta": p.data.normal_(0, 0.02) model.eval() + return model + + +def _make_quantized_model(qlinear, qembedding, group_size, hqq=False): + """Create a tiny model and quantize with _quantize.""" + model = _make_tiny_model() class Args: pass @@ -77,15 +84,29 @@ class Args: args.qembedding = qembedding args.qlinear_group_size = group_size args.qlinear_packing_format = "tile_packed_to_4d" - args.hqq = hqq _quantize(model, TINY_CONFIG, args) return model +def _make_sensitive_model(group_size, hqq=False): + """Create a tiny model and quantize with _quantize_sensitive.""" + model = _make_tiny_model() + + class Args: + pass + + args = Args() + args.qlinear_group_size = group_size + args.hqq = hqq + _quantize_sensitive(model, TINY_CONFIG, args) + + return model + + def _save_bundle(model, output_dir): - """Save a quantized model as a bundle (model.safetensors + config.json). + """Save a quantized model (model.safetensors + config.json). Uses the production save_quantized_model for weights, and writes a config.json from TINY_CONFIG so load_prequantized_model can read it. @@ -179,6 +200,50 @@ def test_4w_8w_gs128_hqq(self): """Roundtrip: 4w linear + 8w embedding, group_size=128, HQQ experts.""" self._test_roundtrip("4w", "8w", 128, hqq=True) + def _test_sensitive_roundtrip(self, group_size, hqq=False): + """Test: sensitive quantize -> save -> load -> forward + produces same output as sensitive quantize -> forward. + """ + import executorch.backends.cuda.triton.kernels # noqa: F401 + + model_a = _make_sensitive_model(group_size, hqq) + with tempfile.TemporaryDirectory() as tmpdir: + _save_bundle(model_a, tmpdir) + + _materialize_buffers(model_a, TINY_CONFIG) + model_a.to(device="cuda") + + torch.manual_seed(99) + tokens = torch.randint(0, TINY_CONFIG.vocab_size, (1, 4), device="cuda") + input_pos = torch.arange(4, device="cuda") + + with torch.no_grad(): + output_a = model_a(tokens, input_pos) + del model_a + + model_b, _ = load_prequantized_model( + tmpdir, max_seq_len=TINY_CONFIG.max_seq_len + ) + + _materialize_buffers(model_b, TINY_CONFIG) + model_b.to(device="cuda") + + with torch.no_grad(): + output_b = model_b(tokens, input_pos) + + self.assertTrue( + torch.equal(output_a, output_b), + f"Outputs differ: max diff = {(output_a - output_b).abs().max().item()}", + ) + + def test_sensitive_gs32(self): + """Roundtrip: sensitive quantization, group_size=32.""" + self._test_sensitive_roundtrip(32) + + def test_sensitive_gs32_hqq(self): + """Roundtrip: sensitive quantization, group_size=32, HQQ experts.""" + self._test_sensitive_roundtrip(32, hqq=True) + def test_load_rejects_corrupted_checkpoint(self): """load_prequantized_model raises on corrupted/mismatched checkpoint. diff --git a/examples/models/qwen3_6_moe/README.md b/examples/models/qwen3_6_moe/README.md new file mode 100644 index 00000000000..e566766febf --- /dev/null +++ b/examples/models/qwen3_6_moe/README.md @@ -0,0 +1,15 @@ +# Qwen 3.6 MoE + +Qwen 3.6 MoE uses the same architecture and runner as Qwen 3.5 MoE. +See [examples/models/qwen3_5_moe](../qwen3_5_moe/) for export, build, +and inference instructions. + +Prequantized weights are available at +[SocialLocalMobile/Qwen3.6-35B-A3B-HQQ-INT4](https://huggingface.co/SocialLocalMobile/Qwen3.6-35B-A3B-HQQ-INT4). + +Qwen 3.6 does not have quantization-aware training, so it requires +`--sensitive` for quantization. `--hqq` is recommended for better +expert weight accuracy. See the model card for details. + +**Note:** This model has not been tested or evaluated. It is provided +mainly for development purposes.