From 2ff9055fdc1b7493a27158d57d4f09be5a7d8b22 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 16 Apr 2024 21:45:40 -0700 Subject: [PATCH 1/2] {executorch][llama] support mqa This diff adds support for multi query attention for sdpa with kv cache Differential Revision: [D56228316](https://our.internmc.facebook.com/intern/diff/D56228316/) [ghstack-poisoned] --- examples/models/llama2/custom_ops/TARGETS | 14 ++ examples/models/llama2/custom_ops/op_sdpa.cpp | 21 +- .../custom_ops/test_sdpa_with_kv_cache.py | 203 ++++++++++++++++++ 3 files changed, 236 insertions(+), 2 deletions(-) create mode 100644 examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py diff --git a/examples/models/llama2/custom_ops/TARGETS b/examples/models/llama2/custom_ops/TARGETS index 2341af9282f..518ede09dcb 100644 --- a/examples/models/llama2/custom_ops/TARGETS +++ b/examples/models/llama2/custom_ops/TARGETS @@ -1,8 +1,22 @@ # Any targets that should be shared between fbcode and xplat must be defined in # targets.bzl. This file can contain fbcode-only targets. +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") load(":targets.bzl", "define_common_targets") oncall("executorch") define_common_targets() + +python_unittest( + name = "test_sdpa_with_kv_cache", + srcs = [ + "test_sdpa_with_kv_cache.py", + ], + preload_deps = [ + ":custom_ops_aot_lib", + ], + deps = [ + "//caffe2:torch", + ], +) diff --git a/examples/models/llama2/custom_ops/op_sdpa.cpp b/examples/models/llama2/custom_ops/op_sdpa.cpp index bf8c31de73d..dd0fa67ec08 100644 --- a/examples/models/llama2/custom_ops/op_sdpa.cpp +++ b/examples/models/llama2/custom_ops/op_sdpa.cpp @@ -219,13 +219,29 @@ void cpu_flash_attention( int64_t qSize = query.size(2); int64_t headSize = query.size(3); int64_t kvSize = value.size(2); + int64_t num_heads_kv = key.size(1); if (is_with_kv_cache) { num_head = query.size(2); + num_heads_kv = key.size(2); qSize = query.size(1); kvSize = value.size(1); } + ET_CHECK_MSG( + num_heads_kv <= num_head, + "FlashAttention does not support num kv heads > num query heads.Got num query heads=%" PRId64 + " num key heads:%" PRId64, + num_head, + num_heads_kv); + ET_CHECK_MSG( + num_head % num_heads_kv == 0, + "FlashAttention: num qyery heads must be divisible by num kv heads but got num query heads=%" PRId64 + " and num kv heads=%" PRId64, + num_head, + num_heads_kv); + int64_t num_reps = num_head / num_heads_kv; + bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); if (has_attn_mask) { /* @@ -365,6 +381,7 @@ void cpu_flash_attention( fill_stub( qk_max_data, -std::numeric_limits::infinity(), qBlockSize); int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + auto j_kv = j / num_reps; for (int64_t n = 0; n < num_keys; n += kvSplitSize) { int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); // Calculate scale * q @ k.T @@ -376,7 +393,7 @@ void cpu_flash_attention( qBlockSize, headSize, static_cast(1), - k_data + i * kStrideB + j * kStrideH + n * kStrideN, + k_data + i * kStrideB + j_kv * kStrideH + n * kStrideN, kStrideN, q_data + i * qStrideB + j * qStrideH + m * qStrideM, qStrideM, @@ -460,7 +477,7 @@ void cpu_flash_attention( qBlockSize, kvBlockSize, static_cast(1), - v_data + i * vStrideB + j * vStrideH + n * vStrideN, + v_data + i * vStrideB + j_kv * vStrideH + n * vStrideN, vStrideN, conditional_data_ptr(qk_data, qk_reduced_data), kvBlockSize, diff --git a/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py new file mode 100644 index 00000000000..949fdeab2c4 --- /dev/null +++ b/examples/models/llama2/custom_ops/test_sdpa_with_kv_cache.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torch.nn.functional as F + + +class SDPATest(unittest.TestCase): + + def setUp(self): + torch.manual_seed(42) + self.k_cache = torch.zeros((1, 5, 8, 4)) + self.v_cache = torch.zeros((1, 5, 8, 4)) + self.mask = torch.full( + (5, 5), + float("-inf"), + ) + self.mask = torch.triu(self.mask, diagonal=1) + + def _sdpa_with_kv_cache_ref(self, q, k, v, k_cache, v_cache, mask, start_pos): + print(f"at start_pos:{start_pos}") + print(q) + print(k) + print(v) + attn_mask = mask[start_pos].view((1, -1)) + attn_mask = attn_mask[:, : start_pos + 1] + q = q.transpose(1, 2) + k_cache[:, start_pos] = k + v_cache[:, start_pos] = v + sliced_k_cache = k_cache[:, : start_pos + 1, :, :] + sliced_v_cache = v_cache[:, : start_pos + 1, :, :] + sliced_k_cache = sliced_k_cache.transpose(1, 2) + sliced_v_cache = sliced_v_cache.transpose(1, 2) + # print(sliced_k_cache.size()) + # print(torch.matmul(q, sliced_k_cache.transpose(2, 3))) + # print("q @ k") + # qk = torch.matmul(q, sliced_k_cache.transpose(2, 3)) + # qk_softmax = torch.softmax(qk, dim=-1) + # qkv = torch.matmul(qk_softmax, sliced_v_cache) + # print(qk) + # print(qk_softmax) + # print(qkv) + out = F.scaled_dot_product_attention( + q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask + ) + out = out.transpose(1, 2) + print(out) + print(f"-------- start pos {start_pos} done -----") + return out + + def test_sdpa_with_cache_no_mqa_1(self): + q = torch.rand((1, 1, 8, 4)) + k = torch.rand((1, 1, 8, 4)) + v = torch.rand((1, 1, 8, 4)) + ref_output = self._sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, self.mask, 0 + ) + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, 0, 1, None, 0, False + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + def test_sdpa_with_cache_no_mqa_2(self): + q = torch.rand((1, 1, 8, 4)) + k = torch.rand((1, 1, 8, 4)) + v = torch.rand((1, 1, 8, 4)) + + ref_output = self._sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, self.mask, 1 + ) + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + def test_sdpa_with_cache_no_mqa_3(self): + q = torch.rand((1, 1, 8, 4)) + k = torch.rand((1, 1, 8, 4)) + v = torch.rand((1, 1, 8, 4)) + + ref_output = self._sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, self.mask, 2 + ) + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, 2, 1, None, 0, False + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + def test_sdpa_with_cache_no_mqa_4(self): + q = torch.rand((1, 1, 8, 4)) + k = torch.rand((1, 1, 8, 4)) + v = torch.rand((1, 1, 8, 4)) + + ref_output = self._sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, self.mask, 3 + ) + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, 3, 1, None, 0, False + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + +class SDPATestWithMQA(unittest.TestCase): + + def setup_caches(self): + self.k_cache = torch.zeros((1, 5, self.n_heads_kv, 4)) + self.v_cache = torch.zeros((1, 5, self.n_heads_kv, 4)) + + def setUp(self): + torch.manual_seed(42) + self.n_heads_kv = 4 + self.n_heads_q = 8 + self.setup_caches() + self.mask = torch.full( + (5, 5), + float("-inf"), + ) + self.mask = torch.triu(self.mask, diagonal=1) + + def _sdpa_with_kv_cache_ref(self, q, k, v, k_cache, v_cache, mask, start_pos): + print(f"at start_pos:{start_pos}") + print(q) + print(k) + print(v) + attn_mask = mask[start_pos].view((1, -1)) + attn_mask = attn_mask[:, : start_pos + 1] + q = q.transpose(1, 2) + k_cache[:, start_pos] = k + v_cache[:, start_pos] = v + sliced_k_cache = k_cache[:, : start_pos + 1, :, :] + sliced_v_cache = v_cache[:, : start_pos + 1, :, :] + sliced_k_cache = sliced_k_cache.transpose(1, 2) + sliced_v_cache = sliced_v_cache.transpose(1, 2) + # print(sliced_k_cache.size()) + # print(torch.matmul(q, sliced_k_cache.transpose(2, 3))) + # print("q @ k") + # qk = torch.matmul(q, sliced_k_cache.transpose(2, 3)) + # qk_softmax = torch.softmax(qk, dim=-1) + # qkv = torch.matmul(qk_softmax, sliced_v_cache) + # print(qk) + # print(qk_softmax) + # print(qkv) + num_heads_q = q.size(1) + num_heads_kv = sliced_k_cache.size(1) + if num_heads_q != num_heads_kv: + assert ( + num_heads_q % num_heads_kv == 0 + ), f"{num_heads_q} not divisible by {num_heads_kv}" + n_reps = num_heads_q // num_heads_kv + if n_reps > 1: + sliced_k_cache = sliced_k_cache.repeat_interleave(n_reps, dim=1) + sliced_v_cache = sliced_v_cache.repeat_interleave(n_reps, dim=1) + out = F.scaled_dot_product_attention( + q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask + ) + out = out.transpose(1, 2) + print(out) + print(f"-------- start pos {start_pos} done -----") + return out + + def test_sdpa_with_cache_mqa_1(self): + q = torch.rand((1, 1, self.n_heads_q, 4)) + k = torch.rand((1, 1, self.n_heads_kv, 4)) + v = torch.rand((1, 1, self.n_heads_kv, 4)) + ref_output = self._sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, self.mask, 0 + ) + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, 0, 1, None, 0, False + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + def test_sdpa_with_cache_mqa_2(self): + q = torch.rand((1, 1, self.n_heads_q, 4)) + k = torch.rand((1, 1, self.n_heads_kv, 4)) + v = torch.rand((1, 1, self.n_heads_kv, 4)) + ref_output = self._sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, self.mask, 1 + ) + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False + ) + self.assertTrue(torch.allclose(ref_output, op_output)) + + def test_sdpa_with_cache_mqa_3(self): + self.n_heads_q = 14 + self.n_heads_kv = 7 + self.setup_caches() + q = torch.rand((1, 1, self.n_heads_q, 4)) + k = torch.rand((1, 1, self.n_heads_kv, 4)) + v = torch.rand((1, 1, self.n_heads_kv, 4)) + ref_output = self._sdpa_with_kv_cache_ref( + q, k, v, self.k_cache, self.v_cache, self.mask, 1 + ) + op_output = torch.ops.llama.sdpa_with_kv_cache( + q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False + ) + self.assertTrue(torch.allclose(ref_output, op_output)) From 26aced7565e18f5eb92709634ecb27f63021df18 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 17 Apr 2024 07:08:56 -0700 Subject: [PATCH 2/2] Update on "{executorch][llama] support mqa" This diff adds support for multi query attention for sdpa with kv cache Differential Revision: [D56228316](https://our.internmc.facebook.com/intern/diff/D56228316/) [ghstack-poisoned] --- examples/models/llama2/custom_ops/TARGETS | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/models/llama2/custom_ops/TARGETS b/examples/models/llama2/custom_ops/TARGETS index 518ede09dcb..199cbe363d0 100644 --- a/examples/models/llama2/custom_ops/TARGETS +++ b/examples/models/llama2/custom_ops/TARGETS @@ -1,14 +1,14 @@ # Any targets that should be shared between fbcode and xplat must be defined in # targets.bzl. This file can contain fbcode-only targets. -load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") load(":targets.bzl", "define_common_targets") oncall("executorch") define_common_targets() -python_unittest( +runtime.python_test( name = "test_sdpa_with_kv_cache", srcs = [ "test_sdpa_with_kv_cache.py",