From 09e1b6188f93b2339268b575abcc875b98bcf914 Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Tue, 28 Apr 2026 14:17:22 -0700 Subject: [PATCH 1/6] add mimo-v2.5 support --- convert_hf_to_gguf.py | 102 +++++++++++++++++++++++++++++++++++++- src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + src/llama-hparams.h | 3 ++ src/llama-model-saver.cpp | 1 + src/llama-model.cpp | 2 + src/models/mimo2-iswa.cpp | 8 +++ 7 files changed, 117 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bf8af863a47..523328894e6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9292,10 +9292,83 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register("MiMoV2FlashForCausalLM") +@ModelBase.register("MiMoV2FlashForCausalLM", "MiMoV2ForCausalLM") class MimoV2Model(TextModel): model_arch = gguf.MODEL_ARCH.MIMO2 + @staticmethod + def _tp_aware_qkv_dequant(weight: Tensor, scale_inv: Tensor, + n_q: int, n_kv: int, hd: int, vhd: int, + tp: int = 4, bs: int = 128) -> Tensor: + # MiMo-V2.5 ships qkv_proj sharded TP=4: rows are stacked per-rank as + # [Q_per | K_per | V_per] for r in 0..3. weight_scale_inv has + # ceil(rows_per_rank/bs) block-rows per rank (last may extend past + # rows_per_rank with phantom rows that don't appear in the weight). + # Existing repeat_interleave aligns rank 0 only and mis-applies scales + # to ranks 1..3 once rows_per_rank isn't a multiple of bs. + q_per = (n_q * hd) // tp + k_per = (n_kv * hd) // tp + v_per = (n_kv * vhd) // tp + rows_per_rank = q_per + k_per + v_per + blocks_per_rank = (rows_per_rank + bs - 1) // bs + total_rows = tp * rows_per_rank + if weight.shape[0] != total_rows: + raise ValueError(f"qkv_proj weight rows {weight.shape[0]} != tp*rows_per_rank {total_rows}") + if scale_inv.shape[0] != tp * blocks_per_rank: + raise ValueError(f"scale_inv rows {scale_inv.shape[0]} != tp*blocks_per_rank {tp * blocks_per_rank}") + + scale_inv = scale_inv.float() + # per-row scale-row index: rank * blocks_per_rank + (rr_in_rank // bs) + row_idx = torch.arange(total_rows) + rr = row_idx % rows_per_rank + rank = row_idx // rows_per_rank + scale_row_idx = rank * blocks_per_rank + (rr // bs) + # gather: (total_rows, n_col_blocks) + scale_per_row_block = scale_inv[scale_row_idx] + # expand col-blocks → cols: each block-col covers `bs` weight cols + scale_full = scale_per_row_block.repeat_interleave(bs, dim=1) + # crop to weight col count (in case last col-block isn't full) + scale_full = scale_full[:, : weight.shape[1]] + return weight.float() * scale_full + + def dequant_model(self): + # Capture raw FP8 (weight, scale_inv) lambdas for qkv_proj BEFORE super + # rewrites them with the existing dequant. Replace super's lambda after + # it runs so scale_inv removal still happens via the standard path. + qkv_overrides: dict[str, tuple[Callable, Callable, int]] = {} + qc = self.hparams.get("quantization_config") + if isinstance(qc, dict) and qc.get("quant_method") == "fp8": + pat = re.compile(r"^model\.layers\.(\d+)\.self_attn\.qkv_proj\.weight_scale_inv$") + for name in list(self.model_tensors.keys()): + m = pat.match(name) + if not m: + continue + weight_name = name.removesuffix("_scale_inv") + if weight_name not in self.model_tensors: + continue + qkv_overrides[weight_name] = ( + self.model_tensors[weight_name], + self.model_tensors[name], + int(m.group(1)), + ) + + super().dequant_model() + + if not qkv_overrides: + return + + n_q = self.hparams["num_attention_heads"] + hd = self.hparams["head_dim"] + vhd = self.hparams["v_head_dim"] + hybrid = self.hparams["hybrid_layer_pattern"] + for weight_name, (w_fn, s_fn, bid) in qkv_overrides.items(): + is_swa = hybrid[bid] == 1 + n_kv = self.hparams["swa_num_key_value_heads" if is_swa else "num_key_value_heads"] + self.model_tensors[weight_name] = ( + lambda w_fn=w_fn, s_fn=s_fn, n_q=n_q, n_kv=n_kv, hd=hd, vhd=vhd: + MimoV2Model._tp_aware_qkv_dequant(w_fn(), s_fn(), n_q, n_kv, hd, vhd) + ) + def set_gguf_parameters(self): super().set_gguf_parameters() @@ -9320,6 +9393,10 @@ def set_gguf_parameters(self): self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5)) + v_scale = self.hparams.get("attention_value_scale") + if v_scale is not None: + self.gguf_writer.add_attn_value_scale(float(v_scale)) + _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch, name, bid): @@ -9333,6 +9410,29 @@ def modify_tensors(self, data_torch, name, bid): if "model.mtp." in name: return + # MiMo-V2.5 (non-Pro) ships audio/visual modules we don't run in llama.cpp + if name.startswith(("audio_encoder.", "visual.", "speech_embeddings.")): + return + + # split fused qkv_proj into separate q/k/v tensors (MiMoV2ForCausalLM uses fused_qkv layout) + if "self_attn.qkv_proj" in name: + assert bid is not None + is_swa = self.hparams["hybrid_layer_pattern"][bid] == 1 + num_q_heads = self.hparams["swa_num_attention_heads" if is_swa else "num_attention_heads"] + num_kv_heads = self.hparams["swa_num_key_value_heads" if is_swa else "num_key_value_heads"] + head_dim = self.hparams["swa_head_dim" if is_swa else "head_dim"] + v_head_dim = self.hparams["swa_v_head_dim" if is_swa else "v_head_dim"] + q_size = num_q_heads * head_dim + k_size = num_kv_heads * head_dim + v_size = num_kv_heads * v_head_dim + q, k, v = data_torch.split([q_size, k_size, v_size], dim=0) + suffix = ".weight" if name.endswith(".weight") else ".bias" + base = name.replace(f"qkv_proj{suffix}", "") + yield from super().modify_tensors(q, f"{base}q_proj{suffix}", bid) + yield from super().modify_tensors(k, f"{base}k_proj{suffix}", bid) + yield from super().modify_tensors(v, f"{base}v_proj{suffix}", bid) + return + # process the experts separately if name.find("mlp.experts") != -1: n_experts = self.hparams["n_routed_experts"] diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 633a66fc665..59dde99e362 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -232,6 +232,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, + { LLM_KV_ATTENTION_VALUE_SCALE, "%s.attention.value_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 8f335f5c7b3..e37d548c98e 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -236,6 +236,7 @@ enum llm_kv { LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, + LLM_KV_ATTENTION_VALUE_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index ac7f9ee8650..704ca81e069 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -166,6 +166,9 @@ struct llama_hparams { float f_attn_out_scale = 0.0f; uint32_t attn_temp_length = 0; + // mimo-v2: per-head V scalar applied before attention (1.0 = disabled) + float f_attn_value_scale = 1.0f; + bool causal_attn = true; bool use_alibi = false; bool attn_soft_cap = false; diff --git a/src/llama-model-saver.cpp b/src/llama-model-saver.cpp index 26864c18e97..e83056557bf 100644 --- a/src/llama-model-saver.cpp +++ b/src/llama-model-saver.cpp @@ -268,6 +268,7 @@ void llama_model_saver::add_kv_from_model() { // add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, ???); add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); add_kv(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale); + add_kv(LLM_KV_ATTENTION_VALUE_SCALE, hparams.f_attn_value_scale); add_kv(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length); add_kv(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale); add_kv(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9e2a13cbd43..a3d2cf0d4fc 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2896,6 +2896,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_SCALE, hparams.f_attn_value_scale, false); ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); switch (hparams.n_layer) { @@ -8182,6 +8183,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: f_attn_value_scale = %.4f\n", __func__, hparams.f_attn_value_scale); LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); diff --git a/src/models/mimo2-iswa.cpp b/src/models/mimo2-iswa.cpp index 52c6acfe214..b321e82e6f4 100644 --- a/src/models/mimo2-iswa.cpp +++ b/src/models/mimo2-iswa.cpp @@ -10,6 +10,8 @@ llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_ auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); + const float v_scale = hparams.f_attn_value_scale; + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -35,6 +37,11 @@ llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); + if (v_scale != 1.0f) { + Vcur = ggml_scale(ctx0, Vcur, v_scale); + cb(Vcur, "Vcur_scaled", il); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); @@ -60,6 +67,7 @@ llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_ cur = build_attn(inp_attn, model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il); + cb(cur, "attn_out", il); } if (il == n_layer - 1 && inp_out_ids) { From 548fde30ae34cecd9addccb9efe40691225c1e34 Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Tue, 28 Apr 2026 15:51:33 -0700 Subject: [PATCH 2/6] mimo-v2.5: fix modify_tensors row split --- convert_hf_to_gguf.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 523328894e6..c3832996923 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9425,6 +9425,24 @@ def modify_tensors(self, data_torch, name, bid): q_size = num_q_heads * head_dim k_size = num_kv_heads * head_dim v_size = num_kv_heads * v_head_dim + # MiMo-V2.5 ships qkv_proj sharded TP=4: rows are stacked per-rank as + # [Q_per | K_per | V_per] for r in 0..3, not as a single [Q | K | V]. + # Re-group rows from each rank into a unified [Q | K | V] before split. + tp = 4 + total_rows = data_torch.shape[0] + if data_torch.ndim == 2 and total_rows == q_size + k_size + v_size and total_rows % tp == 0: + q_per = q_size // tp + k_per = k_size // tp + v_per = v_size // tp + rows_per_rank = q_per + k_per + v_per + if rows_per_rank * tp == total_rows: + qs, ks, vs = [], [], [] + for r in range(tp): + base = r * rows_per_rank + qs.append(data_torch[base : base + q_per]) + ks.append(data_torch[base + q_per : base + q_per + k_per]) + vs.append(data_torch[base + q_per + k_per : base + rows_per_rank]) + data_torch = torch.cat(qs + ks + vs, dim=0) q, k, v = data_torch.split([q_size, k_size, v_size], dim=0) suffix = ".weight" if name.endswith(".weight") else ".bias" base = name.replace(f"qkv_proj{suffix}", "") From 287ac836e85b5d5b0715af985e588bb2a7f2cb4c Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Tue, 28 Apr 2026 21:06:25 -0700 Subject: [PATCH 3/6] mimi-v2.5: forgot `add_attn_value_scale` plumbing --- gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 83ae51ce9ce..8b46605889a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -175,6 +175,7 @@ class Attention: SLIDING_WINDOW = "{arch}.attention.sliding_window" SCALE = "{arch}.attention.scale" OUTPUT_SCALE = "{arch}.attention.output_scale" + VALUE_SCALE = "{arch}.attention.value_scale" TEMPERATURE_LENGTH = "{arch}.attention.temperature_length" KEY_LENGTH_MLA = "{arch}.attention.key_length_mla" VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 6a81ca37d8c..8040328389d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -943,6 +943,9 @@ def add_attention_scale(self, value: float) -> None: def add_attn_output_scale(self, value: float) -> None: self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value) + def add_attn_value_scale(self, value: float) -> None: + self.add_float32(Keys.Attention.VALUE_SCALE.format(arch=self.arch), value) + def add_attn_temperature_length(self, value: int) -> None: self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value) From 3dcaba9853a88c4d32fbd47a9ece58860489430c Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Wed, 29 Apr 2026 01:07:27 -0700 Subject: [PATCH 4/6] mimi-v2.5: fix tp dequant to detect tp rows --- convert_hf_to_gguf.py | 81 +++++++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 34 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c3832996923..7d3ccded588 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9299,23 +9299,42 @@ class MimoV2Model(TextModel): @staticmethod def _tp_aware_qkv_dequant(weight: Tensor, scale_inv: Tensor, n_q: int, n_kv: int, hd: int, vhd: int, - tp: int = 4, bs: int = 128) -> Tensor: - # MiMo-V2.5 ships qkv_proj sharded TP=4: rows are stacked per-rank as - # [Q_per | K_per | V_per] for r in 0..3. weight_scale_inv has - # ceil(rows_per_rank/bs) block-rows per rank (last may extend past - # rows_per_rank with phantom rows that don't appear in the weight). - # Existing repeat_interleave aligns rank 0 only and mis-applies scales - # to ranks 1..3 once rows_per_rank isn't a multiple of bs. - q_per = (n_q * hd) // tp - k_per = (n_kv * hd) // tp - v_per = (n_kv * vhd) // tp + bs: int = 128) -> Tensor: + # MiMo-V2.5 (TP=4) and V2.5-Pro (TP=8) ship qkv_proj sharded across TP + # ranks; per rank, rows are stacked as [Q_per | K_per | V_per]. + # weight_scale_inv has ceil(rows_per_rank/bs) block-rows per rank (last + # may extend past rows_per_rank with phantom rows not in the weight). + # Naive repeat_interleave aligns rank 0 only and mis-applies scales to + # later ranks once rows_per_rank isn't a multiple of bs. + # Re-group the per-rank [Q_per|K_per|V_per] rows into a unified + # [Q | K | V] layout so the standard split in modify_tensors works. + q_size = n_q * hd + k_size = n_kv * hd + v_size = n_kv * vhd + total_rows = q_size + k_size + v_size + if weight.shape[0] != total_rows: + raise ValueError(f"qkv_proj weight rows {weight.shape[0]} != q+k+v {total_rows}") + + # auto-detect TP from scale_inv shape + tp = None + for cand in (1, 2, 4, 8, 16): + if total_rows % cand != 0: + continue + rpr = total_rows // cand + bpr = (rpr + bs - 1) // bs + if scale_inv.shape[0] == cand * bpr: + tp = cand + break + if tp is None: + raise ValueError( + f"qkv_proj: cannot detect TP — scale_inv rows {scale_inv.shape[0]}, " + f"q+k+v {total_rows}") + + q_per = q_size // tp + k_per = k_size // tp + v_per = v_size // tp rows_per_rank = q_per + k_per + v_per blocks_per_rank = (rows_per_rank + bs - 1) // bs - total_rows = tp * rows_per_rank - if weight.shape[0] != total_rows: - raise ValueError(f"qkv_proj weight rows {weight.shape[0]} != tp*rows_per_rank {total_rows}") - if scale_inv.shape[0] != tp * blocks_per_rank: - raise ValueError(f"scale_inv rows {scale_inv.shape[0]} != tp*blocks_per_rank {tp * blocks_per_rank}") scale_inv = scale_inv.float() # per-row scale-row index: rank * blocks_per_rank + (rr_in_rank // bs) @@ -9329,7 +9348,19 @@ def _tp_aware_qkv_dequant(weight: Tensor, scale_inv: Tensor, scale_full = scale_per_row_block.repeat_interleave(bs, dim=1) # crop to weight col count (in case last col-block isn't full) scale_full = scale_full[:, : weight.shape[1]] - return weight.float() * scale_full + dequant = weight.float() * scale_full + + if tp == 1: + return dequant + + # Re-group per-rank [Q_per|K_per|V_per] rows into unified [Q | K | V] + qs, ks, vs = [], [], [] + for r in range(tp): + base = r * rows_per_rank + qs.append(dequant[base : base + q_per]) + ks.append(dequant[base + q_per : base + q_per + k_per]) + vs.append(dequant[base + q_per + k_per : base + rows_per_rank]) + return torch.cat(qs + ks + vs, dim=0) def dequant_model(self): # Capture raw FP8 (weight, scale_inv) lambdas for qkv_proj BEFORE super @@ -9425,24 +9456,6 @@ def modify_tensors(self, data_torch, name, bid): q_size = num_q_heads * head_dim k_size = num_kv_heads * head_dim v_size = num_kv_heads * v_head_dim - # MiMo-V2.5 ships qkv_proj sharded TP=4: rows are stacked per-rank as - # [Q_per | K_per | V_per] for r in 0..3, not as a single [Q | K | V]. - # Re-group rows from each rank into a unified [Q | K | V] before split. - tp = 4 - total_rows = data_torch.shape[0] - if data_torch.ndim == 2 and total_rows == q_size + k_size + v_size and total_rows % tp == 0: - q_per = q_size // tp - k_per = k_size // tp - v_per = v_size // tp - rows_per_rank = q_per + k_per + v_per - if rows_per_rank * tp == total_rows: - qs, ks, vs = [], [], [] - for r in range(tp): - base = r * rows_per_rank - qs.append(data_torch[base : base + q_per]) - ks.append(data_torch[base + q_per : base + q_per + k_per]) - vs.append(data_torch[base + q_per + k_per : base + rows_per_rank]) - data_torch = torch.cat(qs + ks + vs, dim=0) q, k, v = data_torch.split([q_size, k_size, v_size], dim=0) suffix = ".weight" if name.endswith(".weight") else ".bias" base = name.replace(f"qkv_proj{suffix}", "") From 24364b3673652663fcbf0c51cbb0c593427272b5 Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:33:34 -0700 Subject: [PATCH 5/6] mimo-v2.5: fix TP iteration to be descending --- convert_hf_to_gguf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7d3ccded588..bc495dbad97 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9315,9 +9315,9 @@ def _tp_aware_qkv_dequant(weight: Tensor, scale_inv: Tensor, if weight.shape[0] != total_rows: raise ValueError(f"qkv_proj weight rows {weight.shape[0]} != q+k+v {total_rows}") - # auto-detect TP from scale_inv shape + # detect TP from scale_inv block count, descending order so larger matches first tp = None - for cand in (1, 2, 4, 8, 16): + for cand in (8, 4): if total_rows % cand != 0: continue rpr = total_rows // cand From 027d57567f8506fca217b6c2a840b461369cf950 Mon Sep 17 00:00:00 2001 From: Aes Sedai <7980540+AesSedai@users.noreply.github.com> Date: Thu, 30 Apr 2026 23:41:26 -0700 Subject: [PATCH 6/6] mimo-v2.5: fix comment --- src/llama-hparams.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 704ca81e069..9934717fb5c 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -166,7 +166,7 @@ struct llama_hparams { float f_attn_out_scale = 0.0f; uint32_t attn_temp_length = 0; - // mimo-v2: per-head V scalar applied before attention (1.0 = disabled) + // mimo-v2.5: per-head V scalar applied before attention (1.0 = disabled) float f_attn_value_scale = 1.0f; bool causal_attn = true;