From 6ab150c82d3871620776afb0e8a2f955af4760e6 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Tue, 9 Dec 2025 10:50:48 +0800 Subject: [PATCH] Add edge padding mode --- python/tvm/relax/frontend/common.py | 4 +- tests/python/relax/test_frontend_common.py | 174 +++++++++++++++++++++ 2 files changed, 176 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/common.py b/python/tvm/relax/frontend/common.py index c1e9296ca3a5..5b18d5e27d9b 100644 --- a/python/tvm/relax/frontend/common.py +++ b/python/tvm/relax/frontend/common.py @@ -123,5 +123,5 @@ def autopad( topi.nn.mirror_pad, data, pad[:, 0].tolist(), pad[:, 1].tolist(), "REFLECT" ) else: - # TODO(gigiblender) Support edge mode. - raise NotImplementedError("Pad mode {} not implemented".format(pad_type)) + # edge mode - replicate border values + return bb.emit_te(topi.nn.replicate_pad, data, pad[:, 0].tolist(), pad[:, 1].tolist()) diff --git a/tests/python/relax/test_frontend_common.py b/tests/python/relax/test_frontend_common.py index 21becb2c8590..85424df2f602 100644 --- a/tests/python/relax/test_frontend_common.py +++ b/tests/python/relax/test_frontend_common.py @@ -16,7 +16,11 @@ # under the License. import tvm import tvm.testing +from tvm import relax from tvm.relax.frontend import detach_params +from tvm.relax.frontend.common import autopad +from tvm.script import ir as I +from tvm.script import tir as T from tvm.script.parser import relax as R @@ -37,5 +41,175 @@ def func(x: R.Tensor((2, 3), "float32")): tvm.testing.assert_allclose(detached_params["func"][0].numpy(), param.numpy()) +class TestAutopad: + def _test_autopad(self, pad_type, expected): + bb = relax.BlockBuilder() + input_shape = (1, 1, 4, 4) + x = relax.Var("x", relax.TensorStructInfo(input_shape, "float32")) + + with bb.function("main", [x]): + with bb.dataflow(): + result = autopad( + bb, + x, + strides=[2, 2], + kernel_shape=[3, 3], + dilations=(1, 1), + pad_type=pad_type, + deconv=False, + mode="SAME_UPPER", + pad_value=0.0, + ) + out = bb.emit_output(result) + bb.emit_func_output(out) + + tvm.ir.assert_structural_equal(bb.get(), expected) + + def test_constant(self): + @I.ir_module + class expected: + @T.prim_func(private=True) + def pad( + x: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(4)), "float32"), + PadInput: T.Buffer((T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5)): + with T.block("PadInput"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(x[v_i0, v_i1, v_i2, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else( + T.int64(0) <= v_i2 + and v_i2 < T.int64(4) + and T.int64(0) <= v_i3 + and v_i3 < T.int64(4), + x[v_i0, v_i1, v_i2, v_i3], + T.float32(0.0), + ) + + @R.function + def main( + x: R.Tensor((1, 1, 4, 4), dtype="float32") + ) -> R.Tensor((1, 1, 5, 5), dtype="float32"): + cls = expected + with R.dataflow(): + lv = R.call_tir( + cls.pad, (x,), out_sinfo=R.Tensor((1, 1, 5, 5), dtype="float32") + ) + gv: R.Tensor((1, 1, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + self._test_autopad("constant", expected) + + def test_edge(self): + @I.ir_module + class expected: + @T.prim_func(private=True) + def replicate_pad( + x: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(4)), "float32"), + ReplicatePadInput: T.Buffer( + (T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32" + ), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5)): + with T.block("ReplicatePadInput"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads( + x[ + T.int64(0), + T.int64(0), + T.int64(0) : T.int64(4), + T.int64(0) : T.int64(4), + ] + ) + T.writes(ReplicatePadInput[v_i0, v_i1, v_i2, v_i3]) + ReplicatePadInput[v_i0, v_i1, v_i2, v_i3] = x[ + T.if_then_else( + v_i0 < T.int64(0), + T.int64(0), + T.if_then_else(T.int64(1) <= v_i0, T.int64(0), v_i0), + ), + T.if_then_else( + v_i1 < T.int64(0), + T.int64(0), + T.if_then_else(T.int64(1) <= v_i1, T.int64(0), v_i1), + ), + T.if_then_else( + v_i2 < T.int64(0), + T.int64(0), + T.if_then_else(T.int64(4) <= v_i2, T.int64(3), v_i2), + ), + T.if_then_else( + v_i3 < T.int64(0), + T.int64(0), + T.if_then_else(T.int64(4) <= v_i3, T.int64(3), v_i3), + ), + ] + + @R.function + def main( + x: R.Tensor((1, 1, 4, 4), dtype="float32") + ) -> R.Tensor((1, 1, 5, 5), dtype="float32"): + cls = expected + with R.dataflow(): + lv = R.call_tir( + cls.replicate_pad, (x,), out_sinfo=R.Tensor((1, 1, 5, 5), dtype="float32") + ) + gv: R.Tensor((1, 1, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + self._test_autopad("edge", expected) + + def test_reflect(self): + @I.ir_module + class expected: + @T.prim_func(private=True) + def mirror_pad( + x: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(4)), "float32"), + MirrorPadInput: T.Buffer( + (T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32" + ), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5)): + with T.block("MirrorPadInput"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(x[v_i0, v_i1, T.int64(0) : T.int64(4), T.int64(0) : T.int64(4)]) + T.writes(MirrorPadInput[v_i0, v_i1, v_i2, v_i3]) + MirrorPadInput[v_i0, v_i1, v_i2, v_i3] = x[ + v_i0, + v_i1, + T.if_then_else( + T.int64(4) <= v_i2, + T.int64(6) - v_i2, + T.if_then_else(v_i2 < T.int64(0), v_i2 * T.int64(-1), v_i2), + ), + T.if_then_else( + T.int64(4) <= v_i3, + T.int64(6) - v_i3, + T.if_then_else(v_i3 < T.int64(0), v_i3 * T.int64(-1), v_i3), + ), + ] + + @R.function + def main( + x: R.Tensor((1, 1, 4, 4), dtype="float32") + ) -> R.Tensor((1, 1, 5, 5), dtype="float32"): + cls = expected + with R.dataflow(): + lv = R.call_tir( + cls.mirror_pad, (x,), out_sinfo=R.Tensor((1, 1, 5, 5), dtype="float32") + ) + gv: R.Tensor((1, 1, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + self._test_autopad("reflect", expected) + + if __name__ == "__main__": tvm.testing.main()