Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.6
+++++

* :pr:`357`: complete simple_loop_for, an easier to rewrite loops
* :pr:`356`: include qwen embedding part
* :pr:`355`: better command line to export models
* :pr:`353`, :pr:`354`: add command line to compare two onnx models
Expand Down
7 changes: 7 additions & 0 deletions _doc/api/export/cf_simple_loop_for.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.export.cf_simple_loop_for
=========================================

.. automodule:: onnx_diagnostic.export.cf_simple_loop_for
:members:
:no-undoc-members:
7 changes: 7 additions & 0 deletions _doc/api/export/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ onnx_diagnostic.export
shape_helper
validate

.. toctree::
:maxdepth: 1
:caption: higher order ops

cf_simple_loop_for


CoupleInputsDynamicShapes
+++++++++++++++++++++++++

Expand Down
270 changes: 270 additions & 0 deletions _unittests/ut_export/test_cf_simple_loop_for.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import unittest
from typing import Tuple
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
from onnx_diagnostic.export.control_flow_onnx import (
enable_code_export_control_flow,
)
from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for, SimpleLoopForOp


class TestCfSimpleLoopFor(ExtTestCase):
@requires_torch("2.9.99")
def test_simple_loop_for_int(self):
class Model(torch.nn.Module):
def forward(self, x):
def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
return (x[: i.item() + 1].unsqueeze(1),)

return simple_loop_for(4, body, (x,))

model = Model()
x = torch.arange(10, dtype=torch.float32)
expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1)
got = model(x)
self.assertEqualArray(expected, got)

with enable_code_export_control_flow():
got = model(x)
self.assertEqualArray(expected, got)

ep = torch.export.export(
model, (x,), dynamic_shapes=(({0: torch.export.Dim.DYNAMIC},))
)
check = []
for node in ep.graph.nodes:
if isinstance(node.target, SimpleLoopForOp):
check.append(node)
# Loop should be unrolled.
self.assertEqual(len(check), 0)

@requires_torch("2.9.99")
def test_simple_loop_for_no_inputs(self):
class Model(torch.nn.Module):
def forward(self, n_iter, x):
def body(i: torch.Tensor) -> Tuple[torch.Tensor]:
return (torch.arange(i + 1, dtype=torch.int64),)

y = simple_loop_for(n_iter, body)
torch._check(isinstance(y, torch.Tensor), lambda: f"y is {type(y)}")
return x.unsqueeze(1) + y.unsqueeze(0).to(x.device)

model = Model()
n_iter = torch.tensor(4, dtype=torch.int64)
x = torch.arange(8, dtype=torch.float32)
expected = x.reshape((-1, 1)) + torch.tensor(
[[0, 0, 1, 0, 1, 2, 0, 1, 2, 3]], dtype=x.dtype
)
got = model(n_iter, x)
self.assertEqualArray(expected, got)

with enable_code_export_control_flow():
got = model(n_iter, x)
self.assertEqualArray(expected, got)

ep = torch.export.export(
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
check = []
for node in ep.graph.nodes:
if isinstance(node.target, SimpleLoopForOp):
check.append(node)
self.assertEqual(len(check), 1)

@requires_torch("2.9.99")
def test_simple_loop_for_1(self):
class Model(torch.nn.Module):
def forward(self, n_iter, x):
def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
return (x[: i.item() + 1].unsqueeze(1),)

return simple_loop_for(n_iter, body, (x,))

model = Model()
n_iter = torch.tensor(4, dtype=torch.int64)
x = torch.arange(10, dtype=torch.float32)
expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1)
got = model(n_iter, x)
self.assertEqualArray(expected, got)

with enable_code_export_control_flow():
got = model(n_iter, x)
self.assertEqualArray(expected, got)

ep = torch.export.export(
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
check = []
for node in ep.graph.nodes:
if isinstance(node.target, SimpleLoopForOp):
check.append(node)
self.assertEqual(len(check), 1)

@requires_torch("2.9.99")
def test_simple_loop_for_2(self):
class Model(torch.nn.Module):
def forward(self, n_iter, x):
def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(1))

return simple_loop_for(n_iter, body, (x,))

