Skip to content

Commit fc64d70

Browse files
tlopexShiboXing
authored andcommitted
[Relax][PyTorch] Add support for where, cumprod and reciprocal ops (apache#17788)
* Update fx_translator.py * Update base_fx_graph_translator.py * Update test_frontend_from_fx.py * Update test_frontend_from_fx.py
1 parent e49be56 commit fc64d70

3 files changed

Lines changed: 89 additions & 0 deletions

File tree

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,12 @@ def convert(node: fx.Node):
949949

950950
return convert
951951

952+
def _where(self, node: fx.Node) -> relax.Var:
953+
condition = self.env[node.args[0]]
954+
x = self.env[node.args[1]]
955+
y = self.env[node.args[2]]
956+
return self.block_builder.emit(relax.op.where(condition, x, y))
957+
952958
########## Manipulation ##########
953959

954960
def _cat(self, node: fx.Node) -> relax.Var:
@@ -967,6 +973,17 @@ def _chunk(self, node: fx.Node) -> relax.Var:
967973
relax.op.split(x=x, indices_or_sections=n_sections, axis=dim)
968974
)
969975

976+
def _cumprod(self, node: fx.Node) -> relax.Var:
977+
x = self.env[node.args[0]]
978+
979+
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
980+
if "dtype" in node.kwargs:
981+
dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
982+
else:
983+
dtype = None
984+
985+
return self.block_builder.emit(relax.op.cumprod(x, dim, dtype))
986+
970987
def _cumsum(self, node: fx.Node) -> relax.Var:
971988
x = self.env[node.args[0]]
972989

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def _fetch_attr(self, model, target: str):
6262

6363
########## Unary Ops ##########
6464

65+
def _reciprocal(self, node: fx.Node) -> relax.Var:
66+
x = self.env[node.args[0]]
67+
return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x))
68+
6569
def _leakyrelu_module(self, node: fx.Node) -> relax.Var:
6670
x = self.env[node.args[0]]
6771
module = self.named_modules[node.target]
@@ -708,6 +712,7 @@ def create_convert_map(
708712
"logical_not": self._unary_op(relax.op.logical_not),
709713
"log_softmax": self._log_softmax,
710714
"neg": self._unary_op(relax.op.negative),
715+
"reciprocal": self._reciprocal,
711716
"relu": self._unary_op(relax.op.nn.relu),
712717
"round": self._round,
713718
"rsqrt": self._unary_op(relax.op.rsqrt),
@@ -784,11 +789,13 @@ def create_convert_map(
784789
# search
785790
"argmax": self._argmax_argmin(relax.op.argmax),
786791
"argmin": self._argmax_argmin(relax.op.argmin),
792+
"where": self._where,
787793
# tensor manipulation
788794
"cat": self._cat,
789795
"chunk": self._chunk,
790796
"concat": self._cat,
791797
"contiguous": lambda node: self.env[node.args[0]],
798+
"cumprod": self._cumprod,
792799
"cumsum": self._cumsum,
793800
"expand": self._expand,
794801
"expand_as.default": self._expand_as,

tests/python/relax/test_frontend_from_fx.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2339,6 +2339,27 @@ def main(
23392339
verify_model(LogSoftmax(), input_info, {}, expected_log_softmax)
23402340
verify_model(LogSoftmax2(), input_info, {}, expected_log_softmax)
23412341

2342+
# reciprocal
2343+
class Reciprocal(Module):
2344+
def forward(self, input):
2345+
return torch.reciprocal(input)
2346+
2347+
@tvm.script.ir_module
2348+
class expected_reciprocal:
2349+
@R.function
2350+
def main(
2351+
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
2352+
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
2353+
with R.dataflow():
2354+
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
2355+
R.const(1.0, "float32"), input_1
2356+
)
2357+
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
2358+
R.output(gv)
2359+
return gv
2360+
2361+
verify_model(Reciprocal(), input_info, {}, expected_reciprocal)
2362+
23422363
# relu
23432364
class ReLU0(Module):
23442365
def __init__(self):
@@ -4315,5 +4336,49 @@ def main(
43154336
verify_model(Prod(), [([5, 3], "float32")], {}, Expected)
43164337

43174338

4339+
def test_cumprod():
4340+
class Cumprod(Module):
4341+
def forward(self, x):
4342+
return torch.cumprod(x, 0)
4343+
4344+
@tvm.script.ir_module
4345+
class Expected:
4346+
@R.function
4347+
def main(
4348+
inp_0: R.Tensor((5, 3), dtype="float32"),
4349+
) -> R.Tensor((5, 3), dtype="float32"):
4350+
with R.dataflow():
4351+
lv: R.Tensor((5, 3), dtype="float32") = R.cumprod(inp_0, axis=0, exclusive=False)
4352+
gv: R.Tensor((5, 3), dtype="float32") = lv
4353+
R.output(gv)
4354+
return gv
4355+
4356+
verify_model(Cumprod(), [([5, 3], "float32")], {}, Expected)
4357+
4358+
4359+
def test_where():
4360+
class Where(Module):
4361+
def forward(self, condition, x, y):
4362+
return torch.where(condition, x, y)
4363+
4364+
@tvm.script.ir_module
4365+
class Expected:
4366+
@R.function
4367+
def main(
4368+
inp_0: R.Tensor((5, 3), dtype="bool"),
4369+
inp_1: R.Tensor((5, 3), dtype="float32"),
4370+
inp_2: R.Tensor((5, 3), dtype="float32"),
4371+
) -> R.Tensor((5, 3), dtype="float32"):
4372+
with R.dataflow():
4373+
lv: R.Tensor((5, 3), dtype="float32") = R.where(inp_0, inp_1, inp_2)
4374+
gv: R.Tensor((5, 3), dtype="float32") = lv
4375+
R.output(gv)
4376+
return gv
4377+
4378+
verify_model(
4379+
Where(), [([5, 3], "bool"), ([5, 3], "float32"), ([5, 3], "float32")], {}, Expected
4380+
)
4381+
4382+
43184383
if __name__ == "__main__":
43194384
tvm.testing.main()

0 commit comments

Comments
 (0)