From 5b5113516264181b6abed3ce1fe38c6754083f8d Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 17 Jul 2025 10:30:00 -0400 Subject: [PATCH] [FFI] Structural equal and hash based on reflection This PR add initial support for structural equal and hash via the new reflection mechanism. It will helps us to streamline the structural equality/hash with broader support and clean error reports via AccessPath. It also gives us ability to unify all struct equal/hash registration into the extra meta-data in reflection registration. --- ffi/CMakeLists.txt | 3 + ffi/include/tvm/ffi/c_api.h | 104 +++++- ffi/include/tvm/ffi/object.h | 2 + ffi/include/tvm/ffi/reflection/access_path.h | 108 ++++++ ffi/include/tvm/ffi/reflection/registry.h | 40 +- .../tvm/ffi/reflection/structural_equal.h | 80 ++++ .../tvm/ffi/reflection/structural_hash.h | 58 +++ ffi/include/tvm/ffi/string.h | 3 +- ffi/src/ffi/container.cc | 3 +- ffi/src/ffi/reflection/access_path.cc | 34 ++ ffi/src/ffi/reflection/structural_equal.cc | 349 ++++++++++++++++++ ffi/src/ffi/reflection/structural_hash.cc | 265 +++++++++++++ ...lection.cc => test_reflection_accessor.cc} | 21 +- .../test_reflection_structural_equal_hash.cc | 172 +++++++++ ffi/tests/cpp/testing_object.h | 79 +++- 15 files changed, 1298 insertions(+), 23 deletions(-) create mode 100644 ffi/include/tvm/ffi/reflection/access_path.h create mode 100644 ffi/include/tvm/ffi/reflection/structural_equal.h create mode 100644 ffi/include/tvm/ffi/reflection/structural_hash.h create mode 100644 ffi/src/ffi/reflection/access_path.cc create mode 100644 ffi/src/ffi/reflection/structural_equal.cc create mode 100644 ffi/src/ffi/reflection/structural_hash.cc rename ffi/tests/cpp/{test_reflection.cc => test_reflection_accessor.cc} (88%) create mode 100644 ffi/tests/cpp/test_reflection_structural_equal_hash.cc diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index 2aafacf5ac8f..ba0237f0e434 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -57,6 +57,9 @@ add_library(tvm_ffi_objs OBJECT "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_equal.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_hash.cc" ) set_target_properties( tvm_ffi_objs PROPERTIES diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 545604a47395..188f5303b8a6 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -56,6 +56,27 @@ #define TVM_FFI_DLL_EXPORT __attribute__((visibility("default"))) #endif +/*! + * \brief Marks the API as extra c++ api that is defined in cc files. + * + * These APIs are extra features that depend on, but are not required to + * support essential core functionality, such as function calling and object + * access. + * + * They are implemented in cc files to reduce compile-time overhead. + * The input/output only uses POD/Any/ObjectRef for ABI stability. + * However, these extra APIs may have an issue across MSVC/Itanium ABI, + * + * Related features are also available through reflection based function + * that is fully based on C API + * + * The project aims to minimize the number of extra C++ APIs and only + * restrict the use to non-core functionalities. + */ +#ifndef TVM_FFI_EXTRA_CXX_API +#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL +#endif + #ifdef __cplusplus extern "C" { #endif @@ -326,12 +347,89 @@ typedef enum { kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1, /*! \brief The field is a static method. */ kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2, + /*! + * \brief The field should be ignored when performing structural eq/hash + * + * This is an optional meta-data for structural eq/hash. + */ + kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3, + /*! + * \brief The field enters a def region where var can be defined/matched. + * + * This is an optional meta-data for structural eq/hash. + */ + kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4, #ifdef __cplusplus }; #else } TVMFFIFieldFlagBitMask; #endif +/*! + * \brief Optional meta-data for structural eq/hash. + * + * This meta-data is only useful when we want to leverage the information + * to perform richer semantics aware structural comparison and hash. + * It can be safely ignored if such information is not needed. + * + * The meta-data record comparison method in tree node and DAG node. + * + * \code + * x = VarNode() + * v0 = AddNode(x, 1) + * v1 = AddNode(x, 1) + * v2 = AddNode(v0, v0) + * v3 = AddNode(v1, v0) + * \endcode + * + * Consider the construct sequence of AddNode below, + * if AddNode is treated as a tree node, then v2 and v3 + * structural equals to each other, but if AddNode is + * treated as a DAG node, then v2 and v3 does not + * structural equals to each other. + */ +#ifdef __cplusplus +enum TVMFFISEqHashKind : int32_t { +#else +typedef enum { +#endif + /*! \brief Do not support structural eq/hash. */ + kTVMFFISEqHashKindUnsupported = 0, + /*! + * \brief The object be compared as a tree node. + */ + kTVMFFISEqHashKindTreeNode = 1, + /*! + * \brief The object is treated as a free variable that can be mapped + * to another free variable in the definition region. + */ + kTVMFFISEqHashKindFreeVar = 2, + /*! + * \brief The field should be compared as a DAG node. + */ + kTVMFFISEqHashKindDAGNode = 3, + /*! + * \brief The object is treated as a constant tree node. + * + * Same as tree node, but the object does not contain free var + * as any of its nested children. + * + * That means we can use pointer equality for equality. + */ + kTVMFFISEqHashKindConstTreeNode = 4, + /*! + * \brief One can simply use pointer equality for equality. + * + * This is useful for "singleton"-style object that can + * is only an unique copy of each value. + */ + kTVMFFISEqHashKindUniqueInstance = 5, +#ifdef __cplusplus +}; +#else +} TVMFFISEqHashKind; +#endif + /*! * \brief Information support for optional object reflection. */ @@ -431,7 +529,11 @@ typedef struct { * * This field is set optional and set to 0 if not registered. */ - int64_t total_size; + int32_t total_size; + /*! + * \brief Optional meta-data for structural eq/hash. + */ + TVMFFISEqHashKind structural_eq_hash_kind; } TVMFFITypeExtraInfo; /*! diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 0dba58acf487..05b936ea9077 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -212,6 +212,8 @@ class Object { static constexpr int32_t _type_index = TypeIndex::kTVMFFIObject; // the static type depth of the class static constexpr int32_t _type_depth = 0; + // the structural equality and hash kind of the type + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUnsupported; // extra fields used by plug-ins for attribute visiting // and structural information static constexpr const bool _type_has_method_sequal_reduce = false; diff --git a/ffi/include/tvm/ffi/reflection/access_path.h b/ffi/include/tvm/ffi/reflection/access_path.h new file mode 100644 index 000000000000..e37b3f410cbc --- /dev/null +++ b/ffi/include/tvm/ffi/reflection/access_path.h @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/reflection/registry.h + * \brief Registry of reflection metadata. + */ +#ifndef TVM_FFI_REFLECTION_ACCESS_PATH_H_ +#define TVM_FFI_REFLECTION_ACCESS_PATH_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { +namespace reflection { + +enum class AccessKind : int32_t { + kObjectField = 0, + kArrayIndex = 1, + kMapKey = 2, + // the following two are used for error reporting when + // the supposed access field is not available + kArrayIndexMissing = 3, + kMapKeyMissing = 4, +}; + +/*! + * \brief Represent a single step in object field, map key, array index access. + */ +class AccessStepObj : public Object { + public: + /*! + * \brief The kind of the access pattern. + */ + AccessKind kind; + /*! + * \brief The access key + * \note for array access, it will always be integer + * for field access, it will be string + */ + Any key; + + AccessStepObj(AccessKind kind, Any key) : kind(kind), key(key) {} + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("kind", &AccessStepObj::kind) + .def_ro("key", &AccessStepObj::key); + } + + static constexpr const char* _type_key = "tvm.ffi.reflection.AccessStep"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AccessStepObj, Object); +}; + +/*! + * \brief ObjectRef class of AccessStepObj. + * + * \sa AccessStepObj + */ +class AccessStep : public ObjectRef { + public: + AccessStep(AccessKind kind, Any key) : ObjectRef(make_object(kind, key)) {} + + static AccessStep ObjectField(String field_name) { + return AccessStep(AccessKind::kObjectField, field_name); + } + + static AccessStep ArrayIndex(int64_t index) { return AccessStep(AccessKind::kArrayIndex, index); } + + static AccessStep ArrayIndexMissing(int64_t index) { + return AccessStep(AccessKind::kArrayIndexMissing, index); + } + + static AccessStep MapKey(Any key) { return AccessStep(AccessKind::kMapKey, key); } + + static AccessStep MapKeyMissing(Any key) { return AccessStep(AccessKind::kMapKeyMissing, key); } + + TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessStep, ObjectRef, AccessStepObj); +}; + +using AccessPath = Array; +using AccessPathPair = Tuple; + +} // namespace reflection +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_ diff --git a/ffi/include/tvm/ffi/reflection/registry.h b/ffi/include/tvm/ffi/reflection/registry.h index 27e0b0877c7c..abd1851498c4 100644 --- a/ffi/include/tvm/ffi/reflection/registry.h +++ b/ffi/include/tvm/ffi/reflection/registry.h @@ -55,6 +55,39 @@ class DefaultValue : public FieldInfoTrait { Any value_; }; +/* + * \brief Trait that can be used to attach field flag + */ +class AttachFieldFlag : public FieldInfoTrait { + public: + /*! + * \brief Attach a field flag to the field + * + * \param flag The flag to be set + * + * \return The trait object. + */ + explicit AttachFieldFlag(int32_t flag) : flag_(flag) {} + + /*! + * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDef + */ + TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() { + return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef); + } + /*! + * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashIgnore + */ + TVM_FFI_INLINE static AttachFieldFlag SEqHashIgnore() { + return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashIgnore); + } + + TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { info->flags |= flag_; } + + private: + int32_t flag_; +}; + /*! * \brief Get the byte offset of a class member field. * @@ -83,7 +116,11 @@ class ReflectionDefBase { template static int FieldSetter(void* field, const TVMFFIAny* value) { TVM_FFI_SAFE_CALL_BEGIN(); - *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value).cast(); + if constexpr (std::is_same_v) { + *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value); + } else { + *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value).cast(); + } TVM_FFI_SAFE_CALL_END(); } @@ -346,6 +383,7 @@ class ObjectDef : public ReflectionDefBase { void RegisterExtraInfo(ExtraArgs&&... extra_args) { TVMFFITypeExtraInfo info; info.total_size = sizeof(Class); + info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind; info.creator = nullptr; info.doc = TVMFFIByteArray{nullptr, 0}; if constexpr (std::is_default_constructible_v) { diff --git a/ffi/include/tvm/ffi/reflection/structural_equal.h b/ffi/include/tvm/ffi/reflection/structural_equal.h new file mode 100644 index 000000000000..860222644c95 --- /dev/null +++ b/ffi/include/tvm/ffi/reflection/structural_equal.h @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/reflection/structural_equal.h + * \brief Structural equal implementation + */ +#ifndef TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_ +#define TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_ + +#include +#include +#include + +namespace tvm { +namespace ffi { +namespace reflection { +/* + * \brief Structural equality comparators + */ +class StructuralEqual { + public: + /** + * \brief Compare two Any values for structural equality. + * \param lhs The left hand side Any object. + * \param rhs The right hand side Any object. + * \param map_free_vars Whether to map free variables. + * \param skip_ndarray_content Whether to skip comparingn darray data content, + * useful for cases where we don't care about parameters content + * \return True if the two Any values are structurally equal, false otherwise. + */ + TVM_FFI_EXTRA_CXX_API static bool Equal(const Any& lhs, const Any& rhs, + bool map_free_vars = false, + bool skip_ndarray_content = false); + /** + * \brief Get the first mismatch AccessPath pair when running + * structural equal comparison between two Any values. + * + * \param lhs The left hand side Any object. + * \param rhs The right hand side Any object. + * \param map_free_vars Whether to map free variables. + * \param skip_ndarray_content Whether to skip comparing ndarray data content, + * useful for cases where we don't care about parameters content + * \return If comparison fails, return the first mismatch AccessPath pair, + * otherwise return std::nullopt. + */ + TVM_FFI_EXTRA_CXX_API static Optional GetFirstMismatch( + const Any& lhs, const Any& rhs, bool map_free_vars = false, + bool skip_ndarray_content = false); + + /* + * \brief Compare two Any values for structural equality. + * \param lhs The left hand side Any object. + * \param rhs The right hand side Any object. + * \return True if the two Any values are structurally equal, false otherwise. + */ + TVM_FFI_INLINE bool operator()(const Any& lhs, const Any& rhs) const { + return Equal(lhs, rhs, false, true); + } +}; + +} // namespace reflection +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_ diff --git a/ffi/include/tvm/ffi/reflection/structural_hash.h b/ffi/include/tvm/ffi/reflection/structural_hash.h new file mode 100644 index 000000000000..b0d17cf8bfbc --- /dev/null +++ b/ffi/include/tvm/ffi/reflection/structural_hash.h @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/reflection/structural_hash.h + * \brief Structural hash + */ +#ifndef TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_ +#define TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_ + +#include + +namespace tvm { +namespace ffi { +namespace reflection { + +/* + * \brief Structural hash + */ +class StructuralHash { + public: + /*! + * \brief Hash an Any value. + * \param value The Any value to hash. + * \param map_free_vars Whether to map free variables. + * \param skip_ndarray_content Whether to skip comparingn darray data content, + * useful for cases where we don't care about parameters content. + * \return The hash value. + */ + TVM_FFI_EXTRA_CXX_API static uint64_t Hash(const Any& value, bool map_free_vars = false, + bool skip_ndarray_content = false); + /*! + * \brief Hash an Any value. + * \param value The Any value to hash. + * \return The hash value. + */ + TVM_FFI_INLINE uint64_t operator()(const Any& value) const { return Hash(value); } +}; + +} // namespace reflection +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_ diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h index 25cdc7c7db70..ed654e8557e0 100644 --- a/ffi/include/tvm/ffi/string.h +++ b/ffi/include/tvm/ffi/string.h @@ -165,7 +165,6 @@ class Bytes : public ObjectRef { TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bytes, ObjectRef, BytesObj); - private: /*! * \brief Compare two char sequence * @@ -178,7 +177,7 @@ class Bytes : public ObjectRef { */ static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); - friend struct AnyEqual; + private: friend class String; }; diff --git a/ffi/src/ffi/container.cc b/ffi/src/ffi/container.cc index 76b9cd6671f9..858cbd47c771 100644 --- a/ffi/src/ffi/container.cc +++ b/ffi/src/ffi/container.cc @@ -18,8 +18,7 @@ * under the License. */ /* - * \file src/ffi/ffi_api.cc - * \brief Extra ffi apis for frontend to access containers. + * \file src/ffi/container.cc */ #include #include diff --git a/ffi/src/ffi/reflection/access_path.cc b/ffi/src/ffi/reflection/access_path.cc new file mode 100644 index 000000000000..17b8abb062ff --- /dev/null +++ b/ffi/src/ffi/reflection/access_path.cc @@ -0,0 +1,34 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/reflection/access_path.cc + */ + +#include + +namespace tvm { +namespace ffi { +namespace reflection { + +TVM_FFI_STATIC_INIT_BLOCK({ AccessStepObj::RegisterReflection(); }); + +} // namespace reflection +} // namespace ffi +} // namespace tvm diff --git a/ffi/src/ffi/reflection/structural_equal.cc b/ffi/src/ffi/reflection/structural_equal.cc new file mode 100644 index 000000000000..622a487fd923 --- /dev/null +++ b/ffi/src/ffi/reflection/structural_equal.cc @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/reflection/structural_equal.cc + * + * \brief Structural equal implementation. + */ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace ffi { +namespace reflection { + +/** + * \brief Internal Handler class for structural equal comparison. + */ +class StructEqualHandler { + public: + StructEqualHandler() = default; + + bool CompareAny(ffi::Any lhs, ffi::Any rhs) { + using ffi::details::AnyUnsafe; + const TVMFFIAny* lhs_data = AnyUnsafe::TVMFFIAnyPtrFromAny(lhs); + const TVMFFIAny* rhs_data = AnyUnsafe::TVMFFIAnyPtrFromAny(rhs); + if (lhs_data->type_index != rhs_data->type_index) { + return false; + } + if (lhs_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) { + // this is POD data, we can just compare the value + return lhs_data->v_int64 == rhs_data->v_int64; + } + switch (lhs_data->type_index) { + case TypeIndex::kTVMFFIStr: + case TypeIndex::kTVMFFIBytes: { + // compare bytes + const BytesObjBase* lhs_str = + AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); + const BytesObjBase* rhs_str = + AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); + return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size) == 0; + } + case TypeIndex::kTVMFFIArray: { + return CompareArray(AnyUnsafe::MoveFromAnyAfterCheck>(std::move(lhs)), + AnyUnsafe::MoveFromAnyAfterCheck>(std::move(rhs))); + } + case TypeIndex::kTVMFFIMap: { + return CompareMap(AnyUnsafe::MoveFromAnyAfterCheck>(std::move(lhs)), + AnyUnsafe::MoveFromAnyAfterCheck>(std::move(rhs))); + } + case TypeIndex::kTVMFFIShape: { + return CompareShape(AnyUnsafe::MoveFromAnyAfterCheck(std::move(lhs)), + AnyUnsafe::MoveFromAnyAfterCheck(std::move(rhs))); + } + case TypeIndex::kTVMFFINDArray: { + return CompareNDArray(AnyUnsafe::MoveFromAnyAfterCheck(std::move(lhs)), + AnyUnsafe::MoveFromAnyAfterCheck(std::move(rhs))); + } + default: { + return CompareObject(AnyUnsafe::MoveFromAnyAfterCheck(std::move(lhs)), + AnyUnsafe::MoveFromAnyAfterCheck(std::move(rhs))); + } + } + } + + bool CompareObject(ObjectRef lhs, ObjectRef rhs) { + // NOTE: invariant: lhs and rhs are already the same type + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(lhs->type_index()); + if (type_info->extra_info == nullptr) { + return lhs.same_as(rhs); + } + auto structural_eq_hash_kind = type_info->extra_info->structural_eq_hash_kind; + + if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported || + structural_eq_hash_kind == kTVMFFISEqHashKindUniqueInstance) { + // use pointer comparison + return lhs.same_as(rhs); + } + if (structural_eq_hash_kind == kTVMFFISEqHashKindConstTreeNode) { + // fast path: constant tree node, pointer equality indicate equality and avoid content + // comparison if false, we should still run content comparison + if (lhs.same_as(rhs)) return true; + } + // check recorded mapping for DAG and fre var + if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode || + structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { + // if there is pre-recorded mapping, need to cross check the pointer equality after mapping + auto it = equal_map_lhs_.find(lhs); + if (it != equal_map_lhs_.end()) { + return it->second.same_as(rhs); + } + // if rhs is mapped but lhs is not, it means lhs is a free var, return false + if (equal_map_rhs_.count(rhs)) { + return false; + } + } + + bool success = true; + if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { + // we are in a free var case that is not yet mapped. + // in this case, either map_free_vars_ should be set to true, or map_free_vars_ should be set + if (!lhs.same_as(rhs) && !map_free_vars_) { + success = false; + } + } else { + // We recursively compare the fields the object + ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) { + // skip fields that are marked as structural eq hash ignore + if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) return false; + // get the field value from both side + FieldGetter getter(field_info); + Any lhs_value = getter(lhs); + Any rhs_value = getter(rhs); + // field is in def region, enable free var mapping + if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) { + bool allow_free_var = true; + std::swap(allow_free_var, map_free_vars_); + success = CompareAny(lhs_value, rhs_value); + std::swap(allow_free_var, map_free_vars_); + } else { + success = CompareAny(lhs_value, rhs_value); + } + if (!success) { + // record the first mismatching field if we sub-rountine compare failed + if (mismatch_lhs_reverse_path_ != nullptr) { + mismatch_lhs_reverse_path_->emplace_back( + AccessStep::ObjectField(String(field_info->name))); + mismatch_rhs_reverse_path_->emplace_back( + AccessStep::ObjectField(String(field_info->name))); + } + // return true to indicate early stop + return true; + } else { + // return false to continue checking other fields + return false; + } + }); + } + if (success) { + // if we have a success mapping and in graph/var mode, record the equality mapping + if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode || + structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { + // record the equality + equal_map_lhs_[lhs] = rhs; + equal_map_rhs_[rhs] = lhs; + } + return true; + } else { + return false; + } + } + + bool CompareMap(Map lhs, Map rhs) { + if (lhs.size() != rhs.size()) { + // size mismatch, and there is no path tracing + // return false since we don't need informative error message + if (mismatch_lhs_reverse_path_ == nullptr) return false; + } + // compare key and value pair by pair + for (auto kv : lhs) { + Any rhs_key = this->MapLhsToRhs(kv.first); + auto it = rhs.find(rhs_key); + if (it == rhs.end()) { + if (mismatch_lhs_reverse_path_ != nullptr) { + mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first)); + mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKeyMissing(rhs_key)); + } + return false; + } + // now recursively compare value + if (!CompareAny(kv.second, (*it).second)) { + if (mismatch_lhs_reverse_path_ != nullptr) { + mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first)); + mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKey(rhs_key)); + } + return false; + } + } + // fast path, all contents equals to each other + if (lhs.size() == rhs.size()) return true; + // slow path, cross check every key from rhs in lhs to find the missing + // key for better error reporting + for (auto kv : rhs) { + Any lhs_key = this->MapRhsToLhs(kv.first); + auto it = lhs.find(lhs_key); + if (it == lhs.end()) { + if (mismatch_lhs_reverse_path_ != nullptr) { + mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKeyMissing(lhs_key)); + mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first)); + } + return false; + } + } + return false; + } + + bool CompareArray(ffi::Array lhs, ffi::Array rhs) { + if (lhs.size() != rhs.size()) { + // fast path, size mismatch, and there is no path tracing + // return false since we don't need informative error message + if (mismatch_lhs_reverse_path_ == nullptr) return false; + } + for (size_t i = 0; i < std::min(lhs.size(), rhs.size()); ++i) { + if (!CompareAny(lhs[i], rhs[i])) { + if (mismatch_lhs_reverse_path_ != nullptr) { + mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(i)); + mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(i)); + } + return false; + } + } + if (lhs.size() == rhs.size()) return true; + if (mismatch_lhs_reverse_path_ != nullptr) { + if (lhs.size() > rhs.size()) { + mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(rhs.size())); + mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndexMissing(rhs.size())); + } else { + mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndexMissing(lhs.size())); + mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(lhs.size())); + } + } + return false; + } + + bool CompareShape(Shape lhs, Shape rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); ++i) { + if (lhs[i] != rhs[i]) { + return false; + } + } + return true; + } + + bool CompareNDArray(NDArray lhs, NDArray rhs) { + if (lhs.same_as(rhs)) return true; + if (lhs->ndim != rhs->ndim) return false; + for (int i = 0; i < lhs->ndim; ++i) { + if (lhs->shape[i] != rhs->shape[i]) return false; + } + if (lhs->dtype != rhs->dtype) return false; + if (!skip_ndarray_content_) { + TVM_FFI_ICHECK_EQ(lhs->device.device_type, kDLCPU) << "can only compare CPU tensor"; + TVM_FFI_ICHECK_EQ(rhs->device.device_type, kDLCPU) << "can only compare CPU tensor"; + TVM_FFI_ICHECK(lhs.IsContiguous()) << "Can only compare contiguous tensor"; + TVM_FFI_ICHECK(rhs.IsContiguous()) << "Can only compare contiguous tensor"; + size_t data_size = GetDataSize(*(lhs.operator->())); + return std::memcmp(lhs->data, rhs->data, data_size) == 0; + } else { + return true; + } + } + + Any MapLhsToRhs(Any lhs) const { + if (lhs.type_index() < TypeIndex::kTVMFFIStaticObjectBegin) { + return lhs; + } + ObjectRef lhs_obj = ffi::details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(lhs)); + auto it = equal_map_lhs_.find(lhs_obj); + if (it != equal_map_lhs_.end()) { + return it->second; + } + return lhs_obj; + } + + Any MapRhsToLhs(Any rhs) const { + if (rhs.type_index() < TypeIndex::kTVMFFIStaticObjectBegin) { + return rhs; + } + ObjectRef rhs_obj = ffi::details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(rhs)); + auto it = equal_map_rhs_.find(rhs_obj); + if (it != equal_map_rhs_.end()) { + return it->second; + } + return rhs_obj; + } + // whether we map free variables that are not defined + bool map_free_vars_{false}; + // whether we compare ndarray data + bool skip_ndarray_content_{false}; + // the root lhs for result printing + std::vector* mismatch_lhs_reverse_path_ = nullptr; + std::vector* mismatch_rhs_reverse_path_ = nullptr; + // map from lhs to rhs + std::unordered_map equal_map_lhs_; + // map from rhs to lhs + std::unordered_map equal_map_rhs_; +}; + +bool StructuralEqual::Equal(const Any& lhs, const Any& rhs, bool map_free_vars, + bool skip_ndarray_content) { + StructEqualHandler handler; + handler.map_free_vars_ = map_free_vars; + handler.skip_ndarray_content_ = skip_ndarray_content; + return handler.CompareAny(lhs, rhs); +} + +Optional StructuralEqual::GetFirstMismatch(const Any& lhs, const Any& rhs, + bool map_free_vars, + bool skip_ndarray_content) { + StructEqualHandler handler; + handler.map_free_vars_ = map_free_vars; + handler.skip_ndarray_content_ = skip_ndarray_content; + std::vector lhs_reverse_path; + std::vector rhs_reverse_path; + handler.mismatch_lhs_reverse_path_ = &lhs_reverse_path; + handler.mismatch_rhs_reverse_path_ = &rhs_reverse_path; + if (handler.CompareAny(lhs, rhs)) { + return std::nullopt; + } + AccessPath lhs_path(lhs_reverse_path.rbegin(), lhs_reverse_path.rend()); + AccessPath rhs_path(rhs_reverse_path.rbegin(), rhs_reverse_path.rend()); + return AccessPathPair(lhs_path, rhs_path); +} + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ffi.reflection.GetFirstStructuralMismatch", + StructuralEqual::GetFirstMismatch); +}); + +} // namespace reflection +} // namespace ffi +} // namespace tvm diff --git a/ffi/src/ffi/reflection/structural_hash.cc b/ffi/src/ffi/reflection/structural_hash.cc new file mode 100644 index 000000000000..dd4167ce3a86 --- /dev/null +++ b/ffi/src/ffi/reflection/structural_hash.cc @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/reflection/structural_equal.cc + * + * \brief Structural equal implementation. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace ffi { +namespace reflection { +/** + * \brief Internal Handler class for structural hash. + */ +class StructuralHashHandler { + public: + StructuralHashHandler() = default; + + uint64_t HashAny(ffi::Any src) { + using ffi::details::AnyUnsafe; + const TVMFFIAny* src_data = AnyUnsafe::TVMFFIAnyPtrFromAny(src); + + if (src_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) { + // this is POD data, we can just hash the value + return details::StableHashCombine(src_data->type_index, src_data->v_uint64); + } + + switch (src_data->type_index) { + case TypeIndex::kTVMFFIStr: + case TypeIndex::kTVMFFIBytes: { + // return same hash as AnyHash + const BytesObjBase* src_str = + AnyUnsafe::CopyFromAnyViewAfterCheck(src); + return details::StableHashCombine(src_data->type_index, + details::StableHashBytes(src_str->data, src_str->size)); + } + case TypeIndex::kTVMFFIArray: { + return HashArray(AnyUnsafe::MoveFromAnyAfterCheck>(std::move(src))); + } + case TypeIndex::kTVMFFIMap: { + return HashMap(AnyUnsafe::MoveFromAnyAfterCheck>(std::move(src))); + } + case TypeIndex::kTVMFFIShape: { + return HashShape(AnyUnsafe::MoveFromAnyAfterCheck(std::move(src))); + } + case TypeIndex::kTVMFFINDArray: { + return HashNDArray(AnyUnsafe::MoveFromAnyAfterCheck(std::move(src))); + } + default: { + return HashObject(AnyUnsafe::MoveFromAnyAfterCheck(std::move(src))); + } + } + } + + uint64_t HashObject(ObjectRef obj) { + // NOTE: invariant: lhs and rhs are already the same type + const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(obj->type_index()); + if (type_info->extra_info == nullptr) { + // Fallback to pointer hash + return std::hash()(obj.get()); + } + auto structural_eq_hash_kind = type_info->extra_info->structural_eq_hash_kind; + if (structural_eq_hash_kind == kTVMFFISEqHashKindUnsupported) { + // Fallback to pointer hash + return std::hash()(obj.get()); + } + // return recored hash value if it is already computed + auto it = hash_memo_.find(obj); + if (it != hash_memo_.end()) { + return it->second; + } + + // compute the hash value + uint64_t hash_value = obj->GetTypeKeyHash(); + if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) { + if (map_free_vars_) { + // use lexical order of free var and its type + hash_value = details::StableHashCombine(hash_value, free_var_counter_++); + } else { + // Fallback to pointer hash, we are not mapping free var. + return std::hash()(obj.get()); + } + } else { + // go over the content and hash the fields + ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) { + // skip fields that are marked as structural eq hash ignore + if (!(field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore)) { + // get the field value from both side + FieldGetter getter(field_info); + Any field_value = getter(obj); + // field is in def region, enable free var mapping + if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) { + bool allow_free_var = true; + std::swap(allow_free_var, map_free_vars_); + hash_value = details::StableHashCombine(hash_value, HashAny(field_value)); + std::swap(allow_free_var, map_free_vars_); + } else { + hash_value = details::StableHashCombine(hash_value, HashAny(field_value)); + } + } + }); + // if it is a DAG node, also record the lexical order of graph counter + // this helps to distinguish DAG from trees. + if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) { + hash_value = details::StableHashCombine(hash_value, graph_node_counter_++); + } + } + // record the hash value for this object + hash_memo_[obj] = hash_value; + return hash_value; + } + + uint64_t HashArray(Array arr) { + uint64_t hash_value = details::StableHashCombine(arr->GetTypeKeyHash(), arr.size()); + for (size_t i = 0; i < arr.size(); ++i) { + hash_value = details::StableHashCombine(hash_value, HashAny(arr[i])); + } + return hash_value; + } + + // Find an order independent hash value for a given Any. + // Order independent hash value means the hash value will remain stable independent + // of the order we hash the content at the current context. + // This property is needed to support stable hash for map. + std::optional FindOrderIndependentHash(Any src) { + using ffi::details::AnyUnsafe; + const TVMFFIAny* src_data = AnyUnsafe::TVMFFIAnyPtrFromAny(src); + + if (src_data->type_index < TypeIndex::kTVMFFIStaticObjectBegin) { + // this is POD data, we can just hash the value + return details::StableHashCombine(src_data->type_index, src_data->v_uint64); + } else { + if (src_data->type_index == TypeIndex::kTVMFFIStr || + src_data->type_index == TypeIndex::kTVMFFIBytes) { + const BytesObjBase* src_str = + AnyUnsafe::CopyFromAnyViewAfterCheck(src); + // return same hash as AnyHash + return details::StableHashCombine(src_data->type_index, + details::StableHashBytes(src_str->data, src_str->size)); + } else { + // if the hash of the object is already computed, return it + auto it = hash_memo_.find(src.cast()); + if (it != hash_memo_.end()) { + return it->second; + } + return std::nullopt; + } + } + } + + uint64_t HashMap(Map map) { + // Compute a deterministic hash value for the map. + uint64_t hash_value = details::StableHashCombine(map->GetTypeKeyHash(), map.size()); + std::vector> items; + for (auto [key, value] : map) { + // if we cannot find order independent hash, we skip the key + if (auto hash_key = FindOrderIndependentHash(key)) { + items.emplace_back(*hash_key, value); + } + } + // sort the items by the hash key, so the hash value is deterministic + // and independent of the order of insertion + std::sort(items.begin(), items.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); + + for (size_t i = 0; i < items.size();) { + size_t k = i + 1; + for (; k < items.size() && items[k].first == items[i].first; ++k) { + } + // detect ties, which are rare, but we need to skip value hash during ties + // to make sure that the hash value is deterministic. + if (k == i + 1) { + // no ties, we just hash the key and value + hash_value = details::StableHashCombine(hash_value, items[i].first); + hash_value = details::StableHashCombine(hash_value, HashAny(items[i].second)); + } else { + // ties occur, we skip the value hash to make sure that the hash value is deterministic. + hash_value = details::StableHashCombine(hash_value, items[i].first); + } + i = k; + } + return hash_value; + } + + uint64_t HashShape(Shape shape) { + uint64_t hash_value = details::StableHashCombine(shape->GetTypeKeyHash(), shape.size()); + for (size_t i = 0; i < shape.size(); ++i) { + hash_value = details::StableHashCombine(hash_value, shape[i]); + } + return hash_value; + } + + uint64_t HashNDArray(NDArray ndarray) { + uint64_t hash_value = details::StableHashCombine(ndarray->GetTypeKeyHash(), ndarray->ndim); + for (int i = 0; i < ndarray->ndim; ++i) { + hash_value = details::StableHashCombine(hash_value, ndarray->shape[i]); + } + TVMFFIAny temp; + temp.v_uint64 = 0; + temp.v_dtype = ndarray->dtype; + hash_value = details::StableHashCombine(hash_value, temp.v_int64); + + if (!skip_ndarray_content_) { + TVM_FFI_ICHECK_EQ(ndarray->device.device_type, kDLCPU) << "can only hash CPU tensor"; + TVM_FFI_ICHECK(ndarray.IsContiguous()) << "Can only hash contiguous tensor"; + size_t data_size = GetDataSize(*(ndarray.operator->())); + uint64_t data_hash = + details::StableHashBytes(static_cast(ndarray->data), data_size); + hash_value = details::StableHashCombine(hash_value, data_hash); + } + return hash_value; + } + + bool map_free_vars_{false}; + bool skip_ndarray_content_{false}; + // free var counter. + uint32_t free_var_counter_{0}; + // graph node counter. + uint32_t graph_node_counter_{0}; + // map from lhs to rhs + std::unordered_map hash_memo_; +}; + +uint64_t StructuralHash::Hash(const Any& value, bool map_free_vars, bool skip_ndarray_content) { + StructuralHashHandler handler; + handler.map_free_vars_ = map_free_vars; + handler.skip_ndarray_content_ = skip_ndarray_content; + return handler.HashAny(value); +} + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ffi.reflection.StructuralHash", StructuralHash::Hash); +}); + +} // namespace reflection +} // namespace ffi +} // namespace tvm diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection_accessor.cc similarity index 88% rename from ffi/tests/cpp/test_reflection.cc rename to ffi/tests/cpp/test_reflection_accessor.cc index 3f6b5def0107..b657f5ff12f8 100644 --- a/ffi/tests/cpp/test_reflection.cc +++ b/ffi/tests/cpp/test_reflection_accessor.cc @@ -50,22 +50,11 @@ struct TestObjADerived : public TestObjA { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("value", &TFloatObj::value, "float value field", refl::DefaultValue(10.0)) - .def("sub", [](const TFloatObj* self, double other) -> double { return self->value - other; }) - .def("add", &TFloatObj::Add, "add method"); - - refl::ObjectDef() - .def_ro("value", &TIntObj::value) - .def_static("static_add", &TInt::StaticAdd, "static add method"); - - refl::ObjectDef() - .def_rw("dtype", &TPrimExprObj::dtype, "dtype field", refl::DefaultValue("float")) - .def_ro("value", &TPrimExprObj::value, "value field", refl::DefaultValue(0)) - .def("sub", [](TPrimExprObj* self, double other) -> double { - // this is ok because TPrimExprObj is declared asmutable - return self->value - other; - }); + TIntObj::RegisterReflection(); + TFloatObj::RegisterReflection(); + TPrimExprObj::RegisterReflection(); + TVarObj::RegisterReflection(); + TFuncObj::RegisterReflection(); refl::ObjectDef().def_ro("x", &TestObjA::x).def_rw("y", &TestObjA::y); refl::ObjectDef().def_ro("z", &TestObjADerived::z); diff --git a/ffi/tests/cpp/test_reflection_structural_equal_hash.cc b/ffi/tests/cpp/test_reflection_structural_equal_hash.cc new file mode 100644 index 000000000000..8646c43c6197 --- /dev/null +++ b/ffi/tests/cpp/test_reflection_structural_equal_hash.cc @@ -0,0 +1,172 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; +namespace refl = tvm::ffi::reflection; + +TEST(StructuralEqualHash, Array) { + Array a = {1, 2, 3}; + Array b = {1, 2, 3}; + EXPECT_TRUE(refl::StructuralEqual()(a, b)); + EXPECT_EQ(refl::StructuralHash()(a), refl::StructuralHash()(b)); + + Array c = {1, 3}; + EXPECT_FALSE(refl::StructuralEqual()(a, c)); + EXPECT_NE(refl::StructuralHash()(a), refl::StructuralHash()(c)); + auto diff_a_c = refl::StructuralEqual::GetFirstMismatch(a, c); + + // first directly interepret diff, + EXPECT_TRUE(diff_a_c.has_value()); + EXPECT_EQ((*diff_a_c).get<0>()[0]->kind, refl::AccessKind::kArrayIndex); + EXPECT_EQ((*diff_a_c).get<1>()[0]->kind, refl::AccessKind::kArrayIndex); + EXPECT_EQ((*diff_a_c).get<0>()[0]->key.cast(), 1); + EXPECT_EQ((*diff_a_c).get<1>()[0]->key.cast(), 1); + EXPECT_EQ((*diff_a_c).get<0>().size(), 1); + EXPECT_EQ((*diff_a_c).get<1>().size(), 1); + + // use structural equal for checking in future parts + // given we have done some basic checks above by directly interepret diff, + Array d = {1, 2}; + auto diff_a_d = refl::StructuralEqual::GetFirstMismatch(a, d); + auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath({ + refl::AccessStep::ArrayIndex(2), + }), + refl::AccessPath({ + refl::AccessStep::ArrayIndexMissing(2), + })); + // then use structural equal to check it + EXPECT_TRUE(refl::StructuralEqual()(diff_a_d, expected_diff_a_d)); +} + +TEST(StructuralEqualHash, Map) { + // same map but different insertion order + Map a = {{"a", 1}, {"b", 2}, {"c", 3}}; + Map b = {{"b", 2}, {"c", 3}, {"a", 1}}; + EXPECT_TRUE(refl::StructuralEqual()(a, b)); + EXPECT_EQ(refl::StructuralHash()(a), refl::StructuralHash()(b)); + + Map c = {{"a", 1}, {"b", 2}, {"c", 4}}; + EXPECT_FALSE(refl::StructuralEqual()(a, c)); + EXPECT_NE(refl::StructuralHash()(a), refl::StructuralHash()(c)); + + auto diff_a_c = refl::StructuralEqual::GetFirstMismatch(a, c); + auto expected_diff_a_c = refl::AccessPathPair(refl::AccessPath({ + refl::AccessStep::MapKey("c"), + }), + refl::AccessPath({ + refl::AccessStep::MapKey("c"), + })); + EXPECT_TRUE(diff_a_c.has_value()); + EXPECT_TRUE(refl::StructuralEqual()(diff_a_c, expected_diff_a_c)); +} + +TEST(StructuralEqualHash, NestedMapArray) { + Map> a = {{"a", {1, 2, 3}}, {"b", {4, "hello", 6}}}; + Map> b = {{"a", {1, 2, 3}}, {"b", {4, "hello", 6}}}; + EXPECT_TRUE(refl::StructuralEqual()(a, b)); + EXPECT_EQ(refl::StructuralHash()(a), refl::StructuralHash()(b)); + + Map> c = {{"a", {1, 2, 3}}, {"b", {4, "world", 6}}}; + EXPECT_FALSE(refl::StructuralEqual()(a, c)); + EXPECT_NE(refl::StructuralHash()(a), refl::StructuralHash()(c)); + + auto diff_a_c = refl::StructuralEqual::GetFirstMismatch(a, c); + auto expected_diff_a_c = refl::AccessPathPair(refl::AccessPath({ + refl::AccessStep::MapKey("b"), + refl::AccessStep::ArrayIndex(1), + }), + refl::AccessPath({ + refl::AccessStep::MapKey("b"), + refl::AccessStep::ArrayIndex(1), + })); + EXPECT_TRUE(diff_a_c.has_value()); + EXPECT_TRUE(refl::StructuralEqual()(diff_a_c, expected_diff_a_c)); + + Map> d = {{"a", {1, 2, 3}}}; + auto diff_a_d = refl::StructuralEqual::GetFirstMismatch(a, d); + auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath({ + refl::AccessStep::MapKey("b"), + }), + refl::AccessPath({ + refl::AccessStep::MapKeyMissing("b"), + })); + EXPECT_TRUE(diff_a_d.has_value()); + EXPECT_TRUE(refl::StructuralEqual()(diff_a_d, expected_diff_a_d)); + + auto diff_d_a = refl::StructuralEqual::GetFirstMismatch(d, a); + auto expected_diff_d_a = refl::AccessPathPair(refl::AccessPath({ + refl::AccessStep::MapKeyMissing("b"), + }), + refl::AccessPath({ + refl::AccessStep::MapKey("b"), + })); +} + +TEST(StructuralEqualHash, FreeVar) { + TVar a = TVar("a"); + TVar b = TVar("b"); + EXPECT_TRUE(refl::StructuralEqual::Equal(a, b, /*map_free_vars=*/true)); + EXPECT_FALSE(refl::StructuralEqual::Equal(a, b)); + + EXPECT_NE(refl::StructuralHash()(a), refl::StructuralHash()(b)); + EXPECT_EQ(refl::StructuralHash::Hash(a, /*map_free_vars=*/true), + refl::StructuralHash::Hash(b, /*map_free_vars=*/true)); +} + +TEST(StructuralEqualHash, FuncDefAndIgnoreField) { + TVar x = TVar("x"); + TVar y = TVar("y"); + // comment fields are ignored + TFunc fa = TFunc({x}, {TInt(1), x}, "comment a"); + TFunc fb = TFunc({y}, {TInt(1), y}, "comment b"); + + TFunc fc = TFunc({x}, {TInt(1), TInt(2)}, "comment c"); + + EXPECT_TRUE(refl::StructuralEqual()(fa, fb)); + EXPECT_EQ(refl::StructuralHash()(fa), refl::StructuralHash()(fb)); + + EXPECT_FALSE(refl::StructuralEqual()(fa, fc)); + auto diff_fa_fc = refl::StructuralEqual::GetFirstMismatch(fa, fc); + auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath({ + refl::AccessStep::ObjectField("body"), + refl::AccessStep::ArrayIndex(1), + }), + refl::AccessPath({ + refl::AccessStep::ObjectField("body"), + refl::AccessStep::ArrayIndex(1), + })); + EXPECT_TRUE(diff_fa_fc.has_value()); + EXPECT_TRUE(refl::StructuralEqual()(diff_fa_fc, expected_diff_fa_fc)); +} + +} // namespace diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index c72c61d00289..8786d194b41c 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -20,6 +20,7 @@ #ifndef TVM_FFI_TESTING_OBJECT_H_ #define TVM_FFI_TESTING_OBJECT_H_ +#include #include #include #include @@ -43,6 +44,7 @@ class TNumberObj : public BasePad, public Object { public: // declare as one slot, with float as overflow static constexpr uint32_t _type_child_slots = 1; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const char* _type_key = "test.Number"; TVM_FFI_DECLARE_BASE_OBJECT_INFO(TNumberObj, Object); }; @@ -62,6 +64,8 @@ class TIntObj : public TNumberObj { static constexpr const char* _type_key = "test.Int"; + inline static void RegisterReflection(); + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TIntObj, TNumberObj); }; @@ -74,6 +78,13 @@ class TInt : public TNumber { TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TInt, TNumber, TIntObj); }; +inline void TIntObj::RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("value", &TIntObj::value) + .def_static("static_add", &TInt::StaticAdd, "static add method"); +} + class TFloatObj : public TNumberObj { public: double value; @@ -102,7 +113,6 @@ class TFloat : public TNumber { TVM_FFI_DEFINE_OBJECT_REF_METHODS(TFloat, TNumber, TFloatObj); }; -// TPrimExpr is used for testing FallbackTraits class TPrimExprObj : public Object { public: std::string dtype; @@ -110,7 +120,19 @@ class TPrimExprObj : public Object { TPrimExprObj(std::string dtype, double value) : dtype(dtype), value(value) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_rw("dtype", &TPrimExprObj::dtype, "dtype field", refl::DefaultValue("float")) + .def_ro("value", &TPrimExprObj::value, "value field", refl::DefaultValue(0)) + .def("sub", [](TPrimExprObj* self, double other) -> double { + // this is ok because TPrimExprObj is declared asmutable + return self->value - other; + }); + } + static constexpr const char* _type_key = "test.PrimExpr"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr bool _type_mutable = true; TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TPrimExprObj, Object); }; @@ -123,6 +145,61 @@ class TPrimExpr : public ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS(TPrimExpr, ObjectRef, TPrimExprObj); }; + +class TVarObj : public Object { + public: + std::string name; + + TVarObj(std::string name) : name(name) {} + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("name", &TVarObj::name); + } + + static constexpr const char* _type_key = "test.Var"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TVarObj, Object); +}; + +class TVar : public ObjectRef { + public: + explicit TVar(std::string name) { data_ = make_object(name); } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS(TVar, ObjectRef, TVarObj); +}; + +class TFuncObj : public Object { + public: + Array params; + Array body; + String comment; + + TFuncObj(Array params, Array body, String comment) + : params(params), body(body), comment(comment) {} + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("params", &TFuncObj::params, refl::AttachFieldFlag::SEqHashDef()) + .def_ro("body", &TFuncObj::body) + .def_ro("comment", &TFuncObj::comment, refl::AttachFieldFlag::SEqHashIgnore()); + } + + static constexpr const char* _type_key = "test.Func"; + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TFuncObj, Object); +}; + +class TFunc : public ObjectRef { + public: + explicit TFunc(Array params, Array body, String comment) { + data_ = make_object(params, body, comment); + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS(TFunc, ObjectRef, TFuncObj); +}; + } // namespace testing template <>