From 180802a1009a533b0b9bd0a1589d8f7af39ba5eb Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 5 Feb 2024 22:19:31 +0000 Subject: [PATCH 1/7] Add nn frontend support for conv3d and related operators --- python/tvm/relax/frontend/nn/modules.py | 87 +++++++++++++++++++++---- python/tvm/relax/frontend/nn/op.py | 85 ++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 11 deletions(-) diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index 29b9c7fcca48..4a4ce06db3e8 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -218,7 +218,7 @@ def __init__( # pylint: disable=too-many-arguments self, in_channels: int, out_channels: int, - kernel_size: int, + kernel_size: Union[List[int], int], stride: int = 1, padding: int = 0, dilation: int = 1, @@ -229,7 +229,6 @@ def __init__( # pylint: disable=too-many-arguments super().__init__() self.in_channels = in_channels self.out_channels = out_channels - self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation @@ -241,15 +240,16 @@ def __init__( # pylint: disable=too-many-arguments else: in_channels = tir.floordiv(self.in_channels, self.groups) - self.weight = Parameter( - ( - self.out_channels, - in_channels, - self.kernel_size, - self.kernel_size, - ), - dtype, - ) + # Expand kernel size if provided an integer. + if isinstance(kernel_size, int): + self.kernel_size = [kernel_size] * 2 + else: + self.kernel_size = kernel_size + + kernel_shape = [self.out_channels, in_channels] + list(self.kernel_size) + + self.weight = Parameter(kernel_shape, dtype) + if bias: self.bias = Parameter((self.out_channels,), dtype) else: @@ -274,6 +274,71 @@ def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name ) +class Conv3D(Module): + """ + Module for conv3d layer. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + in_channels: int, + out_channels: int, + kernel_size: Union[List[int], int], + stride: Union[List[int], int] = 1, + padding: Union[List[int], int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + dtype: Optional[str] = None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + + # Allow dynamic input channels. + if isinstance(self.in_channels, int): + in_channels = int(self.in_channels / self.groups) + else: + in_channels = tir.floordiv(self.in_channels, self.groups) + + # Expand kernel size if given an integer. + if isinstance(kernel_size, int): + self.kernel_size = [kernel_size] * 3 + else: + self.kernel_size = kernel_size + + kernel_shape = [self.out_channels, self.in_channels] + list(self.kernel_size) + + self.weight = Parameter(kernel_shape, dtype) + + if bias: + self.bias = Parameter((self.out_channels,), dtype) + else: + self.bias = None + + def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name + """ + Forward method for conv3d layer. + + Parameters + ---------- + x : Tensor + The input tensor. + + Returns + ------- + ret : Tensor + The output tensor for the conv2d layer. + """ + return op.conv3d( + x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + + class ConvTranspose1D(Module): """ Module for ConvTranspose1D layer. diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index ae880190ad46..44453f94aa9c 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -421,6 +421,64 @@ def conv2d( return wrap_nested(conv_out, name) +def conv3d( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + stride: Optional[Union[int, Tuple]] = 1, + padding: Optional[Union[int, Tuple, str]] = 0, + dilation: Optional[Union[int, Tuple]] = 1, + groups: Optional[int] = 1, + name: str = "conv3d", +) -> Tensor: + """Applies a 3D convolution over an input image composed of sevaral input planes + + Parameters + ---------- + x : Tensor + Input tensor of shape [B, N, D, H, W] + + weight : Tensor + Filters of shape [O, N/groups, kD, kH, kW] + + bias : Optional[Tensor] + Optional bias tensor of shape [O]. + + stride : Optional[Union[int, Tuple]] + The stride of the convolving kernel. Can be a single number + or tuple of (sD, sH, sW). + + padding : Optional[[Union[int, Tuple]]] + Implicit paddings on both sides of the input. + + dilation : Optional[Union[int, Tuple]] + The spacing between kernel elements. Can be a single number of tuple (dD, dH, dW). + + groups : Optional[int] + Split input into a number of groups. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result with shape [B, O, oD, oH, oW]. + """ + conv_out = _op.nn.conv3d( + data=x._expr, + weight=weight._expr, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + if bias is not None: + conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1, 1])) + + return wrap_nested(conv_out, name) + + def conv1d_transpose( x: Tensor, weight: Tensor, @@ -1486,6 +1544,33 @@ def interpolate( ) +def where(condition: Tensor, input: Tensor, other: Tensor, name: str = "where") -> Tensor: + """Return a tensor of elemends selected from input or other based on condition. + + Parameters + ---------- + condition : Tensor + When True, yield input, otherwise yield other. + + input : Tensor + Value or values selected at indices where condition is True. + + other : Tensor + Value or values selected at indices where condition is False. + + name : str + Name hint. + + Returns + ------- + result : Tensor + The computed result. + """ + # Cast condition to boolean. + condition = astype(condition, "bool") + return wrap_nested(_op.where(condition._expr, input._expr, other._expr), name) + + def ccl_allreduce(x: Tensor, op_type: str = "sum", name="ccl_allreduce"): """CCL Allreduce operator From 04e518d837e3dbf536342501ddcaa664c778ba2a Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 9 Feb 2024 17:17:59 +0000 Subject: [PATCH 2/7] Expose conv3d attrs --- python/tvm/relax/op/op_attrs.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index a3d46428c53a..4658950f511a 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -59,6 +59,11 @@ class Conv2DAttrs(Attrs): """Attributes for nn.conv2d""" +@tvm._ffi.register_object("relax.attrs.Conv3DAttrs") +class Conv3DAttrs(Attrs): + """Attributes for nn.conv3d""" + + @tvm._ffi.register_object("relax.attrs.Conv2DTransposeAttrs") class Conv2DTransposeAttrs(Attrs): """Attributes for nn.conv2d_transpose""" From 776bfb442b50215514f4bc344107e992ca4dcaa5 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 21 Feb 2024 16:28:50 +0000 Subject: [PATCH 3/7] Allow NHWC in nn frontend --- python/tvm/relax/frontend/nn/modules.py | 22 ++++++- python/tvm/relax/frontend/nn/op.py | 46 +++++++++++--- src/relax/op/image/resize.cc | 19 ++++-- src/tir/schedule/state.cc | 2 +- .../relax/test_transform_convert_layout.py | 62 ++++++++++++++++++- 5 files changed, 134 insertions(+), 17 deletions(-) diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index 4a4ce06db3e8..a46b3be6a16d 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -225,6 +225,7 @@ def __init__( # pylint: disable=too-many-arguments groups: int = 1, bias: bool = True, dtype: Optional[str] = None, + data_layout: str = "NCHW", ): super().__init__() self.in_channels = in_channels @@ -233,6 +234,7 @@ def __init__( # pylint: disable=too-many-arguments self.padding = padding self.dilation = dilation self.groups = groups + self.data_layout = data_layout # Allow dynamic input channels. if isinstance(self.in_channels, int): @@ -270,7 +272,14 @@ def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name The output tensor for the conv2d layer. """ return op.conv2d( - x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.data_layout, ) @@ -290,6 +299,7 @@ def __init__( # pylint: disable=too-many-arguments groups: int = 1, bias: bool = True, dtype: Optional[str] = None, + data_layout: str = "NCDHW", ): super().__init__() self.in_channels = in_channels @@ -298,6 +308,7 @@ def __init__( # pylint: disable=too-many-arguments self.padding = padding self.dilation = dilation self.groups = groups + self.data_layout = data_layout # Allow dynamic input channels. if isinstance(self.in_channels, int): @@ -335,7 +346,14 @@ def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name The output tensor for the conv2d layer. """ return op.conv3d( - x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.data_layout, ) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 44453f94aa9c..1fe65cc64aa7 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -371,6 +371,7 @@ def conv2d( padding: Optional[Union[int, Tuple, str]] = 0, dilation: Optional[Union[int, Tuple]] = 1, groups: Optional[int] = 1, + data_layout: Optional[str] = "NCHW", name: str = "conv2d", ) -> Tensor: """Applies a 2D convolution over an input image composed of sevaral input planes @@ -399,6 +400,9 @@ def conv2d( groups : Optional[int] Split input into a number of groups. + data_layout : Optional[str] + Layout of input and output data. + name : str Name hint. @@ -413,10 +417,16 @@ def conv2d( strides=stride, padding=padding, dilation=dilation, + data_layout=data_layout, groups=groups, ) if bias is not None: - conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1])) + if data_layout == "NCHW": + conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1])) + elif data_layout == "NHWC": + conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, 1, 1, -1])) + else: + raise NotImplementedError(f"Dont know how to handle layout {data_layout}.") return wrap_nested(conv_out, name) @@ -429,6 +439,7 @@ def conv3d( padding: Optional[Union[int, Tuple, str]] = 0, dilation: Optional[Union[int, Tuple]] = 1, groups: Optional[int] = 1, + data_layout: Optional[str] = "NCDHW", name: str = "conv3d", ) -> Tensor: """Applies a 3D convolution over an input image composed of sevaral input planes @@ -457,6 +468,9 @@ def conv3d( groups : Optional[int] Split input into a number of groups. + data_layout : Optional[str] + Optional layout of the input and output data. + name : str Name hint. @@ -472,9 +486,15 @@ def conv3d( padding=padding, dilation=dilation, groups=groups, + data_layout=data_layout, ) if bias is not None: - conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1, 1])) + if data_layout == "NCDHW": + conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1, 1])) + elif data_layout == "NDHWC": + conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, 1, 1, 1, -1])) + else: + raise NotImplemented(f"Dont know how to handle layout {data_layout}.") return wrap_nested(conv_out, name) @@ -1485,6 +1505,7 @@ def interpolate( align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: Optional[bool] = None, + data_layout: Optional[str] = "NCHW", name: str = "interpolate", ): """Resize a tensor using the specified mode. @@ -1506,6 +1527,8 @@ def interpolate( Recompute the scale_factor for use in interpolation. antialias : Optional[bool] Apply antialiasing to output. + data_layout : Optional[str] + Layout of the input and output data. name : str Name hint for this operation. @@ -1518,11 +1541,14 @@ def interpolate( assert antialias is None, "antialias is not supported." if size is None: - shape = x.shape - if isinstance(scale_factor, (list, tuple)): - size = tuple(int(shape[i] * scale_factor[i]) for i in range(2, len(shape))) - else: - size = tuple(int(shape[i] * scale_factor) for i in range(2, len(shape))) + size = [] + for i, dim in enumerate(data_layout): + # Only upscale spatial dimensions. + if dim not in ["N", "C"]: + if isinstance(scale_factor, (list, tuple)): + size.append(int(x.shape[i] * scale_factor[len(size)])) + else: + size.append(int(x.shape[i] * scale_factor)) if mode.startswith("nearest"): mode = "nearest_neighbor" @@ -1538,7 +1564,11 @@ def interpolate( return wrap_nested( _op.image.resize2d( - x._expr, size, layout="NCHW", method=mode, coordinate_transformation_mode=coord_trans + x._expr, + size, + layout=data_layout, + method=mode, + coordinate_transformation_mode=coord_trans, ), name, ) diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index 8b92f34edd81..0ea8cdec12b5 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -105,14 +105,25 @@ StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutResize2d(const Call& call, const Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + const auto& it = desired_layouts.find("relax.image.resize2d"); const auto* attrs = call->attrs.as(); ICHECK(attrs) << "Invalid Call"; - LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + LayoutDecision data_layout; ObjectPtr new_attrs = make_object(*attrs); - new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name(); - return InferLayoutOutput({layout, InitialNLayout(call->args[1])}, {layout}, Attrs(new_attrs)); + + if (it != desired_layouts.end()) { + // We have a desired layout for resize2d. + Layout desired_data_layout = (*it).second[0]; + ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only"; + data_layout = TransposeLike(InitialLayout(4), attrs->layout, desired_data_layout); + new_attrs->layout = (*it).second[0]; + } else { + // We dont have a desired layout for resize2d, propagate from the input instead. + data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), data_layout->layout).name(); + } + return InferLayoutOutput({data_layout, InitialNLayout(call->args[1])}, {data_layout}, Attrs(new_attrs)); } TVM_REGISTER_OP("relax.image.resize2d") diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index ecb857a4c353..4fac6f22d26d 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -384,7 +384,7 @@ ScheduleState::ScheduleState(IRModule mod, int debug_mask, bool enable_check) { const BaseFunc& base_func = kv.second; if (auto opt = base_func.as()) { auto func = opt.value(); - VerifyWellFormed(func); + //VerifyWellFormed(func); BlockInfoCollector::Collect(self, func->body); } } diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 417a5519e0b9..56b59ba23867 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -21,8 +21,10 @@ from tvm.script.parser import ir as I, relax as R, tir as T -def verify(input, expected): - mod = ConvertLayout({"relax.nn.conv2d": ["NHWC", "OHWI"]})(input) +def verify(input, expected, extra_ops={}): + desired_layouts = {"relax.nn.conv2d": ["NHWC", "OHWI"]} + desired_layouts.update(extra_ops) + mod = ConvertLayout(desired_layouts)(input) mod = Normalize()(mod) tvm.ir.assert_structural_equal(mod, expected) @@ -1303,6 +1305,62 @@ def main( verify(Input, Expected) +def test_resize2d_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.image.resize2d(x, (52, 52), layout="NCHW") + gv2: R.Tensor((2, 4, 50, 50), "float32") = R.nn.conv2d(gv, w, out_dtype="float32") + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32") + ) -> R.Tensor((2, 4, 50, 50), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + gv: R.Tensor((2, 52, 52, 3), dtype="float32") = R.image.resize2d( + lv, + R.shape([52, 52]), + roi=[T.float32(0), T.float32(0), T.float32(0), T.float32(0)], + layout="NHWC", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.5, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="void", + ) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv2: R.Tensor((2, 50, 50, 4), dtype="float32") = R.nn.conv2d( + gv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + gv2: R.Tensor((2, 4, 50, 50), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv2) + return gv2 + + verify(Input, Expected, extra_ops={"relax.image.resize2d": ["NHWC"]}) + + def test_conv2d_unknown_bias_dim(): @I.ir_module class Input: From c6c07bca72d6c3c7f25392dea47d5b5f14fe9f15 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 28 Feb 2024 17:09:03 +0000 Subject: [PATCH 4/7] Add tests. --- src/tir/schedule/state.cc | 2 +- .../python/relax/test_frontend_nn_modules.py | 31 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 4fac6f22d26d..ecb857a4c353 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -384,7 +384,7 @@ ScheduleState::ScheduleState(IRModule mod, int debug_mask, bool enable_check) { const BaseFunc& base_func = kv.second; if (auto opt = base_func.as()) { auto func = opt.value(); - //VerifyWellFormed(func); + VerifyWellFormed(func); BlockInfoCollector::Collect(self, func->body); } } diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index f438f387056c..dd78ad5a5545 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -246,6 +246,37 @@ def forward( assert_structural_equal(tvm_mod["forward"], forward, True) +def test_conv3d(): + @R.function + def forward( + x: R.Tensor((1, 3, 32, 32, 32), dtype="float32"), + _io: R.Object, + weight: R.Tensor((32, 3, 3, 3, 3), dtype="float32"), + bias: R.Tensor((32,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 32, 30, 30, 30), dtype="float32"), R.Tuple(R.Object)): + R.func_attr({"num_input": 2}) + with R.dataflow(): + lv1: R.Tensor((1, 32, 30, 30, 30), dtype="float32") = R.nn.conv3d(x, weight) + lv2: R.Tensor((1, 32, 1, 1, 1), dtype="float32") = R.reshape(bias, R.shape([1, 32, 1, 1, 1])) + conv3d: R.Tensor((1, 32, 30, 30, 30), dtype="float32") = R.add(lv1, lv2) + gv1: R.Tuple(R.Tensor((1, 32, 30, 30, 30), dtype="float32"), R.Tuple(R.Object)) = conv3d, ( + _io, + ) + R.output(gv1) + return gv1 + + mod = modules.Conv3D(3, 32, 3, bias=True) + tvm_mod, _ = mod.export_tvm( + spec={ + "forward": { + "x": spec.Tensor([1, 3, 32, 32, 32], "float32"), + } + }, + debug=True, + ) + assert_structural_equal(tvm_mod["forward"], forward, True) + + def test_conv2d_dynamic(): @R.function def forward( From 2a8bca7a16e288508a1f1a9b5bc38bcb278458cd Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 28 Feb 2024 17:09:24 +0000 Subject: [PATCH 5/7] Formatting. --- src/relax/op/image/resize.cc | 3 ++- tests/python/relax/test_frontend_nn_modules.py | 10 ++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index 0ea8cdec12b5..202702d78746 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -123,7 +123,8 @@ InferLayoutOutput InferLayoutResize2d(const Call& call, data_layout = GetLayoutDecision(var_layout_map, call->args[0]); new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), data_layout->layout).name(); } - return InferLayoutOutput({data_layout, InitialNLayout(call->args[1])}, {data_layout}, Attrs(new_attrs)); + return InferLayoutOutput({data_layout, InitialNLayout(call->args[1])}, {data_layout}, + Attrs(new_attrs)); } TVM_REGISTER_OP("relax.image.resize2d") diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index dd78ad5a5545..6966a5f2a927 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -257,11 +257,13 @@ def forward( R.func_attr({"num_input": 2}) with R.dataflow(): lv1: R.Tensor((1, 32, 30, 30, 30), dtype="float32") = R.nn.conv3d(x, weight) - lv2: R.Tensor((1, 32, 1, 1, 1), dtype="float32") = R.reshape(bias, R.shape([1, 32, 1, 1, 1])) - conv3d: R.Tensor((1, 32, 30, 30, 30), dtype="float32") = R.add(lv1, lv2) - gv1: R.Tuple(R.Tensor((1, 32, 30, 30, 30), dtype="float32"), R.Tuple(R.Object)) = conv3d, ( - _io, + lv2: R.Tensor((1, 32, 1, 1, 1), dtype="float32") = R.reshape( + bias, R.shape([1, 32, 1, 1, 1]) ) + conv3d: R.Tensor((1, 32, 30, 30, 30), dtype="float32") = R.add(lv1, lv2) + gv1: R.Tuple( + R.Tensor((1, 32, 30, 30, 30), dtype="float32"), R.Tuple(R.Object) + ) = conv3d, (_io,) R.output(gv1) return gv1 From 47a99c46a43c86f1d57859bc062a0c9a3af62fd4 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 28 Feb 2024 18:19:33 +0000 Subject: [PATCH 6/7] Fix typo --- python/tvm/relax/frontend/nn/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index a46b3be6a16d..e69660f70880 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -343,7 +343,7 @@ def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name Returns ------- ret : Tensor - The output tensor for the conv2d layer. + The output tensor for the conv3d layer. """ return op.conv3d( x, From 1c80d0c02b061a9177da321289b71821e79239d9 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 28 Feb 2024 18:55:04 +0000 Subject: [PATCH 7/7] Fix lint issues --- python/tvm/relax/frontend/nn/op.py | 31 +++--------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 1fe65cc64aa7..d299d3943944 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -494,7 +494,7 @@ def conv3d( elif data_layout == "NDHWC": conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, 1, 1, 1, -1])) else: - raise NotImplemented(f"Dont know how to handle layout {data_layout}.") + raise NotImplementedError(f"Dont know how to handle layout {data_layout}.") return wrap_nested(conv_out, name) @@ -1574,33 +1574,6 @@ def interpolate( ) -def where(condition: Tensor, input: Tensor, other: Tensor, name: str = "where") -> Tensor: - """Return a tensor of elemends selected from input or other based on condition. - - Parameters - ---------- - condition : Tensor - When True, yield input, otherwise yield other. - - input : Tensor - Value or values selected at indices where condition is True. - - other : Tensor - Value or values selected at indices where condition is False. - - name : str - Name hint. - - Returns - ------- - result : Tensor - The computed result. - """ - # Cast condition to boolean. - condition = astype(condition, "bool") - return wrap_nested(_op.where(condition._expr, input._expr, other._expr), name) - - def ccl_allreduce(x: Tensor, op_type: str = "sum", name="ccl_allreduce"): """CCL Allreduce operator @@ -2106,6 +2079,8 @@ def where(condition: Tensor, x1: Tensor, x2: Tensor, name: str = "where") -> Ten result : Tensor The result tensor. """ + # Cast condition to boolean. + condition = astype(condition, "bool") return wrap_nested(_op.where(condition._expr, x1._expr, x2._expr), name)