From 6f791d1e2083b2f6086644e9637637b3aa7b8311 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Thu, 14 May 2026 15:10:22 -0700 Subject: [PATCH] Add active-MoE AutoQuant cost accounting Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 44 ++++- modelopt/torch/quantization/algorithms.py | 176 ++++++++++++++++-- modelopt/torch/quantization/model_quant.py | 125 ++++++++++++- .../unit/torch/quantization/test_autoquant.py | 111 +++++++++++ 4 files changed, 437 insertions(+), 19 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d52c3ee40bb..9adce9ca513 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -382,9 +382,18 @@ def forward_step(model, batch): f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" ) + auto_quantize_constraints = { + "effective_bits": args.auto_quantize_bits, + "cost_model": args.auto_quantize_cost_model, + } + if args.auto_quantize_active_moe_expert_ratio is not None: + auto_quantize_constraints["cost"] = { + "active_moe_expert_ratio": args.auto_quantize_active_moe_expert_ratio + } + language_model, _ = mtq.auto_quantize( language_model, - constraints={"effective_bits": args.auto_quantize_bits}, + constraints=auto_quantize_constraints, data_loader=calib_dataloader, forward_step=forward_step, loss_func=loss_func, # Only used for gradient-based method @@ -1401,6 +1410,27 @@ def parse_args() -> argparse.Namespace: "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified." ), ) + parser.add_argument( + "--auto_quantize_cost_model", + type=str, + default="weight", + choices=["weight", "active_moe"], + help=( + "Cost model for auto_quantize effective-bits accounting. 'weight' counts all " + "quantizable weights equally. 'active_moe' scales routed MoE expert weights by " + "--auto_quantize_active_moe_expert_ratio, or infers top_k/num_experts from model config." + ), + ) + parser.add_argument( + "--auto_quantize_active_moe_expert_ratio", + type=float, + default=None, + help=( + "Routed MoE expert active ratio for --auto_quantize_cost_model active_moe. " + "For top-k MoE this is top_k / num_experts. If omitted, common model config " + "fields such as num_experts_per_tok and num_experts are used when available." + ), + ) parser.add_argument( "--moe_calib_experts_ratio", type=float, @@ -1434,6 +1464,18 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].") + if args.auto_quantize_active_moe_expert_ratio is not None and not ( + 0.0 < args.auto_quantize_active_moe_expert_ratio <= 1.0 + ): + parser.error("--auto_quantize_active_moe_expert_ratio must be in the range (0.0, 1.0].") + if ( + args.auto_quantize_cost_model == "weight" + and args.auto_quantize_active_moe_expert_ratio is not None + ): + parser.error( + "--auto_quantize_active_moe_expert_ratio requires " + "--auto_quantize_cost_model active_moe." + ) if args.specdec_offline_dataset is not None and args.sparsity_fmt != "dense": parser.error("--specdec_offline_dataset is only supported with --sparsity_fmt dense (PTQ).") diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index e4e633e36ae..9576f66c2d0 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -45,6 +45,26 @@ from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import is_quantized_linear +_ROUTED_MOE_EXPERT_NAME_RE = re.compile(r"(^|\.)experts(\.|$)") +_AUTO_QUANTIZE_CONSTRAINT_KEYS = {"effective_bits", "cost_model", "cost"} +_AUTO_QUANTIZE_COST_CONSTRAINT_KEYS = {"active_moe_expert_ratio"} + + +def _is_routed_moe_module_name(name: str) -> bool: + """Return True for routed MoE expert modules, excluding shared experts.""" + return "shared_expert" not in name and _ROUTED_MOE_EXPERT_NAME_RE.search(name) is not None + + +def _get_active_moe_cost_weight( + module_names: Sequence[str], active_moe_expert_ratio: float | None +) -> float: + """Return cost multiplier for the active-MoE cost model.""" + if active_moe_expert_ratio is None: + return 1.0 + if any(_is_routed_moe_module_name(n) for n in module_names): + return active_moe_expert_ratio + return 1.0 + def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float: """Estimate the compression ratio of a quantization configuration. @@ -206,6 +226,7 @@ def __init__( score_modules: list[nn.Module] | None = None, name: str | None = None, quant_module_names: list[str] | None = None, + cost_weight: float = 1.0, ) -> None: """Initializes Hparam with original value and choices.""" choices = sorted({*(choices if choices else []), QuantRecipe(quant_cfg=None)}) @@ -213,6 +234,8 @@ def __init__( self.name = name self.quant_module_names = quant_module_names or [] + assert cost_weight > 0.0, "cost_weight must be positive." + self.cost_weight = cost_weight self.quant_modules = list(set(quant_modules or [])) self.score_modules = list(set(score_modules or self.quant_modules)) @@ -305,15 +328,18 @@ def get_score(self, recipe: QuantRecipe) -> float: total_score += importance.item() return total_score - def get_cost(self, recipe: QuantRecipe) -> float: + def get_cost(self, recipe: QuantRecipe, cost_weight: float | None = None) -> float: """Get the cost for a given recipe. The cost is the total weight size of the quantizable modules multiplied by the compression ratio of the recipe. """ + cost_weight = self.cost_weight if cost_weight is None else cost_weight cost = 0 for quant_module in self.quant_modules: - weight_size = _AutoQuantizeBaseSearcher._get_total_weight_size([quant_module]) + weight_size = ( + _AutoQuantizeBaseSearcher._get_total_weight_size([quant_module]) * cost_weight + ) parallel_state = getattr(quant_module, "parallel_state", None) if parallel_state is None: @@ -343,7 +369,7 @@ def get_cost(self, recipe: QuantRecipe) -> float: @property def attrs(self) -> list[str]: """Return the attributes of the hparam for repr.""" - return ["name", *super().attrs] + return ["name", "cost_weight", *super().attrs] class _AutoQuantizeBaseSearcher(BaseSearcher, ABC): @@ -390,6 +416,9 @@ def default_state_dict(self) -> SearchStateDict: """Get the default state dict for AutoQuantize.""" return { "method": self.method_name, + "cost_model": "weight", + "active_moe_expert_ratio": None, + "cost_denominator": None, "candidate_stats": defaultdict(dict), "quantizer_states": {}, "best": {"recipe": {}, "constraints": {}, "score": float("inf"), "is_satisfied": False}, @@ -407,6 +436,53 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: ) return config + def _get_cost_constraints(self) -> tuple[str, float | None]: + unexpected_constraint_keys = set(self.constraints) - _AUTO_QUANTIZE_CONSTRAINT_KEYS + if unexpected_constraint_keys: + raise ValueError( + f"Unsupported auto_quantize constraints: {unexpected_constraint_keys}. " + "Supported constraints are 'effective_bits', 'cost_model', and 'cost'." + ) + + cost_model = self.constraints.get("cost_model", "weight") + if not isinstance(cost_model, str) or cost_model not in ("weight", "active_moe"): + raise ValueError( + f"Invalid constraints['cost_model']: {cost_model}. " + "Valid options are 'weight' and 'active_moe'." + ) + + cost_constraints = self.constraints.get("cost", {}) + if cost_constraints is None: + cost_constraints = {} + if not isinstance(cost_constraints, dict): + raise ValueError("constraints['cost'] must be a dict when provided.") + unknown_cost_keys = set(cost_constraints) - _AUTO_QUANTIZE_COST_CONSTRAINT_KEYS + if unknown_cost_keys: + raise ValueError(f"Unsupported auto_quantize cost constraints: {unknown_cost_keys}.") + + active_moe_expert_ratio = cost_constraints.get("active_moe_expert_ratio") + if active_moe_expert_ratio is not None: + if not ( + isinstance(active_moe_expert_ratio, (int, float)) + and not isinstance(active_moe_expert_ratio, bool) + and 0.0 < active_moe_expert_ratio <= 1.0 + ): + raise ValueError( + "constraints['cost']['active_moe_expert_ratio'] must be in (0.0, 1.0]." + ) + active_moe_expert_ratio = float(active_moe_expert_ratio) + + if cost_model == "weight" and active_moe_expert_ratio is not None: + raise ValueError( + "constraints['cost']['active_moe_expert_ratio'] requires cost_model='active_moe'." + ) + if cost_model == "active_moe" and active_moe_expert_ratio is None: + raise ValueError( + "constraints['cost']['active_moe_expert_ratio'] must be set when using " + "active_moe cost accounting." + ) + return cost_model, active_moe_expert_ratio + def load_search_checkpoint(self) -> bool: return super().load_search_checkpoint(strict=False) @@ -492,7 +568,9 @@ def _get_score_module_from_name( ) return quant_module - def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers=None): + def insert_hparams_after_merge_rules( + self, model, quant_recipes, disabled_layers=None, active_moe_expert_ratio=None + ): """Restrict the search space using the merge rules and insert the hparams for the model.""" # TRTLLM fuses linear layers such as q_proj, k_proj, v_proj into same layer # Hence we need to restrict the search space so that all these layers share the same recipe @@ -547,6 +625,8 @@ def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers quant_modules = [module for module, _, _, _ in module_info_list] disabled = any(disabled for _, _, disabled, _ in module_info_list) score_modules = [score_module for _, _, _, score_module in module_info_list] + quant_module_names = [name for _, name, _, _ in module_info_list] + cost_weight = _get_active_moe_cost_weight(quant_module_names, active_moe_expert_ratio) _quant_recipes = None if disabled else quant_recipes hparam = QuantRecipeHparam( @@ -554,7 +634,8 @@ def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers quant_modules=quant_modules, score_modules=score_modules, name=str(group_key), - quant_module_names=[name for _, name, _, _ in module_info_list], + quant_module_names=quant_module_names, + cost_weight=cost_weight, ) for module in quant_modules: @@ -586,23 +667,30 @@ def initialize_candidate_stats(self): if not isinstance(hparam, QuantRecipeHparam): continue - formats, scores, costs = [], [], [] + formats, scores, costs, active_costs = [], [], [], [] prev_score = float("inf") + constraint_cost_weight = ( + hparam.cost_weight if self.config["cost_model"] == "active_moe" else 1.0 + ) for recipe in hparam.choices: formats.append(recipe) score = hparam.get_score(recipe) # type: ignore [arg-type] - cost = hparam.get_cost(recipe) # type: ignore [arg-type] + cost = hparam.get_cost(recipe, cost_weight=constraint_cost_weight) # type: ignore [arg-type] + active_cost = hparam.get_cost(recipe, cost_weight=hparam.cost_weight) # type: ignore [arg-type] score = min(score, prev_score) # TODO: Should we get rid of this? scores.append(score) costs.append(cost) + active_costs.append(active_cost) prev_score = score self.candidate_stats[name]["formats"] = formats self.candidate_stats[name]["scores"] = scores self.candidate_stats[name]["costs"] = costs + self.candidate_stats[name]["active_costs"] = active_costs self.candidate_stats[name]["module_names"] = hparam.quant_module_names + self.candidate_stats[name]["cost_weight"] = hparam.cost_weight def _run_func(self, func, num_iters=1, desc=""): for i, data in tqdm( @@ -621,18 +709,39 @@ def before_search(self): from .utils import get_quantizer_state_dict, set_quantizer_state_dict super().before_search() + self.config["cost_model"], self.config["active_moe_expert_ratio"] = ( + self._get_cost_constraints() + ) restored_method = getattr(self, "method", None) if self.candidate_stats and restored_method not in (None, self.method_name): raise ValueError( f"Checkpoint method '{restored_method}' does not match current method " f"'{self.method_name}'. Use a different checkpoint path." ) + restored_cost_model = getattr(self, "cost_model", "weight") + restored_active_moe_expert_ratio = getattr(self, "active_moe_expert_ratio", None) + if self.candidate_stats and ( + restored_cost_model != self.config["cost_model"] + or restored_active_moe_expert_ratio != self.config["active_moe_expert_ratio"] + ): + raise ValueError( + "Checkpoint AutoQuantize cost model does not match current search config: " + f"checkpoint=({restored_cost_model}, {restored_active_moe_expert_ratio}), " + f"current=({self.config['cost_model']}, {self.config['active_moe_expert_ratio']}). " + "Use a different checkpoint path." + ) self.method = self.method_name + self.cost_model = self.config["cost_model"] + self.active_moe_expert_ratio = self.config["active_moe_expert_ratio"] + self.cost_denominator = getattr(self, "cost_denominator", None) search_recipes = self._get_search_recipes(self.config["quantization_formats"]) self._verify_constraint(search_recipes) self.insert_hparams_after_merge_rules( - self.model, search_recipes, self.config["disabled_layers"] + self.model, + search_recipes, + self.config["disabled_layers"], + self.config["active_moe_expert_ratio"], ) QuantRecipe.disable_folding_pqs_to_weights() @@ -722,6 +831,17 @@ def _get_total_weight_size(modules): for module in modules ) + @staticmethod + def _get_total_weight_size_from_named_modules(named_modules, active_moe_expert_ratio=None): + total_weight_size = 0.0 + for name, module in named_modules: + if not _AutoQuantizeBaseSearcher._is_auto_quantize_module(module): + continue + total_weight_size += module.weight.numel() * _get_active_moe_cost_weight( + [name], active_moe_expert_ratio + ) + return total_weight_size + def _get_constraints_for_search(self, max_weight_size, lower_bound=None): constraints = { "weight_size_after_compression": ( @@ -731,6 +851,12 @@ def _get_constraints_for_search(self, max_weight_size, lower_bound=None): } return constraints, "weight_size_after_compression" + def _get_search_lower_bounds(self): + cost_model = getattr(self, "cost_model", getattr(self, "config", {}).get("cost_model")) + if cost_model == "active_moe": + return [0.99, 0.90, None] + return [None, 0.99, 0.90] + @abstractmethod def run_search_with_stats(self, max_weight_size, verbose=False): """Run the search with stats to get the best recipe and whether the constraints are satisfied.""" @@ -738,14 +864,32 @@ def run_search_with_stats(self, max_weight_size, verbose=False): def run_search(self): """Search for the best per-layer quantization configuration and return the best model and configuration.""" verbose = self.config["verbose"] - assert len(self.constraints) == 1 and "effective_bits" in self.constraints, ( - f"`constraints` must contain only 'effective_bits' constraint. " - f"Got {self.constraints.keys()}" + assert "effective_bits" in self.constraints and ( + set(self.constraints) <= _AUTO_QUANTIZE_CONSTRAINT_KEYS + ), ( + "`constraints` must contain 'effective_bits' and may contain 'cost_model' and 'cost'. " + f"Got {self.constraints.keys()}." ) compression = self._get_formatted_weight_compression_constraint() - total_weight_size = self._get_total_weight_size(self.model.modules()) + if self.config["cost_model"] == "active_moe": + total_weight_size = self._get_total_weight_size_from_named_modules( + self.model.named_modules(), self.config["active_moe_expert_ratio"] + ) + else: + total_weight_size = self._get_total_weight_size(self.model.modules()) + self.cost_denominator = total_weight_size max_weight_size = total_weight_size * compression + if verbose: + print_rank_0( + "AutoQuantize cost model: " + f"{self.config['cost_model']}" + + ( + f" (active_moe_expert_ratio={self.config['active_moe_expert_ratio']})" + if self.config["cost_model"] == "active_moe" + else "" + ) + ) # Run the search with stats to get the best recipe and whether the constraints are satisfied best_recipe_info, is_satisfied = self.run_search_with_stats(max_weight_size, verbose) @@ -1050,7 +1194,7 @@ def run_search_with_stats(self, max_weight_size, verbose=False): """ # TODO: Do this only for rank 0 in the respective pipeline group - for lower_bound in [None, 0.99, 0.90]: + for lower_bound in self._get_search_lower_bounds(): # The LP solver for auto_quantize sometimes fails to find a solution if a lower bound is not # specified. I dont know why this happens. # As a workaround, lets specify a lower bound for the weight compression if previous @@ -1379,7 +1523,9 @@ def _resolve_best_recipe(search_state, constraints, verbose=False): effective_bits = constraints["effective_bits"] compression = effective_bits / 16.0 candidate_stats = search_state["candidate_stats"] - total_weight_size = sum(s["costs"][-1] for s in candidate_stats.values()) + total_weight_size = search_state.get("cost_denominator") or sum( + s["costs"][-1] for s in candidate_stats.values() + ) max_weight_size = total_weight_size * compression method = search_state["method"] @@ -1393,6 +1539,8 @@ def _resolve_best_recipe(search_state, constraints, verbose=False): ) searcher.candidate_stats = candidate_stats + searcher.cost_model = search_state.get("cost_model", "weight") + searcher.config = {"cost_model": searcher.cost_model} best_recipe_info, _ = searcher.run_search_with_stats(max_weight_size, verbose=verbose) best_recipe = {name: info["format"] for name, info in best_recipe_info.items()} diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 3582223c4d3..a375eab07f2 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -265,10 +265,115 @@ def forward_loop(model) -> None: "awq_clip", } +_ACTIVE_MOE_TOP_K_ATTRS = ( + "num_experts_per_tok", + "num_experts_per_token", + "moe_top_k", + "top_k", + "num_selected_experts", +) +_ACTIVE_MOE_NUM_EXPERTS_ATTRS = ( + "num_experts", + "num_local_experts", + "n_routed_experts", + "moe_num_experts", + "num_routed_experts", +) + + +def _iter_model_configs(model: nn.Module): + seen = set() + for obj in (model, getattr(model, "model", None), getattr(model, "language_model", None)): + config = getattr(obj, "config", None) + if config is None or id(config) in seen: + continue + seen.add(id(config)) + yield config + for nested_attr in ("text_config", "language_config"): + nested_config = getattr(config, nested_attr, None) + if nested_config is None or id(nested_config) in seen: + continue + seen.add(id(nested_config)) + yield nested_config + + +def _get_first_numeric_config_attr(config: Any, attr_names: tuple[str, ...]) -> float | None: + for attr_name in attr_names: + value = getattr(config, attr_name, None) + if isinstance(value, (int, float)) and not isinstance(value, bool): + return float(value) + return None + + +def _infer_active_moe_expert_ratio(model: nn.Module) -> float | None: + for config in _iter_model_configs(model): + num_active_experts = _get_first_numeric_config_attr(config, _ACTIVE_MOE_TOP_K_ATTRS) + num_experts = _get_first_numeric_config_attr(config, _ACTIVE_MOE_NUM_EXPERTS_ATTRS) + if num_active_experts is None or num_experts is None or num_experts <= 0: + continue + ratio = num_active_experts / num_experts + if ratio <= 0.0: + continue + return min(ratio, 1.0) + return None + + +def _normalize_auto_quantize_constraints( + model: nn.Module, constraints: dict[str, Any] | None +) -> dict[str, Any]: + constraints = {"effective_bits": 4.8} if constraints is None else dict(constraints) + cost_model = constraints.get("cost_model", "weight") + if cost_model not in ("weight", "active_moe"): + raise ValueError( + f"Invalid constraints['cost_model']: {cost_model}. " + "Valid options are 'weight' and 'active_moe'." + ) + + cost_constraints = constraints.get("cost", {}) + if cost_constraints is None: + cost_constraints = {} + if not isinstance(cost_constraints, dict): + raise ValueError("constraints['cost'] must be a dict when provided.") + cost_constraints = dict(cost_constraints) + + unknown_cost_keys = set(cost_constraints) - {"active_moe_expert_ratio"} + if unknown_cost_keys: + raise ValueError(f"Unsupported auto_quantize cost constraints: {unknown_cost_keys}.") + + active_moe_expert_ratio = cost_constraints.get("active_moe_expert_ratio") + if active_moe_expert_ratio is not None: + if not ( + isinstance(active_moe_expert_ratio, (int, float)) + and not isinstance(active_moe_expert_ratio, bool) + and 0.0 < active_moe_expert_ratio <= 1.0 + ): + raise ValueError( + "constraints['cost']['active_moe_expert_ratio'] must be in (0.0, 1.0]." + ) + cost_constraints["active_moe_expert_ratio"] = float(active_moe_expert_ratio) + + if cost_model == "weight" and active_moe_expert_ratio is not None: + raise ValueError( + "constraints['cost']['active_moe_expert_ratio'] requires cost_model='active_moe'." + ) + if cost_model == "active_moe" and active_moe_expert_ratio is None: + active_moe_expert_ratio = _infer_active_moe_expert_ratio(model) + if active_moe_expert_ratio is None: + raise ValueError( + "Could not infer active_moe_expert_ratio from model.config. " + "Pass it via constraints['cost']['active_moe_expert_ratio']." + ) + cost_constraints["active_moe_expert_ratio"] = active_moe_expert_ratio + + constraints["cost_model"] = cost_model + if cost_constraints or cost_model == "active_moe": + constraints["cost"] = cost_constraints + return constraints + def auto_quantize( model: nn.Module, - constraints: dict[str, float | str] = {"effective_bits": 4.8}, + constraints: dict[str, Any] | None = None, quantization_formats: list[dict[str, Any] | str] = [ mtq.NVFP4_AWQ_LITE_CFG, mtq.FP8_DEFAULT_CFG, @@ -301,8 +406,11 @@ def auto_quantize( Args: model: A pytorch model with quantizer modules. - constraints: Constraints for the search. Currently we support only ``effective_bits``. - ``effective_bits`` specifies the effective number of bits for the quantized model. + constraints: Constraints for the search. ``effective_bits`` specifies the effective number + of bits for the quantized model. ``cost_model`` selects the metric used for the + effective-bits constraint and currently supports ``"weight"`` (default) and + ``"active_moe"``. Additional cost-model parameters are provided through the nested + ``cost`` dict. Here is an example for valid ``effective_bits`` argument: @@ -311,6 +419,13 @@ def auto_quantize( # For an effective quantization bits of 4.8 constraints = {"effective_bits": 4.8} + # For active-MoE accounting where 2 of 8 routed experts are active per token + constraints = { + "effective_bits": 4.8, + "cost_model": "active_moe", + "cost": {"active_moe_expert_ratio": 0.25}, + } + quantization_formats: A list of quantization format config dictionaries or string names to search for. Each config dictionary should be valid as a ``config`` argument in :meth:`quantize `. @@ -514,6 +629,8 @@ def forward_backward_step(model, batch) -> None: else: raise ValueError(f"Invalid method: {method}. Valid options are 'gradient' or 'kl_div'.") + constraints = _normalize_auto_quantize_constraints(model, constraints) + model = apply_mode( model, mode="auto_quantize", @@ -533,7 +650,7 @@ def forward_backward_step(model, batch) -> None: } # Disable all quantizers; AutoQuantize will enable the needed ones set_quantizer_by_cfg(model, [{"quantizer_name": "*", "enable": False}]) - searcher.search(model, constraints, config=search_config) # type: ignore[arg-type] + searcher.search(model, constraints, config=search_config) return model, searcher.state_dict() diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index 87ec73291e7..99676fc5267 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -15,6 +15,7 @@ import copy import io +from types import SimpleNamespace import pytest import torch @@ -24,11 +25,13 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.quantization.algorithms import ( + AutoQuantizeGradientSearcher, QuantRecipe, QuantRecipeHparam, estimate_quant_compression, ) from modelopt.torch.quantization.config import _base_disable_all, _default_disabled_quantizer_cfg +from modelopt.torch.quantization.model_quant import _infer_active_moe_expert_ratio from modelopt.torch.utils import safe_load from modelopt.torch.utils.distributed import DistributedProcessGroup @@ -62,6 +65,31 @@ def get_input(self): return torch.randn(1, 4, 32) +class _AutoQuantMoeModel(torch.nn.Module): + def __init__(self, num_experts_attr="num_experts"): + super().__init__() + self.config = SimpleNamespace(text_config=SimpleNamespace(num_experts_per_tok=2)) + setattr(self.config.text_config, num_experts_attr, 8) + self.mlp = torch.nn.Module() + self.mlp.experts = torch.nn.ModuleList() + for _ in range(2): + expert = torch.nn.Module() + expert.gate_proj = torch.nn.Linear(32, 32) + expert.up_proj = torch.nn.Linear(32, 32) + expert.down_proj = torch.nn.Linear(32, 32) + self.mlp.experts.append(expert) + self.mlp.shared_expert = torch.nn.Linear(32, 32) + + def forward(self, x): + y = self.mlp.shared_expert(x) + for expert in self.mlp.experts: + y = y + expert.down_proj(expert.gate_proj(x) + expert.up_proj(x)) + return y + + def get_input(self): + return torch.randn(1, 4, 32) + + @pytest.mark.parametrize( ("quant_cfg", "other_quant_cfg", "is_less_than"), [ @@ -109,6 +137,89 @@ def test_quant_recipe_hparam(): assert torch.allclose(output_test, output_ref) +def test_quant_recipe_hparam_cost_weight(): + model_test = mtq.quantize(torch.nn.Linear(4, 16), mtq.INT8_DEFAULT_CFG) + search_recipes = [QuantRecipe(mtq.INT8_DEFAULT_CFG)] + hparam = QuantRecipeHparam( + search_recipes, + quant_modules=[model_test], + quant_module_names=["layers.0.mlp.experts.0.down_proj"], + cost_weight=0.25, + ) + + dense_cost = hparam.get_cost(QuantRecipe(quant_cfg=None)) + int8_cost = hparam.get_cost(QuantRecipe(mtq.INT8_DEFAULT_CFG)) + + assert dense_cost == pytest.approx(model_test.weight.numel() * 0.25) + assert int8_cost == pytest.approx(model_test.weight.numel() * 0.25 * 0.5) + + +@pytest.mark.parametrize("num_experts_attr", ["num_experts", "num_local_experts"]) +def test_auto_quantize_active_moe_cost_model(num_experts_attr): + model = _AutoQuantMoeModel(num_experts_attr) + + _, search_history = mtq.auto_quantize( + model, + constraints={"effective_bits": 6.0, "cost_model": "active_moe"}, + quantization_formats=[mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG], + data_loader=[model.get_input() for _ in range(2)], + forward_step=lambda model, batch: model(batch), + loss_func=lambda output, data: output.sum(), + num_calib_steps=2, + num_score_steps=2, + ) + + assert search_history["cost_model"] == "active_moe" + assert search_history["active_moe_expert_ratio"] == pytest.approx(0.25) + weighted_no_quant_cost = sum( + stats["costs"][-1] for stats in search_history["candidate_stats"].values() + ) + assert search_history["cost_denominator"] == pytest.approx(weighted_no_quant_cost) + routed_stats = [ + stats + for stats in search_history["candidate_stats"].values() + if any("mlp.experts" in name for name in stats["module_names"]) + ] + shared_stats = [ + stats + for stats in search_history["candidate_stats"].values() + if any("mlp.shared_expert" in name for name in stats["module_names"]) + ] + assert routed_stats + assert shared_stats + assert all(stats["cost_weight"] == pytest.approx(0.25) for stats in routed_stats) + assert all(stats["cost_weight"] == pytest.approx(1.0) for stats in shared_stats) + assert all("active_costs" in stats for stats in search_history["candidate_stats"].values()) + + +def test_active_moe_ratio_requires_single_config_object(): + model = torch.nn.Module() + model.config = SimpleNamespace( + num_experts_per_tok=2, + text_config=SimpleNamespace(num_experts=8), + ) + + assert _infer_active_moe_expert_ratio(model) is None + + +def test_active_moe_search_prefers_budget_lower_bound(): + searcher = AutoQuantizeGradientSearcher() + searcher.config = {"cost_model": "active_moe"} + searcher.cost_model = "active_moe" + searcher.candidate_stats = { + "layers.0.mlp.quant_recipe": { + "formats": ["under_budget", "near_budget"], + "costs": [1.0, 4.95], + "scores": [0.0, 10.0], + } + } + + best_recipes, is_satisfied = searcher.run_search_with_stats(5.0) + + assert is_satisfied + assert best_recipes["layers.0.mlp.quant_recipe"]["format"] == "near_budget" + + # use this config to test custom quantization config INT8_CUSTOM_QUANT_TEST_CFG = { "quant_cfg": [