From 28876179b9dd89c2b9f29200bee4813489206917 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 24 Apr 2025 07:15:16 +0000 Subject: [PATCH 1/9] add pixel shuffle op into torch frontends --- include/tvm/relax/attrs/nn.h | 9 +++ .../torch/base_fx_graph_translator.py | 6 ++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 9 +++ python/tvm/relax/op/nn/__init__.py | 1 + python/tvm/relax/op/nn/nn.py | 33 +++++++++ python/tvm/relax/transform/legalize_ops/nn.py | 6 ++ python/tvm/topi/nn/pixel_shuffle.py | 70 +++++++++++++++++++ src/relax/op/nn/nn.cc | 64 +++++++++++++++++ src/relax/op/nn/nn.h | 3 + .../test_frontend_from_exported_program.py | 31 ++++++++ tests/python/relax/test_frontend_from_fx.py | 32 +++++++++ tests/python/relax/test_op_nn.py | 20 ++++++ 13 files changed, 285 insertions(+) create mode 100644 python/tvm/topi/nn/pixel_shuffle.py diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index e2ce2be6a882..f0f80ad8f4a0 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -603,6 +603,15 @@ struct PadAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used for the pixel shuffle operator */ +struct PixelShuffleAttrs : public tvm::AttrsNode { + int upscale_factor; + + TVM_DECLARE_ATTRS(PixelShuffleAttrs, "relax.attrs.PixelShuffleAttrs") { + TVM_ATTR_FIELD(upscale_factor).describe("Scale factor for spatial upsampling."); + } +}; + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 33f6ffc3132e..69e61f9dbe09 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -883,6 +883,12 @@ def _pad(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode, value)) + def _pixel_shuffle(self, node: fx.Node) -> relax.Var: + data = self.env[node.args[0]] + upscale_factor = node.args[1] + + return self.block_builder.emit(relax.op.nn.pixel_shuffle(data, upscale_factor)) + def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) query = transpose_S_H(self.env[node.args[0]]) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0434712050ed..58e060b7595d 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -309,6 +309,7 @@ def create_convert_map( "log_softmax.int": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), "pad.default": self._pad, + "pixel_shuffle.default": self._pixel_shuffle, "prelu.default": self._prelu, "reciprocal.default": self._reciprocal, "relu.default": self._unary_op(relax.op.nn.relu), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 55abf20fcc03..fb7525a50f81 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -459,6 +459,13 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: ceil_mode = module.ceil_mode return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + + def _pixel_shuffle_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + upscale_factor = module.upscale_factor + + return self.block_builder.emit(relax.op.nn.pixel_shuffle(x, upscale_factor)) ########## Linear Interpolation ########## @@ -665,6 +672,7 @@ def create_convert_map( nn.Linear: self._linear_module, nn.MaxPool2d: self._max_pool2d_module, nn.modules.sparse.Embedding: self._embedding_module, + nn.PixelShuffle: self._pixel_shuffle_module, # tensor manipulation nn.Flatten: self._flatten_module, ## call_function and call_method @@ -703,6 +711,7 @@ def create_convert_map( "log_softmax": self._log_softmax, "neg": self._unary_op(relax.op.negative), "pad": self._pad, + "pixel_shuffle": self._pixel_shuffle, "prelu": self._prelu, "reciprocal": self._reciprocal, "relu": self._unary_op(relax.op.nn.relu), diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index 14b5dcfc0681..08ecda275c3e 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -43,6 +43,7 @@ max_pool3d, nll_loss, pad, + pixel_shuffle, prelu, relu, rms_norm, diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index e201b596f936..0c18f03e69be 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -549,6 +549,39 @@ def pad( return _ffi_api.pad(data, pad_width, pad_mode, pad_value) +def pixel_shuffle(data: Expr, upscale_factor: int): + r""" + Pixel Shuffle Operator + + This operator performs the pixel shuffle operation on the input tensor, + which is often used for efficient sub-pixel convolution in image + super-resolution tasks. It rearranges elements in a tensor of shape + (N, C × r^2, H, W) to a tensor of shape (N, C, H × r, W × r), where `r` + is the upscale factor. + + Parameters + ---------- + data : relax.Expr + The input tensor to the pixel shuffle operator. It must have 4 dimensions + with the format (N, C * r^2, H, W), where `r` is the upscale factor. + + upscale_factor : int + The upscaling factor `r`. It determines how much to increase the spatial + resolution (height and width) of the input tensor. + + Returns + ------- + result : relax.Expr + The transformed tensor with shape (N, C, H * r, W * r). + + Example + ------- + If the input tensor has shape (1, 8, 10, 15) and `upscale_factor` is 2, + the resulting tensor will have shape (1, 2, 20, 30). + """ + return _ffi_api.pixel_shuffle(data, upscale_factor) + + def max_pool1d( data: Expr, pool_size: Union[int, Tuple[int, int]] = (1,), diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 6a6f0ed6cb93..1e0584ab0ed0 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -249,6 +249,12 @@ def _nn_pad(bb: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.nn.pixel_shuffle") +def _nn_pixel_shuffle(bb: BlockBuilder, call: Call) -> Expr: + upscale_factor = call.attrs.upscale_factor + return bb.call_te(topi.nn.pixel_shuffle, call.args[0], upscale_factor= upscale_factor) + + @register_legalize("relax.nn.max_pool1d") def _nn_max_pool1d(bb: BlockBuilder, call: Call) -> Expr: if call.attrs.out_layout != call.attrs.layout: diff --git a/python/tvm/topi/nn/pixel_shuffle.py b/python/tvm/topi/nn/pixel_shuffle.py new file mode 100644 index 000000000000..6209f706cb12 --- /dev/null +++ b/python/tvm/topi/nn/pixel_shuffle.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM operator pixel shuffle compute.""" +from __future__ import absolute_import + +import tvm + +def pixel_shuffle(data, upscale_factor, name="PixelShuffle"): + """PixelShuffle operator that rearranges elements in a tensor of shape + [..., C * r * r, H, W] to [..., C, H * r, W * r]. + + Parameters + ---------- + data : tvm.te.Tensor + N-D input tensor with at least 3 dimensions. Channel must be at index -3. + + upscale_factor : int + The upscale factor (r). + + name : str + Name of the output tensor. + + Returns + ------- + output : tvm.te.Tensor + Pixel shuffled tensor with shape [..., C, H*r, W*r] + """ + assert isinstance(upscale_factor, int) and upscale_factor > 0 + ndim = len(data.shape) + assert ndim >= 3, "Input must be at least 3D" + + r = tvm.tir.const(upscale_factor, "int32") + c_in, h_in, w_in = data.shape[-3], data.shape[-2], data.shape[-1] + + c_out = tvm.tir.floordiv(c_in, r * r) + h_out = h_in * r + w_out = w_in * r + + out_shape = list(data.shape[:-3]) + [c_out, h_out, w_out] + + def _compute(*indices): + batch_indices = indices[:-3] + c_out_idx, h_out_idx, w_out_idx = indices[-3], indices[-2], indices[-1] + + h_idx = tvm.tir.floordiv(h_out_idx, r) + r1 = h_out_idx % r + + w_idx = tvm.tir.floordiv(w_out_idx, r) + r2 = w_out_idx % r + + c_in_idx = (c_out_idx * r * r) + (r1 * r) + r2 + + index_tuple = batch_indices + (c_in_idx, h_idx, w_idx) + return data[index_tuple] + + return tvm.te.compute(out_shape, _compute, name=name) \ No newline at end of file diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 8c0b86fe5f8e..3519cbcf59b8 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -224,6 +224,70 @@ TVM_REGISTER_OP("relax.nn.pad") .set_attr("FInferStructInfo", InferStructInfoPad) .set_attr("FPurity", Bool(true)); +/* relax.nn.pixel_shuffle */ +TVM_REGISTER_NODE_TYPE(PixelShuffleAttrs); + +Expr pixel_shuffle(Expr data, int upscale_factor) { + auto attrs = make_object(); + attrs->upscale_factor = upscale_factor; + static const Op& op = Op::Get("relax.nn.pixel_shuffle"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.pixel_shuffle").set_body_typed(pixel_shuffle); + +StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + int r = attrs->upscale_factor; + ICHECK_GT(r, 0) << "Upscale factor must be positive"; + + const TensorStructInfo& input = input_sinfo[0]; + int ndim = input->ndim; + ICHECK_GE(ndim, 3) << "PixelShuffle requires at least 3D input tensor"; + + if (!input->shape.defined()) { + return TensorStructInfo(input->dtype, ndim); + } + + const auto* shape = input->shape.as(); + Array in_shape = shape->values; + + int channel_idx = ndim - 3; + int h_idx = ndim - 2; + int w_idx = ndim - 1; + + PrimExpr c_in = in_shape[channel_idx]; + PrimExpr h_in = in_shape[h_idx]; + PrimExpr w_in = in_shape[w_idx]; + + PrimExpr r_expr = IntImm(DataType::Int(32), r); + PrimExpr r_squared = r_expr * r_expr; + + // Output shape: + Array out_shape; + for (int i = 0; i < ndim; ++i) { + if (i == channel_idx) { + out_shape.push_back(c_in / r_squared); + } else if (i == h_idx) { + out_shape.push_back(h_in * r_expr); + } else if (i == w_idx) { + out_shape.push_back(w_in * r_expr); + } else { + out_shape.push_back(in_shape[i]); + } + } + + return TensorStructInfo(ShapeExpr(out_shape), input->dtype); +} + +TVM_REGISTER_OP("relax.nn.pixel_shuffle") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", InferStructInfoPixelShuffle) + .set_attr("FPurity", Bool(true)); + /* relax.nn.batchnorm */ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, const Array& input_sinfo, Array axes) { diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index a9c3dd0a5767..c618059b5ed7 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -75,6 +75,9 @@ Expr softplus(Expr data, double beta, double threshold); /*! \brief LogSoftmax function. */ Expr log_softmax(Expr data, int axis); +/*! \brief Pixel Shuffle function. */ +Expr pixel_shuffle(Expr data, int upscale_factor) + /*! \brief Compute batch normalization. */ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // int axis, double epsilon, bool center, bool scale, double momentum, bool training); diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 108617991b1f..76004ae175bd 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1985,6 +1985,37 @@ def main( verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), example_args, {}, expected_circular) +def test_pixel_shuffle(): + class PixelShuffle1(torch.nn.Module): + def __init__(self, upscale_factor=2): + super().__init__() + self.pixel_shuffle = torch.nn.PixelShuffle(upscale_factor) + + def forward(self, x): + return self.pixel_shuffle(x) + + class PixelShuffle2(torch.nn.Module): + def __init__(self, upscale_factor=2): + super().__init__() + self.upscale_factor = upscale_factor + + def forward(self, x): + return torch.nn.functional.pixel_shuffle(x, self.upscale_factor) + + @tvm.script.ir_module + class expected: + @R.function + def main(x: R.Tensor((1, 8, 10, 15), dtype="float32")) -> R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle(x, upscale_factor=2) + gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 8, 10, 15, dtype=torch.float32),) + verify_model(PixelShuffle1(upscale_factor=2), example_args, {}, expected) + verify_model(PixelShuffle2(upscale_factor=2), example_args, {}, expected) + def test_einsum(): class Einsum1(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index cb69398e0a00..d095838916e4 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -592,6 +592,38 @@ def main( verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), input_infos, {}, expected_circular) +def test_pixel_shuffle(): + class PixelShuffle1(torch.nn.Module): + def __init__(self, upscale_factor=2): + super().__init__() + self.pixel_shuffle = torch.nn.PixelShuffle(upscale_factor) + + def forward(self, x): + return self.pixel_shuffle(x) + + class PixelShuffle2(torch.nn.Module): + def __init__(self, upscale_factor=2): + super().__init__() + self.upscale_factor = upscale_factor + + def forward(self, x): + return torch.nn.functional.pixel_shuffle(x, self.upscale_factor) + + @tvm.script.ir_module + class expected: + @R.function + def main(inp_0: R.Tensor((1, 8, 10, 15), dtype="float32")) -> R.Tensor((1, 2, 20, 30), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle(inp_0, upscale_factor=2) + gv: R.Tensor((1, 2, 20, 30), dtype="float32") = lv + R.output(gv) + return gv + + input_infos = [([1, 8, 10, 15], "float32")] + verify_model(PixelShuffle1(2), input_infos, {}, expected) + verify_model(PixelShuffle2(2), input_infos, {}, expected) + + def test_linear(): # nn.Linear class Dense1(Module): diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index 1c03d8fe4649..bb61329da3e0 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -1822,5 +1822,25 @@ def test_pad_infer_struct_info(): ) +def test_pixel_shuffle_infer_struct_info(): + bb = relax.BlockBuilder() + x1 = relax.Var("x1", R.Tensor((1, 8, 10, 15), "float32")) + x2 = relax.Var("x2", R.Tensor((2, 6, 18, 5, 4), "float32")) + + upscale_factor1 = 2 + _check_inference( + bb, + relax.op.nn.pixel_shuffle(x1, upscale_factor1), + relax.TensorStructInfo((1, 2, 20, 30), dtype="float32"), + ) + + upscale_factor2 = 3 + _check_inference( + bb, + relax.op.nn.pixel_shuffle(x2, upscale_factor2), + relax.TensorStructInfo((2, 6, 2, 15, 12), dtype="float32"), + ) + + if __name__ == "__main__": tvm.testing.main() From 891a3d4033a2c06101a9652541c44553f2efe03b Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 24 Apr 2025 07:16:30 +0000 Subject: [PATCH 2/9] fix end of files formatting issue --- python/tvm/topi/nn/pixel_shuffle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/nn/pixel_shuffle.py b/python/tvm/topi/nn/pixel_shuffle.py index 6209f706cb12..24cb4eaea315 100644 --- a/python/tvm/topi/nn/pixel_shuffle.py +++ b/python/tvm/topi/nn/pixel_shuffle.py @@ -67,4 +67,4 @@ def _compute(*indices): index_tuple = batch_indices + (c_in_idx, h_idx, w_idx) return data[index_tuple] - return tvm.te.compute(out_shape, _compute, name=name) \ No newline at end of file + return tvm.te.compute(out_shape, _compute, name=name) From afd2880d902318558c7179603cfa6d5a23b5b483 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 24 Apr 2025 07:17:26 +0000 Subject: [PATCH 3/9] fix trailing whitespaces issue --- .../relax/frontend/torch/base_fx_graph_translator.py | 2 +- python/tvm/relax/frontend/torch/fx_translator.py | 4 ++-- python/tvm/relax/op/nn/nn.py | 10 +++++----- .../relax/test_frontend_from_exported_program.py | 2 +- tests/python/relax/test_frontend_from_fx.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 69e61f9dbe09..2a244ac0c4e0 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -886,7 +886,7 @@ def _pad(self, node: fx.Node) -> relax.Var: def _pixel_shuffle(self, node: fx.Node) -> relax.Var: data = self.env[node.args[0]] upscale_factor = node.args[1] - + return self.block_builder.emit(relax.op.nn.pixel_shuffle(data, upscale_factor)) def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index fb7525a50f81..83a9ad55dfbd 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -459,12 +459,12 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: ceil_mode = module.ceil_mode return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) - + def _pixel_shuffle_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] upscale_factor = module.upscale_factor - + return self.block_builder.emit(relax.op.nn.pixel_shuffle(x, upscale_factor)) ########## Linear Interpolation ########## diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 0c18f03e69be..e234e8ad7b18 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -555,18 +555,18 @@ def pixel_shuffle(data: Expr, upscale_factor: int): This operator performs the pixel shuffle operation on the input tensor, which is often used for efficient sub-pixel convolution in image - super-resolution tasks. It rearranges elements in a tensor of shape - (N, C × r^2, H, W) to a tensor of shape (N, C, H × r, W × r), where `r` + super-resolution tasks. It rearranges elements in a tensor of shape + (N, C × r^2, H, W) to a tensor of shape (N, C, H × r, W × r), where `r` is the upscale factor. Parameters ---------- data : relax.Expr - The input tensor to the pixel shuffle operator. It must have 4 dimensions + The input tensor to the pixel shuffle operator. It must have 4 dimensions with the format (N, C * r^2, H, W), where `r` is the upscale factor. upscale_factor : int - The upscaling factor `r`. It determines how much to increase the spatial + The upscaling factor `r`. It determines how much to increase the spatial resolution (height and width) of the input tensor. Returns @@ -576,7 +576,7 @@ def pixel_shuffle(data: Expr, upscale_factor: int): Example ------- - If the input tensor has shape (1, 8, 10, 15) and `upscale_factor` is 2, + If the input tensor has shape (1, 8, 10, 15) and `upscale_factor` is 2, the resulting tensor will have shape (1, 2, 20, 30). """ return _ffi_api.pixel_shuffle(data, upscale_factor) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 76004ae175bd..c0e032f9b494 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2011,7 +2011,7 @@ def main(x: R.Tensor((1, 8, 10, 15), dtype="float32")) -> R.Tuple(R.Tensor((1, 2 gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv,) R.output(gv) return gv - + example_args = (torch.randn(1, 8, 10, 15, dtype=torch.float32),) verify_model(PixelShuffle1(upscale_factor=2), example_args, {}, expected) verify_model(PixelShuffle2(upscale_factor=2), example_args, {}, expected) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index d095838916e4..8142f938f949 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -608,7 +608,7 @@ def __init__(self, upscale_factor=2): def forward(self, x): return torch.nn.functional.pixel_shuffle(x, self.upscale_factor) - + @tvm.script.ir_module class expected: @R.function @@ -618,7 +618,7 @@ def main(inp_0: R.Tensor((1, 8, 10, 15), dtype="float32")) -> R.Tensor((1, 2, 20 gv: R.Tensor((1, 2, 20, 30), dtype="float32") = lv R.output(gv) return gv - + input_infos = [([1, 8, 10, 15], "float32")] verify_model(PixelShuffle1(2), input_infos, {}, expected) verify_model(PixelShuffle2(2), input_infos, {}, expected) From a1c5e7550047a0f296429d070ad4c48a6760f857 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 24 Apr 2025 09:09:46 +0000 Subject: [PATCH 4/9] fix lint issues --- python/tvm/relax/transform/legalize_ops/nn.py | 2 +- python/tvm/topi/nn/pixel_shuffle.py | 19 ++++++++++--------- src/relax/op/nn/nn.h | 2 +- .../test_frontend_from_exported_program.py | 9 +++++++-- tests/python/relax/test_frontend_from_fx.py | 8 ++++++-- 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 1e0584ab0ed0..f18ad6097f06 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -252,7 +252,7 @@ def _nn_pad(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.nn.pixel_shuffle") def _nn_pixel_shuffle(bb: BlockBuilder, call: Call) -> Expr: upscale_factor = call.attrs.upscale_factor - return bb.call_te(topi.nn.pixel_shuffle, call.args[0], upscale_factor= upscale_factor) + return bb.call_te(topi.nn.pixel_shuffle, call.args[0], upscale_factor=upscale_factor) @register_legalize("relax.nn.max_pool1d") diff --git a/python/tvm/topi/nn/pixel_shuffle.py b/python/tvm/topi/nn/pixel_shuffle.py index 24cb4eaea315..cd8bd619611a 100644 --- a/python/tvm/topi/nn/pixel_shuffle.py +++ b/python/tvm/topi/nn/pixel_shuffle.py @@ -19,6 +19,7 @@ import tvm + def pixel_shuffle(data, upscale_factor, name="PixelShuffle"): """PixelShuffle operator that rearranges elements in a tensor of shape [..., C * r * r, H, W] to [..., C, H * r, W * r]. @@ -43,12 +44,12 @@ def pixel_shuffle(data, upscale_factor, name="PixelShuffle"): ndim = len(data.shape) assert ndim >= 3, "Input must be at least 3D" - r = tvm.tir.const(upscale_factor, "int32") + upscale_factor_const = tvm.tir.const(upscale_factor, "int32") c_in, h_in, w_in = data.shape[-3], data.shape[-2], data.shape[-1] - c_out = tvm.tir.floordiv(c_in, r * r) - h_out = h_in * r - w_out = w_in * r + c_out = tvm.tir.floordiv(c_in, upscale_factor_const * upscale_factor_const) + h_out = h_in * upscale_factor_const + w_out = w_in * upscale_factor_const out_shape = list(data.shape[:-3]) + [c_out, h_out, w_out] @@ -56,13 +57,13 @@ def _compute(*indices): batch_indices = indices[:-3] c_out_idx, h_out_idx, w_out_idx = indices[-3], indices[-2], indices[-1] - h_idx = tvm.tir.floordiv(h_out_idx, r) - r1 = h_out_idx % r + h_idx = tvm.tir.floordiv(h_out_idx, upscale_factor_const) + h_offset = h_out_idx % upscale_factor_const - w_idx = tvm.tir.floordiv(w_out_idx, r) - r2 = w_out_idx % r + w_idx = tvm.tir.floordiv(w_out_idx, upscale_factor_const) + w_offset = w_out_idx % upscale_factor_const - c_in_idx = (c_out_idx * r * r) + (r1 * r) + r2 + c_in_idx = (c_out_idx * upscale_factor_const * upscale_factor_const) + (h_offset * upscale_factor_const) + w_offset index_tuple = batch_indices + (c_in_idx, h_idx, w_idx) return data[index_tuple] diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index c618059b5ed7..018741430199 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -76,7 +76,7 @@ Expr softplus(Expr data, double beta, double threshold); Expr log_softmax(Expr data, int axis); /*! \brief Pixel Shuffle function. */ -Expr pixel_shuffle(Expr data, int upscale_factor) +Expr pixel_shuffle(Expr data, int upscale_factor); /*! \brief Compute batch normalization. */ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index c0e032f9b494..93ecc454902d 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2005,9 +2005,13 @@ def forward(self, x): @tvm.script.ir_module class expected: @R.function - def main(x: R.Tensor((1, 8, 10, 15), dtype="float32")) -> R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")): + def main( + x: R.Tensor((1, 8, 10, 15), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle(x, upscale_factor=2) + lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle( + x, upscale_factor=2 + ) gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv,) R.output(gv) return gv @@ -2016,6 +2020,7 @@ def main(x: R.Tensor((1, 8, 10, 15), dtype="float32")) -> R.Tuple(R.Tensor((1, 2 verify_model(PixelShuffle1(upscale_factor=2), example_args, {}, expected) verify_model(PixelShuffle2(upscale_factor=2), example_args, {}, expected) + def test_einsum(): class Einsum1(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 8142f938f949..2989164f1259 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -612,9 +612,13 @@ def forward(self, x): @tvm.script.ir_module class expected: @R.function - def main(inp_0: R.Tensor((1, 8, 10, 15), dtype="float32")) -> R.Tensor((1, 2, 20, 30), dtype="float32"): + def main( + inp_0: R.Tensor((1, 8, 10, 15), dtype="float32") + ) -> R.Tensor((1, 2, 20, 30), dtype="float32"): with R.dataflow(): - lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle(inp_0, upscale_factor=2) + lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle( + inp_0, upscale_factor=2 + ) gv: R.Tensor((1, 2, 20, 30), dtype="float32") = lv R.output(gv) return gv From bceaee59a42f2cb1df69dd2f8a59108848e90df4 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 24 Apr 2025 09:32:34 +0000 Subject: [PATCH 5/9] fix long line code formatting --- python/tvm/topi/nn/pixel_shuffle.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/nn/pixel_shuffle.py b/python/tvm/topi/nn/pixel_shuffle.py index cd8bd619611a..78966ee4d9d7 100644 --- a/python/tvm/topi/nn/pixel_shuffle.py +++ b/python/tvm/topi/nn/pixel_shuffle.py @@ -63,7 +63,11 @@ def _compute(*indices): w_idx = tvm.tir.floordiv(w_out_idx, upscale_factor_const) w_offset = w_out_idx % upscale_factor_const - c_in_idx = (c_out_idx * upscale_factor_const * upscale_factor_const) + (h_offset * upscale_factor_const) + w_offset + c_in_idx = ( + (c_out_idx * upscale_factor_const * upscale_factor_const) + + (h_offset * upscale_factor_const) + + w_offset + ) index_tuple = batch_indices + (c_in_idx, h_idx, w_idx) return data[index_tuple] From a62a502315f463719f5e7b4ed752dd14cd3a676d Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Sun, 27 Apr 2025 16:01:13 +0000 Subject: [PATCH 6/9] add arg check condition --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 2a244ac0c4e0..57214ad5b7ee 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -886,6 +886,7 @@ def _pad(self, node: fx.Node) -> relax.Var: def _pixel_shuffle(self, node: fx.Node) -> relax.Var: data = self.env[node.args[0]] upscale_factor = node.args[1] + assert isinstance(upscale_factor, int), "PixelShuffle only accepts an integer upscale_factor." return self.block_builder.emit(relax.op.nn.pixel_shuffle(data, upscale_factor)) From e01cc48d8352662a3a8ef4786edb881dc6974ade Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Sun, 27 Apr 2025 16:17:49 +0000 Subject: [PATCH 7/9] fix lint issue in base fx graph script --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 57214ad5b7ee..fa64443c9d7d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -886,7 +886,9 @@ def _pad(self, node: fx.Node) -> relax.Var: def _pixel_shuffle(self, node: fx.Node) -> relax.Var: data = self.env[node.args[0]] upscale_factor = node.args[1] - assert isinstance(upscale_factor, int), "PixelShuffle only accepts an integer upscale_factor." + assert isinstance( + upscale_factor, int + ), "PixelShuffle only accepts an integer upscale_factor." return self.block_builder.emit(relax.op.nn.pixel_shuffle(data, upscale_factor)) From 3c85be9740768096c74138a8905f4e138cd2e320 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Wed, 30 Apr 2025 06:11:20 +0000 Subject: [PATCH 8/9] add condition in struct info function --- src/relax/op/nn/nn.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 3519cbcf59b8..0c996ede7ec4 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -264,6 +264,12 @@ StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx PrimExpr r_expr = IntImm(DataType::Int(32), r); PrimExpr r_squared = r_expr * r_expr; + const auto* c_in_imm = c_in.as(); + const auto* r2_imm = r_squared.as(); + + ICHECK_EQ(c_in_imm->value % r2_imm->value, 0) + << "Number of input channels must be divisible by the square of the upscale factor"; + // Output shape: Array out_shape; for (int i = 0; i < ndim; ++i) { From b4b8c00c21fdec3bb62bac998287b82b4f626bfa Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Wed, 30 Apr 2025 06:30:45 +0000 Subject: [PATCH 9/9] fix lint issue in struct info func --- src/relax/op/nn/nn.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 0c996ede7ec4..16b8f467ff0f 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -268,7 +268,7 @@ StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx const auto* r2_imm = r_squared.as(); ICHECK_EQ(c_in_imm->value % r2_imm->value, 0) - << "Number of input channels must be divisible by the square of the upscale factor"; + << "Number of input channels must be divisible by the square of the upscale factor"; // Output shape: Array out_shape;