Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def torch_lib_onnx_functions_from_registry() -> Generator[onnxscript.OnnxFunctio

class TestDeduceTypeConstraints(unittest.TestCase):
_SKIP_FUNCTIONS_WITH_LOOP_OR_SCAN = (
"_aten_as_strided_onnx",
"_aten_unfold_onnx",
"_aten_embedding_bag_onnx",
"_aten_embedding_bag_1d_padding_idx_onnx",
Expand Down
127 changes: 66 additions & 61 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,72 +817,77 @@ def aten_argwhere(self: TensorType) -> TensorType:

@torch_op("aten::as_strided", trace_only=True)
def aten_as_strided(
self: TTensor, size: INT64, stride: Sequence[int], storage_offset: int = 0
self: TTensor, size: Sequence[INT64], stride: Sequence[INT64], storage_offset: Optional[INT64] = None
) -> TTensor:
"""as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)"""

rank = len(stride)
return _aten_as_strided_onnx(self, size, stride, storage_offset, rank)


@torch_op("aten::as_strided", private=True)
def _aten_as_strided_onnx(
self: TTensor, size: INT64, stride: INT64, storage_offset: int = 0, rank: int = 0
) -> TTensor:
# e.g. when size=[2,3,4], stride=[2,1,3], indices=[0]
# i = 0
# indices=[0], add_value=[0,3,6,9]
# expand(shape=[4]) to [0,0,0,0]
# then + add_value = [0,3,6,9]
# i = 1
# indices=[0,3,6,9], add_value=[0,1,2]
# expand(shape=[3,4] to [[0,3,6,9],[0,3,6,9],[0,3,6,9]]
# indices + add_value = [[0,3,6,9],[1,3,7,10],[2,5,8,11]]
# i = 2
# indices = [[0,3,6,9],[1,3,7,10],[2,5,8,11]], add_value=[0,2]
# expand(shape=[2,3,4]) to [[[0,3,6,9],[1,3,7,10],[2,5,8,11]]],[[0,3,6,9],[1,3,7,10],[2,5,8,11]]]
# indices + add_value = [[[0,3,6,9],[1,3,7,10],[2,5,8,11]]],[[2,5,8,11],[3,5,9,12],[4,7,10,13]]]
neg_1 = op.Constant(value_ints=[-1])
rank_tensor = op.Reshape(rank, neg_1) # should be 3
# The final indices for op.Gather(data, indices), will be continually changed during the loop
indices = op.Constant(value_int=0)
one_seq = op.SequenceEmpty()
for i in range(rank):
# Get the index from back to front, should be 2,1,0 when to i=0,1,2
j = rank - i - 1
j_tensor = op.Reshape(j, neg_1)
# Get size according to index_j, should be 4,3,2 when i=0,1,2
size_dim_j = op.Gather(size, j_tensor, axis=0)
# Get right size according to index_j, should be [4],[3,4],[2,3,4] when i=0,1,2
size_after_j = op.Slice(size, j_tensor, rank_tensor)
# Get stride according to index_j, should be 3,1,2 when i=0,1,2
stride_dim_j = op.Gather(stride, j_tensor, axis=0)
indices = op.Expand(indices, size_after_j)
# When size[j]=4, stride[j]=3, then add_value = [0,1,2,3] * 3 = [0,3,6,9]
# When size[j]=3, stride[j]=1, then add_value = [0,1,2] * 1 = [0,1,2]
# When size[j]=2, stride[j]=2, then add_value = [0,1] * 2 = [0,2]
add_value = op.Range(0, size_dim_j, 1) * stride_dim_j
# Compute the shape for add_value for correct broadcasting
if i == 0:
# shape = [dim_size]
shape = size_dim_j
else:
# shape = [dim_size, 1, 1, ...], the count of 1 euqal to i
ones = op.ConcatFromSequence(one_seq, axis=0)
shape = op.Concat(op.Cast(size_dim_j, to=FLOAT.dtype), ones, axis=0)
shape = op.Cast(shape, to=INT64.dtype)

add_value = op.Reshape(add_value, shape)
# Broadcasting add value to indices according to size and stride value
indices = indices + add_value
# Dims after dim_size to reshape(add_value), should be [1],[1,1],[1,1,1] when i=0,1,2
one_seq = op.SequenceInsert(one_seq, op.Constant(value_floats=[1.0]))

# torch.as_strided produces a view of `self`'s underlying contiguous storage
# with the requested `size` (the output shape) and `stride` (the step, in
# elements of storage, taken along each output dimension), starting at
# `storage_offset` elements into the storage. For an output element at
# position (i_0, ..., i_{n-1}) the element read from storage lives at the flat
# index storage_offset + sum_d i_d * stride[d]. So if we flatten `self` to 1-D
# and gather it with a tensor of those flat indices shaped like the output, we
# reproduce the view as a single Gather. This avoids the hard-to-fold loop of
# the previous implementation.
rank = len(size)
# `self_flatten` is the contiguous storage as a 1-D tensor; Gather indexes into it.
self_flatten = op.Reshape(self, op.Constant(value_ints=[-1]))
indices = op.Add(indices, storage_offset)
result = op.Gather(self_flatten, indices)

return result
# A missing storage_offset means "start at the beginning of the storage".
if storage_offset is None:
storage_offset = 0

if (
all(isinstance(s, int) for s in size)
and all(isinstance(s, int) for s in stride)
and isinstance(storage_offset, int)
):
# Static fast path: every size/stride/offset is known at trace time, so we
# compute the full index tensor with NumPy and emit it as a single
# constant that downstream passes can fold trivially.
# Start from the storage_offset; the per-dimension contributions are added in.
indices = np.array(storage_offset, dtype=np.int64)
for dim, (dim_size, dim_stride) in enumerate(zip(size, stride)):
# Contribution of dimension `dim`: index i_dim contributes i_dim * stride[dim].
add_value = np.arange(dim_size, dtype=np.int64) * dim_stride
# Reshape that 1-D contribution so it broadcasts along `dim` only
# (length dim_size at position `dim`, length 1 everywhere else), which
# lets the running sum build the full n-D index grid.
broadcast_shape = [1] * rank
broadcast_shape[dim] = dim_size
indices = indices + add_value.reshape(broadcast_shape)
# `indices` now has shape `size`; gathering yields the strided view.
return op.Gather(self_flatten, op.Constant(value=ir.tensor(indices)))

# Dynamic path: at least one SymInt is a runtime value, so the index tensor
# cannot be folded to a constant. We build it with ONNX ops, mirroring the
# NumPy math above. The per-dimension loop is unrolled at trace time because
# the rank is always static, so no Loop/Scan is emitted.
zero = op.Constant(value_int=0)
one = op.Constant(value_int=1)
# `empty_shape` reshapes a value to a 0-D scalar (shape []).
empty_shape = op.Constant(value=ir.tensor(np.array([], dtype=np.int64)))
# Start the running index from storage_offset as an INT64 scalar; SymInt
# runtime values are assumed to be INT64.
indices = op.Reshape(storage_offset, empty_shape)
for dim in range(rank):
# Reshape this dimension's size and stride to INT64 scalars.
dim_size = op.Reshape(size[dim], empty_shape)
dim_stride = op.Reshape(stride[dim], empty_shape)
Comment thread
justinchuby marked this conversation as resolved.
# add_value = arange(dim_size) * dim_stride, a 1-D tensor of length dim_size
# holding the storage offsets contributed by index 0..dim_size-1 along `dim`.
add_value = op.Mul(op.Range(zero, dim_size, one), dim_stride)
# Insert singleton axes everywhere except `dim` so this 1-D contribution
# broadcasts along dimension `dim` only when added to the running index,
# matching the NumPy `reshape(broadcast_shape)` in the static path.
unsqueeze_axes = [axis for axis in range(rank) if axis != dim]
if unsqueeze_axes:
add_value = op.Unsqueeze(add_value, op.Constant(value_ints=unsqueeze_axes))
indices = op.Add(indices, add_value)

# `indices` now has shape `size`; gathering yields the strided view.
return op.Gather(self_flatten, indices)


def aten_as_strided_copy(
Expand Down
74 changes: 74 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,80 @@ def forward(self, x):
got = onnx_program.call_reference({"x": inputs[0]})
torch.testing.assert_close(expected, got[0])

def test_aten_as_strided_static_multi_dim(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.as_strided(x, (2, 3), (4, 1), 2)

model = Model()
x = torch.arange(24, dtype=torch.float32).reshape(4, 6)
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_aten_as_strided_static_single_dim(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.as_strided(x, (4,), (2,))

model = Model()
x = torch.arange(12, dtype=torch.float32)
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_aten_as_strided_static_overlapping(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.as_strided(x, (3, 3), (1, 1))

model = Model()
x = torch.arange(10, dtype=torch.float32)
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_aten_as_strided_static_scalar(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.as_strided(x, (), (), 3)

model = Model()
x = torch.arange(12, dtype=torch.float32)
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
_testing.assert_onnx_program(onnx_program)

def test_aten_as_strided_dynamic_size(self):
class Model(torch.nn.Module):
def forward(self, x):
n = x.shape[0] - 1
return torch.as_strided(x, (n, 2), (1, 1))

model = Model()
x = torch.arange(12, dtype=torch.float32)
onnx_program = torch.onnx.export(
model,
(x,),
dynamic_shapes=({0: "length"},),
dynamo=True,
verbose=False,
)
_testing.assert_onnx_program(onnx_program)

def test_aten_as_strided_dynamic_size_with_offset(self):
class Model(torch.nn.Module):
def forward(self, x):
n = x.shape[0] - 2
return torch.as_strided(x, (n,), (1,), 1)

model = Model()
x = torch.arange(12, dtype=torch.float32)
onnx_program = torch.onnx.export(
model,
(x,),
dynamic_shapes=({0: "length"},),
dynamo=True,
verbose=False,
)
_testing.assert_onnx_program(onnx_program)


if __name__ == "__main__":
unittest.main()
Loading