diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 43984aeb15..73cfb01419 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -354,10 +354,7 @@ def unpermute(self, intermediate, sorted_selected_experts, weights): return output.reshape(int(self.config.per_device_batch_size), -1, self.config.emb_dim).astype(self.dtype) def megablox(self, inputs, gate_logits, config, w0_kernel, w1_kernel, wo_kernel): - # TODO(ranran): need to changes in JAX repo to enable optimized tile_size - # instead of the static default tile_size (512, 512, 512) - tile_size = (512, 512, 512) - + tile_size = (512, 1024, 1024) def gmm(inputs, kernel, group_sizes): hs_shape = inputs.shape # pad length is the 1st dimension of tiling size in gmm call