diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 2da4979c0..f43e6e87b 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -2,11 +2,12 @@ import math from abc import abstractmethod from enum import Enum -from typing import Annotated, Optional, overload +from numbers import Real +from typing import Annotated, Literal, Optional, overload import torch import torch.nn as nn -from pydantic import BaseModel, Field, model_validator, validator +from pydantic import BaseModel, Field, field_validator, model_validator, validator from modalities.config.lookup_enum import LookupEnum from modalities.config.utils import convert_base_model_config_to_dict @@ -31,6 +32,60 @@ # GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT +def _validate_numeric_field(field_name: str, value: object) -> float: + """Validate that a value is a real number (excluding bool) and cast to float.""" + if isinstance(value, bool) or not isinstance(value, Real): + raise ValueError(f"rope_scaling.{field_name} must be a float") + return float(value) + + +class DefaultRopeScalingConfig(BaseModel): + """Configuration for default RoPE behavior.""" + + rope_type: Literal["default"] = "default" + + +class YarnRopeScalingConfig(BaseModel): + """Configuration for YaRN RoPE scaling.""" + + rope_type: Literal["yarn"] = "yarn" + original_max_position_embeddings: Annotated[int, Field(strict=True, ge=1)] + factor: Optional[Annotated[float, Field(ge=1.0)]] = None + attention_factor: Optional[Annotated[float, Field(gt=0.0)]] = None + mscale: Optional[Annotated[float, Field(ge=0.0)]] = None + mscale_all_dim: Optional[Annotated[float, Field(ge=0.0)]] = None + beta_fast: Annotated[float, Field(ge=0.0)] = 32.0 + beta_slow: Annotated[float, Field(ge=0.0)] = 1.0 + truncate: bool = True + + @field_validator( + "factor", + "attention_factor", + "mscale", + "mscale_all_dim", + "beta_fast", + "beta_slow", + mode="before", + ) + @classmethod + def validate_numeric_fields(cls, value: object, info): + if value is None: + return value + return _validate_numeric_field(info.field_name, value) + + @model_validator(mode="after") + def validate_mscale_pair(self) -> "YarnRopeScalingConfig": + if (self.mscale is None) != (self.mscale_all_dim is None): + raise ValueError("rope_scaling.mscale and rope_scaling.mscale_all_dim must be provided together") + return self + + +RopeScalingConfig = Annotated[ + DefaultRopeScalingConfig | YarnRopeScalingConfig, + Field(discriminator="rope_type"), +] + + class LayerNorms(LookupEnum): """ Enum lookup class for LayerNorms. @@ -120,7 +175,15 @@ class RotaryTransform(QueryKeyValueTransform): XFormers implementation and removed in this implementation.# """ - def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq: int = 10000): + def __init__( + self, + n_embd: int, + n_head: int, + seq_length_dim: int = -2, + base_freq: int = 10000, + max_position_embeddings: int | None = None, + rope_scaling: RopeScalingConfig | None = None, + ): """ Initializes the RotaryTransform object. @@ -136,16 +199,33 @@ def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq self.dim_model = n_embd // n_head self.seq_length_dim = seq_length_dim self.base_freq = base_freq + self.max_position_embeddings = max_position_embeddings + + if rope_scaling is not None and not isinstance(rope_scaling, (DefaultRopeScalingConfig, YarnRopeScalingConfig)): + raise TypeError( + "rope_scaling must be an instance of DefaultRopeScalingConfig, YarnRopeScalingConfig, or None" + ) + + self.rope_scaling = rope_scaling + self.attention_scaling = 1.0 self.reset_parameters() def reset_parameters(self): # If previously initialized on or moved to a device, reuse that device. # Otherwise, use the default device of the current environment. - device = self.inv_freq.device if hasattr(self, "inv_freq") else None - inv_freq = 1.0 / ( - self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model) - ) + device = self.inv_freq.device if hasattr(self, "inv_freq") and isinstance(self.inv_freq, torch.Tensor) else None + + rope_type = self.rope_scaling.rope_type if self.rope_scaling is not None else "default" + + if rope_type == "yarn": + inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device) + else: + inv_freq = 1.0 / ( + self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model) + ) + self.attention_scaling = 1.0 + self.register_buffer("inv_freq", inv_freq) self._seq_len_cached = None @@ -166,24 +246,6 @@ def rotate_half(self, x: torch.Tensor): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) - def _update_cos_sin_tables(self, x): - # Update the cosine and sine tables. - seq_len = x.shape[self.seq_length_dim] - - # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) - if seq_len != self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: - self._seq_len_cached = seq_len - t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32) - freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) - emb = torch.cat((freqs, freqs), dim=-1).to( - x.device - ) # here, we combine the two matrices (not zipping them). - self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) - self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) - - return self._cos_cached, self._sin_cached - def apply_rotary_pos_emb(self, x, cos, sin): """ Applies rotary positional embedding to the input tensor. @@ -228,6 +290,118 @@ def forward( return q, k, v + def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.Tensor, float]: + """Compute YaRN inverse frequencies and the attention scaling factor.""" + if not isinstance(self.rope_scaling, YarnRopeScalingConfig): + raise ValueError("YaRN requires a rope_scaling config.") + if self.max_position_embeddings is None: + raise ValueError("YaRN requires max_position_embeddings to be set.") + + original_max_position_embeddings = self.rope_scaling.original_max_position_embeddings + factor = self.rope_scaling.factor + if factor is None: + factor = self.max_position_embeddings / original_max_position_embeddings + factor_float = float(factor) + + attention_factor = self.rope_scaling.attention_factor + mscale_pair = None + if self.rope_scaling.mscale is not None and self.rope_scaling.mscale_all_dim is not None: + mscale_pair = (self.rope_scaling.mscale, self.rope_scaling.mscale_all_dim) + + beta_fast = self.rope_scaling.beta_fast + beta_slow = self.rope_scaling.beta_slow + truncate = self.rope_scaling.truncate + + def get_mscale(scale: float, mscale: float = 1.0) -> float: + """Return the YaRN mscale coefficient for a given scaling factor.""" + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + if attention_factor is None: + if mscale_pair is not None: + mscale, mscale_all_dim = mscale_pair + attention_factor = float( + get_mscale(factor_float, float(mscale)) / get_mscale(factor_float, float(mscale_all_dim)) + ) + else: + attention_factor = get_mscale(factor_float) + + def find_correction_dim(num_rotations: float, dim: int, base: int, max_position_embeddings: int) -> float: + """Map a target number of rotations to a rotary dimension index.""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range( + low_rot: float, + high_rot: float, + dim: int, + base: int, + max_position_embeddings: int, + truncate: bool, + ) -> tuple[float, float]: + """Compute the lower and upper rotary-dimension correction bounds for YaRN.""" + low = find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor: + """Create a clamped linear ramp used to blend interpolation and extrapolation.""" + if min_value == max_value: + max_value += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + dim = self.dim_model + base = self.base_freq + + pos_freqs = base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor_float * pos_freqs) + + low, high = find_correction_range( + beta_fast, + beta_slow, + dim, + base, + original_max_position_embeddings, + bool(truncate), + ) + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, float(attention_factor) + + def _update_cos_sin_tables(self, x): + # Update the cosine and sine tables. + seq_len = x.shape[self.seq_length_dim] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seq_len != self._seq_len_cached + or self._cos_cached is None + or self._sin_cached is None + or self._cos_cached.device != x.device + or self._cos_cached.dtype != x.dtype + ): + self._seq_len_cached = seq_len + t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) + emb = torch.cat((freqs, freqs), dim=-1).to( + x.device + ) # here, we combine the two matrices (not zipping them). + self._cos_cached = (emb.cos() * self.attention_scaling)[None, None, :, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.attention_scaling)[None, None, :, :].to(x.dtype) + + return self._cos_cached, self._sin_cached + class QueryKeyValueTransformType(Enum): """ @@ -295,6 +469,15 @@ class RotaryTransformConfig(BaseModel): n_head: Annotated[int, Field(strict=True, ge=0)] seq_length_dim: Annotated[int, Field(strict=True)] base_freq: Annotated[int, Field(strict=True, ge=10000)] + max_position_embeddings: Optional[Annotated[int, Field(strict=True, ge=1)]] = None + rope_scaling: Optional[RopeScalingConfig] = None + + @model_validator(mode="after") + def validate_rope_scaling(self) -> "AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig": + """Validate rope_scaling cross-field constraints.""" + if isinstance(self.rope_scaling, YarnRopeScalingConfig) and self.max_position_embeddings is None: + raise ValueError("YaRN requires max_position_embeddings to be set") + return self @validator("type_hint", pre=True, always=True) def parse_sharding_strategy_by_name(cls, name): diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index c715a01fa..4ad54b226 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -1,6 +1,6 @@ +import gc from datetime import datetime from enum import Enum -import gc from typing import Callable, Optional import torch @@ -388,7 +388,7 @@ def train( self.gc.run(step_count=training_progress.num_seen_steps_total) evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total) checkpointing_callback(training_progress=training_progress) - + profiler_cm.step() @staticmethod diff --git a/tests/fsdp2_parallelization/test_tensor_parallelism.py b/tests/fsdp2_parallelization/test_tensor_parallelism.py index d3ccd46c2..25abc686b 100644 --- a/tests/fsdp2_parallelization/test_tensor_parallelism.py +++ b/tests/fsdp2_parallelization/test_tensor_parallelism.py @@ -20,14 +20,15 @@ from tests.utility import find_free_port -def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path) -> Path: +def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path, file_tag: str = "") -> Path: """Patches the original configuration file to set a custom activation type.""" with original_config_path.open("r", encoding="utf-8") as f: config_dict = yaml.safe_load(f) config_dict["model_raw"]["config"]["activation_type"] = activation_type - tmp_file_path = tmp_dir / original_config_path.name + file_suffix = f"_{file_tag}" if file_tag else "" + tmp_file_path = tmp_dir / f"{original_config_path.stem}{file_suffix}{original_config_path.suffix}" with tmp_file_path.open("w", encoding="utf-8") as f: yaml.safe_dump(config_dict, f) @@ -103,12 +104,16 @@ def _test_tp_sharding_impl( ): # Seed before FSDP2 instantiation torch.manual_seed(42) - fsdp2_path = patch_config_file(fsdp2_config_path, activation_type, tmp_config_dir) + fsdp2_path = patch_config_file( + fsdp2_config_path, activation_type, tmp_config_dir, file_tag=f"{activation_type}_rank{process_id}_fsdp2" + ) fsdp2_model, fsdp2_mesh = self._get_components(fsdp2_path, tmp_path) # Seed again before TP instantiation to match torch.manual_seed(42) - tp_path = patch_config_file(tp_config_path, activation_type, tmp_config_dir) + tp_path = patch_config_file( + tp_config_path, activation_type, tmp_config_dir, file_tag=f"{activation_type}_rank{process_id}_tp" + ) tp_model, tp_mesh = self._get_components(tp_path, tmp_path) # Ensure models use the correct MLP diff --git a/tests/test_rotary_qkv_transform.py b/tests/test_rotary_qkv_transform.py index fa82715b1..b44868e4b 100644 --- a/tests/test_rotary_qkv_transform.py +++ b/tests/test_rotary_qkv_transform.py @@ -1,6 +1,7 @@ +import pytest import torch -from modalities.models.gpt2.gpt2_model import RotaryTransform +from modalities.models.gpt2.gpt2_model import AttentionConfig, RotaryTransform, YarnRopeScalingConfig def test_rotary_transform(): @@ -41,3 +42,170 @@ def test_rotary_transform(): comp_rot_h = torch.cat([-comp_h_2, comp_h_1], dim=-1) comp_rot_expected = comp * cos_m_theta + comp_rot_h * sin_m_theta assert torch.equal(comp_rot_expected, comp_rot) + + +def _apply_rotary(x: torch.Tensor, cos_cached: torch.Tensor, sin_cached: torch.Tensor) -> torch.Tensor: + cos_local = cos_cached[:, :, : x.shape[-2], :] + sin_local = sin_cached[:, :, : x.shape[-2], :] + x1, x2 = x.chunk(2, dim=-1) + x_rot = torch.cat((-x2, x1), dim=-1) + return (x * cos_local) + (x_rot * sin_local) + + +def _assert_yarn_outputs_match_reference( + rotary_transform: RotaryTransform, + q: torch.Tensor, + k: torch.Tensor, + q_rot: torch.Tensor, + k_rot: torch.Tensor, + seq_length: int, +) -> None: + t = torch.arange(seq_length, device=q.device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, rotary_transform.inv_freq.to(q.dtype)) + emb = torch.cat((freqs, freqs), dim=-1) + cos = (emb.cos() * rotary_transform.attention_scaling)[None, None, :, :].to(q.dtype) + sin = (emb.sin() * rotary_transform.attention_scaling)[None, None, :, :].to(q.dtype) + + q_expected = _apply_rotary(q, cos, sin) + k_expected = _apply_rotary(k, cos, sin) + + assert torch.allclose(q_rot, q_expected, atol=1e-5, rtol=1e-5) + assert torch.allclose(k_rot, k_expected, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize( + "rope_scaling", + [ + YarnRopeScalingConfig( + factor=2.0, + beta_fast=32, + beta_slow=1, + original_max_position_embeddings=4, + ), + YarnRopeScalingConfig( + beta_fast=32, + beta_slow=1, + original_max_position_embeddings=4, + ), + ], +) +def test_rotary_transform_yarn_matches_reference(rope_scaling: YarnRopeScalingConfig): + bs = 1 + n_heads = 2 + embedding_dim = 8 + seq_length = 8 + head_dim = embedding_dim // n_heads + + q = torch.randn(bs, n_heads, seq_length, head_dim) + k = torch.randn(bs, n_heads, seq_length, head_dim) + v = torch.randn(bs, n_heads, seq_length, head_dim) + + rotary_transform = RotaryTransform( + n_embd=embedding_dim, + n_head=n_heads, + base_freq=10000, + max_position_embeddings=seq_length, + rope_scaling=rope_scaling, + ) + + q_rot, k_rot, v_rot = rotary_transform(q=q, k=k, v=v) + assert torch.equal(v, v_rot) + + _assert_yarn_outputs_match_reference( + rotary_transform=rotary_transform, + q=q, + k=k, + q_rot=q_rot, + k_rot=k_rot, + seq_length=seq_length, + ) + + +def test_rotary_transform_rejects_dict_rope_scaling(): + rope_scaling = { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + } + + with pytest.raises(TypeError, match="rope_scaling must be an instance"): + RotaryTransform( + n_embd=8, + n_head=2, + base_freq=10000, + max_position_embeddings=8, + rope_scaling=rope_scaling, + ) + + +@pytest.mark.parametrize( + ("key", "value"), + [ + ("beta_fast", "32"), + ("beta_slow", torch.tensor(1.0)), + ("beta_slow", False), + ], +) +def test_rotary_transform_config_yarn_rejects_invalid_beta_values(key: str, value: object): + rope_scaling = { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + key: value, + } + + with pytest.raises(ValueError, match=rf"rope_scaling\.{key} must be a float"): + AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig( + n_embd=8, + n_head=2, + seq_length_dim=-2, + base_freq=10000, + max_position_embeddings=8, + rope_scaling=rope_scaling, + ) + + +@pytest.mark.parametrize( + ("rope_scaling", "match"), + [ + ( + { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + "mscale": "1.0", + "mscale_all_dim": 1.0, + }, + r"rope_scaling\.mscale must be a float", + ), + ( + { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + "mscale": 1.0, + "mscale_all_dim": torch.tensor(1.0), + }, + r"rope_scaling\.mscale_all_dim must be a float", + ), + ( + { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 4, + "mscale": 1.0, + }, + r"rope_scaling\.mscale and rope_scaling\.mscale_all_dim must be provided together", + ), + ], +) +def test_rotary_transform_config_yarn_rejects_invalid_mscale_values(rope_scaling: dict, match: str): + with pytest.raises(ValueError, match=match): + AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig( + n_embd=8, + n_head=2, + seq_length_dim=-2, + base_freq=10000, + max_position_embeddings=8, + rope_scaling=rope_scaling, + )