diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 5896f13471..e0e4cb68ae 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2127,7 +2127,7 @@ def aten_conv3d( if bias is None: weight_dim_0 = op.Shape(weight, start=0, end=1) - bias_shape = op.Concat(weight_dim_0, op.Constant(value_ints=[2]), axis=0) + bias_shape = op.Expand(weight_dim_0, op.Constant(value_ints=[1])) zero = op.CastLike(0.0, input) bias = op.Expand(zero, bias_shape) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 440e2316c1..1a393d6fd7 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -5,6 +5,7 @@ import math import unittest +import onnx import parameterized # TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo @@ -1004,6 +1005,32 @@ def forward(self, x): got = onnx_program.call_reference({"x": inputs[0]}) torch.testing.assert_close(expected, got[0]) + def test_conv3d_without_bias_produces_1d_bias(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 4, kernel_size=2, bias=False) + + def forward(self, x): + return self.conv(x) + + onnx_program = torch.onnx.export( + Model().eval(), (torch.randn(1, 3, 8, 8, 8),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + # The bias synthesized for a bias-less conv must be 1D ([out_channels]) to + # match the ONNX Conv spec. See https://github.com/microsoft/onnxscript/issues/2931. + inferred = onnx.shape_inference.infer_shapes(onnx_program.model_proto, data_prop=True) + shape_ranks = { + value_info.name: len(value_info.type.tensor_type.shape.dim) + for value_info in inferred.graph.value_info + if value_info.type.tensor_type.HasField("shape") + } + conv_nodes = [node for node in inferred.graph.node if node.op_type == "Conv"] + self.assertEqual(len(conv_nodes), 1) + self.assertEqual(shape_ranks[conv_nodes[0].input[2]], 1) + if __name__ == "__main__": unittest.main()