Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2411,6 +2411,39 @@ def _impl_v18(cls, bb, inputs, attr, params):
)


class AffineGrid(OnnxOpConverter):
"""Converts an onnx AffineGrid node into an equivalent Relax expression."""

@classmethod
def _impl_v20(cls, bb, inputs, attr, params):
theta = inputs[0] # [N, 2, 3] for 2D
size = get_constant(inputs[1], params) # [N, C, H, W] for 2D
align_corners = attr.get("align_corners", 0)

if align_corners != 1:
raise NotImplementedError(
"AffineGrid with align_corners=0 is not yet supported in TVM"
)

# Extract size values
if isinstance(size, relax.Constant):
size_vals = size.data.numpy().astype("int64").tolist()
elif isinstance(size, relax.expr.ShapeExpr):
size_vals = [int(v.value) for v in size.values]
else:
raise NotImplementedError(f"Dynamic size of type {type(size)} is not supported")

# Only 2D is supported: size = [N, C, H, W]
if len(size_vals) != 4:
raise ValueError("Only 2D AffineGrid (size=[N,C,H,W]) is supported")
target_h, target_w = size_vals[2], size_vals[3]

# Relax affine_grid outputs [N, 2, H, W]
grid = bb.emit(relax.op.image.affine_grid(theta, (target_h, target_w)))
# Permute to ONNX convention [N, H, W, 2]
return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1]))


class Einsum(OnnxOpConverter):
"""Converts an onnx Einsum node into an equivalent Relax expression."""

Expand Down Expand Up @@ -4151,6 +4184,7 @@ def _get_convert_map():
"NonMaxSuppression": NonMaxSuppression,
"AllClassNMS": AllClassNMS,
"GridSample": GridSample,
"AffineGrid": AffineGrid,
"Upsample": Upsample,
# others
"DepthToSpace": DepthToSpace,
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,29 @@ def _grid_sampler_2d(self, node: fx.Node) -> relax.Var:
)
)

def _affine_grid_generator(self, node: fx.Node) -> relax.Var:
"""Convert torch.nn.functional.affine_grid to relax.op.image.affine_grid."""
args = self.retrieve_args(node)
theta = args[0] # [N, 2, 3]
size = args[1] # [N, C, H, W]
align_corners = args[2] if len(args) > 2 else False

if not align_corners:
raise NotImplementedError(
"affine_grid with align_corners=False is not yet supported in TVM"
)

# Extract spatial dimensions (H, W) from PyTorch's [N, C, H, W] size
target_h = size[2]
target_w = size[3]

# Relax affine_grid outputs [N, 2, H, W]
grid = self.block_builder.emit(
relax.op.image.affine_grid(theta, (target_h, target_w))
)
# Permute to PyTorch convention [N, H, W, 2]
return self.block_builder.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1]))

