From 39b5abd0effa88c380a4c3d47783ec7d67811aff Mon Sep 17 00:00:00 2001 From: Guan-Ming Chiu Date: Fri, 16 Jan 2026 12:20:13 +0800 Subject: [PATCH 1/2] Add FDataDependent operator attribute for LegalizeOps --- src/relax/op/tensor/index.cc | 3 ++- src/relax/transform/legalize_ops.cc | 15 +++++++++------ ...ransform_legalize_ops_index_linear_algebra.py | 16 +++++++++++++--- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index b346b6681b03..91cb69f84047 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -575,7 +575,8 @@ TVM_REGISTER_OP("relax.dynamic_strided_slice") .set_attr("FInferStructInfo", InferStructInfoDynStridedSlice) .set_attr("FRelaxInferLayout", InferLayoutDynStridedSlice) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) - .set_attr("FPurity", Bool(true)); + .set_attr("FPurity", Bool(true)) + .set_attr("FDataDependent", Bool(true)); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 75e0776418ed..723c2814038a 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -287,8 +287,15 @@ class LegalizeMutator : public ExprMutator { return false; } - std::string op_name(op->name); - bool is_data_dependent_op = (op_name.find("dynamic") != std::string::npos); + bool is_data_dependent_op = [&]() -> bool { + if (Op::HasAttrMap("FDataDependent")) { + auto op_map = Op::GetAttrMap("FDataDependent"); + if (op_map.count(op)) { + return op_map[op]->value; + } + } + return false; + }(); bool ret_shape_defined = KnowAllShapeValues(GetStructInfo(visited_call)); if (!is_data_dependent_op && !ret_shape_defined) { // This operator cannot be legalized, because legalization by @@ -303,10 +310,6 @@ class LegalizeMutator : public ExprMutator { // data-dependent op, and match cast to define symbolic output // shapes. These symbolic output shapes at compile time can // be by later operations to refer to the runtime shape. - // - // TODO(Lunderberg): Make a new operator attribute - // `.set_attr("DataDependent")`, rather than relying on - // the name of the operator. return false; } diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index a6e53dab4d42..e51f0645a6a5 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -16,10 +16,12 @@ # under the License. import tvm -from tvm.relax.transform import LegalizeOps -from tvm.script import relax as R, tir as T, ir as I import tvm.testing - +from tvm.ir import Op +from tvm.relax.transform import LegalizeOps +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T ##################### Indexing ##################### @@ -1197,5 +1199,13 @@ def einsum( tvm.ir.assert_structural_equal(mod, Expected) +def test_data_dependent_attribute(): + dynamic_strided_slice_op = Op.get("relax.dynamic_strided_slice") + assert dynamic_strided_slice_op.get_attr("FDataDependent") + + strided_slice_op = Op.get("relax.strided_slice") + assert not strided_slice_op.has_attr("FDataDependent") + + if __name__ == "__main__": tvm.testing.main() From 523ddbd177a2710aa4c49180bc572aac3da0c95e Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 16 Jan 2026 14:30:09 +0800 Subject: [PATCH 2/2] Update tests --- .../relax/test_transform_legalize_ops_index_linear_algebra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index e51f0645a6a5..44419e51e7dc 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -1204,7 +1204,7 @@ def test_data_dependent_attribute(): assert dynamic_strided_slice_op.get_attr("FDataDependent") strided_slice_op = Op.get("relax.strided_slice") - assert not strided_slice_op.has_attr("FDataDependent") + assert strided_slice_op.get_attr("FDataDependent") is None if __name__ == "__main__":