From 93d4791ea5c696cf4764e3a684c099b31c320b27 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 17 Apr 2024 19:00:23 -0700 Subject: [PATCH] fix embedding_4bit resize (#3118) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/3118 Reviewed By: larryliu0820 Differential Revision: D56282683 --- kernels/quantized/cpu/op_embedding4b.cpp | 2 +- kernels/quantized/test/op_embedding4b_test.cpp | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/kernels/quantized/cpu/op_embedding4b.cpp b/kernels/quantized/cpu/op_embedding4b.cpp index 33be86e5cc4..f234ee224ca 100644 --- a/kernels/quantized/cpu/op_embedding4b.cpp +++ b/kernels/quantized/cpu/op_embedding4b.cpp @@ -195,7 +195,7 @@ void resize_out_tensor( for (size_t i = 0; i < indices.dim(); i++) { expected_output_size[i] = indices.size(i); } - const size_t embedding_dim = weight.size(1); + const size_t embedding_dim = weight.size(1) * 2; expected_output_size[out.dim() - 1] = embedding_dim; exec_aten::ArrayRef output_size{ diff --git a/kernels/quantized/test/op_embedding4b_test.cpp b/kernels/quantized/test/op_embedding4b_test.cpp index 56944c57857..1eb7aa11b2a 100644 --- a/kernels/quantized/test/op_embedding4b_test.cpp +++ b/kernels/quantized/test/op_embedding4b_test.cpp @@ -19,6 +19,7 @@ using namespace ::testing; using exec_aten::ArrayRef; using exec_aten::optional; +using exec_aten::RuntimeContext; using exec_aten::ScalarType; using exec_aten::Tensor; using torch::executor::native::quantized_embedding_4bit_out; @@ -60,6 +61,20 @@ TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbedding) { EXPECT_TENSOR_EQ(out, expected); + out = tf.zeros({3, 4}); + auto context = RuntimeContext(); + torch::executor::native::quantized_embedding_4bit_out( + context, + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_EQ(out, expected); + // Groupwise quantization. groupsize = 2 weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0}); weight_zero_points = tf.make({3, 2}, {1, -5, 0, 2, -3, -1});