Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/relax/op/tensor/index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,8 @@ TVM_REGISTER_OP("relax.dynamic_strided_slice")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoDynStridedSlice)
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutDynStridedSlice)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));
.set_attr<Bool>("FPurity", Bool(true))
.set_attr<Bool>("FDataDependent", Bool(true));

} // namespace relax
} // namespace tvm
15 changes: 9 additions & 6 deletions src/relax/transform/legalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,15 @@ class LegalizeMutator : public ExprMutator {
return false;
}

std::string op_name(op->name);
bool is_data_dependent_op = (op_name.find("dynamic") != std::string::npos);
bool is_data_dependent_op = [&]() -> bool {
if (Op::HasAttrMap("FDataDependent")) {
auto op_map = Op::GetAttrMap<Bool>("FDataDependent");
if (op_map.count(op)) {
return op_map[op]->value;
}
}
return false;
}();
Comment on lines +290 to +298
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This immediately-invoked lambda expression can be simplified to a more concise and idiomatic one-liner using Op::GetAttrMap and the .get() method with a default value. This will improve readability.

      bool is_data_dependent_op = Op::GetAttrMap<Bool>("FDataDependent").get(op, Bool(false))->value;

bool ret_shape_defined = KnowAllShapeValues(GetStructInfo(visited_call));
if (!is_data_dependent_op && !ret_shape_defined) {
// This operator cannot be legalized, because legalization by
Expand All @@ -303,10 +310,6 @@ class LegalizeMutator : public ExprMutator {
// data-dependent op, and match cast to define symbolic output
// shapes. These symbolic output shapes at compile time can
// be by later operations to refer to the runtime shape.
//
// TODO(Lunderberg): Make a new operator attribute
// `.set_attr<Bool>("DataDependent")`, rather than relying on
// the name of the operator.
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# under the License.

import tvm
from tvm.relax.transform import LegalizeOps
from tvm.script import relax as R, tir as T, ir as I
import tvm.testing

from tvm.ir import Op
from tvm.relax.transform import LegalizeOps
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T

##################### Indexing #####################

Expand Down Expand Up @@ -1197,5 +1199,13 @@ def einsum(
tvm.ir.assert_structural_equal(mod, Expected)


def test_data_dependent_attribute():
dynamic_strided_slice_op = Op.get("relax.dynamic_strided_slice")
assert dynamic_strided_slice_op.get_attr("FDataDependent")

strided_slice_op = Op.get("relax.strided_slice")
assert strided_slice_op.get_attr("FDataDependent") is None


if __name__ == "__main__":
tvm.testing.main()
Loading