diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b9650e6e9a9c..35f74544b833 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2672,6 +2672,12 @@ def logical_and(self, inputs, input_types): return _op.logical_and(lhs, rhs) + def logical_or(self, inputs, input_types): + lhs = _op.cast(inputs[0], "bool") + rhs = _op.cast(inputs[1], "bool") + + return _op.logical_or(lhs, rhs) + def nonzero(self, inputs, input_types, is_numpy_style=False): data = inputs[0] ret = _op.transform.argwhere(data) @@ -4238,6 +4244,7 @@ def create_convert_map(self): "aten::unbind": self.unbind, "aten::__and__": self.logical_and, "aten::logical_and": self.logical_and, + "aten::logical_or": self.logical_or, "aten::_shape_as_tensor": self.shape_as_tensor, "aten::nonzero": self.nonzero, "aten::nonzero_numpy": self.nonzero_numpy, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9bf40cfcdd85..bf96c21399f0 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4882,6 +4882,21 @@ def test_fn(x, y): verify_model(test_fn, [a, b]) +def test_logical_or(): + """test_logical_or""" + + def test_fn(x, y): + return torch.logical_or(x, y) + + a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + verify_model(test_fn, [a, b]) + + a = torch.tensor([True, False, True]) + b = torch.tensor([True, False, False]) + verify_model(test_fn, [a, b]) + + def test_masked_select(): """test_masked_select"""