[Frontend][ONNX] Support select_last_index for ArgMax and ArgMin#18969
Conversation
Implements the select_last_index attribute (opset 12) for ArgMax and ArgMin ONNX operators. Previously this attribute raised OpAttributeUnImplemented. The implementation reverses the input tensor along the reduction axis, runs argmax/argmin on the flipped copy, then remaps the index back via (axis_size - 1) - flipped_idx. Adds tests for correctness with ties, correctness without ties, and IR structure verification. Signed-off-by: OmarAzizi <oalazizi75@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request implements the select_last_index attribute for ArgMax and ArgMin operators in the ONNX frontend, allowing the retrieval of the last occurrence of extreme values via a flip-and-subtract strategy. New unit tests have been added to verify correctness in tie-breaking scenarios and to ensure the expected Relax IR lowering. Review feedback points out that the current implementation relies on static shape information, which will fail for models with dynamic dimensions; it is recommended to use relax.op.shape_of instead. Additionally, the logic for both operators is nearly identical and should be refactored into a shared helper function to improve maintainability.
| axis_size = int(inputs[0].struct_info.shape[axis]) | ||
| offset = bb.normalize( | ||
| relax.op.full( | ||
| flipped_idx.struct_info.shape, | ||
| relax.const(axis_size - 1, "int64"), | ||
| dtype="int64", | ||
| ) | ||
| ) |
There was a problem hiding this comment.
The use of int(inputs[0].struct_info.shape[axis]) assumes that the size of the reduction axis is a static integer. This will fail for models with dynamic shapes, where a shape dimension can be a tir.Var. To correctly handle dynamic shapes, you should use relax.op.shape_of to obtain the axis size at runtime and create the offset as a relax.PrimValue.
offset = bb.normalize(
relax.op.full(
flipped_idx.struct_info.shape,
relax.PrimValue(relax.op.shape_of(data)[axis] - 1),
dtype="int64",
)
)| if select_last_index: | ||
| # TODO(vvchernov): support attr | ||
| raise tvm.error.OpAttributeUnImplemented( | ||
| "'select_last_index' attribute has not been supported yet" | ||
| data_flipped = relax.op.flip(data, axis=axis) | ||
| flipped_idx = bb.normalize(relax.op.argmin(data_flipped, axis, keepdims)) | ||
| axis_size = int(inputs[0].struct_info.shape[axis]) | ||
| offset = bb.normalize( | ||
| relax.op.full( | ||
| flipped_idx.struct_info.shape, | ||
| relax.const(axis_size - 1, "int64"), | ||
| dtype="int64", | ||
| ) | ||
| ) | ||
| return relax.op.subtract(offset, flipped_idx) | ||
| return relax.op.argmin(data, axis, keepdims) |
There was a problem hiding this comment.
This implementation for select_last_index is nearly identical to the one for ArgMax. To improve maintainability and avoid code duplication, consider refactoring this logic into a shared helper function. This would also allow you to apply the fix for the dynamic shape issue (mentioned in the ArgMax comment) in a single place.
A possible refactoring could look like this:
def _get_last_index(bb, data, axis, keepdims, op):
data_flipped = relax.op.flip(data, axis=axis)
flipped_idx = bb.normalize(op(data_flipped, axis, keepdims))
offset = bb.normalize(
relax.op.full(
flipped_idx.struct_info.shape,
relax.PrimValue(relax.op.shape_of(data)[axis] - 1),
dtype="int64",
)
)
return relax.op.subtract(offset, flipped_idx)
...
# in ArgMax._impl_v12
if select_last_index:
return _get_last_index(bb, data, axis, keepdims, relax.op.argmax)
...
# in ArgMin._impl_v12
if select_last_index:
return _get_last_index(bb, data, axis, keepdims, relax.op.argmin)
...- Refactored shared logic into _argreduce_select_last_index helper to avoid code duplication between ArgMax and ArgMin - Fixed dynamic shape handling by branching on whether axis_size is a static IntImm or a symbolic Var, using shape_to_tensor + take for the dynamic case Signed-off-by: OmarAzizi <oalazizi75@gmail.com>
Summary
This PR implements the
select_last_indexattribute (introduced in opset 12) for theArgMaxandArgMinONNX operators.Previously, setting
select_last_index=1raisedOpAttributeUnImplemented. This closes the limitation tracked in the ONNX frontend issue.Implementation
When
select_last_index=1, the input tensor is reversed along the reduction axis usingrelax.op.flip, argmax/argmin is computed on the flipped copy, and the result is remapped back to the original index space vialast_idx = (axis_size - 1) - flipped_idxCloses part of #18945