Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions _doc/cmds/compare.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,46 @@ Example

python -m onnx_diagnostic compare <mode1.onnx> <mode1.onnx>

.. code-block:: text

-- loading 'two_nodes.onnx'
-- loading 'two_nodes.onnx'
-- starts comparison
-- done with distance 0
0000 INITIA FLOAT ? encoder.encoders.0.layer_norm_att.w | INITIA FLOAT ? encoder.encoders.0.layer_norm_att.w
0001 INITIA FLOAT ? encoder.encoders.0.layer_norm_att.b | INITIA FLOAT ? encoder.encoders.0.layer_norm_att.b
0002 INPUT FLOAT s0x(((s1 - 1)//8)) linear | INPUT FLOAT s0x(((s1 - 1)//8)) linear
0003 INPUT FLOAT s0x(((s1 - 1)//8)) mul_178 | INPUT FLOAT s0x(((s1 - 1)//8)) mul_178
0004 NODE FLOAT s0x(((s1 - 1)//8)) Add add_256 | NODE FLOAT s0x(((s1 - 1)//8)) Add add_256
0005 NODE FLOAT s0x(((s1 - 1)//8)) LayerNormalizat layer_norm_1 | NODE FLOAT s0x(((s1 - 1)//8)) LayerNormalizat layer_norm_1
0006 OUTPUT FLOAT s0x(((s1 - 1)//8)) layer_norm_1 | OUTPUT FLOAT s0x(((s1 - 1)//8)) layer_norm_1
0007 OUTPUT FLOAT s0x(((s1 - 1)//8)) add_256 | OUTPUT FLOAT s0x(((s1 - 1)//8)) add_256
This example is based on python but it produces the same output
than the command line.

.. runpython::
:showcode:

import torch
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.torch_onnx.compare import ObsComparePair, ObsCompare


class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 16, 5)
self.fc1 = torch.nn.Linear(144, 64)
self.fc2 = torch.nn.Linear(64, 128)
self.fc3 = torch.nn.Linear(128, 10)

def forward(self, x):
x = torch.nn.functional.max_pool2d(torch.nn.functional.relu(self.conv1(x)), (4, 4))
# x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = torch.flatten(x, 1)
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
y = self.fc3(x)
return y


model = Model()
x = torch.randn((2, 3, 16, 17), dtype=torch.float32)
dynamic_shapes = ({0: "batch", 3: "dim"},)
onnx_optimized = to_onnx(
model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=True
).model_proto
onnx_not_optimized = to_onnx(
model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=False
).model_proto
seq1 = ObsCompare.obs_sequence_from_model(onnx_not_optimized)
seq2 = ObsCompare.obs_sequence_from_model(onnx_optimized)
_dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
text = ObsComparePair.to_str(pair_cmp)
print(text)
88 changes: 72 additions & 16 deletions _unittests/ut_torch_onnx/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from onnx_diagnostic.ext_test_case import ExtTestCase
from onnx_diagnostic.torch_onnx.compare import ObsCompare
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.torch_onnx.compare import ObsCompare, ObsComparePair

TFLOAT = onnx.TensorProto.FLOAT
TINT64 = onnx.TensorProto.INT64
Expand Down Expand Up @@ -49,12 +50,13 @@ def _get_model(self, cast=True):
def test_edit_distance_0(self):
model = self._get_model()
seq = ObsCompare.obs_sequence_from_model(model)
dist, path, pair_cmp = ObsCompare.distance_sequence(seq, seq)
dist, path, pair_cmp = ObsComparePair.distance_sequence(seq, seq)
self.assertEqual(dist, 0)
self.assertEqual(path, [(i, i) for i in range(len(path))])
self.assertEqual(len(path), len(pair_cmp))
uni = set()
for o1, o2 in pair_cmp:
for pair in pair_cmp:
o1, o2 = pair.side1, pair.side2
self.assertIsInstance(o1, ObsCompare)
self.assertIsInstance(o2, ObsCompare)
self.assertEqual(o1, o2)
Expand All @@ -67,17 +69,13 @@ def test_edit_distance_1(self):
model2 = self._get_model(cast=False)
seq1 = ObsCompare.obs_sequence_from_model(model)
seq2 = ObsCompare.obs_sequence_from_model(model2)
dist, path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2)
self.assertGreater(dist, 2000)
expected_path = [
*[(i, i) for i in range(11)],
*[(10, 11), (11, 11)],
*[(i, i) for i in range(12, len(seq1))],
]
self.assertEqual(expected_path, path)
dist, path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
self.assertGreaterOrEqual(dist, 900)
self.assertEqual([(i, i) for i in range(len(path))], path)
self.assertEqual(len(path), len(pair_cmp))
n1, n2, n12 = 0, 0, 0
for o1, o2 in pair_cmp:
for pair in pair_cmp:
o1, o2 = pair.side1, pair.side2
if o1:
self.assertIsInstance(o1, ObsCompare)
else:
Expand All @@ -87,13 +85,71 @@ def test_edit_distance_1(self):
else:
n2 += 1
if o1 and o2:
self.assertEqual(o1, o2)
pass
elif not o1 and not o2:
n12 += 1
self.assertEqual(n1, 1)
self.assertEqual(n2, 1)
self.assertEqual(n1, 0)
self.assertEqual(n2, 0)
self.assertEqual(n12, 0)

@ignore_warnings(DeprecationWarning)
def test_comp_model_gemm(self):
import torch

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 16, 5)
self.fc1 = torch.nn.Linear(144, 64)
self.fc2 = torch.nn.Linear(64, 128)
self.fc3 = torch.nn.Linear(128, 10)

def forward(self, x):
x = torch.nn.functional.max_pool2d(
torch.nn.functional.relu(self.conv1(x)), (4, 4)
)
# x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = torch.flatten(x, 1)
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
y = self.fc3(x)
return y

model = Model()
x = torch.randn((2, 3, 16, 17), dtype=torch.float32)
model(x)
dynamic_shapes = ({0: "batch", 3: "dim"},)
onx_file = self.get_dump_file("test_comp_model_gemm.onnx")
to_onnx(
model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=True
).save(onx_file)
onx = onnx.load(onx_file)
self.assert_onnx_disc("test_comp_model_gemm", onx, model, (x,), use_ort=True)
seq1 = ObsCompare.obs_sequence_from_model(onx)
seq2 = ObsCompare.obs_sequence_from_model(onx)
dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
text = str(pair_cmp[0])
self.assertIn("0000 INITIA", text)
self.assertNotIn("(", text)
text = ObsComparePair.to_str(pair_cmp)
self.assertEqual(dist, 0)
self.assertNotIn("?", text)
self.assertIn("0013 NODE", text)
onx_file0 = self.get_dump_file("test_comp_model_gemm0.onnx")
to_onnx(
model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=False
).save(onx_file0)
onx0 = onnx.load(onx_file0)
seq1 = ObsCompare.obs_sequence_from_model(onx0)
seq2 = ObsCompare.obs_sequence_from_model(onx)
_dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
text = ObsComparePair.to_str(pair_cmp)
self.assertIn("Conv", text)
for pair in pair_cmp:
assert (
pair.side1.op_type != "Conv" or pair.side2.op_type == "FusedConv"
), f"wrong pair {pair!r}"


if __name__ == "__main__":
unittest.main(verbosity=2)
1 change: 0 additions & 1 deletion _unittests/ut_xrun_doc/test_command_lines_exe.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def test_j_parser_compare(self):
with redirect_stdout(st):
main(["compare", self.dummy_path, self.dummy_path])
text = st.getvalue()
print(text)
self.assertIn("done with distance 0", text)


Expand Down
7 changes: 3 additions & 4 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,7 +1533,7 @@ def get_parser_compare() -> ArgumentParser:

def _cmd_compare(argv: List[Any]):
import onnx
from .torch_onnx.compare import ObsCompare
from .torch_onnx.compare import ObsCompare, ObsComparePair

parser = get_parser_compare()
args = parser.parse_args(argv[1:])
Expand All @@ -1542,10 +1542,9 @@ def _cmd_compare(argv: List[Any]):
print(f"-- loading {args.model2!r}")
seq2 = ObsCompare.obs_sequence_from_model(onnx.load(args.model2, load_external_data=False))
print("-- starts comparison")
dist, _path, pair_cmp = ObsCompare.distance_sequence(seq1, seq2)
dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
print(f"-- done with distance {dist}")
for i, (o1, o2) in enumerate(pair_cmp):
print(f"{i:04d} {o1} | {o2}")
print(ObsComparePair.to_str(pair_cmp))


#############
Expand Down
Loading
Loading