Add top-k sampling support to llm Sampler (#19122)#19122
Add top-k sampling support to llm Sampler (#19122)#19122meta-codesync[bot] merged 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19122
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (2 Unrelated Failures)As of commit b7688fe with merge base 56da964 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D102385104. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Adds top-k sampling mode to the existing LLM Sampler, complementing current greedy, multinomial, and top-p sampling options.
Changes:
- Added
Sampler::set_topk(int32_t)API and internaltopk_state to enable/disable top-k filtering. - Implemented
sample_topkusingstd::partial_sortand integrated it intoSampler::samplewith precedence over top-p. - Added unit tests covering top-k candidate restriction, disabling behavior, FP16 path, and
topk=1argmax behavior.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| extension/llm/sampler/sampler.h | Adds set_topk API, declares sample_topk, and stores topk_ configuration. |
| extension/llm/sampler/sampler.cpp | Implements sample_topk and wires top-k selection into Sampler::sample. |
| extension/llm/sampler/test/test_sampler.cpp | Adds test coverage for top-k behavior (including FP16 smoke coverage). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| T topk_sum = 0; | ||
| for (int i = 0; i < k; i++) { | ||
| topk_sum += probindex[i].prob; | ||
| } | ||
|
|
||
| // Sample from the (implicitly renormalized) top-k distribution. | ||
| const T r = coin * topk_sum; | ||
| T cdf = 0; | ||
| for (int i = 0; i < k; i++) { | ||
| cdf += probindex[i].prob; |
There was a problem hiding this comment.
sample_topk accumulates topk_sum / cdf and stores the random threshold r in T. Since coin is float, keeping these accumulators as float would avoid extra quantization for low-precision dtypes (FP16/BF16) and make the sampling thresholding more numerically stable. This is optional, but low-cost to do here.
| T topk_sum = 0; | |
| for (int i = 0; i < k; i++) { | |
| topk_sum += probindex[i].prob; | |
| } | |
| // Sample from the (implicitly renormalized) top-k distribution. | |
| const T r = coin * topk_sum; | |
| T cdf = 0; | |
| for (int i = 0; i < k; i++) { | |
| cdf += probindex[i].prob; | |
| float topk_sum = 0.0f; | |
| for (int i = 0; i < k; i++) { | |
| topk_sum += static_cast<float>(probindex[i].prob); | |
| } | |
| // Sample from the (implicitly renormalized) top-k distribution. | |
| const float r = coin * topk_sum; | |
| float cdf = 0.0f; | |
| for (int i = 0; i < k; i++) { | |
| cdf += static_cast<float>(probindex[i].prob); |
| // topk=0 means disabled. With low temperature and topp disabled, sampling | ||
| // collapses to multinomial over the full vocab, but the dominant token | ||
| // should still win the vast majority of the time. |
There was a problem hiding this comment.
The test comment says "With low temperature" but the sampler is constructed with temperature = 1.0f. Update the comment to match the actual setup (or change the temperature) to avoid confusion about what behavior is being exercised.
| // topk=0 means disabled. With low temperature and topp disabled, sampling | |
| // collapses to multinomial over the full vocab, but the dominant token | |
| // should still win the vast majority of the time. | |
| // topk=0 means disabled. With topp disabled, sampling uses multinomial | |
| // over the full vocab, but the dominant token should still win the vast | |
| // majority of the time because its logit is much larger than the rest. |
|
@claude Review this PR |
|
Claude finished @mergennachin's task in 1m 2s —— View job Review of PR #19122: Add top-k sampling support to llm SamplerOverall this is a clean, well-structured addition that follows the existing patterns in the sampler. The implementation is correct, the tests are reasonable, and the API is backward-compatible. A few observations below, ranging from a real numerical concern to minor nits. Issues1. Per-call heap allocation (acknowledged but worth emphasizing)
2. Numerical precision with low-precision types (FP16/BF16)
3.
Minor / Nits4. Missing The only way to enable top-k is via 5. Test comment mismatch ( The comment says "With low temperature" but temperature is 6. No test for top-k + top-p interaction The PR says "when top-k is enabled it takes precedence over top-p — the two modes are mutually exclusive." It would strengthen the test coverage to have a test that sets both 7. Test uses
What looks good
|
|
@claude review |
|
Claude finished @kirklandsign's task in 55s —— View job Review of PR #19122: Add top-k sampling support to llm Sampler
The implementation is clean and follows the existing patterns well. 1. Accumulator precision for FP16/BF16 —
2. Test comment mismatch — Comment says "With low temperature" but 3. No test for top-k + top-p interaction The PR states top-k takes precedence over top-p when both are enabled. Adding a test that sets both 4. Per-call heap allocation — The TODO at line 77 acknowledges this. For completeness: since top-k doesn't need the full array sorted, an alternative is to use Everything else looks good — the defensive |
b48387f to
64121b7
Compare
Summary: The existing `Sampler` in `extension/llm/sampler` supports greedy (temperature=0), full multinomial, and top-p (nucleus) sampling, but had no top-k mode. Top-k is a common companion / alternative to top-p and is trivial to add given the existing infrastructure. This change adds: - `Sampler::set_topk(int32_t)` to enable top-k filtering. `k <= 0` or `k >= vocab_size` disables it (default). - A new private `sample_topk` template that uses `std::partial_sort` to pick the k highest-probability tokens (O(n log k) — heap-based partial sort), and CDF-samples from the top-k slice (scaling `coin` by the top-k sum is mathematically equivalent to renormalizing). Same shape as the existing `sample_topp`. Sums and CDF are accumulated in `float` so FP16/BF16 inputs don't lose precision over k summands. - Wiring in `Sampler::sample` so that when top-k is enabled it takes precedence over top-p — the two modes are mutually exclusive in this implementation. With both disabled, falls back to multinomial, matching today's behaviour. The existing constructors are unchanged, so this is fully backward-compatible. Float, FP16, BF16, and uint16_t logit dtypes are all supported via the existing template instantiations. A follow-up diff will wire `set_topk` into the `text_llm_runner` `GenerationConfig` so callers can opt into top-k from the runner. Differential Revision: D102385104
64121b7 to
b7688fe
Compare
|
@claude review |
|
Claude finished @kirklandsign's task in 56s —— View job Review of PR #19122: Add top-k sampling support to llm Sampler
Clean addition that follows the existing sampler patterns well. The use of 1. Pre-existing: The new 2. Pre-existing: Same issue — 3. The member has 4. Potential future improvement: combined top-k + top-p The PR description says the two modes are mutually exclusive with top-k taking precedence. This is a reasonable choice for a first implementation. Many inference frameworks (e.g., vLLM, HuggingFace 5. Tests look solid The test suite covers the key behaviors well:
This addresses the gap the previous review noted about top-k + top-p interaction testing. Overall this is a straightforward, well-implemented addition. LGTM with no blocking issues. |
Summary:
The existing
Samplerinextension/llm/samplersupports greedy(temperature=0), full multinomial, and top-p (nucleus) sampling, but had
no top-k mode. Top-k is a common companion / alternative to top-p and is
trivial to add given the existing infrastructure.
This change adds:
Sampler::set_topk(int32_t)to enable top-k filtering.k <= 0ork >= vocab_sizedisables it (default).sample_topktemplate that usesstd::partial_sorttopick the k highest-probability tokens (O(n log k) — heap-based partial
sort), and CDF-samples from the top-k slice (scaling
coinby thetop-k sum is mathematically equivalent to renormalizing). Same shape
as the existing
sample_topp. Sums and CDF are accumulated infloatso FP16/BF16 inputs don't lose precision over k summands.
Sampler::sampleso that when top-k is enabled it takesprecedence over top-p — the two modes are mutually exclusive in this
implementation. With both disabled, falls back to multinomial,
matching today's behaviour.
The existing constructors are unchanged, so this is fully
backward-compatible. Float, FP16, BF16, and uint16_t logit dtypes are
all supported via the existing template instantiations.
A follow-up diff will wire
set_topkinto thetext_llm_runnerGenerationConfigso callers can opt into top-k from the runner.Differential Revision: D102385104