Skip to content

Commit 166a482

Browse files
author
Siyuan Feng
committed
[Relax] Update ONNX frontend for unique, nonzero and compress
This PR updates the ONNX frontend: - Add match cast for unique and nonzero operators, enabling further import of ONNX models. - Add support for compress operator. - Fix the shape of the output tensor for nonzero operator.
1 parent d5b9f5c commit 166a482

5 files changed

Lines changed: 81 additions & 9 deletions

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,32 @@ def _impl_v18(cls, bb, inputs, attr, params):
833833
return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction)
834834

835835

836+
class Compress(OnnxOpConverter):
837+
"""Convert an onnx Compress node into an equivalent Relax expression."""
838+
839+
@classmethod
840+
def _impl_v11(cls, bb, inputs, attr, params):
841+
tensor, condition = inputs
842+
axis = attr.get("axis", None)
843+
844+
# Change one hot tensor to indices e.g. [0, 1, 1, 0, 1] -> [1, 2, 4]
845+
if condition.struct_info.dtype != "bool":
846+
raise ValueError("Condition tensor is expected to be a boolean tensor")
847+
if condition.struct_info.ndim != 1:
848+
raise ValueError("Condition tensor is expected to be a 1D boolean tensor")
849+
indices = relax.op.nonzero(condition)
850+
num_nonzero = tir.Var("num_nonzero", "int64")
851+
indices = bb.match_cast(indices, relax.TensorStructInfo([1, num_nonzero], "int64"))
852+
indices = relax.op.reshape(indices, [-1])
853+
854+
if axis is not None:
855+
return relax.op.take(tensor, indices, axis=axis)
856+
857+
# if axis is None, flatten input tensor before selection
858+
tensor = relax.op.reshape(tensor, (-1,))
859+
return relax.op.take(tensor, indices, axis=0)
860+
861+
836862
class Size(OnnxOpConverter):
837863
"""Convert an onnx Size node into an equivalent Relax expression."""
838864

@@ -2726,15 +2752,35 @@ def _impl_v11(cls, bb, inputs, attr, params):
27262752
axis = attr.get("axis", None)
27272753
sorted = bool(attr.get("sorted", 1))
27282754
# TODO(tvm-team): Add support for return_index, return_inverse, return_counts
2729-
return relax.op.unique(data, sorted=sorted, axis=axis)
2755+
unique = relax.op.unique(data, sorted=sorted, axis=axis)
2756+
unique_numbers = tir.Var("unique_numbers", "int64")
2757+
input_shape = data.struct_info.shape
2758+
dtype = data.struct_info.dtype
2759+
2760+
if axis is None:
2761+
# flatten the input tensor
2762+
return bb.match_cast(unique, relax.TensorStructInfo((unique_numbers,), dtype))
2763+
2764+
axis = axis if axis >= 0 else len(input_shape) + axis
2765+
if axis < 0 or axis >= len(input_shape):
2766+
raise ValueError(f"Axis {axis} is out of bounds")
2767+
output_shape = [
2768+
input_shape[i] if i != axis else unique_numbers for i in range(len(input_shape))
2769+
]
2770+
return bb.match_cast(unique, relax.TensorStructInfo(output_shape, dtype))
27302771

27312772

27322773
class NonZero(OnnxOpConverter):
27332774
"""Converts an onnx NonZero node into an equivalent Relax expression."""
27342775

27352776
@classmethod
27362777
def _impl_v9(cls, bb, inputs, attr, params):
2737-
return relax.op.nonzero(inputs[0])
2778+
ndim = inputs[0].struct_info.ndim
2779+
ndim = 1 if ndim == 0 else ndim
2780+
nonzero_numbers = tir.Var("nonzero_numbers", "int64")
2781+
return bb.match_cast(
2782+
relax.op.nonzero(inputs[0]), relax.TensorStructInfo((ndim, nonzero_numbers), "int64")
2783+
)
27382784

