diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index c8531c88465a..68d6fc92b0cf 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -531,6 +531,24 @@ class IntImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode); }; +/* \brief FFI extention, ObjectRef to integer conversion + * + * If a PackedFunc expects an integer type, and the user passes an + * IntImm as the argument, this specialization allows it to be + * converted by the FFI. + */ +template +struct runtime::PackedFuncObjectRefConverter>> { + static std::optional TryFrom(const ObjectRef& obj) { + if (auto ptr = obj.as()) { + return ptr->value; + } else { + return std::nullopt; + } + } +}; + /*! * \brief Constant floating point literals in the program. * \sa FloatImm @@ -578,6 +596,24 @@ class FloatImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode); }; +/* \brief FFI extention, ObjectRef to integer conversion + * + * If a PackedFunc expects an integer type, and the user passes an + * IntImm as the argument, this specialization allows it to be + * converted by the FFI. + */ +template +struct runtime::PackedFuncObjectRefConverter< + FloatType, std::enable_if_t>> { + static std::optional TryFrom(const ObjectRef& obj) { + if (auto ptr = obj.as()) { + return ptr->value; + } else { + return std::nullopt; + } + } +}; + /*! * \brief Boolean constant. * diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 4159c4b2e764..6a7d1dc724a1 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -37,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -537,6 +538,42 @@ struct ObjectTypeChecker> { } }; +class TVMPODValue_; + +/*! + * \brief Type trait to specify special value conversion rules from + * ObjectRef to primitive types. + * + * TVM containers, such as tvm::runtime::Array, require the contained + * objects to inherit from ObjectRef. As a result, the wrapper types + * IntImm, FloatImm, and StringImm are often used to hold primitive + * types inside a TVM container. Conversions into this type may be + * required when using a container, and may be performed + * automatically when passing an object across the FFI. By also + * handling conversions from wrapped to unwrapped types, these + * conversions can be transparent to users. + * + * The trait can be specialized to add type specific conversion logic + * from the TVMArgvalue and TVMRetValue. + * + * \tparam T The type (e.g. int64_t) which may be contained within the + * ObjectRef. + * + * \tparam (anonymous) An anonymous and unused type parameter, which + * may be used for SFINAE. + */ +template +struct PackedFuncObjectRefConverter { + /*! + * \brief Attempt to convert an ObjectRef from an argument value. + * + * \param obj The ObjectRef which may be convertible to T + * + * \return The converted result, or std::nullopt if not convertible. + */ + static std::optional TryFrom(const ObjectRef& obj) { return std::nullopt; } +}; + /*! * \brief Internal base class to * handle conversion to POD values. @@ -549,25 +586,41 @@ class TVMPODValue_ { // the frontend while the API expects a float. if (type_code_ == kDLInt) { return static_cast(value_.v_int64); + } else if (auto opt = ThroughObjectRef()) { + return opt.value(); + } else if (auto opt = ThroughObjectRef()) { + return static_cast(opt.value()); } TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); return value_.v_float64; } operator int64_t() const { + if (auto opt = ThroughObjectRef()) { + return opt.value(); + } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64; } operator uint64_t() const { + if (auto opt = ThroughObjectRef()) { + return opt.value(); + } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64; } operator int() const { + if (auto opt = ThroughObjectRef()) { + return opt.value(); + } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); ICHECK_LE(value_.v_int64, std::numeric_limits::max()); ICHECK_GE(value_.v_int64, std::numeric_limits::min()); return static_cast(value_.v_int64); } operator bool() const { + if (auto opt = ThroughObjectRef()) { + return opt.value(); + } TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64 != 0; } @@ -638,6 +691,27 @@ class TVMPODValue_ { TVMValue value_; /*! \brief the type code */ int type_code_; + + private: + /* \brief A utility function to check for conversions through + * PackedFuncObjectRefConverter + * + * \tparam T The type to attempt to convert into + * + * \return The converted type, or std::nullopt if the value cannot + * be converted into T. + */ + template + std::optional ThroughObjectRef() const { + if (IsObjectRef()) { + if (std::optional from_obj = + PackedFuncObjectRefConverter::TryFrom(AsObjectRef())) { + return from_obj.value(); + } + } + + return std::nullopt; + } }; /*! diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 183aca1385a7..778017938eef 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -319,3 +319,62 @@ TEST(TypedPackedFunc, RValue) { tf(1, true); } } + +TEST(TypedPackedFunc, IntImmWrapper) { + using namespace tvm::runtime; + + TypedPackedFunc typed_func = [](int x) {}; + PackedFunc func = typed_func; + + // Integer argument may be provided + func(5); + + // IntImm argument may be provided, automatically unwrapped. + tvm::IntImm lvalue_intimm(DataType::Int(32), 10); + func(lvalue_intimm); + + // Unwrapping of IntImm argument works for rvalues as well + func(tvm::IntImm(DataType::Int(32), 10)); +} + +TEST(TypedPackedFunc, FloatImmWrapper) { + using namespace tvm::runtime; + + TypedPackedFunc typed_func = [](double x) {}; + PackedFunc func = typed_func; + + // Argument may be provided as a floating point. If provided as an + // integer, it will be converted to a float. + func(static_cast(5.0)); + func(static_cast(5)); + + // IntImm and FloatImm arguments may be provided, and are + // automatically unwrapped. These arguments work correctly for + // either lvalue or rvalue arguments. + + tvm::IntImm lvalue_intimm(DataType::Int(32), 10); + tvm::FloatImm lvalue_floatimm(DataType::Float(32), 10.5); + + func(lvalue_intimm); + func(lvalue_floatimm); + func(tvm::IntImm(DataType::Int(32), 10)); + func(tvm::FloatImm(DataType::Float(32), 10.5)); +} + +TEST(TypedPackedFunc, BoolWrapper) { + using namespace tvm::runtime; + + TypedPackedFunc typed_func = [](bool x) {}; + PackedFunc func = typed_func; + + // Argument may be provided as an IntImm, or as its subclass Bool. + func(true); + + tvm::IntImm lvalue_intimm(DataType::Int(32), 10); + func(lvalue_intimm); + func(tvm::IntImm(DataType::Int(32), 10)); + + tvm::Bool lvalue_bool(false); + func(lvalue_bool); + func(tvm::Bool(true)); +}