From bc6b02dcfe3299db485aa2d2c4efdb4f9234fdd1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 1 Jul 2024 14:20:50 -0500 Subject: [PATCH] [Bugfix] Restrict CopyOnWrite to _type_final Prior to this commit, the `TVM_DEFINE_OBJECT_REF_COW_METHOD` could be used in any `ObjectRef` subclass to provide a `CopyOnWrite` method. However, the implementation of this method method was invalid if the object's `ContainerType` could itself be subclassed. In that case, using `obj.CopyOnWrite()` when the object contains a subclass, and when a copy is required, would silently convert `obj` to instead contain a base class. This commit adds a `static_assert`, to the `TVM_DEFINE_OBJECT_REF_COW_METHOD` macro, preventing the macro from being used in classes that would have incorrect usage. Compilation with this change found two classes, `relax::Var` and `relax::BindingBlock` that were susceptible to this error, and the macro has been removed from these classes. For backwards-compatibility, the `CopyOnWrite` function for these two classes is provided explicitly. --- include/tvm/relax/expr.h | 7 ++++--- include/tvm/runtime/object.h | 20 +++++++++++-------- src/relax/ir/expr.cc | 38 ++++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 401aaa9248ce..60032c34622f 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -427,7 +427,8 @@ class Var : public LeafExpr { TVM_DLL explicit Var(Id vid, Optional struct_info_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); + + VarNode* CopyOnWrite(); }; /*! \brief A sub-type of the variable node used to mark dataflow variables from @@ -784,10 +785,10 @@ class BindingBlock : public ObjectRef { public: TVM_DLL explicit BindingBlock(Array bindings, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(BindingBlockNode); + + BindingBlockNode* CopyOnWrite(); }; -class DataflowBlock; class DataflowBlockNode : public BindingBlockNode { public: bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 172316daae59..4483867f3ccb 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -823,14 +823,18 @@ struct ObjectPtrEqual { * * \endcode */ -#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ - ObjectName* CopyOnWrite() { \ - ICHECK(data_ != nullptr); \ - if (!data_.unique()) { \ - auto n = make_object(*(operator->())); \ - ObjectPtr(std::move(n)).swap(data_); \ - } \ - return static_cast(data_.get()); \ +#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ + static_assert(ObjectName::_type_final, \ + "TVM's CopyOnWrite may only be used for " \ + "Object types that are declared as final, " \ + "using the TVM_DECLARE_FINAL_OBJECT_INFO macro."); \ + ObjectName* CopyOnWrite() { \ + ICHECK(data_ != nullptr); \ + if (!data_.unique()) { \ + auto n = make_object(*(operator->())); \ + ObjectPtr(std::move(n)).swap(data_); \ + } \ + return static_cast(data_.get()); \ } // Implementations details below diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 59b6a0aeb78b..a14ba1d9aaa1 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -265,6 +265,25 @@ Var::Var(Id vid, Optional struct_info_annotation, Span span) { data_ = std::move(n); } +VarNode* Var::CopyOnWrite() { + // The `TVM_DEFINE_OBJECT_REF_COW_METHOD` cannot be used for + // Var, because it is the base class for `DataflowBlock`. + // If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the + // automatic implementation would erroneously convert from a + // `DataflowBlock` to a `Var`. + ICHECK(data_ != nullptr); + if (!data_.unique()) { + ObjectPtr node; + if (auto dataflow_var = as()) { + node = make_object(*dataflow_var); + } else { + node = make_object(*(operator->())); + } + ObjectPtr(std::move(node)).swap(data_); + } + return static_cast(data_.get()); +} + TVM_REGISTER_GLOBAL("relax.Var") .set_body_typed([](String name_hint, Optional struct_info_annotation, Span span) { return Var(name_hint, struct_info_annotation, span); @@ -473,6 +492,25 @@ BindingBlock::BindingBlock(Array bindings, Span span) { data_ = std::move(n); } +BindingBlockNode* BindingBlock::CopyOnWrite() { + // The `TVM_DEFINE_OBJECT_REF_COW_METHOD` cannot be used for + // BindingBlock, because it is the base class for `DataflowBlock`. + // If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the + // automatic implementation would erroneously convert from a + // `DataflowBlock` to a `BindingBlock`. + ICHECK(data_ != nullptr); + if (!data_.unique()) { + ObjectPtr node; + if (auto dataflow_block = as()) { + node = make_object(*dataflow_block); + } else { + node = make_object(*(operator->())); + } + ObjectPtr(std::move(node)).swap(data_); + } + return static_cast(data_.get()); +} + TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array bindings, Span span) { return BindingBlock(bindings, span); });