def _torchvision_roi_align(self, node: fx.Node) -> relax.Var:
"""Convert torchvision.ops.roi_align to relax.op.vision.roi_align."""
args = self.retrieve_args(node)
Expand Down Expand Up @@ -1768,6 +1791,7 @@ def create_convert_map(
"zeros.default": self._zeros,
"zeros_like.default": self._zeros_like,
"grid_sampler_2d.default": self._grid_sampler_2d,
"affine_grid_generator.default": self._affine_grid_generator,
"roi_align.default": self._torchvision_roi_align,
# datatype
"to.dtype": self._to,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
# under the License.
"""Image operators."""

from .image import grid_sample, resize2d, resize3d
from .image import affine_grid, grid_sample, resize2d, resize3d
41 changes: 41 additions & 0 deletions python/tvm/relax/op/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
# under the License.
"""Image operators."""

from typing import cast

from tvm import DataType
from tvm.ir.expr import PrimExpr

from ...expr import Expr, ShapeExpr
from . import _ffi_api

PrimExprLike = int | PrimExpr
SizeLike = PrimExprLike | tuple[PrimExprLike, ...]


def resize2d(
Expand Down Expand Up @@ -229,3 +232,41 @@ def grid_sample(
padding_mode,
align_corners,
)


def affine_grid(
data: Expr,
size: Expr | SizeLike,
) -> Expr:
"""Generate a 2D sampling grid using an affine transformation matrix.

This operation is described in https://arxiv.org/pdf/1506.02025.pdf.
It generates a uniform sampling grid within the target shape, normalizes it
to [-1, 1], and applies the provided affine transformation.

Parameters
----------
data : relax.Expr
The input affine matrix tensor with shape [batch, 2, 3].

size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, PrimExprLike]]
The target output spatial shape (H, W). If a single integer or PrimExpr
is provided, it is interpreted as a square output shape (size, size).

Returns
-------
result : relax.Expr
The output grid tensor with shape [batch, 2, H, W].

Note
----
Only `align_corners=True` is supported by this operator, matching the
behavior of the underlying TOPI implementation. When using this operator
via PyTorch or ONNX frontends, `align_corners=False` will be rejected.
"""
if isinstance(size, int | PrimExpr):
size = (size, size)
Comment on lines +267 to +268
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The behavior of treating a single int or PrimExpr for size as (size, size) (i.e., square dimensions) is an implicit convention. It would be beneficial to explicitly document this behavior in the docstring for clarity, or add a comment here.

Suggested change
if isinstance(size, int | PrimExpr):
size = (size, size)
if isinstance(size, int | PrimExpr): # Assume square dimensions if a single value is provided
size = (size, size)

if isinstance(size, tuple | list):
size = ShapeExpr(size)

return cast(Expr, _ffi_api.affine_grid(data, size))
18 changes: 17 additions & 1 deletion python/tvm/relax/transform/legalize_ops/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=invalid-name
"""Default legalization function for image operators."""

from tvm import topi
from tvm import tirx, topi

from ...block_builder import BlockBuilder
from ...expr import Call, Expr
Expand Down Expand Up @@ -54,6 +54,22 @@ def _image_grid_sample(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.image.affine_grid")
def _image_affine_grid(bb: BlockBuilder, call: Call) -> Expr:
for v in call.args[1].values:
if not isinstance(v, (int, tirx.IntImm)):
raise ValueError(
"affine_grid legalization requires static target_shape, "
f"got symbolic value: {v}"
)
target_shape = [int(v) for v in call.args[1].values]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Converting PrimExpr values to int using int(v) can lead to runtime errors if v is a symbolic tir.Var that cannot be evaluated to a concrete integer at this stage. If topi.image.affine_grid requires concrete integer shapes, a check should be added to ensure v is not symbolic, or a more robust conversion mechanism should be used if topi can handle symbolic shapes.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int(v) will crash with an unhelpful error on symbolic shapes. Consider adding an explicit check with a clear message, e.g.:

for v in call.args[1].values:                                                                                                      
    if not isinstance(v, (int, tir.IntImm)):                                                                                       
        raise ValueError(                                                                                                          
            "affine_grid legalization requires static target_shape, "                                                              
            f"got symbolic value: {v}"                          
        )       

return bb.call_te(
topi.image.affine_grid,
call.args[0],
target_shape=target_shape,
)


@register_legalize("relax.image.resize3d")
def _image_resize3d(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
Expand Down
94 changes: 94 additions & 0 deletions src/relax/op/image/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,5 +340,99 @@ TVM_REGISTER_OP("relax.image.grid_sample")
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.image.affine_grid */

Expr affine_grid(Expr data, Expr size) {
static const Op& op = Op::Get("relax.image.affine_grid");
return Call(op, {std::move(data), std::move(size)}, Attrs(), {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.op.image.affine_grid", affine_grid);
}

StructInfo InferStructInfoAffineGrid(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 2) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "AffineGrid expects two arguments, while the given number of arguments is "
<< call->args.size());
}

const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
const auto* size_sinfo = GetStructInfoAs<ShapeStructInfoNode>(call->args[1]);
const auto* size_value = call->args[1].as<ShapeExprNode>();

if (data_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "AffineGrid expects the input data to be a Tensor, while the given data is "
<< call->args[0]->GetTypeKey());
}
if (size_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "AffineGrid expects the target size to be a Shape, while the given one is "
<< call->args[1]->GetTypeKey());
}
if (size_sinfo->ndim != 2) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "AffineGrid expects the target size to be a 2-dim shape, while the given "
"one has ndim "
<< size_sinfo->ndim);
}

// data should be 3-D: [batch, 2, 3]
if (data_sinfo->ndim != -1 && data_sinfo->ndim != 3) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "AffineGrid expects the input data to be 3-D (batch, 2, 3), but got ndim "
<< data_sinfo->ndim);
}

const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape != nullptr) {
// Check that the affine matrix has shape [batch, 2, 3]
if (data_shape->values.size() >= 2) {
auto* dim1 = data_shape->values[1].as<IntImmNode>();
if (dim1 != nullptr && dim1->value != 2) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "AffineGrid expects the second dimension of input to be 2, but got "
<< dim1->value);
}
}
if (data_shape->values.size() >= 3) {
auto* dim2 = data_shape->values[2].as<IntImmNode>();
if (dim2 != nullptr && dim2->value != 3) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "AffineGrid expects the third dimension of input to be 3, but got "
<< dim2->value);
}
}
}

