From 05f91f1dd2f47c4c7ac465addef96e1a87cc6da0 Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Mon, 12 Jan 2026 15:11:07 +0700 Subject: [PATCH 1/3] [Relax][Onnx] Handle slope and axis argument with different slope shapes :(1xCx1x1) or (S,) or (1,1) etc.. --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 26 ++++++++++++++++++- python/tvm/topi/nn/elemwise.py | 7 ++--- tests/python/relax/test_frontend_onnx.py | 3 +++ 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 2212fa6c68ea..1479d6f23913 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1127,7 +1127,31 @@ class PRelu(OnnxOpConverter): def _impl_v1(cls, bb, inputs, attr, params): x = inputs[0] slope = inputs[1] - return relax.op.nn.prelu(x, slope) + + x_shape = x.struct_info.shape + slope_shape = slope.struct_info.shape + + ndim = len(x_shape) + s_ndim = len(slope_shape) + + if all(ss == 1 for ss in slope_shape) or s_ndim == 1: + slope = relax.op.reshape(slope, (slope_shape[0],)) + return relax.op.nn.prelu(x, slope, ndim - 1) + + if s_ndim == ndim: + non_one_axes = [i for i, ss in enumerate(slope_shape) if ss != 1] + + # Must have only ONE non-broadcast axis + if len(non_one_axes) != 1: + raise ValueError( + f"Invalid PRelu slope shape (multiple non-broadcast dims): {slope_shape}" + ) + axis = non_one_axes[0] + + slope = relax.op.reshape(slope, (slope_shape[axis],)) + return relax.op.nn.prelu(x, slope, axis) + + raise ValueError(f"Unsupported PRelu slope shape: {slope_shape}") class ThresholdedRelu(OnnxOpConverter): diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py index 59cc3598e9f2..332636185c02 100644 --- a/python/tvm/topi/nn/elemwise.py +++ b/python/tvm/topi/nn/elemwise.py @@ -129,9 +129,10 @@ def prelu(x, slope, axis=1): assert len(slope.shape) == 1 assert axis < len(x.shape) - slope = te.compute( - (get_const_int(x.shape[axis]),), lambda c: slope[0], name="slope_broadcasted" - ) + if slope.shape[0] == 1: + slope = te.compute( + (get_const_int(x.shape[axis]),), lambda c: slope[0], name="slope_broadcasted" + ) assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis]) def _compute_channelwise(*indices): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index eb4c557e754c..f967b3c4c666 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1068,6 +1068,9 @@ def test_mish(): def test_prelu(): verify_binary("PRelu", [3, 32, 32], [1], [3, 32, 32]) + verify_binary("PRelu", [3, 32, 32], [1, 1], [3, 32, 32]) + verify_binary("PRelu", [3, 32, 32], [32], [3, 32, 32]) + verify_binary("PRelu", [3, 32, 32], [3, 1, 1], [3, 32, 32]) def test_thresholded_relu(): From ac37de51a4c7d03af819ed64f56569bb7767d1bb Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Mon, 12 Jan 2026 16:57:08 +0700 Subject: [PATCH 2/3] rerun --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 1479d6f23913..423ee682ba19 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1130,6 +1130,7 @@ def _impl_v1(cls, bb, inputs, attr, params): x_shape = x.struct_info.shape slope_shape = slope.struct_info.shape + ndim = len(x_shape) s_ndim = len(slope_shape) From 20972677c30101accab40b9c4eff2c4feb728a6e Mon Sep 17 00:00:00 2001 From: locnd182644 Date: Mon, 12 Jan 2026 16:57:19 +0700 Subject: [PATCH 3/3] rerun --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 423ee682ba19..1479d6f23913 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1130,7 +1130,6 @@ def _impl_v1(cls, bb, inputs, attr, params): x_shape = x.struct_info.shape slope_shape = slope.struct_info.shape - ndim = len(x_shape) s_ndim = len(slope_shape)