Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.8.6
+++++

* :pr:`353`: add command line to compare two onnx models

0.8.5
+++++

Expand Down
7 changes: 7 additions & 0 deletions _doc/api/torch_onnx/compare.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.torch_onnx.compare
==================================

.. automodule:: onnx_diagnostic.torch_onnx.compare
:members:
:no-undoc-members:
1 change: 1 addition & 0 deletions _doc/api/torch_onnx/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ onnx_diagnostic.torch_onnx
:maxdepth: 1
:caption: submodules

compare
runtime_info
sbs
sbs_dataclasses
Expand Down
38 changes: 38 additions & 0 deletions _doc/cmds/compare.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
-m onnx_diagnostic compare ... compares two models
==================================================

Description
+++++++++++

The command lines compares two models assuming they represent
the same models and most parts of both are the same.
Different options were used to export or an optimization
was different. This highlights the differences.

.. runpython::

from onnx_diagnostic._command_lines_parser import get_parser_compare

get_parser_compare().print_help()

Example
+++++++

.. code-block:: bash

python -m onnx_diagnostic compare <mode1.onnx> <mode1.onnx>

.. code-block:: text

-- loading 'two_nodes.onnx'
-- loading 'two_nodes.onnx'
-- starts comparison
-- done with distance 0
0000 INITIA FLOAT ? encoder.encoders.0.layer_norm_att.w | INITIA FLOAT ? encoder.encoders.0.layer_norm_att.w
0001 INITIA FLOAT ? encoder.encoders.0.layer_norm_att.b | INITIA FLOAT ? encoder.encoders.0.layer_norm_att.b
0002 INPUT FLOAT s0x(((s1 - 1)//8)) linear | INPUT FLOAT s0x(((s1 - 1)//8)) linear
0003 INPUT FLOAT s0x(((s1 - 1)//8)) mul_178 | INPUT FLOAT s0x(((s1 - 1)//8)) mul_178
0004 NODE FLOAT s0x(((s1 - 1)//8)) Add add_256 | NODE FLOAT s0x(((s1 - 1)//8)) Add add_256
0005 NODE FLOAT s0x(((s1 - 1)//8)) LayerNormalizat layer_norm_1 | NODE FLOAT s0x(((s1 - 1)//8)) LayerNormalizat layer_norm_1
0006 OUTPUT FLOAT s0x(((s1 - 1)//8)) layer_norm_1 | OUTPUT FLOAT s0x(((s1 - 1)//8)) layer_norm_1
0007 OUTPUT FLOAT s0x(((s1 - 1)//8)) add_256 | OUTPUT FLOAT s0x(((s1 - 1)//8)) add_256
1 change: 1 addition & 0 deletions _doc/cmds/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Command Lines
.. toctree::
:maxdepth: 1

compare
config
sbs
validate
99 changes: 99 additions & 0 deletions _unittests/ut_torch_onnx/test_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import unittest
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.torch_onnx.compare import ObsCompare

TFLOAT = onnx.TensorProto.FLOAT
TINT64 = onnx.TensorProto.INT64


class TestCompare(ExtTestCase):
def _get_model(self, cast=True):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
(
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1)
if cast
else oh.make_node("Identity", ["xmc2"], ["xm2"])
),
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
],
"dummy",
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
[
onh.from_array(
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
),
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
onh.from_array(np.array([1], dtype=np.int64), name="un"),
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)
return model

def test_edit_distance_0(self):
model = self._get_model()
seq = ObsCompare.obs_sequence_from_model(model)
dist, path, pair_cmp = ObsCompare.distance_sequence(seq, seq)
self.assertEqual(dist, 0)
self.assertEqual(path, [(i, i) for i in range(len(path))])
self.assertEqual(len(path), len(pair_cmp))
uni = set()
for o1, o2 in pair_cmp:
self.assertIsInstance(o1, ObsCompare)
self.assertIsInstance(o2, ObsCompare)
self.assertEqual(o1, o2)
row = f"{o1} | {o2}"
uni.add(len(row))
self.assertEqual(len(uni), 1)

def test_edit_distance_1(self):
model = self._get_model()
model2 = self._get_model(cast=False)
seq1 = ObsCompare.obs_sequence_from_model(model)
seq2 = ObsCompare.obs_sequence_from_model(model2)
dist, path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2)
self.assertGreater(dist, 2000)
expected_path = [
*[(i, i) for i in range(11)],
*[(10, 11), (11, 11)],
*[(i, i) for i in range(12, len(seq1))],
]
self.assertEqual(expected_path, path)
self.assertEqual(len(path), len(pair_cmp))
n1, n2, n12 = 0, 0, 0
for o1, o2 in pair_cmp:
if o1:
self.assertIsInstance(o1, ObsCompare)
else:
n1 += 1
if o2:
self.assertIsInstance(o2, ObsCompare)
else:
n2 += 1
if o1 and o2:
self.assertEqual(o1, o2)
elif not o1 and not o2:
n12 += 1
self.assertEqual(n1, 1)
self.assertEqual(n2, 1)
self.assertEqual(n12, 0)


if __name__ == "__main__":
unittest.main(verbosity=2)
8 changes: 8 additions & 0 deletions _unittests/ut_xrun_doc/test_command_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from onnx_diagnostic._command_lines_parser import (
get_main_parser,
get_parser_agg,
get_parser_compare,
get_parser_config,
get_parser_dot,
get_parser_find,
Expand Down Expand Up @@ -170,6 +171,13 @@ def test_parser_dot(self):
text = st.getvalue()
self.assertIn("--run", text)

def test_parser_compare(self):
st = StringIO()
with redirect_stdout(st):
get_parser_compare().print_help()
text = st.getvalue()
self.assertIn("compare", text)


if __name__ == "__main__":
unittest.main(verbosity=2)
8 changes: 8 additions & 0 deletions _unittests/ut_xrun_doc/test_command_lines_exe.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ def forward(self, x):
# text is empty is dot is not installed
self.assertIn("converts into dot", text)

def test_j_parser_compare(self):
st = StringIO()
with redirect_stdout(st):
main(["compare", self.dummy_path, self.dummy_path])
text = st.getvalue()
print(text)
self.assertIn("done with distance 0", text)


if __name__ == "__main__":
unittest.main(verbosity=2)
47 changes: 47 additions & 0 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,51 @@ def _size(name):
print("-- done")


def get_parser_compare() -> ArgumentParser:
parser = ArgumentParser(
prog="compare",
description=textwrap.dedent(
"""
Compares two onnx models by aligning the nodes between both models.
This is done through an edit distance.
"""
),
epilog=textwrap.dedent(
"""
Each element (initializer, input, node, output) of the model
is converted into an observation. Then it defines a distance between
two elements. And finally, it finds the best alignment with
an edit distance.
"""
),
)
parser.add_argument("model1", type=str, help="first model to compare")
parser.add_argument("model2", type=str, help="second model to compare")
return parser


def _cmd_compare(argv: List[Any]):
import onnx
from .torch_onnx.compare import ObsCompare

parser = get_parser_compare()
args = parser.parse_args(argv[1:])
print(f"-- loading {args.model1!r}")
seq1 = ObsCompare.obs_sequence_from_model(onnx.load(args.model1, load_external_data=False))
print(f"-- loading {args.model2!r}")
seq2 = ObsCompare.obs_sequence_from_model(onnx.load(args.model2, load_external_data=False))
print("-- starts comparison")
dist, _path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2)
print(f"-- done with distance {dist}")
for i, (o1, o2) in enumerate(pair_cmp):
print(f"{i:04d} {o1} | {o2}")


#############
# main parser
#############


def get_main_parser() -> ArgumentParser:
parser = ArgumentParser(
prog="onnx_diagnostic",
Expand Down Expand Up @@ -1555,6 +1600,7 @@ def get_main_parser() -> ArgumentParser:
def main(argv: Optional[List[Any]] = None):
fcts = dict(
agg=_cmd_agg,
compare=_cmd_compare,
config=_cmd_config,
dot=_cmd_dot,
exportsample=_cmd_export_sample,
Expand All @@ -1580,6 +1626,7 @@ def main(argv: Optional[List[Any]] = None):
else:
parsers = dict(
agg=get_parser_agg,
compare=get_parser_compare,
config=get_parser_config,
dot=get_parser_dot,
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
Expand Down
Loading
Loading