2525#include < tvm/ffi/container/map.h>
2626#include < tvm/ffi/container/ndarray.h>
2727#include < tvm/ffi/container/shape.h>
28+ #include < tvm/ffi/extra/structural_equal.h>
2829#include < tvm/ffi/reflection/accessor.h>
29- #include < tvm/ffi/reflection/structural_equal.h>
3030#include < tvm/ffi/string.h>
3131
3232#include < cmath>
3333#include < unordered_map>
3434
3535namespace tvm {
3636namespace ffi {
37- namespace reflection {
3837
3938/* *
4039 * \brief Internal Handler class for structural equal comparison.
@@ -135,11 +134,11 @@ class StructEqualHandler {
135134 bool success = true ;
136135 if (custom_s_equal[type_info->type_index ] == nullptr ) {
137136 // We recursively compare the fields the object
138- ForEachFieldInfoWithEarlyStop (type_info, [&](const TVMFFIFieldInfo* field_info) {
137+ reflection:: ForEachFieldInfoWithEarlyStop (type_info, [&](const TVMFFIFieldInfo* field_info) {
139138 // skip fields that are marked as structural eq hash ignore
140139 if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore ) return false ;
141140 // get the field value from both side
142- FieldGetter getter (field_info);
141+ reflection:: FieldGetter getter (field_info);
143142 Any lhs_value = getter (lhs);
144143 Any rhs_value = getter (rhs);
145144 // field is in def region, enable free var mapping
@@ -155,9 +154,9 @@ class StructEqualHandler {
155154 // record the first mismatching field if we sub-rountine compare failed
156155 if (mismatch_lhs_reverse_path_ != nullptr ) {
157156 mismatch_lhs_reverse_path_->emplace_back (
158- AccessStep::ObjectField (String (field_info->name )));
157+ reflection:: AccessStep::ObjectField (String (field_info->name )));
159158 mismatch_rhs_reverse_path_->emplace_back (
160- AccessStep::ObjectField (String (field_info->name )));
159+ reflection:: AccessStep::ObjectField (String (field_info->name )));
161160 }
162161 // return true to indicate early stop
163162 return true ;
@@ -185,8 +184,10 @@ class StructEqualHandler {
185184 if (!success) {
186185 if (mismatch_lhs_reverse_path_ != nullptr ) {
187186 String field_name_str = field_name.cast <String>();
188- mismatch_lhs_reverse_path_->emplace_back (AccessStep::ObjectField (field_name_str));
189- mismatch_rhs_reverse_path_->emplace_back (AccessStep::ObjectField (field_name_str));
187+ mismatch_lhs_reverse_path_->emplace_back (
188+ reflection::AccessStep::ObjectField (field_name_str));
189+ mismatch_rhs_reverse_path_->emplace_back (
190+ reflection::AccessStep::ObjectField (field_name_str));
190191 }
191192 }
192193 return success;
@@ -235,16 +236,16 @@ class StructEqualHandler {
235236 auto it = rhs.find (rhs_key);
236237 if (it == rhs.end ()) {
237238 if (mismatch_lhs_reverse_path_ != nullptr ) {
238- mismatch_lhs_reverse_path_->emplace_back (AccessStep::MapKey (kv.first ));
239- mismatch_rhs_reverse_path_->emplace_back (AccessStep::MapKeyMissing (rhs_key));
239+ mismatch_lhs_reverse_path_->emplace_back (reflection:: AccessStep::MapItem (kv.first ));
240+ mismatch_rhs_reverse_path_->emplace_back (reflection:: AccessStep::MapItemMissing (rhs_key));
240241 }
241242 return false ;
242243 }
243244 // now recursively compare value
244245 if (!CompareAny (kv.second , (*it).second )) {
245246 if (mismatch_lhs_reverse_path_ != nullptr ) {
246- mismatch_lhs_reverse_path_->emplace_back (AccessStep::MapKey (kv.first ));
247- mismatch_rhs_reverse_path_->emplace_back (AccessStep::MapKey (rhs_key));
247+ mismatch_lhs_reverse_path_->emplace_back (reflection:: AccessStep::MapItem (kv.first ));
248+ mismatch_rhs_reverse_path_->emplace_back (reflection:: AccessStep::MapItem (rhs_key));
248249 }
249250 return false ;
250251 }
@@ -258,8 +259,8 @@ class StructEqualHandler {
258259 auto it = lhs.find (lhs_key);
259260 if (it == lhs.end ()) {
260261 if (mismatch_lhs_reverse_path_ != nullptr ) {
261- mismatch_lhs_reverse_path_->emplace_back (AccessStep::MapKeyMissing (lhs_key));
262- mismatch_rhs_reverse_path_->emplace_back (AccessStep::MapKey (kv.first ));
262+ mismatch_lhs_reverse_path_->emplace_back (reflection:: AccessStep::MapItemMissing (lhs_key));
263+ mismatch_rhs_reverse_path_->emplace_back (reflection:: AccessStep::MapItem (kv.first ));
263264 }
264265 return false ;
265266 }
@@ -276,20 +277,22 @@ class StructEqualHandler {
276277 for (size_t i = 0 ; i < std::min (lhs.size (), rhs.size ()); ++i) {
277278 if (!CompareAny (lhs[i], rhs[i])) {
278279 if (mismatch_lhs_reverse_path_ != nullptr ) {
279- mismatch_lhs_reverse_path_->emplace_back (AccessStep::ArrayIndex (i));
280- mismatch_rhs_reverse_path_->emplace_back (AccessStep::ArrayIndex (i));
280+ mismatch_lhs_reverse_path_->emplace_back (reflection:: AccessStep::ArrayItem (i));
281+ mismatch_rhs_reverse_path_->emplace_back (reflection:: AccessStep::ArrayItem (i));
281282 }
282283 return false ;
283284 }
284285 }
285286 if (lhs.size () == rhs.size ()) return true ;
286287 if (mismatch_lhs_reverse_path_ != nullptr ) {
287288 if (lhs.size () > rhs.size ()) {
288- mismatch_lhs_reverse_path_->emplace_back (AccessStep::ArrayIndex (rhs.size ()));
289- mismatch_rhs_reverse_path_->emplace_back (AccessStep::ArrayIndexMissing (rhs.size ()));
289+ mismatch_lhs_reverse_path_->emplace_back (reflection::AccessStep::ArrayItem (rhs.size ()));
290+ mismatch_rhs_reverse_path_->emplace_back (
291+ reflection::AccessStep::ArrayItemMissing (rhs.size ()));
290292 } else {
291- mismatch_lhs_reverse_path_->emplace_back (AccessStep::ArrayIndexMissing (lhs.size ()));
292- mismatch_rhs_reverse_path_->emplace_back (AccessStep::ArrayIndex (lhs.size ()));
293+ mismatch_lhs_reverse_path_->emplace_back (
294+ reflection::AccessStep::ArrayItemMissing (lhs.size ()));
295+ mismatch_rhs_reverse_path_->emplace_back (reflection::AccessStep::ArrayItem (lhs.size ()));
293296 }
294297 }
295298 return false ;
@@ -354,8 +357,8 @@ class StructEqualHandler {
354357 // whether we compare ndarray data
355358 bool skip_ndarray_content_{false };
356359 // the root lhs for result printing
357- std::vector<AccessStep>* mismatch_lhs_reverse_path_ = nullptr ;
358- std::vector<AccessStep>* mismatch_rhs_reverse_path_ = nullptr ;
360+ std::vector<reflection:: AccessStep>* mismatch_lhs_reverse_path_ = nullptr ;
361+ std::vector<reflection:: AccessStep>* mismatch_rhs_reverse_path_ = nullptr ;
359362 // lazily initialize custom equal function
360363 ffi::Function s_equal_callback_ = nullptr ;
361364 // map from lhs to rhs
@@ -372,32 +375,31 @@ bool StructuralEqual::Equal(const Any& lhs, const Any& rhs, bool map_free_vars,
372375 return handler.CompareAny (lhs, rhs);
373376}
374377
375- Optional<AccessPathPair> StructuralEqual::GetFirstMismatch (const Any& lhs, const Any& rhs,
376- bool map_free_vars,
377- bool skip_ndarray_content) {
378+ Optional<reflection::AccessPathPair> StructuralEqual::GetFirstMismatch (const Any& lhs,
379+ const Any& rhs,
380+ bool map_free_vars,
381+ bool skip_ndarray_content) {
378382 StructEqualHandler handler;
379383 handler.map_free_vars_ = map_free_vars;
380384 handler.skip_ndarray_content_ = skip_ndarray_content;
381- std::vector<AccessStep> lhs_reverse_path;
382- std::vector<AccessStep> rhs_reverse_path;
385+ std::vector<reflection:: AccessStep> lhs_reverse_path;
386+ std::vector<reflection:: AccessStep> rhs_reverse_path;
383387 handler.mismatch_lhs_reverse_path_ = &lhs_reverse_path;
384388 handler.mismatch_rhs_reverse_path_ = &rhs_reverse_path;
385389 if (handler.CompareAny (lhs, rhs)) {
386390 return std::nullopt ;
387391 }
388- AccessPath lhs_path (lhs_reverse_path.rbegin (), lhs_reverse_path.rend ());
389- AccessPath rhs_path (rhs_reverse_path.rbegin (), rhs_reverse_path.rend ());
390- return AccessPathPair (lhs_path, rhs_path);
392+ reflection:: AccessPath lhs_path (lhs_reverse_path.rbegin (), lhs_reverse_path.rend ());
393+ reflection:: AccessPath rhs_path (rhs_reverse_path.rbegin (), rhs_reverse_path.rend ());
394+ return reflection:: AccessPathPair (lhs_path, rhs_path);
391395}
392396
393397TVM_FFI_STATIC_INIT_BLOCK ({
394398 namespace refl = tvm::ffi::reflection;
395- refl::GlobalDef ().def (" ffi.reflection.GetFirstStructuralMismatch" ,
396- StructuralEqual::GetFirstMismatch);
399+ refl::GlobalDef ().def (" ffi.GetFirstStructuralMismatch" , StructuralEqual::GetFirstMismatch);
397400 // ensure the type attribute column is presented in the system even if it is empty.
398401 refl::EnsureTypeAttrColumn (" __s_equal__" );
399402});
400403
401- } // namespace reflection
402404} // namespace ffi
403405} // namespace tvm
0 commit comments