From c94b062355f3815d6c456687b96af93b20766752 Mon Sep 17 00:00:00 2001 From: beomwookang Date: Mon, 16 Mar 2026 16:07:11 +0900 Subject: [PATCH 1/4] fix(arm): validate partitions for dependency cycles after Q/DQ de-tagging `_detag_boundary_nodes` removes Q/DQ nodes from partition boundaries after `CapabilityBasedPartitioner` has produced cycle-free partitions. However, this de-tagging can introduce dependency cycles for models with complex attention blocks (e.g. MobileViT, where CNN and Transformer ops are grouped into a single large partition). The cycle occurs because removing Q/DQ bridge nodes creates paths that exit the partition and re-enter it through the now-unpartitioned nodes, making it impossible to extract the partition as a valid subgraph. This change adds cycle validation after `_detag_boundary_nodes`. When a cycle is detected, the partition is split into connected components of the surviving (still-tagged) nodes. Each component becomes a separate partition that is individually cycle-free after de-tagging. - Add `_validate_partition()`: BFS-based cycle detection (same algorithm as `torch.fx.passes.utils.fuser_utils.validate_partition`) - Add `_find_connected_components()`: undirected graph traversal to split surviving nodes into disjoint sub-partitions - Guard the nocompute-partition `tags.remove()` against already-removed tags from the cycle-split path Tested with MobileViT-S on Ethos-U85: previously failed with `AssertionError: Invalid partition, found dependency cycles`, now successfully produces a .pte file (5.7 MB). Nine attention-block partitions are each split into 3 sub-partitions. All sub-partitions remain on NPU (no CPU fallback). Existing CNN-only models (ResNet, MobileNetV2, EfficientNet) are unaffected as their partitions have no cycles after de-tagging. --- backends/arm/tosa/partitioner.py | 99 +++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 1 deletion(-) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index a7ef79abbef..877eee2e0f0 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -14,6 +14,7 @@ """ import logging +from collections import deque from itertools import count from typing import Callable, List, Optional, Sequence, Tuple @@ -131,6 +132,77 @@ def reject_partition( ) +def _validate_partition(nodes: set[torch.fx.Node]) -> bool: + """Check whether a set of nodes can be extracted as a subgraph without + cycles. + + Perform a BFS from the external users of partition nodes. If any node + reached by BFS is itself inside the partition, then extracting the + partition would create a dependency cycle in the remaining graph. + + Args: + nodes: The set of FX nodes that form the partition. + + Returns: + True if the partition is valid (no cycles), False otherwise. + + """ + outputs: list[torch.fx.Node] = [] + for node in nodes: + for user in node.users: + if user not in nodes: + outputs.append(user) + + visited: set[torch.fx.Node] = set() + queue = deque(outputs) + while queue: + current = queue.popleft() + if current in visited: + continue + visited.add(current) + if current in nodes: + return False + for user in current.users: + if user not in visited: + queue.append(user) + return True + + +def _find_connected_components(nodes: set[torch.fx.Node]) -> list[set[torch.fx.Node]]: + """Find connected components in a set of nodes treating edges as undirected. + + Two nodes are connected if one is an input or user of the other and both + are in ``nodes``. + + Args: + nodes: The node set to partition into components. + + Returns: + A list of disjoint node sets, one per connected component. + + """ + remaining = set(nodes) + components: list[set[torch.fx.Node]] = [] + while remaining: + seed = next(iter(remaining)) + component: set[torch.fx.Node] = set() + queue = deque([seed]) + while queue: + node = queue.popleft() + if node in component or node not in remaining: + continue + component.add(node) + for inp in node.all_input_nodes: + if inp in remaining and inp not in component: + queue.append(inp) + for user in node.users: + if user in remaining and user not in component: + queue.append(user) + remaining -= component + components.append(component) + return components + + class TOSAPartitioner(Partitioner): """Partition an exported program into TOSA-delegable subgraphs. @@ -285,6 +357,30 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: reporter, ) + # After de-tagging, the remaining tagged nodes may form + # dependency cycles. This happens when models contain complex + # attention blocks (e.g. MobileViT) where Q/DQ nodes act as + # bridges between partition segments. Detect such cycles and + # split the partition into valid connected components. + surviving = {n for n in partition.nodes if is_partitioned(n, tag)} + if surviving and not _validate_partition(surviving): + components = _find_connected_components(surviving) + logger.info( + f"Partition {tag} has dependency cycle after Q/DQ " + f"de-tagging. Splitting into {len(components)} " + f"sub-partition(s)." + ) + # Remove the original tag from all nodes + for node in surviving: + del node.meta["delegation_tag"] + tags.remove(tag) + # Re-tag each connected component as a new partition + for component in components: + new_tag = f"tag{next(tag_iterator)}" + tags.add(new_tag) + for node in component: + node.meta["delegation_tag"] = new_tag + # Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation." is_nocompute_partition = all( _is_noop_clone(node) @@ -303,7 +399,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: partition, reporter, ) - tags.remove(tag) + if tag in tags: + tags.remove(tag) return tags def partition(self, exported_program: ExportedProgram) -> PartitionResult: From 7b7c5dfd27860ab5ebcf255b61cfa887640ddbb8 Mon Sep 17 00:00:00 2001 From: beomwookang Date: Thu, 23 Apr 2026 10:42:21 +0900 Subject: [PATCH 2/4] fix(arm): handle nocompute-partition check for cycle-split sub-partitions After a cycle split, the nocompute check still iterated over the original partition nodes and only attempted to remove the original tag. This left orphan tags in the returned set when sub-partitions were rejected. Group the nocompute check by active tag so each sub-partition is evaluated and cleaned up independently. Also update reject_partition() to accept an iterable of nodes instead of a Partition object. --- backends/arm/tosa/partitioner.py | 59 +++++++++++++++++++------------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 877eee2e0f0..1dae743fc80 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -16,7 +16,7 @@ import logging from collections import deque from itertools import count -from typing import Callable, List, Optional, Sequence, Tuple +from typing import Callable, Iterable, List, Optional, Sequence, Tuple import torch from executorch.backends.arm._passes.arm_pass_utils import ( @@ -112,18 +112,19 @@ def is_partitioned( def reject_partition( - reason: str, partition: Partition, reporter: WhyNoPartitionReporter + reason: str, + nodes: Iterable[torch.fx.Node], + reporter: WhyNoPartitionReporter, ) -> None: """Remove a proposed partition and record the rejection reason. Args: reason (str): Human-readable explanation for rejection. - partition (object): Proposed partition object from the - capability partitioner. + nodes: The nodes to de-tag. reporter (WhyNoPartitionReporter): used to report why nodes were rejected. """ - for node in partition.nodes: + for node in nodes: if "delegation_tag" in node.meta: del node.meta["delegation_tag"] reporter.report_reject( @@ -381,26 +382,36 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: for node in component: node.meta["delegation_tag"] = new_tag - # Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation." - is_nocompute_partition = all( - _is_noop_clone(node) - or _is_noop_alias_copy(node) - or _is_noop_expand(node) - or _is_noop_detach_copy(node) - or _is_noop_to_dim_order_copy(node) - or _is_view_copy(node) - or node.target in Q_OPS - or node.target in DQ_OPS - for node in partition.nodes - ) - if is_nocompute_partition: - reject_partition( - "Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.", - partition, - reporter, + # Check whether the partition contains only no-op or non-computational + # ops. Such partitions don't make sense to delegate, and in the worst + # case may be optimized away during lowering, which can break + # compilation. After a cycle split the nodes may belong to multiple + # sub-partitions, so collect every active tag and check each group. + active_tags: dict[str, list[torch.fx.Node]] = {} + for node in partition.nodes: + node_tag = node.meta.get("delegation_tag") + if node_tag is not None and node_tag in tags: + active_tags.setdefault(node_tag, []).append(node) + + for active_tag, nodes in active_tags.items(): + is_nocompute_partition = all( + _is_noop_clone(node) + or _is_noop_alias_copy(node) + or _is_noop_expand(node) + or _is_noop_detach_copy(node) + or _is_noop_to_dim_order_copy(node) + or _is_view_copy(node) + or node.target in Q_OPS + or node.target in DQ_OPS + for node in nodes ) - if tag in tags: - tags.remove(tag) + if is_nocompute_partition: + reject_partition( + "Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.", + nodes, + reporter, + ) + tags.remove(active_tag) return tags def partition(self, exported_program: ExportedProgram) -> PartitionResult: From 9b912a8c70ee14af5dc961da12eee2012b41f8a4 Mon Sep 17 00:00:00 2001 From: beomwookang Date: Thu, 23 Apr 2026 10:42:58 +0900 Subject: [PATCH 3/4] test(arm): add unit tests for partition cycle detection utilities Add tests for _validate_partition and _find_connected_components using synthetic torch.fx graphs. Cover contiguous/non-contiguous partitions, single nodes, branching graphs, and empty sets. --- .../misc/test_partition_cycle_detection.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 backends/arm/test/misc/test_partition_cycle_detection.py diff --git a/backends/arm/test/misc/test_partition_cycle_detection.py b/backends/arm/test/misc/test_partition_cycle_detection.py new file mode 100644 index 00000000000..3bf579dbd0b --- /dev/null +++ b/backends/arm/test/misc/test_partition_cycle_detection.py @@ -0,0 +1,91 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm.tosa.partitioner import ( + _find_connected_components, + _validate_partition, +) + + +def _build_linear_graph(): + """Build a linear graph: x -> a -> b -> c -> output. + + Returns the graph and nodes (x, a, b, c, output). + """ + graph = torch.fx.Graph() + x = graph.placeholder("x") + a = graph.call_function(torch.add, (x, x)) + b = graph.call_function(torch.mul, (a, a)) + c = graph.call_function(torch.sub, (b, b)) + output = graph.output(c) + return graph, (x, a, b, c, output) + + +class TestValidatePartition(unittest.TestCase): + def test_contiguous_partition_is_valid(self): + """A contiguous slice of a linear graph has no cycle.""" + _, (_, a, b, _, _) = _build_linear_graph() + self.assertTrue(_validate_partition({a, b})) + + def test_non_contiguous_partition_has_cycle(self): + """Nodes {a, c} with b in between create a cycle: extracting a and c + would force b to depend on a (inside) and c to depend on b (outside), + while c is also inside.""" + _, (_, a, _, c, _) = _build_linear_graph() + self.assertFalse(_validate_partition({a, c})) + + def test_single_node_is_valid(self): + _, (_, a, _, _, _) = _build_linear_graph() + self.assertTrue(_validate_partition({a})) + + def test_full_graph_interior_is_valid(self): + """All interior nodes form a valid partition.""" + _, (_, a, b, c, _) = _build_linear_graph() + self.assertTrue(_validate_partition({a, b, c})) + + +class TestFindConnectedComponents(unittest.TestCase): + def test_single_component(self): + _, (_, a, b, _, _) = _build_linear_graph() + components = _find_connected_components({a, b}) + self.assertEqual(len(components), 1) + self.assertEqual(components[0], {a, b}) + + def test_disconnected_components(self): + """Nodes {a, c} with b not in the set form two components.""" + _, (_, a, _, c, _) = _build_linear_graph() + components = _find_connected_components({a, c}) + self.assertEqual(len(components), 2) + component_sets = [frozenset(c) for c in components] + self.assertIn(frozenset({a}), component_sets) + self.assertIn(frozenset({c}), component_sets) + + def test_empty_set(self): + components = _find_connected_components(set()) + self.assertEqual(len(components), 0) + + def test_branching_graph(self): + """Graph with a fork: x -> a -> b, x -> a -> c. {b, c} are disconnected + when a is excluded.""" + graph = torch.fx.Graph() + x = graph.placeholder("x") + a = graph.call_function(torch.add, (x, x)) + b = graph.call_function(torch.mul, (a, a)) + c = graph.call_function(torch.sub, (a, a)) + _ = graph.output((b, c)) + + components = _find_connected_components({b, c}) + self.assertEqual(len(components), 2) + + # With a included, all three form one component + components = _find_connected_components({a, b, c}) + self.assertEqual(len(components), 1) + + +if __name__ == "__main__": + unittest.main() From 199a8600aa83f4cf11f55fd712489ef6b0aea498 Mon Sep 17 00:00:00 2001 From: beomwookang Date: Thu, 23 Apr 2026 10:50:34 +0900 Subject: [PATCH 4/4] chore(arm): fix lint warnings in partitioner and cycle detection tests --- backends/arm/test/misc/test_partition_cycle_detection.py | 3 ++- backends/arm/tosa/partitioner.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/backends/arm/test/misc/test_partition_cycle_detection.py b/backends/arm/test/misc/test_partition_cycle_detection.py index 3bf579dbd0b..288204d9759 100644 --- a/backends/arm/test/misc/test_partition_cycle_detection.py +++ b/backends/arm/test/misc/test_partition_cycle_detection.py @@ -35,7 +35,8 @@ def test_contiguous_partition_is_valid(self): def test_non_contiguous_partition_has_cycle(self): """Nodes {a, c} with b in between create a cycle: extracting a and c would force b to depend on a (inside) and c to depend on b (outside), - while c is also inside.""" + while c is also inside. + """ _, (_, a, _, c, _) = _build_linear_graph() self.assertFalse(_validate_partition({a, c})) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 1dae743fc80..7b4d068e9b8 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -43,7 +43,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram from torch.fx import GraphModule -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import any_chain, OperatorSupportBase logger = logging.getLogger(__name__)