[IR] Implemented Variant<...> container#15672
Merged
csullivan merged 3 commits intoapache:mainfrom Sep 13, 2023
Merged
Conversation
This commit introduces a new container, `Variant`, which is analogous to the `std::variant` introduced in C++17, the `enum` in Rust, or a tagged union in C. The `Variant` class is templated over the types that it may contain (e.g. `Variant<String, Expr>`), where each type is a distinct option that can be stored within the container. `Variant` is implemented as a subclass of `ObjectRef` with no additional data members, similar to the implementation of `Optional<T>`. It can be constructed from any of its contained types, and the contents can be inspected using the usual `my_object.as<T>()` and `Downcast<T>(my_object)` methods. This is intended to allow for drop-in replacement of `ObjectRef` with `Variant<Type1, Type2, ...>` in places that previously used a common base class. To ensure that each variant can be uniquely retrieved, no type stored within the variant may inherit from any other type within the variant. This condition is checked at compile-time, with a `static_assert` explaining the limitation. This condition is necessary to mimic the semantics of `std::variant`, whose active member depends on the compile-time type of an object. Without this condition, the expression `Variant<PrimExpr, tir::Var> variant = PrimExpr(...)` could populate either of the variants depending on the run-time type of an object. Because the `Variant` class is primarily intended for use when two types do not already inherit from each other, this limitation is not expected to limit its utility. There are several locations within the TVM codebase where this pattern may be useful, and which are currently worked around various strategies. (This PR does not alter any existing implementations, instead introducing the `Variant` container that can be used in subsequent PRs, if desired.) * Workaround: Store a common base class. For example, the type of `relax::TensorStructInfoNode::shape` is `Optional<Expr>`, with a comment stating that it should be only `NullOpt`, `ShapeExpr`, or `Var`. However, these restrictions are not checked by the compiler, and a developer could erroneously provide a different type. By expressing the type as as `Optional<Variant<Var,ShapeExpr>>`, these errors could be automatically caught. * Workaround: Use additional data structures. For example, a `PrimFunc` parameter may be either a TIR primitive, which is lowered to a primitive type, or a TIR Buffer, which is lowered to a `DLTensor*` argument and appropriate unpacking code. However, these two types are represented as an `Array<tir::Var>` and a `Map<tir::Var, tir::Buffer>`, which together represent a `Array<Variant<tir::Var, tir::Buffer>>`. The separate data structures must be kept in sync whenever modified, such as when removing a parameter. * Workaround: Use `std::variant`. For example, the `tvm::tir::IdentifyMemCpyImpl` utility function returns a `std::variant` with the result or an error message. However, this is only suitable for use within a C++ implementation, and requires a wrapper in order to expose it to the FFI.
387eadc to
422eafd
Compare
csullivan
reviewed
Sep 11, 2023
Comment on lines
+858
to
+876
| TEST(Variant, Construct) { | ||
| Variant<PrimExpr, String> variant; | ||
| variant = PrimExpr(1); | ||
| ICHECK(variant.as<PrimExpr>()); | ||
| ICHECK(!variant.as<String>()); | ||
|
|
||
| variant = String("hello"); | ||
| ICHECK(variant.as<String>()); | ||
| ICHECK(!variant.as<PrimExpr>()); | ||
| } | ||
|
|
||
| TEST(Variant, InvalidTypeThrowsError) { | ||
| auto expected_to_throw = []() { | ||
| ObjectPtr<Object> node = make_object<Object>(); | ||
| Variant<PrimExpr, String> variant(node); | ||
| }; | ||
|
|
||
| EXPECT_THROW(expected_to_throw(), InternalError); | ||
| } |
Contributor
There was a problem hiding this comment.
A rather small set of tests, albeit for a fairly small API surface as compared to Array and Map. Are there other tests we could add? Maybe check assignment?
TEST(Variant, Assignment) {
Variant<PrimExpr, String> variant;
Variant<PrimExpr, String> variant2 = String("foo");
variant = PrimExpr(1);
variant2 = variant;
ICHECK(variant2.as<PrimExpr>());
# check the value of variant2
}
Contributor
Author
There was a problem hiding this comment.
Good point. I made the API surface as small as possible, but there were additional tests that should be included. I've added tests to validate that reference equality is preserved across Variant assignments, and that the values are correctly preserved.
Contributor
|
cc @junrushao |
csullivan
approved these changes
Sep 12, 2023
Comment on lines
+858
to
+876
| TEST(Variant, Construct) { | ||
| Variant<PrimExpr, String> variant; | ||
| variant = PrimExpr(1); | ||
| ICHECK(variant.as<PrimExpr>()); | ||
| ICHECK(!variant.as<String>()); | ||
|
|
||
| variant = String("hello"); | ||
| ICHECK(variant.as<String>()); | ||
| ICHECK(!variant.as<PrimExpr>()); | ||
| } | ||
|
|
||
| TEST(Variant, InvalidTypeThrowsError) { | ||
| auto expected_to_throw = []() { | ||
| ObjectPtr<Object> node = make_object<Object>(); | ||
| Variant<PrimExpr, String> variant(node); | ||
| }; | ||
|
|
||
| EXPECT_THROW(expected_to_throw(), InternalError); | ||
| } |
Lunderberg
added a commit
to Lunderberg/tvm
that referenced
this pull request
Jun 14, 2024
Prior to the implementation of `Variant<...>` in apache#15672, functions that were polymorphic over an argument type would typically accept an `ObjectRef` argument, then downcast to an allowed type. This delays the catching of an error, and can accidentally omit automatic conversions applied by the FFI. This commit updates several locations using this pattern to instead accept a `Variant`, templated over the allowed types. This enables C++ type checking for C++ callers, standardizes the type-checking in the FFI for non-C++ callers, and ensures that FFI type conversions are uniformly applied.
Lunderberg
added a commit
to Lunderberg/tvm
that referenced
this pull request
Jun 14, 2024
Prior to the implementation of `Variant<...>` in apache#15672, functions that were polymorphic over an argument type would typically accept an `ObjectRef` argument, then downcast to an allowed type. This delays the catching of an error, and can accidentally omit automatic conversions applied by the FFI. This commit updates several locations using this pattern to instead accept a `Variant`, templated over the allowed types. This enables C++ type checking for C++ callers, standardizes the type-checking in the FFI for non-C++ callers, and ensures that FFI type conversions are uniformly applied.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This commit introduces a new container,
Variant, which is analogous to thestd::variantintroduced in C++17, theenumin Rust, or a tagged union in C. TheVariantclass is templated over the types that it may contain (e.g.Variant<String, Expr>), where each type is a distinct option that can be stored within the container.Variantis implemented as a subclass ofObjectRefwith no additional data members, similar to the implementation ofOptional<T>. It can be constructed from any of its contained types, and the contents can be inspected using the usualmy_object.as<T>()andDowncast<T>(my_object)methods. This is intended to allow for drop-in replacement ofObjectRefwithVariant<Type1, Type2, ...>in places that previously used a common base class.To ensure that each variant can be uniquely retrieved, no type stored within the variant may inherit from any other type within the variant. This condition is checked at compile-time, with a
static_assertexplaining the limitation. This condition is necessary to mimic the semantics ofstd::variant, whose active member depends on the compile-time type of an object. Without this condition, the expressionVariant<PrimExpr, tir::Var> variant = PrimExpr(...)could populate either of the variants depending on the run-time type of an object. Because theVariantclass is primarily intended for use when two types do not already inherit from each other, this limitation is not expected to limit its utility.There are several locations within the TVM codebase where this pattern may be useful, and which are currently worked around various strategies. (This PR does not alter any existing implementations, instead introducing the
Variantcontainer that can be used in subsequent PRs, if desired.)Workaround: Store a common base class. For example, the type of
relax::TensorStructInfoNode::shapeisOptional<Expr>, with a comment stating that it should be onlyNullOpt,ShapeExpr, orVar. However, these restrictions are not checked by the compiler, and a developer could erroneously provide a different type. By expressing the type as asOptional<Variant<Var,ShapeExpr>>, these errors could be automatically caught.Workaround: Use additional data structures. For example, a
PrimFuncparameter may be either a TIR primitive, which is lowered to a primitive type, or a TIR Buffer, which is lowered to aDLTensor*argument and appropriate unpacking code. However, these two types are represented as anArray<tir::Var>and aMap<tir::Var, tir::Buffer>, which together represent aArray<Variant<tir::Var, tir::Buffer>>. The separate data structures must be kept in sync whenever modified, such as when removing a parameter.Workaround: Use
std::variant. For example, thetvm::tir::IdentifyMemCpyImplutility function returns astd::variantwith the result or an error message. However, this is only suitable for use within a C++ implementation, and requires a wrapper in order to expose it to the FFI.