diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index 4373395e3214..f5a2831382be 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -66,6 +66,15 @@ def rope_freq_default(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: return cos_freq, sin_freq, {freq_var: freq} +def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str): + """Compute the inverse frequency of RoPE for gptj RoPE scaling.""" + freq = s / tir.power(theta, 2 * (d // 2) % d_range / tir.const(d_range, "float32")) + freq_var = tir.Var("freq", "float32") + cos_freq = tir.cos(freq_var).astype(dtype) + sin_freq = tir.sin(freq_var).astype(dtype) + return cos_freq, sin_freq, {freq_var: freq} + + def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals s: tir.Var, d: tir.Var, @@ -123,12 +132,74 @@ def rope_freq_longrope( # pylint: disable=too-many-arguments return cos_freq, sin_freq, {freq_var: freq} +def yarn_find_correction_dim( + num_rotations: int, + d: tir.Var, + theta: float, + max_position_embeddings: int, +): + """Inverse dim formula to find dim based on number of rotations""" + return (d * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(theta) + ) + + +def yarn_find_correction_range( + low_rot: int, + high_rot: int, + d: tir.Var, + theta: float, + max_position_embeddings: int, +): + """Find the correction range based on the number of rotations""" + low = tir.floor(yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings)) + high = tir.ceil(yarn_find_correction_dim(high_rot, d, theta, max_position_embeddings)) + return tir.max(low, 0), tir.min(high, d - 1) + + +def rope_freq_yarn( + s: tir.Var, + d: tir.Var, + d_range: int, + theta: float, + dtype: str, + original_max_position_embeddings: int, + scaling_factor: float, + beta_fast: int, + beta_slow: int, +): # pylint: disable=too-many-arguments, too-many-locals + """Compute the inverse frequency of RoPE for yarn RoPE scaling.""" + freq_extra = tir.const(1, "float32") / tir.power( + theta, d * 2 % d_range / tir.const(d_range, "float32") + ) + + freq_inter = tir.const(1, "float32") / tir.power( + scaling_factor * theta, d * 2 % d_range / tir.const(d_range, "float32") + ) + + low, high = yarn_find_correction_range( + beta_fast, beta_slow, d, theta, original_max_position_embeddings + ) + high = tir.if_then_else(low == high, high + 0.001, high) + inv_freq_mask = tir.const(1, "float32") - tir.max( + tir.min((d - low) / (high - low), 1.0), 0.0 + ).astype("float32") + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + freq = s * inv_freq + freq_var = tir.Var("freq", "float32") + cos_freq = tir.cos(freq_var).astype(dtype) + sin_freq = tir.sin(freq_var).astype(dtype) + return cos_freq, sin_freq, {freq_var: freq} + + def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable: """Return the RoPE inverse frequency computation function based on the given RoPE scaling. """ if "rope_type" not in rope_scaling: return rope_freq_default + if rope_scaling["rope_type"] == "gptj": + return rope_freq_gptj if rope_scaling["rope_type"] == "llama3": return partial( rope_freq_llama3, @@ -143,6 +214,14 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable: max_position_embeddings=rope_scaling["max_position_embeddings"], original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], ) + if rope_scaling["rope_type"] == "yarn": + return partial( + rope_freq_yarn, + original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], + scaling_factor=rope_scaling["factor"], + beta_fast=rope_scaling["beta_fast"], + beta_slow=rope_scaling["beta_slow"], + ) raise ValueError(f'Unsupported RoPE scaling type: {rope_scaling["rope_type"]}') @@ -220,11 +299,18 @@ def _rope( # pylint: disable=too-many-arguments (s + offset) * scale, d, rotary_dim, theta, dtype ) cos = cos_freq * x[b, s, h, d] - sin = sin_freq * tir.if_then_else( - d < rotary_dim // 2, - -x[b, s, h, d + rotary_dim // 2], - x[b, s, h, d - rotary_dim // 2], - ) + if rope_scaling["rope_type"] == "gptj": + sin = sin_freq * tir.if_then_else( + d % 2 == 0, + -x[b, s, h, d + 1], + x[b, s, h, d - 1], + ) + else: + sin = sin_freq * tir.if_then_else( + d < rotary_dim // 2, + -x[b, s, h, d + rotary_dim // 2], + x[b, s, h, d - rotary_dim // 2], + ) expr = cos + sin for var, value in var_map.items(): expr = tir.Let(var, value, expr) @@ -341,11 +427,18 @@ def _rope( # pylint: disable=too-many-arguments pos * scale, d, rotary_dim, theta, "float32", **kwargs ) cos = cos_freq * x[s, h, d].astype("float32") - sin = sin_freq * tir.if_then_else( - d < rotary_dim // 2, - -x[s, h, d + rotary_dim // 2], - x[s, h, d - rotary_dim // 2], - ).astype("float32") + if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj": + sin = sin_freq * tir.if_then_else( + d % 2 == 0, + -x[s, h, d + 1], + x[s, h, d - 1], + ).astype("float32") + else: + sin = sin_freq * tir.if_then_else( + d < rotary_dim // 2, + -x[s, h, d + rotary_dim // 2], + x[s, h, d - rotary_dim // 2], + ).astype("float32") expr = (cos + sin).astype(dtype) for var, value in var_map.items(): expr = tir.Let(var, value, expr)