[Pytorch][Common] Hybrid quantization#2817
Conversation
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces hybrid (per-direction) quantization to TransformerEngine, allowing rowwise and columnwise GEMM operands to use different quantization formats (e.g., NVFP4 forward + MXFP8 backward) via a
Confidence Score: 5/5Safe to merge — the new hybrid and identity quantization paths are well-isolated and thoroughly tested; existing FP8 and MXFP8 paths are unchanged in behavior. The core dispatch logic (GEMM unwrapping, grouped-linear split-quantize, FSDP2 pre/post all-gather, master-weight update) follows the same contracts as the existing per-format paths. Previous review concerns about mixed hybrid/None quantizer lists, make_empty exception safety, and FSDP2 direction-awareness for Float8 sub-storages are all addressed in this revision. The two remaining notes are defensive/forward-compatibility concerns that do not affect any currently exercised code path. transformer_engine/pytorch/cpp_extensions/gemm.py (None passthrough from dropped hybrid sub-storage) and transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py (fsdp_extract_buffers single-direction coverage vs fsdp_buffer_fields bidirectional reporting) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
HQ[HybridQuantizer
rowwise_quantizer
columnwise_quantizer
columnwise_source] -->|quantize_impl| HQT[HybridQuantizedTensor
_rowwise_storage
_columnwise_storage]
HQT -->|GEMM dispatch| UA[_unwrap_hybrid_A
layout0==T → rowwise
layout0==N → columnwise]
HQT -->|GEMM dispatch| UB[_unwrap_hybrid_B
layout1==N → rowwise
layout1==T → columnwise]
UA --> UI1[_unwrap_identity_tensor
IdentityStorage → dequantize]
UB --> UI2[_unwrap_identity_tensor
IdentityStorage → dequantize]
UI1 --> GEMM[general_gemm / general_grouped_gemm]
UI2 --> GEMM
HQT -->|FSDP2 pre-AG| EPR[rowwise_storage
.fsdp_extract_buffers]
HQT -->|FSDP2 pre-AG| EPC[columnwise_storage
.fsdp_extract_buffers]
EPR --> AG[all_gather_into_tensor
per-direction buffers]
EPC --> AG
AG -->|fsdp_post_all_gather| RAG[fsdp_assign_gathered
+ _sync_usage]
RAG --> HQT2[Reconstructed
HybridQuantizedTensor]
HQ -->|GroupedLinear| HSQ[_hybrid_split_quantize
tex.split_quantize x2
row + col passes]
HSQ --> HQS[HybridQuantizedTensorStorage
per-expert]
subgraph Sub-storage types
F8[Float8Tensor
delayed/current]
MX[MXFP8Tensor]
FBW[Float8BlockwiseQTensor]
ID[IdentityTensor
high-precision]
end
HQT --> F8
HQT --> MX
HQT --> FBW
HQT --> ID
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
HQ[HybridQuantizer
rowwise_quantizer
columnwise_quantizer
columnwise_source] -->|quantize_impl| HQT[HybridQuantizedTensor
_rowwise_storage
_columnwise_storage]
HQT -->|GEMM dispatch| UA[_unwrap_hybrid_A
layout0==T → rowwise
layout0==N → columnwise]
HQT -->|GEMM dispatch| UB[_unwrap_hybrid_B
layout1==N → rowwise
layout1==T → columnwise]
UA --> UI1[_unwrap_identity_tensor
IdentityStorage → dequantize]
UB --> UI2[_unwrap_identity_tensor
IdentityStorage → dequantize]
UI1 --> GEMM[general_gemm / general_grouped_gemm]
UI2 --> GEMM
HQT -->|FSDP2 pre-AG| EPR[rowwise_storage
.fsdp_extract_buffers]
HQT -->|FSDP2 pre-AG| EPC[columnwise_storage
.fsdp_extract_buffers]
EPR --> AG[all_gather_into_tensor
per-direction buffers]
EPC --> AG
AG -->|fsdp_post_all_gather| RAG[fsdp_assign_gathered
+ _sync_usage]
RAG --> HQT2[Reconstructed
HybridQuantizedTensor]
HQ -->|GroupedLinear| HSQ[_hybrid_split_quantize
tex.split_quantize x2
row + col passes]
HSQ --> HQS[HybridQuantizedTensorStorage
per-expert]
subgraph Sub-storage types
F8[Float8Tensor
delayed/current]
MX[MXFP8Tensor]
FBW[Float8BlockwiseQTensor]
ID[IdentityTensor
high-precision]
end
HQT --> F8
HQT --> MX
HQT --> FBW
HQT --> ID
Reviews (20): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
timmoon10
left a comment
There was a problem hiding this comment.
Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | ||
| columnwise_result = self.columnwise_quantizer.quantize(tensor) |
There was a problem hiding this comment.
Do we handle the case where not all usages are needed? I'd expect something like:
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) | |
| rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None |
| requires_grad: bool = False, | ||
| pin_memory: bool = False, | ||
| ) -> HybridQuantizedTensor: | ||
| self.rowwise_quantizer.internal = True |
There was a problem hiding this comment.
Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.
There was a problem hiding this comment.
This would not work under FSDP2.
| def factory(role): | ||
| if role == "linear_weight": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_mxfp8_quantizer(), | ||
| ) | ||
| if role == "linear_input": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| if role in ("linear_grad_output", "linear_grad_input"): | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_mxfp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| return None |
There was a problem hiding this comment.
This is horrifying. Good test.
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
| # DCP serializes ``CustomRecipe`` via ``pickle``; closure-based qfactories | ||
| # (lambdas, inner functions referencing captured state) are not picklable, | ||
| # so the qfactory must live at module scope. See | ||
| # ``run_fsdp2_fused_adam.py::test_hybrid_dcp_output_parity``. |
There was a problem hiding this comment.
This comment is potentially useful, but I don't think it is in the right place - shouldn't it be closer to the actual implementation?
| for param in model.parameters(): | ||
| state = optimizer.state[param] | ||
| assert state["exp_avg"].dtype == torch.float32 | ||
| assert state["exp_avg_sq"].dtype == torch.float32 | ||
| if "master_param" in state: | ||
| assert state["master_param"].dtype == torch.float32 | ||
|
|
||
| assert losses[-1] < losses[0], f"Loss did not decrease: {losses[0]:.4f} -> {losses[-1]:.4f}" |
There was a problem hiding this comment.
That's not a very strict test, is there a way for us to do some numerical correctness comparisons?
There was a problem hiding this comment.
Enabled check for the monotonic loss decrease (still mostly sanity), and also enabled hybrid vs vanilla bitwise recipe comparizon, see e.g. test_fused_adam_hybrid_vs_base_recipe_parity.
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com>
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
|
Enable columnwise_source and hybrid recipes Respect quantizer veto for save_original_inp |
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
| elif name == "_columnwise_scale_inv" and t is not None: | ||
| expected = flattened_in_shape0 // MXFP8_BLOCK_SCALING_SIZE | ||
| if t.size(0) != expected: | ||
| t = t[:expected] | ||
| buffers.append(t) | ||
| return tuple(buffers), {"field_names": names} |
There was a problem hiding this comment.
The columnwise scale truncation uses floor division (
flattened_in_shape0 // MXFP8_BLOCK_SCALING_SIZE) instead of ceiling. For a sharded tensor where M is not a multiple of 32, ceil(M/32) scale entries are valid but M//32 are retained — the entry covering the last partial block is silently dropped. After all-gather, dequantization for those boundary rows uses a stale or zero scale. For example with M=48: 2 scale entries valid, but 48//32=1 is used, discarding row 32–47's scale.
| elif name == "_columnwise_scale_inv" and t is not None: | |
| expected = flattened_in_shape0 // MXFP8_BLOCK_SCALING_SIZE | |
| if t.size(0) != expected: | |
| t = t[:expected] | |
| buffers.append(t) | |
| return tuple(buffers), {"field_names": names} | |
| elif name == "_columnwise_scale_inv" and t is not None: | |
| expected = math.ceil(flattened_in_shape0 / MXFP8_BLOCK_SCALING_SIZE) | |
| if t.size(0) != expected: | |
| t = t[:expected] | |
| buffers.append(t) | |
| return tuple(buffers), {"field_names": names} |
Signed-off-by: Evgeny <etsykunov@nvidia.com>
kwyss-nvidia
left a comment
There was a problem hiding this comment.
Thanks Evgeny for this expansive PR!
I'm excited to see the columnwise_source options in hybrid quantizer and that the edge cases for the FSDP protocol are considered and captured in the new tensor types.
LGTM!
| assert copied.rowwise_usage is False | ||
| assert copied.columnwise_usage is True | ||
|
|
||
| def test_rowwise_dequantized_identity_columnwise_matches_rowwise(self, input_tensor): |
There was a problem hiding this comment.
Thanks for adding the coverage for double quantization and the field so that the quantized tensor tracks describes columnwise source.
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| # Module-level qfactories (picklable, required for checkpoint serialization). |
There was a problem hiding this comment.
Which qfactories go into the checkpoint serialization?
While it's useful to have provenance of how the checkpoint was created, does the pickling of qfactories mean that the resulting checkpoints won't be read by transformer engine's with the same classes for custom quantization.
Is there any way to override this requirement and load the checkpoint as BF16, ignoring the pickled qfactories?
There was a problem hiding this comment.
Good point. I clarified the comment: the module-level qfactory is only needed so TE-to-TE quantized-param checkpoints have a stable importable reference for any pickled quantizer/recipe metadata.
For portability: if quantized_model_init is disabled, model weights are normal BF16 tensors and the CustomRecipe extra state is not needed for stateless recipes, so an external consumer can ignore TE _extra_state. With quantized_model_init, the model state_dict stores TE quantized tensor subclasses for any recipe, not just hybrid/CustomRecipe, so TE -> third-party runtime portability would need a separate high-precision/BF16 export path for quantized primary weights.
Would it be useful to plan that quantized_model_init BF16-weight state_dict/export support as a follow-up, or is running without quantized_model_init for portable checkpoints sufficient for your use case?
There was a problem hiding this comment.
Running without quantized_model_init is sufficient. I did not know about that option!
| transa = layout[0] == "T" | ||
| transb = layout[1] == "T" | ||
|
|
||
| A = _materialize_high_precision(_unwrap_hybrid_A(A, layout)) |
There was a problem hiding this comment.
The naming convention could be revisited. It can be interpreted easily but falsely to mean all quantized tensors will be materialized into high precision tensors. Perhaps "_unwrap_if_high_precision"?
There was a problem hiding this comment.
Replaced with _unwrap_identity_tensor, since we are doing isinstance(tensor, IdentityTensorStorage)
|
|
||
| # Linear-only recipe (no attention quantization): the qfactory is the only knob. | ||
| recipe = CustomRecipe(qfactory=mxfp8_fwd_nvfp4_bwd_quantizer_factory) | ||
| with autocast(recipe=recipe): |
There was a problem hiding this comment.
This is pleasantly simple as an API. Thanks Evgeny.
| ``HybridQuantizer`` terms, that source choice is expressed with | ||
| ``columnwise_source="rowwise_dequantized"``. | ||
|
|
||
| All non-weight roles keep the standard NVFP4 factory behavior, including RHT |
There was a problem hiding this comment.
This is useful. We have trained with an equivalent recipe for several experiments and I'm looking forward to trying this implementation.
| # Return early if recipe state matches recipe | ||
| if self.fp8_meta_tensors_initialized: | ||
| recipe_state = self.fp8_meta[fp8_meta_tensor_key] | ||
| # Follow-up: Match built-in recipes by full config, not just RecipeState type, so |
There was a problem hiding this comment.
If this follow up liked an issue or this pull request ID, it would be easier to grep for all related follow ups.
There was a problem hiding this comment.
Create an umbrella tracker #3158 and referenced it
| ------- | ||
| MXFP8 forward plus high-precision backward from the rowwise-dequantized | ||
| forward value can be expressed as:: | ||
|
|
There was a problem hiding this comment.
In the zoo, there's an example where the weight tensor is specialized with double quantization. It would be helpful to illustrate that it's possible to customize along the weight/activation/grad axis as well as the rowwise/colwise abstraction and reference the zoo.
| for sub in (self.rowwise_quantizer, self.columnwise_quantizer): | ||
| group = getattr(sub, "amax_reduction_group", None) | ||
| if group is not None: | ||
| return group |
There was a problem hiding this comment.
Arguably, this should assert if there are two groups, they are consistent. Is that possible?
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Description
Hybrid (per-direction) quantization. Hybrid means rowwise/colwise can use different formats via CustomRecipe(qfactory).
This is an experimental feature.
The main problem that it tries to solve is that precision requirements are non-uniform.
Current recipes set one format for both rowwise and colwise directions.
Hybrid quantization enables, e.g. MXFP8 fwd and NVFP4 bwd (or vice versa) or any other valid combination. No need for a hardcoded recipe for every combination.
Composer-style (Composer 2 paper) grouped GEMM recipe, e.g. row-scaled NVFP4 fwd + MXFP8 bwd:
By default, the above factory uses
columnwise_source="original", so MXFP8 backward operands are quantized from the original high-precision tensor. Usecolumnwise_source="rowwise_dequantized"when the backward operand should be derived from the dequantized rowwise NVFP4 forward value.C++ optimizations (fusions, etc.) will come as standalone PRs. cc @kainzhong
TODO:
Follow-up issue tracker #3158.
Integration
Ecosystem integration (all functional, unit-tested):
Megatron-LM integration status:
--fp{4,8}-param-gather+ dist opt (persistent low-precision params viaquantized_model_init+ sharded-master FP32 → quantized cast viaquantize_master_weights.)- [Done] Per-tensor Float8 hybrid (delayed and/or current, any per-direction combination
including same-format, cross-format Float8, single-direction)
- [TODO] Per-block hybrid sub-quantizers (MXFP8, NVFP4, Float8Blockwise) — each rejected per-direction by
quantize_master_weights; unblocker is TE-side cast-helper / kernel.--fp{4,8}-param-gather(fix private attribute access)--fp{4,8}-param-gather- [Done] TE-side hybrid FSDP2 path works end-to-end for Float8 / MXFP8 / Float8Blockwise sub-storages (TODO: need some minor MLM update)
- [TODO] NVFP4 sub-storage FSDP2 hooks
_hybrid_split_quantizeunder Megatron MoE)Review
Total diff +14000
New hybrid source (
hybrid_tensor.py,hybrid_tensor_storage.py,identity_tensor.py,identity_tensor_storage.py) ~1800Adjacent modifications ~1500
Tests are the rest (~10K)
Suggested reading order
-columnwise_source controls whether columnwise quantization uses the original input or the rowwise-dequantized value.
1.1 Identity passthrough — b99277a
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: