Skip to content
Merged

YaRN #445

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 208 additions & 25 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import math
from abc import abstractmethod
from enum import Enum
from typing import Annotated, Optional, overload
from numbers import Real
from typing import Annotated, Literal, Optional, overload

import torch
import torch.nn as nn
from pydantic import BaseModel, Field, model_validator, validator
from pydantic import BaseModel, Field, field_validator, model_validator, validator

from modalities.config.lookup_enum import LookupEnum
from modalities.config.utils import convert_base_model_config_to_dict
Expand All @@ -31,6 +32,60 @@
# GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT


def _validate_numeric_field(field_name: str, value: object) -> float:
"""Validate that a value is a real number (excluding bool) and cast to float."""
if isinstance(value, bool) or not isinstance(value, Real):
raise ValueError(f"rope_scaling.{field_name} must be a float")
return float(value)


class DefaultRopeScalingConfig(BaseModel):
"""Configuration for default RoPE behavior."""

rope_type: Literal["default"] = "default"


class YarnRopeScalingConfig(BaseModel):
"""Configuration for YaRN RoPE scaling."""

rope_type: Literal["yarn"] = "yarn"
original_max_position_embeddings: Annotated[int, Field(strict=True, ge=1)]
factor: Optional[Annotated[float, Field(ge=1.0)]] = None
attention_factor: Optional[Annotated[float, Field(gt=0.0)]] = None
mscale: Optional[Annotated[float, Field(ge=0.0)]] = None
mscale_all_dim: Optional[Annotated[float, Field(ge=0.0)]] = None
beta_fast: Annotated[float, Field(ge=0.0)] = 32.0
beta_slow: Annotated[float, Field(ge=0.0)] = 1.0
truncate: bool = True

@field_validator(
"factor",
"attention_factor",
"mscale",
"mscale_all_dim",
"beta_fast",
"beta_slow",
mode="before",
)
@classmethod
def validate_numeric_fields(cls, value: object, info):
if value is None:
return value
return _validate_numeric_field(info.field_name, value)

@model_validator(mode="after")
def validate_mscale_pair(self) -> "YarnRopeScalingConfig":
if (self.mscale is None) != (self.mscale_all_dim is None):
raise ValueError("rope_scaling.mscale and rope_scaling.mscale_all_dim must be provided together")
return self


RopeScalingConfig = Annotated[
DefaultRopeScalingConfig | YarnRopeScalingConfig,
Field(discriminator="rope_type"),
]


class LayerNorms(LookupEnum):
"""
Enum lookup class for LayerNorms.
Expand Down Expand Up @@ -120,7 +175,15 @@ class RotaryTransform(QueryKeyValueTransform):
XFormers implementation and removed in this implementation.#
"""

def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq: int = 10000):
def __init__(
self,
n_embd: int,
n_head: int,
seq_length_dim: int = -2,
base_freq: int = 10000,
max_position_embeddings: int | None = None,
rope_scaling: RopeScalingConfig | None = None,
):
"""
Initializes the RotaryTransform object.

Expand All @@ -136,16 +199,33 @@ def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq
self.dim_model = n_embd // n_head
self.seq_length_dim = seq_length_dim
self.base_freq = base_freq
self.max_position_embeddings = max_position_embeddings

if rope_scaling is not None and not isinstance(rope_scaling, (DefaultRopeScalingConfig, YarnRopeScalingConfig)):
raise TypeError(
"rope_scaling must be an instance of DefaultRopeScalingConfig, YarnRopeScalingConfig, or None"
)

self.rope_scaling = rope_scaling
self.attention_scaling = 1.0

self.reset_parameters()

def reset_parameters(self):
# If previously initialized on or moved to a device, reuse that device.
# Otherwise, use the default device of the current environment.
device = self.inv_freq.device if hasattr(self, "inv_freq") else None
inv_freq = 1.0 / (
self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model)
)
device = self.inv_freq.device if hasattr(self, "inv_freq") and isinstance(self.inv_freq, torch.Tensor) else None

rope_type = self.rope_scaling.rope_type if self.rope_scaling is not None else "default"

if rope_type == "yarn":
inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device)
else:
inv_freq = 1.0 / (
self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model)
)
self.attention_scaling = 1.0

self.register_buffer("inv_freq", inv_freq)

self._seq_len_cached = None
Expand All @@ -166,24 +246,6 @@ def rotate_half(self, x: torch.Tensor):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)

def _update_cos_sin_tables(self, x):
# Update the cosine and sine tables.
seq_len = x.shape[self.seq_length_dim]

# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
self._seq_len_cached = seq_len
t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(
x.device
) # here, we combine the two matrices (not zipping them).
self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)

return self._cos_cached, self._sin_cached

def apply_rotary_pos_emb(self, x, cos, sin):
"""
Applies rotary positional embedding to the input tensor.
Expand Down Expand Up @@ -228,6 +290,118 @@ def forward(

return q, k, v

def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.Tensor, float]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pleace place private methods below the public interface of the class.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this in e12db1a

"""Compute YaRN inverse frequencies and the attention scaling factor."""
if not isinstance(self.rope_scaling, YarnRopeScalingConfig):
raise ValueError("YaRN requires a rope_scaling config.")
if self.max_position_embeddings is None:
raise ValueError("YaRN requires max_position_embeddings to be set.")

original_max_position_embeddings = self.rope_scaling.original_max_position_embeddings
factor = self.rope_scaling.factor
if factor is None:
factor = self.max_position_embeddings / original_max_position_embeddings
factor_float = float(factor)

