Skip to content

Add top-k sampling support to llm Sampler (#19122)#19122

Merged
meta-codesync[bot] merged 1 commit intomainfrom
export-D102385104
Apr 26, 2026
Merged

Add top-k sampling support to llm Sampler (#19122)#19122
meta-codesync[bot] merged 1 commit intomainfrom
export-D102385104

Conversation

@kirklandsign
Copy link
Copy Markdown
Contributor

@kirklandsign kirklandsign commented Apr 24, 2026

Summary:

The existing Sampler in extension/llm/sampler supports greedy
(temperature=0), full multinomial, and top-p (nucleus) sampling, but had
no top-k mode. Top-k is a common companion / alternative to top-p and is
trivial to add given the existing infrastructure.

This change adds:

  • Sampler::set_topk(int32_t) to enable top-k filtering. k <= 0 or
    k >= vocab_size disables it (default).
  • A new private sample_topk template that uses std::partial_sort to
    pick the k highest-probability tokens (O(n log k) — heap-based partial
    sort), and CDF-samples from the top-k slice (scaling coin by the
    top-k sum is mathematically equivalent to renormalizing). Same shape
    as the existing sample_topp. Sums and CDF are accumulated in float
    so FP16/BF16 inputs don't lose precision over k summands.
  • Wiring in Sampler::sample so that when top-k is enabled it takes
    precedence over top-p — the two modes are mutually exclusive in this
    implementation. With both disabled, falls back to multinomial,
    matching today's behaviour.

The existing constructors are unchanged, so this is fully
backward-compatible. Float, FP16, BF16, and uint16_t logit dtypes are
all supported via the existing template instantiations.

A follow-up diff will wire set_topk into the text_llm_runner
GenerationConfig so callers can opt into top-k from the runner.

Differential Revision: D102385104

Copilot AI review requested due to automatic review settings April 24, 2026 21:31
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 24, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19122

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (2 Unrelated Failures)

As of commit b7688fe with merge base 56da964 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 24, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Apr 24, 2026

@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D102385104.

@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds top-k sampling mode to the existing LLM Sampler, complementing current greedy, multinomial, and top-p sampling options.

Changes:

  • Added Sampler::set_topk(int32_t) API and internal topk_ state to enable/disable top-k filtering.
  • Implemented sample_topk using std::partial_sort and integrated it into Sampler::sample with precedence over top-p.
  • Added unit tests covering top-k candidate restriction, disabling behavior, FP16 path, and topk=1 argmax behavior.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
extension/llm/sampler/sampler.h Adds set_topk API, declares sample_topk, and stores topk_ configuration.
extension/llm/sampler/sampler.cpp Implements sample_topk and wires top-k selection into Sampler::sample.
extension/llm/sampler/test/test_sampler.cpp Adds test coverage for top-k behavior (including FP16 smoke coverage).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread extension/llm/sampler/sampler.cpp Outdated
Comment on lines +104 to +113
T topk_sum = 0;
for (int i = 0; i < k; i++) {
topk_sum += probindex[i].prob;
}

// Sample from the (implicitly renormalized) top-k distribution.
const T r = coin * topk_sum;
T cdf = 0;
for (int i = 0; i < k; i++) {
cdf += probindex[i].prob;
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample_topk accumulates topk_sum / cdf and stores the random threshold r in T. Since coin is float, keeping these accumulators as float would avoid extra quantization for low-precision dtypes (FP16/BF16) and make the sampling thresholding more numerically stable. This is optional, but low-cost to do here.

Suggested change
T topk_sum = 0;
for (int i = 0; i < k; i++) {
topk_sum += probindex[i].prob;
}
// Sample from the (implicitly renormalized) top-k distribution.
const T r = coin * topk_sum;
T cdf = 0;
for (int i = 0; i < k; i++) {
cdf += probindex[i].prob;
float topk_sum = 0.0f;
for (int i = 0; i < k; i++) {
topk_sum += static_cast<float>(probindex[i].prob);
}
// Sample from the (implicitly renormalized) top-k distribution.
const float r = coin * topk_sum;
float cdf = 0.0f;
for (int i = 0; i < k; i++) {
cdf += static_cast<float>(probindex[i].prob);

Copilot uses AI. Check for mistakes.
Comment on lines +71 to +73
// topk=0 means disabled. With low temperature and topp disabled, sampling
// collapses to multinomial over the full vocab, but the dominant token
// should still win the vast majority of the time.
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test comment says "With low temperature" but the sampler is constructed with temperature = 1.0f. Update the comment to match the actual setup (or change the temperature) to avoid confusion about what behavior is being exercised.

Suggested change
// topk=0 means disabled. With low temperature and topp disabled, sampling
// collapses to multinomial over the full vocab, but the dominant token
// should still win the vast majority of the time.
// topk=0 means disabled. With topp disabled, sampling uses multinomial
// over the full vocab, but the dominant token should still win the vast
// majority of the time because its logit is much larger than the rest.

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin requested a review from Gasoonjia April 24, 2026 21:35
@mergennachin
Copy link
Copy Markdown
Contributor

@claude Review this PR

@claude
Copy link
Copy Markdown

claude Bot commented Apr 24, 2026

Claude finished @mergennachin's task in 1m 2s —— View job


Review of PR #19122: Add top-k sampling support to llm Sampler

Overall this is a clean, well-structured addition that follows the existing patterns in the sampler. The implementation is correct, the tests are reasonable, and the API is backward-compatible. A few observations below, ranging from a real numerical concern to minor nits.


Issues

1. Per-call heap allocation (acknowledged but worth emphasizing)

sampler.cpp:87-88sample_topk allocates vocab_size ProbIndex<T> entries on every call via std::make_unique. For autoregressive decoding this means one heap allocation per generated token. The existing sample_topp has the same pattern, so this is pre-existing tech debt, but the TODO comment at line 78 is good. Worth noting that for top-k specifically, an std::nth_element + selection over the original array (no copy) could avoid the allocation entirely, though that would mutate the input probabilities array (which sample() already does for temperature scaling/softmax, so it may be acceptable).

2. Numerical precision with low-precision types (FP16/BF16)

sampler.cpp:104-111 — The accumulators topk_sum, cdf, and the threshold r are all typed as T. When T is Half or BFloat16, accumulating even a small number of probabilities can lose precision. The existing sample_topp has the same issue, so this is consistent, but Copilot's suggestion to use float accumulators here is a sound improvement that would be low-cost. Since this is new code, it's a good opportunity to get it right from the start even if sample_topp doesn't do it yet.

Fix accumulator precision →

3. const T& r in existing sample_topp is a dangling reference

sampler.cpp:161 (pre-existing, not introduced by this PR) — const T& r = coin * cumulative_prob; binds a reference to a temporary. In practice this works because the temporary's lifetime is extended to the reference's scope, but it's misleading style. The new sample_topk correctly uses const T r (by value) at line 110. Just flagging for awareness.


Minor / Nits

4. Missing set_topk in constructors

The only way to enable top-k is via set_topk() after construction. This is fine for now and keeps the constructors unchanged, but the PR description mentions a follow-up to wire this into GenerationConfig — at that point, consider adding topk as an optional constructor parameter (defaulting to 0) to allow one-shot initialization.

5. Test comment mismatch (test_sampler.cpp:71)

The comment says "With low temperature" but temperature is 1.0f. Copilot caught this too. Minor but worth fixing to avoid confusion.

6. No test for top-k + top-p interaction

The PR says "when top-k is enabled it takes precedence over top-p — the two modes are mutually exclusive." It would strengthen the test coverage to have a test that sets both topk and topp to valid values and verifies that only the top-k candidates are selected (proving top-k takes precedence). Currently this invariant is only implicitly tested because all tests set topp=0.0.

7. Test uses std::set::count — prefer contains (C++20) or find

test_sampler.cpp:66allowed.count(out) works but allowed.find(out) != allowed.end() is more idiomatic for a membership check in pre-C++20 code. Very minor.


What looks good

  • Use of std::partial_sort (O(n log k)) instead of full sort is the right choice for top-k.
  • The "scale coin by top-k sum" trick to avoid explicit renormalization is clean and matches the pattern in sample_topp.
  • Defensive k <= 0 guard in the private helper is good practice.
  • Tests cover the key behaviors: candidate restriction, disabled state, FP16, and k=1 degeneracy.
  • Backward compatibility is preserved — no constructor changes, topk_ defaults to 0.

@kirklandsign
Copy link
Copy Markdown
Contributor Author

@claude review

@claude
Copy link
Copy Markdown

claude Bot commented Apr 24, 2026

Claude finished @kirklandsign's task in 55s —— View job


Review of PR #19122: Add top-k sampling support to llm Sampler

  • Read changed files and gather context
  • Analyze implementation for correctness and style
  • Post review feedback

The implementation is clean and follows the existing patterns well. std::partial_sort is the right choice, the coin-scaling trick avoids explicit renormalization, and backward compatibility is preserved. A few items worth considering:

1. Accumulator precision for FP16/BF16sampler.cpp:104-111

topk_sum, cdf, and r are typed as T. When T is Half or BFloat16, accumulating probabilities loses precision. Since coin is already float, using float accumulators here is low-cost and strictly more correct. The existing sample_topp has the same issue, but since this is new code it's a good opportunity to do it right.

Fix accumulator precision →

2. Test comment mismatchtest_sampler.cpp:71

Comment says "With low temperature" but temperature = 1.0f. The body of the comment is actually accurate ("topk=0 means disabled... sampling collapses to multinomial over the full vocab"), so just the "With low temperature" fragment is stale/wrong. Minor but worth cleaning up.

Fix comment →

3. No test for top-k + top-p interaction

The PR states top-k takes precedence over top-p when both are enabled. Adding a test that sets both topk and topp to valid values and asserts only top-k candidates appear would explicitly verify this invariant. Currently all tests set topp=0.0, so this code path in sample() (lines 238-240) is tested only by the absence of a topp test accidentally passing.

4. Per-call heap allocationsampler.cpp:87-88

The TODO at line 77 acknowledges this. For completeness: since top-k doesn't need the full array sorted, an alternative is to use std::nth_element directly on the probabilities array (O(n) average), avoiding both the copy and the allocation. The sample() function already mutates logits in place (temperature scaling + softmax), so mutating probabilities further would be acceptable. This could be a follow-up optimization.

Everything else looks good — the defensive k <= 0 guard, the k=1 degeneracy test, and the FP16 smoke test are all sensible additions.


@meta-codesync meta-codesync Bot force-pushed the export-D102385104 branch from b48387f to 64121b7 Compare April 24, 2026 22:05
Summary:

The existing `Sampler` in `extension/llm/sampler` supports greedy
(temperature=0), full multinomial, and top-p (nucleus) sampling, but had
no top-k mode. Top-k is a common companion / alternative to top-p and is
trivial to add given the existing infrastructure.

This change adds:
- `Sampler::set_topk(int32_t)` to enable top-k filtering. `k <= 0` or
  `k >= vocab_size` disables it (default).
- A new private `sample_topk` template that uses `std::partial_sort` to
  pick the k highest-probability tokens (O(n log k) — heap-based partial
  sort), and CDF-samples from the top-k slice (scaling `coin` by the
  top-k sum is mathematically equivalent to renormalizing). Same shape
  as the existing `sample_topp`. Sums and CDF are accumulated in `float`
  so FP16/BF16 inputs don't lose precision over k summands.
- Wiring in `Sampler::sample` so that when top-k is enabled it takes
  precedence over top-p — the two modes are mutually exclusive in this
  implementation. With both disabled, falls back to multinomial,
  matching today's behaviour.

The existing constructors are unchanged, so this is fully
backward-compatible. Float, FP16, BF16, and uint16_t logit dtypes are
all supported via the existing template instantiations.

A follow-up diff will wire `set_topk` into the `text_llm_runner`
`GenerationConfig` so callers can opt into top-k from the runner.

Differential Revision: D102385104
@meta-codesync meta-codesync Bot changed the title Add top-k sampling support to llm Sampler Add top-k sampling support to llm Sampler (#19122) Apr 24, 2026
Copilot AI review requested due to automatic review settings April 24, 2026 22:35
@meta-codesync meta-codesync Bot force-pushed the export-D102385104 branch from 64121b7 to b7688fe Compare April 24, 2026 22:35
@kirklandsign kirklandsign review requested due to automatic review settings April 24, 2026 22:35
@kirklandsign
Copy link
Copy Markdown
Contributor Author

@claude review

@claude
Copy link
Copy Markdown

claude Bot commented Apr 24, 2026

Claude finished @kirklandsign's task in 56s —— View job


Review of PR #19122: Add top-k sampling support to llm Sampler

  • Read changed files and gather context
  • Analyze implementation correctness and style
  • Review tests for coverage and correctness
  • Post review feedback

Clean addition that follows the existing sampler patterns well. The use of std::partial_sort, the coin-scaling trick, and the float accumulators are all the right choices. Tests are thorough. A few items:

1. Pre-existing: sample_topp still uses T-typed accumulatorssampler.cpp:151-163

The new sample_topk correctly uses float accumulators (lines 105-114), which is good. But the existing sample_topp at lines 151-163 still accumulates cumulative_prob and cdf as T, and line 162 computes const T& r = coin * cumulative_prob which binds a reference to a temporary of type T. This means FP16/BF16 precision loss in sample_topp is still present. Not introduced by this PR, but worth a follow-up to align sample_topp with the improved pattern established here.

2. Pre-existing: sample_mult also uses T accumulatorssampler.cpp:62

Same issue — T cdf = 0.0; will lose precision for FP16/BF16 vocabs. Again not from this PR, but noting for completeness since the new code sets a better standard.

3. topk_ is not initialized in the 4-arg constructorsampler.cpp:173-181

The member has int32_t topk_ = 0; as a default member initializer in the header (sampler.h:73), which covers both constructors. This is fine in C++11 and later since the default member initializer applies when the constructor's member-initializer-list doesn't mention topk_. Just confirming this is intentional and correct — no issue here.

4. Potential future improvement: combined top-k + top-p

The PR description says the two modes are mutually exclusive with top-k taking precedence. This is a reasonable choice for a first implementation. Many inference frameworks (e.g., vLLM, HuggingFace transformers) apply top-k first to narrow the candidate set, then top-p within that set. If that becomes desirable, the structure here makes it straightforward — apply partial_sort for top-k, then truncate by cumulative probability within the top-k slice. Not a blocker, just noting for the follow-up.

5. Tests look solid

The test suite covers the key behaviors well:

  • TestTopKRestrictsToCandidates — verifies only top-k tokens are sampled (50 trials)
  • TestTopKDisabledByZero — confirms k=0 falls back to multinomial
  • TestTopKWithFP16 — exercises the Half template path
  • TestTopKEqualsOneIsArgmax — degeneracy test
  • TestTopKTakesPrecedenceOverTopP — verifies the precedence invariant with topp=0.99

This addresses the gap the previous review noted about top-k + top-p interaction testing.

Overall this is a straightforward, well-implemented addition. LGTM with no blocking issues.


@meta-codesync meta-codesync Bot merged commit 2d9bbc1 into main Apr 26, 2026
280 of 288 checks passed
@meta-codesync meta-codesync Bot deleted the export-D102385104 branch April 26, 2026 03:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants