Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 41 additions & 31 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,7 @@ def _impl(inputs, input_types):
def _take():
def _impl(inputs, input_types):
data = inputs[0]
import torch

if isinstance(inputs[1], _expr.Var):
indices = _op.cast(inputs[1], "int32")
elif isinstance(inputs[1], torch.Tensor):
indices = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in take operator." % (type(inputs[1]))
raise AssertionError(msg)
indices = _op.cast(inputs[1], "int32")

return _op.transform.take(data, indices=indices)
return _impl
Expand Down Expand Up @@ -337,6 +329,40 @@ def _impl(inputs, input_types):
return _op.transform.repeat(data, repeats=repeats, axis=axis)
return _impl


def _addcdiv():
def _impl(inputs, input_types):
data = inputs[0]
c = _expr.const(inputs[3])
t1 = inputs[1]
t2 = inputs[2]

return data + (c * (t1 / t2))
return _impl


def _addcmul():
def _impl(inputs, input_types):
data = inputs[0]
c = _expr.const(inputs[3])
t1 = inputs[1]
t2 = inputs[2]

return data + (c * (t1 * t2))
return _impl


def _where():
def _impl(inputs, input_types):
cond = inputs[0]
x = inputs[1]
y = inputs[2]

return _op.where(cond, x, y)

return _impl


def _ones():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -1382,16 +1408,7 @@ def _impl(inputs, input_types):
def _bitwise_xor():
def _impl(inputs, input_types):
lhs = inputs[0]

import torch
if isinstance(inputs[1], _expr.Var):
rhs = inputs[1]
elif isinstance(inputs[1], torch.Tensor):
rhs = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in bitwise_xor operator." % (type(inputs[1]))
raise AssertionError(msg)

rhs = inputs[1]
lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int")
rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int")

Expand All @@ -1410,17 +1427,7 @@ def _impl(inputs, input_types):
def _logical_xor():
def _impl(inputs, input_types):
lhs = _op.cast(inputs[0], "bool")

import torch
if isinstance(inputs[1], _expr.Var):
rhs = inputs[1]
elif isinstance(inputs[1], torch.Tensor):
rhs = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in logical_xor operator." % (type(inputs[1]))
raise AssertionError(msg)

rhs = _op.cast(rhs, "bool")
rhs = _op.cast(inputs[1], "bool")

return _op.logical_xor(lhs, rhs)
return _impl
Expand Down Expand Up @@ -1551,6 +1558,8 @@ def _get_convert_map(prelude):
"aten::arange" : _arange(),
"aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"),
"aten::addcdiv" : _addcdiv(),
"aten::addcmul" : _addcmul(),
"aten::ones" : _ones(),
"aten::ones_like" : _ones_like(),
"aten::zeros" : _zeros(),
Expand All @@ -1570,6 +1579,7 @@ def _get_convert_map(prelude):
"aten::split_with_sizes" : _split_with_sizes(),
"aten::select" : _select(),
"aten::take" : _take(),
"aten::where" : _where(),
"aten::topk" : _topk(),
"aten::relu" : _relu(),
"aten::relu_" : _relu(),
Expand Down Expand Up @@ -1832,7 +1842,7 @@ def _get_constant(node):
tensor = node.t(attr_name)
if len(tensor.shape) == 0: # tensor(0.1)
return float(tensor)
return tensor
return _wrap_const(tensor.numpy())
elif ty == "DeviceObjType":
return node.s(attr_name)
elif ty == "FunctionType":
Expand Down
69 changes: 69 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1888,6 +1888,72 @@ def forward(self, *args):
verify_model(Neg1().float().eval(), input_data=input_data)


def test_forward_where():
torch.set_grad_enabled(False)

class Where1(Module):
def forward(self, *args):
y = torch.ones([3, 2])
if torch.cuda.is_available():
y = y.cuda()
return torch.where(args[0] > 0, args[0], y)

class Where2(Module):
def forward(self, *args):
return torch.where(args[0] > 0, args[0], args[1])

x = torch.rand([3, 2]).float()
verify_model(Where1().float().eval(), input_data=[x])
y = torch.rand([3, 2])
verify_model(Where2().float().eval(), input_data=[x, y])


def test_forward_addcdiv():
torch.set_grad_enabled(False)

class Addcdiv1(Module):
def forward(self, *args):
t1 = torch.ones([3, 1])
t2 = torch.ones([1, 3])
if torch.cuda.is_available():
t1 = t1.cuda()
t2 = t2.cuda()
return torch.addcdiv(args[0], 0.1, t1, t2)

class Addcdiv2(Module):
def forward(self, *args):
return torch.addcdiv(args[0], 0.5, args[1], args[2])

input_data = torch.rand([1, 3]).float()
verify_model(Addcdiv1().float().eval(), input_data=input_data)
t1 = torch.rand([3, 1]).float()
t2 = torch.rand([1, 3]).float()
verify_model(Addcdiv2().float().eval(), input_data=[input_data, t1, t2])


def test_forward_addcmul():
torch.set_grad_enabled(False)

class Addcmul1(Module):
def forward(self, *args):
t1 = torch.ones([3, 1])
t2 = torch.ones([1, 3])
if torch.cuda.is_available():
t1 = t1.cuda()
t2 = t2.cuda()
return torch.addcmul(args[0], 0.1, t1, t2)

class Addcmul2(Module):
def forward(self, *args):
return torch.addcmul(args[0], 0.5, args[1], args[2])

input_data = torch.rand([1, 3]).float()
verify_model(Addcmul1().float().eval(), input_data=input_data)
t1 = torch.rand([3, 1]).float()
t2 = torch.rand([1, 3]).float()
verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])


if __name__ == "__main__":
# Single operator tests
test_forward_add()
Expand Down Expand Up @@ -1933,6 +1999,9 @@ def forward(self, *args):
test_forward_select()
test_forward_take()
test_forward_topk()
test_forward_where()
test_forward_addcdiv()
test_forward_addcmul()
test_forward_clone()
test_forward_softplus()
test_forward_softsign()
Expand Down