From 09b644ff4bf30d7ee245fd4748d4a9648639c849 Mon Sep 17 00:00:00 2001 From: Mengtao Yuan Date: Fri, 19 Apr 2024 14:37:05 -0700 Subject: [PATCH] Update model arg name rope_theta to be consistent with those in llama's website (#3147) Summary: As title Reviewed By: larryliu0820 Differential Revision: D56357117 Pulled By: iseeyuan --- examples/models/llama2/llama_transformer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 4184861f091..298e0463c07 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -88,7 +88,10 @@ class ModelArgs: use_sdpa_with_kv_cache_op: bool = ( False # Use custom sdpa op that updates kv cache in-place ) - rope_freq_base: float = 10000.0 # The base frequency for RoPE + rope_theta: Optional[float] = ( + None # The official name to override self.rope_freq_base. + ) + rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. # Additional Model Metadata needed at runtime bos_idx: int = 1 eos_idx: int = 3 @@ -99,6 +102,10 @@ def __post_init__(self): if self.n_kv_heads is None: self.n_kv_heads = self.n_heads + # rope_theta overrides rope_freq_base since it's the official name. + if self.rope_theta is not None: + self.rope_freq_base = self.rope_theta + if self.use_sdpa_with_kv_cache_op: assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache"