@@ -833,6 +833,32 @@ def _impl_v18(cls, bb, inputs, attr, params):
833833 return relax .op .scatter_nd (inputs [0 ], inputs [1 ], inputs [2 ], reduction )
834834
835835
836+ class Compress (OnnxOpConverter ):
837+ """Convert an onnx Compress node into an equivalent Relax expression."""
838+
839+ @classmethod
840+ def _impl_v11 (cls , bb , inputs , attr , params ):
841+ tensor , condition = inputs
842+ axis = attr .get ("axis" , None )
843+
844+ # Change one hot tensor to indices e.g. [0, 1, 1, 0, 1] -> [1, 2, 4]
845+ if condition .struct_info .dtype != "bool" :
846+ raise ValueError ("Condition tensor is expected to be a boolean tensor" )
847+ if condition .struct_info .ndim != 1 :
848+ raise ValueError ("Condition tensor is expected to be a 1D boolean tensor" )
849+ indices = relax .op .nonzero (condition )
850+ num_nonzero = tir .Var ("num_nonzero" , "int64" )
851+ indices = bb .match_cast (indices , relax .TensorStructInfo ([1 , num_nonzero ], "int64" ))
852+ indices = relax .op .reshape (indices , [- 1 ])
853+
854+ if axis is not None :
855+ return relax .op .take (tensor , indices , axis = axis )
856+
857+ # if axis is None, flatten input tensor before selection
858+ tensor = relax .op .reshape (tensor , (- 1 ,))
859+ return relax .op .take (tensor , indices , axis = 0 )
860+
861+
836862class Size (OnnxOpConverter ):
837863 """Convert an onnx Size node into an equivalent Relax expression."""
838864
@@ -2726,15 +2752,35 @@ def _impl_v11(cls, bb, inputs, attr, params):
27262752 axis = attr .get ("axis" , None )
27272753 sorted = bool (attr .get ("sorted" , 1 ))
27282754 # TODO(tvm-team): Add support for return_index, return_inverse, return_counts
2729- return relax .op .unique (data , sorted = sorted , axis = axis )
2755+ unique = relax .op .unique (data , sorted = sorted , axis = axis )
2756+ unique_numbers = tir .Var ("unique_numbers" , "int64" )
2757+ input_shape = data .struct_info .shape
2758+ dtype = data .struct_info .dtype
2759+
2760+ if axis is None :
2761+ # flatten the input tensor
2762+ return bb .match_cast (unique , relax .TensorStructInfo ((unique_numbers ,), dtype ))
2763+
2764+ axis = axis if axis >= 0 else len (input_shape ) + axis
2765+ if axis < 0 or axis >= len (input_shape ):
2766+ raise ValueError (f"Axis { axis } is out of bounds" )
2767+ output_shape = [
2768+ input_shape [i ] if i != axis else unique_numbers for i in range (len (input_shape ))
2769+ ]
2770+ return bb .match_cast (unique , relax .TensorStructInfo (output_shape , dtype ))
27302771
27312772
27322773class NonZero (OnnxOpConverter ):
27332774 """Converts an onnx NonZero node into an equivalent Relax expression."""
27342775
27352776 @classmethod
27362777 def _impl_v9 (cls , bb , inputs , attr , params ):
2737- return relax .op .nonzero (inputs [0 ])
2778+ ndim = inputs [0 ].struct_info .ndim
2779+ ndim = 1 if ndim == 0 else ndim
2780+ nonzero_numbers = tir .Var ("nonzero_numbers" , "int64" )
2781+ return bb .match_cast (
2782+ relax .op .nonzero (inputs [0 ]), relax .TensorStructInfo ((ndim , nonzero_numbers ), "int64" )
2783+ )
27382784
27392785
27402786class HardSigmoid (OnnxOpConverter ):
@@ -3075,7 +3121,7 @@ def _get_convert_map():
30753121 "Scatter" : Scatter ,
30763122 "ScatterElements" : ScatterElements ,
30773123 "ScatterND" : ScatterND ,
3078- # "Compress": Compress,
3124+ "Compress" : Compress ,
30793125 "Size" : Size ,
30803126 "EyeLike" : EyeLike ,
30813127 # Normalization
0 commit comments