Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
prelu op support and test script added
  • Loading branch information
deivanayakisankaralingam authored and deivanayakisankaralingam committed Apr 17, 2025
commit 8e384a03b10502eabd809be8e488be3f6311df18
10 changes: 10 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,16 @@ struct SoftplusAttrs : public tvm::AttrsNode<SoftplusAttrs> {
}
};

/*! \brief Attributes used in PReLU operator */
struct PReluAttrs : public tvm::AttrsNode<PReluAttrs> {
int axis;

TVM_DECLARE_ATTRS(PReluAttrs, "relax.attrs.PReluAttrs") {
TVM_ATTR_FIELD(axis)
.describe("The axis along which the alpha values are applied.");
}
};

/*! \brief Attributes used in batch_norm operator */
struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
int axis;
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ def _log_softmax(self, node: fx.Node) -> relax.Var:
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))

def _prelu(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
alpha = self.env[node.args[1]]
axis = 0 if len(x.struct_info.shape) == 1 else 1
return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis))

def _round(self, node: fx.Node) -> relax.Expr:
if node.kwargs.get("decimals", 0) != 0:
raise ValueError("specifying decimals for round is not supported yet")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def create_convert_map(
"log1p.default": self._log1p,
"log_softmax.int": self._log_softmax,
"neg.default": self._unary_op(relax.op.negative),
"prelu.default": self._prelu,
"reciprocal.default": self._reciprocal,
"relu.default": self._unary_op(relax.op.nn.relu),
"relu_.default": self._unary_op(relax.op.nn.relu),
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ def _log_softmax_module(self, node: fx.Node) -> relax.Var:
dim = module.dim
assert dim is not None
return self.block_builder.emit(relax.op.nn.log_softmax(x, dim))

def _prelu_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
alpha_tensor = module.weight.numpy()
alpha = relax.const(alpha_tensor, dtype="float32")
axis = 0 if len(x.struct_info.shape) == 1 else 1
return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis))

