Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes the scale format (scale_fmt) parameter handling across the FP8 quantization pipeline. The changes enable proper propagation of the scale_fmt parameter to control whether scale rounding is applied during quantization operations.
Key changes include:
- Added scale_fmt parameter to fill_kv_cache_blocked_fp8 function and its test coverage
- Implemented fast scale rounding functions (fast_log2_ceil, fast_pow2, fast_round_scale) in fill_kv_cache.py
- Updated callers to pass scale_fmt through the call chain (nsa.py, attention.py)
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py | Added scale_fmt parameter support with ROUND_SCALE flag and fast rounding utility functions for FP8 quantization |
| tests/pytorch/kernel/test_fill_kv_cache.py | Added scale_fmt fixture and parametrized tests to verify both None and 'ue8m0' scale formats |
| lmdeploy/pytorch/backends/cuda/nsa.py | Propagated scale_fmt parameter to quant_fp8 and fill_kv_cache_blocked_fp8 calls |
| lmdeploy/pytorch/backends/cuda/attention.py | Added hardcoded scale_fmt='ue8m0' to fill_kv_cache_blocked_fp8 call |
| lmdeploy/pytorch/backends/cuda/blockedf8_modules.py | Removed scale_fmt parameter from warmup method call (not accessible in that context) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| quantization.""" | ||
|
|
There was a problem hiding this comment.
The scale_fmt parameter is missing documentation. Consider adding a docstring parameter description explaining what values are valid (None or 'ue8m0') and what this parameter controls.
| quantization.""" | |
| quantization. | |
| Args: | |
| scale_fmt (str, optional): Format of the fp8 scaling factors. Valid | |
| values are ``None`` and ``'ue8m0'``. When set to ``'ue8m0'``, | |
| scaling factors are stored/interpreted using the UE8M0 fp8 scale | |
| format; when ``None``, the default scale layout for this kernel | |
| is used. | |
| """ |
| max_q_seqlen=max_q_seqlen, | ||
| block_offsets=block_offsets, | ||
| group_size=128, | ||
| scale_fmt='ue8m0', |
There was a problem hiding this comment.
The scale_fmt value is hardcoded as 'ue8m0' here. For consistency with other parts of the codebase and to make this configurable, consider making this value configurable through a class attribute or parameter, similar to how it's handled in the NSAOp class.
| scale_fmt='ue8m0', | |
| scale_fmt=getattr(self, 'scale_fmt', 'ue8m0'), |
| group_size: int = 128, | ||
| kv_layout: str = 'bshd'): | ||
| kv_layout: str = 'bshd', | ||
| scale_fmt: str = None): |
There was a problem hiding this comment.
Missing input validation for the scale_fmt parameter. The parameter should be validated to ensure it's either None or 'ue8m0', similar to the validation in _quant_fp8_launcher (line 107 in blocked_gemm_fp8.py: assert scale_fmt in (None, 'ue8m0')).
| group_size: int = 128, | ||
| kv_layout: str = 'bshd'): | ||
| kv_layout: str = 'bshd', | ||
| scale_fmt: str = None): |
There was a problem hiding this comment.
The type annotation for scale_fmt should be Optional[str] for consistency with other functions in the codebase (e.g., quant_fp8, quant_fp8_tma in blocked_gemm_fp8.py). Change the type annotation from 'str = None' to 'Optional[str] = None'.
| scale_fmt: str = None): | |
| scale_fmt: Optional[str] = None): |
| @triton.jit | ||
| def fast_log2_ceil(x): | ||
| bits_x = tl.cast(x, tl.uint32, bitcast=True) | ||
| exp_x = (bits_x >> 23) & 0xFF | ||
| man_bits = bits_x & ((1 << 23) - 1) | ||
| tmp = exp_x - 127 + tl.where(man_bits != 0, 1, 0) | ||
| return tl.cast(tmp, tl.int32) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def fast_pow2(x): | ||
| bits_x = (x + 127) << 23 | ||
| return tl.cast(bits_x, tl.float32, bitcast=True) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def fast_round_scale(amax, fp8_max_inv): | ||
| return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) |
There was a problem hiding this comment.
The functions fast_log2_ceil, fast_pow2, and fast_round_scale are duplicated from lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py. Consider extracting these utility functions to a shared module to avoid code duplication and improve maintainability.
|
|
||
| @pytest.fixture | ||
| def gt(self, k_states, v_states, group_size, quant_dtype): | ||
| def gt(self, k_states, v_states, group_size, quant_dtype, scale_fmt): |
There was a problem hiding this comment.
This method requires 6 positional arguments, whereas overridden TestFillKVCache.gt requires 9.
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
Please describe the motivation of this PR and the goal you want to achieve through this PR.
Modification
Please briefly describe what modification is made in this PR.
BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist