From 73d5e7eed0d370e86f2293e3adaf2ca3f8d34818 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 16 Apr 2024 17:43:04 -0700 Subject: [PATCH 1/4] {executorch][llama] support mqa Summary: This diff adds support for multi query attention for sdpa with kv cache Reviewed By: iseeyuan Differential Revision: D56212419 --- examples/models/llama2/custom_ops/op_sdpa.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/models/llama2/custom_ops/op_sdpa.cpp b/examples/models/llama2/custom_ops/op_sdpa.cpp index dd0fa67ec08..58cb5c6c2c3 100644 --- a/examples/models/llama2/custom_ops/op_sdpa.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa.cpp @@ -240,6 +240,7 @@ void cpu_flash_attention( " and num kv heads=%" PRId64, num_head, num_heads_kv); + int64_t num_reps = num_head / num_heads_kv; bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); From f7c1459e1e9b926abcb3ea3f10fa7c8e4e0fb950 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Tue, 16 Apr 2024 23:10:18 -0700 Subject: [PATCH 2/4] 4b embedding quantizer (#3081) Summary: 4b embedding quantizer Reviewed By: larryliu0820 Differential Revision: D56229021 --- examples/models/llama2/quantize.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/models/llama2/quantize.py b/examples/models/llama2/quantize.py index ec2662ecc5c..21464125fe8 100644 --- a/examples/models/llama2/quantize.py +++ b/examples/models/llama2/quantize.py @@ -436,10 +436,18 @@ def __init__( @torch.no_grad() def forward(self, indices: torch.Tensor) -> torch.Tensor: if not self.packed: # 8bit +<<<<<<< HEAD return torch.ops.quantized_decomposed.embedding_byte.dtype( self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype ) else: # 4bit packed return torch.ops.quantized_decomposed.embedding_4bit.dtype( +======= + return torch.ops.llama_quantized.DEPRECATED_DO_NOT_USE_embedding_byte.dtype( + self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype + ) + else: # 4bit packed + return torch.ops.llama_quantized.embedding_4bit.dtype( +>>>>>>> 6b3b7228c (4b embedding quantizer (#3081)) self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype ) From f938acb8760310a5df129e236bd4691c67c1c9d7 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 17 Apr 2024 08:54:41 -0700 Subject: [PATCH 3/4] Patch Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- examples/models/llama2/export_llama_lib.py | 4 ++-- examples/models/llama2/quantize.py | 8 -------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 449824a33b9..05e314eae0b 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -29,7 +29,7 @@ Transformer, ) from executorch.exir.backend.backend_details import CompileSpec - +from executorch.exir.passes import * from executorch.sdk.etrecord import generate_etrecord from executorch.util.activation_memory_profiler import generate_memory_trace from sentencepiece import SentencePieceProcessor @@ -539,7 +539,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: bitwidth = int(bitwidth) transforms.append( lambda model: EmbeddingQuantHandler( - model, bitwidth=bitwidth, group_size=group_size + model, bitwidth=bitwidth, group_size=group_size, packed=(bitwidth==4), ).quantized_model() ) diff --git a/examples/models/llama2/quantize.py b/examples/models/llama2/quantize.py index 21464125fe8..ec2662ecc5c 100644 --- a/examples/models/llama2/quantize.py +++ b/examples/models/llama2/quantize.py @@ -436,18 +436,10 @@ def __init__( @torch.no_grad() def forward(self, indices: torch.Tensor) -> torch.Tensor: if not self.packed: # 8bit -<<<<<<< HEAD return torch.ops.quantized_decomposed.embedding_byte.dtype( self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype ) else: # 4bit packed return torch.ops.quantized_decomposed.embedding_4bit.dtype( -======= - return torch.ops.llama_quantized.DEPRECATED_DO_NOT_USE_embedding_byte.dtype( - self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype - ) - else: # 4bit packed - return torch.ops.llama_quantized.embedding_4bit.dtype( ->>>>>>> 6b3b7228c (4b embedding quantizer (#3081)) self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype ) From a28e73bc951d895f9affdcca55cb4b9cd9fcb2e2 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 17 Apr 2024 18:58:29 -0700 Subject: [PATCH 4/4] Define embedding_4bit ops Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- examples/models/llama2/custom_ops/op_sdpa.cpp | 1 - examples/models/llama2/export_llama_lib.py | 7 +++++-- exir/passes/_quant_patterns_and_replacements.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/models/llama2/custom_ops/op_sdpa.cpp b/examples/models/llama2/custom_ops/op_sdpa.cpp index 58cb5c6c2c3..dd0fa67ec08 100644 --- a/examples/models/llama2/custom_ops/op_sdpa.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa.cpp @@ -240,7 +240,6 @@ void cpu_flash_attention( " and num kv heads=%" PRId64, num_head, num_heads_kv); - int64_t num_reps = num_head / num_heads_kv; bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 05e314eae0b..a20c006677c 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -29,7 +29,7 @@ Transformer, ) from executorch.exir.backend.backend_details import CompileSpec -from executorch.exir.passes import * + from executorch.sdk.etrecord import generate_etrecord from executorch.util.activation_memory_profiler import generate_memory_trace from sentencepiece import SentencePieceProcessor @@ -539,7 +539,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: bitwidth = int(bitwidth) transforms.append( lambda model: EmbeddingQuantHandler( - model, bitwidth=bitwidth, group_size=group_size, packed=(bitwidth==4), + model, + bitwidth=bitwidth, + group_size=group_size, + packed=(bitwidth == 4), ).quantized_model() ) diff --git a/exir/passes/_quant_patterns_and_replacements.py b/exir/passes/_quant_patterns_and_replacements.py index e198ff383e9..ec543560b86 100644 --- a/exir/passes/_quant_patterns_and_replacements.py +++ b/exir/passes/_quant_patterns_and_replacements.py @@ -189,7 +189,7 @@ def embedding_byte_dtype_out_meta( quantized_decomposed_lib.define( "embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " - "int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", + "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", )