Dynamic Scale Factor Calculations for Key/Value Scales With FP8 KV Caching#317
Merged
Dynamic Scale Factor Calculations for Key/Value Scales With FP8 KV Caching#317
Conversation
…cales flag in arg_utils.py
gshtras
reviewed
Dec 10, 2024
Collaborator
gshtras
left a comment
There was a problem hiding this comment.
Overall looks good, aside from a few minor questions and comments.
Also pending conflict resolution
vllm/envs.py
Outdated
| VLLM_USE_ROCM_SKINNY_GEMM: bool = True | ||
| VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True | ||
| VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True | ||
| VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = False |
| VLLM_MOE_PADDING: bool = False | ||
| VLLM_FP8_PADDING: bool = True | ||
| VLLM_ENABLE_V1_MULTIPROCESSING: bool = False | ||
| K_SCALE_CONSTANT: int = 200 |
Collaborator
There was a problem hiding this comment.
Do we want different values?
| for field in dataclasses.fields(attn_backend.get_metadata_cls()): | ||
| if field.name in tensor_dict: | ||
| if field.name in tensor_dict and field.name != \ | ||
| 'enable_kv_scales_calculation': |
Collaborator
There was a problem hiding this comment.
Not sure, why do we filter it out here?
benchmarks/P3L.py
Outdated
| engine_args = EngineArgs.from_cli_args(args) | ||
| llm = LLM(**dataclasses.asdict(engine_args)) | ||
|
|
||
| llm = LLM( |
Collaborator
There was a problem hiding this comment.
This is not needed now with the **dataclasses.asdict(engine_args)
| self._k_scale = 1.0 | ||
| self._v_scale = 1.0 | ||
| self.calculate_kv_scales = calculate_kv_scales | ||
| self._k_scale = torch.tensor(1.0, dtype=torch.float32) |
Collaborator
There was a problem hiding this comment.
Possibly torch.ones is better
shajrawi
approved these changes
Dec 17, 2024
gshtras
added a commit
that referenced
this pull request
Jan 7, 2025
…ching (#317) * Changed _k_scale and _v_scale to tensors * fixed rocm paged attention with tensor kv scales * Added on the fly scale factor calculation * trying to fix attn metadata * fixed AttentionMetadata issue, updated description for calculate-kv-scales flag in arg_utils.py * Changed K and V scale constants * Removed unneeded comment * Changes to pass format.sh, also fixed lingering k_scale/v_scale : float * Fix for TP > 1 * Ran format.sh * Removed legacy kv_scale loading from the json file * Removed the outdated kv cache docs * Revert some unwanted changes --------- Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR implements a simple method for calculating k_scale and v_scale in the attention layer. This is especially useful in the absence of scale factors in the model checkpoints, where the previous solution was to default the scale factors to 1.0.
This feature necessitated changing k_scale and v_scale to tensors rather than floats, which should be useful for exploring different types of key & value scaling in the future (e.g. per-channel).
Here are a few PPL measurements taken using Llama 3.1 70B, demonstrating superior accuracy compared to using a scale factor of 1.0 for both k_scale and v_scale.