diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index b3c0d3b8..f28a427e 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.6 +++++ +* :pr:`357`: complete simple_loop_for, an easier to rewrite loops * :pr:`356`: include qwen embedding part * :pr:`355`: better command line to export models * :pr:`353`, :pr:`354`: add command line to compare two onnx models diff --git a/_doc/api/export/cf_simple_loop_for.rst b/_doc/api/export/cf_simple_loop_for.rst new file mode 100644 index 00000000..56bfb314 --- /dev/null +++ b/_doc/api/export/cf_simple_loop_for.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.export.cf_simple_loop_for +========================================= + +.. automodule:: onnx_diagnostic.export.cf_simple_loop_for + :members: + :no-undoc-members: diff --git a/_doc/api/export/index.rst b/_doc/api/export/index.rst index 8c806fe6..1b19dd5b 100644 --- a/_doc/api/export/index.rst +++ b/_doc/api/export/index.rst @@ -12,6 +12,13 @@ onnx_diagnostic.export shape_helper validate +.. toctree:: + :maxdepth: 1 + :caption: higher order ops + + cf_simple_loop_for + + CoupleInputsDynamicShapes +++++++++++++++++++++++++ diff --git a/_unittests/ut_export/test_cf_simple_loop_for.py b/_unittests/ut_export/test_cf_simple_loop_for.py new file mode 100644 index 00000000..4ab63310 --- /dev/null +++ b/_unittests/ut_export/test_cf_simple_loop_for.py @@ -0,0 +1,270 @@ +import unittest +from typing import Tuple +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch +from onnx_diagnostic.export.control_flow_onnx import ( + enable_code_export_control_flow, +) +from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for, SimpleLoopForOp + + +class TestCfSimpleLoopFor(ExtTestCase): + @requires_torch("2.9.99") + def test_simple_loop_for_int(self): + class Model(torch.nn.Module): + def forward(self, x): + def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]: + return (x[: i.item() + 1].unsqueeze(1),) + + return simple_loop_for(4, body, (x,)) + + model = Model() + x = torch.arange(10, dtype=torch.float32) + expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1) + got = model(x) + self.assertEqualArray(expected, got) + + with enable_code_export_control_flow(): + got = model(x) + self.assertEqualArray(expected, got) + + ep = torch.export.export( + model, (x,), dynamic_shapes=(({0: torch.export.Dim.DYNAMIC},)) + ) + check = [] + for node in ep.graph.nodes: + if isinstance(node.target, SimpleLoopForOp): + check.append(node) + # Loop should be unrolled. + self.assertEqual(len(check), 0) + + @requires_torch("2.9.99") + def test_simple_loop_for_no_inputs(self): + class Model(torch.nn.Module): + def forward(self, n_iter, x): + def body(i: torch.Tensor) -> Tuple[torch.Tensor]: + return (torch.arange(i + 1, dtype=torch.int64),) + + y = simple_loop_for(n_iter, body) + torch._check(isinstance(y, torch.Tensor), lambda: f"y is {type(y)}") + return x.unsqueeze(1) + y.unsqueeze(0).to(x.device) + + model = Model() + n_iter = torch.tensor(4, dtype=torch.int64) + x = torch.arange(8, dtype=torch.float32) + expected = x.reshape((-1, 1)) + torch.tensor( + [[0, 0, 1, 0, 1, 2, 0, 1, 2, 3]], dtype=x.dtype + ) + got = model(n_iter, x) + self.assertEqualArray(expected, got) + + with enable_code_export_control_flow(): + got = model(n_iter, x) + self.assertEqualArray(expected, got) + + ep = torch.export.export( + model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) + ) + check = [] + for node in ep.graph.nodes: + if isinstance(node.target, SimpleLoopForOp): + check.append(node) + self.assertEqual(len(check), 1) + + @requires_torch("2.9.99") + def test_simple_loop_for_1(self): + class Model(torch.nn.Module): + def forward(self, n_iter, x): + def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]: + return (x[: i.item() + 1].unsqueeze(1),) + + return simple_loop_for(n_iter, body, (x,)) + + model = Model() + n_iter = torch.tensor(4, dtype=torch.int64) + x = torch.arange(10, dtype=torch.float32) + expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1) + got = model(n_iter, x) + self.assertEqualArray(expected, got) + + with enable_code_export_control_flow(): + got = model(n_iter, x) + self.assertEqualArray(expected, got) + + ep = torch.export.export( + model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) + ) + check = [] + for node in ep.graph.nodes: + if isinstance(node.target, SimpleLoopForOp): + check.append(node) + self.assertEqual(len(check), 1) + + @requires_torch("2.9.99") + def test_simple_loop_for_2(self): + class Model(torch.nn.Module): + def forward(self, n_iter, x): + def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]: + return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(1)) + + return simple_loop_for(n_iter, body, (x,)) + + model = Model() + n_iter = torch.tensor(4, dtype=torch.int64) + x = torch.arange(10, dtype=torch.float32) + expected = ( + torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1), + torch.tensor( + [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 4, + 5, + 6, + 7, + 8, + 9, + ], + dtype=x.dtype, + ).unsqueeze(1), + ) + got = model(n_iter, x) + self.assertEqualArray(expected[0], got[0]) + self.assertEqualArray(expected[1], got[1]) + + with enable_code_export_control_flow(): + got = model(n_iter, x) + self.assertEqualArray(expected[0], got[0]) + self.assertEqualArray(expected[1], got[1]) + + ep = torch.export.export( + model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) + ) + check = [] + for node in ep.graph.nodes: + if isinstance(node.target, SimpleLoopForOp): + check.append(node) + self.assertEqual(len(check), 1) + + @requires_torch("2.9.99") + def test_simple_loop_for_2_concatenation_dims(self): + class Model(torch.nn.Module): + def forward(self, n_iter, x): + def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]: + return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(0)) + + return simple_loop_for(n_iter, body, (x,), (0, 1)) + + model = Model() + n_iter = torch.tensor(4, dtype=torch.int64) + x = torch.arange(10, dtype=torch.float32) + expected = ( + torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1), + torch.tensor( + [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 4, + 5, + 6, + 7, + 8, + 9, + ], + dtype=x.dtype, + ).unsqueeze(0), + ) + got = model(n_iter, x) + self.assertEqualArray(expected[0], got[0]) + self.assertEqualArray(expected[1], got[1]) + + with enable_code_export_control_flow(): + got = model(n_iter, x) + self.assertEqualArray(expected[0], got[0]) + self.assertEqualArray(expected[1], got[1]) + + ep = torch.export.export( + model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) + ) + check = [] + for node in ep.graph.nodes: + if isinstance(node.target, SimpleLoopForOp): + check.append(node) + self.assertEqual(len(check), 1) + + @requires_torch("2.9.99") + def test_simple_loop_for_1_with_concatenation_dims(self): + class Model(torch.nn.Module): + def forward(self, n_iter, x): + def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]: + return (x[: i.item() + 1].unsqueeze(0),) + + return simple_loop_for(n_iter, body, (x,), 1) + + model = Model() + n_iter = torch.tensor(4, dtype=torch.int64) + x = torch.arange(10, dtype=torch.float32) + expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(0) + got = model(n_iter, x) + self.assertEqualArray(expected, got) + + with enable_code_export_control_flow(): + got = model(n_iter, x) + self.assertEqualArray(expected, got) + + ep = torch.export.export( + model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) + ) + check = [] + for node in ep.graph.nodes: + if isinstance(node.target, SimpleLoopForOp): + check.append(node) + self.assertEqual(len(check), 1) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_export/test_control_flow.py b/_unittests/ut_export/test_control_flow.py deleted file mode 100644 index e0409a82..00000000 --- a/_unittests/ut_export/test_control_flow.py +++ /dev/null @@ -1,32 +0,0 @@ -import unittest -import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch -from onnx_diagnostic.export.control_flow import loop_for - - -class TestControlFlow(ExtTestCase): - @requires_torch("2.9.99") - def test_loop_for(self): - class Model(torch.nn.Module): - def forward(self, n_iter, x): - def body(i, x): - return x[: i.item() + 1].unsqueeze(1) - - return loop_for(n_iter, body, (x,)) - - model = Model() - n_iter = torch.tensor(4, dtype=torch.int64) - x = torch.arange(10, dtype=torch.float32) - expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1) - got = model(n_iter, x) - self.assertEqualArray(expected, got) - - ep = torch.export.export( - model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) - ) - names = set(m for m, _ in ep.module().named_modules()) - self.assertIn("", names) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/_unittests/ut_export/test_control_flow_onnx.py b/_unittests/ut_export/test_control_flow_onnx.py index 4791f22b..0e409d6f 100644 --- a/_unittests/ut_export/test_control_flow_onnx.py +++ b/_unittests/ut_export/test_control_flow_onnx.py @@ -1,43 +1,13 @@ import unittest -from typing import Tuple import torch from onnxscript import script, FLOAT, INT64 from onnxscript import opset18 as op -from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, never_test -from onnx_diagnostic.export.control_flow_onnx import ( - enable_code_export_control_flow, - loop_for_onnx, -) -from onnx_diagnostic.export.control_flow_research import simple_loop_for as loop_for_r +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch +from onnx_diagnostic.export.control_flow_onnx import loop_for_onnx from onnx_diagnostic.export.api import to_onnx class TestControlFlowOnnx(ExtTestCase): - @never_test() - def test_loop_one_research(self): - class Model(torch.nn.Module): - def forward(self, n_iter, x): - def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]: - return (x[: i.item() + 1].unsqueeze(1),) - - return loop_for_r(n_iter, body, (x,))[0] - - model = Model() - n_iter = torch.tensor(4, dtype=torch.int64) - x = torch.arange(10, dtype=torch.float32) - expected = torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3], dtype=x.dtype).unsqueeze(1) - got = model(n_iter, x) - self.assertEqualArray(expected, got) - - with enable_code_export_control_flow(): - got = model(n_iter, x) - self.assertEqualArray(expected, got) - - ep = torch.export.export( - model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) - ) - print(ep) - def test_onnxscript_loop(self): @script() def concatenation(N: INT64[1], x: FLOAT[None]) -> FLOAT[None, 1]: diff --git a/onnx_diagnostic/export/cf_simple_loop_for.py b/onnx_diagnostic/export/cf_simple_loop_for.py new file mode 100644 index 00000000..a2db3c31 --- /dev/null +++ b/onnx_diagnostic/export/cf_simple_loop_for.py @@ -0,0 +1,351 @@ +import contextlib +from typing import Callable, List, Optional, Sequence, Tuple, Union +import torch +from torch._C import DispatchKey +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +import torch.utils._pytree as pytree +from torch._higher_order_ops.utils import ( + check_input_alias_and_mutation_return_outputs, + reenter_make_fx, + unique_graph_id, + validate_subgraph_args_types, +) +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.utils._python_dispatch import _get_current_dispatch_mode + + +class SimpleLoopForOp(HigherOrderOperator): + """Higher order op for :func:`simple_loop_for`.""" + + def __init__(self): + super().__init__("simple_loop_for") + + def __call__(self, n_iter, body_fn, operands, concatenation_dims=None): + validate_subgraph_args_types(operands) + return super().__call__(n_iter, body_fn, operands, concatenation_dims) + + def gen_schema(self, n_iter, body_fn, operands, concatenation_dims): + from torch._higher_order_ops.schema import HopSchemaGenerator + from torch._higher_order_ops.utils import materialize_as_graph + + body_gm: torch.fx.GraphModule = materialize_as_graph( # type: ignore[annotation-unchecked] + body_fn, (torch.tensor(0, dtype=torch.int64), *operands) + ) + ( + _, + _, + _, + body_mutated_inputs, + body_outputs, + ) = check_input_alias_and_mutation_return_outputs(body_gm) + mutated_inputs = body_mutated_inputs + + schema_gen = HopSchemaGenerator(self) + schema_gen.add_arg("n_iter", n_iter) + schema_gen.add_arg("body_fn", body_gm) + for idx, arg in enumerate(operands): + schema_gen.add_arg(f"operand{idx}", arg, is_mutated=idx in mutated_inputs) + + for out in body_outputs: + schema_gen.add_output(out) + assert concatenation_dims is None or len(concatenation_dims) == len(body_outputs), ( + f"concatenation_dims={concatenation_dims} but its length should be equal to " + f"the number of outputs ({len(body_outputs)})" + ) + schema_gen.add_schema_tree_spec(n_iter, body_fn, operands, concatenation_dims) + return schema_gen.gen_schema() + + +simple_loop_for_op = SimpleLoopForOp() + + +def _simple_loop_for_fn( + n_iter: torch.Tensor, + body_fn: Callable, + operands: Tuple[torch.Tensor, ...] = (), + concatenation_dims: Optional[Sequence[int]] = None, +) -> Tuple[torch.Tensor, ...]: + """ + Python implementation of the loop. + + :param n_iter: number of iteration + :param body_fn: function implementing the body + :param concatenation_dims: dimension used to reduce the list produced by the loop + :param operands: arguments to the loop body + :return: results + """ + torch._check( + isinstance(n_iter, (int, torch.Tensor)), + lambda: f"Unexpected type {type(n_iter)} for n_iter", + ) + torch._check(callable(body_fn), lambda: f"Unexpected type {type(body_fn)} for body_fn") + torch._check( + concatenation_dims is None or isinstance(concatenation_dims, (list, tuple)), + lambda: f"Unexpected type {type(concatenation_dims)} for concatenation_dims", + ) + torch._check( + isinstance(operands, tuple), lambda: f"Unexpected type {type(operands)} for operands" + ) + res: List[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = [] + for i in torch.arange( + n_iter, dtype=torch.int64 if isinstance(n_iter, int) else n_iter.dtype + ): + r = body_fn(i, *operands) + if isinstance(r, tuple): + assert not res or len(r) == len(res[-1]), ( + f"Unexpected number of results {len(r)} for function {body_fn}, " + f"expected {len(res[-1])}" + ) + res.append(r) + else: + assert isinstance(r, torch.Tensor), ( + f"Unexpected type {r} for function {body_fn}, " + f"it must be a tuple or a Tensor." + ) + assert not res or len(res[-1]) == 1, ( + f"Unexpected number of results {len(r)} for function {body_fn}, " + f"expected {len(res[-1])}" + ) + res.append((r,)) + + if not res: + return torch.empty(tuple(), dtype=torch.float32, device=operands[0].device) + + n_res = len(res[0]) + return tuple( + torch.cat( + [r[i] for r in res], + dim=( + 0 + if concatenation_dims is None or i >= len(concatenation_dims) + else concatenation_dims[i] + ), + ) + for i in range(n_res) + ) + + +# from torch._functorch.utils import exposed_in +# @exposed_in("torch") +def _simple_loop_for( + n_iter: Union[int, torch.Tensor], + body_fn: Callable, + operands: Tuple[torch.Tensor, ...] = (), + concatenation_dims: Optional[Sequence[int]] = None, +) -> Tuple[torch.Tensor, ...]: + def _validate_input(n_iter, body_fn, operands, concatenation_dims): + assert isinstance( + n_iter, (int, torch.Tensor, torch.SymInt) + ), f"Expected pred to be bool or tensor, but got {n_iter}." + assert ( + not isinstance(n_iter, torch.Tensor) or n_iter.numel() == 1 + ), f"Expected pred to be bool or single-element tensor, but got {n_iter}." + assert callable(body_fn), "Expect both branches to be callable." + assert isinstance(operands, (tuple, list)) and pytree.tree_all( + lambda t: isinstance(t, torch.Tensor), operands + ), ( + "Expect operands to be a tuple of possibly nested dict/list/tuple that only " + f"consists of tensor leaves, but got {operands}." + ) + assert concatenation_dims is None or ( + isinstance(concatenation_dims, (list, tuple)) + and all(isinstance(i, int) for i in concatenation_dims) + ), ( + f"concatenation_dims should be None or a list of integers but it is " + f"{concatenation_dims}. Its length should be equal to the number of outputs." + ) + assert torch._dynamo.is_dynamo_supported(), "simple_loop_for requires dynamo support." + + if torch.compiler.is_dynamo_compiling(): + return simple_loop_for_op( + n_iter, body_fn, (n_iter, *operands), concatenation_dims=concatenation_dims + ) + + if isinstance(n_iter, (bool, int, float)): + torch._check( + isinstance(n_iter, int), + lambda: f"n_iter must be an integer or a tensor not {type(n_iter)}", + ) + return _simple_loop_for_fn( + n_iter, body_fn, operands, concatenation_dims=concatenation_dims + ) + + def _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims): + return simple_loop_for_op(n_iter, body_fn, operands, concatenation_dims) + + _validate_input(n_iter, body_fn, operands, concatenation_dims) + + # This requires torch>=2.10. + from torch._higher_order_ops.utils import setup_compilation_env + + with setup_compilation_env() as _backend: + return _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims) + # return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)( + # n_iter, body_fn, operands, concatenation_dims) + + +def trace_simple_loop_for( + proxy_mode, func_overload, n_iter, body_fn, operands, concatenation_dims +): + """See function ``simple_loop_for``.""" + assert isinstance(operands, (list, tuple)) and ( + concatenation_dims is None + or ( + isinstance(concatenation_dims, (list, tuple)) + and all(isinstance(i, int) for i in concatenation_dims) + ) + ), ( + f"simple_loop_for operands must be a list or tuple of tensors and SymInts and " + f"concatenation_dims must be None or a list of integer, " + f"operands={[type(o) for o in operands]}, " + f"concatenation_dims={concatenation_dims}" + ) + + body_graph = reenter_make_fx(body_fn)(n_iter, *operands) + + body_outs = [] + for node in body_graph.graph.nodes: + if node.op == "output": + body_outs.extend(node.args) + + # flat_body_outs = pytree.arg_tree_leaves(*body_outs) + _i, body_name = unique_graph_id(proxy_mode, prefix="body_graph") + proxy_mode.tracer.root.register_module(body_name, body_graph) + args = (n_iter, body_graph, operands, concatenation_dims) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + out_proxy = proxy_mode.tracer.create_proxy("call_function", func_overload, proxy_args, {}) + out = func_overload(n_iter, body_graph, operands, concatenation_dims) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@simple_loop_for_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def loop_for_op_dense(n_iter, body_fn, operands, concatenation_dims=None): + """Registered eager mode implementation.""" + assert all(isinstance(o, torch.Tensor) for o in operands) and ( + concatenation_dims is None + or ( + isinstance(concatenation_dims, (list, tuple)) + and all(isinstance(i, int) for i in concatenation_dims) + ) + ), ( + f"simple_loop_for operands must be a list or tuple of tensors and SymInts and " + f"concatenation_dims must be None or a list of integer, " + f"operands={[type(o) for o in operands]}, " + f"concatenation_dims={concatenation_dims}" + ) + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return _simple_loop_for_fn( + n_iter, body_fn, operands, concatenation_dims=concatenation_dims + ) + + +@simple_loop_for_op.py_impl(ProxyTorchDispatchMode) +def inner(mode, n_iter, body_fn, operands, concatenation_dims=None): + """Registered tracing implementation.""" + return trace_simple_loop_for( + mode, simple_loop_for_op, n_iter, body_fn, operands, concatenation_dims + ) + + +@simple_loop_for_op.py_impl(FakeTensorMode) +def simple_loop_for_fake_tensor_mode(mode, n_iter, body_fn, operands, concatenation_dims=None): + """Registered FakeMode implementation.""" + ignore_fresh_unbacked = contextlib.nullcontext() + if mode.shape_env: + ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols() + + with mode, ignore_fresh_unbacked: + flat_body_outs, true_body_spec = pytree.tree_flatten(body_fn(n_iter, *operands)) + + return pytree.tree_unflatten(flat_body_outs, true_body_spec) + + +# Registration for autograd. +simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU) +simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA) + + +def simple_loop_for( + n_iter: Union[int, torch.Tensor], + body_fn: Callable, + operands: Tuple[torch.Tensor, ...] = (), + concatenation_dims: Optional[Union[int, Sequence[int]]] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + """ + Implements a simple loop for, the body is defined by a function which takes the + iteration number stored in a tensor, and other tensors. + It results one or several tensors in a tuple. All of them + are finally concatenated along the first dimension. + + :param n_iter: iteration number + :param body: function + :param operands: bidy arguments + :param concatenation_dims: dimension or dimensions used to concatenate the output sequences + :return: contenated outputs, the output is a Tensor + + An example with one output: + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for + + + class Model(torch.nn.Module): + def forward(self, n_iter, x): + def body(i, x): + return (x[: i.item() + 1].unsqueeze(1),) + + return simple_loop_for(n_iter, body, (x,)) + + + model = Model() + n_iter = torch.tensor(4, dtype=torch.int64) + x = torch.arange(10, dtype=torch.float32) + ep = torch.export.export( + model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) + ) + print(ep) + + Another example with two outputs and a final concatenation on different axes. + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for + + + class Model(torch.nn.Module): + def forward(self, n_iter, x): + def body(i: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor]: + return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(0)) + + return simple_loop_for(n_iter, body, (x,), (0, 1)) + + + model = Model() + n_iter = torch.tensor(4, dtype=torch.int64) + x = torch.arange(10, dtype=torch.float32) + ep = torch.export.export( + model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) + ) + print(ep) + """ + res = _simple_loop_for( + n_iter, + body_fn, + operands, + concatenation_dims=( + (concatenation_dims,) + if isinstance(concatenation_dims, int) + else concatenation_dims + ), + ) + torch._check( + isinstance(res, tuple), f"Output of the loop should be a tuple not {type(res)}." + ) + return res[0] if len(res) == 1 else res diff --git a/onnx_diagnostic/export/control_flow.py b/onnx_diagnostic/export/control_flow.py deleted file mode 100644 index 243a07ef..00000000 --- a/onnx_diagnostic/export/control_flow.py +++ /dev/null @@ -1,214 +0,0 @@ -import contextlib -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union -import torch -from torch._higher_order_ops.utils import ( - materialize_as_graph, - check_input_alias_and_mutation_return_outputs, - # _maybe_reenter_make_fx, -) - -_TEST_EXPORT = False - - -@contextlib.contextmanager -def enable_code_export_control_flow(): - """Enables the code meant to be exported.""" - global _TEST_EXPORT - old = _TEST_EXPORT - _TEST_EXPORT = True - try: - yield - finally: - _TEST_EXPORT = old - - -def is_exporting() -> bool: - """ - Returns :func:`torch.compiler.is_exporting` or - :func:`torch.compiler.is_compiling`. - Changes ``_TEST_EXPORT`` to make it trigger. - """ - return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling() - - -def _loop_for_fn(n_iter, body_fn, reduction_dim, args): - """ - Python implementation of the loop. - - :param n_iter: number of iteration - :param body_fn: function implementing the body - :param reduction_dim: dimension used to reduce the list produced by the loop - :param args: arguments to the loop body - :return: results - """ - res = [] - for i in torch.arange(n_iter, dtype=n_iter.dtype): - r = body_fn(i, *args) - if isinstance(r, tuple): - assert not res or len(r) == len(res[-1]), ( - f"Unexpected number of results {len(r)} for function {body_fn}, " - f"expected {len(res[-1])}" - ) - res.append(r) - else: - assert isinstance(r, torch.Tensor), ( - f"Unexpected type {r} for function {body_fn}, " - f"it must be a tuple or a Tensor." - ) - assert not res or len(res[-1]) == 1, ( - f"Unexpected number of results {len(r)} for function {body_fn}, " - f"expected {len(res[-1])}" - ) - res.append((r,)) - - if not res: - return torch.empty(tuple(), dtype=torch.float32, device=args[0].device) - if len(res) == 1: - final = res[0] - else: - n_res = len(res[0]) - final = [ - torch.cat( - [r[i] for r in res], - dim=( - 0 if reduction_dim is None or i >= len(reduction_dim) else reduction_dim[i] - ), - ) - for i in range(n_res) - ] - return tuple(final) if len(final) > 1 else final[0] - - -def make_custom_loop_for( - n_iter: torch.Tensor, - body_fn: Callable, - reduction_dim: Optional[Sequence[int]], - args: Sequence[torch.Tensor], - body_gm: Optional[torch.fx.GraphModule] = None, - body_mutated_inputs: Optional[List[Any]] = None, - body_outputs: Optional[List[Any]] = None, -) -> Tuple[str, torch.library.CustomOpDef]: - """ - Defines a custom operator for a loop in order to avoid - :func:`torch.export.export` digging into it. - It registers the custom op and a custom conversion - to ONNX. - - :param n_iter: number of iterations defined by a tensor of no dimension - :param body_fn: the loop body defined as a function - :param reduction_dim: dimension used to concatenated the results - :param args: list of tensors, input to the body - :param body_gm: torch.fx.GraphModule equivalent to *body_gm* - :param body_mutated_inputs: inputs to *body_gm* - :param body_outputs: outputs to *body_gm* - :return: a name and the custom op definition, the name - is used to cache the custom op - """ - assert body_gm is not None, "body_gm cannot be None" - assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None" - assert body_outputs is not None, "body_outputs cannot be None" - - srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs) - sred = "x".join(map(str, reduction_dim)) if reduction_dim else "" - full_name = ( - body_fn.__qualname__.replace("", "L") - .replace("", "l") - .replace(".", "_") - ) - name = f"loop_for_onnx_{full_name}_{srank}_{sred}" - - schema = "(str body_fn, Tensor n_iter, Tensor[] body_inputs) -> Tensor" - if len(body_outputs) > 1: - schema += "[]" - custom_def = torch.library.CustomOpDef("onnx_higher_ops", "loop_for", schema, body_fn) - custom_def.register_kernel("cpu")(body_fn) - - custom_def._abstract_fn = lambda _fn_id, *_args, _o=body_outputs: ( - tuple([torch.empty_like(s) for s in _o]) if len(_o) > 1 else torch.empty_like(_o[0]) - ) - return name, custom_def - - -def loop_for( - n_iter: Union[torch.SymInt, torch.Tensor], - body_fn: Callable[..., Tuple[torch.Tensor]], - args: Sequence[torch.Tensor], - reduction_dim: Optional[Sequence[int]] = None, -) -> Tuple[torch.Tensor, ...]: - """ - High operators used to easily export a loop in ONNX. - Does not fully work with :func:`torch.export.export`, - it does replaces a custom op with a loop operator afterwards. - Every iteration produces tensors, all of them are gathered - into lists, all these lists are concatenated into tensors. - - :param n_iter: number of iterations, it can be fixed on - variable, in that case it should a tensor with no dimension - :param body_fn: function body, takes only tensors and returns - only tensors, the first argument is the iteration number - in a tensor with no dimension, all the others - are not changed during the loop - :param args: the available tensors at every loop - :param reduction_dim: the loop aggregated the results into list, - one of each output, each of them is concatenated into one - tensor along one dimension, by default, it is the first - dimension, but it can be defined otherwise - """ - assert args, "The function should have at least one arg." - assert ( - isinstance(n_iter, torch.Tensor) - and n_iter.dtype == torch.int64 - and len(n_iter.shape) == 0 - ), f"Only a tensor for one int64 is allowed for n_iter but it equal to {n_iter}." - if is_exporting(): - from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER - - # tracer = _CURRENT_MAKE_FX_TRACER.fx_tracer - root = _CURRENT_MAKE_FX_TRACER.fx_tracer.root - # graph = _CURRENT_MAKE_FX_TRACER.fx_tracer.graph - - body_gm: torch.fx.GraphModule = materialize_as_graph( - body_fn, (torch.tensor(0, dtype=torch.int64), *args) - ) - ( - _1, - _2, - _3, - body_mutated_inputs, - body_outputs, - ) = check_input_alias_and_mutation_return_outputs(body_gm) - name, _custom_ops = make_custom_loop_for( - n_iter, - body_fn, - reduction_dim, - args, - body_gm=body_gm, - body_mutated_inputs=body_mutated_inputs, - body_outputs=body_outputs, - ) - root.register_module(name, body_gm) - # body_graph = _maybe_reenter_make_fx(body_fn)(n_iter, *args) - return torch.ops.onnx_higher_ops.loop_for(name, n_iter, args) - - return _loop_for_fn(n_iter, body_fn, reduction_dim, args) - - -""" - proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph) - proxy_mode.tracer.root.register_module(body_graph_name, body_graph) - - args = (cond_graph, body_graph, carried_inputs, additional_inputs) - - proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) - - out_proxy = proxy_mode.tracer.create_proxy( - "call_function", op, proxy_args, {}, name=op._name - ) - - out = op( - cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs - ) - return track_tensor_tree( - out, out_proxy, constant=None, tracer=proxy_mode.tracer - ) -""" diff --git a/onnx_diagnostic/export/control_flow_research.py b/onnx_diagnostic/export/control_flow_research.py deleted file mode 100644 index c10d9b07..00000000 --- a/onnx_diagnostic/export/control_flow_research.py +++ /dev/null @@ -1,140 +0,0 @@ -from typing import Any, Callable, Union -import torch -from torch._C import DispatchKey - -# from torch._higher_order_ops import BaseHOP -from torch._ops import HigherOrderOperator -from torch._functorch.utils import exposed_in -import torch.utils._pytree as pytree -from torch._higher_order_ops.utils import ( - check_input_alias_and_mutation_return_outputs, - reenter_make_fx, - unique_graph_id, - validate_subgraph_args_types, -) -from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree -from torch.utils._python_dispatch import _get_current_dispatch_mode -from .control_flow_onnx import _loop_for_onnx_fn - - -class SimpleLoopForOp(HigherOrderOperator): - def __init__(self): - super().__init__("simple_loop_for") - - def __call__(self, n_iter, body_fn, operands): - validate_subgraph_args_types(operands) - return super().__call__(n_iter, body_fn, operands) - - def gen_schema(self, n_iter, body_fn, operands): - from torch._higher_order_ops.schema import HopSchemaGenerator - from torch._higher_order_ops.utils import materialize_as_graph - - body_gm: torch.fx.GraphModule = materialize_as_graph( # type: ignore[annotation-unchecked] - body_fn, (torch.tensor(0, dtype=torch.int64), *operands) - ) - ( - _, - _, - _, - body_mutated_inputs, - body_outputs, - ) = check_input_alias_and_mutation_return_outputs(body_gm) - mutated_inputs = body_mutated_inputs - - schema_gen = HopSchemaGenerator(self) - schema_gen.add_arg("n_iter", n_iter) - schema_gen.add_arg("body_fn", body_gm) - for idx, arg in enumerate(operands): - schema_gen.add_arg(f"operand{idx}", arg, is_mutated=idx in mutated_inputs) - - for out in body_outputs: - schema_gen.add_output(out) - schema_gen.add_schema_tree_spec(n_iter, body_fn, operands) - return schema_gen.gen_schema() - - -simple_loop_for_op = SimpleLoopForOp() - - -@exposed_in("torch") -def simple_loop_for( - n_iter: Union[int, torch.Tensor], - body_fn: Callable, - operands: Union[tuple, list] = (), -) -> Any: - if torch.compiler.is_dynamo_compiling(): - return simple_loop_for_op(n_iter, body_fn, (n_iter, *operands)) - - if isinstance(n_iter, (bool, int, float)): - return _loop_for_onnx_fn(body_fn, n_iter, None, *operands) - - def _validate_input(n_iter, body_fn, operands): - assert isinstance( - n_iter, (int, torch.Tensor, torch.SymInt) - ), f"Expected pred to be bool or tensor, but got {n_iter}." - assert ( - not isinstance(n_iter, torch.Tensor) or n_iter.numel() == 1 - ), f"Expected pred to be bool or single-element tensor, but got {n_iter}." - assert callable(body_fn), "Expect both branches to be callable." - assert isinstance(operands, (tuple, list)) and pytree.tree_all( - lambda t: isinstance(t, torch.Tensor), operands - ), ( - "Expect operands to be a tuple of possibly nested dict/list/tuple that only " - f"consists of tensor leaves, but got {operands}." - ) - - _validate_input(n_iter, body_fn, operands) - - assert torch._dynamo.is_dynamo_supported(), "torch.cond requires dynamo support." - - def _loop_for_op_wrapper(*args, **kwargs): - return simple_loop_for_op(*args, **kwargs) - - from torch._higher_order_ops.utils import setup_compilation_env - - with setup_compilation_env() as _backend: - return _loop_for_op_wrapper(n_iter, body_fn, *operands) - # return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)( - # n_iter, body_fn, operands - # ) - - -def trace_loop_for(proxy_mode, func_overload, n_iter, body_fn, operands): - assert isinstance( - operands, (list, tuple) - ), f"Cond operands must be a list or tuple of tensors and SymInts {operands}" - - body_graph = reenter_make_fx(body_fn)(n_iter, *operands) - - body_outs = [] - for node in body_graph.graph.nodes: - if node.op == "output": - body_outs.extend(node.args) - - # flat_body_outs = pytree.arg_tree_leaves(*body_outs) - _i, body_name = unique_graph_id(proxy_mode, prefix="body_graph") - proxy_mode.tracer.root.register_module(body_name, body_graph) - args = (n_iter, body_graph, body_graph, operands) - proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) - out_proxy = proxy_mode.tracer.create_proxy("call_function", func_overload, proxy_args, {}) - out = func_overload(n_iter, body_graph, operands) - return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) - - -@simple_loop_for_op.py_impl(DispatchKey.CompositeExplicitAutograd) -def loop_for_op_dense(n_iter, body_fn, operands): - assert all( - isinstance(o, (torch.Tensor, int)) for o in operands - ), f"Dense implementation operands must be a list of tensors and ints {operands}" - mode = _get_current_dispatch_mode() - assert mode is None, "Mode should never be enabled for CPU/CUDA key" - return _loop_for_onnx_fn(body_fn, n_iter, None, operands) - - -@simple_loop_for_op.py_impl(ProxyTorchDispatchMode) -def inner(mode, n_iter, body_fn, operands): - return trace_loop_for(mode, simple_loop_for_op, n_iter, body_fn, operands) - - -simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU) -simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA) diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index 211794c9..638462b4 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -774,6 +774,7 @@ def verbose(self) -> int: def setUpClass(cls): logger = logging.getLogger("onnxscript.optimizer.constant_folding") logger.setLevel(logging.ERROR) + warnings.filterwarnings("ignore", category=DeprecationWarning) unittest.TestCase.setUpClass() @classmethod