attention_factor = self.rope_scaling.attention_factor
mscale_pair = None
if self.rope_scaling.mscale is not None and self.rope_scaling.mscale_all_dim is not None:
mscale_pair = (self.rope_scaling.mscale, self.rope_scaling.mscale_all_dim)

beta_fast = self.rope_scaling.beta_fast
beta_slow = self.rope_scaling.beta_slow
truncate = self.rope_scaling.truncate

def get_mscale(scale: float, mscale: float = 1.0) -> float:
"""Return the YaRN mscale coefficient for a given scaling factor."""
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0

if attention_factor is None:
if mscale_pair is not None:
mscale, mscale_all_dim = mscale_pair
attention_factor = float(
get_mscale(factor_float, float(mscale)) / get_mscale(factor_float, float(mscale_all_dim))
)
else:
attention_factor = get_mscale(factor_float)

def find_correction_dim(num_rotations: float, dim: int, base: int, max_position_embeddings: int) -> float:
"""Map a target number of rotations to a rotary dimension index."""
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))

def find_correction_range(
low_rot: float,
high_rot: float,
dim: int,
base: int,
max_position_embeddings: int,
truncate: bool,
) -> tuple[float, float]:
"""Compute the lower and upper rotary-dimension correction bounds for YaRN."""
low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0), min(high, dim - 1)

def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor:
"""Create a clamped linear ramp used to blend interpolation and extrapolation."""
if min_value == max_value:
max_value += 0.001
linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func

dim = self.dim_model
base = self.base_freq

pos_freqs = base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor_float * pos_freqs)

low, high = find_correction_range(
beta_fast,
beta_slow,
dim,
base,
original_max_position_embeddings,
bool(truncate),
)
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
)

return inv_freq, float(attention_factor)

def _update_cos_sin_tables(self, x):
# Update the cosine and sine tables.
seq_len = x.shape[self.seq_length_dim]

# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seq_len != self._seq_len_cached
or self._cos_cached is None
or self._sin_cached is None
or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype
):
self._seq_len_cached = seq_len
t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(
x.device
) # here, we combine the two matrices (not zipping them).
self._cos_cached = (emb.cos() * self.attention_scaling)[None, None, :, :].to(x.dtype)
self._sin_cached = (emb.sin() * self.attention_scaling)[None, None, :, :].to(x.dtype)

return self._cos_cached, self._sin_cached


class QueryKeyValueTransformType(Enum):
"""
Expand Down Expand Up @@ -295,6 +469,15 @@ class RotaryTransformConfig(BaseModel):
n_head: Annotated[int, Field(strict=True, ge=0)]
seq_length_dim: Annotated[int, Field(strict=True)]
base_freq: Annotated[int, Field(strict=True, ge=10000)]
max_position_embeddings: Optional[Annotated[int, Field(strict=True, ge=1)]] = None
rope_scaling: Optional[RopeScalingConfig] = None

@model_validator(mode="after")
def validate_rope_scaling(self) -> "AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig":
"""Validate rope_scaling cross-field constraints."""
if isinstance(self.rope_scaling, YarnRopeScalingConfig) and self.max_position_embeddings is None:
raise ValueError("YaRN requires max_position_embeddings to be set")
return self

@validator("type_hint", pre=True, always=True)
def parse_sharding_strategy_by_name(cls, name):
Expand Down
4 changes: 2 additions & 2 deletions src/modalities/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gc
from datetime import datetime
from enum import Enum
import gc
from typing import Callable, Optional

import torch
Expand Down Expand Up @@ -388,7 +388,7 @@ def train(
self.gc.run(step_count=training_progress.num_seen_steps_total)
evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total)
checkpointing_callback(training_progress=training_progress)

profiler_cm.step()

@staticmethod
Expand Down
13 changes: 9 additions & 4 deletions tests/fsdp2_parallelization/test_tensor_parallelism.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change related to the yarn PR?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it isn't related to Yarn. I noticed this test was failing and fixed it. I can open a separate PR for the test fix if needed, but I'm not sure it's worth the extra overhead.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah thats fine. thanks for the clarification

Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
from tests.utility import find_free_port


def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path) -> Path:
def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path, file_tag: str = "") -> Path:
"""Patches the original configuration file to set a custom activation type."""
with original_config_path.open("r", encoding="utf-8") as f:
config_dict = yaml.safe_load(f)

config_dict["model_raw"]["config"]["activation_type"] = activation_type

tmp_file_path = tmp_dir / original_config_path.name
file_suffix = f"_{file_tag}" if file_tag else ""
tmp_file_path = tmp_dir / f"{original_config_path.stem}{file_suffix}{original_config_path.suffix}"
with tmp_file_path.open("w", encoding="utf-8") as f:
yaml.safe_dump(config_dict, f)

Expand Down Expand Up @@ -103,12 +104,16 @@ def _test_tp_sharding_impl(
):
# Seed before FSDP2 instantiation
torch.manual_seed(42)
fsdp2_path = patch_config_file(fsdp2_config_path, activation_type, tmp_config_dir)
fsdp2_path = patch_config_file(
fsdp2_config_path, activation_type, tmp_config_dir, file_tag=f"{activation_type}_rank{process_id}_fsdp2"
)
fsdp2_model, fsdp2_mesh = self._get_components(fsdp2_path, tmp_path)

# Seed again before TP instantiation to match
torch.manual_seed(42)
tp_path = patch_config_file(tp_config_path, activation_type, tmp_config_dir)
tp_path = patch_config_file(
tp_config_path, activation_type, tmp_config_dir, file_tag=f"{activation_type}_rank{process_id}_tp"
)
tp_model, tp_mesh = self._get_components(tp_path, tmp_path)

# Ensure models use the correct MLP
Expand Down
Loading
Loading