From a26715e626f6c93fdb67652ccfb0c85c954f8295 Mon Sep 17 00:00:00 2001 From: RissyRan Date: Tue, 9 Jul 2024 21:36:04 +0000 Subject: [PATCH] Cast type for inputs before kernel call --- MaxText/layers/linears.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 73cfb01419..7382fdd85f 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -364,7 +364,7 @@ def gmm(inputs, kernel, group_sizes): inputs = jax.lax.pad(inputs.astype(jnp.float32), 0.0, [(0, pad_length, 0), (0,0,0)]) inputs = inputs.astype(self.dtype) - kernel = kernel.astype(self.weight_dtype) + kernel = kernel.astype(self.dtype) output = mblx.gmm(lhs=inputs, rhs=kernel, group_sizes=group_sizes,