DataType out_dtype = data_sinfo->dtype;

if (data_shape == nullptr || size_value == nullptr) {
return TensorStructInfo(out_dtype, /*ndim=*/4, data_sinfo->vdevice);
}

// Output shape: [batch, 2, target_height, target_width]
ffi::Array<PrimExpr> out_shape;
out_shape.push_back(data_shape->values[0]); // batch
out_shape.push_back(IntImm(DataType::Int(64), 2)); // 2 (spatial dimensions)
out_shape.push_back(size_value->values[0]); // target_height
out_shape.push_back(size_value->values[1]); // target_width

return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice);
}

TVM_REGISTER_OP("relax.image.affine_grid")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input affine matrix tensor.")
.add_argument("size", "Shape", "The target output shape (H, W).")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAffineGrid)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
} // namespace tvm
3 changes: 3 additions & 0 deletions src/relax/op/image/resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ Expr resize3d(Expr data, Expr size, ffi::Array<FloatImm> roi, ffi::String layout
Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout,
ffi::String padding_mode, bool align_corners);

/*! \brief Image affine_grid operator. */
Expr affine_grid(Expr data, Expr size);

} // namespace relax
} // namespace tvm

Expand Down
57 changes: 57 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -9095,5 +9095,62 @@ def false_fn(x):
)


def test_affine_grid():
class AffineGrid(Module):
def forward(self, theta):
return torch.nn.functional.affine_grid(
theta, [1, 3, 16, 16], align_corners=True
)

@tvm.script.ir_module
class expected:
@R.function
def main(
theta: R.Tensor((1, 2, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 16, 16, 2), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 2, 16, 16), dtype="float32") = R.image.affine_grid(
theta, size=(16, 16)
)
lv1: R.Tensor((1, 16, 16, 2), dtype="float32") = R.permute_dims(
lv, axes=[0, 2, 3, 1]
)
gv: R.Tuple(R.Tensor((1, 16, 16, 2), dtype="float32")) = (lv1,)
R.output(gv)
return gv

example_args = (torch.randn(1, 2, 3, dtype=torch.float32),)
# Disable decomposition to keep aten.affine_grid_generator as a single op
verify_model(AffineGrid(), example_args, {}, expected, run_ep_decomposition=False)


def test_affine_grid_numerically():
"""Verify affine_grid numerical correctness: PyTorch vs TVM via our converter."""

class AffineGrid(Module):
def forward(self, theta):
return torch.nn.functional.affine_grid(
theta, [2, 3, 8, 12], align_corners=True
)

model = AffineGrid()
example_args = (torch.randn(2, 2, 3, dtype=torch.float32),)

with torch.no_grad():
pytorch_output = model(*example_args)

exported_program = export(model, args=example_args)
mod = from_exported_program(exported_program, run_ep_decomposition=False)

exe = tvm.compile(mod, target="llvm")
vm = relax.VirtualMachine(exe, tvm.cpu())

tvm_args = [tvm.runtime.tensor(arg.numpy()) for arg in example_args]
tvm_output = vm["main"](*tvm_args)
tvm_output_np = tvm_output[0].numpy()

tvm.testing.assert_allclose(tvm_output_np, pytorch_output.numpy(), rtol=1e-5, atol=1e-5)


if __name__ == "__main__":
tvm.testing.main()
26 changes: 26 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4011,6 +4011,32 @@ def test_nms_score_threshold():
)


def test_affine_grid():
affine_grid_node = helper.make_node(
"AffineGrid",
["theta", "size"],
["grid"],
align_corners=1,
)

graph = helper.make_graph(
[affine_grid_node],
"affine_grid_test",
inputs=[
helper.make_tensor_value_info("theta", TensorProto.FLOAT, [2, 2, 3]),
],
initializer=[
helper.make_tensor("size", TensorProto.INT64, [4], [2, 3, 16, 16]),
],
outputs=[
helper.make_tensor_value_info("grid", TensorProto.FLOAT, [2, 16, 16, 2]),
],
)

model = helper.make_model(graph, producer_name="affine_grid_test")
check_correctness(model, opset=20)


@pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"])
@pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
@pytest.mark.parametrize("align_corners", [0, 1])
Expand Down
Loading