diff --git a/extension/llm/sampler/sampler.cpp b/extension/llm/sampler/sampler.cpp index 3beda885d6f..d41da96f07e 100644 --- a/extension/llm/sampler/sampler.cpp +++ b/extension/llm/sampler/sampler.cpp @@ -69,6 +69,56 @@ int32_t Sampler::sample_mult(T* probabilities, float coin) { return vocab_size_ - 1; // in case of rounding errors } +template +int32_t Sampler::sample_topk(T* probabilities, float coin) { + // top-k sampling samples from the k highest-probability tokens. + // coin is a random number in [0, 1), usually from random_f32(). + // + // TODO: probindex is allocated on every call; lifting it to a member + // would avoid per-token heap allocation in autoregressive loops. + const int n = vocab_size_; + const int k = std::min(topk_, n); + // Defensive: callers gate on topk_ > 0, but a private helper should not + // rely on external invariants. Fall back to a deterministic index. + if (k <= 0) { + return 0; + } + + std::unique_ptr[]> probindex = + std::make_unique[]>(n); + for (int i = 0; i < n; i++) { + probindex[i].index = i; + probindex[i].prob = probabilities[i]; + } + + auto compare = [](const ProbIndex& a, const ProbIndex& b) { + return a.prob > b.prob; + }; + // Partial sort: only the top-k entries need to be sorted in descending order. + std::partial_sort( + probindex.get(), probindex.get() + k, probindex.get() + n, compare); + + // Sum of the top-k probabilities. Used to scale `coin` instead of + // explicitly renormalizing the k probs — mathematically equivalent and + // saves k divisions. Accumulate in float so FP16/BF16 inputs don't lose + // precision over k summands. + float topk_sum = 0.0f; + for (int i = 0; i < k; i++) { + topk_sum += static_cast(probindex[i].prob); + } + + // Sample from the (implicitly renormalized) top-k distribution. + const float r = coin * topk_sum; + float cdf = 0.0f; + for (int i = 0; i < k; i++) { + cdf += static_cast(probindex[i].prob); + if (r < cdf) { + return probindex[i].index; + } + } + return probindex[k - 1].index; // in case of rounding errors +} + template int32_t Sampler::sample_topp(T* probabilities, float coin) { // top-p sampling (or "nucleus sampling") samples from the smallest set of @@ -186,7 +236,10 @@ int32_t Sampler::sample(T* logits) { // flip a (float) coin (this is our source of entropy for sampling) float coin = random_f32(&rng_state_); // we sample from this distribution to get the next token - if (topp_ <= 0 || topp_ >= 1) { + if (topk_ > 0 && topk_ < vocab_size_) { + // top-k sampling, restrict to the k most likely tokens + next = sample_topk(logits, coin); + } else if (topp_ <= 0 || topp_ >= 1) { // simply sample from the predicted probability distribution next = sample_mult(logits, coin); } else { diff --git a/extension/llm/sampler/sampler.h b/extension/llm/sampler/sampler.h index 1525f38692a..4a480edc1ef 100644 --- a/extension/llm/sampler/sampler.h +++ b/extension/llm/sampler/sampler.h @@ -44,6 +44,13 @@ class ET_EXPERIMENTAL Sampler { Sampler(int32_t vocab_size, float temperature); + // Enable top-k filtering. k <= 0 or k >= vocab_size disables top-k. + // When top-k is enabled, top-p is ignored — the two modes are mutually + // exclusive in this implementation. + void set_topk(int32_t topk) { + topk_ = topk; + } + template int32_t sample(T* logits); @@ -51,6 +58,8 @@ class ET_EXPERIMENTAL Sampler { template int32_t sample_topp(T* probabilities, float coin); template + int32_t sample_topk(T* probabilities, float coin); + template int32_t sample_mult(T* probabilities, float coin); template int32_t sample_argmax(T* probabilities); @@ -60,6 +69,8 @@ class ET_EXPERIMENTAL Sampler { // reciprocal of temperature, or 0 if temperature == 0. float inv_temperature_; float topp_; + // 0 (or >= vocab_size_) means top-k is disabled. + int32_t topk_ = 0; unsigned long long rng_state_; }; diff --git a/extension/llm/sampler/test/test_sampler.cpp b/extension/llm/sampler/test/test_sampler.cpp index 044a39458ea..8463c2e9678 100644 --- a/extension/llm/sampler/test/test_sampler.cpp +++ b/extension/llm/sampler/test/test_sampler.cpp @@ -8,6 +8,8 @@ #include +#include + #include #include @@ -39,3 +41,114 @@ TEST(SamplerTest, TestArgMaxWithFP16) { input[0][0][396] = 1.0f; EXPECT_EQ(sampler.sample(input.data_ptr()), 396); } + +TEST(SamplerTest, TestTopKRestrictsToCandidates) { + // With topk=3, sampling must always return one of the top-3 indices, + // regardless of the random draw. + Sampler sampler{ + /*vocab_size*/ 100, + /*temperature*/ 1.0f, + /*topp*/ 0.0f, // disable top-p so we exercise top-k alone + /*rng_seed*/ 42}; + sampler.set_topk(3); + + // Construct logits where indices {7, 13, 42} dominate. + torch::Tensor input = torch::full({100}, -10.0f, at::kFloat); + input[7] = 5.0f; + input[13] = 4.5f; + input[42] = 4.0f; + + std::set allowed = {7, 13, 42}; + for (int trial = 0; trial < 50; ++trial) { + // Re-fill logits each trial because sample() mutates them in place. + torch::Tensor logits = input.clone(); + int32_t out = sampler.sample(logits.data_ptr()); + EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out; + } +} + +TEST(SamplerTest, TestTopKDisabledByZero) { + // topk=0 means disabled. With topp disabled, sampling collapses to + // multinomial over the full vocab, but the dominant token should still + // win the vast majority of the time. + Sampler sampler{ + /*vocab_size*/ 50, + /*temperature*/ 1.0f, + /*topp*/ 0.0f, + /*rng_seed*/ 7}; + sampler.set_topk(0); // disabled + + torch::Tensor input = torch::full({50}, -10.0f, at::kFloat); + input[11] = 20.0f; // dominant + + int hits = 0; + for (int trial = 0; trial < 20; ++trial) { + torch::Tensor logits = input.clone(); + if (sampler.sample(logits.data_ptr()) == 11) { + hits++; + } + } + EXPECT_GE(hits, 18); // dominant token should win nearly every time +} + +TEST(SamplerTest, TestTopKWithFP16) { + // Smoke test the FP16 template instantiation of the top-k path. + Sampler sampler{ + /*vocab_size*/ 50, + /*temperature*/ 1.0f, + /*topp*/ 0.0f, + /*rng_seed*/ 99}; + sampler.set_topk(2); + + torch::Tensor input = torch::full({50}, -10.0f, at::kHalf); + input[3] = 5.0f; + input[8] = 4.5f; + + std::set allowed = {3, 8}; + for (int trial = 0; trial < 30; ++trial) { + torch::Tensor logits = input.clone(); + int32_t out = sampler.sample(logits.data_ptr()); + EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out; + } +} + +TEST(SamplerTest, TestTopKEqualsOneIsArgmax) { + // topk=1 should behave like greedy argmax even with temperature > 0. + Sampler sampler{ + /*vocab_size*/ 100, + /*temperature*/ 1.0f, + /*topp*/ 0.0f, + /*rng_seed*/ 123}; + sampler.set_topk(1); + + torch::Tensor input = torch::rand({100}, at::kFloat); + input[57] = 100.0f; // make 57 the unambiguous max + + for (int trial = 0; trial < 10; ++trial) { + torch::Tensor logits = input.clone(); + EXPECT_EQ(sampler.sample(logits.data_ptr()), 57); + } +} + +TEST(SamplerTest, TestTopKTakesPrecedenceOverTopP) { + // When both top-k and top-p are set, top-k should restrict the candidate + // set; top-p alone would admit a third token that top-k=2 must exclude. + Sampler sampler{ + /*vocab_size*/ 100, + /*temperature*/ 1.0f, + /*topp*/ 0.99f, // would keep nearly the whole vocab on its own + /*rng_seed*/ 99}; + sampler.set_topk(2); + + torch::Tensor input = torch::full({100}, -10.0f, at::kFloat); + input[3] = 5.0f; + input[8] = 4.5f; + input[19] = 4.0f; // would be in the top-p set but is excluded by top-k=2 + + std::set allowed = {3, 8}; + for (int trial = 0; trial < 50; ++trial) { + torch::Tensor logits = input.clone(); + int32_t out = sampler.sample(logits.data_ptr()); + EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out; + } +}