Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8282,43 +8282,54 @@

.. code-block:: python

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
x.repeat_interleave(2, dim=0)

is equivalent to:

.. code-block:: python
import torch

x = torch.tensor([[0, 1, 2], [3, 4, 5]])
x.repeat((1, 2)).reshape((-1, t.shape[1]))
result = x.repeat_interleave(2, dim=0)

print(result)
"""

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

if dim is None:
raise NotImplementedError("No conversion available yet when dim is None.")
flat_self = op.Reshape(self, [-1])
unsqueezed = op.Unsqueeze(flat_self, [1])

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

if isinstance(repeats, int):
tile_repeat = op.Constant(value=ir.tensor([1, repeats], dtype=INT64.dtype))
else:
# repeats is a symbolic tensor
tile_repeat = op.Concat(
op.Constant(value=ir.tensor([1], dtype=INT64.dtype)),
op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))),
axis=0,
)

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

tiled = op.Expand(unsqueezed, tile_repeat)
return op.Reshape(tiled, op.Constant(value_ints=[-1]))

self_rank = len(self.shape)
pos_dim = (dim + self_rank) % self_rank
unsqueezed = op.Unsqueeze(self, [pos_dim + 1])

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

if isinstance(repeats, int):
tiles = [1] * (self_rank + 1)
tiles[pos_dim + 1] = repeats
tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype))
else:
# repeats is a symbolic tensor
tile_repeat = op.Concat(
op.Constant(value=ir.tensor([1] * pos_dim, dtype=INT64.dtype)),
op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))),
op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)),
axis=0,
)

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

tiled = op.Expand(unsqueezed, tile_repeat)
if self_rank == 1:
return op.Identity(tiled)
final_shape = op.Concat(
op.Shape(self, start=0, end=dim),
op.Shape(self, start=0, end=pos_dim),
op.Constant(value_ints=[-1]),
op.Shape(self, start=pos_dim + 1),
axis=0,
)
)

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

return op.Reshape(tiled, final_shape)


Expand Down
103 changes: 83 additions & 20 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x**2

onnx_program = torch.onnx.export(
PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False
PowModel(),
(torch.tensor(0.5, dtype=torch.float16),),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

Expand All @@ -76,7 +79,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x**0.5

onnx_program = torch.onnx.export(
PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False
PowModel(),
(torch.tensor(0.5, dtype=torch.float16),),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

Expand Down Expand Up @@ -141,12 +147,28 @@ def forward(self, x, ind):
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_int_dim_none(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.repeat_interleave(x, 2)

model = Model().eval()
inputs = (torch.tensor([2], dtype=torch.int64),)

onnx_program = torch.onnx.export(
model,
inputs,
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)

def test_repeat_interleave_symbolic_tensor(self):
class Model(torch.nn.Module):
def forward(self, x, y):
return torch.repeat_interleave(x, y.shape[1], dim=1) * torch.repeat_interleave(
y, x.shape[1], dim=1
)
return torch.repeat_interleave(
x, y.shape[1], dim=1
) * torch.repeat_interleave(y, x.shape[1], dim=1)

inputs = (
torch.arange(4, dtype=torch.float32).reshape((2, 2)),
Expand Down Expand Up @@ -320,7 +342,9 @@ def forward(self, x, offset, weight, bias, mask):
def test_dft_axis_promoted_from_attribute_to_input(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.ops.aten._fft_r2c(x, [0], normalization=1, onesided=True) # pylint: disable=protected-access
return torch.ops.aten._fft_r2c(
x, [0], normalization=1, onesided=True
) # pylint: disable=protected-access

onnx_program = torch.onnx.export(
Model(),
Expand All @@ -335,12 +359,24 @@ def test_avg_pool(self):
class Model(torch.nn.Module):
def forward(self, x2d, x3d, x4d, x5d):
return (
torch.nn.functional.avg_pool1d(x2d, 2), # pylint: disable=not-callable
torch.nn.functional.avg_pool1d(x3d, 2), # pylint: disable=not-callable
torch.nn.functional.avg_pool2d(x3d, 2), # pylint: disable=not-callable
torch.nn.functional.avg_pool2d(x4d, 2), # pylint: disable=not-callable
torch.nn.functional.avg_pool3d(x4d, 2), # pylint: disable=not-callable
torch.nn.functional.avg_pool3d(x5d, 2), # pylint: disable=not-callable
torch.nn.functional.avg_pool1d(
x2d, 2
), # pylint: disable=not-callable
torch.nn.functional.avg_pool1d(
x3d, 2
), # pylint: disable=not-callable
torch.nn.functional.avg_pool2d(
x3d, 2
), # pylint: disable=not-callable
torch.nn.functional.avg_pool2d(
x4d, 2
), # pylint: disable=not-callable
torch.nn.functional.avg_pool3d(
x4d, 2
), # pylint: disable=not-callable
torch.nn.functional.avg_pool3d(
x5d, 2
), # pylint: disable=not-callable
)

x2d = torch.randn(10, 10)
Expand Down Expand Up @@ -518,7 +554,9 @@ def forward(self, x):
def test_aten_unique_consecutive_return(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.unique_consecutive(x, return_inverse=True, return_counts=True)
return torch.unique_consecutive(
x, return_inverse=True, return_counts=True
)

model = Model()
x = torch.tensor([0, 1, 2, 2, 3, 3, 3, 0, 0], dtype=torch.int64)
Expand Down Expand Up @@ -564,7 +602,9 @@ def test_aten_stft_3(self):
class Model(torch.nn.Module):
def forward(self, x):
window = torch.ones(16, dtype=torch.float32)
return torch.ops.aten.stft(x, n_fft=16, window=window, return_complex=False)
return torch.ops.aten.stft(
x, n_fft=16, window=window, return_complex=False
)

x = torch.randn(100, dtype=torch.float32)

Expand Down Expand Up @@ -672,7 +712,9 @@ def forward(self, tokens, h, c):
tokens = torch.tensor([1])
h = torch.randn(2, 1, 64) # 2 layers
c = torch.randn(2, 1, 64) # 2 layers
onnx_program = torch.onnx.export(model, (tokens, h, c), dynamo=True, verbose=False)
onnx_program = torch.onnx.export(
model, (tokens, h, c), dynamo=True, verbose=False
)
_testing.assert_onnx_program(onnx_program)

def test_unbind_dynamic_dim0(self):
Expand Down Expand Up @@ -764,7 +806,12 @@ def forward(self, x):
(2,),
"contiguous_non_broadcast_indices_new_dim1",
),
((6, 6, 6), [None, [0, 1], [2, 3]], (), "contiguous_non_broadcast_indices_scalar"),
(
(6, 6, 6),
[None, [0, 1], [2, 3]],
(),
"contiguous_non_broadcast_indices_scalar",
),
# Multiple advanced indices, with broadcasting among indices.
# Contiguous advanced indices:
# This produces index tuples [(0,2), (0, 3), (1,2), (1,3)] in shape (2,2)
Expand Down Expand Up @@ -815,8 +862,18 @@ def forward(self, x):
"non_contiguous_non_first",
),
((6, 6, 6), [0, None, None], (6, 6), "single_scalar_index"),
((6, 6, 6), [0, None, [0, 1]], (2, 6), "non_contiguous_scalar_index_and_1d_index"),
((6, 6, 6), [None, 0, [0, 1]], (6, 2), "contiguous_scalar_index_and_1d_index"),
(
(6, 6, 6),
[0, None, [0, 1]],
(2, 6),
"non_contiguous_scalar_index_and_1d_index",
),
(
(6, 6, 6),
[None, 0, [0, 1]],
(6, 2),
"contiguous_scalar_index_and_1d_index",
),
# (TODO): Exporter doesn't yet support all None indices
# ((6, 6, 6), [None, None, None], (6, 6, 6), "all_none_indices"),
]
Expand Down Expand Up @@ -871,7 +928,9 @@ def forward(self, update, index1, index2):
update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32)
index1 = torch.tensor([1, 2], dtype=torch.int64)
index2 = torch.tensor([3, 4], dtype=torch.int64)
feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2)))
feeds = dict(
zip(["update", "index1", "index2"], (update, index1, index2))
)
onnx_program = torch.onnx.export(
Model(dimension),
tuple(feeds.values()),
Expand Down Expand Up @@ -943,7 +1002,11 @@ def forward(self, x, index, update):
output_names=["output"],
opset_version=18,
dynamo=True,
dynamic_shapes=({0: "a", 1: "b", 2: "c"}, {0: "d"}, {0: "e", 1: "f", 2: "g"}),
dynamic_shapes=(
{0: "a", 1: "b", 2: "c"},
{0: "d"},
{0: "e", 1: "f", 2: "g"},
),
)
_testing.assert_onnx_program(onnx_program)

Expand Down
8 changes: 0 additions & 8 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,10 +1141,6 @@ def _where_input_wrangler(
reason=("ignore cases when repeasts is a Tensor"),
)
.skip(dtypes=(torch.bool,), reason="bool not supported")
.skip(
matcher=lambda sample: sample.kwargs.get("dim") is None,
reason="fixme: conversion not implemented if dim is None",
)
.skip(
matcher=lambda sample: sample.input.numel() == 0,
reason="fixme: conversion not implemented when input tensor is empty",
Expand All @@ -1155,10 +1151,6 @@ def _where_input_wrangler(
reason=("ignore cases when repeasts is an int"),
)
.skip(dtypes=(torch.bool,), reason="bool not supported")
.skip(
matcher=lambda sample: sample.kwargs.get("dim") is None,
reason="fixme: conversion not implemented if dim is None",
)
.skip(
matcher=lambda sample: sample.input.numel() == 0,
reason="fixme: conversion not implemented when input tensor is empty",
Expand Down
Loading