Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
5465d8b
Replace chunked FLA with recurrent gated delta rule for T=1 decode
Gasoonjia Apr 2, 2026
a6ebe8a
Runtime dispatch: recurrent (T=1) vs chunked (T>1) inside triton_op
Gasoonjia Apr 3, 2026
fc5018e
Revert model.py, export.py, main.cpp to main branch
Gasoonjia Apr 3, 2026
c90a8e8
Add tests for recurrent (T=1) and multi-T dispatch
Gasoonjia Apr 3, 2026
ce3e9ca
lint fix - 2
Gasoonjia Apr 3, 2026
8d35c65
lint fix - 2
Gasoonjia Apr 3, 2026
709deb0
Merge branch 'main' into recurrent-fla
Gasoonjia Apr 3, 2026
eff976d
lint fix - 3
Gasoonjia Apr 3, 2026
7dd4280
Optimize recurrent kernel: parallelize over V tiles
Gasoonjia Apr 3, 2026
3a1ee31
Dual-method PTE with GPU-resident state for Qwen3.5 MoE
Apr 5, 2026
63c162e
Use share_mutable_buffers to eliminate select_scatter overhead
Apr 6, 2026
47d6b98
Merge branch 'main' into recurrent-fla
Gasoonjia Apr 6, 2026
375e5c0
lint
Gasoonjia Apr 6, 2026
2b36797
remove reduntdant updates
Gasoonjia Apr 6, 2026
c06d58b
Cross-method AOTI constant sharing for KV cache
Apr 7, 2026
6945b2a
Fix cross-method AOTI constant sharing and add dual-method runner
Gasoonjia Apr 7, 2026
ea51d0d
Remove debug printf and decode_only flag
Gasoonjia Apr 7, 2026
a0a62f1
Lint formatting fixes
Gasoonjia Apr 7, 2026
ca69871
Improve CUDA backend error handling and add dual-method runner fallback
Apr 9, 2026
7c148f7
Add CUDA graph capture/replay for decode method
Apr 10, 2026
ee75c2e
Merge branch 'main' into cuda-graph
Gasoonjia Apr 10, 2026
10e7aad
lint and reformat
Gasoonjia Apr 13, 2026
9042f36
Merge branch 'main' into cuda-graph
Gasoonjia Apr 13, 2026
84d1587
Merge branch 'main' into cuda-graph
Gasoonjia Apr 15, 2026
e00a499
solve claude
Gasoonjia Apr 15, 2026
aa7bb82
Merge branch 'main' into cuda-graph
Gasoonjia Apr 15, 2026
cef386b
Merge branch 'main' into cuda-graph
Gasoonjia Apr 15, 2026
2d32422
Merge branch 'main' into cuda-graph
Gasoonjia Apr 16, 2026
1270870
Merge branch 'main' into cuda-graph
Gasoonjia Apr 16, 2026
8fc7355
solve stride out of scope
Gasoonjia Apr 17, 2026
2c46ed2
Merge branch 'main' into cuda-graph
Gasoonjia Apr 21, 2026
855eb93
Merge branch 'main' into cuda-graph
Gasoonjia Apr 22, 2026
4237d17
remove unused env var
Gasoonjia Apr 22, 2026
9b4705e
Merge branch 'main' into cuda-graph
Gasoonjia Apr 23, 2026
0492e8d
Add GPU-side Gumbel-max sampling for CUDA graph compatibility
Apr 13, 2026
8c0bbf3
lintrunner
Gasoonjia Apr 13, 2026
5245f64
remove git info
Gasoonjia Apr 23, 2026
880391d
reintro llm headers
Gasoonjia Apr 23, 2026
6f411af
lint
Gasoonjia Apr 24, 2026
eff4294
add top-p and top-k arg
Gasoonjia Apr 24, 2026
61d47aa
move top-p and top-k suport into a individual PR
Gasoonjia Apr 24, 2026
3e185c0
Merge branch 'main' into cuda-graph-sampling
Gasoonjia Apr 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
move top-p and top-k suport into a individual PR
  • Loading branch information
Gasoonjia committed Apr 24, 2026
commit 61d47aa5ed2c0c1e19d06a80abbec49c7b66e5ac
27 changes: 2 additions & 25 deletions examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,23 +770,10 @@ def _export_cuda(model, config, args):
decode_tokens = torch.tensor([[0]], dtype=torch.long)
decode_pos = torch.tensor([0], dtype=torch.long)
decode_temperature = torch.tensor([1.0], dtype=torch.float32)
# top_k / top_p are runtime scalar tensors (parallel to temperature) so
# the same .pte can be re-driven with different sampling configurations
# without re-export. Default examples are no-op values: top_k=V (keep
# all tokens), top_p=1.0 (keep full nucleus). Callers override them at
# runtime by binding different scalar tensors.
decode_top_k = torch.tensor(config.vocab_size, dtype=torch.int64)
decode_top_p = torch.tensor(1.0, dtype=torch.float32)
with torch.no_grad():
decode_ep = export(
model,
(
decode_tokens,
decode_pos,
decode_temperature,
decode_top_k,
decode_top_p,
),
(decode_tokens, decode_pos, decode_temperature),
strict=True,
)
print("Decode export successful!")
Expand All @@ -803,26 +790,16 @@ def _export_cuda(model, config, args):
prefill_tokens = torch.zeros((1, example_prefill_len), dtype=torch.long)
prefill_pos = torch.arange(example_prefill_len, dtype=torch.long)
prefill_temperature = torch.tensor([1.0], dtype=torch.float32)
prefill_top_k = torch.tensor(config.vocab_size, dtype=torch.int64)
prefill_top_p = torch.tensor(1.0, dtype=torch.float32)
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
prefill_dynamic_shapes = (
{1: seq_dim}, # tokens
{0: seq_dim}, # input_pos
None, # temperature (static scalar tensor)
None, # top_k (static scalar tensor — runtime-bindable)
None, # top_p (static scalar tensor — runtime-bindable)
)
with torch.no_grad():
prefill_ep = export(
model,
(
prefill_tokens,
prefill_pos,
prefill_temperature,
prefill_top_k,
prefill_top_p,
),
(prefill_tokens, prefill_pos, prefill_temperature),
dynamic_shapes=prefill_dynamic_shapes,
strict=True,
)
Expand Down
28 changes: 0 additions & 28 deletions examples/models/qwen3_5_moe/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,6 @@ DEFINE_string(
DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy).");
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
DEFINE_bool(cuda_graph, false, "Enable CUDA graph for decode method.");
DEFINE_int64(
top_k,
-1,
"Top-k sampling cutoff (<=0 = no-op default of vocab_size, keeps all tokens).");
DEFINE_double(
top_p,
1.0,
"Top-p (nucleus) sampling threshold. 1.0 = no-op (keeps full nucleus).");

namespace llm = ::executorch::extension::llm;
using ::executorch::extension::from_blob;
Expand Down Expand Up @@ -206,22 +198,6 @@ int main(int argc, char** argv) {
auto temp_tensor =
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float);

// top_k / top_p are 0-D scalar tensors matching the export-time signature
// (see examples/models/qwen3_5_moe/export.py). The default flag values
// (top_k = vocab_size, top_p = 1.0) are mathematical no-ops: the sort+
// scatter subgraph still runs (it was traced into the graph at export
// time), but produces all-False filter masks so logits pass through
// unchanged. Override at runtime to enable real filtering.
int64_t vocab_size = metadata.count(llm::kVocabSize)
? metadata[llm::kVocabSize]
: static_cast<int64_t>(tokenizer->vocab_size());
int64_t top_k_val = (FLAGS_top_k <= 0) ? vocab_size : FLAGS_top_k;
float top_p_val = static_cast<float>(FLAGS_top_p);
auto top_k_tensor =
from_blob(&top_k_val, {}, executorch::aten::ScalarType::Long);
auto top_p_tensor =
from_blob(&top_p_val, {}, executorch::aten::ScalarType::Float);

