diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index b346b6681b03..91cb69f84047 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -575,7 +575,8 @@ TVM_REGISTER_OP("relax.dynamic_strided_slice") .set_attr("FInferStructInfo", InferStructInfoDynStridedSlice) .set_attr("FRelaxInferLayout", InferLayoutDynStridedSlice) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) - .set_attr("FPurity", Bool(true)); + .set_attr("FPurity", Bool(true)) + .set_attr("FDataDependent", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 75e0776418ed..723c2814038a 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -287,8 +287,15 @@ class LegalizeMutator : public ExprMutator { return false; } - std::string op_name(op->name); - bool is_data_dependent_op = (op_name.find("dynamic") != std::string::npos); + bool is_data_dependent_op = [&]() -> bool { + if (Op::HasAttrMap("FDataDependent")) { + auto op_map = Op::GetAttrMap("FDataDependent"); + if (op_map.count(op)) { + return op_map[op]->value; + } + } + return false; + }(); bool ret_shape_defined = KnowAllShapeValues(GetStructInfo(visited_call)); if (!is_data_dependent_op && !ret_shape_defined) { // This operator cannot be legalized, because legalization by @@ -303,10 +310,6 @@ class LegalizeMutator : public ExprMutator { // data-dependent op, and match cast to define symbolic output // shapes. These symbolic output shapes at compile time can // be by later operations to refer to the runtime shape. - // - // TODO(Lunderberg): Make a new operator attribute - // `.set_attr("DataDependent")`, rather than relying on - // the name of the operator. return false; } diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index a6e53dab4d42..44419e51e7dc 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -16,10 +16,12 @@ # under the License. import tvm -from tvm.relax.transform import LegalizeOps -from tvm.script import relax as R, tir as T, ir as I import tvm.testing - +from tvm.ir import Op +from tvm.relax.transform import LegalizeOps +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T ##################### Indexing ##################### @@ -1197,5 +1199,13 @@ def einsum( tvm.ir.assert_structural_equal(mod, Expected) +def test_data_dependent_attribute(): + dynamic_strided_slice_op = Op.get("relax.dynamic_strided_slice") + assert dynamic_strided_slice_op.get_attr("FDataDependent") + + strided_slice_op = Op.get("relax.strided_slice") + assert strided_slice_op.get_attr("FDataDependent") is None + + if __name__ == "__main__": tvm.testing.main()