From 6a0e570a2e5935df702f9c72e3a85c19771d9245 Mon Sep 17 00:00:00 2001 From: RissyRan Date: Mon, 8 Jul 2024 18:19:37 +0000 Subject: [PATCH] Enable quantization for MoE Matmul implementation --- MaxText/layers/linears.py | 3 +++ MaxText/layers/mistral.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 7382fdd85f..f1701ffd43 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -278,6 +278,7 @@ class MoeBlock(nn.Module): kernel_axes: Tuple with axes to apply kernel function. weight_dtype: Type for the weights. dtype: Type for the dense layer. + quant: Optional quantization config, no quantization if None. """ config: Config @@ -288,6 +289,7 @@ class MoeBlock(nn.Module): kernel_axes: Tuple[str, ...] weight_dtype: DType = jnp.float32 dtype: DType = jnp.float32 + quant: Optional[Quant] = None def generate_kernels(self, num_experts, emb_dim, mlp_dim): @@ -411,6 +413,7 @@ def __call__(self, inputs): self.num_experts, dtype=self.dtype, weight_dtype=self.weight_dtype, + quant=self.quant, kernel_init=self.kernel_init, kernel_axes=self.kernel_axes, name="gate")(inputs) diff --git a/MaxText/layers/mistral.py b/MaxText/layers/mistral.py index 9994cfc717..4b7501df0c 100644 --- a/MaxText/layers/mistral.py +++ b/MaxText/layers/mistral.py @@ -132,6 +132,7 @@ def __call__( kernel_axes=('embed', 'mlp'), dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, + quant=self.quant, )(hidden_states) mlp_lnx = nn.with_logical_constraint( mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') @@ -145,6 +146,7 @@ def __call__( weight_dtype=cfg.weight_dtype, name="mlp", config=cfg, + quant=self.quant, )(hidden_states, deterministic=deterministic) mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed"))