Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[Relax][ONNX] Refactor constant Unsqueeze scalar path and add regress…
…ion test

- refactor constant Unsqueeze lowering to build target shape then reshape
  instead of scalar special-casing plus repeated expand_dims
- add Unsqueeze scalar-input regression test to cover the new path
- restore helper used by structural ONNX frontend IR checks

Validation:
- python -m ruff check python/tvm/relax/frontend/onnx/onnx_frontend.py tests/python/relax/test_frontend_onnx.py
- python -m pre_commit run --files python/tvm/relax/frontend/onnx/onnx_frontend.py tests/python/relax/test_frontend_onnx.py
- python -m pytest -n 1 tests/python/relax/test_frontend_onnx.py -k "unsqueeze_scalar_input or unsqueeze_dynamic_axes or unsqueeze_duplicate_axes_validation or slice_dynamic_inputs_ir or slice_dynamic_inputs_length_validation or slice_zero_step_validation" -v

Result:
- 8 passed
  • Loading branch information
Aharrypotter committed Mar 30, 2026
commit 8242e43bbba5a814f4bd8136e8f3f22608febe67
14 changes: 9 additions & 5 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,11 +707,15 @@ def _impl_v13(cls, bb, inputs, attr, params):
)
constant_axes = sorted(constant_axes)
expanded = data.data.numpy()
if len(expanded.shape) == 0:
expanded = [expanded]
constant_axes = [axis - 1 for axis in constant_axes if axis != 0]
for axis in constant_axes:
expanded = _np.expand_dims(expanded, axis=axis)
output_rank = expanded.ndim + len(constant_axes)
new_shape = []
input_dims_iter = iter(expanded.shape)
for i in range(output_rank):
if i in constant_axes:
new_shape.append(1)
else:
new_shape.append(next(input_dims_iter))
expanded = expanded.reshape(new_shape)
return relax.const(expanded, data.struct_info.dtype)

if isinstance(axes, relax.Constant):
Expand Down
27 changes: 27 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,17 @@ def _check_output(tvm_out, ort_out):
_check_output(tvm_out, ort_out)


def collect_relax_call_ops(relax_func: relax.Function) -> set[str]:
call_ops = set()

def _visit(expr):
if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op):
call_ops.add(expr.op.name)

relax.analysis.post_order_visit(relax_func.body, _visit)
return call_ops


@pytest.mark.parametrize(
"input_names, expected_names",
[
Expand Down Expand Up @@ -904,6 +915,22 @@ def test_unsqueeze():
check_correctness(model)


def test_unsqueeze_scalar_input():
unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])

graph = helper.make_graph(
[unsqueeze_node],
"unsqueeze_scalar_input",
inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [])],
initializer=[helper.make_tensor("axes", TensorProto.INT64, [2], vals=[0, 1])],
outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 1])],
)

model = helper.make_model(graph, producer_name="unsqueeze_scalar_input_test")
inputs = {"a": np.array(3.0, dtype="float32")}
check_correctness(model, inputs, opset=13)


def test_unsqueeze_dynamic_axes():
unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])

Expand Down