diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index f3a5e9098de8..f32a260f8129 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -48,6 +48,7 @@ from tvm.ir import IRModule from tvm.ir.supply import NameSupply from tvm.tir.generic import cast +from tvm.topi.utils import get_const_tuple from ..common import autopad @@ -2488,9 +2489,19 @@ def _impl_v17(cls, bb, inputs, attr, params): axis = attr.get("axis", -1) epsilon = attr.get("epsilon", 1e-05) + gamma_shape = get_const_tuple(scale.struct_info.shape) + if bias is None: seq_len = data.struct_info.shape[1].value bias = relax.const([0.0] * seq_len, dtype="float32") + else: + beta_shape = get_const_tuple(bias.struct_info.shape) + if gamma_shape != beta_shape: + raise ValueError("gamma and beta shapes do not match") + + axis = list(axis) if isinstance(axis, (list, tuple)) else [axis] + if len(axis) < len(gamma_shape): + axis.extend(range(axis[-1] + 1, axis[-1] + 1 + len(gamma_shape) - len(axis))) output = relax.op.nn.layer_norm(data, scale, bias, axis, epsilon) # Onnx layernorm has 3 outputs but only the first is used. diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 0d532e07fc33..acd7419183bb 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1282,18 +1282,20 @@ def test_mean_variance_norm(): def test_layer_norm(): - layer_norm_node = helper.make_node("LayerNormalization", ["a", "b", "c"], ["d"], epsilon=1e-12) + layer_norm_node = helper.make_node( + "LayerNormalization", ["input", "scale", "bias"], ["Y"], epsilon=1e-12 + ) graph = helper.make_graph( [layer_norm_node], "layer_norm_test", inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), - helper.make_tensor_value_info("b", TensorProto.FLOAT, [32]), - helper.make_tensor_value_info("c", TensorProto.FLOAT, [32]), + helper.make_tensor_value_info("input", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, [32]), ], outputs=[ - helper.make_tensor_value_info("d", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [32, 32]), ], ) @@ -1301,17 +1303,19 @@ def test_layer_norm(): check_correctness(model) # Test case with no bias that is an optional input - layer_norm_node = helper.make_node("LayerNormalization", ["a", "b"], ["d"], epsilon=1e-12) + layer_norm_node = helper.make_node( + "LayerNormalization", ["input", "scale"], ["Y"], epsilon=1e-12 + ) graph = helper.make_graph( [layer_norm_node], "layer_norm_test", inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), - helper.make_tensor_value_info("b", TensorProto.FLOAT, [32]), + helper.make_tensor_value_info("input", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]), ], outputs=[ - helper.make_tensor_value_info("d", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [32, 32]), ], ) @@ -1319,6 +1323,48 @@ def test_layer_norm(): check_correctness(model) +def test_layer_norm_with_nd_gamma_beta(): + layer_norm_node = helper.make_node( + "LayerNormalization", ["input", "scale", "bias"], ["Y"], axis=1, epsilon=1e-12 + ) + + graph = helper.make_graph( + [layer_norm_node], + "layer_norm_with_nd_gamma_beta_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 4, 4]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [3, 4, 4]), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, [3, 4, 4]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 4, 4]), + ], + ) + + model = helper.make_model(graph, producer_name="layer_norm_with_nd_gamma_beta_test") + check_correctness(model) + + # Test case with no bias that is an optional input + layer_norm_node = helper.make_node( + "LayerNormalization", ["input", "scale"], ["Y"], axis=1, epsilon=1e-12 + ) + + graph = helper.make_graph( + [layer_norm_node], + "layer_norm_with_nd_gamma_beta_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [32, 32]), + ], + ) + + model = helper.make_model(graph, producer_name="layer_norm_with_nd_gamma_beta_test") + check_correctness(model) + + # TODO Enable dynamism @pytest.mark.parametrize("dynamic", [False]) def test_skiplayernormalization(dynamic):