diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 3292ce57ba5c..bada827c813d 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -43,7 +43,11 @@ #include #include +#include "tvm/ir/expr.h" +#include "tvm/runtime/data_type.h" #include "tvm/tir/expr.h" +#include "tvm/tir/op.h" +#include "tvm/tir/var.h" namespace tvm { namespace topi { @@ -635,6 +639,55 @@ inline Array split(const Tensor& x, Array split_indices, int a return result; } +inline PrimExpr DynamicCanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride) { + auto idx_var = index.as(); + auto extent_var = extent.as(); + + if (idx_var && extent_var && idx_var->name_hint == extent_var->name_hint) { + return index; + } + + PrimExpr begin_range = tvm::if_then_else(stride < 0, -1, 0); + PrimExpr end_range = tvm::if_then_else(stride < 0, extent - 1, extent); + + if (!(index->IsInstance() && GetConstInt(index) >= 0)) { + index = tvm::if_then_else(index < 0, index + extent, index); + } + + return tvm::min(tvm::max(index, begin_range), end_range); +} + +inline int64_t StaticCanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) { + int64_t begin_range = stride < 0 ? -1 : 0; + int64_t end_range = stride < 0 ? extent - 1 : extent; + if (index < 0) { + index += extent; + } + return std::min(std::max(index, begin_range), end_range); +} + +inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride) { + if (index->IsInstance() && extent->IsInstance() && + stride->IsInstance()) { + return tvm::IntImm( + tvm::DataType::Int(64), + StaticCanonicalizeIndex(GetConstInt(index), GetConstInt(extent), GetConstInt(stride))); + } + return DynamicCanonicalizeIndex(index, extent, stride); +} + +inline PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr extent, + bool assume_inbound = true) { + if (assume_inbound) { + return ceildiv(end - begin, stride); + } else { + begin = CanonicalizeIndex(begin, extent, stride); + end = CanonicalizeIndex(end, extent, stride); + return tvm::if_then_else(stride < 0, ceildiv(begin - end, -stride), + ceildiv(end - begin, stride)); + } +} + /*! * \brief strided_slice of a tensor where begin/end/stride can be mixed static and dynamic * @@ -644,6 +697,7 @@ inline Array split(const Tensor& x, Array split_indices, int a * \param strides Specifies the stride values, it can be negative * in that case, the input tensor will be reversed in that particular axis * \param axes Specifies which axes will be updated. + * \param assume_inbound Specifies if all indices are assumed to be inbound * \param name The name of the operation * \param tag The tag to mark the operation * @@ -651,7 +705,7 @@ inline Array split(const Tensor& x, Array split_indices, int a */ inline Tensor dynamic_strided_slice_with_axes( const Tensor& x, const Array& begin, const Array& end, - const Array& strides, const Array& axes, + const Array& strides, const Array& axes, bool assume_inbound = true, std::string name = "T_dynamic_strided_slice_with_axes", std::string tag = kInjective) { const size_t src_tensor_dim = x->shape.size(); ICHECK_EQ(begin.size(), end.size()); @@ -669,7 +723,8 @@ inline Tensor dynamic_strided_slice_with_axes( Array out_shape = x->shape; for (size_t i = 0; i < begin.size(); i++) { int axis = axes[i]->value; - PrimExpr new_shape = analyzer.Simplify(ceildiv(end[i] - begin[i], strides[i])); + PrimExpr new_shape = + analyzer.Simplify(GetLength(begin[i], end[i], strides[i], out_shape[axis], assume_inbound)); out_shape.Set(axis, new_shape); } @@ -697,6 +752,7 @@ inline Tensor dynamic_strided_slice_with_axes( * \param end Indices indicating end of the slice * \param strides Specifies the stride values, it can be negative * in that case, the input tensor will be reversed in that particular axis + * \param assume_inbound Specifies if all indices are assumed to be inbound * \param name The name of the operation * \param tag The tag to mark the operation * @@ -704,6 +760,7 @@ inline Tensor dynamic_strided_slice_with_axes( */ inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begin, const Array& end, const Array& strides, + bool assume_inbound = true, std::string name = "T_dynamic_strided_slice", std::string tag = kInjective) { const size_t src_tensor_dim = x->shape.size(); @@ -721,7 +778,8 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begi // Check ProducerLoad to keep backward compatibility for Relay. if (!begin[i]->IsInstance() && !end[i]->IsInstance() && !strides[i]->IsInstance()) { - out_shape.push_back(analyzer.Simplify(ceildiv(end[i] - begin[i], strides[i]))); + out_shape.push_back( + analyzer.Simplify(GetLength(begin[i], end[i], strides[i], x->shape[i], assume_inbound))); } else { out_shape.push_back(tvm::tir::Var("dim")); } @@ -755,6 +813,7 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begi * \param end Indices indicating end of the slice * \param strides Specifies the stride values, it can be negative * in that case, the input tensor will be reversed in that particular axis + * \param assume_inbound Specifies if all indices are assumed to be inbound * \param name The name of the operation * \param tag The tag to mark the operation * @@ -762,6 +821,7 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begi */ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& begin, const te::Tensor& end, const te::Tensor& strides, + bool assume_inbound = true, std::string name = "T_strided_slice_dynamic", std::string tag = topi::kInjective) { DataType index_dtype = begin->shape[0]->dtype; @@ -776,7 +836,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b end_expr.push_back(end(ind)); strides_expr.push_back(strides(ind)); } - return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag); + return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, assume_inbound, name, tag); } /*! diff --git a/python/tvm/relax/transform/legalize_ops/distributed.py b/python/tvm/relax/transform/legalize_ops/distributed.py index d540628e0e2b..2ca283d1ec09 100644 --- a/python/tvm/relax/transform/legalize_ops/distributed.py +++ b/python/tvm/relax/transform/legalize_ops/distributed.py @@ -40,4 +40,5 @@ def _redistribute_replica_to_shard(_bb: BlockBuilder, call: Call) -> Expr: axes=[axis], begin=[worker_id_symbol * split_axis_size // num_workers], end=[(worker_id_symbol + 1) * split_axis_size // num_workers], + assume_inbound=True, ) diff --git a/python/tvm/relax/transform/legalize_ops/index.py b/python/tvm/relax/transform/legalize_ops/index.py index a4fac46a13b1..8d0ac535f626 100644 --- a/python/tvm/relax/transform/legalize_ops/index.py +++ b/python/tvm/relax/transform/legalize_ops/index.py @@ -67,6 +67,7 @@ def _relax_tuple_to_tir(relax_tuple): strides, axes, slice_mode="end", + assume_inbound=call.attrs.assume_inbound, ) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 3b007a632599..686311fbee86 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -170,7 +170,7 @@ def reverse_sequence(a, seq_lengths, seq_axis=1, batch_axis=0): return cpp.reverse_sequence(a, seq_lengths, seq_axis, batch_axis) -def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"): +def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end", assume_inbound=True): """Slice of an array. Parameters @@ -200,6 +200,9 @@ def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"): the sizeof a slice starting at the location specified by begin. If end[i] is -1, all remaining elements in that dimension are included in the slice. + assume_inbound: bool, optional + A flag to indicate if all indices are assumed to be inbound + Returns ------- ret : tvm.te.Tensor @@ -223,7 +226,7 @@ def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"): strides = [] if axes is None: axes = [] - return cpp.strided_slice(a, begin, end, strides, axes, slice_mode) + return cpp.strided_slice(a, begin, end, strides, axes, slice_mode, assume_inbound) def dynamic_strided_slice(a, begin, end, strides, output_shape): diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 36527c35841e..e62dbe89d08a 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -25,6 +25,7 @@ #include "index.h" #include +#include #include #include @@ -171,29 +172,6 @@ Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, Optional strid TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice); -inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr stride) { - // Handle Python-style negative indices - index = if_then_else(index < 0, index + extent, index); - // Clamp the result to valid indices - PrimExpr lower_bound = tvm::if_then_else(stride < 0, -1, 0); - PrimExpr upper_bound = tvm::if_then_else(stride < 0, extent - 1, extent); - index = tvm::min(tvm::max(index, lower_bound), upper_bound); - - return index; -} - -PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr extent, - bool assume_inbound) { - if (assume_inbound) { - return ceildiv(end - begin, stride); - } else { - begin = CanonicalizeIndex(begin, extent, stride); - end = CanonicalizeIndex(end, extent, stride); - return tvm::if_then_else(stride < 0, ceildiv(begin - end, -stride), - ceildiv(end - begin, stride)); - } -} - /* \brief Helper function to unpack a relax::Tuple * * A `relax::Tuple` may be provided to an operator as an in-line @@ -424,7 +402,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx PrimExpr end = end_tuple[i]; PrimExpr output_dim = - GetLength(begin, end, strides_tuple[i], input_dim, attrs->assume_inbound); + topi::GetLength(begin, end, strides_tuple[i], input_dim, attrs->assume_inbound); arith::Analyzer* analyzer = ctx->GetAnalyzer(); std::optional> context; diff --git a/src/topi/transform.cc b/src/topi/transform.cc index a84e3dce500c..d844739568bc 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -27,6 +27,10 @@ #include #include +#include + +#include "tvm/ir/expr.h" + namespace tvm { namespace topi { @@ -179,6 +183,7 @@ TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* Array end = args[2]; Array strides = args[3]; Array axes = args[4]; + bool assume_inbound = args[6]; if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides) && IsConstIntArray(x->shape)) { Array begin_static = args[1]; @@ -192,9 +197,9 @@ TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* } } else { if (axes.size()) { - *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes); + *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes, assume_inbound); } else { - *rv = dynamic_strided_slice(x, begin, end, strides); + *rv = dynamic_strided_slice(x, begin, end, strides, assume_inbound); } } }); diff --git a/tests/python/relax/test_op_index.py b/tests/python/relax/test_op_index.py index 57e7a14b7056..31245de59960 100644 --- a/tests/python/relax/test_op_index.py +++ b/tests/python/relax/test_op_index.py @@ -21,6 +21,7 @@ from tvm import TVMError from tvm.ir import Op, VDevice from tvm.script import ir as I, relax as R, tir as T +import numpy as np def test_op_correctness(): @@ -1010,5 +1011,44 @@ def strided_slice( tvm.ir.assert_structural_equal(expected, after) +def test_legalize_dynamic_begin_inf_end(): + """relax.op.strided_slice FLegalize must support dynamic begin/end""" + + @I.ir_module + class before: + @R.function + def main(A: R.Tensor((16, 16), "float32"), B: R.Shape(["index"])) -> R.Tensor((1, 16)): + index = T.int64() + return R.strided_slice( + A, [0], [index], [T.int64(np.iinfo(np.int64).max)], assume_inbound=False + ) + + # fmt: off + @I.ir_module + class expected: + @T.prim_func(private=True) + def strided_slice(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), var_T_dynamic_strided_slice_with_axes: T.handle, index: T.int64): + T.func_attr({"tir.noalias": T.bool(True)}) + T_dynamic_strided_slice_with_axes = T.match_buffer(var_T_dynamic_strided_slice_with_axes, (T.max(T.int64(16) - T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index), T.int64(0)), T.int64(0)), T.int64(16))) + # with T.block("root"): + for ax0, ax1 in T.grid(T.max(T.int64(16) - T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index), T.int64(0)), T.int64(0)), T.int64(16)): + with T.block("T_dynamic_strided_slice_with_axes"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0 + index, v_ax1]) + T.writes(T_dynamic_strided_slice_with_axes[v_ax0, v_ax1]) + T_dynamic_strided_slice_with_axes[v_ax0, v_ax1] = A[v_ax0 + index, v_ax1] + + @R.function + def main(A: R.Tensor((16, 16), dtype="float32"), B: R.Shape(["index"])) -> R.Tensor(("T.max(16 - T.max(T.if_then_else(index < 0, index + 16, index), 0), 0)", 16), dtype="float32"): + index = T.int64() + cls = expected + gv = R.call_tir(cls.strided_slice, (A,), out_sinfo=R.Tensor((T.max(16 - T.max(T.if_then_else(index < 0, index + 16, index), 0), 0), 16), dtype="float32"), tir_vars=R.shape([index])) + return gv + # fmt: on + + after = tvm.relax.transform.LegalizeOps()(before) + tvm.ir.assert_structural_equal(expected, after) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index 2f4da5cf0653..90643694c1e8 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -263,7 +263,7 @@ class StridedSlice: @R.function def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor((2, "n"), "float32"): n = T.int64() - gv: R.Tensor((2, n), "float32") = R.strided_slice(x, axes=[0], begin=[1], end=[8], strides=[3]) + gv: R.Tensor((3, n), "float32") = R.strided_slice(x, axes=[0], begin=[1], end=[8], strides=[3], assume_inbound=True) return gv @I.ir_module