Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion extension/llm/sampler/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,56 @@ int32_t Sampler::sample_mult(T* probabilities, float coin) {
return vocab_size_ - 1; // in case of rounding errors
}

template <typename T>
int32_t Sampler::sample_topk(T* probabilities, float coin) {
// top-k sampling samples from the k highest-probability tokens.
// coin is a random number in [0, 1), usually from random_f32().
//
// TODO: probindex is allocated on every call; lifting it to a member
// would avoid per-token heap allocation in autoregressive loops.
const int n = vocab_size_;
const int k = std::min(topk_, n);
// Defensive: callers gate on topk_ > 0, but a private helper should not
// rely on external invariants. Fall back to a deterministic index.
if (k <= 0) {
return 0;
}

std::unique_ptr<ProbIndex<T>[]> probindex =
std::make_unique<ProbIndex<T>[]>(n);
for (int i = 0; i < n; i++) {
probindex[i].index = i;
probindex[i].prob = probabilities[i];
}

auto compare = [](const ProbIndex<T>& a, const ProbIndex<T>& b) {
return a.prob > b.prob;
};
// Partial sort: only the top-k entries need to be sorted in descending order.
std::partial_sort(
probindex.get(), probindex.get() + k, probindex.get() + n, compare);

// Sum of the top-k probabilities. Used to scale `coin` instead of
// explicitly renormalizing the k probs — mathematically equivalent and
// saves k divisions. Accumulate in float so FP16/BF16 inputs don't lose
// precision over k summands.
float topk_sum = 0.0f;
for (int i = 0; i < k; i++) {
topk_sum += static_cast<float>(probindex[i].prob);
}

// Sample from the (implicitly renormalized) top-k distribution.
const float r = coin * topk_sum;
float cdf = 0.0f;
for (int i = 0; i < k; i++) {
cdf += static_cast<float>(probindex[i].prob);
if (r < cdf) {
return probindex[i].index;
}
}
return probindex[k - 1].index; // in case of rounding errors
}

