feat: sample on the OpenXLA serve path (temperature/top-p/top-k) (#449 M3 Stage 2d)#471
Merged
Merged
Conversation
…M3 Stage 2d)
The OpenXLA serve path was greedy regardless of the request's sampling
parameters. Make the continuous-batching engine work from logits and sample on
the host, so it honors temperature / top-k / top-p / min-p / seed. Greedy
(temperature 0) is host argmax of the logits, token-exact with the single-sequence
argmax path, so existing greedy behavior is unchanged.
- Assets: the batched engine now uses LOGITS graphs (prefill_logits.mlir returns
[V] logits; decode_ragged_logits_b{4,8}.mlir return [B, V] logits), emitted by
the #451 emitter's non-argmax variants. The single-sequence session keeps the
argmax graphs.
- C shim: xla_llama_prefill_slot_logits / xla_llama_decode_ragged_logits run the
logits graphs and copy the per-row logits to host (the prefill variant also does
the device-side KV slot write); the argmax ragged functions are replaced.
- sampler.rs: SampleParams { temperature, top_k, top_p, min_p, seed } + a host
sampler (stable softmax, top-k, nucleus top-p, min-p, categorical draw from a
seeded splitmix64 PRNG). Greedy short-circuits to argmax. Unit-tested.
- Engine: IreeRaggedLlama returns logits; XlaBatchEngine holds per-slot params +
PRNG state, samples the first token from the prefill logits and per-row from the
decode logits. submit takes SampleParams.
- Serve worker: translates the request's SamplingConfig to SampleParams (warning
once that history-based penalties / DRY are not applied); the bench submits
greedy.
Reading the full [B, V] logits per step is a small fraction of the decode matmuls
(on-device sampling is a later optimization). Validated:
- Greedy reference-exact on CPU local-task (B=4 N=8) and CUDA GB10 (B=8 N=16):
host argmax of the logits graphs matches the single-seq argmax, so the switch to
logits graphs preserves greedy output token for token, on both devices.
- Sampling E2E on GB10 (/v1/completions): a seed makes output reproducible (same
seed -> identical), different seeds and temperatures diverge, temperature 0
matches greedy.
- 5 sampler unit tests; fmt + clippy -D warnings clean (default + xla-iree).
The MLX serving path is untouched; default and CI builds compile with the XLA
worker absent.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Stage 2d of the OpenXLA/IREE milestone (#449 M3): the serve path (Stage 2c) served greedily regardless of the request's sampling parameters. This makes the continuous-batching engine work from logits and sample on the host, so it honors temperature / top-k / top-p / min-p / seed. Greedy (
temperature 0) is host argmax of the logits, token-exact with the single-sequence argmax path, so existing greedy behavior is unchanged.What's here
prefill_logits.mlirreturns[V]logits,decode_ragged_logits_b{4,8}.mlirreturn[B, V]logits (the feat: evaluate a Rust-native StableHLO emitter as the compiler-family authoring path (spike) #451 emitter's non-argmax variants). The single-sequence session keeps the argmax graphs.csrc/xla_iree.c) —xla_llama_prefill_slot_logits/xla_llama_decode_ragged_logitsrun the logits graphs and copy the per-row logits to host (the prefill variant also does the device-side KV slot write); the argmax ragged functions are replaced.sampler.rs—SampleParams { temperature, top_k, top_p, min_p, seed }+ a host sampler (stable softmax, top-k, nucleus top-p, min-p, categorical draw from a seeded splitmix64 PRNG). Greedy short-circuits to argmax. Unit-tested.IreeRaggedLlamareturns logits;XlaBatchEngineholds per-slot params + PRNG state, samples the first token from the prefill logits and per-row from the decode logits.submittakesSampleParams.SamplingConfigtoSampleParams(warns once that history-based penalties / DRY are not applied); the bench submits greedy.Cost
Reading the full
[B, V]logits per step is a small fraction of the decode matmuls (the greedy bench throughput is unchanged within noise: CUDA B=8 N=16 at ~24.5 tok/s, vs ~23.4 in Stage 2c). On-device sampling is a later optimization.Validation
local-task(B=4 N=8) and CUDA GB10 (B=8 N=16): host argmax of the logits graphs matches the single-seq argmax, so the switch to logits graphs preserves greedy output token for token, on both devices. (This also confirms the new logits graphs lower + run on both targets.)/v1/completions):temperature=0-> matches greedy.temperature=0.7, seed=42run twice -> identical (reproducible).seed=42vsseed=7-> different;temperature=0vs1.2-> different. All coherent.cargo fmt+cargo clippy -D warningsclean (default +xla-iree).The MLX serving path is untouched (no trait / scheduler changes); default and CI builds compile with the XLA worker absent.
Scope / non-goals
Honors temperature / top-k / top-p / min-p / seed. Not applied: repetition / frequency / presence penalties, DRY (warned once); stop strings; logprobs / structured / multimodal (rejected). These remain Stage 2d/2e follow-ups.
Refs #449 (epic). Follows #470 (Stage 2c serving), #469 (2b engine), #468/#466 (2a), #462 (Stage 1).