Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,18 @@ def forward_step(model, batch):
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
)

auto_quantize_constraints = {
"effective_bits": args.auto_quantize_bits,
"cost_model": args.auto_quantize_cost_model,
}
if args.auto_quantize_active_moe_expert_ratio is not None:
auto_quantize_constraints["cost"] = {
"active_moe_expert_ratio": args.auto_quantize_active_moe_expert_ratio
}

language_model, _ = mtq.auto_quantize(
language_model,
constraints={"effective_bits": args.auto_quantize_bits},
constraints=auto_quantize_constraints,
data_loader=calib_dataloader,
forward_step=forward_step,
loss_func=loss_func, # Only used for gradient-based method
Expand Down Expand Up @@ -1401,6 +1410,27 @@ def parse_args() -> argparse.Namespace:
"(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
),
)
parser.add_argument(
"--auto_quantize_cost_model",
type=str,
default="weight",
choices=["weight", "active_moe"],
help=(
"Cost model for auto_quantize effective-bits accounting. 'weight' counts all "
"quantizable weights equally. 'active_moe' scales routed MoE expert weights by "
"--auto_quantize_active_moe_expert_ratio, or infers top_k/num_experts from model config."
),
)
parser.add_argument(
"--auto_quantize_active_moe_expert_ratio",
type=float,
default=None,
help=(
"Routed MoE expert active ratio for --auto_quantize_cost_model active_moe. "
"For top-k MoE this is top_k / num_experts. If omitted, common model config "
"fields such as num_experts_per_tok and num_experts are used when available."
),
)
parser.add_argument(
"--moe_calib_experts_ratio",
type=float,
Expand Down Expand Up @@ -1434,6 +1464,18 @@ def parse_args() -> argparse.Namespace:
args = parser.parse_args()
if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0):
parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].")
if args.auto_quantize_active_moe_expert_ratio is not None and not (
0.0 < args.auto_quantize_active_moe_expert_ratio <= 1.0
):
parser.error("--auto_quantize_active_moe_expert_ratio must be in the range (0.0, 1.0].")
if (
args.auto_quantize_cost_model == "weight"
and args.auto_quantize_active_moe_expert_ratio is not None
):
parser.error(
"--auto_quantize_active_moe_expert_ratio requires "
"--auto_quantize_cost_model active_moe."
)

if args.specdec_offline_dataset is not None and args.sparsity_fmt != "dense":
parser.error("--specdec_offline_dataset is only supported with --sparsity_fmt dense (PTQ).")
Expand Down
Loading
Loading