Skip to content

[quantization] Fix attention_mask computation in QuantQwen3VLTextModel#630

Open
dvsav wants to merge 1 commit intoSamsung:mainfrom
dvsav:fix_qwen_text_model_attention_mask
Open

[quantization] Fix attention_mask computation in QuantQwen3VLTextModel#630
dvsav wants to merge 1 commit intoSamsung:mainfrom
dvsav:fix_qwen_text_model_attention_mask

Conversation

@dvsav
Copy link
Copy Markdown
Contributor

@dvsav dvsav commented Apr 14, 2026

This PR fixes the divergence between the original Qwen3VL model and the wrapped one (after tico.prepare).

Symptoms

Divergence between the original (unquantized) model and the wrapped model (after tico.prepare). The divergence was detected by tico/quantization/wrapq/examples/qwen/trace_qwen.py script. In the below trace most of submodules are skipped for brevity:

--------------------------------------------------------------------------------
MODULE NAME                                            DIFFERENCE
--------------------------------------------------------------------------------
model.language_model.embed_tokens                      {'mean': 0.0, 'min': 0.0, 'max': 0.0, 'stddev': 0.0, 'peir': 0.0}
model.visual.patch_embed.proj                          {'mean': 4.663015715777874e-07, 'min': 2.9802322387695312e-08, 'max': 3.337860107421875e-06, 'stddev': 5.199021302360052e-07, 'peir': 1.105467087481138e-06}
model.visual.patch_embed                               {'mean': 4.663015715777874e-07, 'min': 2.9802322387695312e-08, 'max': 3.337860107421875e-06, 'stddev': 5.199021302360052e-07, 'peir': 1.105467087481138e-06}
...
model.language_model.layers.0.self_attn.v_proj         {'mean': 1.054959852808679e-07, 'min': 0.0, 'max': 7.171183824539185e-07, 'stddev': 1.0697030461415125e-07, 'peir': 7.13171060995813e-07}
model.language_model.layers.0.self_attn.o_proj         {'mean': 0.005488143302500248, 'min': 0.0, 'max': 0.08425191044807434, 'stddev': 0.006922183558344841, 'peir': 0.5775794506934978}
...
lm_head                                                {'mean': 0.022689497098326683, 'min': 1.1920928955078125e-07, 'max': 0.7876284718513489, 'stddev': 0.04204004257917404, 'peir': 0.6178593831941964}

As you can see, the difference between the submodules' outputs has a large leap at model.language_model.layers.0.self_attn.o_proj submodule (PEIR growth from 7.13e-07 to 0.57).

Root Cause Analysis

Debugging was aided by trace_qwen.py script:

python tico/quantization/wrapq/examples/qwen/trace_qwen.py \
    --model "Qwen/Qwen3-VL-2B-Instruct" \
    --no-trace-unquantized \
    --no-trace-quantized \
    --interesting-modules model.language_model.layers.0.self_attn.o_proj \
    --breakpoint-on-interesting-modules

Debugging localized the root cause of the divergence - attention_mask used in Qwen3VLTextAttention.forward was different from that used in QuantQwen3VLTextAttention.forward:

428     class Qwen3VLTextAttention(nn.Module):
...
459         def forward(
...
487             attn_output, attn_weights = attention_interface(
488                 self,
489                 query_states,
490                 key_states,
491                 value_states,
492  ->             attention_mask,
493                 dropout=0.0 if not self.training else self.attention_dropout,
494                 scaling=self.scaling,
495                 **kwargs,
496             )

(Pdb) pp attention_mask
tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38,  ..., -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.4028e+38, -3.4028e+38, -3.4028e+38],
          ...,
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,  0.0000e+00, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,  0.0000e+00,  0.0000e+00]]]])
(Pdb) pp attention_mask.shape
torch.Size([1, 1, 84, 84])
30      class QuantQwen3VLTextAttention(QuantModuleBase):
...
165         def forward(
...
229                 # mask add: broadcast on head axis (1 -> G)
230  ->             logits_i = self._fq(logits_i + attention_mask, self.obs_mask_add)

(Pdb) pp attention_mask
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
(Pdb) pp attention_mask.shape
torch.Size([1, 84])

In the original (unquantized) model attention mask obtains its value in Qwen3VLTextModel:

class Qwen3VLTextModel(Qwen3VLPreTrainedModel):
...
    def forward(
...
        attention_mask = create_causal_mask(
            config=self.config,
            input_embeds=inputs_embeds,
            attention_mask=attention_mask,
            cache_position=cache_position,
            past_key_values=past_key_values,
            position_ids=text_position_ids,
        )

In the current implementation of QuantQwen3VLTextModel a similar mask computation is conditional:

class QuantQwen3VLTextModel(QuantModuleBase):
...
    def forward(
...
        # Build causal mask if not provided (or provided as bool)
        if attention_mask is None or attention_mask.dtype == torch.bool:
            attention_mask = self._slice_causal(
                inputs_embeds.shape[1], inputs_embeds.device
            )

What create_causal_mask returns

create_causal_mask calls several more functions:

transformers/models/qwen3_vl/modeling_qwen3_vl.py(896)forward()
-> attention_mask = create_causal_mask(
    transformers/masking_utils.py(854)create_causal_mask()
    -> causal_mask = mask_interface(
        transformers/masking_utils.py(510)eager_mask()
        -> mask = sdpa_mask(
            transformers/masking_utils.py(414)sdpa_mask()
            returns tensor([[[[ True, False, False,  ..., False, False, False],
                              [ True,  True, False,  ..., False, False, False],
                              [ True,  True,  True,  ..., False, False, False],
                              ...,
                              [ True,  True,  True,  ...,  True, False, False],
                              [ True,  True,  True,  ...,  True,  True, False],
                              [ True,  True,  True,  ...,  True,  True,  True]]]])
        mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
        returns tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38,  ..., -3.4028e+38, -3.4028e+38, -3.4028e+38],
                          [ 0.0000e+00,  0.0000e+00, -3.4028e+38,  ..., -3.4028e+38, -3.4028e+38, -3.4028e+38],
                          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -3.4028e+38, -3.4028e+38, -3.4028e+38],
                          ...,
                          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00, -3.4028e+38, -3.4028e+38],
                          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,  0.0000e+00, -3.4028e+38],
                          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,  0.0000e+00,  0.0000e+00]]]])