template <typename T>
int32_t Sampler::sample_topp(T* probabilities, float coin) {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
Expand Down Expand Up @@ -186,7 +236,10 @@ int32_t Sampler::sample(T* logits) {
// flip a (float) coin (this is our source of entropy for sampling)
float coin = random_f32(&rng_state_);
// we sample from this distribution to get the next token
if (topp_ <= 0 || topp_ >= 1) {
if (topk_ > 0 && topk_ < vocab_size_) {
// top-k sampling, restrict to the k most likely tokens
next = sample_topk(logits, coin);
} else if (topp_ <= 0 || topp_ >= 1) {
// simply sample from the predicted probability distribution
next = sample_mult(logits, coin);
} else {
Expand Down
11 changes: 11 additions & 0 deletions extension/llm/sampler/sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,22 @@ class ET_EXPERIMENTAL Sampler {

Sampler(int32_t vocab_size, float temperature);

// Enable top-k filtering. k <= 0 or k >= vocab_size disables top-k.
// When top-k is enabled, top-p is ignored — the two modes are mutually
// exclusive in this implementation.
void set_topk(int32_t topk) {
topk_ = topk;
}

template <typename T>
int32_t sample(T* logits);

private:
template <typename T>
int32_t sample_topp(T* probabilities, float coin);
template <typename T>
int32_t sample_topk(T* probabilities, float coin);
template <typename T>
int32_t sample_mult(T* probabilities, float coin);
template <typename T>
int32_t sample_argmax(T* probabilities);
Expand All @@ -60,6 +69,8 @@ class ET_EXPERIMENTAL Sampler {
// reciprocal of temperature, or 0 if temperature == 0.
float inv_temperature_;
float topp_;
// 0 (or >= vocab_size_) means top-k is disabled.
int32_t topk_ = 0;
unsigned long long rng_state_;
};

Expand Down
113 changes: 113 additions & 0 deletions extension/llm/sampler/test/test_sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include <executorch/extension/llm/sampler/sampler.h>

#include <set>

#include <gtest/gtest.h>
#include <torch/torch.h>

Expand Down Expand Up @@ -39,3 +41,114 @@ TEST(SamplerTest, TestArgMaxWithFP16) {
input[0][0][396] = 1.0f;
EXPECT_EQ(sampler.sample(input.data_ptr<c10::Half>()), 396);
}

TEST(SamplerTest, TestTopKRestrictsToCandidates) {
// With topk=3, sampling must always return one of the top-3 indices,
// regardless of the random draw.
Sampler sampler{
/*vocab_size*/ 100,
/*temperature*/ 1.0f,
/*topp*/ 0.0f, // disable top-p so we exercise top-k alone
/*rng_seed*/ 42};
sampler.set_topk(3);

// Construct logits where indices {7, 13, 42} dominate.
torch::Tensor input = torch::full({100}, -10.0f, at::kFloat);
input[7] = 5.0f;
input[13] = 4.5f;
input[42] = 4.0f;

std::set<int32_t> allowed = {7, 13, 42};
for (int trial = 0; trial < 50; ++trial) {
// Re-fill logits each trial because sample() mutates them in place.
torch::Tensor logits = input.clone();
int32_t out = sampler.sample(logits.data_ptr<float>());
EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out;
}
}

TEST(SamplerTest, TestTopKDisabledByZero) {
// topk=0 means disabled. With topp disabled, sampling collapses to
// multinomial over the full vocab, but the dominant token should still
// win the vast majority of the time.
Sampler sampler{
/*vocab_size*/ 50,
/*temperature*/ 1.0f,
/*topp*/ 0.0f,
/*rng_seed*/ 7};
sampler.set_topk(0); // disabled

torch::Tensor input = torch::full({50}, -10.0f, at::kFloat);
input[11] = 20.0f; // dominant

int hits = 0;
for (int trial = 0; trial < 20; ++trial) {
torch::Tensor logits = input.clone();
if (sampler.sample(logits.data_ptr<float>()) == 11) {
hits++;
}
}
EXPECT_GE(hits, 18); // dominant token should win nearly every time
}

TEST(SamplerTest, TestTopKWithFP16) {
// Smoke test the FP16 template instantiation of the top-k path.
Sampler sampler{
/*vocab_size*/ 50,
/*temperature*/ 1.0f,
/*topp*/ 0.0f,
/*rng_seed*/ 99};
sampler.set_topk(2);

torch::Tensor input = torch::full({50}, -10.0f, at::kHalf);
input[3] = 5.0f;
input[8] = 4.5f;

std::set<int32_t> allowed = {3, 8};
for (int trial = 0; trial < 30; ++trial) {
torch::Tensor logits = input.clone();
int32_t out = sampler.sample(logits.data_ptr<c10::Half>());
EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out;
}
}

TEST(SamplerTest, TestTopKEqualsOneIsArgmax) {
// topk=1 should behave like greedy argmax even with temperature > 0.
Sampler sampler{
/*vocab_size*/ 100,
/*temperature*/ 1.0f,
/*topp*/ 0.0f,
/*rng_seed*/ 123};
sampler.set_topk(1);

torch::Tensor input = torch::rand({100}, at::kFloat);
input[57] = 100.0f; // make 57 the unambiguous max

for (int trial = 0; trial < 10; ++trial) {
torch::Tensor logits = input.clone();
EXPECT_EQ(sampler.sample(logits.data_ptr<float>()), 57);
}
}

TEST(SamplerTest, TestTopKTakesPrecedenceOverTopP) {
// When both top-k and top-p are set, top-k should restrict the candidate
// set; top-p alone would admit a third token that top-k=2 must exclude.
Sampler sampler{
/*vocab_size*/ 100,
/*temperature*/ 1.0f,
/*topp*/ 0.99f, // would keep nearly the whole vocab on its own
/*rng_seed*/ 99};
sampler.set_topk(2);

torch::Tensor input = torch::full({100}, -10.0f, at::kFloat);
input[3] = 5.0f;
input[8] = 4.5f;
input[19] = 4.0f; // would be in the top-p set but is excluded by top-k=2

std::set<int32_t> allowed = {3, 8};
for (int trial = 0; trial < 50; ++trial) {
torch::Tensor logits = input.clone();
int32_t out = sampler.sample(logits.data_ptr<float>());
EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out;
}
}
Loading