Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/infinicore/nn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

#include "nn/embedding.hpp"
#include "nn/linear.hpp"
#include "nn/mrope.hpp"
#include "nn/rmsnorm.hpp"
#include "nn/rope.hpp"
59 changes: 0 additions & 59 deletions include/infinicore/nn/mrope.hpp

This file was deleted.

54 changes: 29 additions & 25 deletions include/infinicore/nn/rope.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#include "rope_scaling_configs.hpp"
#include <cmath>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

namespace infinicore::nn {

Expand All @@ -32,6 +35,9 @@ class RoPE : public Module {
* @param dtype Data type for sin/cos cache (default: DataType::F32)
* @param device Device to create the cache on
* @param scaling RoPE scaling configuration (default: nullptr)
* @param mrope_section Optional MRoPE section sizes [t, h, w], whose sum must equal rotary_dim / 2.
* When set, pair forward overloads apply MRoPE to q/k using positions [3, num_tokens].
* @param mrope_interleaved Whether to interleave MRoPE axes/frequency sections.
*/
RoPE(size_t head_dim,
size_t rotary_dim,
Expand All @@ -40,50 +46,47 @@ class RoPE : public Module {
Algo algo = Algo::GPT_J,
const DataType &dtype = DataType::F32,
const Device &device = Device(),
std::shared_ptr<RopeScalingConfig> scaling = nullptr);
std::shared_ptr<RopeScalingConfig> scaling = nullptr,
std::optional<std::vector<int>> mrope_section = std::nullopt,
bool mrope_interleaved = false);

/**
* @brief Forward pass: apply RoPE to a tensor
* @brief Forward pass: apply standard RoPE to a tensor
*
* @param x Input tensor of shape (..., rotary_dim) where ... is any number of dimensions
* @param pos Position IDs tensor of shape (*,) typically [seq_len] or [batch, seq_len]
* @param in_place If true, modify input tensor in place (default: false)
* @return Rotated tensor with same shape as input
*
* Applies rotary position embeddings to the input tensor.
* For attention mechanisms, call this method separately for query and key tensors.
*
* Common input shapes:
* - [batch, num_heads, seq_len, rotary_dim]
* - [batch, seq_len, num_heads, rotary_dim]
* - [seq_len, rotary_dim]
*/
Tensor forward(const Tensor &x, const Tensor &pos, bool in_place = false) const;

/**
* @brief Forward pass: apply RoPE to a tensor in place
*
* @param y Output tensor of shape (..., rotary_dim) where ... is any number of dimensions
* @param x Input tensor of shape (..., rotary_dim) where ... is any number of dimensions
* @param pos Position IDs tensor of shape (*,) typically [seq_len] or [batch, seq_len]
* @return Rotated tensor with same shape as input
* @brief Apply MRoPE to q and k.
*
* Applies rotary position embeddings to the input tensor.
* For attention mechanisms, call this method separately for query and key tensors.
*
* Common input shapes:
* - [batch, num_heads, seq_len, rotary_dim]
* - [batch, seq_len, num_heads, rotary_dim]
* - [seq_len, rotary_dim]
* Requires construction with mrope_section. q/k may be either
* [num_tokens, num_heads * head_dim] or [num_tokens, num_heads, head_dim].
* positions is [3, num_tokens] with axes ordered as t, h, w.
*/
Tensor forward(const Tensor &y, const Tensor &x, const Tensor &pos) const;
std::pair<Tensor, Tensor> forward(const Tensor &q, const Tensor &k, const Tensor &positions) const;

/**
* @brief Apply MRoPE to q and k into caller-provided outputs.
*/
std::pair<Tensor, Tensor> forward(const Tensor &q_out,
const Tensor &k_out,
const Tensor &q,
const Tensor &k,
const Tensor &positions) const;

// Module information
size_t rotary_dim() const { return rotary_dim_; }
size_t head_dim() const { return head_dim_; }
size_t max_seq_len() const { return max_seq_len_; }
double theta() const { return theta_; }
Algo algo() const { return algo_; }
DataType dtype() const { return dtype_; }
const std::optional<std::vector<int>> &mrope_section() const { return mrope_section_; }
bool mrope_interleaved() const { return mrope_interleaved_; }

// String representation
std::string extra_repr() const;
Expand All @@ -95,14 +98,15 @@ class RoPE : public Module {

private:
void initialize_cache();

size_t rotary_dim_; // Number of dimensions to apply rotation to (must be even).
size_t head_dim_; // Dimension of each attention head
size_t max_seq_len_; // Maximum sequence length
double theta_; // Base frequency for rotary embeddings
Algo algo_; // RoPE algorithm type
DataType dtype_; // Data type for cache tables
std::shared_ptr<RopeScalingConfig> scaling_; // RoPE scaling configuration
std::optional<std::vector<int>> mrope_section_;
bool mrope_interleaved_;
};

} // namespace infinicore::nn
3 changes: 1 addition & 2 deletions python/infinicore/nn/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .container import InfiniCoreModuleList as ModuleList
from .linear import Linear
from .module import InfiniCoreModule as Module
from .mrope import MRoPE
from .normalization import RMSNorm
from .rope import RoPE
from .sparse import Embedding

__all__ = ["Linear", "RMSNorm", "Embedding", "RoPE", "MRoPE", "ModuleList", "Module"]
__all__ = ["Linear", "RMSNorm", "Embedding", "RoPE", "ModuleList", "Module"]
105 changes: 0 additions & 105 deletions python/infinicore/nn/modules/mrope.py

This file was deleted.

Loading
Loading