def _softmax_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
Expand Down Expand Up @@ -595,6 +603,7 @@ def create_convert_map(
nn.Identity: lambda node: self.env[node.args[0]],
nn.LeakyReLU: self._leakyrelu_module,
nn.LogSoftmax: self._log_softmax_module,
nn.PReLU: self._prelu_module,
nn.ReLU: self._unary_op(relax.op.nn.relu),
nn.ReLU6: lambda node: self.block_builder.emit(
relax.op.clip(self.env[node.args[0]], 0, 6)
Expand Down Expand Up @@ -657,6 +666,7 @@ def create_convert_map(
"logical_not": self._unary_op(relax.op.logical_not),
"log_softmax": self._log_softmax,
"neg": self._unary_op(relax.op.negative),
"prelu":self._prelu,
"reciprocal": self._reciprocal,
"relu": self._unary_op(relax.op.nn.relu),
"round": self._round,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
max_pool3d,
nll_loss,
pad,
prelu,
relu,
rms_norm,
selu,
Expand Down
26 changes: 26 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,32 @@ def log_softmax(data: Expr, axis: int = -1) -> Expr:
return _ffi_api.log_softmax(data, axis) # type: ignore


def prelu(data: Expr, alpha: Expr, axis: int = 1) -> Expr:
r"""Parametric Rectified Linear Unit (PReLU).

.. math::
PReLU(x) = x \text{ if } x > 0 \text{ else } \alpha * x

Parameters
----------
data : relax.Expr
The input tensor.

alpha : relax.Expr
The learnable slope tensor, applied channel-wise.

axis : int
The axis along which the `alpha` values are applied.
Default is 1 (assuming NCHW format).

Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.prelu(data, alpha, axis)


def batch_norm(
data: Expr,
gamma: Expr,
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,11 @@ def _nn_leakyrelu(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.nn.leaky_relu, call.args[0], call.attrs.alpha)


@register_legalize("relax.nn.prelu")
def _nn_prelu(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.nn.prelu, call.args[0], call.args[1], call.attrs.axis)


@register_legalize("relax.nn.gelu")
def _nn_gelu(bb: BlockBuilder, call: Call) -> Expr:
def te_gelu(x: te.Tensor):
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/topi/nn/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,15 @@ def prelu(x, slope, axis=1):

assert len(slope.shape) == 1
assert axis < len(x.shape)
slope = te.compute(
(get_const_int(x.shape[axis]),),
lambda c: slope[0],
name="slope_broadcasted"
)
assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis])

def _compute_channelwise(*indices):
xval = x(*indices)
return tvm.tir.Select(xval > 0, xval, xval * slope(indices[axis]))

return te.compute(x.shape, _compute_channelwise)
return te.compute(x.shape, _compute_channelwise)
20 changes: 20 additions & 0 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,26 @@ TVM_REGISTER_OP("relax.nn.softplus")
.set_attrs_type<SoftplusAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoUnaryArith</*require_float_dtype=*/true>)

/* relax.nn.prelu */
TVM_REGISTER_NODE_TYPE(PReluAttrs);

Expr prelu(Expr data, Expr alpha, int axis = 1) {
auto attrs = make_object<PReluAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("relax.nn.prelu");
return Call(op, {data, alpha}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.nn.prelu").set_body_typed(prelu);

TVM_REGISTER_OP("relax.nn.prelu")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("alpha", "Tensor", "The channel-wise learnable slope.")
.set_attrs_type<PReluAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoUnaryArith</*require_float_dtype=*/true>)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nn.softmax */
Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ Expr gelu(Expr data);
/*! \brief Gaussian Error Linear Units function approximated by tanh. */
Expr gelu_tanh(Expr data);

/*! \brief Parametric Rectified Linear Unit function.*/
Expr prelu(Expr data, Expr alpha, int axis);

/*! \brief Scaled Exponential Linear Unit function. */
Expr selu(Expr data);

Expand Down
32 changes: 32 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,38 @@ def main(
verify_model(LogSoftmax2(), example_args, {}, expected1)


def test_prelu():
class Prelu1(Module):
def __init__(self, num_parameters=1, alpha=0.25):
super().__init__()
self.prelu = torch.nn.PReLU(num_parameters=num_parameters, init=alpha)

def forward(self, x):
return self.prelu(x)

class Prelu2(torch.nn.Module):
def __init__(self):
super(Prelu2, self).__init__()
self.alpha = torch.nn.Parameter(torch.tensor([0.25]))

def forward(self, x):
return torch.nn.functional.prelu(x, self.alpha)

@tvm.script.ir_module
class expected:
@R.function
def main(x: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.prelu(x, R.const([0.25], dtype="float32"), axis=1)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
verify_model(Prelu1(), example_args, {}, expected)
verify_model(Prelu2(), example_args, {}, expected)


def test_softmax():
class Softmax(Module):
def __init__(self):
Expand Down
35 changes: 35 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,38 @@ def main(
verify_model(LeakyReLU1(), input_info, {}, expected)


def test_prelu():
class Prelu1(Module):
def __init__(self, num_parameters=1, alpha=0.25):
super().__init__()
self.prelu = torch.nn.PReLU(num_parameters=num_parameters, init=alpha)

def forward(self, x):
return self.prelu(x)

class Prelu2(torch.nn.Module):
def __init__(self):
super(Prelu2, self).__init__()
self.alpha = torch.nn.Parameter(torch.tensor([0.25]))

def forward(self, x):
return torch.nn.functional.prelu(x, self.alpha)

@tvm.script.ir_module
class expected:
@R.function
def main(x: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.prelu(x, R.const([0.25], dtype="float32"), axis=1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

input_info = [([1, 3, 10, 10], "float32")]
verify_model(Prelu1(), input_info, {}, expected)
verify_model(Prelu2(), input_info, {}, expected)


def test_maxpool2d():
input_info = [([1, 3, 10, 10], "float32")]

Expand Down Expand Up @@ -2266,6 +2298,9 @@ def main(
# softplus
test_softplus()

# prelu
test_prelu()

# log2
class Log2(Module):
def forward(self, x):
Expand Down
4 changes: 4 additions & 0 deletions tests/python/relax/test_op_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def test_op_correctness():
assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout")
assert relax.op.nn.pad(x, (1, 1, 1, 1)).op == Op.get("relax.nn.pad")

x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
alpha = relax.Var("alpha", R.Tensor((3,), "float32"))
assert relax.op.nn.prelu(x, alpha, axis=1) == Op.get("relax.nn.prelu")

x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
gamma = relax.Var("gamma", R.Tensor((3,), "float32"))
beta = relax.Var("beta", R.Tensor((3,), "float32"))
Expand Down