From 722123863cf5797308c2fc7f605370d96d87b009 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Wed, 1 Jul 2026 09:40:00 +0800 Subject: [PATCH] feat: mrope nn module --- include/infinicore/nn.hpp | 1 + include/infinicore/nn/mrope.hpp | 59 ++++++ python/infinicore/__init__.py | 2 + python/infinicore/nn/modules/__init__.py | 3 +- python/infinicore/nn/modules/mrope.py | 105 ++++++++++ python/infinicore/ops/mrope.py | 52 +++++ src/infinicore/nn/mrope.cc | 154 ++++++++++++++ src/infinicore/pybind11/ops.hpp | 2 + src/infinicore/pybind11/ops/mrope.hpp | 46 ++++ src/infiniop/ops/mrope/cuda/kernel.cuh | 12 +- src/infiniop/ops/mrope/info.h | 15 +- src/infiniop/ops/mrope/nvidia/mrope_nvidia.cu | 6 +- test/infinicore/nn/mrope.py | 163 +++++++++++++++ test/infinicore/ops/mrope.py | 196 ++++++++++++++++++ test/infiniop/mrope.py | 21 +- 15 files changed, 808 insertions(+), 29 deletions(-) create mode 100644 include/infinicore/nn/mrope.hpp create mode 100644 python/infinicore/nn/modules/mrope.py create mode 100644 python/infinicore/ops/mrope.py create mode 100644 src/infinicore/nn/mrope.cc create mode 100644 src/infinicore/pybind11/ops/mrope.hpp create mode 100644 test/infinicore/nn/mrope.py create mode 100644 test/infinicore/ops/mrope.py diff --git a/include/infinicore/nn.hpp b/include/infinicore/nn.hpp index b927b294b..9f7ac5bfe 100644 --- a/include/infinicore/nn.hpp +++ b/include/infinicore/nn.hpp @@ -2,4 +2,5 @@ #include "nn/embedding.hpp" #include "nn/linear.hpp" +#include "nn/mrope.hpp" #include "nn/rmsnorm.hpp" diff --git a/include/infinicore/nn/mrope.hpp b/include/infinicore/nn/mrope.hpp new file mode 100644 index 000000000..2ac69ac85 --- /dev/null +++ b/include/infinicore/nn/mrope.hpp @@ -0,0 +1,59 @@ +#pragma once + +#include "../context/context.hpp" +#include "../tensor.hpp" +#include "module.hpp" +#include +#include +#include + +namespace infinicore::nn { + +class MRoPE : public Module { +public: + MRoPE(size_t head_dim, + size_t rotary_dim, + size_t max_seq_len, + double theta, + std::array section, + bool interleaved, + const DataType &dtype, + const Device &device); + + std::pair forward(const Tensor &q, + const Tensor &k, + const Tensor &positions) const; + + std::pair forward(const Tensor &q_out, + const Tensor &k_out, + const Tensor &q, + const Tensor &k, + const Tensor &positions) const; + + size_t rotary_dim() const { return rotary_dim_; } + size_t head_dim() const { return head_dim_; } + size_t max_seq_len() const { return max_seq_len_; } + double theta() const { return theta_; } + const std::array §ion() const { return section_; } + bool interleaved() const { return interleaved_; } + DataType dtype() const { return dtype_; } + + std::string extra_repr() const; + +protected: + INFINICORE_NN_BUFFER(sin_cache); + INFINICORE_NN_BUFFER(cos_cache); + +private: + void initialize_cache(); + + size_t head_dim_; + size_t rotary_dim_; + size_t max_seq_len_; + double theta_; + std::array section_; + bool interleaved_; + DataType dtype_; +}; + +} // namespace infinicore::nn diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index d84a1ce74..cdefb904b 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -112,6 +112,7 @@ moore_mate_flash_attn_decode, moore_mate_flash_attn_prefill, ) +from infinicore.ops.mrope import mrope from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow from infinicore.ops.nrm2 import nrm2 @@ -215,6 +216,7 @@ "addbmm", "floor", "attention", + "mrope", "block_diag", "kron", "bitwise_right_shift", diff --git a/python/infinicore/nn/modules/__init__.py b/python/infinicore/nn/modules/__init__.py index 1ec1349ad..f7d9cb9c2 100644 --- a/python/infinicore/nn/modules/__init__.py +++ b/python/infinicore/nn/modules/__init__.py @@ -1,8 +1,9 @@ from .container import InfiniCoreModuleList as ModuleList from .linear import Linear from .module import InfiniCoreModule as Module +from .mrope import MRoPE from .normalization import RMSNorm from .rope import RoPE from .sparse import Embedding -__all__ = ["Linear", "RMSNorm", "Embedding", "RoPE", "ModuleList", "Module"] +__all__ = ["Linear", "RMSNorm", "Embedding", "RoPE", "MRoPE", "ModuleList", "Module"] diff --git a/python/infinicore/nn/modules/mrope.py b/python/infinicore/nn/modules/mrope.py new file mode 100644 index 000000000..a25ba9191 --- /dev/null +++ b/python/infinicore/nn/modules/mrope.py @@ -0,0 +1,105 @@ +import numpy as np + +import infinicore +from infinicore.ops.mrope import mrope + +from ...tensor import Tensor +from .module import InfiniCoreModule as Module + + +def create_sin_cos_table_numpy(max_position, rotary_dim, theta=10000.0): + if rotary_dim % 2 != 0: + raise ValueError("rotary_dim must be even") + pos = np.arange(0, max_position) + freqs = 1.0 / ( + theta + ** (np.arange(0, rotary_dim, 2)[: (rotary_dim // 2)].astype(float) / rotary_dim) + ) + angles = np.outer(pos, freqs) + sin_table = np.sin(angles, dtype=np.float32) + cos_table = np.cos(angles, dtype=np.float32) + return sin_table, cos_table + + +def create_sin_cos_table( + max_position, rotary_dim, theta=10000.0, device=None, dtype=None +): + sin_table_np, cos_table_np = create_sin_cos_table_numpy( + max_position, rotary_dim, theta + ) + return ( + infinicore.from_numpy(sin_table_np, dtype=dtype, device=device), + infinicore.from_numpy(cos_table_np, dtype=dtype, device=device), + ) + + +class MRoPE(Module): + r"""Multimodal rotary position embedding with vLLM-style 2D sin/cos cache.""" + + __constants__ = [ + "max_position_embeddings", + "rope_theta", + "head_dim", + "rotary_dim", + "section", + "interleaved", + ] + + def __init__( + self, + max_position_embeddings: int, + rope_theta: float, + head_dim: int, + rotary_dim: int, + section: tuple[int, int, int], + interleaved: bool = False, + device=None, + dtype=None, + ): + super().__init__() + if rotary_dim <= 0 or rotary_dim > head_dim or rotary_dim % 2 != 0: + raise ValueError("rotary_dim must be positive, even, and <= head_dim") + if len(section) != 3 or 2 * sum(section) != rotary_dim: + raise ValueError("section must contain 3 values and sum to rotary_dim / 2") + + factory_kwargs = { + "device": infinicore.device("cpu", 0) if device is None else device, + "dtype": infinicore.float32 if dtype is None else dtype, + } + + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.head_dim = head_dim + self.rotary_dim = rotary_dim + self.section = tuple(section) + self.interleaved = interleaved + + self._sin_table, self._cos_table = create_sin_cos_table( + self.max_position_embeddings, + self.rotary_dim, + self.rope_theta, + **factory_kwargs, + ) + + def forward( + self, + q: Tensor, + k: Tensor, + positions: Tensor, + *, + out: tuple[Tensor, Tensor] | None = None, + ) -> tuple[Tensor, Tensor]: + return mrope( + q, + k, + self._cos_table, + self._sin_table, + positions, + self.head_dim, + self.rotary_dim, + self.section[0], + self.section[1], + self.section[2], + self.interleaved, + out=out, + ) diff --git a/python/infinicore/ops/mrope.py b/python/infinicore/ops/mrope.py new file mode 100644 index 000000000..8111ce9d7 --- /dev/null +++ b/python/infinicore/ops/mrope.py @@ -0,0 +1,52 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def mrope( + q: Tensor, + k: Tensor, + cos: Tensor, + sin: Tensor, + positions: Tensor, + head_size: int, + rotary_dim: int, + section_t: int, + section_h: int, + section_w: int, + interleaved: bool, + *, + out=None, +) -> tuple[Tensor, Tensor]: + if out is None: + q_out, k_out = _infinicore.mrope( + q._underlying, + k._underlying, + cos._underlying, + sin._underlying, + positions._underlying, + head_size, + rotary_dim, + section_t, + section_h, + section_w, + interleaved, + ) + return Tensor(q_out), Tensor(k_out) + + q_out, k_out = out + _infinicore.mrope_( + q_out._underlying, + k_out._underlying, + q._underlying, + k._underlying, + cos._underlying, + sin._underlying, + positions._underlying, + head_size, + rotary_dim, + section_t, + section_h, + section_w, + interleaved, + ) + return q_out, k_out diff --git a/src/infinicore/nn/mrope.cc b/src/infinicore/nn/mrope.cc new file mode 100644 index 000000000..a01d1580e --- /dev/null +++ b/src/infinicore/nn/mrope.cc @@ -0,0 +1,154 @@ +#include "infinicore/nn/mrope.hpp" +#include "../../utils.h" +#include "../utils.hpp" +#include "infinicore/ops/mrope.hpp" +#include +#include +#include + +namespace infinicore::nn { + +MRoPE::MRoPE(size_t head_dim, + size_t rotary_dim, + size_t max_seq_len, + double theta, + std::array section, + bool interleaved, + const DataType &dtype, + const Device &device) + : head_dim_(head_dim), + rotary_dim_(rotary_dim), + max_seq_len_(max_seq_len), + theta_(theta), + section_(section), + interleaved_(interleaved), + dtype_(dtype) { + if (rotary_dim_ % 2 != 0) { + throw std::invalid_argument("rotary_dim must be even for MRoPE, got " + std::to_string(rotary_dim_)); + } + if (rotary_dim_ == 0 || rotary_dim_ > head_dim_) { + throw std::invalid_argument("rotary_dim must be in (0, head_dim] for MRoPE"); + } + if (2 * static_cast(section_[0] + section_[1] + section_[2]) != rotary_dim_) { + throw std::invalid_argument("MRoPE section sum must equal rotary_dim / 2"); + } + device_ = device; + initialize_cache(); +} + +void MRoPE::initialize_cache() { + const size_t cache_dim = rotary_dim_ / 2; + const size_t numel = max_seq_len_ * cache_dim; + INFINICORE_NN_BUFFER_INIT(sin_cache, ({max_seq_len_, cache_dim}, dtype_, device_)); + INFINICORE_NN_BUFFER_INIT(cos_cache, ({max_seq_len_, cache_dim}, dtype_, device_)); + + std::vector sin_data(numel); + std::vector cos_data(numel); + for (size_t pos = 0; pos < max_seq_len_; ++pos) { + for (size_t dim_idx = 0; dim_idx < cache_dim; ++dim_idx) { + const float inv_freq = 1.0f / std::pow(static_cast(theta_), 2.0f * static_cast(dim_idx) / static_cast(rotary_dim_)); + const float angle = static_cast(pos) * inv_freq; + const size_t offset = pos * cache_dim + dim_idx; + sin_data[offset] = std::sin(angle); + cos_data[offset] = std::cos(angle); + } + } + + const auto cpu_device = Device(Device::Type::CPU, 0); + if (dtype_ == DataType::F32) { + auto sin_cpu = Tensor::from_blob(sin_data.data(), {max_seq_len_, cache_dim}, DataType::F32, cpu_device); + auto cos_cpu = Tensor::from_blob(cos_data.data(), {max_seq_len_, cache_dim}, DataType::F32, cpu_device); + sin_cache_->copy_from(sin_cpu); + cos_cache_->copy_from(cos_cpu); + return; + } + if (dtype_ == DataType::BF16) { + std::vector sin_bf16(numel); + std::vector cos_bf16(numel); + for (size_t i = 0; i < numel; ++i) { + sin_bf16[i] = utils::cast(sin_data[i]); + cos_bf16[i] = utils::cast(cos_data[i]); + } + auto sin_cpu = Tensor::from_blob(sin_bf16.data(), {max_seq_len_, cache_dim}, DataType::BF16, cpu_device); + auto cos_cpu = Tensor::from_blob(cos_bf16.data(), {max_seq_len_, cache_dim}, DataType::BF16, cpu_device); + sin_cache_->copy_from(sin_cpu); + cos_cache_->copy_from(cos_cpu); + return; + } + if (dtype_ == DataType::F16) { + std::vector sin_f16(numel); + std::vector cos_f16(numel); + for (size_t i = 0; i < numel; ++i) { + sin_f16[i] = utils::cast(sin_data[i]); + cos_f16[i] = utils::cast(cos_data[i]); + } + auto sin_cpu = Tensor::from_blob(sin_f16.data(), {max_seq_len_, cache_dim}, DataType::F16, cpu_device); + auto cos_cpu = Tensor::from_blob(cos_f16.data(), {max_seq_len_, cache_dim}, DataType::F16, cpu_device); + sin_cache_->copy_from(sin_cpu); + cos_cache_->copy_from(cos_cpu); + return; + } + throw std::runtime_error("MRoPE cache dtype conversion not supported for dtype: " + std::to_string(static_cast(dtype_))); +} + +std::pair MRoPE::forward(const Tensor &q, + const Tensor &k, + const Tensor &positions) const { + const size_t num_tokens = q->size(0); + auto q_flat = q->contiguous()->view({num_tokens, q->size(1) * head_dim_}); + auto k_flat = k->contiguous()->view({num_tokens, k->size(1) * head_dim_}); + auto q_out = Tensor::empty(q_flat->shape(), q_flat->dtype(), q_flat->device()); + auto k_out = Tensor::empty(k_flat->shape(), k_flat->dtype(), k_flat->device()); + op::mrope_(q_out, + k_out, + q_flat, + k_flat, + cos_cache_, + sin_cache_, + positions, + static_cast(head_dim_), + static_cast(rotary_dim_), + section_[0], + section_[1], + section_[2], + interleaved_); + return {q_out->view(q->shape()), k_out->view(k->shape())}; +} + +std::pair MRoPE::forward(const Tensor &q_out, + const Tensor &k_out, + const Tensor &q, + const Tensor &k, + const Tensor &positions) const { + const size_t num_tokens = q->size(0); + auto q_flat = q->contiguous()->view({num_tokens, q->size(1) * head_dim_}); + auto k_flat = k->contiguous()->view({num_tokens, k->size(1) * head_dim_}); + auto q_out_flat = q_out->view({num_tokens, q->size(1) * head_dim_}); + auto k_out_flat = k_out->view({num_tokens, k->size(1) * head_dim_}); + op::mrope_(q_out_flat, + k_out_flat, + q_flat, + k_flat, + cos_cache_, + sin_cache_, + positions, + static_cast(head_dim_), + static_cast(rotary_dim_), + section_[0], + section_[1], + section_[2], + interleaved_); + return {q_out, k_out}; +} + +std::string MRoPE::extra_repr() const { + return "MRoPE(head_dim=" + std::to_string(head_dim_) + + ", rotary_dim=" + std::to_string(rotary_dim_) + + ", max_seq_len=" + std::to_string(max_seq_len_) + + ", theta=" + std::to_string(theta_) + + ", section=[" + std::to_string(section_[0]) + "," + std::to_string(section_[1]) + "," + std::to_string(section_[2]) + "]" + + ", interleaved=" + (interleaved_ ? "true" : "false") + + ", dtype=" + std::to_string(static_cast(dtype_)) + ")"; +} + +} // namespace infinicore::nn diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index e5fb4441f..582008fcd 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -82,6 +82,7 @@ #include "ops/mha.hpp" #include "ops/mha_kvcache.hpp" #include "ops/mha_varlen.hpp" +#include "ops/mrope.hpp" #include "ops/mul.hpp" #include "ops/multi_margin_loss.hpp" #include "ops/nrm2.hpp" @@ -188,6 +189,7 @@ inline void bind(py::module &m) { bind_mha_kvcache(m); bind_mha_varlen(m); bind_mha(m); + bind_mrope(m); bind_hardswish(m); bind_hardtanh(m); bind_gaussian_nll_loss(m); diff --git a/src/infinicore/pybind11/ops/mrope.hpp b/src/infinicore/pybind11/ops/mrope.hpp new file mode 100644 index 000000000..765584880 --- /dev/null +++ b/src/infinicore/pybind11/ops/mrope.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +#include "infinicore/ops/mrope.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_mrope(py::module &m) { + m.def("mrope", + &op::mrope, + py::arg("q"), + py::arg("k"), + py::arg("cos"), + py::arg("sin"), + py::arg("positions"), + py::arg("head_size"), + py::arg("rotary_dim"), + py::arg("section_t"), + py::arg("section_h"), + py::arg("section_w"), + py::arg("interleaved"), + R"doc(Multimodal rotary position embedding for q and k.)doc"); + + m.def("mrope_", + &op::mrope_, + py::arg("q_out"), + py::arg("k_out"), + py::arg("q"), + py::arg("k"), + py::arg("cos"), + py::arg("sin"), + py::arg("positions"), + py::arg("head_size"), + py::arg("rotary_dim"), + py::arg("section_t"), + py::arg("section_h"), + py::arg("section_w"), + py::arg("interleaved"), + R"doc(In-place multimodal rotary position embedding for q and k.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ops/mrope/cuda/kernel.cuh b/src/infiniop/ops/mrope/cuda/kernel.cuh index 5a715184a..2844ae0aa 100644 --- a/src/infiniop/ops/mrope/cuda/kernel.cuh +++ b/src/infiniop/ops/mrope/cuda/kernel.cuh @@ -52,9 +52,7 @@ __device__ void rotateOne( ptrdiff_t out_stride_head, ptrdiff_t in_stride_token, ptrdiff_t in_stride_head, - ptrdiff_t cos_stride_axis, ptrdiff_t cos_stride_position, - ptrdiff_t sin_stride_axis, ptrdiff_t sin_stride_position, ptrdiff_t positions_stride_axis, ptrdiff_t positions_stride_token, @@ -90,9 +88,9 @@ __device__ void rotateOne( : static_cast(token_idx) * positions_stride_token; const int64_t raw_pos = static_cast(positions[pos_offset]); const size_t position = (raw_pos >= 0 && static_cast(raw_pos) < max_position_embeddings) ? static_cast(raw_pos) : 0; - const ptrdiff_t table_offset = axis * cos_stride_axis + static_cast(position) * cos_stride_position + i; + const ptrdiff_t table_offset = static_cast(position) * cos_stride_position + i; const Tangle cos_v = loadAs(cos[table_offset]); - const Tangle sin_v = loadAs(sin[axis * sin_stride_axis + static_cast(position) * sin_stride_position + i]); + const Tangle sin_v = loadAs(sin[static_cast(position) * sin_stride_position + i]); const Tangle x0 = loadAs(in[in_offset + i]); const Tangle x1 = loadAs(in[in_offset + i + half_rotary_dim]); out[out_offset + i] = storeAs(x0 * cos_v - x1 * sin_v); @@ -125,9 +123,7 @@ __device__ void mropeBlock( ptrdiff_t q_stride_head, ptrdiff_t k_stride_token, ptrdiff_t k_stride_head, - ptrdiff_t cos_stride_axis, ptrdiff_t cos_stride_position, - ptrdiff_t sin_stride_axis, ptrdiff_t sin_stride_position, ptrdiff_t positions_stride_axis, ptrdiff_t positions_stride_token, @@ -140,12 +136,12 @@ __device__ void mropeBlock( const size_t head_idx = blockIdx.y; rotateOne(q_out, q, cos, sin, positions, head_idx, num_q_heads, head_size, rotary_dim, half_rotary_dim, q_out_stride_token, q_out_stride_head, q_stride_token, q_stride_head, - cos_stride_axis, cos_stride_position, sin_stride_axis, sin_stride_position, + cos_stride_position, sin_stride_position, positions_stride_axis, positions_stride_token, max_position_embeddings, section_t, section_h, section_w, positions_has_axes, interleaved); rotateOne(k_out, k, cos, sin, positions, head_idx, num_kv_heads, head_size, rotary_dim, half_rotary_dim, k_out_stride_token, k_out_stride_head, k_stride_token, k_stride_head, - cos_stride_axis, cos_stride_position, sin_stride_axis, sin_stride_position, + cos_stride_position, sin_stride_position, positions_stride_axis, positions_stride_token, max_position_embeddings, section_t, section_h, section_w, positions_has_axes, interleaved); } diff --git a/src/infiniop/ops/mrope/info.h b/src/infiniop/ops/mrope/info.h index 144b8380a..5e2c958cc 100644 --- a/src/infiniop/ops/mrope/info.h +++ b/src/infiniop/ops/mrope/info.h @@ -16,7 +16,7 @@ class MRoPEInfo { size_t section_t, section_h, section_w; ptrdiff_t q_out_stride_token, q_out_stride_head, k_out_stride_token, k_out_stride_head; ptrdiff_t q_stride_token, q_stride_head, k_stride_token, k_stride_head; - ptrdiff_t cos_stride_axis, cos_stride_position, sin_stride_axis, sin_stride_position; + ptrdiff_t cos_stride_position, sin_stride_position; ptrdiff_t positions_stride_axis, positions_stride_token; bool positions_has_axes; bool interleaved; @@ -51,9 +51,8 @@ class MRoPEInfo { CHECK_OR_RETURN(q_desc->ndim() == 2 && k_desc->ndim() == 2 && q_out_desc->ndim() == 2 && k_out_desc->ndim() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(cos_desc->ndim() == 3 && sin_desc->ndim() == 3, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(cos_desc->dim(0) == 3 && sin_desc->dim(0) == 3, INFINI_STATUS_BAD_TENSOR_SHAPE); - CHECK_OR_RETURN(cos_desc->dim(1) == sin_desc->dim(1), INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(cos_desc->ndim() == 2 && sin_desc->ndim() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(cos_desc->dim(0) == sin_desc->dim(0), INFINI_STATUS_BAD_TENSOR_SHAPE); const size_t num_tokens = q_desc->dim(0); CHECK_OR_RETURN(k_desc->dim(0) == num_tokens && q_out_desc->dim(0) == num_tokens && k_out_desc->dim(0) == num_tokens, @@ -72,7 +71,7 @@ class MRoPEInfo { } else { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - CHECK_OR_RETURN(cos_desc->dim(2) == size_t(rotary_dim / 2) && sin_desc->dim(2) == size_t(rotary_dim / 2), + CHECK_OR_RETURN(cos_desc->dim(1) == size_t(rotary_dim / 2) && sin_desc->dim(1) == size_t(rotary_dim / 2), INFINI_STATUS_BAD_TENSOR_SHAPE); CHECK_OR_RETURN(q_desc->dim(1) % size_t(head_size) == 0 && k_desc->dim(1) % size_t(head_size) == 0, @@ -82,7 +81,7 @@ class MRoPEInfo { CHECK_OR_RETURN(q_desc->stride(1) == 1 && k_desc->stride(1) == 1 && q_out_desc->stride(1) == 1 && k_out_desc->stride(1) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); - CHECK_OR_RETURN(cos_desc->stride(2) == 1 && sin_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(cos_desc->stride(1) == 1 && sin_desc->stride(1) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); return utils::Result(MRoPEInfo{ data_type, @@ -93,7 +92,7 @@ class MRoPEInfo { size_t(head_size), size_t(rotary_dim), size_t(rotary_dim / 2), - cos_desc->dim(1), + cos_desc->dim(0), size_t(section_t), size_t(section_h), size_t(section_w), @@ -106,9 +105,7 @@ class MRoPEInfo { k_desc->stride(0), ptrdiff_t(head_size), cos_desc->stride(0), - cos_desc->stride(1), sin_desc->stride(0), - sin_desc->stride(1), positions_stride_axis, positions_stride_token, positions_has_axes, diff --git a/src/infiniop/ops/mrope/nvidia/mrope_nvidia.cu b/src/infiniop/ops/mrope/nvidia/mrope_nvidia.cu index dc57290c2..043f45b85 100644 --- a/src/infiniop/ops/mrope/nvidia/mrope_nvidia.cu +++ b/src/infiniop/ops/mrope/nvidia/mrope_nvidia.cu @@ -35,9 +35,7 @@ INFINIOP_CUDA_KERNEL mropeKernel( ptrdiff_t q_stride_head, ptrdiff_t k_stride_token, ptrdiff_t k_stride_head, - ptrdiff_t cos_stride_axis, ptrdiff_t cos_stride_position, - ptrdiff_t sin_stride_axis, ptrdiff_t sin_stride_position, ptrdiff_t positions_stride_axis, ptrdiff_t positions_stride_token, @@ -51,7 +49,7 @@ INFINIOP_CUDA_KERNEL mropeKernel( q_out, k_out, q, k, cos, sin, positions, num_q_heads, num_kv_heads, head_size, rotary_dim, half_rotary_dim, q_out_stride_token, q_out_stride_head, k_out_stride_token, k_out_stride_head, q_stride_token, q_stride_head, k_stride_token, k_stride_head, - cos_stride_axis, cos_stride_position, sin_stride_axis, sin_stride_position, + cos_stride_position, sin_stride_position, positions_stride_axis, positions_stride_token, max_position_embeddings, section_t, section_h, section_w, positions_has_axes, interleaved); } @@ -75,7 +73,7 @@ infiniStatus_t launchMRoPE( info.num_q_heads, info.num_kv_heads, info.head_size, info.rotary_dim, info.half_rotary_dim, info.q_out_stride_token, info.q_out_stride_head, info.k_out_stride_token, info.k_out_stride_head, info.q_stride_token, info.q_stride_head, info.k_stride_token, info.k_stride_head, - info.cos_stride_axis, info.cos_stride_position, info.sin_stride_axis, info.sin_stride_position, + info.cos_stride_position, info.sin_stride_position, info.positions_stride_axis, info.positions_stride_token, info.max_position_embeddings, info.section_t, info.section_h, info.section_w, info.positions_has_axes, info.interleaved); return INFINI_STATUS_SUCCESS; diff --git a/test/infinicore/nn/mrope.py b/test/infinicore/nn/mrope.py new file mode 100644 index 000000000..ce55dfe44 --- /dev/null +++ b/test/infinicore/nn/mrope.py @@ -0,0 +1,163 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorInitializer, + TensorSpec, + TestCase, +) +from ops.mrope import make_positions, torch_mrope + +import infinicore + +_TEST_CASES_DATA = [ + # batch, seq_len, num_q_heads, num_kv_heads, head_dim, rotary_dim, sections, interleaved + (1, 5, 2, 1, 32, 32, (4, 6, 6), False), + (2, 4, 3, 1, 32, 24, (2, 4, 6), False), + (1, 6, 2, 2, 32, 24, (2, 3, 7), True), +] +_TENSOR_DTYPES = [infinicore.float16, infinicore.float32] +_POSITION_DTYPES = [infinicore.int32, infinicore.int64] +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-5}, +} + + +def create_sin_cos_table(max_position, rotary_dim, theta, device): + pos = torch.arange(max_position, dtype=torch.float32, device=device) + freqs = 1.0 / ( + theta + ** ( + torch.arange(0, rotary_dim, 2, dtype=torch.float32, device=device) + / rotary_dim + ) + ) + angles = torch.outer(pos, freqs) + return torch.cos(angles), torch.sin(angles) + + +def parse_test_cases(): + test_cases = [] + for ( + batch, + seq_len, + num_q_heads, + num_kv_heads, + head_dim, + rotary_dim, + sections, + interleaved, + ) in _TEST_CASES_DATA: + num_tokens = batch * seq_len + positions = make_positions(num_tokens) + for dtype in _TENSOR_DTYPES: + for position_dtype in _POSITION_DTYPES: + q_spec = TensorSpec.from_tensor( + (num_tokens, num_q_heads * head_dim), None, dtype + ) + k_spec = TensorSpec.from_tensor( + (num_tokens, num_kv_heads * head_dim), None, dtype + ) + pos_spec = TensorSpec.from_tensor( + positions.shape, + None, + position_dtype, + init_mode=TensorInitializer.MANUAL, + set_tensor=positions, + ) + test_cases.append( + TestCase( + inputs=[q_spec, k_spec, pos_spec], + kwargs={ + "max_position_embeddings": num_tokens * 2 + 4, + "rope_theta": 10000.0, + "head_dim": head_dim, + "rotary_dim": rotary_dim, + "section": sections, + "interleaved": interleaved, + }, + comparison_target=None, + tolerance=_TOLERANCE_MAP[dtype], + output_count=2, + description="nn.MRoPE - OUT_OF_PLACE", + ) + ) + return test_cases + + +class OpTest(BaseOperatorTest): + def __init__(self): + super().__init__("nn.MRoPE") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator( + self, + q, + k, + positions, + max_position_embeddings, + rope_theta, + head_dim, + rotary_dim, + section, + interleaved, + ): + cos, sin = create_sin_cos_table( + max_position_embeddings, rotary_dim, rope_theta, q.device + ) + cos = cos.to(q.dtype) + sin = sin.to(q.dtype) + return torch_mrope( + q, + k, + cos, + sin, + positions, + head_size=head_dim, + rotary_dim=rotary_dim, + section_t=section[0], + section_h=section[1], + section_w=section[2], + interleaved=interleaved, + ) + + def infinicore_operator( + self, + q, + k, + positions, + max_position_embeddings, + rope_theta, + head_dim, + rotary_dim, + section, + interleaved, + ): + module = infinicore.nn.MRoPE( + max_position_embeddings, + rope_theta, + head_dim, + rotary_dim, + section, + interleaved=interleaved, + device=q.device, + dtype=q.dtype, + ) + return module(q, k, positions) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/mrope.py b/test/infinicore/ops/mrope.py new file mode 100644 index 000000000..67bfd4e20 --- /dev/null +++ b/test/infinicore/ops/mrope.py @@ -0,0 +1,196 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorInitializer, + TensorSpec, + TestCase, +) + +import infinicore + +_TEST_CASES_DATA = [ + # num_tokens, num_q_heads, num_kv_heads, head_size, rotary_dim, sections, interleaved + (5, 2, 1, 32, 32, (4, 6, 6), False), + (7, 3, 1, 32, 24, (2, 4, 6), False), + (6, 2, 2, 32, 24, (2, 3, 7), True), +] +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] +_POSITION_DTYPES = [infinicore.int32, infinicore.int64] +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, + infinicore.bfloat16: {"atol": 2e-2, "rtol": 2e-2}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-5}, +} + + +def make_positions(num_tokens): + return torch.stack( + [ + torch.arange(num_tokens, dtype=torch.int64) * 2, + torch.arange(num_tokens, dtype=torch.int64) * 2 + 1, + torch.arange(num_tokens, dtype=torch.int64) * 2 + 2, + ] + ) + + +def parse_test_cases(): + test_cases = [] + for ( + num_tokens, + num_q_heads, + num_kv_heads, + head_size, + rotary_dim, + sections, + interleaved, + ) in _TEST_CASES_DATA: + max_positions = num_tokens * 2 + 4 + positions = make_positions(num_tokens) + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP[dtype] + for position_dtype in _POSITION_DTYPES: + q_spec = TensorSpec.from_tensor( + (num_tokens, num_q_heads * head_size), None, dtype + ) + k_spec = TensorSpec.from_tensor( + (num_tokens, num_kv_heads * head_size), None, dtype + ) + cos_spec = TensorSpec.from_tensor( + (max_positions, rotary_dim // 2), None, dtype + ) + sin_spec = TensorSpec.from_tensor( + (max_positions, rotary_dim // 2), None, dtype + ) + pos_spec = TensorSpec.from_tensor( + positions.shape, + None, + position_dtype, + init_mode=TensorInitializer.MANUAL, + set_tensor=positions, + ) + kwargs = { + "head_size": head_size, + "rotary_dim": rotary_dim, + "section_t": sections[0], + "section_h": sections[1], + "section_w": sections[2], + "interleaved": interleaved, + } + test_cases.append( + TestCase( + inputs=[q_spec, k_spec, cos_spec, sin_spec, pos_spec], + kwargs=kwargs, + comparison_target=None, + tolerance=tolerance, + output_count=2, + description="MRoPE - OUT_OF_PLACE", + ) + ) + test_cases.append( + TestCase( + inputs=[q_spec, k_spec, cos_spec, sin_spec, pos_spec], + kwargs=kwargs, + output_specs=[ + TensorSpec.from_tensor( + (num_tokens, num_q_heads * head_size), None, dtype + ), + TensorSpec.from_tensor( + (num_tokens, num_kv_heads * head_size), None, dtype + ), + ], + comparison_target="out", + tolerance=tolerance, + output_count=2, + description="MRoPE - INPLACE(out)", + ) + ) + return test_cases + + +def axis_for_dim(dim, section_t, section_h, section_w, interleaved): + if interleaved: + mod = dim % 3 + if mod == 1 and dim < section_h * 3: + return 1 + if mod == 2 and dim < section_w * 3: + return 2 + return 0 + if dim < section_t: + return 0 + if dim < section_t + section_h: + return 1 + return 2 + + +def torch_mrope_one( + x, + cos, + sin, + positions, + head_size, + rotary_dim, + section_t, + section_h, + section_w, + interleaved, +): + num_tokens = x.shape[0] + num_heads = x.shape[1] // head_size + half = rotary_dim // 2 + x = x.reshape(num_tokens, num_heads, head_size) + out = x.clone() + cos_row = torch.empty((num_tokens, half), dtype=torch.float32, device=x.device) + sin_row = torch.empty((num_tokens, half), dtype=torch.float32, device=x.device) + has_axes = positions.ndim == 2 + for i in range(half): + axis = axis_for_dim(i, section_t, section_h, section_w, interleaved) + pos = positions[axis] if has_axes else positions + cos_row[:, i] = cos[pos, i].float() + sin_row[:, i] = sin[pos, i].float() + x0 = x[:, :, :half].float() + x1 = x[:, :, half:rotary_dim].float() + cos_row = cos_row[:, None, :] + sin_row = sin_row[:, None, :] + out[:, :, :half] = (x0 * cos_row - x1 * sin_row).to(out.dtype) + out[:, :, half:rotary_dim] = (x1 * cos_row + x0 * sin_row).to(out.dtype) + return out.reshape(num_tokens, num_heads * head_size) + + +def torch_mrope(q, k, cos, sin, positions, **kwargs): + q_out = torch_mrope_one(q, cos, sin, positions, **kwargs) + k_out = torch_mrope_one(k, cos, sin, positions, **kwargs) + return q_out, k_out + + +class OpTest(BaseOperatorTest): + def __init__(self): + super().__init__("MRoPE") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, out=None, **kwargs): + q_out, k_out = torch_mrope(*args, **kwargs) + if out is not None: + out[0].copy_(q_out) + out[1].copy_(k_out) + return out + return q_out, k_out + + def infinicore_operator(self, q, k, cos, sin, positions, out=None, **kwargs): + return infinicore.mrope(q, k, cos, sin, positions, out=out, **kwargs) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infiniop/mrope.py b/test/infiniop/mrope.py index 81bb84cb5..d5edbddac 100644 --- a/test/infiniop/mrope.py +++ b/test/infiniop/mrope.py @@ -18,13 +18,19 @@ test_operator, ) -_TEST_CASES = [ +_BASE_TEST_CASES = [ # num_tokens, num_q_heads, num_kv_heads, head_size, rotary_dim, sections, interleaved (5, 2, 1, 128, 128, (16, 24, 24), False), (7, 4, 2, 128, 128, (16, 24, 24), False), (3, 3, 1, 128, 96, (8, 16, 24), False), (6, 2, 2, 128, 96, (8, 16, 24), True), ] +_POSITION_DTYPES = [InfiniDtype.I32, InfiniDtype.I64] +_TEST_CASES = [ + (*case, position_dtype) + for case in _BASE_TEST_CASES + for position_dtype in _POSITION_DTYPES +] _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] _TOLERANCE_MAP = { InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, @@ -61,8 +67,8 @@ def torch_mrope_one( for i in range(half): axis = axis_for_dim(i, sections, interleaved) pos = positions[axis] if has_axes else positions - cos_row[:, i] = cos[axis, pos, i] - sin_row[:, i] = sin[axis, pos, i] + cos_row[:, i] = cos[pos, i] + sin_row[:, i] = sin[pos, i] x0 = x[:, :, :half].float() x1 = x[:, :, half:rotary_dim].float() cos_row = cos_row[:, None, :].float() @@ -82,6 +88,7 @@ def test( rotary_dim, sections, interleaved, + position_dtype, dtype, sync=None, ): @@ -89,7 +96,7 @@ def test( f"Testing MRoPE on {InfiniDeviceNames[device]} tokens={num_tokens} " f"q_heads={num_q_heads} kv_heads={num_kv_heads} head_size={head_size} " f"rotary_dim={rotary_dim} sections={sections} interleaved={interleaved} " - f"dtype={InfiniDtypeNames[dtype]}" + f"dtype={InfiniDtypeNames[dtype]} position_dtype={InfiniDtypeNames[position_dtype]}" ) q = TestTensor((num_tokens, num_q_heads * head_size), None, dtype, device) k = TestTensor((num_tokens, num_kv_heads * head_size), None, dtype, device) @@ -100,8 +107,8 @@ def test( (num_tokens, num_kv_heads * head_size), None, dtype, device, mode="zeros" ) max_positions = max(num_tokens + 3, 16) - cos = TestTensor((3, max_positions, rotary_dim // 2), None, dtype, device) - sin = TestTensor((3, max_positions, rotary_dim // 2), None, dtype, device) + cos = TestTensor((max_positions, rotary_dim // 2), None, dtype, device) + sin = TestTensor((max_positions, rotary_dim // 2), None, dtype, device) positions_torch = torch.stack( [ torch.arange(num_tokens, dtype=torch.int64), @@ -109,7 +116,7 @@ def test( torch.arange(num_tokens, dtype=torch.int64) + 2, ] ) - positions = TestTensor.from_torch(positions_torch, InfiniDtype.I64, device) + positions = TestTensor.from_torch(positions_torch, position_dtype, device) expected_q = torch_mrope_one( q.torch_tensor(), cos.torch_tensor(),