27392785

27402786
class HardSigmoid(OnnxOpConverter):
@@ -3075,7 +3121,7 @@ def _get_convert_map():
30753121
"Scatter": Scatter,
30763122
"ScatterElements": ScatterElements,
30773123
"ScatterND": ScatterND,
3078-
# "Compress": Compress,
3124+
"Compress": Compress,
30793125
"Size": Size,
30803126
"EyeLike": EyeLike,
30813127
# Normalization

python/tvm/relax/op/set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def nonzero(x: Expr) -> Expr:
123123
Returns
124124
-------
125125
result : relax.Expr
126-
A (n+1)-D tensor containing indices of non-zero elements.
126+
A 2-D tensor containing indices of non-zero elements.
127127
128128
Note
129129
----

src/relax/op/tensor/set.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,7 @@ TVM_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero);
148148

149149
StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) {
150150
TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
151-
// Cheat zero dim scalar as 1-dim.
152-
int dim = data_sinfo->IsUnknownNdim() ? kUnknownNDim : std::max(1, data_sinfo->ndim) + 1;
153-
return TensorStructInfo(DataType::Int(64), dim, data_sinfo->vdevice);
151+
return TensorStructInfo(DataType::Int(64), 2, data_sinfo->vdevice);
154152
}
155153

156154
TVM_REGISTER_OP("relax.nonzero")

tests/python/relax/test_frontend_onnx.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,34 @@ def verify_scatter_nd(data_shape, indices_shape, updates_shape):
601601
verify_scatter_nd([10], [5, 1], [5])
602602

603603

604+
@pytest.mark.parametrize("tensor_shape", [[32, 32]])
605+
@pytest.mark.parametrize("condition_shape", [None, [8], [16]])
606+
@pytest.mark.parametrize("axis", [None, 0, 1])
607+
def test_compress(
608+
tensor_shape: List[int],
609+
condition_shape: Optional[List[int]],
610+
axis: Optional[int],
611+
):
612+
if condition_shape is None and axis is None:
613+
pytest.skip("Either condition_shape or axis must be specified")
614+
if condition_shape is None:
615+
condition_shape = [tensor_shape[axis]]
616+
compress_node = helper.make_node("Compress", ["tensor", "condition"], ["output"], axis=axis)
617+
graph = helper.make_graph(
618+
[compress_node],
619+
"compress_test",
620+
inputs=[
621+
helper.make_tensor_value_info("tensor", TensorProto.FLOAT, tensor_shape),
622+
helper.make_tensor_value_info("condition", TensorProto.BOOL, condition_shape),
623+
],
624+
outputs=[
625+
helper.make_tensor_value_info("output", TensorProto.FLOAT, [])
626+
], # shape is unknown
627+
)
628+
model = helper.make_model(graph, producer_name="compress_test")
629+
check_correctness(model, opset=11)
630+
631+
604632
def test_size():
605633
test_node = helper.make_node("Size", ["x"], ["y"])
606634
graph = helper.make_graph(
@@ -2478,7 +2506,7 @@ def test_unique(axis: Optional[int], sorted: int):
24782506
check_correctness(model)
24792507

24802508

2481-
@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (4, 5, 6)])
2509+
@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (4, 5, 6), (7, 8, 9, 10)])
24822510
def test_nonzero(shape):
24832511
verify_unary("NonZero", shape, input_dtype=TensorProto.BOOL, output_dtype=TensorProto.INT64)
24842512

tests/python/relax/test_op_set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def test_nonzero_infer_struct_info(shape):
875875
_check_inference(
876876
bb,
877877
relax.op.nonzero(x0),
878-
relax.TensorStructInfo(ndim=len(shape) + 1, dtype="int64"),
878+
relax.TensorStructInfo(ndim=2, dtype="int64"),
879879
)
880880

881881

0 commit comments

Comments
 (0)