Skip to content

Hybrid quantization follow-ups #3158

Description

@negvet

Tracking follow-ups from PR #2817. Detailed implementation notes are already in code comments / xfails.

Recipe / validation

  • Validate per-GEMM scaling-mode compatibility for hybrid qfactories.
    • See HybridQuantizer._get_compatible_recipe() in transformer_engine/pytorch/tensor/hybrid_tensor.py.
  • Support delayed-scaling requests inside HybridQuantizer.
    • See HybridQuantizer.__init__() in `transformer_engine/pytorch/tensor/hybrid_tensor.py

TP/SP

  • Add native hybrid dispatch to gather_along_first_dim.
    • See HybridQuantizer.supports_only_rowwise_all_gather() in transformer_engine/pytorch/tensor/hybrid_tensor.py.
  • Enable hybrid SP amax-reduction setup after native quantized all-gather.
    • See set_meta_tensor() comments in:
      • transformer_engine/pytorch/module/linear.py
      • transformer_engine/pytorch/module/layernorm_linear.py
      • transformer_engine/pytorch/module/layernorm_mlp.py

FSDP2

  • Optimize hybrid FSDP2 communication buffers.
    • See HybridQuantizedTensor.fsdp_pre_all_gather() in transformer_engine/pytorch/tensor/hybrid_tensor.py.
  • Fix HybridFloat8BlockScaling FSDP2 xfail.
    • See _HYBRID_FLOAT8_BLOCK_FSDP2_XFAIL_REASON in tests/pytorch/distributed/fsdp2_tests/conftest.py.
  • Add NVFP4 hybrid sub-storage FSDP2 hooks.
    • See TestHybridFsdpPreAllGatherProtocol.test_nvfp4_sub_storage_raises_on_pre_all_gather() in tests/pytorch/test_hybrid_quantization.py.

GEMM / quantization

  • Support HybridQuantizer / IdentityQuantizer as GEMM output quantizers.
    • See _reject_unsupported_output_quantizer() in transformer_engine/pytorch/cpp_extensions/gemm.py.

Distributed optimizer / Megatron

  • Support per-block hybrid sub-quantizers in quantize_master_weights.
    • See _route_hybrid_to_buckets() in transformer_engine/pytorch/tensor/utils.py.
    • Negative tests are in TestHybridQuantizeMasterWeights in tests/pytorch/test_hybrid_quantization.py.
  • Complete Megatron-LM quantized_model_init + --fp{4,8}-param-gather + dist opt.
  • Complete Megatron-FSDP + --fp{4,8}-param-gather.
  • Complete Torch FSDP2 + --fp{4,8}-param-gather.

Activation recompute

  • Investigate vanilla torch.utils.checkpoint(use_reentrant=False) with TE weight-workspace cache.
    • See xfails in TestHybridActivationRecompute in tests/pytorch/test_hybrid_quantization.py.
    • te.checkpoint path is already covered and works.

Validation

  • Convergence validation of base non-hybrid recipes.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Fields

No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions