From c702b696bdc7d5d2e011f0d1ac4f317c1a22eaf2 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 6 Dec 2025 03:34:12 +0900 Subject: [PATCH 1/2] [Relax] Introduce ModuleDict --- python/tvm/relax/frontend/nn/__init__.py | 2 +- python/tvm/relax/frontend/nn/core.py | 61 +++++++++++++++++++ .../python/relax/test_frontend_nn_modules.py | 17 ++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/nn/__init__.py b/python/tvm/relax/frontend/nn/__init__.py index f490af7062b0..d9036348835a 100644 --- a/python/tvm/relax/frontend/nn/__init__.py +++ b/python/tvm/relax/frontend/nn/__init__.py @@ -17,7 +17,7 @@ """A PyTorch-like API to build IRModules.""" # pylint: disable=redefined-builtin from . import op, spec -from .core import Effect, Module, ModuleList, Object, Parameter, Tensor +from .core import Effect, Module, ModuleDict, ModuleList, Object, Parameter, Tensor from .exporter import add_extern from .extern import ExternModule, ObjectModule, SourceModule from .modules import ( diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 8529dda00686..b15ba685b76d 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -540,6 +540,56 @@ def _compile(spec, device, pipeline, debug): raise ValueError(f"Unknown out_format: {out_format}") +class ModuleDict(Module): + """Holds submodules in a dict.""" + + def __init__(self, modules: Optional[OrderedDict[str, Module]] = None): + if modules is None: + self.modules = OrderedDict() + else: + self.modules = OrderedDict(modules) + + def __iter__(self): + return iter(self.modules.values()) + + def __getitem__(self, key: str) -> Module: + return self.modules[key] + + def __setitem__(self, key: str, module: Module) -> None: + self.modules[key] = module + + def __len__(self) -> int: + return len(self.modules) + + def keys(self) -> Iterator[str]: + return self.modules.keys() + + def values(self) -> Iterator[Module]: + return self.modules.values() + + def items(self) -> Iterator[Tuple[str, Module]]: + return self.modules.items() + + def get(self, key: str, default: Optional[Module] = None) -> Optional[Module]: + return self.modules.get(key, default) + + def update(self, modules: Dict[str, Module]) -> None: + self.modules.update(modules) + + def clear(self) -> None: + self.modules.clear() + + def pop(self, key: str) -> Module: + return self.modules.pop(key) + + def __contains__(self, key: str) -> bool: + return key in self.modules + + def to(self, dtype: Optional[str] = None) -> None: # pylint: disable=invalid-name + for module in self.modules.values(): + module.to(dtype=dtype) + + class ModuleList(Module): """Holds submodules in a list.""" @@ -611,6 +661,10 @@ def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any] for i, subitem in enumerate(root): yield from _attribute_finder(subitem, prefix + f"{i}.", condition_yield) return + elif isinstance(root, ModuleDict): + for name, subitem in root.items(): + yield from _attribute_finder(subitem, prefix + f"{name}.", condition_yield) + return for name, item in root.__dict__.items(): if condition_yield(item): yield prefix + name, item @@ -620,6 +674,13 @@ def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any] prefix + name + ".", condition_yield, ) + elif isinstance(item, ModuleDict): + for sub_name, sub_item in item.items(): + yield from _attribute_finder( + sub_item, + prefix + name + f".{sub_name}.", + condition_yield, + ) elif isinstance(item, Module): yield from _attribute_finder( item, diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index 23250f28aa9f..e9a4a6f62424 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -715,5 +715,22 @@ def forward(self, x: nn.Tensor): assert ["layers.0.0.weight", "layers.0.1.weight"] == sorted(list(named_params.keys())) +def test_module_dict(): + class Module(nn.Module): + def __init__(self): + self.layers = nn.ModuleDict( + {"linear0": nn.Linear(4, 4, bias=False), "linear1": nn.Linear(4, 4, bias=False)} + ) + + def forward(self, x: nn.Tensor): + x = self.layers["linear0"](x) + x = self.layers["linear1"](x) + return x + + mod = Module() + named_params = dict(mod.named_parameters()) + assert ["layers.linear0.weight", "layers.linear1.weight"] == sorted(list(named_params.keys())) + + if __name__ == "__main__": tvm.testing.main() From 73d1468e8731a79a87871518e16453aa85f8bc75 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 6 Dec 2025 12:44:18 +0900 Subject: [PATCH 2/2] Mutator support --- python/tvm/relax/frontend/nn/visitor.py | 40 +++++++++++- .../python/relax/test_frontend_nn_mutator.py | 63 ++++++++++++++++++- 2 files changed, 99 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/nn/visitor.py b/python/tvm/relax/frontend/nn/visitor.py index 82f301006697..d2467a2bf81d 100644 --- a/python/tvm/relax/frontend/nn/visitor.py +++ b/python/tvm/relax/frontend/nn/visitor.py @@ -79,6 +79,24 @@ def visit_param(self, name: str, node: nn.Effect) -> Any: """ return self.visit(name, node) + def visit_moduledict(self, name: str, node: nn.ModuleDict) -> Any: + """The base visiting method for mutation of nn.ModuleDict nodes. + + Parameters + ---------- + name : str + The name of the current node in parent's attribute. + + node : nn.ModuleDict + The current node of nn.ModuleDict to mutate. + + Returns + ------ + ret_node: Any + The new node to replace current node. + """ + return self.visit(name, node) + def visit_modulelist(self, name: str, node: nn.ModuleList) -> Any: """The base visiting method for mutation of nn.ModuleList nodes. @@ -88,7 +106,7 @@ def visit_modulelist(self, name: str, node: nn.ModuleList) -> Any: The name of the current node in parent's attribute. node : nn.ModuleList - The current node of nn.MoModuleListdule to mutate. + The current node of nn.ModuleList to mutate. Returns ------ @@ -124,7 +142,9 @@ def _get_child_name(parent: str, child: str) -> str: if isinstance(node, nn.ModuleList): for i in range(len(node)): - if isinstance(node[i], nn.ModuleList): + if isinstance(node[i], nn.ModuleDict): + node[i] = self.visit_moduledict(f"{name}.{i}", node[i]) + elif isinstance(node[i], nn.ModuleList): node[i] = self.visit_modulelist(f"{name}.{i}", node[i]) elif isinstance(node[i], nn.Module): node[i] = self.visit_module(f"{name}.{i}", node[i]) @@ -132,9 +152,23 @@ def _get_child_name(parent: str, child: str) -> str: node[i] = self.visit_effect(f"{name}.{i}", node[i]) elif isinstance(node[i], nn.Parameter): node[i] = self.visit_param(f"{name}.{i}", node[i]) + elif isinstance(node, nn.ModuleDict): + for k, v in node.items(): + if isinstance(v, nn.ModuleDict): + node[k] = self.visit_moduledict(_get_child_name(name, k), v) + elif isinstance(v, nn.ModuleList): + node[k] = self.visit_modulelist(_get_child_name(name, k), v) + elif isinstance(v, nn.Module): + node[k] = self.visit_module(_get_child_name(name, k), v) + elif isinstance(v, nn.Effect): + node[k] = self.visit_effect(_get_child_name(name, k), v) + elif isinstance(v, nn.Parameter): + node[k] = self.visit_param(_get_child_name(name, k), v) else: for key, value in node.__dict__.items(): - if isinstance(value, nn.ModuleList): + if isinstance(value, nn.ModuleDict): + setattr(node, key, self.visit_moduledict(_get_child_name(name, key), value)) + elif isinstance(value, nn.ModuleList): setattr(node, key, self.visit_modulelist(_get_child_name(name, key), value)) elif isinstance(value, nn.Module): setattr(node, key, self.visit_module(_get_child_name(name, key), value)) diff --git a/tests/python/relax/test_frontend_nn_mutator.py b/tests/python/relax/test_frontend_nn_mutator.py index ffb6586159b5..253e24a4eddf 100644 --- a/tests/python/relax/test_frontend_nn_mutator.py +++ b/tests/python/relax/test_frontend_nn_mutator.py @@ -65,6 +65,37 @@ def visit_param(self, name: str, node: nn.Parameter) -> Any: mutator.visit("mod3", mod3) +def test_mutator_naming_moduledict(): + class Module(nn.Module): + def __init__(self, dtype) -> None: + super().__init__() + self.param = nn.Parameter((32, 128), dtype) + + class Mutator(nn.Mutator): + def visit_param(self, name: str, node: nn.Parameter) -> Any: + if node.dtype == "float64": + assert name == "mod_dict.k0.0.param" + return node + elif node.dtype == "float32": + assert name == "mod_dict.k0.1.param" + return node + elif node.dtype == "float16": + assert name == "mod_dict.k1.0.param" + return node + elif node.dtype == "float8": + assert name == "mod_dict.k1.1.param" + return node + + mod_dict = nn.ModuleDict( + { + "k0": nn.ModuleList([Module("float64"), Module("float32")]), + "k1": nn.ModuleList([Module("float16"), Module("float8")]), + } + ) + mutator = Mutator() + mutator.visit("mod_dict", mod_dict) + + def test_mutator_naming_modulelist(): class Module(nn.Module): def __init__(self, dtype) -> None: @@ -124,6 +155,37 @@ def visit_module(self, name: str, node: nn.Module) -> Any: assert isinstance(module.mod, SubModule2) +def test_mutator_moduledict(): + class Module1(nn.Module): + def __init__(self) -> None: + super().__init__() + + class Module2(nn.Module): + def __init__(self) -> None: + super().__init__() + + class Module3(nn.Module): + def __init__(self) -> None: + super().__init__() + + class Mutator(nn.Mutator): + def visit_module(self, name: str, node: nn.Module) -> Any: + if isinstance(node, Module3): + return Module1() + else: + return node + + mutator = Mutator() + module_dict = nn.ModuleDict({"k0": Module1(), "k1": Module2(), "k2": Module3()}) + assert isinstance(module_dict["k0"], Module1) + assert isinstance(module_dict["k1"], Module2) + assert isinstance(module_dict["k2"], Module3) + module_dict = mutator.visit("", module_dict) + assert isinstance(module_dict["k0"], Module1) + assert isinstance(module_dict["k1"], Module2) + assert isinstance(module_dict["k2"], Module1) + + def test_mutator_modulelist(): class Module1(nn.Module): def __init__(self) -> None: @@ -150,7 +212,6 @@ def visit_module(self, name: str, node: nn.Module) -> Any: assert isinstance(module_list[1], Module2) assert isinstance(module_list[2], Module3) module_list = mutator.visit("", module_list) - print(module_list[2]) assert isinstance(module_list[0], Module1) assert isinstance(module_list[1], Module2) assert isinstance(module_list[2], Module1)