diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index f7825e0420..326dd8ca1a 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -28,7 +28,6 @@ def torch_lib_onnx_functions_from_registry() -> Generator[onnxscript.OnnxFunctio class TestDeduceTypeConstraints(unittest.TestCase): _SKIP_FUNCTIONS_WITH_LOOP_OR_SCAN = ( - "_aten_as_strided_onnx", "_aten_unfold_onnx", "_aten_embedding_bag_onnx", "_aten_embedding_bag_1d_padding_idx_onnx", diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 5896f13471..a6ec606119 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -817,72 +817,80 @@ def aten_argwhere(self: TensorType) -> TensorType: @torch_op("aten::as_strided", trace_only=True) def aten_as_strided( - self: TTensor, size: INT64, stride: Sequence[int], storage_offset: int = 0 + self: TTensor, + size: Sequence[INT64], + stride: Sequence[INT64], + storage_offset: Optional[INT64] = None, ) -> TTensor: """as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)""" - rank = len(stride) - return _aten_as_strided_onnx(self, size, stride, storage_offset, rank) - - -@torch_op("aten::as_strided", private=True) -def _aten_as_strided_onnx( - self: TTensor, size: INT64, stride: INT64, storage_offset: int = 0, rank: int = 0 -) -> TTensor: - # e.g. when size=[2,3,4], stride=[2,1,3], indices=[0] - # i = 0 - # indices=[0], add_value=[0,3,6,9] - # expand(shape=[4]) to [0,0,0,0] - # then + add_value = [0,3,6,9] - # i = 1 - # indices=[0,3,6,9], add_value=[0,1,2] - # expand(shape=[3,4] to [[0,3,6,9],[0,3,6,9],[0,3,6,9]] - # indices + add_value = [[0,3,6,9],[1,3,7,10],[2,5,8,11]] - # i = 2 - # indices = [[0,3,6,9],[1,3,7,10],[2,5,8,11]], add_value=[0,2] - # expand(shape=[2,3,4]) to [[[0,3,6,9],[1,3,7,10],[2,5,8,11]]],[[0,3,6,9],[1,3,7,10],[2,5,8,11]]] - # indices + add_value = [[[0,3,6,9],[1,3,7,10],[2,5,8,11]]],[[2,5,8,11],[3,5,9,12],[4,7,10,13]]] - neg_1 = op.Constant(value_ints=[-1]) - rank_tensor = op.Reshape(rank, neg_1) # should be 3 - # The final indices for op.Gather(data, indices), will be continually changed during the loop - indices = op.Constant(value_int=0) - one_seq = op.SequenceEmpty() - for i in range(rank): - # Get the index from back to front, should be 2,1,0 when to i=0,1,2 - j = rank - i - 1 - j_tensor = op.Reshape(j, neg_1) - # Get size according to index_j, should be 4,3,2 when i=0,1,2 - size_dim_j = op.Gather(size, j_tensor, axis=0) - # Get right size according to index_j, should be [4],[3,4],[2,3,4] when i=0,1,2 - size_after_j = op.Slice(size, j_tensor, rank_tensor) - # Get stride according to index_j, should be 3,1,2 when i=0,1,2 - stride_dim_j = op.Gather(stride, j_tensor, axis=0) - indices = op.Expand(indices, size_after_j) - # When size[j]=4, stride[j]=3, then add_value = [0,1,2,3] * 3 = [0,3,6,9] - # When size[j]=3, stride[j]=1, then add_value = [0,1,2] * 1 = [0,1,2] - # When size[j]=2, stride[j]=2, then add_value = [0,1] * 2 = [0,2] - add_value = op.Range(0, size_dim_j, 1) * stride_dim_j - # Compute the shape for add_value for correct broadcasting - if i == 0: - # shape = [dim_size] - shape = size_dim_j - else: - # shape = [dim_size, 1, 1, ...], the count of 1 euqal to i - ones = op.ConcatFromSequence(one_seq, axis=0) - shape = op.Concat(op.Cast(size_dim_j, to=FLOAT.dtype), ones, axis=0) - shape = op.Cast(shape, to=INT64.dtype) - - add_value = op.Reshape(add_value, shape) - # Broadcasting add value to indices according to size and stride value - indices = indices + add_value - # Dims after dim_size to reshape(add_value), should be [1],[1,1],[1,1,1] when i=0,1,2 - one_seq = op.SequenceInsert(one_seq, op.Constant(value_floats=[1.0])) - + # torch.as_strided produces a view of `self`'s underlying contiguous storage + # with the requested `size` (the output shape) and `stride` (the step, in + # elements of storage, taken along each output dimension), starting at + # `storage_offset` elements into the storage. For an output element at + # position (i_0, ..., i_{n-1}) the element read from storage lives at the flat + # index storage_offset + sum_d i_d * stride[d]. So if we flatten `self` to 1-D + # and gather it with a tensor of those flat indices shaped like the output, we + # reproduce the view as a single Gather. This avoids the hard-to-fold loop of + # the previous implementation. + rank = len(size) + # `self_flatten` is the contiguous storage as a 1-D tensor; Gather indexes into it. self_flatten = op.Reshape(self, op.Constant(value_ints=[-1])) - indices = op.Add(indices, storage_offset) - result = op.Gather(self_flatten, indices) - return result + # A missing storage_offset means "start at the beginning of the storage". + if storage_offset is None: + storage_offset = 0 + + if ( + all(isinstance(s, int) for s in size) + and all(isinstance(s, int) for s in stride) + and isinstance(storage_offset, int) + ): + # Static fast path: every size/stride/offset is known at trace time, so we + # compute the full index tensor with NumPy and emit it as a single + # constant that downstream passes can fold trivially. + # Start from the storage_offset; the per-dimension contributions are added in. + indices = np.array(storage_offset, dtype=np.int64) + for dim, (dim_size, dim_stride) in enumerate(zip(size, stride)): + # Contribution of dimension `dim`: index i_dim contributes i_dim * stride[dim]. + add_value = np.arange(dim_size, dtype=np.int64) * dim_stride + # Reshape that 1-D contribution so it broadcasts along `dim` only + # (length dim_size at position `dim`, length 1 everywhere else), which + # lets the running sum build the full n-D index grid. + broadcast_shape = [1] * rank + broadcast_shape[dim] = dim_size + indices = indices + add_value.reshape(broadcast_shape) + # `indices` now has shape `size`; gathering yields the strided view. + return op.Gather(self_flatten, op.Constant(value=ir.tensor(indices))) + + # Dynamic path: at least one SymInt is a runtime value, so the index tensor + # cannot be folded to a constant. We build it with ONNX ops, mirroring the + # NumPy math above. The per-dimension loop is unrolled at trace time because + # the rank is always static, so no Loop/Scan is emitted. + zero = op.Constant(value_int=0) + one = op.Constant(value_int=1) + # `empty_shape` reshapes a value to a 0-D scalar (shape []). + empty_shape = op.Constant(value=ir.tensor(np.array([], dtype=np.int64))) + # Start the running index from storage_offset as an INT64 scalar; SymInt + # runtime values are assumed to be INT64. + indices = op.Reshape(storage_offset, empty_shape) + for dim in range(rank): + # Reshape this dimension's size and stride to INT64 scalars. + dim_size = op.Reshape(size[dim], empty_shape) + dim_stride = op.Reshape(stride[dim], empty_shape) + # add_value = arange(dim_size) * dim_stride, a 1-D tensor of length dim_size + # holding the storage offsets contributed by index 0..dim_size-1 along `dim`. + add_value = op.Mul(op.Range(zero, dim_size, one), dim_stride) + # Insert singleton axes everywhere except `dim` so this 1-D contribution + # broadcasts along dimension `dim` only when added to the running index, + # matching the NumPy `reshape(broadcast_shape)` in the static path. + unsqueeze_axes = [axis for axis in range(rank) if axis != dim] + if unsqueeze_axes: + add_value = op.Unsqueeze(add_value, op.Constant(value_ints=unsqueeze_axes)) + indices = op.Add(indices, add_value) + + # `indices` now has shape `size`; gathering yields the strided view. + return op.Gather(self_flatten, indices) def aten_as_strided_copy( diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 440e2316c1..d9539c5e04 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -1004,6 +1004,80 @@ def forward(self, x): got = onnx_program.call_reference({"x": inputs[0]}) torch.testing.assert_close(expected, got[0]) + def test_aten_as_strided_static_multi_dim(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.as_strided(x, (2, 3), (4, 1), 2) + + model = Model() + x = torch.arange(24, dtype=torch.float32).reshape(4, 6) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_aten_as_strided_static_single_dim(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.as_strided(x, (4,), (2,)) + + model = Model() + x = torch.arange(12, dtype=torch.float32) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_aten_as_strided_static_overlapping(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.as_strided(x, (3, 3), (1, 1)) + + model = Model() + x = torch.arange(10, dtype=torch.float32) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_aten_as_strided_static_scalar(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.as_strided(x, (), (), 3) + + model = Model() + x = torch.arange(12, dtype=torch.float32) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_aten_as_strided_dynamic_size(self): + class Model(torch.nn.Module): + def forward(self, x): + n = x.shape[0] - 1 + return torch.as_strided(x, (n, 2), (1, 1)) + + model = Model() + x = torch.arange(12, dtype=torch.float32) + onnx_program = torch.onnx.export( + model, + (x,), + dynamic_shapes=({0: "length"},), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_as_strided_dynamic_size_with_offset(self): + class Model(torch.nn.Module): + def forward(self, x): + n = x.shape[0] - 2 + return torch.as_strided(x, (n,), (1,), 1) + + model = Model() + x = torch.arange(12, dtype=torch.float32) + onnx_program = torch.onnx.export( + model, + (x,), + dynamic_shapes=({0: "length"},), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main()