From 7c20a000f0ac25f70d75c28a1a1d0d83bb44230d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 25 Oct 2023 15:23:11 -0500 Subject: [PATCH 1/2] [FFI] Allow IntImm arguments to PackedFunc with int parameter TVM containers, such as tvm::runtime::Array, require the contained objects to inherit from `ObjectRef`. As a result, the wrapper types `IntImm`, `FloatImm`, and `StringImm` are often used to allow native types in the TVM containers. Conversions into these wrapper type may be required when using a container, and may be performed automatically when passing an object across the FFI. By also providing conversion to an unwrapped type, these automatic conversions are transparent become transparent to users. The trait can be specialized to add type specific conversion logic from the TVMArgvalue and TVMRetValue. --- include/tvm/ir/expr.h | 36 +++++++++++++++ include/tvm/runtime/packed_func.h | 74 +++++++++++++++++++++++++++++++ tests/cpp/packed_func_test.cc | 59 ++++++++++++++++++++++++ 3 files changed, 169 insertions(+) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index c8531c88465a..68d6fc92b0cf 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -531,6 +531,24 @@ class IntImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode); }; +/* \brief FFI extention, ObjectRef to integer conversion + * + * If a PackedFunc expects an integer type, and the user passes an + * IntImm as the argument, this specialization allows it to be + * converted by the FFI. + */ +template +struct runtime::PackedFuncObjectRefConverter>> { + static std::optional TryFrom(const ObjectRef& obj) { + if (auto ptr = obj.as()) { + return ptr->value; + } else { + return std::nullopt; + } + } +}; + /*! * \brief Constant floating point literals in the program. * \sa FloatImm @@ -578,6 +596,24 @@ class FloatImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode); }; +/* \brief FFI extention, ObjectRef to integer conversion + * + * If a PackedFunc expects an integer type, and the user passes an + * IntImm as the argument, this specialization allows it to be + * converted by the FFI. + */ +template +struct runtime::PackedFuncObjectRefConverter< + FloatType, std::enable_if_t>> { + static std::optional TryFrom(const ObjectRef& obj) { + if (auto ptr = obj.as()) { + return ptr->value; + } else { + return std::nullopt; + } + } +}; + /*! * \brief Boolean constant. * diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 4159c4b2e764..ba9ac047b743 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -37,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -537,6 +538,42 @@ struct ObjectTypeChecker> { } }; +class TVMPODValue_; + +/*! + * \brief Type trait to specify special value conversion rules from + * ObjectRef to primitive types. + * + * TVM containers, such as tvm::runtime::Array, require the contained + * objects to inherit from ObjectRef. As a result, the wrapper types + * IntImm, FloatImm, and StringImm are often used to hold primitive + * types inside a TVM container. Conversions into this type may be + * required when using a container, and may be performed + * automatically when passing an object across the FFI. By also + * handling conversions from wrapped to unwrapped types, these + * conversions can be transparent to users. + * + * The trait can be specialized to add type specific conversion logic + * from the TVMArgvalue and TVMRetValue. + * + * \tparam T The type (e.g. int64_t) which may be contained within the + * ObjectRef. + * + * \tparam (anonymous) An anonymous and unused type parameter, which + * may be used for SFINAE. + */ +template +struct PackedFuncObjectRefConverter { + /*! + * \brief Attempt to convert an ObjectRef from an argument value. + * + * \param obj The ObjectRef which may be convertible to T + * + * \return The converted result, or std::nullopt if not convertible. + */ + static std::optional TryFrom(const ObjectRef& obj) { return std::nullopt; } +}; + /*! * \brief Internal base class to * handle conversion to POD values. @@ -549,25 +586,41 @@ class TVMPODValue_ { // the frontend while the API expects a float. if (type_code_ == kDLInt) { return static_cast(value_.v_int64); + } else if (auto opt = ThroughObjectRef()) { + return opt.value(); + } else if (auto opt = ThroughObjectRef()) { + return opt.value(); } TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); return value_.v_float64; } operator int64_t() const { + if (auto opt = ThroughObjectRef()) { + return opt.value(); + } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64; } operator uint64_t() const { + if (auto opt = ThroughObjectRef()) { + return opt.value(); + } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64; } operator int() const { + if (auto opt = ThroughObjectRef()) { + return opt.value(); + } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); ICHECK_LE(value_.v_int64, std::numeric_limits::max()); ICHECK_GE(value_.v_int64, std::numeric_limits::min()); return static_cast(value_.v_int64); } operator bool() const { + if (auto opt = ThroughObjectRef()) { + return opt.value(); + } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64 != 0; } @@ -638,6 +691,27 @@ class TVMPODValue_ { TVMValue value_; /*! \brief the type code */ int type_code_; + + private: + /* \brief A utility function to check for conversions through + * PackedFuncObjectRefConverter + * + * \tparam T The type to attempt to convert into + * + * \return The converted type, or std::nullopt if the value cannot + * be converted into T. + */ + template + std::optional ThroughObjectRef() const { + if (IsObjectRef()) { + if (std::optional from_obj = + PackedFuncObjectRefConverter::TryFrom(AsObjectRef())) { + return from_obj.value(); + } + } + + return std::nullopt; + } }; /*! diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 183aca1385a7..778017938eef 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -319,3 +319,62 @@ TEST(TypedPackedFunc, RValue) { tf(1, true); } } + +TEST(TypedPackedFunc, IntImmWrapper) { + using namespace tvm::runtime; + + TypedPackedFunc typed_func = [](int x) {}; + PackedFunc func = typed_func; + + // Integer argument may be provided + func(5); + + // IntImm argument may be provided, automatically unwrapped. + tvm::IntImm lvalue_intimm(DataType::Int(32), 10); + func(lvalue_intimm); + + // Unwrapping of IntImm argument works for rvalues as well + func(tvm::IntImm(DataType::Int(32), 10)); +} + +TEST(TypedPackedFunc, FloatImmWrapper) { + using namespace tvm::runtime; + + TypedPackedFunc typed_func = [](double x) {}; + PackedFunc func = typed_func; + + // Argument may be provided as a floating point. If provided as an + // integer, it will be converted to a float. + func(static_cast(5.0)); + func(static_cast(5)); + + // IntImm and FloatImm arguments may be provided, and are + // automatically unwrapped. These arguments work correctly for + // either lvalue or rvalue arguments. + + tvm::IntImm lvalue_intimm(DataType::Int(32), 10); + tvm::FloatImm lvalue_floatimm(DataType::Float(32), 10.5); + + func(lvalue_intimm); + func(lvalue_floatimm); + func(tvm::IntImm(DataType::Int(32), 10)); + func(tvm::FloatImm(DataType::Float(32), 10.5)); +} + +TEST(TypedPackedFunc, BoolWrapper) { + using namespace tvm::runtime; + + TypedPackedFunc typed_func = [](bool x) {}; + PackedFunc func = typed_func; + + // Argument may be provided as an IntImm, or as its subclass Bool. + func(true); + + tvm::IntImm lvalue_intimm(DataType::Int(32), 10); + func(lvalue_intimm); + func(tvm::IntImm(DataType::Int(32), 10)); + + tvm::Bool lvalue_bool(false); + func(lvalue_bool); + func(tvm::Bool(true)); +} From d950754291190605d4a10a292809468265ebca36 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 31 Oct 2023 11:14:08 -0500 Subject: [PATCH 2/2] Use explicit static_cast when converting from int64_t --- include/tvm/runtime/packed_func.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index ba9ac047b743..6a7d1dc724a1 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -589,7 +589,7 @@ class TVMPODValue_ { } else if (auto opt = ThroughObjectRef()) { return opt.value(); } else if (auto opt = ThroughObjectRef()) { - return opt.value(); + return static_cast(opt.value()); } TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); return value_.v_float64;