Skip to content

Commit 93b6cb8

Browse files
authored
[FFI][REFACTOR] Isolate out extra API (apache#18177)
This PR formalizes the extra API in FFI. The extra APIs are minimal set of APIs that are not required in core mechanism, but still helpful. Move structural equal/hash to extra API.
1 parent d6d9f78 commit 93b6cb8

9 files changed

Lines changed: 167 additions & 142 deletions

File tree

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ set(tvm_ffi_objs_sources
5959
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc"
6060
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc"
6161
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc"
62+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
6263
)
6364

6465
if (TVM_FFI_USE_EXTRA_CXX_API)
6566
list(APPEND tvm_ffi_objs_sources
66-
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
67-
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_equal.cc"
68-
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_hash.cc"
67+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_equal.cc"
68+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_hash.cc"
6969
)
7070
endif()
7171

include/tvm/ffi/c_api.h

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,6 @@
5656
#define TVM_FFI_DLL_EXPORT __attribute__((visibility("default")))
5757
#endif
5858

59-
/*!
60-
* \brief Marks the API as extra c++ api that is defined in cc files.
61-
*
62-
* These APIs are extra features that depend on, but are not required to
63-
* support essential core functionality, such as function calling and object
64-
* access.
65-
*
66-
* They are implemented in cc files to reduce compile-time overhead.
67-
* The input/output only uses POD/Any/ObjectRef for ABI stability.
68-
* However, these extra APIs may have an issue across MSVC/Itanium ABI,
69-
*
70-
* Related features are also available through reflection based function
71-
* that is fully based on C API
72-
*
73-
* The project aims to minimize the number of extra C++ APIs and only
74-
* restrict the use to non-core functionalities.
75-
*/
76-
#ifndef TVM_FFI_EXTRA_CXX_API
77-
#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL
78-
#endif
79-
8059
#ifdef __cplusplus
8160
extern "C" {
8261
#endif

include/tvm/ffi/extra/base.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* \file tvm/ffi/extra/base.h
21+
* \brief Base header for Extra API.
22+
*
23+
* The extra APIs contains a minmal set of extra APIs that are not
24+
* required to support essential core functionality.
25+
*/
26+
#ifndef TVM_FFI_EXTRA_BASE_H_
27+
#define TVM_FFI_EXTRA_BASE_H_
28+
29+
#include <tvm/ffi/c_api.h>
30+
31+
/*!
32+
* \brief Marks the API as extra c++ api that is defined in cc files.
33+
*
34+
* They are implemented in cc files to reduce compile-time overhead.
35+
* The input/output only uses POD/Any/ObjectRef for ABI stability.
36+
* However, these extra APIs may have an issue across MSVC/Itanium ABI,
37+
*
38+
* Related features are also available through reflection based function
39+
* that is fully based on C API
40+
*
41+
* The project aims to minimize the number of extra C++ APIs to keep things
42+
* lightweight and restrict the use to non-core functionalities.
43+
*/
44+
#ifndef TVM_FFI_EXTRA_CXX_API
45+
#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL
46+
#endif
47+
48+
#endif // TVM_FFI_EXTRA_BASE_H_

include/tvm/ffi/reflection/structural_equal.h renamed to include/tvm/ffi/extra/structural_equal.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@
1717
* under the License.
1818
*/
1919
/*!
20-
* \file tvm/ffi/reflection/structural_equal.h
20+
* \file tvm/ffi/extra/structural_equal.h
2121
* \brief Structural equal implementation
2222
*/
23-
#ifndef TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_
24-
#define TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_
23+
#ifndef TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_
24+
#define TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_
2525

2626
#include <tvm/ffi/any.h>
27+
#include <tvm/ffi/extra/base.h>
2728
#include <tvm/ffi/optional.h>
2829
#include <tvm/ffi/reflection/access_path.h>
2930

3031
namespace tvm {
3132
namespace ffi {
32-
namespace reflection {
3333
/*
3434
* \brief Structural equality comparators
3535
*/
@@ -59,7 +59,7 @@ class StructuralEqual {
5959
* \return If comparison fails, return the first mismatch AccessPath pair,
6060
* otherwise return std::nullopt.
6161
*/
62-
TVM_FFI_EXTRA_CXX_API static Optional<AccessPathPair> GetFirstMismatch(
62+
TVM_FFI_EXTRA_CXX_API static Optional<reflection::AccessPathPair> GetFirstMismatch(
6363
const Any& lhs, const Any& rhs, bool map_free_vars = false,
6464
bool skip_ndarray_content = false);
6565

@@ -74,7 +74,6 @@ class StructuralEqual {
7474
}
7575
};
7676

77-
} // namespace reflection
7877
} // namespace ffi
7978
} // namespace tvm
80-
#endif // TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_
79+
#endif // TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_

include/tvm/ffi/reflection/structural_hash.h renamed to include/tvm/ffi/extra/structural_hash.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@
1717
* under the License.
1818
*/
1919
/*!
20-
* \file tvm/ffi/reflection/structural_hash.h
20+
* \file tvm/ffi/extra/structural_hash.h
2121
* \brief Structural hash
2222
*/
23-
#ifndef TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_
24-
#define TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_
23+
#ifndef TVM_FFI_EXTRA_STRUCTURAL_HASH_H_
24+
#define TVM_FFI_EXTRA_STRUCTURAL_HASH_H_
2525

2626
#include <tvm/ffi/any.h>
27+
#include <tvm/ffi/extra/base.h>
2728

2829
namespace tvm {
2930
namespace ffi {
30-
namespace reflection {
3131

3232
/*
3333
* \brief Structural hash
@@ -52,7 +52,6 @@ class StructuralHash {
5252
TVM_FFI_INLINE uint64_t operator()(const Any& value) const { return Hash(value); }
5353
};
5454

55-
} // namespace reflection
5655
} // namespace ffi
5756
} // namespace tvm
58-
#endif // TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_
57+
#endif // TVM_FFI_EXTRA_STRUCTURAL_HASH_H_

include/tvm/ffi/reflection/access_path.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ namespace reflection {
3535

3636
enum class AccessKind : int32_t {
3737
kObjectField = 0,
38-
kArrayIndex = 1,
39-
kMapKey = 2,
38+
kArrayItem = 1,
39+
kMapItem = 2,
4040
// the following two are used for error reporting when
4141
// the supposed access field is not available
42-
kArrayIndexMissing = 3,
43-
kMapKeyMissing = 4,
42+
kArrayItemMissing = 3,
43+
kMapItemMissing = 4,
4444
};
4545

4646
/*!
@@ -86,15 +86,15 @@ class AccessStep : public ObjectRef {
8686
return AccessStep(AccessKind::kObjectField, field_name);
8787
}
8888

89-
static AccessStep ArrayIndex(int64_t index) { return AccessStep(AccessKind::kArrayIndex, index); }
89+
static AccessStep ArrayItem(int64_t index) { return AccessStep(AccessKind::kArrayItem, index); }
9090

91-
static AccessStep ArrayIndexMissing(int64_t index) {
92-
return AccessStep(AccessKind::kArrayIndexMissing, index);
91+
static AccessStep ArrayItemMissing(int64_t index) {
92+
return AccessStep(AccessKind::kArrayItemMissing, index);
9393
}
9494

95-
static AccessStep MapKey(Any key) { return AccessStep(AccessKind::kMapKey, key); }
95+
static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem, key); }
9696

97-
static AccessStep MapKeyMissing(Any key) { return AccessStep(AccessKind::kMapKeyMissing, key); }
97+
static AccessStep MapItemMissing(Any key) { return AccessStep(AccessKind::kMapItemMissing, key); }
9898

9999
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessStep, ObjectRef, AccessStepObj);
100100
};
Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,15 @@
2525
#include <tvm/ffi/container/map.h>
2626
#include <tvm/ffi/container/ndarray.h>
2727
#include <tvm/ffi/container/shape.h>
28+
#include <tvm/ffi/extra/structural_equal.h>
2829
#include <tvm/ffi/reflection/accessor.h>
29-
#include <tvm/ffi/reflection/structural_equal.h>
3030
#include <tvm/ffi/string.h>
3131

3232
#include <cmath>
3333
#include <unordered_map>
3434

3535
namespace tvm {
3636
namespace ffi {
37-
namespace reflection {
3837

3938
/**
4039
* \brief Internal Handler class for structural equal comparison.
@@ -135,11 +134,11 @@ class StructEqualHandler {
135134
bool success = true;
136135
if (custom_s_equal[type_info->type_index] == nullptr) {
137136
// We recursively compare the fields the object
138-
ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) {
137+
reflection::ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) {
139138
// skip fields that are marked as structural eq hash ignore
140139
if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) return false;
141140
// get the field value from both side
142-
FieldGetter getter(field_info);
141+
reflection::FieldGetter getter(field_info);
143142
Any lhs_value = getter(lhs);
144143
Any rhs_value = getter(rhs);
145144
// field is in def region, enable free var mapping
@@ -155,9 +154,9 @@ class StructEqualHandler {
155154
// record the first mismatching field if we sub-rountine compare failed
156155
if (mismatch_lhs_reverse_path_ != nullptr) {
157156
mismatch_lhs_reverse_path_->emplace_back(
158-
AccessStep::ObjectField(String(field_info->name)));
157+
reflection::AccessStep::ObjectField(String(field_info->name)));
159158
mismatch_rhs_reverse_path_->emplace_back(
160-
AccessStep::ObjectField(String(field_info->name)));
159+
reflection::AccessStep::ObjectField(String(field_info->name)));
161160
}
162161
// return true to indicate early stop
163162
return true;
@@ -185,8 +184,10 @@ class StructEqualHandler {
185184
if (!success) {
186185
if (mismatch_lhs_reverse_path_ != nullptr) {
187186
String field_name_str = field_name.cast<String>();
188-
mismatch_lhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str));
189-
mismatch_rhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str));
187+
mismatch_lhs_reverse_path_->emplace_back(
188+
reflection::AccessStep::ObjectField(field_name_str));
189+
mismatch_rhs_reverse_path_->emplace_back(
190+
reflection::AccessStep::ObjectField(field_name_str));
190191
}
191192
}
192193
return success;
@@ -235,16 +236,16 @@ class StructEqualHandler {
235236
auto it = rhs.find(rhs_key);
236237
if (it == rhs.end()) {
237238
if (mismatch_lhs_reverse_path_ != nullptr) {
238-
mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first));
239-
mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKeyMissing(rhs_key));
239+
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first));
240+
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItemMissing(rhs_key));
240241
}
241242
return false;
242243
}
243244
// now recursively compare value
244245
if (!CompareAny(kv.second, (*it).second)) {
245246
if (mismatch_lhs_reverse_path_ != nullptr) {
246-
mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first));
247-
mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKey(rhs_key));
247+
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first));
248+
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(rhs_key));
248249
}
249250
return false;
250251
}
@@ -258,8 +259,8 @@ class StructEqualHandler {
258259
auto it = lhs.find(lhs_key);
259260
if (it == lhs.end()) {
260261
if (mismatch_lhs_reverse_path_ != nullptr) {
261-
mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKeyMissing(lhs_key));
262-
mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first));
262+
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItemMissing(lhs_key));
263+
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first));
263264
}
264265
return false;
265266
}
@@ -276,20 +277,22 @@ class StructEqualHandler {
276277
for (size_t i = 0; i < std::min(lhs.size(), rhs.size()); ++i) {
277278
if (!CompareAny(lhs[i], rhs[i])) {
278279
if (mismatch_lhs_reverse_path_ != nullptr) {
279-
mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(i));
280-
mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(i));
280+
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(i));
281+
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(i));
281282
}
282283
return false;
283284
}
284285
}
285286
if (lhs.size() == rhs.size()) return true;
286287
if (mismatch_lhs_reverse_path_ != nullptr) {
287288
if (lhs.size() > rhs.size()) {
288-
mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(rhs.size()));
289-
mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndexMissing(rhs.size()));
289+
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(rhs.size()));
290+
mismatch_rhs_reverse_path_->emplace_back(
291+
reflection::AccessStep::ArrayItemMissing(rhs.size()));
290292
} else {
291-
mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndexMissing(lhs.size()));
292-
mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(lhs.size()));
293+
mismatch_lhs_reverse_path_->emplace_back(
294+
reflection::AccessStep::ArrayItemMissing(lhs.size()));
295+
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(lhs.size()));
293296
}
294297
}
295298
return false;
@@ -354,8 +357,8 @@ class StructEqualHandler {
354357
// whether we compare ndarray data
355358
bool skip_ndarray_content_{false};
356359
// the root lhs for result printing
357-
std::vector<AccessStep>* mismatch_lhs_reverse_path_ = nullptr;
358-
std::vector<AccessStep>* mismatch_rhs_reverse_path_ = nullptr;
360+
std::vector<reflection::AccessStep>* mismatch_lhs_reverse_path_ = nullptr;
361+
std::vector<reflection::AccessStep>* mismatch_rhs_reverse_path_ = nullptr;
359362
// lazily initialize custom equal function
360363
ffi::Function s_equal_callback_ = nullptr;
361364
// map from lhs to rhs
@@ -372,32 +375,31 @@ bool StructuralEqual::Equal(const Any& lhs, const Any& rhs, bool map_free_vars,
372375
return handler.CompareAny(lhs, rhs);
373376
}
374377

375-
Optional<AccessPathPair> StructuralEqual::GetFirstMismatch(const Any& lhs, const Any& rhs,
376-
bool map_free_vars,
377-
bool skip_ndarray_content) {
378+
Optional<reflection::AccessPathPair> StructuralEqual::GetFirstMismatch(const Any& lhs,
379+
const Any& rhs,
380+
bool map_free_vars,
381+
bool skip_ndarray_content) {
378382
StructEqualHandler handler;
379383
handler.map_free_vars_ = map_free_vars;
380384
handler.skip_ndarray_content_ = skip_ndarray_content;
381-
std::vector<AccessStep> lhs_reverse_path;
382-
std::vector<AccessStep> rhs_reverse_path;
385+
std::vector<reflection::AccessStep> lhs_reverse_path;
386+
std::vector<reflection::AccessStep> rhs_reverse_path;
383387
handler.mismatch_lhs_reverse_path_ = &lhs_reverse_path;
384388
handler.mismatch_rhs_reverse_path_ = &rhs_reverse_path;
385389
if (handler.CompareAny(lhs, rhs)) {
386390
return std::nullopt;
387391
}
388-
AccessPath lhs_path(lhs_reverse_path.rbegin(), lhs_reverse_path.rend());
389-
AccessPath rhs_path(rhs_reverse_path.rbegin(), rhs_reverse_path.rend());
390-
return AccessPathPair(lhs_path, rhs_path);
392+
reflection::AccessPath lhs_path(lhs_reverse_path.rbegin(), lhs_reverse_path.rend());
393+
reflection::AccessPath rhs_path(rhs_reverse_path.rbegin(), rhs_reverse_path.rend());
394+
return reflection::AccessPathPair(lhs_path, rhs_path);
391395
}
392396

393397
TVM_FFI_STATIC_INIT_BLOCK({
394398
namespace refl = tvm::ffi::reflection;
395-
refl::GlobalDef().def("ffi.reflection.GetFirstStructuralMismatch",
396-
StructuralEqual::GetFirstMismatch);
399+
refl::GlobalDef().def("ffi.GetFirstStructuralMismatch", StructuralEqual::GetFirstMismatch);
397400
// ensure the type attribute column is presented in the system even if it is empty.
398401
refl::EnsureTypeAttrColumn("__s_equal__");
399402
});
400403

401-
} // namespace reflection
402404
} // namespace ffi
403405
} // namespace tvm

0 commit comments

Comments
 (0)