From da115c57f616785fae4a7f6a2600fc0d4a5977d8 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:05:18 +0900 Subject: [PATCH 01/17] support batchnorm2d and getitem --- .../torch/base_fx_graph_translator.py | 81 +++++++++++++++++++ .../torch/exported_program_translator.py | 31 +++++++ .../tvm/relax/frontend/torch/fx_translator.py | 81 ------------------- .../test_frontend_from_exported_program.py | 53 ++++++++++++ 4 files changed, 165 insertions(+), 81 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index a41b9b6d4f9a..148c53b5dd08 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -357,6 +357,87 @@ def _reshape(self, node: fx.Node) -> relax.Var: ########## Others ########## + def _getitem(self, node: fx.Node) -> relax.Var: + import torch + + x = self.env[node.args[0]] + if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)): + return x[node.args[1]] + elif isinstance(x, relax.Var): + if isinstance(x.struct_info, relax.TupleStructInfo): + return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) + + assert isinstance(x.struct_info, relax.TensorStructInfo) + take_indices = [] + take_axes = [] + stride_begin = [] + stride_end = [] + stride = [] + stride_axes = [] + expand_dim = [] + i = 0 + shape = self.shape_of(x) + non_ellipsis_cnt = 0 + for index in node.args[1]: + if isinstance(index, (int, slice, torch.fx.Node)): + non_ellipsis_cnt += 1 + for index in node.args[1]: + if isinstance(index, int): + stride_begin.append(index) + stride_end.append(index + 1) + stride.append(1) + stride_axes.append(i) + i = i + 1 + elif isinstance(index, slice): + stride_begin.append(0 if index.start is None else index.start) + stride_end.append(shape[i] if index.stop is None else index.stop) + stride.append(1 if index.step is None else index.step) + stride_axes.append(i) + i = i + 1 + elif index is None: + expand_dim.append(len(stride_axes) + len(expand_dim)) + elif index is Ellipsis: + for _ in range(len(shape) - non_ellipsis_cnt): + stride_begin.append(0) + stride_end.append(shape[i]) + stride.append(1) + stride_axes.append(i) + i += 1 + elif isinstance(index, torch.fx.Node): + node_index = self.env[index] + if not isinstance(node_index, relax.Expr): + raise ValueError( + "Unsupported index type for relax.op.take: " + str(type(node_index)) + ) + take_indices.append(node_index) + take_axes.append(i) + i = i + 1 + else: + raise ValueError("Unsupported index type: " + str(type(index))) + while i < len(shape): + stride_begin.append(0) + stride_end.append(shape[i]) + stride.append(1) + stride_axes.append(i) + i += 1 + taken = x + if len(take_indices) > 1: + raise ValueError("Multiple tensors as index not yet supported") + for each_index, each_axis in zip(take_indices, take_axes): + taken = self.block_builder.emit(relax.op.take(taken, each_index, each_axis)) + sliced = self.block_builder.emit( + relax.op.strided_slice(taken, stride_axes, stride_begin, stride_end, stride) + ) + sliced_shape = list(self.shape_of(sliced)) + for i in expand_dim: + sliced_shape.insert(i, 1) + return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) + elif isinstance(x, relax.Constant): + dtype = x.struct_info.dtype + return relax.const(x.data.numpy()[node.args[1]], dtype) + else: + assert False + @abc.abstractmethod def create_convert_map( self, diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 11594690cdc2..834dcd1a6a14 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -74,6 +74,34 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr: max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val", 1.0) return self.block_builder.emit(relax.op.clip(x, min_val, max_val)) + ########## Neural Network ########## + + def _native_batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: + import numpy as np + + x = self.env[node.args[0]] + channel = int(self.shape_of(x)[1]) + dtype = x.struct_info.dtype + weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) + bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) + running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) + running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) + momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) + eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) + + return self.block_builder.emit( + relax.op.nn.batch_norm( + x, + weight, + bias, + running_mean, + running_var, + axis=1, + epsilon=eps, + momentum=momentum, + ) + ) + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: @@ -129,6 +157,7 @@ def create_convert_map( "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow), "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub), # neural network + "_native_batch_norm_legit_no_training.default": self._native_batch_norm_legit_no_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "conv2d.default": self._conv2d, "linear.default": self._linear, @@ -141,6 +170,8 @@ def create_convert_map( "argmin.default": self._argmax_argmin(relax.op.argmin), # tensor manipulation "view.default": self._reshape, + # other + "getitem": self._getitem, } def from_exported_program( diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index dc6ebc2eb34f..12fe0815d252 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1054,87 +1054,6 @@ def _getattr(self, node: fx.Node) -> relax.Var: return self.shape_of(self.env[node.args[0]]) return getattr(self.env[node.args[0]], node.args[1]) - def _getitem(self, node: fx.Node) -> relax.Var: - import torch - - x = self.env[node.args[0]] - if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)): - return x[node.args[1]] - elif isinstance(x, relax.Var): - if isinstance(x.struct_info, relax.TupleStructInfo): - return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) - - assert isinstance(x.struct_info, relax.TensorStructInfo) - take_indices = [] - take_axes = [] - stride_begin = [] - stride_end = [] - stride = [] - stride_axes = [] - expand_dim = [] - i = 0 - shape = self.shape_of(x) - non_ellipsis_cnt = 0 - for index in node.args[1]: - if isinstance(index, (int, slice, torch.fx.Node)): - non_ellipsis_cnt += 1 - for index in node.args[1]: - if isinstance(index, int): - stride_begin.append(index) - stride_end.append(index + 1) - stride.append(1) - stride_axes.append(i) - i = i + 1 - elif isinstance(index, slice): - stride_begin.append(0 if index.start is None else index.start) - stride_end.append(shape[i] if index.stop is None else index.stop) - stride.append(1 if index.step is None else index.step) - stride_axes.append(i) - i = i + 1 - elif index is None: - expand_dim.append(len(stride_axes) + len(expand_dim)) - elif index is Ellipsis: - for _ in range(len(shape) - non_ellipsis_cnt): - stride_begin.append(0) - stride_end.append(shape[i]) - stride.append(1) - stride_axes.append(i) - i += 1 - elif isinstance(index, torch.fx.Node): - node_index = self.env[index] - if not isinstance(node_index, relax.Expr): - raise ValueError( - "Unsupported index type for relax.op.take: " + str(type(node_index)) - ) - take_indices.append(node_index) - take_axes.append(i) - i = i + 1 - else: - raise ValueError("Unsupported index type: " + str(type(index))) - while i < len(shape): - stride_begin.append(0) - stride_end.append(shape[i]) - stride.append(1) - stride_axes.append(i) - i += 1 - taken = x - if len(take_indices) > 1: - raise ValueError("Multiple tensors as index not yet supported") - for each_index, each_axis in zip(take_indices, take_axes): - taken = self.block_builder.emit(relax.op.take(taken, each_index, each_axis)) - sliced = self.block_builder.emit( - relax.op.strided_slice(taken, stride_axes, stride_begin, stride_end, stride) - ) - sliced_shape = list(self.shape_of(sliced)) - for i in expand_dim: - sliced_shape.insert(i, 1) - return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) - elif isinstance(x, relax.Constant): - dtype = x.struct_info.dtype - return relax.const(x.data.numpy()[node.args[1]], dtype) - else: - assert False - def _sym_size_int(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 25e6dbfae308..9cecbb78fb8b 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1156,6 +1156,59 @@ def main( verify_model(Sub2(), example_args2, {}, expected_sub2) +def test_batchnorm2d(): + class BatchNorm2d(Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, input): + return self.bn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + w3: R.Tensor((3,), dtype="float32"), + w4: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = R.nn.batch_norm( + input_1, + w1, + w2, + w3, + w4, + axis=1, + epsilon=1e-05, + center=True, + scale=True, + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = BatchNorm2d().eval() + binding = { + "w1": model.bn.weight.detach().numpy(), + "w2": model.bn.bias.detach().numpy(), + "w3": model.bn.running_mean.detach().numpy(), + "w4": model.bn.running_var.detach().numpy(), + } + verify_model(model, example_args, binding, expected1) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): From 5b4d9a29e8c667ac62b36d9aa808c417ceebb0ec Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:06:31 +0900 Subject: [PATCH 02/17] support addmm --- .../torch/base_fx_graph_translator.py | 22 +++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 22 ------- .../test_frontend_from_exported_program.py | 59 +++++++++++++++++++ 4 files changed, 82 insertions(+), 22 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 148c53b5dd08..8f563994df24 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -227,6 +227,28 @@ def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) + def _addmm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + y = self.env[node.args[1]] + z = self.env[node.args[2]] + alpha = node.kwargs.get("alpha", 1) + beta = node.kwargs.get("beta", 1) + + res = None + if alpha != 0: + res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) + if alpha != 1: + dtype = res.struct_info.dtype + res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) + if beta != 0: + dtype = x.struct_info.dtype + if beta != 1: + bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) + else: + bias = x + res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) + return res + def _conv2d_impl( self, x: relax.Expr, diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 834dcd1a6a14..b798168b7220 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -159,6 +159,7 @@ def create_convert_map( # neural network "_native_batch_norm_legit_no_training.default": self._native_batch_norm_legit_no_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, + "addmm.default": self._addmm, "conv2d.default": self._conv2d, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 12fe0815d252..d0db6ed33d92 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -107,28 +107,6 @@ def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) - def _addmm(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - y = self.env[node.args[1]] - z = self.env[node.args[2]] - alpha = node.kwargs.get("alpha", 1) - beta = node.kwargs.get("beta", 1) - - res = None - if alpha != 0: - res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) - if alpha != 1: - dtype = res.struct_info.dtype - res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) - if beta != 0: - dtype = x.struct_info.dtype - if beta != 1: - bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) - else: - bias = x - res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) - return res - def _avg_pool2d_impl( self, x: relax.Expr, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 9cecbb78fb8b..b4486446c05b 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1242,6 +1242,65 @@ def main( verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1) +def test_addmm(): + class Addmm1(Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3): + return torch.addmm(x1, x2, x3) + + class Addmm2(Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3): + return torch.addmm(x1, x2, x3, beta=0.8, alpha=0.5) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.multiply(lv, R.const(0.5, "float32")) + lv2: R.Tensor((10, 10), dtype="float32") = R.multiply(x1, R.const(0.8, "float32")) + lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, lv1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + example_args = ( + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + ) + + verify_model(Addmm1(), example_args, {}, expected1) + verify_model(Addmm2(), example_args, {}, expected2) + + def test_conv2d(): class Conv2D1(Module): def __init__(self): From 388630ecc7687a2ff6eb94747a2e70e90c8f829f Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:09:01 +0900 Subject: [PATCH 03/17] support avg_pool2d --- .../torch/base_fx_graph_translator.py | 29 ++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 29 ------ .../test_frontend_from_exported_program.py | 93 +++++++++++++++++++ 4 files changed, 123 insertions(+), 29 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 8f563994df24..8b38470358a0 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -249,6 +249,35 @@ def _addmm(self, node: fx.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) return res + def _avg_pool2d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int, int]] = (1, 1), + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[int] = 0, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + stride = kernel_size if stride is None or stride == [] else stride + return self.block_builder.emit( + relax.op.nn.avg_pool2d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + ceil_mode=ceil_mode, + layout="NCHW", + ) + ) + + def _avg_pool2d(self, node: fx.Node) -> relax.Var: + args, kwargs = node.normalized_arguments(node) + x = self.env[args[0]] + kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"] + stride = args[2] if len(args) > 2 else kwargs.get("stride", None) + padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) + ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) + return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + def _conv2d_impl( self, x: relax.Expr, diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index b798168b7220..4a069fd359c0 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -160,6 +160,7 @@ def create_convert_map( "_native_batch_norm_legit_no_training.default": self._native_batch_norm_legit_no_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "addmm.default": self._addmm, + "avg_pool2d.default": self._avg_pool2d, "conv2d.default": self._conv2d, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index d0db6ed33d92..6a38e0f728a0 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -107,35 +107,6 @@ def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) - def _avg_pool2d_impl( - self, - x: relax.Expr, - kernel_size: Union[int, Tuple[int, int]] = (1, 1), - stride: Optional[Union[int, Tuple[int, int]]] = None, - padding: Optional[int] = 0, - ceil_mode: Optional[bool] = False, - ) -> relax.Var: - stride = kernel_size if stride is None or stride == [] else stride - return self.block_builder.emit( - relax.op.nn.avg_pool2d( - x, - pool_size=kernel_size, - strides=stride, - padding=padding, - ceil_mode=ceil_mode, - layout="NCHW", - ) - ) - - def _avg_pool2d(self, node: fx.Node) -> relax.Var: - args, kwargs = node.normalized_arguments(node) - x = self.env[args[0]] - kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"] - stride = args[2] if len(args) > 2 else kwargs.get("stride", None) - padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) - ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) - return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) - def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index b4486446c05b..4908919af219 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1301,6 +1301,99 @@ def main( verify_model(Addmm2(), example_args, {}, expected2) +def test_avg_pool2d(): + class AvgPool2d1(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.avg_pool2d( + input_1, + pool_size=[1, 1], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class AvgPool2d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True) + + def forward(self, input): + return self.pool(input) + + class AvgPool2d3(Module): + def forward(self, input): + return torch.nn.functional.avg_pool2d( + input, kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True + ) + + @tvm.script.ir_module + class expected2: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool2d( + input_1, + pool_size=[4, 4], + strides=[2, 2], + dilation=[1, 1], + padding=[2, 2, 2, 2], + ceil_mode=True, + layout="NCHW", + out_layout="NCHW", + ) + gv = (lv,) + R.output(gv) + return gv + + class AvgPool2d4(Module): + def forward(self, input): + return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1], divisor_override=2) + + @tvm.script.ir_module + class expected3: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool2d( + input_1, + pool_size=[2, 1], + strides=[2, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + ceil_mode=False, + layout="NCHW", + out_layout="NCHW", + ) + gv = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(AvgPool2d1(), example_args, {}, expected1) + verify_model(AvgPool2d2(), example_args, {}, expected2) + verify_model(AvgPool2d3(), example_args, {}, expected2) + verify_model(AvgPool2d4(), example_args, {}, expected3) + + def test_conv2d(): class Conv2D1(Module): def __init__(self): From d7a0c21826b700ccd8f7031e059a8c83e0aae519 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:16:14 +0900 Subject: [PATCH 04/17] support baddbmm --- .../torch/base_fx_graph_translator.py | 22 ++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 22 ---- .../test_frontend_from_exported_program.py | 102 ++++++++++++++++++ 4 files changed, 125 insertions(+), 22 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 8b38470358a0..6cb845b6f2c9 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -278,6 +278,28 @@ def _avg_pool2d(self, node: fx.Node) -> relax.Var: ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + def _baddbmm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + batch1 = self.env[node.args[1]] + batch2 = self.env[node.args[2]] + alpha = node.kwargs.get("alpha", 1) + beta = node.kwargs.get("beta", 1) + + res = None + if alpha != 0: + res = self.block_builder.emit(relax.op.matmul(batch1, batch2)) + if alpha != 1: + dtype = res.struct_info.dtype + res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) + if beta != 0: + dtype = x.struct_info.dtype + if beta != 1: + bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) + else: + bias = x + res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) + return res + def _conv2d_impl( self, x: relax.Expr, diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 4a069fd359c0..0165d3e224d6 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -161,6 +161,7 @@ def create_convert_map( "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "addmm.default": self._addmm, "avg_pool2d.default": self._avg_pool2d, + "baddbmm.default": self._baddbmm, "conv2d.default": self._conv2d, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 6a38e0f728a0..50ee9b488cb9 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -116,28 +116,6 @@ def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: ceil_mode = module.ceil_mode return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) - def _baddbmm(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - a = self.env[node.args[1]] - b = self.env[node.args[2]] - alpha = node.kwargs.get("alpha", 1) - beta = node.kwargs.get("beta", 1) - - res = None - if alpha != 0: - res = self.block_builder.emit(relax.op.matmul(a, b)) - if alpha != 1: - dtype = res.struct_info.dtype - res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) - if beta != 0: - dtype = x.struct_info.dtype - if beta != 1: - bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) - else: - bias = x - res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) - return res - def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 4908919af219..8fb7446cc0f2 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1394,6 +1394,108 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): verify_model(AvgPool2d4(), example_args, {}, expected3) +def test_baddbmm(): + class BAddBMM1(Module): + def __init__(self): + super().__init__() + + def forward(self, c, x, y): + return torch.baddbmm(c, x, y) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((4, 128, 512), dtype="float32"), + inp_1: R.Tensor((4, 128, 256), dtype="float32"), + inp_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv, inp_0) + gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + class BAddBMM2(Module): + def __init__(self): + super().__init__() + + def forward(self, c, x, y): + return torch.baddbmm(c, x, y, alpha=2, beta=0) + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((4, 128, 512), dtype="float32"), + inp_1: R.Tensor((4, 128, 256), dtype="float32"), + inp_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( + lv, R.const(2, "float32") + ) + gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + class BAddBMM3(Module): + def __init__(self): + super().__init__() + + def forward(self, c, x, y): + return torch.baddbmm(c, x, y, alpha=2, beta=3) + + @tvm.script.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((4, 128, 512), dtype="float32"), + inp_1: R.Tensor((4, 128, 256), dtype="float32"), + inp_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( + lv, R.const(2, "float32") + ) + lv2: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( + inp_0, R.const(3, "float32") + ) + lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + example_args = ( + torch.randn(4, 128, 512, dtype=torch.float32), + torch.randn(4, 128, 256, dtype=torch.float32), + torch.randn(4, 256, 512, dtype=torch.float32), + ) + verify_model( + BAddBMM1(), + example_args, + {}, + Expected1, + ) + + verify_model( + BAddBMM2(), + example_args, + {}, + Expected2, + ) + + verify_model( + BAddBMM3(), + example_args, + {}, + Expected3, + ) + + def test_conv2d(): class Conv2D1(Module): def __init__(self): From 340b4769885211f7cc74dfe7feb03b4039cb5e64 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:17:37 +0900 Subject: [PATCH 05/17] support bmm --- .../torch/exported_program_translator.py | 3 ++ .../test_frontend_from_exported_program.py | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 0165d3e224d6..ca30f4cb9c7c 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -162,6 +162,9 @@ def create_convert_map( "addmm.default": self._addmm, "avg_pool2d.default": self._avg_pool2d, "baddbmm.default": self._baddbmm, + "bmm.default": self._binary_op( + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul + ), "conv2d.default": self._conv2d, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 8fb7446cc0f2..46cf07bbeb1f 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1496,6 +1496,42 @@ def main( ) +def test_bmm(): + class BMM(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.bmm(x, y) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input_1: R.Tensor((4, 128, 256), dtype="float32"), + input_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul( + input_1, input_2, out_dtype="float32" + ) + gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.randn(4, 128, 256, dtype=torch.float32), + torch.randn(4, 256, 512, dtype=torch.float32), + ) + verify_model( + BMM(), + example_args, + {}, + Expected, + ) + + def test_conv2d(): class Conv2D1(Module): def __init__(self): From 65f329a9b0d72c695e2998155fa34f4bc6c3eaf5 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:21:28 +0900 Subject: [PATCH 06/17] support conv_transpose1d --- .../torch/base_fx_graph_translator.py | 50 ++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 56 +---------- .../test_frontend_from_exported_program.py | 92 +++++++++++++++++++ 4 files changed, 146 insertions(+), 53 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 6cb845b6f2c9..e052702e2c03 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -300,6 +300,56 @@ def _baddbmm(self, node: fx.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) return res + def _conv1d_transpose_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv1d_transpose = self.block_builder.emit( + relax.op.nn.conv1d_transpose( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv1d_transpose + + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) + + def _conv_transpose1d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_transpose_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + def _conv2d_impl( self, x: relax.Expr, diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index ca30f4cb9c7c..3818ba809694 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -165,6 +165,7 @@ def create_convert_map( "bmm.default": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), + "conv_transpose1d.default": self._conv_transpose1d, "conv2d.default": self._conv2d, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 50ee9b488cb9..8175afeb88fe 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -139,57 +139,7 @@ def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) - def _conv1d_transpose_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ) -> relax.Var: - conv1d_transpose = self.block_builder.emit( - relax.op.nn.conv1d_transpose( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCW", - kernel_layout="OIW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv1d_transpose - - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) - - def _conv1d_transpose(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv1d_transpose_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - - def _conv1d_transpose_module(self, node: fx.Node) -> relax.Var: + def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -1028,7 +978,7 @@ def create_convert_map( nn.Conv1d: self._conv1d_module, nn.Conv2d: self._conv2d_module, nn.Conv3d: self._conv3d_module, - nn.ConvTranspose1d: self._conv1d_transpose_module, + nn.ConvTranspose1d: self._conv_transpose1d_module, nn.ConvTranspose2d: self._conv2d_transpose_module, nn.CrossEntropyLoss: self._cross_entropy_module, nn.GroupNorm: self._group_norm_module, @@ -1094,7 +1044,7 @@ def create_convert_map( "bmm": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), - "conv_transpose1d": self._conv1d_transpose, + "conv_transpose1d": self._conv_transpose1d, "conv_transpose2d": self._conv2d_transpose, "conv1d": self._conv1d, "conv2d": self._conv2d, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 46cf07bbeb1f..955e915251ed 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1532,6 +1532,98 @@ def main( ) +def test_conv_transpose1d(): + class ConvTranspose1d1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=True) + + def forward(self, input): + return self.conv(input) + + class ConvTranspose1d1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 6, 3]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv_transpose1d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 6, 4), dtype="float32"), + w1: R.Tensor((6, 6, 3), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 6), dtype="float32") = R.nn.conv1d_transpose( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1]) + lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class ConvTranspose1d2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 6, 4), dtype="float32"), + w1: R.Tensor((6, 6, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 6), dtype="float32") = R.nn.conv1d_transpose( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 6, 4, dtype=torch.float32),) + + model = ConvTranspose1d1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = ConvTranspose1d1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = ConvTranspose1d2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + def test_conv2d(): class Conv2D1(Module): def __init__(self): From 730841e1f10c7fb2dc7cde00b4ec5c4bc9a913ea Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:28:21 +0900 Subject: [PATCH 07/17] support conv_transpose2d --- .../torch/base_fx_graph_translator.py | 54 ++++++++++- .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 60 +----------- .../test_frontend_from_exported_program.py | 92 +++++++++++++++++++ 4 files changed, 150 insertions(+), 57 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index e052702e2c03..5c669fdcfdf4 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -300,7 +300,7 @@ def _baddbmm(self, node: fx.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) return res - def _conv1d_transpose_impl( + def _conv_transpose1d_impl( self, x: relax.Expr, weight: relax.Expr, @@ -340,7 +340,57 @@ def _conv_transpose1d(self, node: fx.Node) -> relax.Var: padding = args[4] if len(args) > 4 else 0 dilation = args[5] if len(args) > 5 else 1 groups = args[6] if len(args) > 6 else 1 - return self._conv1d_transpose_impl( + return self._conv_transpose1d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv_transpose2d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv2d_transpose = self.block_builder.emit( + relax.op.nn.conv2d_transpose( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv2d_transpose + + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) + + def _conv_transpose2d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv_transpose2d_impl( x, weight, bias=bias, diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 3818ba809694..b78f63b4ae6c 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -166,6 +166,7 @@ def create_convert_map( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), "conv_transpose1d.default": self._conv_transpose1d, + "conv_transpose2d.input": self._conv_transpose2d, "conv2d.default": self._conv2d, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 8175afeb88fe..e76183481eee 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -145,7 +145,7 @@ def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var: weight = self.params[module.weight] bias = self.params.get(module.bias, None) - return self._conv1d_transpose_impl( + return self._conv_transpose1d_impl( x, weight, bias=bias, @@ -155,63 +155,13 @@ def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) - def _conv2d_transpose_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ) -> relax.Var: - conv2d_transpose = self.block_builder.emit( - relax.op.nn.conv2d_transpose( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCHW", - kernel_layout="OIHW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv2d_transpose - - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) - - def _conv2d_transpose(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv2d_transpose_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - - def _conv2d_transpose_module(self, node: fx.Node) -> relax.Var: + def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] bias = self.params.get(module.bias, None) - return self._conv2d_transpose_impl( + return self._conv_transpose2d_impl( x, weight, bias=bias, @@ -979,7 +929,7 @@ def create_convert_map( nn.Conv2d: self._conv2d_module, nn.Conv3d: self._conv3d_module, nn.ConvTranspose1d: self._conv_transpose1d_module, - nn.ConvTranspose2d: self._conv2d_transpose_module, + nn.ConvTranspose2d: self._conv_transpose2d_module, nn.CrossEntropyLoss: self._cross_entropy_module, nn.GroupNorm: self._group_norm_module, nn.LayerNorm: self._layer_norm_module, @@ -1045,7 +995,7 @@ def create_convert_map( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), "conv_transpose1d": self._conv_transpose1d, - "conv_transpose2d": self._conv2d_transpose, + "conv_transpose2d": self._conv_transpose2d, "conv1d": self._conv1d, "conv2d": self._conv2d, "conv3d": self._conv3d, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 955e915251ed..e1dfd7dc370a 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1624,6 +1624,98 @@ def main( verify_model(model, example_args, binding, expected2) +def test_conv_transpose2d(): + class ConvTranspose2d1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + class ConvTranspose2d1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[3, 3, 7, 7]) + self.bias = torch.randn(size=[3]) + + def forward(self, input): + return torch.nn.functional.conv_transpose2d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3, 3, 7, 7), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 3, 16, 16), dtype="float32") = R.nn.conv2d_transpose( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1]) + lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class ConvTranspose2d2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3, 3, 7, 7), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 3, 16, 16), dtype="float32") = R.nn.conv2d_transpose( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = ConvTranspose2d1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = ConvTranspose2d1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = ConvTranspose2d2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + def test_conv2d(): class Conv2D1(Module): def __init__(self): From 7f0d48361465b51918fe92d4f4aedfdbccfce0fe Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:30:20 +0900 Subject: [PATCH 08/17] support conv1d --- .../torch/base_fx_graph_translator.py | 49 ++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 49 ---------- .../test_frontend_from_exported_program.py | 92 +++++++++++++++++++ 4 files changed, 142 insertions(+), 49 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 5c669fdcfdf4..c58367bacf00 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -400,6 +400,55 @@ def _conv_transpose2d(self, node: fx.Node) -> relax.Var: groups=groups, ) + def _conv1d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv1d = self.block_builder.emit( + relax.op.nn.conv1d( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv1d + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + return self.block_builder.emit(relax.op.add(conv1d, bias)) + + def _conv1d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + def _conv2d_impl( self, x: relax.Expr, diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index b78f63b4ae6c..60ef130f6efc 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -167,6 +167,7 @@ def create_convert_map( ), "conv_transpose1d.default": self._conv_transpose1d, "conv_transpose2d.input": self._conv_transpose2d, + "conv1d.default": self._conv1d, "conv2d.default": self._conv2d, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e76183481eee..0bc140bb2c4e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -171,55 +171,6 @@ def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) - def _conv1d_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ) -> relax.Var: - conv1d = self.block_builder.emit( - relax.op.nn.conv1d( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCW", - kernel_layout="OIW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv1d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d, bias)) - - def _conv1d(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv1d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - def _conv1d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e1dfd7dc370a..bdcd77049c1d 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1716,6 +1716,98 @@ def main( verify_model(model, example_args, binding, expected2) +def test_conv1d(): + class Conv1D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + class Conv1D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv1d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + w1: R.Tensor((6, 3, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + input_1: R.Tensor((1, 3, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1]) + lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class Conv1D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + w1: R.Tensor((6, 3, 7), dtype="float32"), + input_1: R.Tensor((1, 3, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, dtype=torch.float32),) + + model = Conv1D1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv1D1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv1D2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + def test_conv2d(): class Conv2D1(Module): def __init__(self): From 01431253e680f7e4fee7aa692ed05c8154a95df1 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:32:02 +0900 Subject: [PATCH 09/17] support conv3d --- .../torch/base_fx_graph_translator.py | 49 ++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 49 ---------- .../test_frontend_from_exported_program.py | 92 +++++++++++++++++++ 4 files changed, 142 insertions(+), 49 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c58367bacf00..6a252247a885 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -498,6 +498,55 @@ def _conv2d(self, node: fx.Node) -> relax.Var: groups=groups, ) + def _conv3d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ): + conv3d = self.block_builder.emit( + relax.op.nn.conv3d( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCDHW", + kernel_layout="OIDHW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv3d + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv3d, bias)) + + def _conv3d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv3d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + def _linear(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 60ef130f6efc..93c62e254ced 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -169,6 +169,7 @@ def create_convert_map( "conv_transpose2d.input": self._conv_transpose2d, "conv1d.default": self._conv1d, "conv2d.default": self._conv2d, + "conv3d.default": self._conv3d, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, # statistical diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 0bc140bb2c4e..f8366c154f42 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -203,55 +203,6 @@ def _conv2d_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) - def _conv3d_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ): - conv3d = self.block_builder.emit( - relax.op.nn.conv3d( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCDHW", - kernel_layout="OIDHW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv3d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv3d, bias)) - - def _conv3d(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv3d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - def _conv3d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index bdcd77049c1d..cefd023dfec9 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1900,6 +1900,98 @@ def main( verify_model(model, example_args, binding, expected2) +def test_conv3d(): + class Conv3D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + class Conv3D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7, 7, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv3d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d( + input_1, + w1, + strides=[1], + padding=[0, 0, 0], + dilation=[1], + data_layout="NCDHW", + kernel_layout="OIDHW", + out_layout="NCDHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1]) + lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class Conv3D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d( + input_1, + w1, + strides=[1], + padding=[0, 0, 0], + dilation=[1], + data_layout="NCDHW", + kernel_layout="OIDHW", + out_layout="NCDHW", + out_dtype="float32", + ) + gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),) + + model = Conv3D1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv3D1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv3D2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + def test_linear(): class Dense1(Module): def __init__(self): From cc7c28b4bc9562f8d5b2086ed655733a7e31a3a2 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:32:55 +0900 Subject: [PATCH 10/17] support einsum --- .../torch/base_fx_graph_translator.py | 7 +++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 7 --- .../test_frontend_from_exported_program.py | 48 +++++++++++++++++++ 4 files changed, 56 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 6a252247a885..34ec23b4815b 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -547,6 +547,13 @@ def _conv3d(self, node: fx.Node) -> relax.Var: groups=groups, ) + def _einsum(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + operands = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.einsum(operands, args[0])) + def _linear(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 93c62e254ced..7746650843e8 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -170,6 +170,7 @@ def create_convert_map( "conv1d.default": self._conv1d, "conv2d.default": self._conv2d, "conv3d.default": self._conv3d, + "einsum.default": self._einsum, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, # statistical diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index f8366c154f42..828e065d6010 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -253,13 +253,6 @@ def _cross_entropy_module(self, node: fx.Node) -> relax.Expr: ) ) - def _einsum(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - operands = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.einsum(operands, args[0])) - def _embedding_impl( self, x, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index cefd023dfec9..ce281f6cc0d3 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1992,6 +1992,54 @@ def main( verify_model(model, example_args, binding, expected2) +def test_einsum(): + class Einsum1(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.einsum("ii", x) + + class Einsum2(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.einsum("i,j->ij", x, y) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((4, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii") + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="float32"), inp_1: R.Tensor((4,), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5, 4), dtype="float32") = R.einsum( + (inp_0, inp_1), subscripts="i,j->ij" + ) + gv: R.Tuple(R.Tensor((5, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(4, 4, dtype=torch.float32),) + verify_model(Einsum1(), example_args, {}, Expected1) + + example_args = (torch.randn(5, dtype=torch.float32), torch.randn(4, dtype=torch.float32)) + verify_model(Einsum2(), example_args, {}, Expected2) + + def test_linear(): class Dense1(Module): def __init__(self): From ecf016c8594c65563241b304a75c66906923a67a Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:34:03 +0900 Subject: [PATCH 11/17] support embedding --- .../torch/base_fx_graph_translator.py | 17 +++++++++++ .../torch/exported_program_translator.py | 3 ++ .../tvm/relax/frontend/torch/fx_translator.py | 17 ----------- .../test_frontend_from_exported_program.py | 30 +++++++++++++++++++ 4 files changed, 50 insertions(+), 17 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 34ec23b4815b..ffc5c7230dbd 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -554,6 +554,23 @@ def _einsum(self, node: fx.Node) -> relax.Var: operands = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.einsum(operands, args[0])) + def _embedding_impl( + self, + x, + weight, + ) -> relax.Var: + x = self.block_builder.emit(relax.op.astype(x, "int32")) + + ndim = x.struct_info.ndim + if ndim == 1: + return self.block_builder.emit(relax.op.take(weight, x, axis=0)) + else: + x_shape = x.struct_info.shape.values + emb_size = weight.struct_info.shape.values[-1] + x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) + embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) + return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) + def _linear(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7746650843e8..2ff7c233631d 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -171,6 +171,9 @@ def create_convert_map( "conv2d.default": self._conv2d, "conv3d.default": self._conv3d, "einsum.default": self._einsum, + "embedding.default": lambda node: self._embedding_impl( + self.env[node.args[1]], self.env[node.args[0]] + ), "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, # statistical diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 828e065d6010..be35ce81cfe3 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -253,23 +253,6 @@ def _cross_entropy_module(self, node: fx.Node) -> relax.Expr: ) ) - def _embedding_impl( - self, - x, - weight, - ) -> relax.Var: - x = self.block_builder.emit(relax.op.astype(x, "int32")) - - ndim = x.struct_info.ndim - if ndim == 1: - return self.block_builder.emit(relax.op.take(weight, x, axis=0)) - else: - x_shape = x.struct_info.shape.values - emb_size = weight.struct_info.shape.values[-1] - x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) - embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) - return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) - def _embedding_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ce281f6cc0d3..cc090b764fce 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2040,6 +2040,36 @@ def main( verify_model(Einsum2(), example_args, {}, Expected2) +def test_embedding(): + class Embedding(Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(10, 3) + + def forward(self, input): + return self.embedding(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((4,), dtype="int64"), w1: R.Tensor((10, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4,), dtype="int32") = R.astype(input_1, dtype="int32") + lv1: R.Tensor((4, 3), dtype="float32") = R.take(w1, lv, axis=0) + gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randint(low=-int(1e5), high=int(1e5), size=(4,), dtype=torch.int64),) + + model = Embedding() + binding = {"w1": model.embedding.weight.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + def test_linear(): class Dense1(Module): def __init__(self): From 327dfeda1cb842c9e6544d056cf67819c8757e6b Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:35:48 +0900 Subject: [PATCH 12/17] support group_norm --- .../torch/exported_program_translator.py | 21 ++++++++ .../test_frontend_from_exported_program.py | 49 +++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2ff7c233631d..80fdda2389ad 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -102,6 +102,26 @@ def _native_batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: ) ) + def _group_norm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + num_groups = node.args[1] + gamma = self.env[node.args[2]] if len(node.args) > 2 else None + beta = self.env[node.args[3]] if len(node.args) > 3 else None + eps = node.args[4] if len(node.args) > 4 else 1e-05 + + dim = len(self.shape_of(x)) + return self.block_builder.emit( + relax.op.nn.group_norm( + x, + gamma, + beta, + num_groups=num_groups, + channel_axis=1, + axes=list(range(2, dim)), + epsilon=eps, + ) + ) + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: @@ -174,6 +194,7 @@ def create_convert_map( "embedding.default": lambda node: self._embedding_impl( self.env[node.args[1]], self.env[node.args[0]] ), + "group_norm.default": self._group_norm, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, # statistical diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index cc090b764fce..cefe25bebab3 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2070,6 +2070,55 @@ def main( verify_model(model, example_args, binding, expected1) +def test_groupnorm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class GroupNorm(Module): + def __init__(self): + super().__init__() + self.gn = torch.nn.GroupNorm(3, 3) + + def forward(self, input): + return self.gn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.group_norm( + input_1, + w1, + w2, + num_groups=3, + channel_axis=1, + axes=[2, 3], + epsilon=1.0000000000000001e-05, + center=True, + scale=True, + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = GroupNorm() + binding = { + "w1": model.gn.weight.detach().numpy(), + "w2": model.gn.bias.detach().numpy(), + } + verify_model(model, example_args, binding, expected1) + + def test_linear(): class Dense1(Module): def __init__(self): From 00e32584fe2ddaf0981d104c5b569746141cf1ad Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:37:03 +0900 Subject: [PATCH 13/17] support layer_norm --- .../torch/base_fx_graph_translator.py | 55 +++++++++++++++++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 55 ------------------- .../test_frontend_from_exported_program.py | 42 ++++++++++++++ 4 files changed, 98 insertions(+), 55 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index ffc5c7230dbd..783921ff8883 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -571,6 +571,61 @@ def _embedding_impl( embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) + def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var: + from torch.fx.immutable_collections import immutable_list + import numpy as np # type: ignore + + if isinstance(normalized_shape, (immutable_list, tuple)): + normalized_shape = tuple(normalized_shape) + else: + try: + normalized_shape = self.env[normalized_shape] + except TypeError: + normalized_shape = tuple(normalized_shape) + + dim_num = len(normalized_shape) + axes = list(range(-dim_num, 0)) + + if gamma is None: + shape_tuple = [int(s) for s in normalized_shape] + gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype) + if beta is None: + shape_tuple = [int(s) for s in normalized_shape] + beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) + + return self.block_builder.emit( + relax.op.nn.layer_norm( + x, + gamma, + beta, + axes=axes, + epsilon=eps, + ) + ) + + def _layer_norm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + normalized_shape = node.args[1] + gamma = self.env[node.args[2]] if len(node.args) > 2 else None + beta = self.env[node.args[3]] if len(node.args) > 3 else None + eps = node.args[4] if len(node.args) > 4 else 1e-05 + return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + + def _layer_norm_module(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + normalized_shape = module.normalized_shape + if module.elementwise_affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) + beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) + eps = module.eps + return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + def _linear(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 80fdda2389ad..6df451a443f9 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -195,6 +195,7 @@ def create_convert_map( self.env[node.args[1]], self.env[node.args[0]] ), "group_norm.default": self._group_norm, + "layer_norm.default": self._layer_norm, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, # statistical diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index be35ce81cfe3..64e8db52b07e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -360,61 +360,6 @@ def _interpolate(self, node: fx.Node) -> relax.Var: ) ) - def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var: - from torch.fx.immutable_collections import immutable_list - import numpy as np # type: ignore - - if isinstance(normalized_shape, (immutable_list, tuple)): - normalized_shape = tuple(normalized_shape) - else: - try: - normalized_shape = self.env[normalized_shape] - except TypeError: - normalized_shape = tuple(normalized_shape) - - dim_num = len(normalized_shape) - axes = list(range(-dim_num, 0)) - - if gamma is None: - shape_tuple = [int(s) for s in normalized_shape] - gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype) - if beta is None: - shape_tuple = [int(s) for s in normalized_shape] - beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) - - return self.block_builder.emit( - relax.op.nn.layer_norm( - x, - gamma, - beta, - axes=axes, - epsilon=eps, - ) - ) - - def _layer_norm(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - normalized_shape = node.args[1] - gamma = self.env[node.args[2]] if len(node.args) > 2 else None - beta = self.env[node.args[3]] if len(node.args) > 3 else None - eps = node.args[4] if len(node.args) > 4 else 1e-05 - return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) - - def _layer_norm_module(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - x = self.env[node.args[0]] - module = self.named_modules[node.target] - normalized_shape = module.normalized_shape - if module.elementwise_affine: - gamma = self.params[module.weight] - beta = self.params[module.bias] - else: - gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) - beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) - eps = module.eps - return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) - def _linear_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index cefe25bebab3..922960785e77 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2119,6 +2119,48 @@ def main( verify_model(model, example_args, binding, expected1) +def test_layernorm(): + class LayerNorm(Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm((10, 10)) + + def forward(self, input): + return self.ln(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((10, 10), dtype="float32"), + w2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm( + input_1, + w1, + w2, + axes=[-2, -1], + epsilon=1e-05, + center=True, + scale=True, + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = LayerNorm() + binding = { + "w1": model.ln.weight.detach().numpy(), + "w2": model.ln.bias.detach().numpy(), + } + verify_model(LayerNorm(), example_args, binding, expected1) + + def test_linear(): class Dense1(Module): def __init__(self): From ccaf7f9debfceaa12b8a3c2f2142fd7fc997e528 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:40:58 +0900 Subject: [PATCH 14/17] support scaled_dot_product_attention --- .../torch/base_fx_graph_translator.py | 22 +++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 22 ----- .../test_frontend_from_exported_program.py | 90 +++++++++++++++++++ 4 files changed, 113 insertions(+), 22 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 783921ff8883..d132f867299f 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -666,6 +666,28 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: + transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) + query = transpose_S_H(self.env[node.args[0]]) + key = transpose_S_H(self.env[node.args[1]]) + value = transpose_S_H(self.env[node.args[2]]) + attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None) + dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0) + assert dropout_p == 0.0, "Dropout is not supported" + is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False) + causal_mask = "TopLeft" if is_causal else None + + if attn_mask is not None: + attn_mask = self.env[attn_mask] + msg = "Only a float mask is supported for the attn_mask input." + assert "float" in attn_mask.struct_info.dtype, msg + + return self.block_builder.emit( + transpose_S_H( + relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) + ) + ) + ########## Statistical ########## def _mean(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 6df451a443f9..842a62273115 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -198,6 +198,7 @@ def create_convert_map( "layer_norm.default": self._layer_norm, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, + "scaled_dot_product_attention.default": self._scaled_dot_product_attention, # statistical "mean.dim": self._mean, "sum.dim_IntList": self._sum, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 64e8db52b07e..c0beb4b4c307 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -378,28 +378,6 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) - def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: - transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) - query = transpose_S_H(self.env[node.args[0]]) - key = transpose_S_H(self.env[node.args[1]]) - value = transpose_S_H(self.env[node.args[2]]) - attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None) - dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0) - assert dropout_p == 0.0, "Dropout is not supported" - is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False) - causal_mask = "TopLeft" if is_causal else None - - if attn_mask is not None: - attn_mask = self.env[attn_mask] - msg = "Only a float mask is supported for the attn_mask input." - assert "float" in attn_mask.struct_info.dtype, msg - - return self.block_builder.emit( - transpose_S_H( - relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) - ) - ) - def _unbind(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 922960785e77..3abfe206bcf8 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2340,6 +2340,96 @@ def main( verify_model(MaxPool2d3(), example_args, {}, expected3) +def test_scaled_dot_product_attention(): + class Attention1(Module): + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention(q, k, v) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), + ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_0, axes=[0, 2, 1, 3] + ) + lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_1, axes=[0, 2, 1, 3] + ) + lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_2, axes=[0, 2, 1, 3] + ) + lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( + lv, lv1, lv2, scale=None + ) + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] + ) + gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv4,) + R.output(gv) + return gv + + class Attention2(Module): + def forward(self, q, k, v, mask): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask) + + @I.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"), + ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_0, axes=[0, 2, 1, 3] + ) + lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_1, axes=[0, 2, 1, 3] + ) + lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_2, axes=[0, 2, 1, 3] + ) + lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( + lv, lv1, lv2, inp_3, scale=None + ) + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] + ) + gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv4,) + R.output(gv) + return gv + + verify_model( + Attention1(), + ( + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 64, dtype=torch.float32), + ), + {}, + Expected1, + ) + + verify_model( + Attention2(), + ( + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 128, dtype=torch.float32), + ), + {}, + Expected2, + ) + + def test_mean(): class Mean(Module): def forward(self, input): From 63a3e43cda959b1d517e1973b14cc040332c51b5 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:42:07 +0900 Subject: [PATCH 15/17] support unbind --- .../torch/base_fx_graph_translator.py | 11 +++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 11 --- .../test_frontend_from_exported_program.py | 94 +++++++++++++++++++ 4 files changed, 106 insertions(+), 11 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index d132f867299f..52784dc8c3cd 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -688,6 +688,17 @@ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: ) ) + def _unbind(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + assert isinstance(dim, int), "Expected 2nd argument of unbind as int" + selections = self.shape_of(x)[dim].value + n_section = list(range(1, selections + 1)) + ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) + for i in range(selections): + ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) + return self.block_builder.emit(relax.Tuple(ret)) + ########## Statistical ########## def _mean(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 842a62273115..d1e7dc0e7a6c 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -199,6 +199,7 @@ def create_convert_map( "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, "scaled_dot_product_attention.default": self._scaled_dot_product_attention, + "unbind.int": self._unbind, # statistical "mean.dim": self._mean, "sum.dim_IntList": self._sum, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index c0beb4b4c307..9f064bcda473 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -378,17 +378,6 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) - def _unbind(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - assert isinstance(dim, int), "Expected 2nd argument of unbind as int" - selections = self.shape_of(x)[dim].value - n_section = list(range(1, selections + 1)) - ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) - for i in range(selections): - ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) - return self.block_builder.emit(relax.Tuple(ret)) - ########## Manipulation ########## def _cat(self, node: fx.Node) -> relax.Var: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 3abfe206bcf8..187b9c681d33 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2430,6 +2430,100 @@ def main( ) +def test_unbind(): + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((0, 3, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[0]) + lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[0]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] + lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] + lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv8, lv9, lv10) + R.output(gv) + return gv + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 0, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) + lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[1]) + lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[1]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] + lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] + lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv8, lv9, lv10) + R.output(gv) + return gv + + example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),) + verify_model(Unbind1(), example_args, {}, expected1) + verify_model(Unbind2(), example_args, {}, expected2) + + def test_mean(): class Mean(Module): def forward(self, input): From 2ca9e77cf8bb39cd990b1930166cc32646c1aca0 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 13:58:15 +0900 Subject: [PATCH 16/17] support interpolate --- .../torch/exported_program_translator.py | 42 ++++++++++++ .../test_frontend_from_exported_program.py | 64 +++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index d1e7dc0e7a6c..11122117fbd6 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -122,6 +122,46 @@ def _group_norm(self, node: fx.Node) -> relax.Var: ) ) + def _upsample_impl( + self, x: relax.Expr, size, align_corners: bool, scale_factor, method: str + ) -> relax.Var: + coord_trans = "align_corners" if align_corners else "half_pixel" + + if size is None: + shape = self.shape_of(x) + assert isinstance(shape, relax.ShapeExpr) + if isinstance(scale_factor, (tuple, list)): + assert len(scale_factor) == len(shape) - 2 + size = tuple( + int(shape[i].value * scale_factor[i - 2]) for i in range(2, len(shape)) + ) + else: + size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) + + return self.block_builder.emit( + relax.op.image.resize2d( + x, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + ) + ) + + def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) + align_corners = ( + node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", True) + ) + scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None) + return self._upsample_impl(x, size, align_corners, scale_factor, "linear") + + def _upsample_nearest2d(self, node: fx.node) -> relax.Var: + x = self.env[node.args[0]] + size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) + align_corners = ( + node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", True) + ) + scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None) + return self._upsample_impl(x, size, align_corners, scale_factor, "nearest_neighbor") + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: @@ -200,6 +240,8 @@ def create_convert_map( "max_pool2d.default": self._max_pool2d, "scaled_dot_product_attention.default": self._scaled_dot_product_attention, "unbind.int": self._unbind, + "upsample_bilinear2d.vec": self._upsample_bilinear2d, + "upsample_nearest2d.vec": self._upsample_nearest2d, # statistical "mean.dim": self._mean, "sum.dim_IntList": self._sum, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 187b9c681d33..7c887d9b9610 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2524,6 +2524,70 @@ def main( verify_model(Unbind2(), example_args, {}, expected2) +def test_interpolate(): + class InterpolateBilinear(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, (224, 224), mode="bilinear") + + @tvm.script.ir_module + class expected_bilinear: + @R.function + def main( + input: R.Tensor((1, 3, 112, 112), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d( + input, + R.shape([224, 224]), + roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], + layout="NCHW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.5, + cubic_exclude=0, + extrapolation_value=0.0, + out_dtype="void", + ) + gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class InterpolateNearest(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, (224, 224), mode="nearest") + + @tvm.script.ir_module + class expected_nearest: + @R.function + def main( + input: R.Tensor((1, 3, 112, 112), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d( + input, + R.shape([224, 224]), + roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.5, + cubic_exclude=0, + extrapolation_value=0.0, + out_dtype="void", + ) + gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),) + verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear) + verify_model(InterpolateNearest(), example_args, {}, expected_nearest) + + def test_mean(): class Mean(Module): def forward(self, input): From 52a930d9bab14d2210603045ec7d82274eaddaba Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sat, 28 Sep 2024 22:53:32 +0900 Subject: [PATCH 17/17] fix lint error --- .../tvm/relax/frontend/torch/exported_program_translator.py | 4 ++-- python/tvm/relax/frontend/torch/fx_translator.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 11122117fbd6..64583d750974 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -76,7 +76,7 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr: ########## Neural Network ########## - def _native_batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: + def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: import numpy as np x = self.env[node.args[0]] @@ -217,7 +217,7 @@ def create_convert_map( "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow), "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub), # neural network - "_native_batch_norm_legit_no_training.default": self._native_batch_norm_legit_no_training, + "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "addmm.default": self._addmm, "avg_pool2d.default": self._avg_pool2d, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 9f064bcda473..c60c7c3953b4 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -18,7 +18,7 @@ # pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck # pylint: disable=import-outside-toplevel """PyTorch FX frontend of Relax.""" -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union from functools import partial, reduce import tvm