diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 401aaa9248ce..60032c34622f 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -427,7 +427,8 @@ class Var : public LeafExpr { TVM_DLL explicit Var(Id vid, Optional struct_info_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); + + VarNode* CopyOnWrite(); }; /*! \brief A sub-type of the variable node used to mark dataflow variables from @@ -784,10 +785,10 @@ class BindingBlock : public ObjectRef { public: TVM_DLL explicit BindingBlock(Array bindings, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(BindingBlockNode); + + BindingBlockNode* CopyOnWrite(); }; -class DataflowBlock; class DataflowBlockNode : public BindingBlockNode { public: bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 172316daae59..4483867f3ccb 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -823,14 +823,18 @@ struct ObjectPtrEqual { * * \endcode */ -#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ - ObjectName* CopyOnWrite() { \ - ICHECK(data_ != nullptr); \ - if (!data_.unique()) { \ - auto n = make_object(*(operator->())); \ - ObjectPtr(std::move(n)).swap(data_); \ - } \ - return static_cast(data_.get()); \ +#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ + static_assert(ObjectName::_type_final, \ + "TVM's CopyOnWrite may only be used for " \ + "Object types that are declared as final, " \ + "using the TVM_DECLARE_FINAL_OBJECT_INFO macro."); \ + ObjectName* CopyOnWrite() { \ + ICHECK(data_ != nullptr); \ + if (!data_.unique()) { \ + auto n = make_object(*(operator->())); \ + ObjectPtr(std::move(n)).swap(data_); \ + } \ + return static_cast(data_.get()); \ } // Implementations details below diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 59b6a0aeb78b..a14ba1d9aaa1 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -265,6 +265,25 @@ Var::Var(Id vid, Optional struct_info_annotation, Span span) { data_ = std::move(n); } +VarNode* Var::CopyOnWrite() { + // The `TVM_DEFINE_OBJECT_REF_COW_METHOD` cannot be used for + // Var, because it is the base class for `DataflowBlock`. + // If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the + // automatic implementation would erroneously convert from a + // `DataflowBlock` to a `Var`. + ICHECK(data_ != nullptr); + if (!data_.unique()) { + ObjectPtr node; + if (auto dataflow_var = as()) { + node = make_object(*dataflow_var); + } else { + node = make_object(*(operator->())); + } + ObjectPtr(std::move(node)).swap(data_); + } + return static_cast(data_.get()); +} + TVM_REGISTER_GLOBAL("relax.Var") .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { return Var(name_hint, struct_info_annotation, span); @@ -473,6 +492,25 @@ BindingBlock::BindingBlock(Array bindings, Span span) { data_ = std::move(n); } +BindingBlockNode* BindingBlock::CopyOnWrite() { + // The `TVM_DEFINE_OBJECT_REF_COW_METHOD` cannot be used for + // BindingBlock, because it is the base class for `DataflowBlock`. + // If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the + // automatic implementation would erroneously convert from a + // `DataflowBlock` to a `BindingBlock`. + ICHECK(data_ != nullptr); + if (!data_.unique()) { + ObjectPtr node; + if (auto dataflow_block = as()) { + node = make_object(*dataflow_block); + } else { + node = make_object(*(operator->())); + } + ObjectPtr(std::move(node)).swap(data_); + } + return static_cast(data_.get()); +} + TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array bindings, Span span) { return BindingBlock(bindings, span); });