From 06f0be30f62764687f6a222d8dd3e3660af647c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 10 Dec 2025 00:42:27 +0100 Subject: [PATCH 01/12] Add command line to compare two onnx models --- _doc/api/torch_onnx/compare.rst | 7 + _doc/api/torch_onnx/index.rst | 1 + _doc/cmds/compare.rst | 38 +++ _doc/cmds/index.rst | 1 + _unittests/ut_torch_onnx/test_compare.py | 89 +++++++ _unittests/ut_xrun_doc/test_command_lines.py | 8 + .../ut_xrun_doc/test_command_lines_exe.py | 8 + onnx_diagnostic/_command_lines_parser.py | 47 ++++ onnx_diagnostic/torch_onnx/compare.py | 223 ++++++++++++++++++ 9 files changed, 422 insertions(+) create mode 100644 _doc/api/torch_onnx/compare.rst create mode 100644 _doc/cmds/compare.rst create mode 100644 _unittests/ut_torch_onnx/test_compare.py create mode 100644 onnx_diagnostic/torch_onnx/compare.py diff --git a/_doc/api/torch_onnx/compare.rst b/_doc/api/torch_onnx/compare.rst new file mode 100644 index 00000000..4e9b560c --- /dev/null +++ b/_doc/api/torch_onnx/compare.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.torch_onnx.compare +================================== + +.. automodule:: onnx_diagnostic.torch_onnx.compare + :members: + :no-undoc-members: diff --git a/_doc/api/torch_onnx/index.rst b/_doc/api/torch_onnx/index.rst index 7ace52ec..0eff1c87 100644 --- a/_doc/api/torch_onnx/index.rst +++ b/_doc/api/torch_onnx/index.rst @@ -5,6 +5,7 @@ onnx_diagnostic.torch_onnx :maxdepth: 1 :caption: submodules + compare runtime_info sbs sbs_dataclasses diff --git a/_doc/cmds/compare.rst b/_doc/cmds/compare.rst new file mode 100644 index 00000000..be5fcc81 --- /dev/null +++ b/_doc/cmds/compare.rst @@ -0,0 +1,38 @@ +-m onnx_diagnostic compare ... compares two models +================================================== + +Description ++++++++++++ + +The command lines compares two models assuming they represent +the same models and most parts of both are the same. +Different options were used to export or an optimization +was different. This highlights the differences. + +.. runpython:: + + from onnx_diagnostic._command_lines_parser import get_parser_compare + + get_parser_compare().print_help() + +Example ++++++++ + +.. code-block:: bash + + python -m onnx_diagnostic compare + +.. code-block:: text + + -- loading 'two_nodes.onnx' + -- loading 'two_nodes.onnx' + -- starts comparison + -- done with distance 0 + 0000 INITIA FLOAT ? encoder.encoders.0.layer_norm_att.w | INITIA FLOAT ? encoder.encoders.0.layer_norm_att.w + 0001 INITIA FLOAT ? encoder.encoders.0.layer_norm_att.b | INITIA FLOAT ? encoder.encoders.0.layer_norm_att.b + 0002 INPUT FLOAT s0x(((s1 - 1)//8)) linear | INPUT FLOAT s0x(((s1 - 1)//8)) linear + 0003 INPUT FLOAT s0x(((s1 - 1)//8)) mul_178 | INPUT FLOAT s0x(((s1 - 1)//8)) mul_178 + 0004 NODE FLOAT s0x(((s1 - 1)//8)) Add add_256 | NODE FLOAT s0x(((s1 - 1)//8)) Add add_256 + 0005 NODE FLOAT s0x(((s1 - 1)//8)) LayerNormalizat layer_norm_1 | NODE FLOAT s0x(((s1 - 1)//8)) LayerNormalizat layer_norm_1 + 0006 OUTPUT FLOAT s0x(((s1 - 1)//8)) layer_norm_1 | OUTPUT FLOAT s0x(((s1 - 1)//8)) layer_norm_1 + 0007 OUTPUT FLOAT s0x(((s1 - 1)//8)) add_256 | OUTPUT FLOAT s0x(((s1 - 1)//8)) add_256 diff --git a/_doc/cmds/index.rst b/_doc/cmds/index.rst index adabb760..4c81b46a 100644 --- a/_doc/cmds/index.rst +++ b/_doc/cmds/index.rst @@ -8,6 +8,7 @@ Command Lines .. toctree:: :maxdepth: 1 + compare config sbs validate diff --git a/_unittests/ut_torch_onnx/test_compare.py b/_unittests/ut_torch_onnx/test_compare.py new file mode 100644 index 00000000..25373324 --- /dev/null +++ b/_unittests/ut_torch_onnx/test_compare.py @@ -0,0 +1,89 @@ +import unittest +import numpy as np +import onnx +import onnx.helper as oh +import onnx.numpy_helper as onh +from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.torch_onnx.compare import ObsCompare + +TFLOAT = onnx.TensorProto.FLOAT +TINT64 = onnx.TensorProto.INT64 + + +class TestCompare(ExtTestCase): + def _get_model(self, cast=True): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]), + oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]), + oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]), + oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]), + ( + oh.make_node("Cast", ["xm2c"], ["xm2"], to=1) + if cast + else oh.make_node("Identity", ["xmc2"], ["xm2"]) + ), + oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]), + oh.make_node("Reshape", ["xm", "shape3"], ["Z"]), + ], + "dummy", + [oh.make_tensor_value_info("X", TFLOAT, [320, 1280])], + [oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])], + [ + onh.from_array( + np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y" + ), + onh.from_array(np.array([0], dtype=np.int64), name="zero"), + onh.from_array(np.array([1], dtype=np.int64), name="un"), + onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"), + onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"), + onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"), + ], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + return model + + def test_edit_distance_0(self): + model = self._get_model() + seq = ObsCompare.obs_sequence_from_model(model) + dist, path, pair_cmp = ObsCompare.distance_sequence(seq, seq) + self.assertEqual(dist, 0) + self.assertEqual(path, [(i, i) for i in range(len(path))]) + self.assertEqual(len(path), len(pair_cmp)) + uni = set() + for o1, o2 in pair_cmp: + self.assertIsInstance(o1, ObsCompare) + self.assertIsInstance(o2, ObsCompare) + self.assertEqual(o1, o2) + row = f"{o1} | {o2}" + uni.add(len(row)) + self.assertEqual(len(uni), 1) + + def test_edit_distance_1(self): + model = self._get_model() + model2 = self._get_model(cast=False) + seq1 = ObsCompare.obs_sequence_from_model(model) + seq2 = ObsCompare.obs_sequence_from_model(model2) + dist, path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2) + self.assertGreater(dist, 2000) + expected_path = [ + *[(i, i) for i in range(11)], + *[(10, 11), (11, 11)], + *[(i, i) for i in range(12, len(seq1))], + ] + self.assertEqual(expected_path, path) + self.assertEqual(len(path), len(pair_cmp)) + for o1, o2 in pair_cmp: + if o1: + self.assertIsInstance(o1, ObsCompare) + if o2: + self.assertIsInstance(o2, ObsCompare) + if o1 and o2: + self.assertEqual(o1, o2) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines.py b/_unittests/ut_xrun_doc/test_command_lines.py index b055763c..1a193229 100644 --- a/_unittests/ut_xrun_doc/test_command_lines.py +++ b/_unittests/ut_xrun_doc/test_command_lines.py @@ -5,6 +5,7 @@ from onnx_diagnostic._command_lines_parser import ( get_main_parser, get_parser_agg, + get_parser_compare, get_parser_config, get_parser_dot, get_parser_find, @@ -170,6 +171,13 @@ def test_parser_dot(self): text = st.getvalue() self.assertIn("--run", text) + def test_parser_compare(self): + st = StringIO() + with redirect_stdout(st): + get_parser_compare().print_help() + text = st.getvalue() + self.assertIn("compare", text) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines_exe.py b/_unittests/ut_xrun_doc/test_command_lines_exe.py index 01470b34..c9ff8f15 100644 --- a/_unittests/ut_xrun_doc/test_command_lines_exe.py +++ b/_unittests/ut_xrun_doc/test_command_lines_exe.py @@ -203,6 +203,14 @@ def forward(self, x): # text is empty is dot is not installed self.assertIn("converts into dot", text) + def test_j_parser_compare(self): + st = StringIO() + with redirect_stdout(st): + main(["compare", self.dummy_path, self.dummy_path]) + text = st.getvalue() + print(text) + self.assertIn("done with distance 0", text) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 049a4e97..16538eaf 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1508,6 +1508,51 @@ def _size(name): print("-- done") +def get_parser_compare() -> ArgumentParser: + parser = ArgumentParser( + prog="compare", + description=textwrap.dedent( + """ + Compares two onnx models by aligning the nodes between both models. + This is done through an edit distance. + """ + ), + epilog=textwrap.dedent( + """ + Each element (initializer, input, node, output) of the model + is converted into an observation. Then it defines a distance between + two elements. And finally, it finds the best alignment with + an edit distance. + """ + ), + ) + parser.add_argument("model1", type=str, help="first model to compare") + parser.add_argument("model2", type=str, help="second model to compare") + return parser + + +def _cmd_compare(argv: List[Any]): + import onnx + from .torch_onnx.compare import ObsCompare + + parser = get_parser_compare() + args = parser.parse_args(argv[1:]) + print(f"-- loading {args.model1!r}") + seq1 = ObsCompare.obs_sequence_from_model(onnx.load(args.model1, load_external_data=False)) + print(f"-- loading {args.model2!r}") + seq2 = ObsCompare.obs_sequence_from_model(onnx.load(args.model2, load_external_data=False)) + print("-- starts comparison") + dist, _path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2) + print(f"-- done with distance {dist}") + for i, (o1, o2) in enumerate(pair_cmp): + print(f"{i:04d} {o1} | {o2}") + + +############# +# main parser +############# + + def get_main_parser() -> ArgumentParser: parser = ArgumentParser( prog="onnx_diagnostic", @@ -1555,6 +1600,7 @@ def get_main_parser() -> ArgumentParser: def main(argv: Optional[List[Any]] = None): fcts = dict( agg=_cmd_agg, + compare=_cmd_compare, config=_cmd_config, dot=_cmd_dot, exportsample=_cmd_export_sample, @@ -1580,6 +1626,7 @@ def main(argv: Optional[List[Any]] = None): else: parsers = dict( agg=get_parser_agg, + compare=get_parser_compare, config=get_parser_config, dot=get_parser_dot, exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator] diff --git a/onnx_diagnostic/torch_onnx/compare.py b/onnx_diagnostic/torch_onnx/compare.py new file mode 100644 index 00000000..8051f484 --- /dev/null +++ b/onnx_diagnostic/torch_onnx/compare.py @@ -0,0 +1,223 @@ +import enum +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union +import onnx +from ..helpers.onnx_helper import onnx_dtype_name + + +def _align(res: str, limit: int) -> str: + if len(res) == limit: + return res + if len(res) > limit: + return res[:limit] + return res + " " * (limit - len(res)) + + +class ObsType(enum.IntEnum): + """Observation kind.""" + + RESULT = 1 + INITIALIZER = 2 + SPARSE_INITIALIZER = 4 + INPUT = 8 + OUTPUT = 16 + NODE = 32 + + def __repr__(self): + return f"{self.__class__.__name__}.{self._name_}" + + +@dataclass +class ObsCompare: + """ + The description of an observation, a node, an input, an output, an initializer. + + :param kind: node type, see :class:`ObsType` + :param name_or_outputs: name of an initilizer or the outputs of a node + :param itype: onnx type + :param index: index of an input or output + :param shape: shape + :param op_type: node op_type + :param comment: comment, unused + """ + + kind: ObsType + name_or_outputs: Tuple[str] + itype: int = 0 + index: int = 0 + shape: Optional[Tuple[Tuple[Union[int, str], ...]]] = None + op_type: str = "" + comment: str = "" + + def __str__(self) -> str: + "usual" + els = [ + _align(self.kind._name_, 6), + _align(onnx_dtype_name(self.itype) if self.itype else "?", 8), + _align("?" if self.shape is None else "x".join(map(str, self.shape)), 18), + _align(self.op_type or "", 15), + _align(", ".join(self.name_or_outputs), 35), + ] + return " ".join(els) + + def distance(self, obs: "ObsCompare") -> float: + """Computes a cost between two observations.""" + if self.kind != obs.kind: + return 1e6 + if self.itype != obs.itype: + return 1e5 + if self.kind == ObsType.NODE: + if self.op_type != obs.op_type: + return 1e4 + if len(self.name_or_outputs) == 1: + return 0 if self.name_or_outputs == obs.name_or_outputs else 1e2 + a = set(self.name_or_outputs) & set(obs.name_or_outputs) + b = set(self.name_or_outputs) | set(obs.name_or_outputs) + return 1e2 * (len(b) - len(a)) + if self.kind == ObsType.INPUT: + return ( + 999.7 + if self.itype != obs.itype + or self.shape != obs.shape + or self.index != obs.index + else 0 + ) + if self.kind == ObsType.INITIALIZER or self.kind == ObsType.SPARSE_INITIALIZER: + return 1e3 if self.itype != obs.itype or self.shape or obs.shape else 0 + if self.kind == ObsType.OUTPUT: + return ( + 999.1 + if self.itype != obs.itype + or self.shape != obs.shape + or self.index != obs.index + else 0 + ) + return 1e8 + + @classmethod + def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tuple[ + float, + List[Tuple[int, int]], + List[Tuple[Optional["ObsCompare"], Optional["ObsCompare"]]], + ]: + """ + Computes the distance between two sequences of results. + + :param s1: first sequence + :param s2: second sequence + :return: distance and alignment + """ + delay = max(50, abs(len(s2) - len(s1)) + 1) + distance = {(-1, -1): 0} + predecessor = {(-1, -1): None} + insert_cost = 1e3 + for i in range(len(s1)): + for j in range(max(0, i - delay), min(len(s2), i + delay)): + best = distance.get((i, j), 1e100) + pred = None + ki, kj = i - 1, j - 1 + if (ki, kj) in distance: + d = distance[ki, kj] + s1[i].distance(s2[j]) + if d < best: + best = d + pred = (ki, kj) + ki, kj = i - 1, j + if (ki, kj) in distance: + d = distance[ki, kj] + insert_cost + 1 + if d < best: + best = d + pred = (ki, kj) + ki, kj = i, j - 1 + if (ki, kj) in distance: + d = distance[ki, kj] + insert_cost + 0.1 + if d < best: + best = d + pred = (ki, kj) + distance[i, j] = best + predecessor[i, j] = pred + + # reverse + way = [] + last = len(s1) - 1, len(s2) - 1 + while last is not None: + way.append(last) + last = predecessor[last] + indices = list(reversed(way))[1:] + obs_path = [] + last = -1, -1 + for i, j in indices: + di = i - last[0] + dj = j - last[1] + if di == dj == 1: + obs_path.append((s1[i], s2[j])) + elif di == 0: + obs_path.append((None, s2[j])) + elif dj == 0: + obs_path.append((s1[i], None)) + else: + raise RuntimeError(f"issue with di={di}, dj={dj}") + last = i, j + return distance[len(s1) - 1, len(s2) - 1], indices, obs_path + + @classmethod + def obs_sequence_from_model( + cls, + model: Union[onnx.ModelProto, onnx.GraphProto], + ) -> List["ObsCompare"]: + """ + Creates a sequence of observations bases on a model. + + :param model: model + :return: sequence of observations + """ + graph = model if isinstance(model, onnx.GraphProto) else model.graph + + shapes = {} + types = {} + for info in graph.value_info: + if info.type.tensor_type: + t = info.type.tensor_type + shapes[info.name] = tuple((d.dim_param or d.dim_value) for d in t.shape.dim) + types[info.name] = t.elem_type + + seq = [] + for init in graph.initializer: + obs = ObsCompare( + kind=ObsType.INITIALIZER, + itype=init.data_type, + name_or_outputs=(init.name,), + ) + seq.append(obs) + for i, inp in enumerate(graph.input): + obs = ObsCompare( + kind=ObsType.INPUT, + itype=inp.type.tensor_type.elem_type, + index=i, + shape=tuple( + (d.dim_param or d.dim_value) for d in inp.type.tensor_type.shape.dim + ), + name_or_outputs=(inp.name,), + ) + seq.append(obs) + for node in graph.node: + obs = ObsCompare( + kind=ObsType.NODE, + itype=types.get(node.output[0], 0), + index=i, + shape=shapes.get(node.output[0], None), + name_or_outputs=tuple(node.output), + op_type=node.op_type, + ) + seq.append(obs) + for i, inp in enumerate(graph.output): + obs = ObsCompare( + kind=ObsType.OUTPUT, + itype=inp.type.tensor_type.elem_type, + index=i, + shape=tuple( + (d.dim_param or d.dim_value) for d in inp.type.tensor_type.shape.dim + ), + name_or_outputs=(inp.name,), + ) + seq.append(obs) + return seq From ae701a6a09245fde7a606526dc440360c87279e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 10 Dec 2025 00:43:38 +0100 Subject: [PATCH 02/12] doc --- CHANGELOGS.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 506e3c21..a162945b 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.8.6 +++++ +* :pr:`353`: add command line to compare two onnx models + 0.8.5 +++++ From aaf2753d2c63626c46bcca7fb526a486e368e4ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 10 Dec 2025 00:46:27 +0100 Subject: [PATCH 03/12] mypy --- onnx_diagnostic/torch_onnx/compare.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/compare.py b/onnx_diagnostic/torch_onnx/compare.py index 8051f484..8dacb029 100644 --- a/onnx_diagnostic/torch_onnx/compare.py +++ b/onnx_diagnostic/torch_onnx/compare.py @@ -1,6 +1,6 @@ import enum from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import onnx from ..helpers.onnx_helper import onnx_dtype_name @@ -33,7 +33,7 @@ class ObsCompare: The description of an observation, a node, an input, an output, an initializer. :param kind: node type, see :class:`ObsType` - :param name_or_outputs: name of an initilizer or the outputs of a node + :param name_or_outputs: name of an initializer or the outputs of a node :param itype: onnx type :param index: index of an input or output :param shape: shape @@ -108,8 +108,8 @@ def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tu :return: distance and alignment """ delay = max(50, abs(len(s2) - len(s1)) + 1) - distance = {(-1, -1): 0} - predecessor = {(-1, -1): None} + distance: Dict[Tuple[int, int], Union[int, float]] = {(-1, -1): 0} + predecessor: Dict[Tuple[int, int], Tuple[int, int]] = {(-1, -1): None} insert_cost = 1e3 for i in range(len(s1)): for j in range(max(0, i - delay), min(len(s2), i + delay)): @@ -138,7 +138,7 @@ def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tu # reverse way = [] - last = len(s1) - 1, len(s2) - 1 + last: Optional[Tuple[int, int]] = len(s1) - 1, len(s2) - 1 while last is not None: way.append(last) last = predecessor[last] From 849f8949db829a98b6f06c488fcad3f337ae7bb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 10 Dec 2025 09:29:56 +0100 Subject: [PATCH 04/12] better unit test --- _unittests/ut_torch_onnx/test_compare.py | 10 ++++++++++ onnx_diagnostic/torch_onnx/compare.py | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_torch_onnx/test_compare.py b/_unittests/ut_torch_onnx/test_compare.py index 25373324..7aef9f1d 100644 --- a/_unittests/ut_torch_onnx/test_compare.py +++ b/_unittests/ut_torch_onnx/test_compare.py @@ -76,13 +76,23 @@ def test_edit_distance_1(self): ] self.assertEqual(expected_path, path) self.assertEqual(len(path), len(pair_cmp)) + n1, n2, n12 = 0, 0, 0 for o1, o2 in pair_cmp: if o1: self.assertIsInstance(o1, ObsCompare) + else: + n1 += 1 if o2: self.assertIsInstance(o2, ObsCompare) + else: + n2 += 1 if o1 and o2: self.assertEqual(o1, o2) + elif not o1 and not o2: + n12 += 1 + self.assertEqual(n1, 1) + self.assertEqual(n2, 1) + self.assertEqual(n12, 0) if __name__ == "__main__": diff --git a/onnx_diagnostic/torch_onnx/compare.py b/onnx_diagnostic/torch_onnx/compare.py index 8dacb029..2fe21bcd 100644 --- a/onnx_diagnostic/torch_onnx/compare.py +++ b/onnx_diagnostic/torch_onnx/compare.py @@ -109,7 +109,7 @@ def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tu """ delay = max(50, abs(len(s2) - len(s1)) + 1) distance: Dict[Tuple[int, int], Union[int, float]] = {(-1, -1): 0} - predecessor: Dict[Tuple[int, int], Tuple[int, int]] = {(-1, -1): None} + predecessor: Dict[Tuple[int, int], Optional[Tuple[int, int]]] = {(-1, -1): None} insert_cost = 1e3 for i in range(len(s1)): for j in range(max(0, i - delay), min(len(s2), i + delay)): @@ -143,7 +143,7 @@ def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tu way.append(last) last = predecessor[last] indices = list(reversed(way))[1:] - obs_path = [] + obs_path: List[Optional[Tuple[int, int]]] = [] last = -1, -1 for i, j in indices: di = i - last[0] From dcc44e1e16db24634bbda4639d919b90f82a469a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 10 Dec 2025 09:41:59 +0100 Subject: [PATCH 05/12] type --- onnx_diagnostic/torch_onnx/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_onnx/compare.py b/onnx_diagnostic/torch_onnx/compare.py index 2fe21bcd..4cf1fadc 100644 --- a/onnx_diagnostic/torch_onnx/compare.py +++ b/onnx_diagnostic/torch_onnx/compare.py @@ -138,7 +138,7 @@ def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tu # reverse way = [] - last: Optional[Tuple[int, int]] = len(s1) - 1, len(s2) - 1 + last: Optional[Tuple[Optional[int], Optional[int]]] = len(s1) - 1, len(s2) - 1 while last is not None: way.append(last) last = predecessor[last] From d90aeb782db7baa1d9dae655b7e633d71daad395 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 10 Dec 2025 09:51:04 +0100 Subject: [PATCH 06/12] fix --- onnx_diagnostic/torch_onnx/compare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/compare.py b/onnx_diagnostic/torch_onnx/compare.py index 4cf1fadc..7e1a8c77 100644 --- a/onnx_diagnostic/torch_onnx/compare.py +++ b/onnx_diagnostic/torch_onnx/compare.py @@ -138,12 +138,12 @@ def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tu # reverse way = [] - last: Optional[Tuple[Optional[int], Optional[int]]] = len(s1) - 1, len(s2) - 1 + last: Optional[Tuple[int, int]] = len(s1) - 1, len(s2) - 1 while last is not None: way.append(last) last = predecessor[last] indices = list(reversed(way))[1:] - obs_path: List[Optional[Tuple[int, int]]] = [] + obs_path: List[Tuple[Optional[ObsCompare], Optional[ObsCompare]]] = [] last = -1, -1 for i, j in indices: di = i - last[0] From 7a351f5ecb656f7cb44519abcae7303959a6e129 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 10 Dec 2025 09:39:04 +0000 Subject: [PATCH 07/12] improves rendering --- onnx_diagnostic/_command_lines_parser.py | 2 +- onnx_diagnostic/torch_onnx/compare.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 16538eaf..869f32b3 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1545,7 +1545,7 @@ def _cmd_compare(argv: List[Any]): dist, _path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2) print(f"-- done with distance {dist}") for i, (o1, o2) in enumerate(pair_cmp): - print(f"{i:04d} {o1} | {o2}") + print(f"{i:04d} {ObsCompare.to_str(o1)} | {ObsCompare.to_str(o2)}") ############# diff --git a/onnx_diagnostic/torch_onnx/compare.py b/onnx_diagnostic/torch_onnx/compare.py index 7e1a8c77..4071e6f6 100644 --- a/onnx_diagnostic/torch_onnx/compare.py +++ b/onnx_diagnostic/torch_onnx/compare.py @@ -60,6 +60,12 @@ def __str__(self) -> str: ] return " ".join(els) + @classmethod + def to_str(cls, obs: Optional["ObsCompare"]) -> str: + if obs: + return str(obs) + return " " * (6 + 8 + 18 + 15 + 35 + 4) + def distance(self, obs: "ObsCompare") -> float: """Computes a cost between two observations.""" if self.kind != obs.kind: From 391a003058e7b3548157a2287088d6943f2da0c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 10 Dec 2025 09:59:56 +0000 Subject: [PATCH 08/12] improve --- onnx_diagnostic/torch_onnx/compare.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/onnx_diagnostic/torch_onnx/compare.py b/onnx_diagnostic/torch_onnx/compare.py index 4071e6f6..1af8e611 100644 --- a/onnx_diagnostic/torch_onnx/compare.py +++ b/onnx_diagnostic/torch_onnx/compare.py @@ -73,13 +73,26 @@ def distance(self, obs: "ObsCompare") -> float: if self.itype != obs.itype: return 1e5 if self.kind == ObsType.NODE: + d = 0 if self.op_type != obs.op_type: - return 1e4 - if len(self.name_or_outputs) == 1: - return 0 if self.name_or_outputs == obs.name_or_outputs else 1e2 - a = set(self.name_or_outputs) & set(obs.name_or_outputs) - b = set(self.name_or_outputs) | set(obs.name_or_outputs) - return 1e2 * (len(b) - len(a)) + is_gemm1 = self.op_type in {"Gemm", "MatMul"} + is_gemm2 = obs.op_type in {"Gemm", "MatMul"} + d += 1e2 if is_gemm1 and is_gemm2 else (1e4 if not is_gemm1 and not is_gemm2 else 1e5) + if len(self.name_or_outputs) == 1 and len(obs.name_or_outputs) == 1: + if self.name_or_outputs[0] != obs.name_or_outputs[0]: + n1 = self.name_or_outputs[0] + n2 = obs.name_or_outputs[0] + n1 = n1.replace("_", "") + n2 = n2.replace("_", "") + if n1 == n2: + d += 1 + else: + d += 1e4 + else: + a = set(self.name_or_outputs) & set(obs.name_or_outputs) + b = set(self.name_or_outputs) | set(obs.name_or_outputs) + d += 1e4 * (len(b) - len(a)) + return d if self.kind == ObsType.INPUT: return ( 999.7 From 5e5a940035e4382d4218bb523336ec93f5c37f48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 10 Dec 2025 10:10:54 +0000 Subject: [PATCH 09/12] conflict --- onnx_diagnostic/_command_lines_parser.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 7069718b..869f32b3 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1545,11 +1545,7 @@ def _cmd_compare(argv: List[Any]): dist, _path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2) print(f"-- done with distance {dist}") for i, (o1, o2) in enumerate(pair_cmp): -<<<<<<< HEAD print(f"{i:04d} {ObsCompare.to_str(o1)} | {ObsCompare.to_str(o2)}") -======= - print(f"{i:04d} {o1} | {o2}") ->>>>>>> 6b27604067df0fda92ba7500b1538e7a3d4292b2 ############# From 89cd087d88b9e75503bd92a079137608f0725a91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 10 Dec 2025 16:53:14 +0100 Subject: [PATCH 10/12] fix comp --- _unittests/ut_torch_onnx/test_compare.py | 84 +++++-- .../ut_xrun_doc/test_command_lines_exe.py | 1 - onnx_diagnostic/_command_lines_parser.py | 5 +- onnx_diagnostic/torch_onnx/compare.py | 205 ++++++++++++------ 4 files changed, 206 insertions(+), 89 deletions(-) diff --git a/_unittests/ut_torch_onnx/test_compare.py b/_unittests/ut_torch_onnx/test_compare.py index 7aef9f1d..ecc96771 100644 --- a/_unittests/ut_torch_onnx/test_compare.py +++ b/_unittests/ut_torch_onnx/test_compare.py @@ -3,8 +3,9 @@ import onnx import onnx.helper as oh import onnx.numpy_helper as onh -from onnx_diagnostic.ext_test_case import ExtTestCase -from onnx_diagnostic.torch_onnx.compare import ObsCompare +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings +from onnx_diagnostic.export.api import to_onnx +from onnx_diagnostic.torch_onnx.compare import ObsCompare, ObsComparePair TFLOAT = onnx.TensorProto.FLOAT TINT64 = onnx.TensorProto.INT64 @@ -54,7 +55,8 @@ def test_edit_distance_0(self): self.assertEqual(path, [(i, i) for i in range(len(path))]) self.assertEqual(len(path), len(pair_cmp)) uni = set() - for o1, o2 in pair_cmp: + for pair in pair_cmp: + o1, o2 = pair.side1, pair.side2 self.assertIsInstance(o1, ObsCompare) self.assertIsInstance(o2, ObsCompare) self.assertEqual(o1, o2) @@ -68,16 +70,12 @@ def test_edit_distance_1(self): seq1 = ObsCompare.obs_sequence_from_model(model) seq2 = ObsCompare.obs_sequence_from_model(model2) dist, path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2) - self.assertGreater(dist, 2000) - expected_path = [ - *[(i, i) for i in range(11)], - *[(10, 11), (11, 11)], - *[(i, i) for i in range(12, len(seq1))], - ] - self.assertEqual(expected_path, path) + self.assertGreaterOrEqual(dist, 900) + self.assertEqual([(i, i) for i in range(len(path))], path) self.assertEqual(len(path), len(pair_cmp)) n1, n2, n12 = 0, 0, 0 - for o1, o2 in pair_cmp: + for pair in pair_cmp: + o1, o2 = pair.side1, pair.side2 if o1: self.assertIsInstance(o1, ObsCompare) else: @@ -87,13 +85,71 @@ def test_edit_distance_1(self): else: n2 += 1 if o1 and o2: - self.assertEqual(o1, o2) + pass elif not o1 and not o2: n12 += 1 - self.assertEqual(n1, 1) - self.assertEqual(n2, 1) + self.assertEqual(n1, 0) + self.assertEqual(n2, 0) self.assertEqual(n12, 0) + @ignore_warnings(DeprecationWarning) + def test_comp_model_gemm(self): + import torch + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 16, 5) + self.fc1 = torch.nn.Linear(144, 64) + self.fc2 = torch.nn.Linear(64, 128) + self.fc3 = torch.nn.Linear(128, 10) + + def forward(self, x): + x = torch.nn.functional.max_pool2d( + torch.nn.functional.relu(self.conv1(x)), (4, 4) + ) + # x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = torch.flatten(x, 1) + x = torch.nn.functional.relu(self.fc1(x)) + x = torch.nn.functional.relu(self.fc2(x)) + y = self.fc3(x) + return y + + model = Model() + x = torch.randn((2, 3, 16, 17), dtype=torch.float32) + model(x) + dynamic_shapes = ({0: "batch", 3: "dim"},) + onx_file = self.get_dump_file("test_comp_model_gemm.onnx") + to_onnx( + model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=True + ).save(onx_file) + onx = onnx.load(onx_file) + self.assert_onnx_disc("test_comp_model_gemm", onx, model, (x,), use_ort=True) + seq1 = ObsCompare.obs_sequence_from_model(onx) + seq2 = ObsCompare.obs_sequence_from_model(onx) + dist, _path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2) + text = str(pair_cmp[0]) + self.assertIn("0000 INITIA", text) + self.assertNotIn("(", text) + text = ObsComparePair.to_str(pair_cmp) + self.assertEqual(dist, 0) + self.assertNotIn("?", text) + self.assertIn("0013 NODE", text) + onx_file0 = self.get_dump_file("test_comp_model_gemm0.onnx") + to_onnx( + model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=False + ).save(onx_file0) + onx0 = onnx.load(onx_file0) + seq1 = ObsCompare.obs_sequence_from_model(onx0) + seq2 = ObsCompare.obs_sequence_from_model(onx) + _dist, _path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2) + text = ObsComparePair.to_str(pair_cmp) + self.assertIn("Conv", text) + for pair in pair_cmp: + assert ( + pair.side1.op_type != "Conv" or pair.side2.op_type == "FusedConv" + ), f"wrong pair {pair!r}" + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines_exe.py b/_unittests/ut_xrun_doc/test_command_lines_exe.py index c9ff8f15..801d19c8 100644 --- a/_unittests/ut_xrun_doc/test_command_lines_exe.py +++ b/_unittests/ut_xrun_doc/test_command_lines_exe.py @@ -208,7 +208,6 @@ def test_j_parser_compare(self): with redirect_stdout(st): main(["compare", self.dummy_path, self.dummy_path]) text = st.getvalue() - print(text) self.assertIn("done with distance 0", text) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 869f32b3..be58320e 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1533,7 +1533,7 @@ def get_parser_compare() -> ArgumentParser: def _cmd_compare(argv: List[Any]): import onnx - from .torch_onnx.compare import ObsCompare + from .torch_onnx.compare import ObsCompare, ObsComparePair parser = get_parser_compare() args = parser.parse_args(argv[1:]) @@ -1544,8 +1544,7 @@ def _cmd_compare(argv: List[Any]): print("-- starts comparison") dist, _path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2) print(f"-- done with distance {dist}") - for i, (o1, o2) in enumerate(pair_cmp): - print(f"{i:04d} {ObsCompare.to_str(o1)} | {ObsCompare.to_str(o2)}") + print(ObsComparePair.to_str(pair_cmp)) ############# diff --git a/onnx_diagnostic/torch_onnx/compare.py b/onnx_diagnostic/torch_onnx/compare.py index 1af8e611..8364a10f 100644 --- a/onnx_diagnostic/torch_onnx/compare.py +++ b/onnx_diagnostic/torch_onnx/compare.py @@ -1,10 +1,28 @@ import enum from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union +import numpy as np import onnx from ..helpers.onnx_helper import onnx_dtype_name +_NOT_SO_FAR_OPS = [ + {"MatMul", "Gemm", "FusedMatMul"}, + {"Conv", "FusedConv"}, + {"MaxPool"}, +] + + +def _sum_sets(sets): + t = set() + for s in sets: + t |= s + return t + + +_ALL_NOT_SO_FAR_OPS = _sum_sets(_NOT_SO_FAR_OPS) + + def _align(res: str, limit: int) -> str: if len(res) == limit: return res @@ -32,6 +50,7 @@ class ObsCompare: """ The description of an observation, a node, an input, an output, an initializer. + :param position: index of this observation in the original model :param kind: node type, see :class:`ObsType` :param name_or_outputs: name of an initializer or the outputs of a node :param itype: onnx type @@ -41,6 +60,7 @@ class ObsCompare: :param comment: comment, unused """ + position: int kind: ObsType name_or_outputs: Tuple[str] itype: int = 0 @@ -52,6 +72,7 @@ class ObsCompare: def __str__(self) -> str: "usual" els = [ + _align(f"{self.position:04d}", 4), _align(self.kind._name_, 6), _align(onnx_dtype_name(self.itype) if self.itype else "?", 8), _align("?" if self.shape is None else "x".join(map(str, self.shape)), 18), @@ -62,22 +83,31 @@ def __str__(self) -> str: @classmethod def to_str(cls, obs: Optional["ObsCompare"]) -> str: + assert not obs or isinstance(obs, ObsCompare), f"unexpected type {type(obs)}" if obs: return str(obs) - return " " * (6 + 8 + 18 + 15 + 35 + 4) + return " " * (4 + 6 + 8 + 18 + 15 + 35 + 5) def distance(self, obs: "ObsCompare") -> float: """Computes a cost between two observations.""" if self.kind != obs.kind: return 1e6 + d = 0 if self.itype != obs.itype: - return 1e5 + d += 1e5 if self.kind == ObsType.NODE: + cost = 9997 d = 0 if self.op_type != obs.op_type: - is_gemm1 = self.op_type in {"Gemm", "MatMul"} - is_gemm2 = obs.op_type in {"Gemm", "MatMul"} - d += 1e2 if is_gemm1 and is_gemm2 else (1e4 if not is_gemm1 and not is_gemm2 else 1e5) + if self.op_type in _ALL_NOT_SO_FAR_OPS or obs.op_type in _ALL_NOT_SO_FAR_OPS: + d += 1e2 + for aset in _NOT_SO_FAR_OPS: + if self.op_type in aset and obs.op_type in aset: + cost = 97 + elif self.op_type in aset or obs.op_type in aset: + d += 5e4 + else: + d += 9e2 if len(self.name_or_outputs) == 1 and len(obs.name_or_outputs) == 1: if self.name_or_outputs[0] != obs.name_or_outputs[0]: n1 = self.name_or_outputs[0] @@ -87,11 +117,11 @@ def distance(self, obs: "ObsCompare") -> float: if n1 == n2: d += 1 else: - d += 1e4 + d += cost else: a = set(self.name_or_outputs) & set(obs.name_or_outputs) b = set(self.name_or_outputs) | set(obs.name_or_outputs) - d += 1e4 * (len(b) - len(a)) + d += cost * (len(b) - len(a)) return d if self.kind == ObsType.INPUT: return ( @@ -102,7 +132,7 @@ def distance(self, obs: "ObsCompare") -> float: else 0 ) if self.kind == ObsType.INITIALIZER or self.kind == ObsType.SPARSE_INITIALIZER: - return 1e3 if self.itype != obs.itype or self.shape or obs.shape else 0 + return 1e3 if self.itype != obs.itype or self.shape != obs.shape else 0 if self.kind == ObsType.OUTPUT: return ( 999.1 @@ -113,11 +143,79 @@ def distance(self, obs: "ObsCompare") -> float: ) return 1e8 + @classmethod + def obs_sequence_from_model( + cls, + model: Union[onnx.ModelProto, onnx.GraphProto], + ) -> List["ObsCompare"]: + """ + Creates a sequence of observations bases on a model. + + :param model: model + :return: sequence of observations + """ + graph = model if isinstance(model, onnx.GraphProto) else model.graph + + shapes = {} + types = {} + for info in [*graph.value_info, *graph.input, *graph.output]: + if info.type.tensor_type: + t = info.type.tensor_type + shapes[info.name] = tuple((d.dim_param or d.dim_value) for d in t.shape.dim) + types[info.name] = t.elem_type + + seq = [] + for init in graph.initializer: + obs = ObsCompare( + position=len(seq), + kind=ObsType.INITIALIZER, + itype=init.data_type, + shape=tuple(init.dims), + name_or_outputs=(init.name,), + ) + seq.append(obs) + for i, inp in enumerate(graph.input): + obs = ObsCompare( + position=len(seq), + kind=ObsType.INPUT, + itype=inp.type.tensor_type.elem_type, + index=i, + shape=tuple( + (d.dim_param or d.dim_value) for d in inp.type.tensor_type.shape.dim + ), + name_or_outputs=(inp.name,), + ) + seq.append(obs) + for node in graph.node: + obs = ObsCompare( + position=len(seq), + kind=ObsType.NODE, + itype=types.get(node.output[0], 0), + index=i, + shape=shapes.get(node.output[0], None), + name_or_outputs=tuple(node.output), + op_type=node.op_type, + ) + seq.append(obs) + for i, inp in enumerate(graph.output): + obs = ObsCompare( + position=len(seq), + kind=ObsType.OUTPUT, + itype=inp.type.tensor_type.elem_type, + index=i, + shape=tuple( + (d.dim_param or d.dim_value) for d in inp.type.tensor_type.shape.dim + ), + name_or_outputs=(inp.name,), + ) + seq.append(obs) + return seq + @classmethod def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tuple[ float, List[Tuple[int, int]], - List[Tuple[Optional["ObsCompare"], Optional["ObsCompare"]]], + List["ObsComparePair"], ]: """ Computes the distance between two sequences of results. @@ -167,76 +265,41 @@ def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tu for i, j in indices: di = i - last[0] dj = j - last[1] + cost = distance.get((i, j), np.nan) if di == dj == 1: - obs_path.append((s1[i], s2[j])) + obs_path.append(ObsComparePair(s1[i], s2[j], distance=cost)) elif di == 0: - obs_path.append((None, s2[j])) + obs_path.append(ObsComparePair(None, s2[j], distance=cost)) elif dj == 0: - obs_path.append((s1[i], None)) + obs_path.append(ObsComparePair(s1[i], None, distance=cost)) else: raise RuntimeError(f"issue with di={di}, dj={dj}") last = i, j return distance[len(s1) - 1, len(s2) - 1], indices, obs_path - @classmethod - def obs_sequence_from_model( - cls, - model: Union[onnx.ModelProto, onnx.GraphProto], - ) -> List["ObsCompare"]: - """ - Creates a sequence of observations bases on a model. - :param model: model - :return: sequence of observations - """ - graph = model if isinstance(model, onnx.GraphProto) else model.graph +@dataclass +class ObsComparePair: + """ + Defines a pair of comparison objects - shapes = {} - types = {} - for info in graph.value_info: - if info.type.tensor_type: - t = info.type.tensor_type - shapes[info.name] = tuple((d.dim_param or d.dim_value) for d in t.shape.dim) - types[info.name] = t.elem_type + :param side1: object from first side + :param side2: object from first side + :param distance: distance + """ - seq = [] - for init in graph.initializer: - obs = ObsCompare( - kind=ObsType.INITIALIZER, - itype=init.data_type, - name_or_outputs=(init.name,), - ) - seq.append(obs) - for i, inp in enumerate(graph.input): - obs = ObsCompare( - kind=ObsType.INPUT, - itype=inp.type.tensor_type.elem_type, - index=i, - shape=tuple( - (d.dim_param or d.dim_value) for d in inp.type.tensor_type.shape.dim - ), - name_or_outputs=(inp.name,), - ) - seq.append(obs) - for node in graph.node: - obs = ObsCompare( - kind=ObsType.NODE, - itype=types.get(node.output[0], 0), - index=i, - shape=shapes.get(node.output[0], None), - name_or_outputs=tuple(node.output), - op_type=node.op_type, - ) - seq.append(obs) - for i, inp in enumerate(graph.output): - obs = ObsCompare( - kind=ObsType.OUTPUT, - itype=inp.type.tensor_type.elem_type, - index=i, - shape=tuple( - (d.dim_param or d.dim_value) for d in inp.type.tensor_type.shape.dim - ), - name_or_outputs=(inp.name,), - ) - seq.append(obs) - return seq + side1: Optional[ObsCompare] + side2: Optional[ObsCompare] + distance: float + + def __str__(self) -> str: + "nice dislay" + return ( + f"{self.distance:.4e} | " + f"{ObsCompare.to_str(self.side1)} | {ObsCompare.to_str(self.side2)}" + ) + + @classmethod + def to_str(cls, seq: List["ObsComparePair"]) -> str: + """Displays every pair in text.""" + return "\n".join([f"{str(pair)}" for pair in seq]) From 0001c29df06949f64b68354b84b2ecd976ba9085 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 10 Dec 2025 17:37:17 +0100 Subject: [PATCH 11/12] fix mypy --- _doc/cmds/compare.rst | 57 +++++++++--- _unittests/ut_torch_onnx/test_compare.py | 8 +- onnx_diagnostic/_command_lines_parser.py | 2 +- onnx_diagnostic/torch_onnx/compare.py | 110 +++++++++++++++++------ 4 files changed, 129 insertions(+), 48 deletions(-) diff --git a/_doc/cmds/compare.rst b/_doc/cmds/compare.rst index be5fcc81..fb601077 100644 --- a/_doc/cmds/compare.rst +++ b/_doc/cmds/compare.rst @@ -22,17 +22,46 @@ Example python -m onnx_diagnostic compare -.. code-block:: text - - -- loading 'two_nodes.onnx' - -- loading 'two_nodes.onnx' - -- starts comparison - -- done with distance 0 - 0000 INITIA FLOAT ? encoder.encoders.0.layer_norm_att.w | INITIA FLOAT ? encoder.encoders.0.layer_norm_att.w - 0001 INITIA FLOAT ? encoder.encoders.0.layer_norm_att.b | INITIA FLOAT ? encoder.encoders.0.layer_norm_att.b - 0002 INPUT FLOAT s0x(((s1 - 1)//8)) linear | INPUT FLOAT s0x(((s1 - 1)//8)) linear - 0003 INPUT FLOAT s0x(((s1 - 1)//8)) mul_178 | INPUT FLOAT s0x(((s1 - 1)//8)) mul_178 - 0004 NODE FLOAT s0x(((s1 - 1)//8)) Add add_256 | NODE FLOAT s0x(((s1 - 1)//8)) Add add_256 - 0005 NODE FLOAT s0x(((s1 - 1)//8)) LayerNormalizat layer_norm_1 | NODE FLOAT s0x(((s1 - 1)//8)) LayerNormalizat layer_norm_1 - 0006 OUTPUT FLOAT s0x(((s1 - 1)//8)) layer_norm_1 | OUTPUT FLOAT s0x(((s1 - 1)//8)) layer_norm_1 - 0007 OUTPUT FLOAT s0x(((s1 - 1)//8)) add_256 | OUTPUT FLOAT s0x(((s1 - 1)//8)) add_256 +This example is based on python but it produces the same output +than the command line. + +.. runpython:: + :showcode: + + import torch + from onnx_diagnostic.export.api import to_onnx + from onnx_diagnostic.torch_onnx.compare import ObsComparePair, ObsCompare + + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 16, 5) + self.fc1 = torch.nn.Linear(144, 64) + self.fc2 = torch.nn.Linear(64, 128) + self.fc3 = torch.nn.Linear(128, 10) + + def forward(self, x): + x = torch.nn.functional.max_pool2d(torch.nn.functional.relu(self.conv1(x)), (4, 4)) + # x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = torch.flatten(x, 1) + x = torch.nn.functional.relu(self.fc1(x)) + x = torch.nn.functional.relu(self.fc2(x)) + y = self.fc3(x) + return y + + + model = Model() + x = torch.randn((2, 3, 16, 17), dtype=torch.float32) + dynamic_shapes = ({0: "batch", 3: "dim"},) + onnx_optimized = to_onnx( + model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=True + ).model_proto + onnx_not_optimized = to_onnx( + model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=False + ).model_proto + seq1 = ObsCompare.obs_sequence_from_model(onnx_not_optimized) + seq2 = ObsCompare.obs_sequence_from_model(onnx_optimized) + _dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2) + text = ObsComparePair.to_str(pair_cmp) + print(text) diff --git a/_unittests/ut_torch_onnx/test_compare.py b/_unittests/ut_torch_onnx/test_compare.py index ecc96771..9f7e4334 100644 --- a/_unittests/ut_torch_onnx/test_compare.py +++ b/_unittests/ut_torch_onnx/test_compare.py @@ -50,7 +50,7 @@ def _get_model(self, cast=True): def test_edit_distance_0(self): model = self._get_model() seq = ObsCompare.obs_sequence_from_model(model) - dist, path, pair_cmp = ObsCompare.distance_sequence(seq, seq) + dist, path, pair_cmp = ObsComparePair.distance_sequence(seq, seq) self.assertEqual(dist, 0) self.assertEqual(path, [(i, i) for i in range(len(path))]) self.assertEqual(len(path), len(pair_cmp)) @@ -69,7 +69,7 @@ def test_edit_distance_1(self): model2 = self._get_model(cast=False) seq1 = ObsCompare.obs_sequence_from_model(model) seq2 = ObsCompare.obs_sequence_from_model(model2) - dist, path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2) + dist, path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2) self.assertGreaterOrEqual(dist, 900) self.assertEqual([(i, i) for i in range(len(path))], path) self.assertEqual(len(path), len(pair_cmp)) @@ -127,7 +127,7 @@ def forward(self, x): self.assert_onnx_disc("test_comp_model_gemm", onx, model, (x,), use_ort=True) seq1 = ObsCompare.obs_sequence_from_model(onx) seq2 = ObsCompare.obs_sequence_from_model(onx) - dist, _path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2) + dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2) text = str(pair_cmp[0]) self.assertIn("0000 INITIA", text) self.assertNotIn("(", text) @@ -142,7 +142,7 @@ def forward(self, x): onx0 = onnx.load(onx_file0) seq1 = ObsCompare.obs_sequence_from_model(onx0) seq2 = ObsCompare.obs_sequence_from_model(onx) - _dist, _path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2) + _dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2) text = ObsComparePair.to_str(pair_cmp) self.assertIn("Conv", text) for pair in pair_cmp: diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index be58320e..bb080332 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1542,7 +1542,7 @@ def _cmd_compare(argv: List[Any]): print(f"-- loading {args.model2!r}") seq2 = ObsCompare.obs_sequence_from_model(onnx.load(args.model2, load_external_data=False)) print("-- starts comparison") - dist, _path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2) + dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2) print(f"-- done with distance {dist}") print(ObsComparePair.to_str(pair_cmp)) diff --git a/onnx_diagnostic/torch_onnx/compare.py b/onnx_diagnostic/torch_onnx/compare.py index 8364a10f..bab55173 100644 --- a/onnx_diagnostic/torch_onnx/compare.py +++ b/onnx_diagnostic/torch_onnx/compare.py @@ -92,7 +92,7 @@ def distance(self, obs: "ObsCompare") -> float: """Computes a cost between two observations.""" if self.kind != obs.kind: return 1e6 - d = 0 + d: float = 0 if self.itype != obs.itype: d += 1e5 if self.kind == ObsType.NODE: @@ -116,6 +116,12 @@ def distance(self, obs: "ObsCompare") -> float: n2 = n2.replace("_", "") if n1 == n2: d += 1 + elif (n1.startswith(("val_", "_onx_")) or "::" in n1 or "--" in n1) and ( + n2.startswith(("val_", "_onx_")) or "::" in n2 or "--" in n2 + ): + # These are name given the exporter + # and not inspired from the model itself. + d += cost / 100 else: d += cost else: @@ -211,6 +217,33 @@ def obs_sequence_from_model( seq.append(obs) return seq + +@dataclass +class ObsComparePair: + """ + Defines a pair of comparison objects + + :param side1: object from first side + :param side2: object from first side + :param distance: distance + """ + + side1: Optional[ObsCompare] + side2: Optional[ObsCompare] + distance: float + + def __str__(self) -> str: + "nice display" + return ( + f"{self.distance:.4e} | " + f"{ObsCompare.to_str(self.side1)} | {ObsCompare.to_str(self.side2)}" + ) + + @classmethod + def to_str(cls, seq: List["ObsComparePair"]) -> str: + """Displays every pair in text.""" + return "\n".join([f"{str(pair)}" for pair in seq]) + @classmethod def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tuple[ float, @@ -223,6 +256,52 @@ def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tu :param s1: first sequence :param s2: second sequence :return: distance and alignment + + An example: + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.export.api import to_onnx + from onnx_diagnostic.torch_onnx.compare import ObsComparePair, ObsCompare + + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 16, 5) + self.fc1 = torch.nn.Linear(144, 64) + self.fc2 = torch.nn.Linear(64, 128) + self.fc3 = torch.nn.Linear(128, 10) + + def forward(self, x): + x = torch.nn.functional.max_pool2d( + torch.nn.functional.relu(self.conv1(x)), + (4, 4), + ) + # x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = torch.flatten(x, 1) + x = torch.nn.functional.relu(self.fc1(x)) + x = torch.nn.functional.relu(self.fc2(x)) + y = self.fc3(x) + return y + + + model = Model() + x = torch.randn((2, 3, 16, 17), dtype=torch.float32) + dynamic_shapes = ({0: "batch", 3: "dim"},) + onnx_optimized = to_onnx( + model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=True + ).model_proto + onnx_not_optimized = to_onnx( + model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=False + ).model_proto + seq1 = ObsCompare.obs_sequence_from_model(onnx_not_optimized) + seq2 = ObsCompare.obs_sequence_from_model(onnx_optimized) + _dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2) + text = ObsComparePair.to_str(pair_cmp) + print(text) """ delay = max(50, abs(len(s2) - len(s1)) + 1) distance: Dict[Tuple[int, int], Union[int, float]] = {(-1, -1): 0} @@ -260,7 +339,7 @@ def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tu way.append(last) last = predecessor[last] indices = list(reversed(way))[1:] - obs_path: List[Tuple[Optional[ObsCompare], Optional[ObsCompare]]] = [] + obs_path: List[ObsComparePair] = [] last = -1, -1 for i, j in indices: di = i - last[0] @@ -276,30 +355,3 @@ def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tu raise RuntimeError(f"issue with di={di}, dj={dj}") last = i, j return distance[len(s1) - 1, len(s2) - 1], indices, obs_path - - -@dataclass -class ObsComparePair: - """ - Defines a pair of comparison objects - - :param side1: object from first side - :param side2: object from first side - :param distance: distance - """ - - side1: Optional[ObsCompare] - side2: Optional[ObsCompare] - distance: float - - def __str__(self) -> str: - "nice dislay" - return ( - f"{self.distance:.4e} | " - f"{ObsCompare.to_str(self.side1)} | {ObsCompare.to_str(self.side2)}" - ) - - @classmethod - def to_str(cls, seq: List["ObsComparePair"]) -> str: - """Displays every pair in text.""" - return "\n".join([f"{str(pair)}" for pair in seq]) From 77d8b8d4ee673952c9e430b87ac2109943e17837 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 10 Dec 2025 17:39:31 +0100 Subject: [PATCH 12/12] mypy --- onnx_diagnostic/torch_onnx/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_onnx/compare.py b/onnx_diagnostic/torch_onnx/compare.py index bab55173..25e8ac44 100644 --- a/onnx_diagnostic/torch_onnx/compare.py +++ b/onnx_diagnostic/torch_onnx/compare.py @@ -170,7 +170,7 @@ def obs_sequence_from_model( shapes[info.name] = tuple((d.dim_param or d.dim_value) for d in t.shape.dim) types[info.name] = t.elem_type - seq = [] + seq: List[ObsCompare] = [] for init in graph.initializer: obs = ObsCompare( position=len(seq),