diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 5896f13471..0217871a85 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8290,12 +8290,15 @@ def aten_repeat_interleave_self_int( .. code-block:: python x = torch.tensor([[0, 1, 2], [3, 4, 5]]) - x.repeat((1, 2)).reshape((-1, t.shape[1])) + x.repeat((1, 2)).reshape((-1, x.shape[1])) """ if dim is None: - raise NotImplementedError("No conversion available yet when dim is None.") + self = op.Reshape(self, [-1]) + dim = 0 + self_rank = 1 + else: + self_rank = len(self.shape) - self_rank = len(self.shape) pos_dim = (dim + self_rank) % self_rank unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) if isinstance(repeats, int): @@ -8311,8 +8314,6 @@ def aten_repeat_interleave_self_int( axis=0, ) tiled = op.Expand(unsqueezed, tile_repeat) - if self_rank == 1: - return op.Identity(tiled) final_shape = op.Concat( op.Shape(self, start=0, end=dim), op.Constant(value_ints=[-1]), diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 440e2316c1..0f47fd8ab3 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -141,6 +141,20 @@ def forward(self, x, ind): ) _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_int_dim_none(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.repeat_interleave(x, 2) + + inputs = (torch.tensor([2]),) + onnx_program = torch.onnx.export( + Model(), + inputs, + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_symbolic_tensor(self): class Model(torch.nn.Module): def forward(self, x, y): diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 241ac98cf9..c4f47f2097 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1141,10 +1141,6 @@ def _where_input_wrangler( reason=("ignore cases when repeasts is a Tensor"), ) .skip(dtypes=(torch.bool,), reason="bool not supported") - .skip( - matcher=lambda sample: sample.kwargs.get("dim") is None, - reason="fixme: conversion not implemented if dim is None", - ) .skip( matcher=lambda sample: sample.input.numel() == 0, reason="fixme: conversion not implemented when input tensor is empty", @@ -1155,10 +1151,6 @@ def _where_input_wrangler( reason=("ignore cases when repeasts is an int"), ) .skip(dtypes=(torch.bool,), reason="bool not supported") - .skip( - matcher=lambda sample: sample.kwargs.get("dim") is None, - reason="fixme: conversion not implemented if dim is None", - ) .skip( matcher=lambda sample: sample.input.numel() == 0, reason="fixme: conversion not implemented when input tensor is empty",