From f93202c7e66b340c21e48a1a43f4606a864739f8 Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 9 Dec 2025 15:21:55 +0800 Subject: [PATCH] fix fope --- .../pytorch/backends/default/rotary_embedding.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/backends/default/rotary_embedding.py b/lmdeploy/pytorch/backends/default/rotary_embedding.py index bbccdc1061..daca2b4994 100644 --- a/lmdeploy/pytorch/backends/default/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/default/rotary_embedding.py @@ -285,10 +285,14 @@ def __init__(self, self.params = params inv_freq = self.params.inv_freq - inv_freq_idx_selected = inv_freq > 2 * torch.pi / self.max_position_embeddings - if self.params.num_inv_freq is not None and inv_freq_idx_selected.sum() > (inv_freq.shape[-1] - - self.params.num_inv_freq): - inv_freq_idx_selected[-self.params.num_inv_freq:] = False + inv_freq_idx_selected = torch.ones_like(inv_freq, dtype=torch.bool) + if self.params.num_inv_freq is not None: + num_inv_freq = self.params.num_inv_freq + inv_freq_idx_selected[num_inv_freq:] = False + else: + inv_freq_idx_selected = inv_freq > (2.0 * torch.pi / self.max_position_embeddings) + num_inv_freq = inv_freq_idx_selected.sum().item() + self.inv_freq = inv_freq[inv_freq_idx_selected] self.register_buffer('inv_freq', self.inv_freq, persistent=False)