From 9f9a6b20da687200c456c750d38ea004771f5657 Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Wed, 13 May 2026 10:31:46 -0700 Subject: [PATCH 1/2] Add fused MoE FFN1: fuse wi_0 and wi_1 into one grouped GEMM When prefuse_moe_weights=True (requires sparse_matmul=True), the two FFN1 expert weight matrices [G,K,N] are concatenated into [G,K,2N] and dispatched as a single grouped GEMM, then split. This halves FFN1 kernel launches and reads input activations from HBM once instead of twice. Backend-agnostic: works with Megablox, Tokamax, and jax.lax.ragged_dot. When attention=vllm_rpa the fused tensor is passed directly to the vLLM-TPU serving kernel. correct sort kernel convergence --- .../kernels/ragged/ragged_gather_reduce.py | 20 ++++++++++++------- src/maxtext/kernels/ragged/ragged_sort.py | 8 ++++++++ src/maxtext/layers/moe.py | 1 - 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/maxtext/kernels/ragged/ragged_gather_reduce.py b/src/maxtext/kernels/ragged/ragged_gather_reduce.py index 49f1346bad..d5bb3314b6 100644 --- a/src/maxtext/kernels/ragged/ragged_gather_reduce.py +++ b/src/maxtext/kernels/ragged/ragged_gather_reduce.py @@ -385,7 +385,7 @@ def _preprocess( reduce_group_size: int, num_row_partitions: int, num_simd_lanes: int, -) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: """Preprocesses indices for ragged gather reduce.""" assert indices.ndim == 1, "Ragged scatter only supports 1d indices." @@ -416,12 +416,16 @@ def _preprocess( num_src_rows_per_row_partition.astype(jnp.int32), (0, num_simd_lanes - num_row_partitions), ) + # If there is no valid source row in a reduce group, we set the mask to + # False, so that the output for that group is set to zero. + mask = jnp.any(valid_rows_mask.reshape(-1, reduce_group_size), axis=-1) return ( src_indices, dst_indices, topk_weights, num_src_rows_per_row_partition, + mask, ) @@ -481,11 +485,6 @@ def ragged_gather_reduce( # Heuristic threshold on whether to fallback for small inputs. dtype = x.dtype dtype_bytes = jax.dtypes.itemsize_bits(dtype) // 8 - if jnp.size(x) * dtype_bytes * 2 < pltpu.get_tpu_info().vmem_capacity_bytes * 0.6: - # For small {input + output}, it's likely that both can be put in TC VMEM, - # so it's likely faster to run TC-based implementation on it than going - # through SC, without data movement to/from HBM. - return _fallback_implementation(x, indices, topk_weights, valid_rows_mask, reduce_group_size) hidden_size = x.shape[-1] input_size = indices.size @@ -533,6 +532,7 @@ def ragged_gather_reduce( dst_indices, topk_weights, num_src_rows_per_row_partition, + mask, ) = _preprocess( indices, topk_weights, @@ -588,4 +588,10 @@ def ragged_gather_reduce( }, )(num_src_rows_per_row_partition, x, src_indices, dst_indices, topk_weights) - return out.astype(x.dtype) + # If there is no valid source row in a reduce group, set that group's output + # to zero. + return jnp.where( + mask[:, None], + out.astype(x.dtype), + jnp.zeros_like(out, dtype=x.dtype), + )[: (input_size // reduce_group_size), :hidden_size] diff --git a/src/maxtext/kernels/ragged/ragged_sort.py b/src/maxtext/kernels/ragged/ragged_sort.py index a6a0485918..38f50f29f4 100644 --- a/src/maxtext/kernels/ragged/ragged_sort.py +++ b/src/maxtext/kernels/ragged/ragged_sort.py @@ -303,6 +303,8 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort valid_rows_mask=valid_rows_mask, reduce_group_size=topk, enforce_fallback=enforce_gather_reduce_fallback, + flops_override=gather_reduce_flops_override, + bytes_accessed_override=gather_reduce_bytes_accessed_override, ) else: # Shift indices so they map to the packed local buffer [0, local_num_tokens). @@ -319,6 +321,8 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort valid_rows_mask=valid_rows_mask, reduce_group_size=topk, enforce_fallback=enforce_gather_reduce_fallback, + flops_override=gather_reduce_flops_override, + bytes_accessed_override=gather_reduce_bytes_accessed_override, ) res = ( @@ -375,6 +379,8 @@ def _ring_ragged_unsort_bwd(res, g_out): weights=weight_for_sorted, has_weights=True, enforce_fallback=enforce_gather_fallback, + flops_override=gather_flops_override, + bytes_accessed_override=gather_bytes_accessed_override, ) else: # Slice the inverse permutation to match the packed local buffer. @@ -392,6 +398,8 @@ def _ring_ragged_unsort_bwd(res, g_out): weights=sliced_weights, has_weights=True, enforce_fallback=enforce_gather_fallback, + flops_override=gather_flops_override, + bytes_accessed_override=gather_bytes_accessed_override, ) return grad_sorted_tokens, None, None, None diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index fdea21981c..474e286ad1 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -918,7 +918,6 @@ def permute( repeats=group_size, total_repeat_length=math.prod(selected_experts.shape), ) - return ( sorted_inputs, sorted_selected_experts, From bd32f80d5b1ed04d62faf0cba7cc7bbae81d8f6e Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Fri, 26 Jun 2026 00:22:17 +0000 Subject: [PATCH 2/2] revert optimization from PR#4166 --- src/maxtext/kernels/ragged/ragged_sort.py | 4 ++++ src/maxtext/layers/moe.py | 1 + 2 files changed, 5 insertions(+) diff --git a/src/maxtext/kernels/ragged/ragged_sort.py b/src/maxtext/kernels/ragged/ragged_sort.py index 38f50f29f4..cabccf0dc7 100644 --- a/src/maxtext/kernels/ragged/ragged_sort.py +++ b/src/maxtext/kernels/ragged/ragged_sort.py @@ -186,6 +186,8 @@ def _ring_ragged_sort_bwd(res, g_out): valid_rows_mask=valid_rows_mask, reduce_group_size=topk, enforce_fallback=enforce_gather_reduce_fallback, + flops_override=gather_reduce_flops_override, + bytes_accessed_override=gather_reduce_bytes_accessed_override, ) else: # Buffering: g_x has size `local_buffer_size` (packed). @@ -209,6 +211,8 @@ def _ring_ragged_sort_bwd(res, g_out): valid_rows_mask=valid_rows_mask, reduce_group_size=topk, enforce_fallback=enforce_gather_reduce_fallback, + flops_override=gather_reduce_flops_override, + bytes_accessed_override=gather_reduce_bytes_accessed_override, ) return grad_hidden_states, None diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 474e286ad1..fdea21981c 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -918,6 +918,7 @@ def permute( repeats=group_size, total_repeat_length=math.prod(selected_experts.shape), ) + return ( sorted_inputs, sorted_selected_experts,