model = Model()
n_iter = torch.tensor(4, dtype=torch.int64)
x = torch.arange(10, dtype=torch.float32)
expected = (
torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1),
torch.tensor(
[
1,
2,
3,
4,
5,
6,
7,
8,
9,
2,
3,
4,
5,
6,
7,
8,
9,
3,
4,
5,
6,
7,
8,
9,
4,
5,
6,
7,
8,
9,
],
dtype=x.dtype,
).unsqueeze(1),
)
got = model(n_iter, x)
self.assertEqualArray(expected[0], got[0])
self.assertEqualArray(expected[1], got[1])

with enable_code_export_control_flow():
got = model(n_iter, x)
self.assertEqualArray(expected[0], got[0])
self.assertEqualArray(expected[1], got[1])

ep = torch.export.export(
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
check = []
for node in ep.graph.nodes:
if isinstance(node.target, SimpleLoopForOp):
check.append(node)
self.assertEqual(len(check), 1)

@requires_torch("2.9.99")
def test_simple_loop_for_2_concatenation_dims(self):
class Model(torch.nn.Module):
def forward(self, n_iter, x):
def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(0))

return simple_loop_for(n_iter, body, (x,), (0, 1))

model = Model()
n_iter = torch.tensor(4, dtype=torch.int64)
x = torch.arange(10, dtype=torch.float32)
expected = (
torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1),
torch.tensor(
[
1,
2,
3,
4,
5,
6,
7,
8,
9,
2,
3,
4,
5,
6,
7,
8,
9,
3,
4,
5,
6,
7,
8,
9,
4,
5,
6,
7,
8,
9,
],
dtype=x.dtype,
).unsqueeze(0),
)
got = model(n_iter, x)
self.assertEqualArray(expected[0], got[0])
self.assertEqualArray(expected[1], got[1])

with enable_code_export_control_flow():
got = model(n_iter, x)
self.assertEqualArray(expected[0], got[0])
self.assertEqualArray(expected[1], got[1])

ep = torch.export.export(
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
check = []
for node in ep.graph.nodes:
if isinstance(node.target, SimpleLoopForOp):
check.append(node)
self.assertEqual(len(check), 1)

@requires_torch("2.9.99")
def test_simple_loop_for_1_with_concatenation_dims(self):
class Model(torch.nn.Module):
def forward(self, n_iter, x):
def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
return (x[: i.item() + 1].unsqueeze(0),)

return simple_loop_for(n_iter, body, (x,), 1)

model = Model()
n_iter = torch.tensor(4, dtype=torch.int64)
x = torch.arange(10, dtype=torch.float32)
expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(0)
got = model(n_iter, x)
self.assertEqualArray(expected, got)

with enable_code_export_control_flow():
got = model(n_iter, x)
self.assertEqualArray(expected, got)

ep = torch.export.export(
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
check = []
for node in ep.graph.nodes:
if isinstance(node.target, SimpleLoopForOp):
check.append(node)
self.assertEqual(len(check), 1)


if __name__ == "__main__":
unittest.main(verbosity=2)
32 changes: 0 additions & 32 deletions _unittests/ut_export/test_control_flow.py

This file was deleted.

34 changes: 2 additions & 32 deletions _unittests/ut_export/test_control_flow_onnx.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,13 @@
import unittest
from typing import Tuple
import torch
from onnxscript import script, FLOAT, INT64
from onnxscript import opset18 as op
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, never_test
from onnx_diagnostic.export.control_flow_onnx import (
enable_code_export_control_flow,
loop_for_onnx,
)
from onnx_diagnostic.export.control_flow_research import simple_loop_for as loop_for_r
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
from onnx_diagnostic.export.control_flow_onnx import loop_for_onnx
from onnx_diagnostic.export.api import to_onnx


class TestControlFlowOnnx(ExtTestCase):
@never_test()
def test_loop_one_research(self):
class Model(torch.nn.Module):
def forward(self, n_iter, x):
def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]:
return (x[: i.item() + 1].unsqueeze(1),)

return loop_for_r(n_iter, body, (x,))[0]

model = Model()
n_iter = torch.tensor(4, dtype=torch.int64)
x = torch.arange(10, dtype=torch.float32)
expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1)
got = model(n_iter, x)
self.assertEqualArray(expected, got)

with enable_code_export_control_flow():
got = model(n_iter, x)
self.assertEqualArray(expected, got)

ep = torch.export.export(
model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
print(ep)

def test_onnxscript_loop(self):
@script()
def concatenation(N: INT64[1], x: FLOAT[None]) -> FLOAT[None, 1]:
Expand Down
Loading
Loading