// ---------------------------------------------------------------
// Prefill
// ---------------------------------------------------------------
Expand Down Expand Up @@ -252,8 +228,6 @@ int main(int argc, char** argv) {
prefill_inputs.push_back(tokens_tensor);
prefill_inputs.push_back(pos_tensor);
prefill_inputs.push_back(temp_tensor);
prefill_inputs.push_back(top_k_tensor);
prefill_inputs.push_back(top_p_tensor);

auto prefill_result = module->execute(run_method, prefill_inputs);
if (prefill_result.error() != Error::Ok) {
Expand Down Expand Up @@ -302,8 +276,6 @@ int main(int argc, char** argv) {
decode_inputs.push_back(EValue(decode_tokens));
decode_inputs.push_back(EValue(decode_pos));
decode_inputs.push_back(EValue(temp_tensor));
decode_inputs.push_back(EValue(top_k_tensor));
decode_inputs.push_back(EValue(top_p_tensor));

auto decode_result = module->execute("decode", decode_inputs);
if (decode_result.error() != Error::Ok) {
Expand Down
13 changes: 6 additions & 7 deletions examples/models/qwen3_5_moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,6 @@ def forward(
tokens: torch.LongTensor,
input_pos: torch.LongTensor,
temperature: Optional[torch.Tensor] = None,
top_k: Optional[torch.Tensor] = None,
top_p: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = self.embed_tokens(tokens)
for layer in self.layers:
Expand All @@ -642,16 +640,17 @@ def forward(
# logits so callers (eval, custom samplers) can inspect every
# position. Otherwise apply the prefill optimization and only
# materialize ``[B, V]`` for the last token.
if temperature is None and top_k is None and top_p is None:
return self.lm_head(x)
if temperature is None:
return self.lm_head(x).float() # [B, T, V] float32
logits = self.lm_head(x[:, -1, :]).float() # [B, V] float32
# GPU-side Gumbel-max sampling: argmax(logits/T + gumbel_noise) is
# equivalent to drawing from softmax(logits/T) but stays entirely
# on-device.
# on-device. Algorithm reference:
# https://huggingface.co/blog/cxdu/fastsampling
# TODO(gasoonjia): once the on-device sampling stack lands, promote
# ``sample`` into a shared CUDA sampling utility reusable by other
# models.
return sample(logits, temperature, top_k, top_p) # [B, 1]
# models, and add top-k / top-p filtering support.
return sample(logits, temperature) # [B, 1]

@staticmethod
def from_hf_checkpoint(model_dir, max_seq_len=4096):
Expand Down
73 changes: 15 additions & 58 deletions examples/models/qwen3_5_moe/sampler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""
GPU-side Gumbel-max sampler with optional top-k / top-p filtering.
GPU-side Gumbel-max sampler.

Self-contained sampling utility that can be imported by other models. Lives
in its own file so it can be reused without pulling in the heavy MoE module.

All sampling parameters (``temperature``, ``top_k``, ``top_p``) are
**runtime tensors** so a single exported program can be re-driven with
different sampling configurations without re-export.
``temperature`` is a runtime tensor so a single exported program can be
re-driven with different sampling configurations without re-export.

"""

from typing import Optional
Expand All @@ -17,20 +17,12 @@
def sample(
logits: torch.Tensor,
temperature: Optional[torch.Tensor] = None,
top_k: Optional[torch.Tensor] = None,
top_p: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""GPU-side Gumbel-max sampler with optional top-k / top-p filtering.

All three sampling knobs are *runtime* scalar tensors so the caller can
change them between calls without re-exporting the graph. The Python-
level ``is None`` checks are static (decided at trace time) and select
which subgraph is emitted; once provided, the actual values are pure
tensors and the kernels are fully data-driven.
"""GPU-side Gumbel-max sampler.

When ``temperature``, ``top_k`` and ``top_p`` are all ``None`` (the
eager / eval default), the function is a no-op and returns ``logits``
unchanged — useful for callers that just want to inspect raw logits.
When ``temperature`` is ``None`` (the eager / eval default) the function
is a no-op and returns ``logits`` unchanged — useful for callers that
just want to inspect raw logits.

Otherwise it draws from ``softmax(logits / temperature)`` entirely
on-device using the Gumbel-max trick:
Expand All @@ -41,58 +33,23 @@ def sample(
float32 logits. The contract is documented as ``[B, V]`` float32 and
callers are expected to ``.float()``-cast before invoking ``sample``.

TODO(gasoonjia): add top-k / top-p filtering support in a follow-up PR.

Args:
logits: ``[B, V]`` float32 logits.
temperature: 0-D or 1-D float tensor (clamped to >= 1e-6 to avoid
divide-by-zero). ``None`` skips temperature scaling.
top_k: 0-D or 1-D int tensor — keep only the top ``k`` logits.
``None`` skips top-k filtering. ``k >= V`` is also a no-op.
top_p: 0-D or 1-D float tensor — nucleus threshold; keep the
smallest set of logits whose cumulative softmax probability
is >= ``top_p``. ``None`` (or ``>= 1.0``) disables top-p.
divide-by-zero). ``None`` skips temperature scaling and the
sampler returns the unmodified ``logits`` tensor.

Returns:
``[B, 1]`` float32 tensor of sampled token IDs, or the unmodified
``logits`` tensor when all sampling parameters are ``None``.
``logits`` tensor when ``temperature`` is ``None``.
"""
# No sampling configured — return raw logits.
if temperature is None and top_k is None and top_p is None:
if temperature is None:
return logits

if temperature is not None:
logits = logits / temperature.clamp(min=1e-6)

# Single sort handles both top-k and top-p filtering — both branches
# need descending logits anyway, so we share the sort to keep the
# graph small.
if top_k is not None or top_p is not None:
sorted_logits, sorted_idx = torch.sort(logits, dim=-1, descending=True)
sorted_remove = torch.zeros_like(sorted_logits, dtype=torch.bool)

if top_k is not None:
# Position >= k → drop. Works for any tensor k via broadcast;
# k >= V naturally becomes a no-op (mask is all-False).
pos = torch.arange(sorted_logits.size(-1), device=sorted_logits.device)
sorted_remove = sorted_remove | (pos >= top_k.to(pos.dtype))

if top_p is not None:
cum_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
p_remove = cum_probs > top_p
# Shift right by one so the highest-prob token is always kept,
# even when its single-token prob already exceeds top_p.
p_remove = torch.cat(
[torch.zeros_like(p_remove[..., :1]), p_remove[..., :-1]],
dim=-1,
)
sorted_remove = sorted_remove | p_remove

sorted_logits = torch.where(
sorted_remove,
torch.full_like(sorted_logits, float("-inf")),
sorted_logits,
)
# Scatter the masked sorted logits back into original token order.
logits = torch.empty_like(logits).scatter_(-1, sorted_idx, sorted_logits)
logits = logits / temperature.clamp(min=1e-6)

# Gumbel-max sampling — equivalent to sampling from softmax(logits)
# but fully on-device and CUDA-graph friendly. The 1e-20 epsilons are
Expand Down
Loading
Loading