From bd57a7e7f45794b8f4d43e94a3c89075ef00ab5d Mon Sep 17 00:00:00 2001 From: HongHongHongL Date: Mon, 9 Sep 2024 15:02:46 +0800 Subject: [PATCH 1/2] fix params name bug --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index c3116f9988ce..462d1cf92c01 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -91,7 +91,7 @@ def get_constant( # Convert if possible if isinstance(var, relax.Var) and var.name_hint in params: # When converting a parameter to a constant, update references to it as well. - _, value = params.pop(var.name_hint) + _, value = params[var.name_hint] const_value = relax.const(value) graph_nodes[var.name_hint] = const_value return const_value @@ -2152,7 +2152,7 @@ def _parse_graph_initializers(self, graph: onnx.onnx_ml_pb2.GraphProto): init_var = self._new_var(var_name, shape=array.shape, dtype=array.dtype) self._nodes[init_tensor.name] = init_var # We need to keep track of both the real value and variable for this variable. - self._params[init_tensor.name] = (init_var, array) + self._params[var_name] = (init_var, array) # Otherwise we can use the weight as a constant. else: self._nodes[init_tensor.name] = relax.const(array) From 046ba98a2c8b7082d969d71a5fe525c07ced9f9f Mon Sep 17 00:00:00 2001 From: HongHongHongL Date: Tue, 10 Sep 2024 16:36:06 +0800 Subject: [PATCH 2/2] add test_multi_ops_with_same_params and test_params_names_start_with_onnx --- tests/python/relax/test_frontend_onnx.py | 43 ++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 3ea987973578..8f4e9881f497 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1909,5 +1909,48 @@ def test_multi_inputs_with_same_symbolic_shape(): check_correctness(model) +def test_multi_ops_with_same_params(): + reshape_node_1 = helper.make_node("Reshape", ["a", "x"], ["b"]) + reshape_node_2 = helper.make_node("Reshape", ["b", "x"], ["c"]) + + a_shape = [16] + output_shape = [1, 16] + + graph = helper.make_graph( + [reshape_node_1, reshape_node_2], + "test_multi_ops_with_same_params", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, a_shape), + ], + initializer=[ + helper.make_tensor("x", TensorProto.INT64, [2], output_shape), + ], + outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="test_multi_ops_with_same_params") + check_correctness(model) + + +def test_params_names_start_with_onnx(): + reshape_node = helper.make_node("Reshape", ["a", "onnx::x"], ["b"]) + + a_shape = [16] + output_shape = [1, 16] + + graph = helper.make_graph( + [reshape_node], + "test_params_names_start_with_onnx", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, a_shape), + ], + initializer=[ + helper.make_tensor("onnx::x", TensorProto.INT64, [2], output_shape), + ], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="test_params_names_start_with_onnx") + check_correctness(model) + + if __name__ == "__main__": tvm.testing.main()