diff --git a/src/maxtext/kernels/ragged/ragged_gather_reduce.py b/src/maxtext/kernels/ragged/ragged_gather_reduce.py index 49f1346bad..d5bb3314b6 100644 --- a/src/maxtext/kernels/ragged/ragged_gather_reduce.py +++ b/src/maxtext/kernels/ragged/ragged_gather_reduce.py @@ -385,7 +385,7 @@ def _preprocess( reduce_group_size: int, num_row_partitions: int, num_simd_lanes: int, -) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: """Preprocesses indices for ragged gather reduce.""" assert indices.ndim == 1, "Ragged scatter only supports 1d indices." @@ -416,12 +416,16 @@ def _preprocess( num_src_rows_per_row_partition.astype(jnp.int32), (0, num_simd_lanes - num_row_partitions), ) + # If there is no valid source row in a reduce group, we set the mask to + # False, so that the output for that group is set to zero. + mask = jnp.any(valid_rows_mask.reshape(-1, reduce_group_size), axis=-1) return ( src_indices, dst_indices, topk_weights, num_src_rows_per_row_partition, + mask, ) @@ -481,11 +485,6 @@ def ragged_gather_reduce( # Heuristic threshold on whether to fallback for small inputs. dtype = x.dtype dtype_bytes = jax.dtypes.itemsize_bits(dtype) // 8 - if jnp.size(x) * dtype_bytes * 2 < pltpu.get_tpu_info().vmem_capacity_bytes * 0.6: - # For small {input + output}, it's likely that both can be put in TC VMEM, - # so it's likely faster to run TC-based implementation on it than going - # through SC, without data movement to/from HBM. - return _fallback_implementation(x, indices, topk_weights, valid_rows_mask, reduce_group_size) hidden_size = x.shape[-1] input_size = indices.size @@ -533,6 +532,7 @@ def ragged_gather_reduce( dst_indices, topk_weights, num_src_rows_per_row_partition, + mask, ) = _preprocess( indices, topk_weights, @@ -588,4 +588,10 @@ def ragged_gather_reduce( }, )(num_src_rows_per_row_partition, x, src_indices, dst_indices, topk_weights) - return out.astype(x.dtype) + # If there is no valid source row in a reduce group, set that group's output + # to zero. + return jnp.where( + mask[:, None], + out.astype(x.dtype), + jnp.zeros_like(out, dtype=x.dtype), + )[: (input_size // reduce_group_size), :hidden_size] diff --git a/src/maxtext/kernels/ragged/ragged_sort.py b/src/maxtext/kernels/ragged/ragged_sort.py index a6a0485918..cabccf0dc7 100644 --- a/src/maxtext/kernels/ragged/ragged_sort.py +++ b/src/maxtext/kernels/ragged/ragged_sort.py @@ -186,6 +186,8 @@ def _ring_ragged_sort_bwd(res, g_out): valid_rows_mask=valid_rows_mask, reduce_group_size=topk, enforce_fallback=enforce_gather_reduce_fallback, + flops_override=gather_reduce_flops_override, + bytes_accessed_override=gather_reduce_bytes_accessed_override, ) else: # Buffering: g_x has size `local_buffer_size` (packed). @@ -209,6 +211,8 @@ def _ring_ragged_sort_bwd(res, g_out): valid_rows_mask=valid_rows_mask, reduce_group_size=topk, enforce_fallback=enforce_gather_reduce_fallback, + flops_override=gather_reduce_flops_override, + bytes_accessed_override=gather_reduce_bytes_accessed_override, ) return grad_hidden_states, None @@ -303,6 +307,8 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort valid_rows_mask=valid_rows_mask, reduce_group_size=topk, enforce_fallback=enforce_gather_reduce_fallback, + flops_override=gather_reduce_flops_override, + bytes_accessed_override=gather_reduce_bytes_accessed_override, ) else: # Shift indices so they map to the packed local buffer [0, local_num_tokens). @@ -319,6 +325,8 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort valid_rows_mask=valid_rows_mask, reduce_group_size=topk, enforce_fallback=enforce_gather_reduce_fallback, + flops_override=gather_reduce_flops_override, + bytes_accessed_override=gather_reduce_bytes_accessed_override, ) res = ( @@ -375,6 +383,8 @@ def _ring_ragged_unsort_bwd(res, g_out): weights=weight_for_sorted, has_weights=True, enforce_fallback=enforce_gather_fallback, + flops_override=gather_flops_override, + bytes_accessed_override=gather_bytes_accessed_override, ) else: # Slice the inverse permutation to match the packed local buffer. @@ -392,6 +402,8 @@ def _ring_ragged_unsort_bwd(res, g_out): weights=sliced_weights, has_weights=True, enforce_fallback=enforce_gather_fallback, + flops_override=gather_flops_override, + bytes_accessed_override=gather_bytes_accessed_override, ) return grad_sorted_tokens, None, None, None