diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index 4f278d8249e8..910567c66225 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -1025,15 +1025,17 @@ def _impl_v1(cls, inputs, attrs, params): return out -class ThresholdedRelu(OneFlowOpConverter): - """Operator converter for ThresholdedRelu.""" +class Threshold(OneFlowOpConverter): + """Operator converter for Threshold.""" @classmethod def _impl_v1(cls, inputs, attrs, params): - alpha = float(attrs.get("alpha", 1.0)) - alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha)) - mask = _op.greater(inputs[0], alpha_tensor).astype("float32") - return inputs[0] * mask + threshold = float(attrs.get("threshold_val", 1.0)) + threshold_tensor = _op.full_like(inputs[0], fill_value=_expr.const(threshold)) + value = float(attrs.get("value")) + value_tensor = _op.full_like(inputs[0], fill_value=_expr.const(value)) + mask = _op.greater(inputs[0], threshold_tensor) + return _op.where(mask, inputs[0], value_tensor) class Elu(OneFlowOpConverter): @@ -1422,6 +1424,7 @@ def get_convert_map(): "relu": Renamer("relu"), "leaky_relu": Renamer("leaky_relu"), "prelu": PReLU.get_converter(), + "threshold": Threshold.get_converter(), "selu": Selu.get_converter(), "silu": Silu.get_converter(), "gelu": Gelu.get_converter(), diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py index cc9333cd03bd..f7b9b934e124 100644 --- a/tests/python/frontend/oneflow/test_forward.py +++ b/tests/python/frontend/oneflow/test_forward.py @@ -24,6 +24,7 @@ import tvm.testing import tvm.topi.testing from tvm import relay +from packaging import version as package_version MODEL_HOME = "test_model" @@ -702,6 +703,15 @@ def forward(self, x): x = x.softmax(dim=-1) return x + class Threshold(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Threshold(0.5, 0.2) + + def forward(self, x): + x = self.active(x) + return x + if os.path.exists(MODEL_HOME): rmdir(MODEL_HOME) @@ -738,6 +748,11 @@ def forward(self, x): inputs=flow.tensor(np.random.rand(1, 12, 197, 197).astype(np.float32)), ) + # Threshold was introduced in the version 0.8.0 of oneflow + if package_version.parse(flow.__version__) >= package_version.parse("0.8.0"): + model14 = Threshold().eval() + verify_activation(model14, device="llvm") + @tvm.testing.uses_gpu def test_math():