From 0fec6abbca593c3b1a39197b1a4a48a20b0e04ac Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Mon, 30 Mar 2026 00:31:29 +0800 Subject: [PATCH 1/5] [Relax][ONNX] Improve Squeeze/Unsqueeze/Slice handling with dynamic axes support --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 404 +++++++++++++++--- tests/python/relax/test_frontend_onnx.py | 196 +++++++++ 2 files changed, 533 insertions(+), 67 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index a1173171252b..106a3c83d748 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -690,38 +690,55 @@ def _impl_v1(cls, bb, inputs, attr, params): def _impl_v13(cls, bb, inputs, attr, params): data = inputs[0] axes = get_constant(inputs[1], params) + data_ndim = _get_known_tensor_rank(data) - # Handle ONNX shape inference if isinstance(data, relax.PrimValue) and isinstance(axes, relax.Constant): - axes = axes.data.numpy().tolist() - if axes == [0]: + constant_axes = _normalize_constant_axes( + list(map(int, axes.data.numpy().tolist())), 1, "Unsqueeze" + ) + if constant_axes == [0]: return relax.ShapeExpr([data.value]) - else: - raise NotImplementedError( - "Unsqueeze with symbolic axes and non-zero axes is not supported." - ) - # If input is a constant, compute directly + raise NotImplementedError( + "Unsqueeze with symbolic scalar inputs only supports axis 0." + ) if isinstance(data, relax.Constant) and isinstance(axes, relax.Constant): - axes = axes.data.numpy().tolist() + constant_axes = _normalize_constant_axes( + list(map(int, axes.data.numpy().tolist())), + data.data.numpy().ndim + axes.data.numpy().size, + "Unsqueeze", + ) + constant_axes = sorted(constant_axes) expanded = data.data.numpy() if len(expanded.shape) == 0: - # Special case implying input is a scalar, wrap it as a list. - if 0 in axes: - axes.remove(0) expanded = [expanded] - for axis in axes: + constant_axes = [axis - 1 for axis in constant_axes if axis != 0] + for axis in constant_axes: expanded = _np.expand_dims(expanded, axis=axis) return relax.const(expanded, data.struct_info.dtype) if isinstance(axes, relax.Constant): - constant_axes = list(axes.data.numpy()) - constant_axes = list(map(int, constant_axes)) + if data_ndim is None: + raise ValueError("Unsqueeze requires a statically known input rank.") + constant_axes = _normalize_constant_axes( + list(map(int, axes.data.numpy().tolist())), + data_ndim + axes.data.numpy().size, + "Unsqueeze", + ) constant_axes = sorted(constant_axes) for axis in constant_axes: data = relax.op.expand_dims(data, axis=axis) return data - raise NotImplementedError("Unsqueeze with dynamic axes is not supported.") + if data_ndim is None: + raise ValueError("Unsqueeze with dynamic axes requires a statically known input rank.") + axes_len = _get_known_tensor_length(axes) + if axes_len is None: + raise ValueError("Unsqueeze requires a statically known axes length.") + data_shape = bb.normalize(relax.op.shape_of(data)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) + output_shape_tensor = _build_unsqueezed_shape_tensor(bb, data_shape_tensor, axes, data_ndim) + output_shape = _tensor_to_shape_expr(bb, output_shape_tensor, data_ndim + axes_len, "unsqueeze_dim") + return relax.op.reshape(data, output_shape) class Concat(OnnxOpConverter): @@ -1440,14 +1457,37 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.const(out_data, data.struct_info.dtype) if isinstance(data, relax.ShapeExpr): - if axis == (0,): + shape_tensor_ndim = 1 + if axis is None: + if len(data) == 1: + return relax.PrimValue(data[0]) + return data + normalized_axes = _normalize_constant_axes(list(axis), shape_tensor_ndim, "Squeeze") + if normalized_axes == [0] and len(data) == 1: return relax.PrimValue(data[0]) - else: - raise NotImplementedError( - "Squeeze with symbolic axes and non-zero axes is not supported." - ) + raise NotImplementedError( + "Squeeze on symbolic shape tensors only supports removing the sole axis." + ) + + if axis is None: + return relax.op.squeeze(data) + + if isinstance(axis, tuple): + return relax.op.squeeze(data, list(axis)) - return relax.op.squeeze(data, axis) + data_ndim = _get_known_tensor_rank(data) + if data_ndim is None: + raise ValueError("Squeeze with dynamic axes requires a statically known input rank.") + axes_len = _get_known_tensor_length(axis) + if axes_len is None: + raise ValueError("Squeeze requires a statically known axes length.") + data_shape = bb.normalize(relax.op.shape_of(data)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) + output_shape_tensor = _build_squeezed_shape_tensor(bb, data_shape_tensor, axis, data_ndim) + output_shape = _tensor_to_shape_expr( + bb, output_shape_tensor, data_ndim - axes_len, "squeeze_dim" + ) + return relax.op.reshape(data, output_shape) class Constant(OnnxOpConverter): @@ -1844,68 +1884,298 @@ def get_prim_value_list(values): return new_values +def _get_known_tensor_rank(expr: relax.Expr) -> int | None: + """Return the statically known rank of an expression when available.""" + + if isinstance(expr, relax.Constant): + return len(expr.data.numpy().shape) + if isinstance(expr, relax.ShapeExpr): + return 1 + if isinstance(expr, relax.PrimValue): + return 0 + struct_info = expr.struct_info + if isinstance(struct_info, relax.TensorStructInfo): + return None if struct_info.ndim == -1 else struct_info.ndim + return None + + +def _get_known_tensor_length(expr: relax.Expr | None) -> int | None: + """Return the statically known length of a 1-D tensor-like expression.""" + + if expr is None: + return None + if isinstance(expr, relax.Constant): + np_value = expr.data.numpy() + if np_value.ndim != 1: + raise ValueError(f"Expected a 1-D tensor, but got ndim={np_value.ndim}.") + return int(np_value.shape[0]) + if isinstance(expr, relax.ShapeExpr): + return len(expr.values) + if isinstance(expr, relax.PrimValue): + return 1 + struct_info = expr.struct_info + if not isinstance(struct_info, relax.TensorStructInfo): + return None + if struct_info.ndim != -1 and struct_info.ndim != 1: + raise ValueError(f"Expected a 1-D tensor, but got ndim={struct_info.ndim}.") + if struct_info.ndim != 1: + return None + if isinstance(struct_info.shape, relax.ShapeExpr): + dim = struct_info.shape.values[0] + if isinstance(dim, tirx.IntImm): + return int(dim.value) + if isinstance(dim, int): + return dim + return None + + +def _normalize_constant_axes(axes: list[int], rank: int, op_name: str) -> list[int]: + """Normalize a list of constant axes and validate their uniqueness.""" + + normalized_axes = [] + for axis in axes: + if axis < 0: + axis += rank + if axis < 0 or axis >= rank: + raise ValueError(f"{op_name} axis {axis} is out of range for rank {rank}.") + normalized_axes.append(axis) + if len(normalized_axes) != len(set(normalized_axes)): + raise ValueError(f"{op_name} axes must be unique.") + return normalized_axes + + +def _as_int64_tensor(bb: relax.BlockBuilder, expr: relax.Expr) -> relax.Expr: + """Convert a tensor-like expression to an int64 tensor expression.""" + + if isinstance(expr, relax.ShapeExpr): + return bb.normalize(relax.op.shape_to_tensor(expr)) + if isinstance(expr, relax.PrimValue): + return bb.normalize(relax.op.full((1,), expr, dtype="int64")) + if isinstance(expr, relax.Constant): + if expr.struct_info.dtype == "int64": + return expr + return bb.normalize(relax.op.astype(expr, "int64")) + if isinstance(expr.struct_info, relax.TensorStructInfo) and expr.struct_info.dtype != "int64": + return bb.normalize(relax.op.astype(expr, "int64")) + return expr + + +def _tensor_to_shape_expr( + bb: relax.BlockBuilder, shape_tensor: relax.Expr, shape_ndim: int, prefix: str +) -> relax.ShapeExpr: + """Convert a statically sized int64 tensor into a ShapeExpr.""" + + shape_tensor = bb.match_cast(shape_tensor, relax.TensorStructInfo([shape_ndim], "int64")) + shape_dataflow_var = bb.emit(relax.op.tensor_to_shape(shape_tensor)) + shape_vars = [tirx.Var(f"{prefix}_{i}", "int64") for i in range(shape_ndim)] + bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) + return relax.ShapeExpr(shape_vars) + + +def _build_unsqueezed_shape_tensor( + bb: relax.BlockBuilder, data_shape_tensor: relax.Expr, axes: relax.Expr, data_ndim: int +) -> relax.Expr: + """Build the output shape tensor for Unsqueeze with runtime axes.""" + + axes = _as_int64_tensor(bb, axes) + axes_len = _get_known_tensor_length(axes) + if axes_len is None: + raise ValueError("Unsqueeze requires a statically known axes length.") + + output_ndim = data_ndim + axes_len + axes = bb.normalize( + relax.op.where( + relax.op.less(axes, relax.const(0, "int64")), + relax.op.add(axes, relax.const(output_ndim, "int64")), + axes, + ) + ) + positions = relax.op.arange(output_ndim, dtype="int64") + positions = bb.normalize(relax.op.expand_dims(positions, axis=1)) + axes = bb.normalize(relax.op.expand_dims(axes, axis=0)) + insert_mask = bb.normalize( + relax.op.sum(relax.op.astype(relax.op.equal(positions, axes), "int64"), axis=1) + ) + keep_mask = bb.normalize(relax.op.subtract(relax.const(1, "int64"), insert_mask)) + input_indices = bb.normalize( + relax.op.subtract(relax.op.cumsum(keep_mask, axis=0), relax.const(1, "int64")) + ) + safe_indices = bb.normalize( + relax.op.where( + relax.op.less(input_indices, relax.const(0, "int64")), + relax.const(0, "int64"), + input_indices, + ) + ) + kept_dims = bb.normalize(relax.op.take(data_shape_tensor, safe_indices, axis=0)) + return bb.normalize( + relax.op.where( + relax.op.greater(insert_mask, relax.const(0, "int64")), + relax.const(1, "int64"), + kept_dims, + ) + ) + + +def _build_squeezed_shape_tensor( + bb: relax.BlockBuilder, data_shape_tensor: relax.Expr, axes: relax.Expr, data_ndim: int +) -> relax.Expr: + """Build the output shape tensor for Squeeze with runtime axes.""" + + axes = _as_int64_tensor(bb, axes) + axes = bb.normalize( + relax.op.where( + relax.op.less(axes, relax.const(0, "int64")), + relax.op.add(axes, relax.const(data_ndim, "int64")), + axes, + ) + ) + positions = relax.op.arange(data_ndim, dtype="int64") + positions = bb.normalize(relax.op.expand_dims(positions, axis=1)) + axes = bb.normalize(relax.op.expand_dims(axes, axis=0)) + remove_mask = bb.normalize( + relax.op.sum(relax.op.astype(relax.op.equal(positions, axes), "int64"), axis=1) + ) + keep_mask = bb.normalize( + relax.op.equal(remove_mask, relax.const(0, "int64")) + ) + keep_indices = bb.normalize(relax.op.nonzero(keep_mask)) + num_keep_dims = tirx.Var("squeeze_num_keep_dims", "int64") + keep_indices = bb.match_cast(keep_indices, relax.TensorStructInfo([1, num_keep_dims], "int64")) + keep_indices = bb.normalize(relax.op.reshape(keep_indices, [-1])) + return bb.normalize(relax.op.take(data_shape_tensor, keep_indices, axis=0)) + + class Slice(OnnxOpConverter): """Converts an onnx Splice node into an equivalent Relax expression.""" @classmethod def _impl_v13(cls, bb, inputs, attr, params): - # TODO (jwfromm) currently only supports constant parameters. data = inputs[0] starts = get_constant(inputs[1], params) ends = get_constant(inputs[2], params) axes = get_constant(inputs[3], params) steps = get_constant(inputs[4], params) - if not all( - [ - ( - isinstance(param, relax.Constant | relax.ShapeExpr | relax.PrimValue) - or param is None + all_constant_params = all( + isinstance(param, relax.Constant | relax.ShapeExpr | relax.PrimValue) or param is None + for param in [starts, ends, axes, steps] + ) + if all_constant_params: + starts = get_prim_expr_list(starts) + ends = get_prim_expr_list(ends) + if len(starts) != len(ends): + raise ValueError( + f"Slice expects starts and ends to have the same length, but got " + f"{len(starts)} and {len(ends)}." ) - for param in [starts, ends, axes, steps] - ] - ): - raise ValueError("Only constant Slice parameters are currently supported.") - # Convert parameters to constant lists. - starts = get_prim_expr_list(starts) - ends = get_prim_expr_list(ends) - if axes is not None: - axes = get_prim_expr_list(axes) - else: - axes = list(range(len(starts))) - # Convert negative axis to positive if needed. - for i, axis in enumerate(axes): - if axis < 0: - axes[i] = axis + len(data.struct_info.shape) - if steps is not None: - steps = get_prim_expr_list(steps) - else: - steps = [1] * len(axes) - # If input is a shape tensor, we can directly extract it. - if isinstance(data, relax.ShapeExpr): - shape_data = list(data) - # Starts, ends, and steps must be 1-d for shape operation. - assert all(len(i) == 1 for i in [starts, ends, steps]) - sliced_values = shape_data[starts[0] : ends[0] : steps[0]] - - if all([isinstance(val, tirx.IntImm | int) for val in sliced_values]): - return relax.const([x.value for x in sliced_values], "int64") + if axes is not None: + axes = get_prim_expr_list(axes) + if len(axes) != len(starts): + raise ValueError( + f"Slice expects axes and starts to have the same length, but got " + f"{len(axes)} and {len(starts)}." + ) + else: + axes = list(range(len(starts))) + + data_ndim = _get_known_tensor_rank(data) + if data_ndim is None: + raise ValueError("Slice requires a statically known input rank.") + axes = _normalize_constant_axes(list(axes), data_ndim, "Slice") + if steps is not None: + steps = get_prim_expr_list(steps) + if len(steps) != len(starts): + raise ValueError( + f"Slice expects steps and starts to have the same length, but got " + f"{len(steps)} and {len(starts)}." + ) else: + steps = [1] * len(axes) + if isinstance(data, relax.ShapeExpr): + shape_data = list(data) + assert all(len(i) == 1 for i in [starts, ends, steps]) + sliced_values = shape_data[starts[0] : ends[0] : steps[0]] + + if all([isinstance(val, tirx.IntImm | int) for val in sliced_values]): + return relax.const([x.value for x in sliced_values], "int64") return relax.ShapeExpr(sliced_values) - # If all `starts`, `ends`, and `steps` are constant, use strict mode - # Otherwise, we assume the slice is inbound. - assume_inbound = not all( - [isinstance(param, tirx.IntImm | int) for param in [*starts, *ends, *steps]] - ) + assume_inbound = not all( + [isinstance(param, tirx.IntImm | int) for param in [*starts, *ends, *steps]] + ) + starts = get_prim_value_list(starts) + ends = get_prim_value_list(ends) + steps = get_prim_value_list(steps) + + return relax.op.strided_slice( + data, axes, starts, ends, steps, assume_inbound=assume_inbound + ) + + data_ndim = _get_known_tensor_rank(data) + if data_ndim is None: + raise ValueError("Slice with dynamic parameters requires a statically known input rank.") - # Converting PrimExpr to PrimValue since relax.op.strided_slice does not accept PrimExpr - starts = get_prim_value_list(starts) - ends = get_prim_value_list(ends) - steps = get_prim_value_list(steps) + data_expr = data + if isinstance(data, relax.ShapeExpr): + data_expr = bb.normalize(relax.op.shape_to_tensor(data)) + + starts_tensor = _as_int64_tensor(bb, starts) + ends_tensor = _as_int64_tensor(bb, ends) + axes_len = _get_known_tensor_length(starts_tensor) + if axes_len is None: + raise ValueError("Slice requires a statically known starts length.") + ends_len = _get_known_tensor_length(ends_tensor) + if ends_len is None: + raise ValueError("Slice requires a statically known ends length.") + if ends_len != axes_len: + raise ValueError( + f"Slice expects starts and ends to have the same length, but got " + f"{axes_len} and {ends_len}." + ) - return relax.op.strided_slice( - data, axes, starts, ends, steps, assume_inbound=assume_inbound + if axes is None: + axes_tensor = relax.op.arange(axes_len, dtype="int64") + else: + axes_tensor = _as_int64_tensor(bb, axes) + axes_tensor_len = _get_known_tensor_length(axes_tensor) + if axes_tensor_len is None: + raise ValueError("Slice requires a statically known axes length.") + if axes_tensor_len != axes_len: + raise ValueError( + f"Slice expects axes and starts to have the same length, but got " + f"{axes_tensor_len} and {axes_len}." + ) + if steps is None: + steps_tensor = relax.const(_np.ones((axes_len,), dtype="int64"), "int64") + else: + steps_tensor = _as_int64_tensor(bb, steps) + steps_len = _get_known_tensor_length(steps_tensor) + if steps_len is None: + raise ValueError("Slice requires a statically known steps length.") + if steps_len != axes_len: + raise ValueError( + f"Slice expects steps and starts to have the same length, but got " + f"{steps_len} and {axes_len}." + ) + + axes_tensor = bb.normalize( + relax.op.where( + relax.op.less(axes_tensor, relax.const(0, "int64")), + relax.op.add(axes_tensor, relax.const(data_ndim, "int64")), + axes_tensor, + ) + ) + + data_shape = bb.normalize(relax.op.shape_of(data_expr)) + data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) + full_starts = relax.const(_np.zeros((data_ndim,), dtype="int64"), "int64") + full_steps = relax.const(_np.ones((data_ndim,), dtype="int64"), "int64") + full_starts = bb.normalize(relax.op.scatter_elements(full_starts, axes_tensor, starts_tensor)) + full_ends = bb.normalize( + relax.op.scatter_elements(data_shape_tensor, axes_tensor, ends_tensor) ) + full_steps = bb.normalize(relax.op.scatter_elements(full_steps, axes_tensor, steps_tensor)) + return relax.op.dynamic_strided_slice(data_expr, full_starts, full_ends, full_steps) class Pad(OnnxOpConverter): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 887533f26139..de2b6e384f67 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -904,6 +904,66 @@ def test_unsqueeze(): check_correctness(model) +def test_unsqueeze_dynamic_axes(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_dynamic_axes", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 32, 1])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_dynamic_axes_test") + inputs = { + "a": rg.standard_normal(size=[32, 32]).astype("float32"), + "axes": np.array([-1, 0], dtype="int64"), + } + check_correctness(model, inputs, opset=13) + + +def test_unsqueeze_dynamic_axes_ir(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_dynamic_axes_ir", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 32, 1])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_dynamic_axes_ir_test") + tvm_model = from_onnx(model, opset=13, keep_params_in_input=True) + call_ops = collect_relax_call_ops(tvm_model["main"]) + + assert "relax.tensor_to_shape" in call_ops + assert "relax.reshape" in call_ops + + +def test_unsqueeze_dynamic_axes_rank_validation(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_dynamic_axes_rank_validation", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [1, 2]), + ], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 32, 1])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_dynamic_axes_rank_validation_test") + with pytest.raises(ValueError, match="Expected a 1-D tensor"): + from_onnx(model, opset=13, keep_params_in_input=True) + + def test_unsqueeze_v1(): # https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1 unsqueeze_node = helper.make_node("Unsqueeze", ["a"], ["b"], axes=[0, 2, 3]) @@ -1384,6 +1444,70 @@ def test_dynamic_squeeze(axis, A, B): check_correctness(model, inputs, opset=13) +def test_squeeze_dynamic_axes(): + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) + shape = [1, 32, 1, 32] + + graph = helper.make_graph( + [squeeze_node], + "squeeze_dynamic_axes_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + + model = helper.make_model(graph, producer_name="squeeze_dynamic_axes_test") + inputs = { + "x": rg.standard_normal(size=shape).astype("float32"), + "axes": np.array([-4, 2], dtype="int64"), + } + check_correctness(model, inputs, opset=13) + + +def test_squeeze_dynamic_axes_ir(): + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) + shape = [1, 32, 1, 32] + + graph = helper.make_graph( + [squeeze_node], + "squeeze_dynamic_axes_ir", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + + model = helper.make_model(graph, producer_name="squeeze_dynamic_axes_ir_test") + tvm_model = from_onnx(model, opset=13, keep_params_in_input=True) + call_ops = collect_relax_call_ops(tvm_model["main"]) + + assert "relax.tensor_to_shape" in call_ops + assert "relax.reshape" in call_ops + assert "relax.squeeze" not in call_ops + + +def test_squeeze_dynamic_axes_rank_validation(): + squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"]) + shape = [1, 32, 1, 32] + + graph = helper.make_graph( + [squeeze_node], + "squeeze_dynamic_axes_rank_validation", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("axes", TensorProto.INT64, [1, 2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 32])], + ) + + model = helper.make_model(graph, producer_name="squeeze_dynamic_axes_rank_validation_test") + with pytest.raises(ValueError, match="Expected a 1-D tensor"): + from_onnx(model, opset=13, keep_params_in_input=True) + + @pytest.mark.parametrize("axis", [[0]]) @pytest.mark.parametrize("A", [8, 16, 32]) def test_dynamic_shape_squeeze(axis, A): @@ -2287,6 +2411,78 @@ def verify_slice(data_shape, output_shape, starts, ends, axes=None, steps=None): # ) +def test_slice_dynamic_inputs(): + slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", "steps"], ["y"]) + + graph = helper.make_graph( + [slice_node], + "slice_dynamic_inputs_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]), + helper.make_tensor_value_info("starts", TensorProto.INT64, [2]), + helper.make_tensor_value_info("ends", TensorProto.INT64, [2]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + helper.make_tensor_value_info("steps", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 5])], + ) + + model = helper.make_model(graph, producer_name="slice_dynamic_inputs_test") + inputs = { + "x": rg.standard_normal(size=[20, 10, 5]).astype("float32"), + "starts": np.array([0, 0], dtype="int64"), + "ends": np.array([3, 10], dtype="int64"), + "axes": np.array([0, 1], dtype="int64"), + "steps": np.array([1, 1], dtype="int64"), + } + check_correctness(model, inputs, opset=13) + + +def test_slice_dynamic_inputs_ir(): + slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", "steps"], ["y"]) + + graph = helper.make_graph( + [slice_node], + "slice_dynamic_inputs_ir", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]), + helper.make_tensor_value_info("starts", TensorProto.INT64, [2]), + helper.make_tensor_value_info("ends", TensorProto.INT64, [2]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + helper.make_tensor_value_info("steps", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 5])], + ) + + model = helper.make_model(graph, producer_name="slice_dynamic_inputs_ir_test") + tvm_model = from_onnx(model, opset=13, keep_params_in_input=True) + + assert "relax.dynamic_strided_slice" in collect_relax_call_ops(tvm_model["main"]) + + +def test_slice_dynamic_inputs_length_validation(): + slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", "steps"], ["y"]) + + graph = helper.make_graph( + [slice_node], + "slice_dynamic_inputs_length_validation", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]), + helper.make_tensor_value_info("starts", TensorProto.INT64, [2]), + helper.make_tensor_value_info("ends", TensorProto.INT64, [1]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [2]), + helper.make_tensor_value_info("steps", TensorProto.INT64, [2]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 5])], + ) + + model = helper.make_model( + graph, producer_name="slice_dynamic_inputs_length_validation_test" + ) + with pytest.raises(ValueError, match="starts and ends to have the same length"): + from_onnx(model, opset=13, keep_params_in_input=True) + + def test_slice_dynamic_shape(): def verify_slice( data_shape, data_instance_shape, output_shape, starts, ends, axes=None, steps=None From 5168a67b10c19363458d5801a4894f53ecf72dd6 Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Mon, 30 Mar 2026 11:50:57 +0800 Subject: [PATCH 2/5] [Relax][ONNX] Tighten dynamic Unsqueeze/Slice validation and add structural tests - fix Slice converter docstring typo ("Splice" -> "Slice") for consistency - add explicit validation to reject zero step values in Slice for both constant and dynamic-parameter paths - add Unsqueeze negative test to reject duplicate axes - strengthen structural IR test for dynamic Slice to assert relax.dynamic_strided_slice is used and relax.strided_slice is not - add Slice negative test for zero-step input Validation: - python -m ruff check python/tvm/relax/frontend/onnx/onnx_frontend.py tests/python/relax/test_frontend_onnx.py - python -m pre_commit run --files python/tvm/relax/frontend/onnx/onnx_frontend.py tests/python/relax/test_frontend_onnx.py - python -m pytest -n 1 tests/python/relax/test_frontend_onnx.py -k "unsqueeze_dynamic_axes or unsqueeze_duplicate_axes_validation or slice_dynamic_inputs_ir or slice_dynamic_inputs_length_validation or slice_zero_step_validation" -v Result: - 7 passed --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 34 ++++++++------ tests/python/relax/test_frontend_onnx.py | 45 +++++++++++++++++-- 2 files changed, 62 insertions(+), 17 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 106a3c83d748..c921e07d0d72 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -698,9 +698,7 @@ def _impl_v13(cls, bb, inputs, attr, params): ) if constant_axes == [0]: return relax.ShapeExpr([data.value]) - raise NotImplementedError( - "Unsqueeze with symbolic scalar inputs only supports axis 0." - ) + raise NotImplementedError("Unsqueeze with symbolic scalar inputs only supports axis 0.") if isinstance(data, relax.Constant) and isinstance(axes, relax.Constant): constant_axes = _normalize_constant_axes( list(map(int, axes.data.numpy().tolist())), @@ -737,7 +735,9 @@ def _impl_v13(cls, bb, inputs, attr, params): data_shape = bb.normalize(relax.op.shape_of(data)) data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) output_shape_tensor = _build_unsqueezed_shape_tensor(bb, data_shape_tensor, axes, data_ndim) - output_shape = _tensor_to_shape_expr(bb, output_shape_tensor, data_ndim + axes_len, "unsqueeze_dim") + output_shape = _tensor_to_shape_expr( + bb, output_shape_tensor, data_ndim + axes_len, "unsqueeze_dim" + ) return relax.op.reshape(data, output_shape) @@ -2036,9 +2036,7 @@ def _build_squeezed_shape_tensor( remove_mask = bb.normalize( relax.op.sum(relax.op.astype(relax.op.equal(positions, axes), "int64"), axis=1) ) - keep_mask = bb.normalize( - relax.op.equal(remove_mask, relax.const(0, "int64")) - ) + keep_mask = bb.normalize(relax.op.equal(remove_mask, relax.const(0, "int64"))) keep_indices = bb.normalize(relax.op.nonzero(keep_mask)) num_keep_dims = tirx.Var("squeeze_num_keep_dims", "int64") keep_indices = bb.match_cast(keep_indices, relax.TensorStructInfo([1, num_keep_dims], "int64")) @@ -2047,7 +2045,7 @@ def _build_squeezed_shape_tensor( class Slice(OnnxOpConverter): - """Converts an onnx Splice node into an equivalent Relax expression.""" + """Converts an onnx Slice node into an equivalent Relax expression.""" @classmethod def _impl_v13(cls, bb, inputs, attr, params): @@ -2091,6 +2089,12 @@ def _impl_v13(cls, bb, inputs, attr, params): ) else: steps = [1] * len(axes) + if any( + (isinstance(step, int) and step == 0) + or (isinstance(step, tirx.IntImm) and int(step) == 0) + for step in steps + ): + raise ValueError("Slice step values must be non-zero.") if isinstance(data, relax.ShapeExpr): shape_data = list(data) assert all(len(i) == 1 for i in [starts, ends, steps]) @@ -2113,7 +2117,9 @@ def _impl_v13(cls, bb, inputs, attr, params): data_ndim = _get_known_tensor_rank(data) if data_ndim is None: - raise ValueError("Slice with dynamic parameters requires a statically known input rank.") + raise ValueError( + "Slice with dynamic parameters requires a statically known input rank." + ) data_expr = data if isinstance(data, relax.ShapeExpr): @@ -2157,6 +2163,8 @@ def _impl_v13(cls, bb, inputs, attr, params): f"Slice expects steps and starts to have the same length, but got " f"{steps_len} and {axes_len}." ) + if isinstance(steps_tensor, relax.Constant) and _np.any(steps_tensor.data.numpy() == 0): + raise ValueError("Slice step values must be non-zero.") axes_tensor = bb.normalize( relax.op.where( @@ -2170,7 +2178,9 @@ def _impl_v13(cls, bb, inputs, attr, params): data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape)) full_starts = relax.const(_np.zeros((data_ndim,), dtype="int64"), "int64") full_steps = relax.const(_np.ones((data_ndim,), dtype="int64"), "int64") - full_starts = bb.normalize(relax.op.scatter_elements(full_starts, axes_tensor, starts_tensor)) + full_starts = bb.normalize( + relax.op.scatter_elements(full_starts, axes_tensor, starts_tensor) + ) full_ends = bb.normalize( relax.op.scatter_elements(data_shape_tensor, axes_tensor, ends_tensor) ) @@ -2691,9 +2701,7 @@ def _impl_v20(cls, bb, inputs, attr, params): align_corners = attr.get("align_corners", 0) if align_corners != 1: - raise NotImplementedError( - "AffineGrid with align_corners=0 is not yet supported in TVM" - ) + raise NotImplementedError("AffineGrid with align_corners=0 is not yet supported in TVM") # Extract size values if isinstance(size, relax.Constant): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index de2b6e384f67..be9022366f81 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -964,6 +964,22 @@ def test_unsqueeze_dynamic_axes_rank_validation(): from_onnx(model, opset=13, keep_params_in_input=True) +def test_unsqueeze_duplicate_axes_validation(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_duplicate_axes_validation", + inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32])], + initializer=[helper.make_tensor("axes", TensorProto.INT64, [2], vals=[0, 0])], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 1, 32, 32])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_duplicate_axes_validation_test") + with pytest.raises(ValueError, match="axes must be unique"): + from_onnx(model, opset=13) + + def test_unsqueeze_v1(): # https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1 unsqueeze_node = helper.make_node("Unsqueeze", ["a"], ["b"], axes=[0, 2, 3]) @@ -2456,8 +2472,10 @@ def test_slice_dynamic_inputs_ir(): model = helper.make_model(graph, producer_name="slice_dynamic_inputs_ir_test") tvm_model = from_onnx(model, opset=13, keep_params_in_input=True) + call_ops = collect_relax_call_ops(tvm_model["main"]) - assert "relax.dynamic_strided_slice" in collect_relax_call_ops(tvm_model["main"]) + assert "relax.dynamic_strided_slice" in call_ops + assert "relax.strided_slice" not in call_ops def test_slice_dynamic_inputs_length_validation(): @@ -2476,13 +2494,32 @@ def test_slice_dynamic_inputs_length_validation(): outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 5])], ) - model = helper.make_model( - graph, producer_name="slice_dynamic_inputs_length_validation_test" - ) + model = helper.make_model(graph, producer_name="slice_dynamic_inputs_length_validation_test") with pytest.raises(ValueError, match="starts and ends to have the same length"): from_onnx(model, opset=13, keep_params_in_input=True) +def test_slice_zero_step_validation(): + slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", "steps"], ["y"]) + + graph = helper.make_graph( + [slice_node], + "slice_zero_step_validation", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5])], + initializer=[ + helper.make_tensor("starts", TensorProto.INT64, [2], vals=[0, 0]), + helper.make_tensor("ends", TensorProto.INT64, [2], vals=[3, 10]), + helper.make_tensor("axes", TensorProto.INT64, [2], vals=[0, 1]), + helper.make_tensor("steps", TensorProto.INT64, [2], vals=[1, 0]), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 5])], + ) + + model = helper.make_model(graph, producer_name="slice_zero_step_validation_test") + with pytest.raises(ValueError, match="step values must be non-zero"): + from_onnx(model, opset=13) + + def test_slice_dynamic_shape(): def verify_slice( data_shape, data_instance_shape, output_shape, starts, ends, axes=None, steps=None From 8242e43bbba5a814f4bd8136e8f3f22608febe67 Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Mon, 30 Mar 2026 13:54:59 +0800 Subject: [PATCH 3/5] [Relax][ONNX] Refactor constant Unsqueeze scalar path and add regression test - refactor constant Unsqueeze lowering to build target shape then reshape instead of scalar special-casing plus repeated expand_dims - add Unsqueeze scalar-input regression test to cover the new path - restore helper used by structural ONNX frontend IR checks Validation: - python -m ruff check python/tvm/relax/frontend/onnx/onnx_frontend.py tests/python/relax/test_frontend_onnx.py - python -m pre_commit run --files python/tvm/relax/frontend/onnx/onnx_frontend.py tests/python/relax/test_frontend_onnx.py - python -m pytest -n 1 tests/python/relax/test_frontend_onnx.py -k "unsqueeze_scalar_input or unsqueeze_dynamic_axes or unsqueeze_duplicate_axes_validation or slice_dynamic_inputs_ir or slice_dynamic_inputs_length_validation or slice_zero_step_validation" -v Result: - 8 passed --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 14 ++++++---- tests/python/relax/test_frontend_onnx.py | 27 +++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index c921e07d0d72..ee3ec005dd81 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -707,11 +707,15 @@ def _impl_v13(cls, bb, inputs, attr, params): ) constant_axes = sorted(constant_axes) expanded = data.data.numpy() - if len(expanded.shape) == 0: - expanded = [expanded] - constant_axes = [axis - 1 for axis in constant_axes if axis != 0] - for axis in constant_axes: - expanded = _np.expand_dims(expanded, axis=axis) + output_rank = expanded.ndim + len(constant_axes) + new_shape = [] + input_dims_iter = iter(expanded.shape) + for i in range(output_rank): + if i in constant_axes: + new_shape.append(1) + else: + new_shape.append(next(input_dims_iter)) + expanded = expanded.reshape(new_shape) return relax.const(expanded, data.struct_info.dtype) if isinstance(axes, relax.Constant): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index be9022366f81..f8f4508ddc86 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -197,6 +197,17 @@ def _check_output(tvm_out, ort_out): _check_output(tvm_out, ort_out) +def collect_relax_call_ops(relax_func: relax.Function) -> set[str]: + call_ops = set() + + def _visit(expr): + if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op): + call_ops.add(expr.op.name) + + relax.analysis.post_order_visit(relax_func.body, _visit) + return call_ops + + @pytest.mark.parametrize( "input_names, expected_names", [ @@ -904,6 +915,22 @@ def test_unsqueeze(): check_correctness(model) +def test_unsqueeze_scalar_input(): + unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) + + graph = helper.make_graph( + [unsqueeze_node], + "unsqueeze_scalar_input", + inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [])], + initializer=[helper.make_tensor("axes", TensorProto.INT64, [2], vals=[0, 1])], + outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 1])], + ) + + model = helper.make_model(graph, producer_name="unsqueeze_scalar_input_test") + inputs = {"a": np.array(3.0, dtype="float32")} + check_correctness(model, inputs, opset=13) + + def test_unsqueeze_dynamic_axes(): unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"]) From 29dc96522bf344cc7cc1e2cc374f6b2f6528f707 Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Tue, 31 Mar 2026 00:42:30 +0800 Subject: [PATCH 4/5] [Relax][ONNX] tighten Slice dynamic checks and improve axis/length validation --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 13 ++++++----- tests/python/relax/test_frontend_onnx.py | 22 +++++++++++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index ee3ec005dd81..04a971472c9b 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1920,10 +1920,10 @@ def _get_known_tensor_length(expr: relax.Expr | None) -> int | None: struct_info = expr.struct_info if not isinstance(struct_info, relax.TensorStructInfo): return None - if struct_info.ndim != -1 and struct_info.ndim != 1: - raise ValueError(f"Expected a 1-D tensor, but got ndim={struct_info.ndim}.") - if struct_info.ndim != 1: + if struct_info.ndim == -1: return None + if struct_info.ndim != 1: + raise ValueError(f"Expected a 1-D tensor, but got ndim={struct_info.ndim}.") if isinstance(struct_info.shape, relax.ShapeExpr): dim = struct_info.shape.values[0] if isinstance(dim, tirx.IntImm): @@ -1938,10 +1938,11 @@ def _normalize_constant_axes(axes: list[int], rank: int, op_name: str) -> list[i normalized_axes = [] for axis in axes: + original_axis = axis if axis < 0: axis += rank if axis < 0 or axis >= rank: - raise ValueError(f"{op_name} axis {axis} is out of range for rank {rank}.") + raise ValueError(f"{op_name} axis {original_axis} is out of range for rank {rank}.") normalized_axes.append(axis) if len(normalized_axes) != len(set(normalized_axes)): raise ValueError(f"{op_name} axes must be unique.") @@ -2125,9 +2126,9 @@ def _impl_v13(cls, bb, inputs, attr, params): "Slice with dynamic parameters requires a statically known input rank." ) - data_expr = data if isinstance(data, relax.ShapeExpr): - data_expr = bb.normalize(relax.op.shape_to_tensor(data)) + raise ValueError("Slice with dynamic parameters does not support ShapeExpr input.") + data_expr = data starts_tensor = _as_int64_tensor(bb, starts) ends_tensor = _as_int64_tensor(bb, ends) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index f8f4508ddc86..eaeb8573afea 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2526,6 +2526,28 @@ def test_slice_dynamic_inputs_length_validation(): from_onnx(model, opset=13, keep_params_in_input=True) +def test_slice_dynamic_shape_expr_input_validation(): + shape_node = helper.make_node("Shape", ["x"], ["y"]) + slice_node = helper.make_node("Slice", ["y", "starts", "ends", "axes", "steps"], ["z"]) + + graph = helper.make_graph( + [shape_node, slice_node], + "slice_dynamic_shape_expr_input_validation", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]), + helper.make_tensor_value_info("starts", TensorProto.INT64, [1]), + helper.make_tensor_value_info("ends", TensorProto.INT64, [1]), + helper.make_tensor_value_info("axes", TensorProto.INT64, [1]), + helper.make_tensor_value_info("steps", TensorProto.INT64, [1]), + ], + outputs=[helper.make_tensor_value_info("z", TensorProto.INT64, [1])], + ) + + model = helper.make_model(graph, producer_name="slice_dynamic_shape_expr_input_validation_test") + with pytest.raises(ValueError, match="does not support ShapeExpr input"): + from_onnx(model, opset=13, keep_params_in_input=True) + + def test_slice_zero_step_validation(): slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", "steps"], ["y"]) From 1b50832a81d57c9b2fe10c88d5cd66a55e5fa717 Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Tue, 31 Mar 2026 13:48:47 +0800 Subject: [PATCH 5/5] [Relax][ONNX] Fix collect_relax_call_ops list semantics The previous conflict merge incorrectly changed the return type of collect_relax_call_ops from list to set. This caused test failures because: 1. Multiple tests rely on .count() to verify the exact number of operator insertions (e.g., ensuring both inputs of MatMulInteger16 are casted). 2. Several tests use '== []' to assert that no operators were generated (verifying constant folding), which is type-incompatible with set(). Restoring list semantics ensures both structural validation and type consistency in tests. --- tests/python/relax/test_frontend_onnx.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 6647a54508ea..8b1292617896 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -232,15 +232,15 @@ def run_in_tvm( return vm.get_outputs("main") -def collect_relax_call_ops(relax_func: relax.Function) -> set[str]: - call_ops = set() +def collect_relax_call_ops(func: relax.Function) -> list[str]: + op_names: list[str] = [] - def _visit(expr): + def fvisit(expr: relax.Expr) -> None: if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op): - call_ops.add(expr.op.name) + op_names.append(expr.op.name) - relax.analysis.post_order_visit(relax_func.body, _visit) - return call_ops + relax.analysis.post_order_visit(func.body, fvisit) + return list(op_names) def collect_scalar_constants(func: relax.Function) -> list[bool | int | float]: