Skip to content

[Frontend][ONNX] Support select_last_index for ArgMax and ArgMin#18969

Merged
tlopex merged 2 commits intoapache:mainfrom
OmarAzizi:relax-onnx-argmax-argmin-select-last-index
Apr 3, 2026
Merged

[Frontend][ONNX] Support select_last_index for ArgMax and ArgMin#18969
tlopex merged 2 commits intoapache:mainfrom
OmarAzizi:relax-onnx-argmax-argmin-select-last-index

Conversation

@OmarAzizi
Copy link
Copy Markdown
Contributor

Summary

This PR implements the select_last_index attribute (introduced in opset 12) for the ArgMax and ArgMin ONNX operators.

Previously, setting select_last_index=1 raised OpAttributeUnImplemented. This closes the limitation tracked in the ONNX frontend issue.

Implementation

When select_last_index=1, the input tensor is reversed along the reduction axis using relax.op.flip, argmax/argmin is computed on the flipped copy, and the result is remapped back to the original index space via last_idx = (axis_size - 1) - flipped_idx

Closes part of #18945

Implements the select_last_index attribute (opset 12) for ArgMax and
ArgMin ONNX operators. Previously this attribute raised
OpAttributeUnImplemented.

The implementation reverses the input tensor along the reduction axis,
runs argmax/argmin on the flipped copy, then remaps the index back via
(axis_size - 1) - flipped_idx.

Adds tests for correctness with ties, correctness without ties, and
IR structure verification.

Signed-off-by: OmarAzizi <oalazizi75@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the select_last_index attribute for ArgMax and ArgMin operators in the ONNX frontend, allowing the retrieval of the last occurrence of extreme values via a flip-and-subtract strategy. New unit tests have been added to verify correctness in tie-breaking scenarios and to ensure the expected Relax IR lowering. Review feedback points out that the current implementation relies on static shape information, which will fail for models with dynamic dimensions; it is recommended to use relax.op.shape_of instead. Additionally, the logic for both operators is nearly identical and should be refactored into a shared helper function to improve maintainability.

Comment on lines 3791 to 3798
axis_size = int(inputs[0].struct_info.shape[axis])
offset = bb.normalize(
relax.op.full(
flipped_idx.struct_info.shape,
relax.const(axis_size - 1, "int64"),
dtype="int64",
)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of int(inputs[0].struct_info.shape[axis]) assumes that the size of the reduction axis is a static integer. This will fail for models with dynamic shapes, where a shape dimension can be a tir.Var. To correctly handle dynamic shapes, you should use relax.op.shape_of to obtain the axis size at runtime and create the offset as a relax.PrimValue.

            offset = bb.normalize(
                relax.op.full(
                    flipped_idx.struct_info.shape,
                    relax.PrimValue(relax.op.shape_of(data)[axis] - 1),
                    dtype="int64",
                )
            )

Comment on lines 3833 to 3845
if select_last_index:
# TODO(vvchernov): support attr
raise tvm.error.OpAttributeUnImplemented(
"'select_last_index' attribute has not been supported yet"
data_flipped = relax.op.flip(data, axis=axis)
flipped_idx = bb.normalize(relax.op.argmin(data_flipped, axis, keepdims))
axis_size = int(inputs[0].struct_info.shape[axis])
offset = bb.normalize(
relax.op.full(
flipped_idx.struct_info.shape,
relax.const(axis_size - 1, "int64"),
dtype="int64",
)
)
return relax.op.subtract(offset, flipped_idx)
return relax.op.argmin(data, axis, keepdims)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This implementation for select_last_index is nearly identical to the one for ArgMax. To improve maintainability and avoid code duplication, consider refactoring this logic into a shared helper function. This would also allow you to apply the fix for the dynamic shape issue (mentioned in the ArgMax comment) in a single place.

A possible refactoring could look like this:

def _get_last_index(bb, data, axis, keepdims, op):
    data_flipped = relax.op.flip(data, axis=axis)
    flipped_idx = bb.normalize(op(data_flipped, axis, keepdims))
    offset = bb.normalize(
        relax.op.full(
            flipped_idx.struct_info.shape,
            relax.PrimValue(relax.op.shape_of(data)[axis] - 1),
            dtype="int64",
        )
    )
    return relax.op.subtract(offset, flipped_idx)

...
# in ArgMax._impl_v12
if select_last_index:
    return _get_last_index(bb, data, axis, keepdims, relax.op.argmax)
...

# in ArgMin._impl_v12
if select_last_index:
    return _get_last_index(bb, data, axis, keepdims, relax.op.argmin)
...

- Refactored shared logic into _argreduce_select_last_index helper to
  avoid code duplication between ArgMax and ArgMin
- Fixed dynamic shape handling by branching on whether axis_size is a
  static IntImm or a symbolic Var, using shape_to_tensor + take for
  the dynamic case

Signed-off-by: OmarAzizi <oalazizi75@gmail.com>
Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Thanks!

@tlopex tlopex merged commit 10ba3c2 into apache:main Apr 3, 2026
9 checks passed
@OmarAzizi OmarAzizi deleted the relax-onnx-argmax-argmin-select-last-index branch April 8, 2026 11:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants