Skip to content

feat: sample on the OpenXLA serve path (temperature/top-p/top-k) (#449 M3 Stage 2d)#471

Merged
inureyes merged 1 commit into
mainfrom
feature/449-stage2d-xla-sampling
Jun 28, 2026
Merged

feat: sample on the OpenXLA serve path (temperature/top-p/top-k) (#449 M3 Stage 2d)#471
inureyes merged 1 commit into
mainfrom
feature/449-stage2d-xla-sampling

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

Stage 2d of the OpenXLA/IREE milestone (#449 M3): the serve path (Stage 2c) served greedily regardless of the request's sampling parameters. This makes the continuous-batching engine work from logits and sample on the host, so it honors temperature / top-k / top-p / min-p / seed. Greedy (temperature 0) is host argmax of the logits, token-exact with the single-sequence argmax path, so existing greedy behavior is unchanged.

What's here

  • Assets — the batched engine uses logits graphs: prefill_logits.mlir returns [V] logits, decode_ragged_logits_b{4,8}.mlir return [B, V] logits (the feat: evaluate a Rust-native StableHLO emitter as the compiler-family authoring path (spike) #451 emitter's non-argmax variants). The single-sequence session keeps the argmax graphs.
  • C shim (csrc/xla_iree.c) — xla_llama_prefill_slot_logits / xla_llama_decode_ragged_logits run the logits graphs and copy the per-row logits to host (the prefill variant also does the device-side KV slot write); the argmax ragged functions are replaced.
  • sampler.rsSampleParams { temperature, top_k, top_p, min_p, seed } + a host sampler (stable softmax, top-k, nucleus top-p, min-p, categorical draw from a seeded splitmix64 PRNG). Greedy short-circuits to argmax. Unit-tested.
  • EngineIreeRaggedLlama returns logits; XlaBatchEngine holds per-slot params + PRNG state, samples the first token from the prefill logits and per-row from the decode logits. submit takes SampleParams.
  • Serve worker — translates the request's SamplingConfig to SampleParams (warns once that history-based penalties / DRY are not applied); the bench submits greedy.

Cost

Reading the full [B, V] logits per step is a small fraction of the decode matmuls (the greedy bench throughput is unchanged within noise: CUDA B=8 N=16 at ~24.5 tok/s, vs ~23.4 in Stage 2c). On-device sampling is a later optimization.

Validation

  • Greedy reference-exact on CPU local-task (B=4 N=8) and CUDA GB10 (B=8 N=16): host argmax of the logits graphs matches the single-seq argmax, so the switch to logits graphs preserves greedy output token for token, on both devices. (This also confirms the new logits graphs lower + run on both targets.)
  • Sampling E2E on GB10 (/v1/completions):
    • temperature=0 -> matches greedy.
    • temperature=0.7, seed=42 run twice -> identical (reproducible).
    • seed=42 vs seed=7 -> different; temperature=0 vs 1.2 -> different. All coherent.
  • Unit tests: 5 sampler tests (greedy=argmax, top-k=1=argmax, seeded determinism, top-p nucleus, empty-safe). cargo fmt + cargo clippy -D warnings clean (default + xla-iree).

The MLX serving path is untouched (no trait / scheduler changes); default and CI builds compile with the XLA worker absent.

Scope / non-goals

Honors temperature / top-k / top-p / min-p / seed. Not applied: repetition / frequency / presence penalties, DRY (warned once); stop strings; logprobs / structured / multimodal (rejected). These remain Stage 2d/2e follow-ups.

Refs #449 (epic). Follows #470 (Stage 2c serving), #469 (2b engine), #468/#466 (2a), #462 (Stage 1).

…M3 Stage 2d)

The OpenXLA serve path was greedy regardless of the request's sampling
parameters. Make the continuous-batching engine work from logits and sample on
the host, so it honors temperature / top-k / top-p / min-p / seed. Greedy
(temperature 0) is host argmax of the logits, token-exact with the single-sequence
argmax path, so existing greedy behavior is unchanged.

- Assets: the batched engine now uses LOGITS graphs (prefill_logits.mlir returns
  [V] logits; decode_ragged_logits_b{4,8}.mlir return [B, V] logits), emitted by
  the #451 emitter's non-argmax variants. The single-sequence session keeps the
  argmax graphs.
- C shim: xla_llama_prefill_slot_logits / xla_llama_decode_ragged_logits run the
  logits graphs and copy the per-row logits to host (the prefill variant also does
  the device-side KV slot write); the argmax ragged functions are replaced.
- sampler.rs: SampleParams { temperature, top_k, top_p, min_p, seed } + a host
  sampler (stable softmax, top-k, nucleus top-p, min-p, categorical draw from a
  seeded splitmix64 PRNG). Greedy short-circuits to argmax. Unit-tested.
- Engine: IreeRaggedLlama returns logits; XlaBatchEngine holds per-slot params +
  PRNG state, samples the first token from the prefill logits and per-row from the
  decode logits. submit takes SampleParams.
- Serve worker: translates the request's SamplingConfig to SampleParams (warning
  once that history-based penalties / DRY are not applied); the bench submits
  greedy.

Reading the full [B, V] logits per step is a small fraction of the decode matmuls
(on-device sampling is a later optimization). Validated:
- Greedy reference-exact on CPU local-task (B=4 N=8) and CUDA GB10 (B=8 N=16):
  host argmax of the logits graphs matches the single-seq argmax, so the switch to
  logits graphs preserves greedy output token for token, on both devices.
- Sampling E2E on GB10 (/v1/completions): a seed makes output reproducible (same
  seed -> identical), different seeds and temperatures diverge, temperature 0
  matches greedy.
- 5 sampler unit tests; fmt + clippy -D warnings clean (default + xla-iree).

The MLX serving path is untouched; default and CI builds compile with the XLA
worker absent.
@inureyes inureyes added area:architecture Architecture and code structure changes priority:medium Medium priority status:done Completed type:enhancement New features, capabilities, or significant additions labels Jun 28, 2026
@inureyes inureyes merged commit d39be11 into main Jun 28, 2026
5 checks passed
@inureyes inureyes deleted the feature/449-stage2d-xla-sampling branch June 28, 2026 16:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:architecture Architecture and code structure changes priority:medium Medium priority status:done Completed type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant