Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
support binary ops
  • Loading branch information
mshr-h committed Sep 28, 2024
commit 1807831b427f4a1de97a5d1b0aa7842f2090a0a1
33 changes: 33 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,39 @@ def convert(node: fx.Node) -> relax.Var:

return convert

########## Binary Ops ##########

def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable:
from torch import fx

def convert(node: fx.Node) -> relax.Var:
def promote_binary_op_args(lhs, rhs):
if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
return lhs, rhs
elif isinstance(lhs, relax.Expr):
assert isinstance(lhs.struct_info, relax.TensorStructInfo)
return lhs, relax.const(rhs, lhs.struct_info.dtype)
elif isinstance(rhs, relax.Expr):
assert isinstance(rhs.struct_info, relax.TensorStructInfo)
return relax.const(lhs, rhs.struct_info.dtype), rhs
else:
assert False

def call_binary_op(op, lhs, rhs):
lhs, rhs = promote_binary_op_args(lhs, rhs)
return self.block_builder.emit(op(lhs, rhs))

lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return call_binary_op(relax_op, lhs, rhs)
elif isinstance(lhs, relax.expr.Constant):
return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype))
elif isinstance(rhs, relax.expr.Constant):
return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs)
return intrinsic_op(lhs, rhs)

return convert

########## Neural Network ##########

def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=import-outside-toplevel
"""PyTorch ExportedProgram of Relax."""
from collections import ChainMap, OrderedDict
from functools import partial
from typing import Callable, Dict, List, Tuple

import torch
Expand Down Expand Up @@ -76,6 +77,8 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr:
def create_convert_map(
self,
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
import operator

return {
# unary
"acos.default": self._unary_op(relax.op.acos),
Expand Down Expand Up @@ -109,6 +112,22 @@ def create_convert_map(
"tanh.default": self._unary_op(relax.op.tanh),
"tril.default": self._tril_triu(relax.op.tril),
"triu.default": self._tril_triu(relax.op.triu),
# binary
"add.Tensor": self._binary_op(relax.op.add, operator.add),
"div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
"eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
"floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv),
"lt.Scalar": self._binary_op(relax.op.less, operator.lt),
"lt.Tensor": self._binary_op(relax.op.less, operator.lt),
"matmul.default": self._binary_op(
partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul
),
"max.other": self._binary_op(relax.op.maximum, max),
"mul.Tensor": self._binary_op(relax.op.multiply, operator.mul),
"pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
"pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
"sub.Tensor": self._binary_op(relax.op.subtract, operator.sub),
# neural network
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
"conv2d.default": self._conv2d,
Expand Down
33 changes: 0 additions & 33 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,39 +96,6 @@ def convert(node: fx.Node) -> relax.Var:

return convert

########## Binary Ops ##########

def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable:
from torch import fx

def convert(node: fx.Node) -> relax.Var:
def promote_binary_op_args(lhs, rhs):
if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
return lhs, rhs
elif isinstance(lhs, relax.Expr):
assert isinstance(lhs.struct_info, relax.TensorStructInfo)
return lhs, relax.const(rhs, lhs.struct_info.dtype)
elif isinstance(rhs, relax.Expr):
assert isinstance(rhs.struct_info, relax.TensorStructInfo)
return relax.const(lhs, rhs.struct_info.dtype), rhs
else:
assert False

def call_binary_op(op, lhs, rhs):
lhs, rhs = promote_binary_op_args(lhs, rhs)
return self.block_builder.emit(op(lhs, rhs))

lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return call_binary_op(relax_op, lhs, rhs)
elif isinstance(lhs, relax.expr.Constant):
return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype))
elif isinstance(rhs, relax.expr.Constant):
return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs)
return intrinsic_op(lhs, rhs)

return convert

########## Neural Network ##########

def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var:
Expand Down
Loading