diff --git a/include/infinicore/nn.hpp b/include/infinicore/nn.hpp index 9f7ac5bfe..16f463a72 100644 --- a/include/infinicore/nn.hpp +++ b/include/infinicore/nn.hpp @@ -2,5 +2,5 @@ #include "nn/embedding.hpp" #include "nn/linear.hpp" -#include "nn/mrope.hpp" #include "nn/rmsnorm.hpp" +#include "nn/rope.hpp" diff --git a/include/infinicore/nn/mrope.hpp b/include/infinicore/nn/mrope.hpp deleted file mode 100644 index 2ac69ac85..000000000 --- a/include/infinicore/nn/mrope.hpp +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include "../context/context.hpp" -#include "../tensor.hpp" -#include "module.hpp" -#include -#include -#include - -namespace infinicore::nn { - -class MRoPE : public Module { -public: - MRoPE(size_t head_dim, - size_t rotary_dim, - size_t max_seq_len, - double theta, - std::array section, - bool interleaved, - const DataType &dtype, - const Device &device); - - std::pair forward(const Tensor &q, - const Tensor &k, - const Tensor &positions) const; - - std::pair forward(const Tensor &q_out, - const Tensor &k_out, - const Tensor &q, - const Tensor &k, - const Tensor &positions) const; - - size_t rotary_dim() const { return rotary_dim_; } - size_t head_dim() const { return head_dim_; } - size_t max_seq_len() const { return max_seq_len_; } - double theta() const { return theta_; } - const std::array §ion() const { return section_; } - bool interleaved() const { return interleaved_; } - DataType dtype() const { return dtype_; } - - std::string extra_repr() const; - -protected: - INFINICORE_NN_BUFFER(sin_cache); - INFINICORE_NN_BUFFER(cos_cache); - -private: - void initialize_cache(); - - size_t head_dim_; - size_t rotary_dim_; - size_t max_seq_len_; - double theta_; - std::array section_; - bool interleaved_; - DataType dtype_; -}; - -} // namespace infinicore::nn diff --git a/include/infinicore/nn/rope.hpp b/include/infinicore/nn/rope.hpp index 020257f0d..cbffb81c8 100644 --- a/include/infinicore/nn/rope.hpp +++ b/include/infinicore/nn/rope.hpp @@ -6,6 +6,9 @@ #include "rope_scaling_configs.hpp" #include #include +#include +#include +#include namespace infinicore::nn { @@ -32,6 +35,9 @@ class RoPE : public Module { * @param dtype Data type for sin/cos cache (default: DataType::F32) * @param device Device to create the cache on * @param scaling RoPE scaling configuration (default: nullptr) + * @param mrope_section Optional MRoPE section sizes [t, h, w], whose sum must equal rotary_dim / 2. + * When set, pair forward overloads apply MRoPE to q/k using positions [3, num_tokens]. + * @param mrope_interleaved Whether to interleave MRoPE axes/frequency sections. */ RoPE(size_t head_dim, size_t rotary_dim, @@ -40,50 +46,47 @@ class RoPE : public Module { Algo algo = Algo::GPT_J, const DataType &dtype = DataType::F32, const Device &device = Device(), - std::shared_ptr scaling = nullptr); + std::shared_ptr scaling = nullptr, + std::optional> mrope_section = std::nullopt, + bool mrope_interleaved = false); /** - * @brief Forward pass: apply RoPE to a tensor + * @brief Forward pass: apply standard RoPE to a tensor * * @param x Input tensor of shape (..., rotary_dim) where ... is any number of dimensions * @param pos Position IDs tensor of shape (*,) typically [seq_len] or [batch, seq_len] * @param in_place If true, modify input tensor in place (default: false) * @return Rotated tensor with same shape as input - * - * Applies rotary position embeddings to the input tensor. - * For attention mechanisms, call this method separately for query and key tensors. - * - * Common input shapes: - * - [batch, num_heads, seq_len, rotary_dim] - * - [batch, seq_len, num_heads, rotary_dim] - * - [seq_len, rotary_dim] */ Tensor forward(const Tensor &x, const Tensor &pos, bool in_place = false) const; /** - * @brief Forward pass: apply RoPE to a tensor in place - * - * @param y Output tensor of shape (..., rotary_dim) where ... is any number of dimensions - * @param x Input tensor of shape (..., rotary_dim) where ... is any number of dimensions - * @param pos Position IDs tensor of shape (*,) typically [seq_len] or [batch, seq_len] - * @return Rotated tensor with same shape as input + * @brief Apply MRoPE to q and k. * - * Applies rotary position embeddings to the input tensor. - * For attention mechanisms, call this method separately for query and key tensors. - * - * Common input shapes: - * - [batch, num_heads, seq_len, rotary_dim] - * - [batch, seq_len, num_heads, rotary_dim] - * - [seq_len, rotary_dim] + * Requires construction with mrope_section. q/k may be either + * [num_tokens, num_heads * head_dim] or [num_tokens, num_heads, head_dim]. + * positions is [3, num_tokens] with axes ordered as t, h, w. */ - Tensor forward(const Tensor &y, const Tensor &x, const Tensor &pos) const; + std::pair forward(const Tensor &q, const Tensor &k, const Tensor &positions) const; + + /** + * @brief Apply MRoPE to q and k into caller-provided outputs. + */ + std::pair forward(const Tensor &q_out, + const Tensor &k_out, + const Tensor &q, + const Tensor &k, + const Tensor &positions) const; // Module information size_t rotary_dim() const { return rotary_dim_; } + size_t head_dim() const { return head_dim_; } size_t max_seq_len() const { return max_seq_len_; } double theta() const { return theta_; } Algo algo() const { return algo_; } DataType dtype() const { return dtype_; } + const std::optional> &mrope_section() const { return mrope_section_; } + bool mrope_interleaved() const { return mrope_interleaved_; } // String representation std::string extra_repr() const; @@ -95,7 +98,6 @@ class RoPE : public Module { private: void initialize_cache(); - size_t rotary_dim_; // Number of dimensions to apply rotation to (must be even). size_t head_dim_; // Dimension of each attention head size_t max_seq_len_; // Maximum sequence length @@ -103,6 +105,8 @@ class RoPE : public Module { Algo algo_; // RoPE algorithm type DataType dtype_; // Data type for cache tables std::shared_ptr scaling_; // RoPE scaling configuration + std::optional> mrope_section_; + bool mrope_interleaved_; }; } // namespace infinicore::nn diff --git a/python/infinicore/nn/modules/__init__.py b/python/infinicore/nn/modules/__init__.py index f7d9cb9c2..1ec1349ad 100644 --- a/python/infinicore/nn/modules/__init__.py +++ b/python/infinicore/nn/modules/__init__.py @@ -1,9 +1,8 @@ from .container import InfiniCoreModuleList as ModuleList from .linear import Linear from .module import InfiniCoreModule as Module -from .mrope import MRoPE from .normalization import RMSNorm from .rope import RoPE from .sparse import Embedding -__all__ = ["Linear", "RMSNorm", "Embedding", "RoPE", "MRoPE", "ModuleList", "Module"] +__all__ = ["Linear", "RMSNorm", "Embedding", "RoPE", "ModuleList", "Module"] diff --git a/python/infinicore/nn/modules/mrope.py b/python/infinicore/nn/modules/mrope.py deleted file mode 100644 index a25ba9191..000000000 --- a/python/infinicore/nn/modules/mrope.py +++ /dev/null @@ -1,105 +0,0 @@ -import numpy as np - -import infinicore -from infinicore.ops.mrope import mrope - -from ...tensor import Tensor -from .module import InfiniCoreModule as Module - - -def create_sin_cos_table_numpy(max_position, rotary_dim, theta=10000.0): - if rotary_dim % 2 != 0: - raise ValueError("rotary_dim must be even") - pos = np.arange(0, max_position) - freqs = 1.0 / ( - theta - ** (np.arange(0, rotary_dim, 2)[: (rotary_dim // 2)].astype(float) / rotary_dim) - ) - angles = np.outer(pos, freqs) - sin_table = np.sin(angles, dtype=np.float32) - cos_table = np.cos(angles, dtype=np.float32) - return sin_table, cos_table - - -def create_sin_cos_table( - max_position, rotary_dim, theta=10000.0, device=None, dtype=None -): - sin_table_np, cos_table_np = create_sin_cos_table_numpy( - max_position, rotary_dim, theta - ) - return ( - infinicore.from_numpy(sin_table_np, dtype=dtype, device=device), - infinicore.from_numpy(cos_table_np, dtype=dtype, device=device), - ) - - -class MRoPE(Module): - r"""Multimodal rotary position embedding with vLLM-style 2D sin/cos cache.""" - - __constants__ = [ - "max_position_embeddings", - "rope_theta", - "head_dim", - "rotary_dim", - "section", - "interleaved", - ] - - def __init__( - self, - max_position_embeddings: int, - rope_theta: float, - head_dim: int, - rotary_dim: int, - section: tuple[int, int, int], - interleaved: bool = False, - device=None, - dtype=None, - ): - super().__init__() - if rotary_dim <= 0 or rotary_dim > head_dim or rotary_dim % 2 != 0: - raise ValueError("rotary_dim must be positive, even, and <= head_dim") - if len(section) != 3 or 2 * sum(section) != rotary_dim: - raise ValueError("section must contain 3 values and sum to rotary_dim / 2") - - factory_kwargs = { - "device": infinicore.device("cpu", 0) if device is None else device, - "dtype": infinicore.float32 if dtype is None else dtype, - } - - self.max_position_embeddings = max_position_embeddings - self.rope_theta = rope_theta - self.head_dim = head_dim - self.rotary_dim = rotary_dim - self.section = tuple(section) - self.interleaved = interleaved - - self._sin_table, self._cos_table = create_sin_cos_table( - self.max_position_embeddings, - self.rotary_dim, - self.rope_theta, - **factory_kwargs, - ) - - def forward( - self, - q: Tensor, - k: Tensor, - positions: Tensor, - *, - out: tuple[Tensor, Tensor] | None = None, - ) -> tuple[Tensor, Tensor]: - return mrope( - q, - k, - self._cos_table, - self._sin_table, - positions, - self.head_dim, - self.rotary_dim, - self.section[0], - self.section[1], - self.section[2], - self.interleaved, - out=out, - ) diff --git a/python/infinicore/nn/modules/rope.py b/python/infinicore/nn/modules/rope.py index 5d579e2d3..1260b295c 100644 --- a/python/infinicore/nn/modules/rope.py +++ b/python/infinicore/nn/modules/rope.py @@ -2,17 +2,19 @@ import infinicore from infinicore.nn import functional as F +from infinicore.ops.mrope import mrope from ...tensor import Tensor from ..functional import RopeAlgo from .module import InfiniCoreModule as Module -def create_sin_cos_table_numpy(max_position, head_dim, theta=10000.0): - assert head_dim % 2 == 0, "Embedding dimension must be even." +def create_sin_cos_table_numpy(max_position, rotary_dim, theta=10000.0): + assert rotary_dim % 2 == 0, "Embedding dimension must be even." pos = np.arange(0, max_position) freqs = 1.0 / ( - theta ** (np.arange(0, head_dim, 2)[: (head_dim // 2)].astype(float) / head_dim) + theta + ** (np.arange(0, rotary_dim, 2)[: (rotary_dim // 2)].astype(float) / rotary_dim) ) angles = np.outer(pos, freqs) sin_table = np.sin(angles, dtype=np.float32) @@ -22,13 +24,13 @@ def create_sin_cos_table_numpy(max_position, head_dim, theta=10000.0): def create_sin_cos_table( max_position, - head_dim, + rotary_dim, theta=10000.0, device=None, dtype=None, ): sin_table_np, cos_table_np = create_sin_cos_table_numpy( - max_position, head_dim, theta + max_position, rotary_dim, theta ) sin_table_infini = infinicore.from_numpy(sin_table_np, dtype=dtype, device=device) @@ -38,22 +40,24 @@ def create_sin_cos_table( class RoPE(Module): - r"""Rotary Position Embedding(RoPE).. + r"""Rotary Position Embedding(RoPE). - Args: - max_position_embeddings (int): The maximum sequence length that this model might ever be used with. - rope_theta (float): The base period of the RoPE embeddings. - head_dim (int): The attention head dimension. - - Shape: - - Input: hidden_states, ( bs, seq_len, num_heads, head_dim). - - Output: hidden_states, ( bs, seq_len, num_heads, head_dim). + Standard RoPE is used when ``mrope_section`` is None. MRoPE is enabled by passing + ``mrope_section=[t, h, w]`` and then calling the module as ``rope(q, k, positions)``. """ - __constants__ = ["max_position_embeddings", "rope_theta", "head_dim"] + __constants__ = [ + "max_position_embeddings", + "rope_theta", + "head_dim", + "rotary_dim", + "mrope_section", + "mrope_interleaved", + ] max_position_embeddings: int rope_theta: float head_dim: int + rotary_dim: int def __init__( self, @@ -62,6 +66,9 @@ def __init__( head_dim: int, device=None, dtype=None, + rotary_dim: int | None = None, + mrope_section: list[int] | tuple[int, int, int] | None = None, + mrope_interleaved: bool = False, ): factory_kwargs = { "device": infinicore.device("cpu", 0) if device is None else device, @@ -72,21 +79,79 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta self.head_dim = head_dim + self.rotary_dim = head_dim if rotary_dim is None else rotary_dim + self.mrope_section = None if mrope_section is None else list(mrope_section) + self.mrope_interleaved = mrope_interleaved + + if ( + self.rotary_dim <= 0 + or self.rotary_dim > self.head_dim + or self.rotary_dim % 2 != 0 + ): + raise ValueError("rotary_dim must be positive, even, and <= head_dim") + if self.mrope_section is not None and ( + len(self.mrope_section) != 3 + or 2 * sum(self.mrope_section) != self.rotary_dim + ): + raise ValueError( + "mrope_section must contain 3 values and sum to rotary_dim / 2" + ) self._sin_table, self._cos_table = create_sin_cos_table( self.max_position_embeddings, - head_dim=self.head_dim, + rotary_dim=self.rotary_dim, theta=self.rope_theta, **factory_kwargs, ) - def forward(self, states: Tensor, position_ids: Tensor, algo=RopeAlgo.GPT_NEOX): + def forward( + self, + states: Tensor, + position_ids: Tensor, + *args, + algo=RopeAlgo.GPT_NEOX, + out=None, + ): + if args: + q = states + k = position_ids + positions = args[0] + if self.mrope_section is not None: + return mrope( + q, + k, + self._cos_table, + self._sin_table, + positions, + self.head_dim, + self.rotary_dim, + self.mrope_section[0], + self.mrope_section[1], + self.mrope_section[2], + self.mrope_interleaved, + out=out, + ) + + if out is None: + q_out = infinicore.empty(q.shape, dtype=q.dtype, device=q.device) + k_out = infinicore.empty(k.shape, dtype=k.dtype, device=k.device) + else: + q_out, k_out = out + F.rope(q, positions, self._sin_table, self._cos_table, algo=algo, out=q_out) + F.rope(k, positions, self._sin_table, self._cos_table, algo=algo, out=k_out) + return q_out, k_out + + if self.mrope_section is not None: + raise NotImplementedError( + "MRoPE single-tensor forward is not implemented; use fused forward(q, k, positions) instead" + ) + target = states if out is None else out F.rope( states, position_ids, self._sin_table, self._cos_table, algo=algo, - out=states, + out=target, ) - return states + return target diff --git a/src/infinicore/nn/mrope.cc b/src/infinicore/nn/mrope.cc deleted file mode 100644 index a01d1580e..000000000 --- a/src/infinicore/nn/mrope.cc +++ /dev/null @@ -1,154 +0,0 @@ -#include "infinicore/nn/mrope.hpp" -#include "../../utils.h" -#include "../utils.hpp" -#include "infinicore/ops/mrope.hpp" -#include -#include -#include - -namespace infinicore::nn { - -MRoPE::MRoPE(size_t head_dim, - size_t rotary_dim, - size_t max_seq_len, - double theta, - std::array section, - bool interleaved, - const DataType &dtype, - const Device &device) - : head_dim_(head_dim), - rotary_dim_(rotary_dim), - max_seq_len_(max_seq_len), - theta_(theta), - section_(section), - interleaved_(interleaved), - dtype_(dtype) { - if (rotary_dim_ % 2 != 0) { - throw std::invalid_argument("rotary_dim must be even for MRoPE, got " + std::to_string(rotary_dim_)); - } - if (rotary_dim_ == 0 || rotary_dim_ > head_dim_) { - throw std::invalid_argument("rotary_dim must be in (0, head_dim] for MRoPE"); - } - if (2 * static_cast(section_[0] + section_[1] + section_[2]) != rotary_dim_) { - throw std::invalid_argument("MRoPE section sum must equal rotary_dim / 2"); - } - device_ = device; - initialize_cache(); -} - -void MRoPE::initialize_cache() { - const size_t cache_dim = rotary_dim_ / 2; - const size_t numel = max_seq_len_ * cache_dim; - INFINICORE_NN_BUFFER_INIT(sin_cache, ({max_seq_len_, cache_dim}, dtype_, device_)); - INFINICORE_NN_BUFFER_INIT(cos_cache, ({max_seq_len_, cache_dim}, dtype_, device_)); - - std::vector sin_data(numel); - std::vector cos_data(numel); - for (size_t pos = 0; pos < max_seq_len_; ++pos) { - for (size_t dim_idx = 0; dim_idx < cache_dim; ++dim_idx) { - const float inv_freq = 1.0f / std::pow(static_cast(theta_), 2.0f * static_cast(dim_idx) / static_cast(rotary_dim_)); - const float angle = static_cast(pos) * inv_freq; - const size_t offset = pos * cache_dim + dim_idx; - sin_data[offset] = std::sin(angle); - cos_data[offset] = std::cos(angle); - } - } - - const auto cpu_device = Device(Device::Type::CPU, 0); - if (dtype_ == DataType::F32) { - auto sin_cpu = Tensor::from_blob(sin_data.data(), {max_seq_len_, cache_dim}, DataType::F32, cpu_device); - auto cos_cpu = Tensor::from_blob(cos_data.data(), {max_seq_len_, cache_dim}, DataType::F32, cpu_device); - sin_cache_->copy_from(sin_cpu); - cos_cache_->copy_from(cos_cpu); - return; - } - if (dtype_ == DataType::BF16) { - std::vector sin_bf16(numel); - std::vector cos_bf16(numel); - for (size_t i = 0; i < numel; ++i) { - sin_bf16[i] = utils::cast(sin_data[i]); - cos_bf16[i] = utils::cast(cos_data[i]); - } - auto sin_cpu = Tensor::from_blob(sin_bf16.data(), {max_seq_len_, cache_dim}, DataType::BF16, cpu_device); - auto cos_cpu = Tensor::from_blob(cos_bf16.data(), {max_seq_len_, cache_dim}, DataType::BF16, cpu_device); - sin_cache_->copy_from(sin_cpu); - cos_cache_->copy_from(cos_cpu); - return; - } - if (dtype_ == DataType::F16) { - std::vector sin_f16(numel); - std::vector cos_f16(numel); - for (size_t i = 0; i < numel; ++i) { - sin_f16[i] = utils::cast(sin_data[i]); - cos_f16[i] = utils::cast(cos_data[i]); - } - auto sin_cpu = Tensor::from_blob(sin_f16.data(), {max_seq_len_, cache_dim}, DataType::F16, cpu_device); - auto cos_cpu = Tensor::from_blob(cos_f16.data(), {max_seq_len_, cache_dim}, DataType::F16, cpu_device); - sin_cache_->copy_from(sin_cpu); - cos_cache_->copy_from(cos_cpu); - return; - } - throw std::runtime_error("MRoPE cache dtype conversion not supported for dtype: " + std::to_string(static_cast(dtype_))); -} - -std::pair MRoPE::forward(const Tensor &q, - const Tensor &k, - const Tensor &positions) const { - const size_t num_tokens = q->size(0); - auto q_flat = q->contiguous()->view({num_tokens, q->size(1) * head_dim_}); - auto k_flat = k->contiguous()->view({num_tokens, k->size(1) * head_dim_}); - auto q_out = Tensor::empty(q_flat->shape(), q_flat->dtype(), q_flat->device()); - auto k_out = Tensor::empty(k_flat->shape(), k_flat->dtype(), k_flat->device()); - op::mrope_(q_out, - k_out, - q_flat, - k_flat, - cos_cache_, - sin_cache_, - positions, - static_cast(head_dim_), - static_cast(rotary_dim_), - section_[0], - section_[1], - section_[2], - interleaved_); - return {q_out->view(q->shape()), k_out->view(k->shape())}; -} - -std::pair MRoPE::forward(const Tensor &q_out, - const Tensor &k_out, - const Tensor &q, - const Tensor &k, - const Tensor &positions) const { - const size_t num_tokens = q->size(0); - auto q_flat = q->contiguous()->view({num_tokens, q->size(1) * head_dim_}); - auto k_flat = k->contiguous()->view({num_tokens, k->size(1) * head_dim_}); - auto q_out_flat = q_out->view({num_tokens, q->size(1) * head_dim_}); - auto k_out_flat = k_out->view({num_tokens, k->size(1) * head_dim_}); - op::mrope_(q_out_flat, - k_out_flat, - q_flat, - k_flat, - cos_cache_, - sin_cache_, - positions, - static_cast(head_dim_), - static_cast(rotary_dim_), - section_[0], - section_[1], - section_[2], - interleaved_); - return {q_out, k_out}; -} - -std::string MRoPE::extra_repr() const { - return "MRoPE(head_dim=" + std::to_string(head_dim_) - + ", rotary_dim=" + std::to_string(rotary_dim_) - + ", max_seq_len=" + std::to_string(max_seq_len_) - + ", theta=" + std::to_string(theta_) - + ", section=[" + std::to_string(section_[0]) + "," + std::to_string(section_[1]) + "," + std::to_string(section_[2]) + "]" - + ", interleaved=" + (interleaved_ ? "true" : "false") - + ", dtype=" + std::to_string(static_cast(dtype_)) + ")"; -} - -} // namespace infinicore::nn diff --git a/src/infinicore/nn/rope.cc b/src/infinicore/nn/rope.cc index 831347674..115e162e2 100644 --- a/src/infinicore/nn/rope.cc +++ b/src/infinicore/nn/rope.cc @@ -1,7 +1,8 @@ #include "infinicore/nn/rope.hpp" #include "../../utils.h" #include "../utils.hpp" -#include "infinicore/ops.hpp" +#include "infinicore/ops/mrope.hpp" +#include "infinicore/ops/rope.hpp" #include #include #include @@ -19,19 +20,31 @@ RoPE::RoPE(size_t head_dim, Algo algo, const DataType &dtype, const Device &device, - std::shared_ptr scaling) + std::shared_ptr scaling, + std::optional> mrope_section, + bool mrope_interleaved) : rotary_dim_(rotary_dim), head_dim_(head_dim), max_seq_len_(max_seq_len), theta_(theta), algo_(algo), dtype_(dtype), - scaling_(scaling) { - // TODO use head_dim + scaling_(scaling), + mrope_section_(mrope_section), + mrope_interleaved_(mrope_interleaved) { if (rotary_dim % 2 != 0) { throw std::invalid_argument("rotary_dim must be even for RoPE, got " + std::to_string(rotary_dim)); } assert((rotary_dim > 0) && (rotary_dim <= head_dim_)); + if (mrope_section_.has_value()) { + const auto §ion = mrope_section_.value(); + if (section.size() != 3 || section[0] <= 0 || section[1] <= 0 || section[2] <= 0) { + throw std::invalid_argument("mrope_section must contain 3 positive values"); + } + if (2 * static_cast(section[0] + section[1] + section[2]) != rotary_dim_) { + throw std::invalid_argument("MRoPE section sum must equal rotary_dim / 2"); + } + } device_ = device; // Initialize cache tables @@ -122,6 +135,9 @@ void RoPE::initialize_cache() { } Tensor RoPE::forward(const Tensor &x, const Tensor &pos, bool in_place) const { + if (mrope_section_.has_value()) { + throw std::runtime_error("MRoPE single-tensor forward is not implemented; use fused forward(q, k, positions) instead"); + } Tensor y; if (in_place) { y = Tensor(x); @@ -139,17 +155,112 @@ Tensor RoPE::forward(const Tensor &x, const Tensor &pos, bool in_place) const { return y; } -Tensor RoPE::forward(const Tensor &y, const Tensor &x, const Tensor &pos) const { - size_t ndim = x->ndim(); - op::rope_(y->narrow({{ndim - 1, 0, rotary_dim_}}), - x->narrow({{ndim - 1, 0, rotary_dim_}}), - pos, sin_cache_, cos_cache_, algo_); - return y; +static Tensor mrope_flatten_input(const Tensor &x, size_t head_dim, const char *name) { + if (x->ndim() == 2) { + if (x->size(1) % head_dim != 0) { + throw std::runtime_error(std::string("MRoPE expects ") + name + " hidden size to be a multiple of head_dim"); + } + return x; + } + if (x->ndim() == 3 && x->size(2) == head_dim) { + return x->view({x->size(0), x->size(1) * head_dim}); + } + throw std::runtime_error(std::string("MRoPE expects ") + name + " with shape [num_tokens, num_heads * head_dim] or [num_tokens, num_heads, head_dim]"); +} + +static Tensor mrope_flatten_output(const Tensor &x, size_t head_dim, const char *name) { + if (x->ndim() == 2) { + if (x->size(1) % head_dim != 0) { + throw std::runtime_error(std::string("MRoPE expects ") + name + " hidden size to be a multiple of head_dim"); + } + return x->view({x->size(0), x->size(1)}); + } + if (x->ndim() == 3 && x->size(2) == head_dim) { + return x->view({x->size(0), x->size(1) * head_dim}); + } + throw std::runtime_error(std::string("MRoPE expects ") + name + " with shape [num_tokens, num_heads * head_dim] or [num_tokens, num_heads, head_dim]"); +} + +std::pair RoPE::forward(const Tensor &q, const Tensor &k, const Tensor &positions) const { + if (!mrope_section_.has_value()) { + auto q_out = Tensor::empty(q->shape(), q->dtype(), q->device()); + auto k_out = Tensor::empty(k->shape(), k->dtype(), k->device()); + return forward(q_out, k_out, q, k, positions); + } + auto q_flat = mrope_flatten_input(q, head_dim_, "q"); + auto k_flat = mrope_flatten_input(k, head_dim_, "k"); + auto q_out = Tensor::empty(q_flat->shape(), q_flat->dtype(), q_flat->device()); + auto k_out = Tensor::empty(k_flat->shape(), k_flat->dtype(), k_flat->device()); + const auto §ion = mrope_section_.value(); + op::mrope_(q_out, + k_out, + q_flat, + k_flat, + cos_cache_, + sin_cache_, + positions, + static_cast(head_dim_), + static_cast(rotary_dim_), + section[0], + section[1], + section[2], + mrope_interleaved_); + return {q_out->view(q->shape()), k_out->view(k->shape())}; +} + +std::pair RoPE::forward(const Tensor &q_out, + const Tensor &k_out, + const Tensor &q, + const Tensor &k, + const Tensor &positions) const { + if (!mrope_section_.has_value()) { + auto apply_standard = [this, &positions](Tensor out, const Tensor &in) { + if (rotary_dim_ < head_dim_) { + out->copy_from(in); + } + size_t ndim = in->ndim(); + op::rope_(out->narrow({{ndim - 1, 0, rotary_dim_}}), + in->narrow({{ndim - 1, 0, rotary_dim_}}), + positions, + sin_cache_, + cos_cache_, + algo_); + }; + apply_standard(q_out, q); + apply_standard(k_out, k); + return {q_out, k_out}; + } + auto q_flat = mrope_flatten_input(q, head_dim_, "q"); + auto k_flat = mrope_flatten_input(k, head_dim_, "k"); + auto q_out_flat = mrope_flatten_output(q_out, head_dim_, "q_out"); + auto k_out_flat = mrope_flatten_output(k_out, head_dim_, "k_out"); + const auto §ion = mrope_section_.value(); + op::mrope_(q_out_flat, + k_out_flat, + q_flat, + k_flat, + cos_cache_, + sin_cache_, + positions, + static_cast(head_dim_), + static_cast(rotary_dim_), + section[0], + section[1], + section[2], + mrope_interleaved_); + return {q_out, k_out}; } std::string RoPE::extra_repr() const { std::string algo_str = (algo_ == Algo::GPT_J) ? "GPT_J" : "GPT_NEOX"; - return "RoPE(head_dim=" + std::to_string(head_dim_) + ", rotary_dim=" + std::to_string(rotary_dim_) + ", max_seq_len=" + std::to_string(max_seq_len_) + ", theta=" + std::to_string(theta_) + ", algo=" + algo_str + ", dtype=" + std::to_string(static_cast(dtype_)) + ")"; + std::string repr = "RoPE(head_dim=" + std::to_string(head_dim_) + ", rotary_dim=" + std::to_string(rotary_dim_) + ", max_seq_len=" + std::to_string(max_seq_len_) + ", theta=" + std::to_string(theta_) + ", algo=" + algo_str + ", dtype=" + std::to_string(static_cast(dtype_)); + if (mrope_section_.has_value()) { + const auto §ion = mrope_section_.value(); + repr += ", mrope_section=[" + std::to_string(section[0]) + "," + std::to_string(section[1]) + "," + std::to_string(section[2]) + "]"; + repr += ", mrope_interleaved=" + std::string(mrope_interleaved_ ? "true" : "false"); + } + repr += ")"; + return repr; } } // namespace infinicore::nn diff --git a/test/infinicore/nn/mrope.py b/test/infinicore/nn/mrope.py index ce55dfe44..fbbd35e80 100644 --- a/test/infinicore/nn/mrope.py +++ b/test/infinicore/nn/mrope.py @@ -16,10 +16,11 @@ import infinicore _TEST_CASES_DATA = [ - # batch, seq_len, num_q_heads, num_kv_heads, head_dim, rotary_dim, sections, interleaved - (1, 5, 2, 1, 32, 32, (4, 6, 6), False), - (2, 4, 3, 1, 32, 24, (2, 4, 6), False), - (1, 6, 2, 2, 32, 24, (2, 3, 7), True), + # batch, seq_len, num_q_heads, num_kv_heads, head_dim, rotary_dim, mrope_section, mrope_interleaved, q_strides, k_strides + (1, 5, 2, 1, 32, 32, (4, 6, 6), False, None, None), + (2, 4, 3, 1, 32, 24, (2, 4, 6), False, None, None), + (1, 6, 2, 2, 32, 24, (2, 3, 7), True, None, None), + (1, 5, 2, 1, 32, 32, (4, 6, 6), False, (160, 1), (96, 1)), ] _TENSOR_DTYPES = [infinicore.float16, infinicore.float32] _POSITION_DTYPES = [infinicore.int32, infinicore.int64] @@ -51,18 +52,20 @@ def parse_test_cases(): num_kv_heads, head_dim, rotary_dim, - sections, - interleaved, + mrope_section, + mrope_interleaved, + q_strides, + k_strides, ) in _TEST_CASES_DATA: num_tokens = batch * seq_len positions = make_positions(num_tokens) for dtype in _TENSOR_DTYPES: for position_dtype in _POSITION_DTYPES: q_spec = TensorSpec.from_tensor( - (num_tokens, num_q_heads * head_dim), None, dtype + (num_tokens, num_q_heads * head_dim), q_strides, dtype ) k_spec = TensorSpec.from_tensor( - (num_tokens, num_kv_heads * head_dim), None, dtype + (num_tokens, num_kv_heads * head_dim), k_strides, dtype ) pos_spec = TensorSpec.from_tensor( positions.shape, @@ -79,13 +82,13 @@ def parse_test_cases(): "rope_theta": 10000.0, "head_dim": head_dim, "rotary_dim": rotary_dim, - "section": sections, - "interleaved": interleaved, + "mrope_section": mrope_section, + "mrope_interleaved": mrope_interleaved, }, comparison_target=None, tolerance=_TOLERANCE_MAP[dtype], output_count=2, - description="nn.MRoPE - OUT_OF_PLACE", + description="nn.RoPE - MROPE_OUT_OF_PLACE", ) ) return test_cases @@ -93,7 +96,7 @@ def parse_test_cases(): class OpTest(BaseOperatorTest): def __init__(self): - super().__init__("nn.MRoPE") + super().__init__("nn.RoPE MRoPE") def get_test_cases(self): return parse_test_cases() @@ -107,8 +110,8 @@ def torch_operator( rope_theta, head_dim, rotary_dim, - section, - interleaved, + mrope_section, + mrope_interleaved, ): cos, sin = create_sin_cos_table( max_position_embeddings, rotary_dim, rope_theta, q.device @@ -123,10 +126,10 @@ def torch_operator( positions, head_size=head_dim, rotary_dim=rotary_dim, - section_t=section[0], - section_h=section[1], - section_w=section[2], - interleaved=interleaved, + section_t=mrope_section[0], + section_h=mrope_section[1], + section_w=mrope_section[2], + interleaved=mrope_interleaved, ) def infinicore_operator( @@ -138,18 +141,18 @@ def infinicore_operator( rope_theta, head_dim, rotary_dim, - section, - interleaved, + mrope_section, + mrope_interleaved, ): - module = infinicore.nn.MRoPE( + module = infinicore.nn.RoPE( max_position_embeddings, rope_theta, head_dim, - rotary_dim, - section, - interleaved=interleaved, device=q.device, dtype=q.dtype, + rotary_dim=rotary_dim, + mrope_section=mrope_section, + mrope_interleaved=mrope_interleaved, ) return module(q, k, positions)