Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,11 +1016,16 @@ def _impl_v13(cls, bb, inputs, attr, params):
data = inputs[0]
new_shape = get_constant(inputs[1], params)

if isinstance(data, relax.ShapeExpr) and isinstance(new_shape, relax.Constant):
new_shape = new_shape.data.numpy().tolist()
if new_shape != [-1]:
raise NotImplementedError("Need to fix this case")
return data
if isinstance(data, relax.ShapeExpr):
# Preserve identity flatten for shape values to keep shape-specialized
# handling in downstream shape-construction patterns.
if isinstance(new_shape, relax.Constant):
new_shape_values = new_shape.data.numpy().tolist()
if new_shape_values == [-1]:
return data

# Other reshape targets follow regular int64 tensor reshape semantics.
data = bb.normalize(relax.op.shape_to_tensor(data))

if isinstance(data, relax.Constant) and isinstance(new_shape, relax.Constant):
out = _np.reshape(data.data.numpy(), new_shape.data.numpy().tolist())
Expand Down
57 changes: 53 additions & 4 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,38 @@ def test_reshape(in_shape, shape, out_shape):
check_correctness(model, inputs=input_values)


@pytest.mark.parametrize(
"target_shape, output_shape",
[
([-1], [3]),
([1, 3], [1, 3]),
([3, 1], [3, 1]),
],
)
def test_reshape_shape_output(target_shape, output_shape):
shape_node = helper.make_node("Shape", ["data"], ["shape_out"])
reshape_node = helper.make_node("Reshape", ["shape_out", "target_shape"], ["reshaped"])

data_shape = [2, 3, 4]

graph = helper.make_graph(
[shape_node, reshape_node],
"reshape_shape_output",
inputs=[
helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape),
],
initializer=[
helper.make_tensor("target_shape", TensorProto.INT64, [len(target_shape)], target_shape)
],
outputs=[helper.make_tensor_value_info("reshaped", TensorProto.INT64, output_shape)],
)
input_values = {
"data": np.random.randn(*data_shape).astype("float32"),
}
model = helper.make_model(graph, producer_name="reshape_shape_output")
check_correctness(model, inputs=input_values)


def test_transpose():
verify_unary("Transpose", [32, 32, 32], attrs={"perm": [1, 2, 0]})

Expand Down Expand Up @@ -3630,29 +3662,46 @@ def test_optional_get_element_empty_raises():
from_onnx(model, opset=18, keep_params_in_input=True)


def test_symbolic_shape_deduction():
@pytest.mark.parametrize("with_reshape_flatten", [False, True])
def test_symbolic_shape_deduction(with_reshape_flatten):
index_node = helper.make_node(
"Constant",
inputs=[],
outputs=["indices"],
value=helper.make_tensor("indices", TensorProto.INT64, [], [0]),
)
shape_node = helper.make_node("Shape", ["data"], ["shape_output"])
gather_node = helper.make_node("Gather", ["shape_output", "indices"], ["gather_output"])
nodes = [index_node, shape_node]
gather_input = "shape_output"

if with_reshape_flatten:
reshape_node = helper.make_node(
"Reshape", ["shape_output", "target_shape"], ["reshaped_shape"]
)
nodes.append(reshape_node)
gather_input = "reshaped_shape"

gather_node = helper.make_node("Gather", [gather_input, "indices"], ["gather_output"])
unsqueeze_node = helper.make_node("Unsqueeze", ["gather_output", "axes"], ["unsqueeze_output"])
constant_of_shape_node = helper.make_node(
"ConstantOfShape",
["unsqueeze_output"],
["output"],
value=helper.make_tensor("value", TensorProto.FLOAT, [], [1]),
)
nodes.extend([gather_node, unsqueeze_node, constant_of_shape_node])
Comment on lines 3673 to +3692
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability and conciseness, the construction of the nodes list can be refactored. Creating the nodes directly within nodes.append and nodes.extend calls avoids unnecessary intermediate variables, making the graph's structure more direct and easier to understand.

Suggested change
shape_node = helper.make_node("Shape", ["data"], ["shape_output"])
gather_node = helper.make_node("Gather", ["shape_output", "indices"], ["gather_output"])
nodes = [index_node, shape_node]
gather_input = "shape_output"
if with_reshape_flatten:
reshape_node = helper.make_node(
"Reshape", ["shape_output", "target_shape"], ["reshaped_shape"]
)
nodes.append(reshape_node)
gather_input = "reshaped_shape"
gather_node = helper.make_node("Gather", [gather_input, "indices"], ["gather_output"])
unsqueeze_node = helper.make_node("Unsqueeze", ["gather_output", "axes"], ["unsqueeze_output"])
constant_of_shape_node = helper.make_node(
"ConstantOfShape",
["unsqueeze_output"],
["output"],
value=helper.make_tensor("value", TensorProto.FLOAT, [], [1]),
)
nodes.extend([gather_node, unsqueeze_node, constant_of_shape_node])
shape_node = helper.make_node("Shape", ["data"], ["shape_output"])
nodes = [index_node, shape_node]
gather_input = "shape_output"
if with_reshape_flatten:
nodes.append(
helper.make_node("Reshape", ["shape_output", "target_shape"], ["reshaped_shape"])
)
gather_input = "reshaped_shape"
nodes.extend([
helper.make_node("Gather", [gather_input, "indices"], ["gather_output"]),
helper.make_node("Unsqueeze", ["gather_output", "axes"], ["unsqueeze_output"]),
helper.make_node(
"ConstantOfShape",
["unsqueeze_output"],
["output"],
value=helper.make_tensor("value", TensorProto.FLOAT, [], [1]),
),
])


initializers = [helper.make_tensor("axes", TensorProto.INT64, [1], vals=[0])]
if with_reshape_flatten:
initializers.append(helper.make_tensor("target_shape", TensorProto.INT64, [1], vals=[-1]))

graph = helper.make_graph(
[index_node, shape_node, gather_node, unsqueeze_node, constant_of_shape_node],
nodes,
"test_shape_deduction",
inputs=[
helper.make_tensor_value_info("data", TensorProto.FLOAT, ["batch", "seq"]),
],
initializer=[helper.make_tensor("axes", TensorProto.INT64, [1], vals=[0])],
initializer=initializers,
outputs=[helper.make_tensor_value_info("output", TensorProto.INT64, [1])],
)
model = helper.make_model(graph, producer_name="test_shape_deduction")
Expand Down
Loading