The Fix

Since we are also aiming to create a causal mask (not a sliding window mask, not a chunked attention mask, computation of which wouldn't be convertible to Circle), we can do that unconditionally:

class QuantQwen3VLTextModel(QuantModuleBase):
...
    def forward(
...
        # Build causal mask
        attention_mask = self._slice_causal(
            inputs_embeds.shape[1], inputs_embeds.device
        )

Tracing Submodules' Divergence After The Fix

After the fix the divergence at model.language_model.layers.0.self_attn.o_proj drops to normal values (PEIR 4.72e-07):

--------------------------------------------------------------------------------
MODULE NAME                                            DIFFERENCE
--------------------------------------------------------------------------------
model.language_model.embed_tokens                      {'mean': 0.0, 'min': 0.0, 'max': 0.0, 'stddev': 0.0, 'peir': 0.0}
model.visual.patch_embed.proj                          {'mean': 4.663015715777874e-07, 'min': 2.9802322387695312e-08, 'max': 3.337860107421875e-06, 'stddev': 5.199021302360052e-07, 'peir': 1.105467087481138e-06}
model.visual.patch_embed                               {'mean': 4.663015715777874e-07, 'min': 2.9802322387695312e-08, 'max': 3.337860107421875e-06, 'stddev': 5.199021302360052e-07, 'peir': 1.105467087481138e-06}
...
model.language_model.layers.0.self_attn.v_proj         {'mean': 1.054959852808679e-07, 'min': 0.0, 'max': 7.171183824539185e-07, 'stddev': 1.0697030461415125e-07, 'peir': 7.13171060995813e-07}
model.language_model.layers.0.self_attn.o_proj         {'mean': 1.4910252588151707e-08, 'min': 0.0, 'max': 6.891787052154541e-08, 'stddev': 1.1672808497564802e-08, 'peir': 4.7245867289065846e-07}
...
lm_head                                                {'mean': 1.4132734804661595e-07, 'min': 0.0, 'max': 8.344650268554688e-07, 'stddev': 1.1461531101986111e-07, 'peir': 6.546005702132053e-07}

As you can see, the PEIR now stays around the order of 1e-07.

Unit Tests

$ python -m pytest test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py -v
================================================================================= test session starts =================================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 15 items                                                                                                                                                                    

test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_deepstack_injection          PASSED  [  6%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_different_batch_sizes        PASSED  [ 13%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_different_sequence_lengths   PASSED  [ 20%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_embedding_layer_quantization PASSED  [ 26%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_forward_diff                 PASSED  [ 33%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_inputs_embeds_path           PASSED  [ 40%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_layers_wrapped               PASSED  [ 46%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_mode_transitions             PASSED  [ 53%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_no_cache_mode                PASSED  [ 60%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_norm_wrapped                 PASSED  [ 66%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_observer_count               PASSED  [ 73%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_output_shape                 PASSED  [ 80%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_per_module_override          PASSED  [ 86%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_registration_in_registry     PASSED  [ 93%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_model.py::TestQuantQwen3VLTextModel::test_rotary_emb_not_wrapped       PASSED  [100%]

=========================================================================== 15 passed, 2 warnings in 6.30s ============================================================================

Conversion to Circle

$ python tico/quantization/wrapq/examples/qwen/quantize_text_model.py
┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.132488
│ PEIR       : 8.692164 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
 4.0┤                                            │
    │                                       •••  │
    │                                    •••••   │
 2.5┤                                  ••••••    │
    │                                ••••••      │
    │                              •••••••       │
    │                           ••••••••         │
 1.0┤                         ••••••••           │
    │                       ••••••••             │
    │                     ••••••••               │
-0.6┤                   ••••••••                 │
    │                 ••••••••                   │
    │               •••••••                      │
    │             ••••••••                       │
-2.1┤            •••••••                         │
    │         •••••••                            │
    │        ••••••                              │
-3.6┤        •••                                 │
    │     ••••                                   │
    │                                            │
    │  •                                         │
-5.1┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -5.1       -2.9       -0.6       1.7       4.0 

Circle model saved as 'qwen3vl_text_model.q.circle'

This PR fixes divergence between the original Qwen3VL model and the wrapped one by correcting attention_mask computation in QuantQwen3VLTextModel.

TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
@dvsav dvsav force-pushed the fix_qwen_text_model_attention_mask branch from 6769bdc to 0ae8649 Compare April 15, 2026 06:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant