From 7d99b85ed7e7a5a56007aada798f47f3c66104cf Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 29 Nov 2025 22:29:06 +0900 Subject: [PATCH] [Relax][PyTorch] Add support for binary scalar operations in ExportedProgram frontend and corresponding tests --- .../torch/exported_program_translator.py | 2 + .../test_frontend_from_exported_program.py | 39 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index fc0ca1820940..3a33a58f8c38 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1253,6 +1253,7 @@ def create_convert_map( "trunc.default": self._unary_op(relax.op.trunc), # binary "add.Tensor": self._binary_op(relax.op.add, operator.add), + "add.Scalar": self._binary_op(relax.op.add, operator.add), "add_.Tensor": self._binary_op(relax.op.add, operator.add), "bitwise_and.Tensor": self._binary_op(relax.op.bitwise_and, operator.and_), "bitwise_and.Scalar": self._binary_op(relax.op.bitwise_and, operator.and_), @@ -1306,6 +1307,7 @@ def create_convert_map( "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow), "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow), "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub), + "sub.Scalar": self._binary_op(relax.op.subtract, operator.sub), "__and__.Tensor": self._binary_op(relax.op.bitwise_and, operator.and_), "__and__.Scalar": self._binary_op(relax.op.bitwise_and, operator.and_), "__or__.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 091f0a4a29c5..48ca5f3209c2 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1429,6 +1429,45 @@ def main( verify_model(Binary2(op), example_args2, {}, expected2) +operator_binary_scalar = [ + (torch.ops.aten.add.Scalar, R.add), + (torch.ops.aten.bitwise_and.Scalar, R.bitwise_and), + (torch.ops.aten.bitwise_or.Scalar, R.bitwise_or), + (torch.ops.aten.bitwise_xor.Scalar, R.bitwise_xor), + (torch.ops.aten.div.Scalar, R.divide), + (torch.ops.aten.sub.Scalar, R.subtract), + (torch.ops.aten.mul.Scalar, R.multiply), + (torch.ops.aten.remainder.Scalar, R.floor_mod), +] + + +@pytest.mark.parametrize("op, relax_op", operator_binary_scalar) +def test_binary_scalar(op, relax_op): + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + class BinaryScalar(Module): + def __init__(self, op): + super().__init__() + self.op = op + + def forward(self, lhs): + return self.op(lhs, 1.0) + + @tvm.script.ir_module + class expected_binary_scalar: + @R.function + def main( + lhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = relax_op(lhs, R.const(1.0)) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(BinaryScalar(op), example_args, {}, expected_binary_scalar) + + operator_binary_promote = [ (operator.add, R.add), (operator.sub, R.subtract),