From 126829958ac87623959568388a238df8004928e7 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 26 Dec 2025 21:25:54 +0800 Subject: [PATCH 1/2] Replaced call_pure_packed with tensor_to_shape operator --- python/tvm/relax/transform/legalize_ops/index.py | 13 ++----------- ...t_transform_legalize_ops_index_linear_algebra.py | 8 ++------ 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/index.py b/python/tvm/relax/transform/legalize_ops/index.py index d99c1f4db6ed..ffdd2d8edca5 100644 --- a/python/tvm/relax/transform/legalize_ops/index.py +++ b/python/tvm/relax/transform/legalize_ops/index.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name """Default legalization function for index operators.""" from tvm import topi, tir, te -from ...op import call_pure_packed +from ...op import call_pure_packed, tensor_to_shape from ...block_builder import BlockBuilder from ...expr import Call, Expr from ...struct_info import ShapeStructInfo, PrimStructInfo @@ -109,17 +109,8 @@ def get_length(begin, end, strides, length): ) # 2. Convert tensor to shape and match cast with new symbolic vars - # Get shape length ndim = int(output_shape.struct_info.shape[0]) - output_shape = bb.emit( - # TODO(@relax-team): Ideally, we should use the tensor_to_shape op here to - # address the issue with purity, but that introduces a staging issue: - # we need to apply DecomposeOpsForInference in that case - # and it's unclear when in the build it should happen - call_pure_packed( - "vm.builtin.tensor_to_shape", output_shape, sinfo_args=ShapeStructInfo(ndim=ndim) - ) - ) + output_shape = bb.emit(tensor_to_shape(output_shape)) output_shape_vars = [tir.Var("s", "int64") for i in range(ndim)] bb.match_cast(output_shape, ShapeStructInfo(output_shape_vars)) diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index efa7f4dfff28..a6e53dab4d42 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -669,9 +669,7 @@ def main( (x, begin, end, strides), out_sinfo=R.Tensor((4,), dtype="int64"), ) - gv1: R.Shape(ndim=4) = R.call_pure_packed( - "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=4),) - ) + gv1: R.Shape(ndim=4) = R.tensor_to_shape(gv) gv2: R.Shape([s, s_1, s_2, s_3]) = R.match_cast( gv1, R.Shape([s, s_1, s_2, s_3]) ) @@ -868,9 +866,7 @@ def main( (x, begin, end, strides), out_sinfo=R.Tensor((2,), dtype="int64"), ) - gv1: R.Shape(ndim=2) = R.call_pure_packed( - "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=2),) - ) + gv1: R.Shape(ndim=2) = R.tensor_to_shape(gv) gv2: R.Shape([s, s_1]) = R.match_cast(gv1, R.Shape([s, s_1])) gv_1 = R.call_tir( Expected.dynamic_strided_slice, From 5a933b009473c7493e0d475c2dfdabbd6af90e70 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 26 Dec 2025 21:43:08 +0800 Subject: [PATCH 2/2] Update python/tvm/relax/transform/legalize_ops/index.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/tvm/relax/transform/legalize_ops/index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/transform/legalize_ops/index.py b/python/tvm/relax/transform/legalize_ops/index.py index ffdd2d8edca5..75c17f7fa936 100644 --- a/python/tvm/relax/transform/legalize_ops/index.py +++ b/python/tvm/relax/transform/legalize_ops/index.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name """Default legalization function for index operators.""" from tvm import topi, tir, te -from ...op import call_pure_packed, tensor_to_shape +from ...op import tensor_to_shape from ...block_builder import BlockBuilder from ...expr import Call, Expr from ...struct_info import ShapeStructInfo, PrimStructInfo