Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 132 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9292,10 +9292,114 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("MiMoV2FlashForCausalLM")
@ModelBase.register("MiMoV2FlashForCausalLM", "MiMoV2ForCausalLM")
class MimoV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.MIMO2

@staticmethod
def _tp_aware_qkv_dequant(weight: Tensor, scale_inv: Tensor,
n_q: int, n_kv: int, hd: int, vhd: int,
bs: int = 128) -> Tensor:
# MiMo-V2.5 (TP=4) and V2.5-Pro (TP=8) ship qkv_proj sharded across TP
# ranks; per rank, rows are stacked as [Q_per | K_per | V_per].
# weight_scale_inv has ceil(rows_per_rank/bs) block-rows per rank (last
# may extend past rows_per_rank with phantom rows not in the weight).
# Naive repeat_interleave aligns rank 0 only and mis-applies scales to
# later ranks once rows_per_rank isn't a multiple of bs.
# Re-group the per-rank [Q_per|K_per|V_per] rows into a unified
# [Q | K | V] layout so the standard split in modify_tensors works.
q_size = n_q * hd
k_size = n_kv * hd
v_size = n_kv * vhd
total_rows = q_size + k_size + v_size
if weight.shape[0] != total_rows:
raise ValueError(f"qkv_proj weight rows {weight.shape[0]} != q+k+v {total_rows}")

# detect TP from scale_inv block count, descending order so larger matches first
tp = None
for cand in (8, 4):
if total_rows % cand != 0:
continue
rpr = total_rows // cand
bpr = (rpr + bs - 1) // bs
if scale_inv.shape[0] == cand * bpr:
tp = cand
break
if tp is None:
raise ValueError(
f"qkv_proj: cannot detect TP — scale_inv rows {scale_inv.shape[0]}, "
f"q+k+v {total_rows}")

q_per = q_size // tp
k_per = k_size // tp
v_per = v_size // tp
rows_per_rank = q_per + k_per + v_per
blocks_per_rank = (rows_per_rank + bs - 1) // bs

scale_inv = scale_inv.float()
# per-row scale-row index: rank * blocks_per_rank + (rr_in_rank // bs)
row_idx = torch.arange(total_rows)
rr = row_idx % rows_per_rank
rank = row_idx // rows_per_rank
scale_row_idx = rank * blocks_per_rank + (rr // bs)
# gather: (total_rows, n_col_blocks)
scale_per_row_block = scale_inv[scale_row_idx]
# expand col-blocks → cols: each block-col covers `bs` weight cols
scale_full = scale_per_row_block.repeat_interleave(bs, dim=1)
# crop to weight col count (in case last col-block isn't full)
scale_full = scale_full[:, : weight.shape[1]]
dequant = weight.float() * scale_full

if tp == 1:
return dequant

# Re-group per-rank [Q_per|K_per|V_per] rows into unified [Q | K | V]
qs, ks, vs = [], [], []
for r in range(tp):
base = r * rows_per_rank
qs.append(dequant[base : base + q_per])
ks.append(dequant[base + q_per : base + q_per + k_per])
vs.append(dequant[base + q_per + k_per : base + rows_per_rank])
return torch.cat(qs + ks + vs, dim=0)

def dequant_model(self):
# Capture raw FP8 (weight, scale_inv) lambdas for qkv_proj BEFORE super
# rewrites them with the existing dequant. Replace super's lambda after
# it runs so scale_inv removal still happens via the standard path.
qkv_overrides: dict[str, tuple[Callable, Callable, int]] = {}
qc = self.hparams.get("quantization_config")
if isinstance(qc, dict) and qc.get("quant_method") == "fp8":
pat = re.compile(r"^model\.layers\.(\d+)\.self_attn\.qkv_proj\.weight_scale_inv$")
for name in list(self.model_tensors.keys()):
m = pat.match(name)
if not m:
continue
weight_name = name.removesuffix("_scale_inv")
if weight_name not in self.model_tensors:
continue
qkv_overrides[weight_name] = (
self.model_tensors[weight_name],
self.model_tensors[name],
int(m.group(1)),
)

super().dequant_model()

if not qkv_overrides:
return

n_q = self.hparams["num_attention_heads"]
hd = self.hparams["head_dim"]
vhd = self.hparams["v_head_dim"]
hybrid = self.hparams["hybrid_layer_pattern"]
for weight_name, (w_fn, s_fn, bid) in qkv_overrides.items():
is_swa = hybrid[bid] == 1
n_kv = self.hparams["swa_num_key_value_heads" if is_swa else "num_key_value_heads"]
self.model_tensors[weight_name] = (
lambda w_fn=w_fn, s_fn=s_fn, n_q=n_q, n_kv=n_kv, hd=hd, vhd=vhd:
MimoV2Model._tp_aware_qkv_dequant(w_fn(), s_fn(), n_q, n_kv, hd, vhd)
)

def set_gguf_parameters(self):
super().set_gguf_parameters()

Expand All @@ -9320,6 +9424,10 @@ def set_gguf_parameters(self):

self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5))

v_scale = self.hparams.get("attention_value_scale")
if v_scale is not None:
self.gguf_writer.add_attn_value_scale(float(v_scale))

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch, name, bid):
Expand All @@ -9333,6 +9441,29 @@ def modify_tensors(self, data_torch, name, bid):
if "model.mtp." in name:
return

# MiMo-V2.5 (non-Pro) ships audio/visual modules we don't run in llama.cpp
if name.startswith(("audio_encoder.", "visual.", "speech_embeddings.")):
return

# split fused qkv_proj into separate q/k/v tensors (MiMoV2ForCausalLM uses fused_qkv layout)
if "self_attn.qkv_proj" in name:
assert bid is not None
is_swa = self.hparams["hybrid_layer_pattern"][bid] == 1
num_q_heads = self.hparams["swa_num_attention_heads" if is_swa else "num_attention_heads"]
num_kv_heads = self.hparams["swa_num_key_value_heads" if is_swa else "num_key_value_heads"]
head_dim = self.hparams["swa_head_dim" if is_swa else "head_dim"]
v_head_dim = self.hparams["swa_v_head_dim" if is_swa else "v_head_dim"]
q_size = num_q_heads * head_dim
k_size = num_kv_heads * head_dim
v_size = num_kv_heads * v_head_dim
q, k, v = data_torch.split([q_size, k_size, v_size], dim=0)
suffix = ".weight" if name.endswith(".weight") else ".bias"
base = name.replace(f"qkv_proj{suffix}", "")
yield from super().modify_tensors(q, f"{base}q_proj{suffix}", bid)
yield from super().modify_tensors(k, f"{base}k_proj{suffix}", bid)
yield from super().modify_tensors(v, f"{base}v_proj{suffix}", bid)
return

# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["n_routed_experts"]
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class Attention:
SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale"
OUTPUT_SCALE = "{arch}.attention.output_scale"
VALUE_SCALE = "{arch}.attention.value_scale"
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,9 @@ def add_attention_scale(self, value: float) -> None:
def add_attn_output_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)

def add_attn_value_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.VALUE_SCALE.format(arch=self.arch), value)

def add_attn_temperature_length(self, value: int) -> None:
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)

Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
{ LLM_KV_ATTENTION_VALUE_SCALE, "%s.attention.value_scale" },
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
{ LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ enum llm_kv {
LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN,
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_OUTPUT_SCALE,
LLM_KV_ATTENTION_VALUE_SCALE,
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
Expand Down
3 changes: 3 additions & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ struct llama_hparams {
float f_attn_out_scale = 0.0f;
uint32_t attn_temp_length = 0;

// mimo-v2.5: per-head V scalar applied before attention (1.0 = disabled)
float f_attn_value_scale = 1.0f;

bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;
Expand Down
1 change: 1 addition & 0 deletions src/llama-model-saver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ void llama_model_saver::add_kv_from_model() {
// add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, ???);
add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
add_kv(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale);
add_kv(LLM_KV_ATTENTION_VALUE_SCALE, hparams.f_attn_value_scale);
add_kv(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length);
add_kv(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale);
add_kv(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl);
Expand Down
2 changes: 2 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2896,6 +2896,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_SCALE, hparams.f_attn_value_scale, false);
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);

switch (hparams.n_layer) {
Expand Down Expand Up @@ -8182,6 +8183,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale);
LLAMA_LOG_INFO("%s: f_attn_value_scale = %.4f\n", __func__, hparams.f_attn_value_scale);
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
Expand Down
8 changes: 8 additions & 0 deletions src/models/mimo2-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids();

const float v_scale = hparams.f_attn_value_scale;

for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;

Expand All @@ -35,6 +37,11 @@ llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);

if (v_scale != 1.0f) {
Vcur = ggml_scale(ctx0, Vcur, v_scale);
cb(Vcur, "Vcur_scaled", il);
}

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens);
Expand All @@ -60,6 +67,7 @@ llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_
cur = build_attn(inp_attn,
model.layers[il].wo, NULL, model.layers[il].wo_s,
Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il);
cb(cur, "attn_out", il);
}

if (il == n_layer - 1 && inp_out_ids) {
Expand Down
Loading