diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 73cfb01419..7382fdd85f 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -364,7 +364,7 @@ def gmm(inputs, kernel, group_sizes): inputs = jax.lax.pad(inputs.astype(jnp.float32), 0.0, [(0, pad_length, 0), (0,0,0)]) inputs = inputs.astype(self.dtype) - kernel = kernel.astype(self.weight_dtype) + kernel = kernel.astype(self.dtype) output = mblx.gmm(lhs=inputs, rhs=kernel, group_sizes=group_sizes,