diff --git a/exir/passes/_quant_patterns_and_replacements.py b/exir/passes/_quant_patterns_and_replacements.py index bf06ce37c5c..e198ff383e9 100644 --- a/exir/passes/_quant_patterns_and_replacements.py +++ b/exir/passes/_quant_patterns_and_replacements.py @@ -46,7 +46,7 @@ ) -def embedding_byte_weight_checks(weight, weight_scales, weight_zero_points): +def embedding_weight_checks(weight, weight_scales, weight_zero_points): assert weight.dtype in [ torch.int8, torch.uint8, @@ -86,7 +86,7 @@ def embedding_byte( weight_quant_max: int, indices: torch.Tensor, ) -> torch.Tensor: - embedding_byte_weight_checks(weight, weight_scales, weight_zero_points) + embedding_weight_checks(weight, weight_scales, weight_zero_points) group_size = weight.size(1) // ( weight_scales.size(1) if weight_scales.dim() == 2 else 1 ) @@ -133,7 +133,7 @@ def embedding_byte_dtype( indices: torch.Tensor, dtype: Optional[torch.dtype], ) -> torch.Tensor: - embedding_byte_weight_checks(weight, weight_scales, weight_zero_points) + embedding_weight_checks(weight, weight_scales, weight_zero_points) group_size = weight.size(1) // ( weight_scales.size(1) if weight_scales.dim() == 2 else 1 ) @@ -172,6 +172,134 @@ def embedding_byte_dtype_out_meta( ) +quantized_decomposed_lib.define( + "embedding_4bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " + "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor", +) + +quantized_decomposed_lib.define( + "embedding_4bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " + "int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None) -> Tensor", +) + +quantized_decomposed_lib.define( + "embedding_4bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " + "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)", +) + +quantized_decomposed_lib.define( + "embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " + "int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", +) + + +@impl(quantized_decomposed_lib, "embedding_4bit", "CompositeExplicitAutograd") +def embedding_4bit( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, +) -> torch.Tensor: + embedding_weight_checks(weight, weight_scales, weight_zero_points) + group_size = (2 * weight.size(1)) // ( + weight_scales.size(1) if weight_scales.dim() == 2 else 1 + ) + weight_even = weight.div(16, rounding_mode="trunc") + weight_odd = weight.remainder(16) + weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1) + weight = weight_unpacked.view(weight.shape[0], -1) + weight = weight.view(torch.int8).add(-8) + + weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + weight.dtype, + group_size, + weight_scales.dtype, + ) + return torch.ops.aten.embedding.default(weight, indices) + + +@impl_abstract("quantized_decomposed::embedding_4bit.out") +def embedding_4bit_out_meta( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, + out: torch.Tensor, +) -> torch.Tensor: + return embedding_4bit( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + ) + + +@impl(quantized_decomposed_lib, "embedding_4bit.dtype", "CompositeExplicitAutograd") +def embedding_4bit_dtype( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, + dtype: Optional[torch.dtype], +) -> torch.Tensor: + embedding_weight_checks(weight, weight_scales, weight_zero_points) + group_size = (2 * weight.size(1)) // ( + weight_scales.size(1) if weight_scales.dim() == 2 else 1 + ) + weight_even = weight.div(16, rounding_mode="trunc") + weight_odd = weight.remainder(16) + weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1) + weight = weight_unpacked.view(weight.shape[0], -1) + weight = weight.view(torch.int8).add(-8) + + weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + weight.dtype, + group_size, + dtype, + ) + return torch.ops.aten.embedding.default(weight, indices) + + +@impl_abstract("quantized_decomposed::embedding_4bit.dtype_out") +def embedding_4bit_dtype_out_meta( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, + dtype: Optional[torch.dtype], + out: torch.Tensor, +) -> torch.Tensor: + return embedding_4bit_dtype( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + dtype, + ) + + quantized_decomposed_lib.define( "mixed_mm(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points) -> Tensor", )