Tracking follow-ups from PR #2817. Detailed implementation notes are already in code comments / xfails. ## Recipe / validation - [ ] Validate per-GEMM scaling-mode compatibility for hybrid qfactories. - See `HybridQuantizer._get_compatible_recipe()` in `transformer_engine/pytorch/tensor/hybrid_tensor.py`. - [ ] Support delayed-scaling requests inside `HybridQuantizer`. - See `HybridQuantizer.__init__()` in `transformer_engine/pytorch/tensor/hybrid_tensor.py ## TP/SP - [ ] Add native hybrid dispatch to `gather_along_first_dim`. - See `HybridQuantizer.supports_only_rowwise_all_gather()` in `transformer_engine/pytorch/tensor/hybrid_tensor.py`. - [ ] Enable hybrid SP amax-reduction setup after native quantized all-gather. - See `set_meta_tensor()` comments in: - `transformer_engine/pytorch/module/linear.py` - `transformer_engine/pytorch/module/layernorm_linear.py` - `transformer_engine/pytorch/module/layernorm_mlp.py` ## FSDP2 - [ ] Optimize hybrid FSDP2 communication buffers. - See `HybridQuantizedTensor.fsdp_pre_all_gather()` in `transformer_engine/pytorch/tensor/hybrid_tensor.py`. - [ ] Fix `HybridFloat8BlockScaling` FSDP2 xfail. - See `_HYBRID_FLOAT8_BLOCK_FSDP2_XFAIL_REASON` in `tests/pytorch/distributed/fsdp2_tests/conftest.py`. - [ ] Add NVFP4 hybrid sub-storage FSDP2 hooks. - See `TestHybridFsdpPreAllGatherProtocol.test_nvfp4_sub_storage_raises_on_pre_all_gather()` in `tests/pytorch/test_hybrid_quantization.py`. ## GEMM / quantization - [ ] Support `HybridQuantizer` / `IdentityQuantizer` as GEMM output quantizers. - See `_reject_unsupported_output_quantizer()` in `transformer_engine/pytorch/cpp_extensions/gemm.py`. ## Distributed optimizer / Megatron - [ ] Support per-block hybrid sub-quantizers in `quantize_master_weights`. - See `_route_hybrid_to_buckets()` in `transformer_engine/pytorch/tensor/utils.py`. - Negative tests are in `TestHybridQuantizeMasterWeights` in `tests/pytorch/test_hybrid_quantization.py`. - [ ] Complete Megatron-LM `quantized_model_init + --fp{4,8}-param-gather + dist opt`. - See PR #2817 integration notes. - [ ] Complete Megatron-FSDP + `--fp{4,8}-param-gather`. - See PR #2817 integration notes. - [ ] Complete Torch FSDP2 + `--fp{4,8}-param-gather`. - See PR #2817 integration notes and `tests/pytorch/distributed/fsdp2_tests/`. ## Activation recompute - [ ] Investigate vanilla `torch.utils.checkpoint(use_reentrant=False)` with TE weight-workspace cache. - See xfails in `TestHybridActivationRecompute` in `tests/pytorch/test_hybrid_quantization.py`. - `te.checkpoint` path is already covered and works. ## Validation - [ ] Convergence validation of base non-hybrid recipes.
Tracking follow-ups from PR #2817. Detailed implementation notes are already in code comments / xfails.
Recipe / validation
HybridQuantizer._get_compatible_recipe()intransformer_engine/pytorch/tensor/hybrid_tensor.py.HybridQuantizer.HybridQuantizer.__init__()in `transformer_engine/pytorch/tensor/hybrid_tensor.pyTP/SP
gather_along_first_dim.HybridQuantizer.supports_only_rowwise_all_gather()intransformer_engine/pytorch/tensor/hybrid_tensor.py.set_meta_tensor()comments in:transformer_engine/pytorch/module/linear.pytransformer_engine/pytorch/module/layernorm_linear.pytransformer_engine/pytorch/module/layernorm_mlp.pyFSDP2
HybridQuantizedTensor.fsdp_pre_all_gather()intransformer_engine/pytorch/tensor/hybrid_tensor.py.HybridFloat8BlockScalingFSDP2 xfail._HYBRID_FLOAT8_BLOCK_FSDP2_XFAIL_REASONintests/pytorch/distributed/fsdp2_tests/conftest.py.TestHybridFsdpPreAllGatherProtocol.test_nvfp4_sub_storage_raises_on_pre_all_gather()intests/pytorch/test_hybrid_quantization.py.GEMM / quantization
HybridQuantizer/IdentityQuantizeras GEMM output quantizers._reject_unsupported_output_quantizer()intransformer_engine/pytorch/cpp_extensions/gemm.py.Distributed optimizer / Megatron
quantize_master_weights._route_hybrid_to_buckets()intransformer_engine/pytorch/tensor/utils.py.TestHybridQuantizeMasterWeightsintests/pytorch/test_hybrid_quantization.py.quantized_model_init + --fp{4,8}-param-gather + dist opt.--fp{4,8}-param-gather.--fp{4,8}-param-gather.tests/pytorch/distributed/fsdp2_tests/.Activation recompute
torch.utils.checkpoint(use_reentrant=False)with TE weight-workspace cache.TestHybridActivationRecomputeintests/pytorch/test_hybrid_quantization.py.te.checkpointpath is already covered and works.Validation