From 328f93d19065853551d4f9ee6e805b1b12c352d9 Mon Sep 17 00:00:00 2001 From: pratikwayase Date: Mon, 15 Jun 2026 20:19:06 +0530 Subject: [PATCH 1/5] Fix aten_repeat_interleave_self_int ONNX export bug --- onnxscript/function_libs/torch_lib/ops/core.py | 8 +++++--- tests/function_libs/torch_lib/ops_test_data.py | 8 -------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 5896f13471..1efe36348d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8293,7 +8293,11 @@ def aten_repeat_interleave_self_int( x.repeat((1, 2)).reshape((-1, t.shape[1])) """ if dim is None: - raise NotImplementedError("No conversion available yet when dim is None.") + self = op.Reshape(self,[-1]) + dim = 0 + self_rank = 1 + else: + self_rank = len(self.shape) self_rank = len(self.shape) pos_dim = (dim + self_rank) % self_rank @@ -8311,8 +8315,6 @@ def aten_repeat_interleave_self_int( axis=0, ) tiled = op.Expand(unsqueezed, tile_repeat) - if self_rank == 1: - return op.Identity(tiled) final_shape = op.Concat( op.Shape(self, start=0, end=dim), op.Constant(value_ints=[-1]), diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 241ac98cf9..c4f47f2097 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1141,10 +1141,6 @@ def _where_input_wrangler( reason=("ignore cases when repeasts is a Tensor"), ) .skip(dtypes=(torch.bool,), reason="bool not supported") - .skip( - matcher=lambda sample: sample.kwargs.get("dim") is None, - reason="fixme: conversion not implemented if dim is None", - ) .skip( matcher=lambda sample: sample.input.numel() == 0, reason="fixme: conversion not implemented when input tensor is empty", @@ -1155,10 +1151,6 @@ def _where_input_wrangler( reason=("ignore cases when repeasts is an int"), ) .skip(dtypes=(torch.bool,), reason="bool not supported") - .skip( - matcher=lambda sample: sample.kwargs.get("dim") is None, - reason="fixme: conversion not implemented if dim is None", - ) .skip( matcher=lambda sample: sample.input.numel() == 0, reason="fixme: conversion not implemented when input tensor is empty", From fddd146918481fa46873a24103f39f6877058aad Mon Sep 17 00:00:00 2001 From: pratikwayase Date: Tue, 16 Jun 2026 16:02:31 +0530 Subject: [PATCH 2/5] Fix aten::repeat_interleave.self_int shape inference error when dim is None --- .../function_libs/torch_lib/ops/core.py | 19 ++++------- tests/function_libs/torch_lib/ops_test.py | 33 +++++++++++++++++++ 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1efe36348d..20907b7bab 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8277,29 +8277,22 @@ def aten_repeat_interleave_self_int( output_size: Optional[int] = None, ) -> TensorType: """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor - The trick is to repeat in one direction orthogonal to reshape. - .. code-block:: python - - x = torch.tensor([[0, 1, 2], [3, 4, 5]]) - x.repeat_interleave(2, dim=0) - + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + x.repeat_interleave(2, dim=0) is equivalent to: - .. code-block:: python - - x = torch.tensor([[0, 1, 2], [3, 4, 5]]) - x.repeat((1, 2)).reshape((-1, t.shape[1])) + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + x.repeat((1, 2)).reshape((-1, t.shape[1])) """ if dim is None: - self = op.Reshape(self,[-1]) + self = op.Reshape(self, [-1]) dim = 0 self_rank = 1 else: self_rank = len(self.shape) - - self_rank = len(self.shape) + pos_dim = (dim + self_rank) % self_rank unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) if isinstance(repeats, int): diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index beb74b5462..9282624c0f 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -324,5 +324,38 @@ def test_complex_output_match_opinfo_( TestOutputConsistencyFullGraph, globals(), only_for=["cpu", "cuda"] ) +def test_repeat_interleave_dim_none_regression(): + """Regression test for repeat_interleave shape inference error when dim is None. + + Previously, this would raise: + InferenceError: Inferred shape and existing shape differ in rank: (2) vs (1) + """ + import torch + import onnx + import tempfile + import os + + class MyModule(torch.nn.Module): + def forward(self, x): + return torch.repeat_interleave(x, 2) + + model = MyModule() + model.eval() + x = torch.tensor([2]) + + with tempfile.TemporaryDirectory() as tmpdir: + onnx_path = os.path.join(tmpdir, "model.onnx") + + # Use standard torch.onnx.export which is compatible across PyTorch 2.x versions + torch.onnx.export( + model, + (x,), + onnx_path, + opset_version=18, + ) + + # raise an InferenceError if the fix is not applied + onnx.shape_inference.infer_shapes_path(onnx_path, strict_mode=True) + if __name__ == "__main__": unittest.main() From 8b48666aab09208640663598566321e71802855b Mon Sep 17 00:00:00 2001 From: pratikwayase Date: Wed, 17 Jun 2026 10:28:33 +0530 Subject: [PATCH 3/5] Fix Add regression test in e2e_ops_tests.py to prevent future errors. --- .../function_libs/torch_lib/ops/core.py | 2 +- .../function_libs/torch_lib/e2e_ops_tests.py | 14 ++++++++ tests/function_libs/torch_lib/ops_test.py | 33 ------------------- 3 files changed, 15 insertions(+), 34 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 20907b7bab..a94de05f64 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8292,7 +8292,7 @@ def aten_repeat_interleave_self_int( self_rank = 1 else: self_rank = len(self.shape) - + pos_dim = (dim + self_rank) % self_rank unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) if isinstance(repeats, int): diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 440e2316c1..0f47fd8ab3 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -141,6 +141,20 @@ def forward(self, x, ind): ) _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_int_dim_none(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.repeat_interleave(x, 2) + + inputs = (torch.tensor([2]),) + onnx_program = torch.onnx.export( + Model(), + inputs, + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_symbolic_tensor(self): class Model(torch.nn.Module): def forward(self, x, y): diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index 9282624c0f..beb74b5462 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -324,38 +324,5 @@ def test_complex_output_match_opinfo_( TestOutputConsistencyFullGraph, globals(), only_for=["cpu", "cuda"] ) -def test_repeat_interleave_dim_none_regression(): - """Regression test for repeat_interleave shape inference error when dim is None. - - Previously, this would raise: - InferenceError: Inferred shape and existing shape differ in rank: (2) vs (1) - """ - import torch - import onnx - import tempfile - import os - - class MyModule(torch.nn.Module): - def forward(self, x): - return torch.repeat_interleave(x, 2) - - model = MyModule() - model.eval() - x = torch.tensor([2]) - - with tempfile.TemporaryDirectory() as tmpdir: - onnx_path = os.path.join(tmpdir, "model.onnx") - - # Use standard torch.onnx.export which is compatible across PyTorch 2.x versions - torch.onnx.export( - model, - (x,), - onnx_path, - opset_version=18, - ) - - # raise an InferenceError if the fix is not applied - onnx.shape_inference.infer_shapes_path(onnx_path, strict_mode=True) - if __name__ == "__main__": unittest.main() From 90ba4563fa1536af2b0da36e6b91d373581c34dd Mon Sep 17 00:00:00 2001 From: pratikwayase Date: Fri, 19 Jun 2026 23:45:15 +0530 Subject: [PATCH 4/5] Fix incorrect rank handling in aten_repeat_interleave_self_int ONNX export (fixes #2932) --- .../function_libs/torch_lib/ops/core.py | 44 ++++++--- .../function_libs/torch_lib/e2e_ops_tests.py | 93 ++++++++++++++----- 2 files changed, 101 insertions(+), 36 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a94de05f64..8c53168595 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8277,43 +8277,59 @@ def aten_repeat_interleave_self_int( output_size: Optional[int] = None, ) -> TensorType: """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + The trick is to repeat in one direction orthogonal to reshape. + .. code-block:: python - x = torch.tensor([[0, 1, 2], [3, 4, 5]]) - x.repeat_interleave(2, dim=0) - is equivalent to: - .. code-block:: python - x = torch.tensor([[0, 1, 2], [3, 4, 5]]) - x.repeat((1, 2)).reshape((-1, t.shape[1])) + + import torch + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + result = x.repeat_interleave(2, dim=0) + + print(result) """ + if dim is None: - self = op.Reshape(self, [-1]) - dim = 0 - self_rank = 1 - else: - self_rank = len(self.shape) + flat_self = op.Reshape(self, [-1]) + unsqueezed = op.Unsqueeze(flat_self, [1]) + + if isinstance(repeats, int): + tile_repeat = op.Constant(value=ir.tensor([1, repeats], dtype=INT64.dtype)) + else: + # repeats is a symbolic tensor + tile_repeat = op.Concat( + op.Constant(value=ir.tensor([1], dtype=INT64.dtype)), + op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))), + axis=0, + ) + + tiled = op.Expand(unsqueezed, tile_repeat) + return op.Reshape(tiled, op.Constant(value_ints=[-1])) + self_rank = len(self.shape) pos_dim = (dim + self_rank) % self_rank unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) + if isinstance(repeats, int): tiles = [1] * (self_rank + 1) tiles[pos_dim + 1] = repeats tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype)) else: - # repeats is a symbolic tensor tile_repeat = op.Concat( op.Constant(value=ir.tensor([1] * pos_dim, dtype=INT64.dtype)), op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))), op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)), axis=0, ) + tiled = op.Expand(unsqueezed, tile_repeat) final_shape = op.Concat( - op.Shape(self, start=0, end=dim), + op.Shape(self, start=0, end=pos_dim), op.Constant(value_ints=[-1]), op.Shape(self, start=pos_dim + 1), axis=0, - ) + ) return op.Reshape(tiled, final_shape) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 0f47fd8ab3..f96430e9a2 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -66,7 +66,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x**2 onnx_program = torch.onnx.export( - PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False + PowModel(), + (torch.tensor(0.5, dtype=torch.float16),), + dynamo=True, + optimize=False, ) _testing.assert_onnx_program(onnx_program) @@ -76,7 +79,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x**0.5 onnx_program = torch.onnx.export( - PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False + PowModel(), + (torch.tensor(0.5, dtype=torch.float16),), + dynamo=True, + optimize=False, ) _testing.assert_onnx_program(onnx_program) @@ -146,9 +152,11 @@ class Model(torch.nn.Module): def forward(self, x): return torch.repeat_interleave(x, 2) - inputs = (torch.tensor([2]),) + model = Model().eval() + inputs = (torch.tensor([2], dtype=torch.int64),) + onnx_program = torch.onnx.export( - Model(), + model, inputs, dynamo=True, optimize=False, @@ -158,9 +166,9 @@ def forward(self, x): def test_repeat_interleave_symbolic_tensor(self): class Model(torch.nn.Module): def forward(self, x, y): - return torch.repeat_interleave(x, y.shape[1], dim=1) * torch.repeat_interleave( - y, x.shape[1], dim=1 - ) + return torch.repeat_interleave( + x, y.shape[1], dim=1 + ) * torch.repeat_interleave(y, x.shape[1], dim=1) inputs = ( torch.arange(4, dtype=torch.float32).reshape((2, 2)), @@ -334,7 +342,9 @@ def forward(self, x, offset, weight, bias, mask): def test_dft_axis_promoted_from_attribute_to_input(self): class Model(torch.nn.Module): def forward(self, x): - return torch.ops.aten._fft_r2c(x, [0], normalization=1, onesided=True) # pylint: disable=protected-access + return torch.ops.aten._fft_r2c( + x, [0], normalization=1, onesided=True + ) # pylint: disable=protected-access onnx_program = torch.onnx.export( Model(), @@ -349,12 +359,24 @@ def test_avg_pool(self): class Model(torch.nn.Module): def forward(self, x2d, x3d, x4d, x5d): return ( - torch.nn.functional.avg_pool1d(x2d, 2), # pylint: disable=not-callable - torch.nn.functional.avg_pool1d(x3d, 2), # pylint: disable=not-callable - torch.nn.functional.avg_pool2d(x3d, 2), # pylint: disable=not-callable - torch.nn.functional.avg_pool2d(x4d, 2), # pylint: disable=not-callable - torch.nn.functional.avg_pool3d(x4d, 2), # pylint: disable=not-callable - torch.nn.functional.avg_pool3d(x5d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool1d( + x2d, 2 + ), # pylint: disable=not-callable + torch.nn.functional.avg_pool1d( + x3d, 2 + ), # pylint: disable=not-callable + torch.nn.functional.avg_pool2d( + x3d, 2 + ), # pylint: disable=not-callable + torch.nn.functional.avg_pool2d( + x4d, 2 + ), # pylint: disable=not-callable + torch.nn.functional.avg_pool3d( + x4d, 2 + ), # pylint: disable=not-callable + torch.nn.functional.avg_pool3d( + x5d, 2 + ), # pylint: disable=not-callable ) x2d = torch.randn(10, 10) @@ -532,7 +554,9 @@ def forward(self, x): def test_aten_unique_consecutive_return(self): class Model(torch.nn.Module): def forward(self, x): - return torch.unique_consecutive(x, return_inverse=True, return_counts=True) + return torch.unique_consecutive( + x, return_inverse=True, return_counts=True + ) model = Model() x = torch.tensor([0, 1, 2, 2, 3, 3, 3, 0, 0], dtype=torch.int64) @@ -578,7 +602,9 @@ def test_aten_stft_3(self): class Model(torch.nn.Module): def forward(self, x): window = torch.ones(16, dtype=torch.float32) - return torch.ops.aten.stft(x, n_fft=16, window=window, return_complex=False) + return torch.ops.aten.stft( + x, n_fft=16, window=window, return_complex=False + ) x = torch.randn(100, dtype=torch.float32) @@ -686,7 +712,9 @@ def forward(self, tokens, h, c): tokens = torch.tensor([1]) h = torch.randn(2, 1, 64) # 2 layers c = torch.randn(2, 1, 64) # 2 layers - onnx_program = torch.onnx.export(model, (tokens, h, c), dynamo=True, verbose=False) + onnx_program = torch.onnx.export( + model, (tokens, h, c), dynamo=True, verbose=False + ) _testing.assert_onnx_program(onnx_program) def test_unbind_dynamic_dim0(self): @@ -778,7 +806,12 @@ def forward(self, x): (2,), "contiguous_non_broadcast_indices_new_dim1", ), - ((6, 6, 6), [None, [0, 1], [2, 3]], (), "contiguous_non_broadcast_indices_scalar"), + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (), + "contiguous_non_broadcast_indices_scalar", + ), # Multiple advanced indices, with broadcasting among indices. # Contiguous advanced indices: # This produces index tuples [(0,2), (0, 3), (1,2), (1,3)] in shape (2,2) @@ -829,8 +862,18 @@ def forward(self, x): "non_contiguous_non_first", ), ((6, 6, 6), [0, None, None], (6, 6), "single_scalar_index"), - ((6, 6, 6), [0, None, [0, 1]], (2, 6), "non_contiguous_scalar_index_and_1d_index"), - ((6, 6, 6), [None, 0, [0, 1]], (6, 2), "contiguous_scalar_index_and_1d_index"), + ( + (6, 6, 6), + [0, None, [0, 1]], + (2, 6), + "non_contiguous_scalar_index_and_1d_index", + ), + ( + (6, 6, 6), + [None, 0, [0, 1]], + (6, 2), + "contiguous_scalar_index_and_1d_index", + ), # (TODO): Exporter doesn't yet support all None indices # ((6, 6, 6), [None, None, None], (6, 6, 6), "all_none_indices"), ] @@ -885,7 +928,9 @@ def forward(self, update, index1, index2): update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32) index1 = torch.tensor([1, 2], dtype=torch.int64) index2 = torch.tensor([3, 4], dtype=torch.int64) - feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2))) + feeds = dict( + zip(["update", "index1", "index2"], (update, index1, index2)) + ) onnx_program = torch.onnx.export( Model(dimension), tuple(feeds.values()), @@ -957,7 +1002,11 @@ def forward(self, x, index, update): output_names=["output"], opset_version=18, dynamo=True, - dynamic_shapes=({0: "a", 1: "b", 2: "c"}, {0: "d"}, {0: "e", 1: "f", 2: "g"}), + dynamic_shapes=( + {0: "a", 1: "b", 2: "c"}, + {0: "d"}, + {0: "e", 1: "f", 2: "g"}, + ), ) _testing.assert_onnx_program(onnx_program) From 0fcb634b9f012e2eaa8d81c19a77b180f3a66be6 Mon Sep 17 00:00:00 2001 From: pratikwayase Date: Sat, 20 Jun 2026 09:52:16 +0530 Subject: [PATCH 5/5] Revert unnecessary changes and fix docstring formatting for repeat_interleave --- .../function_libs/torch_lib/ops/core.py | 40 +++----- .../function_libs/torch_lib/e2e_ops_tests.py | 93 +++++-------------- 2 files changed, 37 insertions(+), 96 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 8c53168595..0217871a85 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8282,54 +8282,44 @@ def aten_repeat_interleave_self_int( .. code-block:: python - import torch - x = torch.tensor([[0, 1, 2], [3, 4, 5]]) - result = x.repeat_interleave(2, dim=0) + x.repeat_interleave(2, dim=0) + + is equivalent to: - print(result) + .. code-block:: python + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + x.repeat((1, 2)).reshape((-1, x.shape[1])) """ - if dim is None: - flat_self = op.Reshape(self, [-1]) - unsqueezed = op.Unsqueeze(flat_self, [1]) - - if isinstance(repeats, int): - tile_repeat = op.Constant(value=ir.tensor([1, repeats], dtype=INT64.dtype)) - else: - # repeats is a symbolic tensor - tile_repeat = op.Concat( - op.Constant(value=ir.tensor([1], dtype=INT64.dtype)), - op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))), - axis=0, - ) - - tiled = op.Expand(unsqueezed, tile_repeat) - return op.Reshape(tiled, op.Constant(value_ints=[-1])) + self = op.Reshape(self, [-1]) + dim = 0 + self_rank = 1 + else: + self_rank = len(self.shape) - self_rank = len(self.shape) pos_dim = (dim + self_rank) % self_rank unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) - if isinstance(repeats, int): tiles = [1] * (self_rank + 1) tiles[pos_dim + 1] = repeats tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype)) else: + # repeats is a symbolic tensor tile_repeat = op.Concat( op.Constant(value=ir.tensor([1] * pos_dim, dtype=INT64.dtype)), op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))), op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)), axis=0, ) - tiled = op.Expand(unsqueezed, tile_repeat) final_shape = op.Concat( - op.Shape(self, start=0, end=pos_dim), + op.Shape(self, start=0, end=dim), op.Constant(value_ints=[-1]), op.Shape(self, start=pos_dim + 1), axis=0, - ) + ) return op.Reshape(tiled, final_shape) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index f96430e9a2..0f47fd8ab3 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -66,10 +66,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x**2 onnx_program = torch.onnx.export( - PowModel(), - (torch.tensor(0.5, dtype=torch.float16),), - dynamo=True, - optimize=False, + PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False ) _testing.assert_onnx_program(onnx_program) @@ -79,10 +76,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x**0.5 onnx_program = torch.onnx.export( - PowModel(), - (torch.tensor(0.5, dtype=torch.float16),), - dynamo=True, - optimize=False, + PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False ) _testing.assert_onnx_program(onnx_program) @@ -152,11 +146,9 @@ class Model(torch.nn.Module): def forward(self, x): return torch.repeat_interleave(x, 2) - model = Model().eval() - inputs = (torch.tensor([2], dtype=torch.int64),) - + inputs = (torch.tensor([2]),) onnx_program = torch.onnx.export( - model, + Model(), inputs, dynamo=True, optimize=False, @@ -166,9 +158,9 @@ def forward(self, x): def test_repeat_interleave_symbolic_tensor(self): class Model(torch.nn.Module): def forward(self, x, y): - return torch.repeat_interleave( - x, y.shape[1], dim=1 - ) * torch.repeat_interleave(y, x.shape[1], dim=1) + return torch.repeat_interleave(x, y.shape[1], dim=1) * torch.repeat_interleave( + y, x.shape[1], dim=1 + ) inputs = ( torch.arange(4, dtype=torch.float32).reshape((2, 2)), @@ -342,9 +334,7 @@ def forward(self, x, offset, weight, bias, mask): def test_dft_axis_promoted_from_attribute_to_input(self): class Model(torch.nn.Module): def forward(self, x): - return torch.ops.aten._fft_r2c( - x, [0], normalization=1, onesided=True - ) # pylint: disable=protected-access + return torch.ops.aten._fft_r2c(x, [0], normalization=1, onesided=True) # pylint: disable=protected-access onnx_program = torch.onnx.export( Model(), @@ -359,24 +349,12 @@ def test_avg_pool(self): class Model(torch.nn.Module): def forward(self, x2d, x3d, x4d, x5d): return ( - torch.nn.functional.avg_pool1d( - x2d, 2 - ), # pylint: disable=not-callable - torch.nn.functional.avg_pool1d( - x3d, 2 - ), # pylint: disable=not-callable - torch.nn.functional.avg_pool2d( - x3d, 2 - ), # pylint: disable=not-callable - torch.nn.functional.avg_pool2d( - x4d, 2 - ), # pylint: disable=not-callable - torch.nn.functional.avg_pool3d( - x4d, 2 - ), # pylint: disable=not-callable - torch.nn.functional.avg_pool3d( - x5d, 2 - ), # pylint: disable=not-callable + torch.nn.functional.avg_pool1d(x2d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool1d(x3d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool2d(x3d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool2d(x4d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool3d(x4d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool3d(x5d, 2), # pylint: disable=not-callable ) x2d = torch.randn(10, 10) @@ -554,9 +532,7 @@ def forward(self, x): def test_aten_unique_consecutive_return(self): class Model(torch.nn.Module): def forward(self, x): - return torch.unique_consecutive( - x, return_inverse=True, return_counts=True - ) + return torch.unique_consecutive(x, return_inverse=True, return_counts=True) model = Model() x = torch.tensor([0, 1, 2, 2, 3, 3, 3, 0, 0], dtype=torch.int64) @@ -602,9 +578,7 @@ def test_aten_stft_3(self): class Model(torch.nn.Module): def forward(self, x): window = torch.ones(16, dtype=torch.float32) - return torch.ops.aten.stft( - x, n_fft=16, window=window, return_complex=False - ) + return torch.ops.aten.stft(x, n_fft=16, window=window, return_complex=False) x = torch.randn(100, dtype=torch.float32) @@ -712,9 +686,7 @@ def forward(self, tokens, h, c): tokens = torch.tensor([1]) h = torch.randn(2, 1, 64) # 2 layers c = torch.randn(2, 1, 64) # 2 layers - onnx_program = torch.onnx.export( - model, (tokens, h, c), dynamo=True, verbose=False - ) + onnx_program = torch.onnx.export(model, (tokens, h, c), dynamo=True, verbose=False) _testing.assert_onnx_program(onnx_program) def test_unbind_dynamic_dim0(self): @@ -806,12 +778,7 @@ def forward(self, x): (2,), "contiguous_non_broadcast_indices_new_dim1", ), - ( - (6, 6, 6), - [None, [0, 1], [2, 3]], - (), - "contiguous_non_broadcast_indices_scalar", - ), + ((6, 6, 6), [None, [0, 1], [2, 3]], (), "contiguous_non_broadcast_indices_scalar"), # Multiple advanced indices, with broadcasting among indices. # Contiguous advanced indices: # This produces index tuples [(0,2), (0, 3), (1,2), (1,3)] in shape (2,2) @@ -862,18 +829,8 @@ def forward(self, x): "non_contiguous_non_first", ), ((6, 6, 6), [0, None, None], (6, 6), "single_scalar_index"), - ( - (6, 6, 6), - [0, None, [0, 1]], - (2, 6), - "non_contiguous_scalar_index_and_1d_index", - ), - ( - (6, 6, 6), - [None, 0, [0, 1]], - (6, 2), - "contiguous_scalar_index_and_1d_index", - ), + ((6, 6, 6), [0, None, [0, 1]], (2, 6), "non_contiguous_scalar_index_and_1d_index"), + ((6, 6, 6), [None, 0, [0, 1]], (6, 2), "contiguous_scalar_index_and_1d_index"), # (TODO): Exporter doesn't yet support all None indices # ((6, 6, 6), [None, None, None], (6, 6, 6), "all_none_indices"), ] @@ -928,9 +885,7 @@ def forward(self, update, index1, index2): update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32) index1 = torch.tensor([1, 2], dtype=torch.int64) index2 = torch.tensor([3, 4], dtype=torch.int64) - feeds = dict( - zip(["update", "index1", "index2"], (update, index1, index2)) - ) + feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2))) onnx_program = torch.onnx.export( Model(dimension), tuple(feeds.values()), @@ -1002,11 +957,7 @@ def forward(self, x, index, update): output_names=["output"], opset_version=18, dynamo=True, - dynamic_shapes=( - {0: "a", 1: "b", 2: "c"}, - {0: "d"}, - {0: "e", 1: "f", 2: "g"}, - ), + dynamic_shapes=({0: "a", 1: "b", 2: "c"}, {0: "d"}, {0: "e", 1: "f", 2: "g"}), ) _testing.assert_onnx_program(onnx_program)