diff --git a/python/tvm/relax/transform/legalize_ops/index.py b/python/tvm/relax/transform/legalize_ops/index.py index d99c1f4db6ed..75c17f7fa936 100644 --- a/python/tvm/relax/transform/legalize_ops/index.py +++ b/python/tvm/relax/transform/legalize_ops/index.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name """Default legalization function for index operators.""" from tvm import topi, tir, te -from ...op import call_pure_packed +from ...op import tensor_to_shape from ...block_builder import BlockBuilder from ...expr import Call, Expr from ...struct_info import ShapeStructInfo, PrimStructInfo @@ -109,17 +109,8 @@ def get_length(begin, end, strides, length): ) # 2. Convert tensor to shape and match cast with new symbolic vars - # Get shape length ndim = int(output_shape.struct_info.shape[0]) - output_shape = bb.emit( - # TODO(@relax-team): Ideally, we should use the tensor_to_shape op here to - # address the issue with purity, but that introduces a staging issue: - # we need to apply DecomposeOpsForInference in that case - # and it's unclear when in the build it should happen - call_pure_packed( - "vm.builtin.tensor_to_shape", output_shape, sinfo_args=ShapeStructInfo(ndim=ndim) - ) - ) + output_shape = bb.emit(tensor_to_shape(output_shape)) output_shape_vars = [tir.Var("s", "int64") for i in range(ndim)] bb.match_cast(output_shape, ShapeStructInfo(output_shape_vars)) diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index efa7f4dfff28..a6e53dab4d42 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -669,9 +669,7 @@ def main( (x, begin, end, strides), out_sinfo=R.Tensor((4,), dtype="int64"), ) - gv1: R.Shape(ndim=4) = R.call_pure_packed( - "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=4),) - ) + gv1: R.Shape(ndim=4) = R.tensor_to_shape(gv) gv2: R.Shape([s, s_1, s_2, s_3]) = R.match_cast( gv1, R.Shape([s, s_1, s_2, s_3]) ) @@ -868,9 +866,7 @@ def main( (x, begin, end, strides), out_sinfo=R.Tensor((2,), dtype="int64"), ) - gv1: R.Shape(ndim=2) = R.call_pure_packed( - "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=2),) - ) + gv1: R.Shape(ndim=2) = R.tensor_to_shape(gv) gv2: R.Shape([s, s_1]) = R.match_cast(gv1, R.Shape([s, s_1])) gv_1 = R.call_tir( Expected.dynamic_strided_slice,