diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index c3116f9988ce..462d1cf92c01 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -91,7 +91,7 @@ def get_constant( # Convert if possible if isinstance(var, relax.Var) and var.name_hint in params: # When converting a parameter to a constant, update references to it as well. - _, value = params.pop(var.name_hint) + _, value = params[var.name_hint] const_value = relax.const(value) graph_nodes[var.name_hint] = const_value return const_value @@ -2152,7 +2152,7 @@ def _parse_graph_initializers(self, graph: onnx.onnx_ml_pb2.GraphProto): init_var = self._new_var(var_name, shape=array.shape, dtype=array.dtype) self._nodes[init_tensor.name] = init_var # We need to keep track of both the real value and variable for this variable. - self._params[init_tensor.name] = (init_var, array) + self._params[var_name] = (init_var, array) # Otherwise we can use the weight as a constant. else: self._nodes[init_tensor.name] = relax.const(array) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 3ea987973578..8f4e9881f497 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1909,5 +1909,48 @@ def test_multi_inputs_with_same_symbolic_shape(): check_correctness(model) +def test_multi_ops_with_same_params(): + reshape_node_1 = helper.make_node("Reshape", ["a", "x"], ["b"]) + reshape_node_2 = helper.make_node("Reshape", ["b", "x"], ["c"]) + + a_shape = [16] + output_shape = [1, 16] + + graph = helper.make_graph( + [reshape_node_1, reshape_node_2], + "test_multi_ops_with_same_params", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, a_shape), + ], + initializer=[ + helper.make_tensor("x", TensorProto.INT64, [2], output_shape), + ], + outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="test_multi_ops_with_same_params") + check_correctness(model) + + +def test_params_names_start_with_onnx(): + reshape_node = helper.make_node("Reshape", ["a", "onnx::x"], ["b"]) + + a_shape = [16] + output_shape = [1, 16] + + graph = helper.make_graph( + [reshape_node], + "test_params_names_start_with_onnx", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, a_shape), + ], + initializer=[ + helper.make_tensor("onnx::x", TensorProto.INT64, [2], output_shape), + ], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, output_shape)], + ) + model = helper.make_model(graph, producer_name="test_params_names_start_with_onnx") + check_correctness(model) + + if __name__ == "__main__": tvm.testing.main()