Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
support argmax and argmin
  • Loading branch information
mshr-h committed Sep 28, 2024
commit 18794984300087413a9d4ea75430bc3462efa9ce
13 changes: 13 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,19 @@ def _sum(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim))
return self.block_builder.emit(relax.op.sum(args[0], args[1]))

########## Search ##########

def _argmax_argmin(self, op: Callable) -> Callable:
from torch import fx

def convert(node: fx.Node):
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
return self.block_builder.emit(op(x, dim, keepdim))

return convert

########## Manipulation ##########

def _reshape(self, node: fx.Node) -> relax.Var:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def create_convert_map(
# statistical
"mean.dim": self._mean,
"sum.dim_IntList": self._sum,
# search
"argmax.default": self._argmax_argmin(relax.op.argmax),
"argmin.default": self._argmax_argmin(relax.op.argmin),
# tensor manipulation
"view.default": self._reshape,
}
Expand Down
13 changes: 0 additions & 13 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,19 +761,6 @@ def _unbind(self, node: fx.Node) -> relax.Var:
ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim)))
return self.block_builder.emit(relax.Tuple(ret))

########## Search ##########

def _argmax_argmin(self, op: Callable) -> Callable:
from torch import fx

def convert(node: fx.Node):
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
return self.block_builder.emit(op(x, dim, keepdim))

return convert

########## Manipulation ##########

def _cat(self, node: fx.Node) -> relax.Var:
Expand Down
86 changes: 86 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,6 +1520,92 @@ def main(
verify_model(Sum(), example_args, {}, expected1)


def test_argmax_argmin():
example_args = (torch.randn(256, 256, dtype=torch.float32),)

class Argmax1(Module):
def __init__(self) -> None:
super().__init__()

def forward(self, input):
return torch.argmax(input, dim=-1)

class Argmax2(Module):
def __init__(self) -> None:
super().__init__()

def forward(self, input):
return torch.argmax(input, dim=-1, keepdim=True)

@tvm.script.ir_module
class expected_argmax1:
@R.function
def main(
inp_0: R.Tensor((256, 256), dtype="float32")
) -> R.Tuple(R.Tensor((256,), dtype="int64")):
with R.dataflow():
lv: R.Tensor((256,), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=False)
gv: R.Tuple(R.Tensor((256,), dtype="int64")) = (lv,)
R.output(gv)
return gv

@tvm.script.ir_module
class expected_argmax2:
@R.function
def main(
inp_0: R.Tensor((256, 256), dtype="float32")
) -> R.Tuple(R.Tensor((256, 1), dtype="int64")):
with R.dataflow():
lv: R.Tensor((256, 1), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=True)
gv: R.Tuple(R.Tensor((256, 1), dtype="int64")) = (lv,)
R.output(gv)
return gv

verify_model(Argmax1(), example_args, {}, expected_argmax1)
verify_model(Argmax2(), example_args, {}, expected_argmax2)

class Argmin1(Module):
def __init__(self) -> None:
super().__init__()

def forward(self, input):
return torch.argmin(input)

class Argmin2(Module):
def __init__(self) -> None:
super().__init__()

def forward(self, input):
return torch.argmin(input, keepdim=True)

@tvm.script.ir_module
class expected_argmin1:
@R.function
def main(
inp_0: R.Tensor((256, 256), dtype="float32")
) -> R.Tuple(R.Tensor((), dtype="int64")):
with R.dataflow():
lv: R.Tensor((), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="int64")) = (lv,)
R.output(gv)
return gv

@tvm.script.ir_module
class expected_argmin2:
@R.function
def main(
inp_0: R.Tensor((256, 256), dtype="float32")
) -> R.Tuple(R.Tensor((1, 1), dtype="int64")):
with R.dataflow():
lv: R.Tensor((1, 1), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=True)
gv: R.Tuple(R.Tensor((1, 1), dtype="int64")) = (lv,)
R.output(gv)
return gv

verify_model(Argmin1(), example_args, {}, expected_argmin1)
verify_model(Argmin2(), example_args, {}, expected_argmin2)


def test_view():
class View(Module):
def forward(self, x):
Expand Down