Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,24 @@ class IntImm : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode);
};

/* \brief FFI extention, ObjectRef to integer conversion
*
* If a PackedFunc expects an integer type, and the user passes an
* IntImm as the argument, this specialization allows it to be
* converted by the FFI.
*/
template <typename IntType>
struct runtime::PackedFuncObjectRefConverter<IntType,
std::enable_if_t<std::is_integral_v<IntType>>> {
static std::optional<IntType> TryFrom(const ObjectRef& obj) {
if (auto ptr = obj.as<IntImmNode>()) {
return ptr->value;
} else {
return std::nullopt;
}
}
};

/*!
* \brief Constant floating point literals in the program.
* \sa FloatImm
Expand Down Expand Up @@ -578,6 +596,24 @@ class FloatImm : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode);
};

/* \brief FFI extention, ObjectRef to integer conversion
*
* If a PackedFunc expects an integer type, and the user passes an
* IntImm as the argument, this specialization allows it to be
* converted by the FFI.
*/
template <typename FloatType>
struct runtime::PackedFuncObjectRefConverter<
FloatType, std::enable_if_t<std::is_floating_point_v<FloatType>>> {
static std::optional<FloatType> TryFrom(const ObjectRef& obj) {
if (auto ptr = obj.as<FloatImmNode>()) {
return ptr->value;
} else {
return std::nullopt;
}
}
};

/*!
* \brief Boolean constant.
*
Expand Down
74 changes: 74 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <functional>
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
Expand Down Expand Up @@ -537,6 +538,42 @@ struct ObjectTypeChecker<Map<K, V>> {
}
};

class TVMPODValue_;

/*!
* \brief Type trait to specify special value conversion rules from
* ObjectRef to primitive types.
*
* TVM containers, such as tvm::runtime::Array, require the contained
* objects to inherit from ObjectRef. As a result, the wrapper types
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace, also in lines 550, 552.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not seeing the trailing whitespace in the commit, either using git show --word-diff-regex="[ ]+|[^ ]+", setting my editor to show whitespace, or in the github diff.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I used the wrong definition, there is a double space before "As"

Copy link
Copy Markdown
Contributor Author

@Lunderberg Lunderberg Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. This is intentional, but is more of a stylistic difference than anything else. Traditionally, (accidental diversion into typography) a space after a sentence would use an em space (the same with as a capital M) after a sentence, wider than the spacing between words. This isn't possible in monospaced fonts, so typewriters would emulate the em space by using two spaces. Exactly which is better has since become the topic of a great many flame wars.

I was curious, and it looks like double-spacing after a sentence is about twice as common as single-spacing in the TVM repo, with the Apache copyright header being the most prominent example.

# Count the number of sentences followed by one space
$ find . \
  \( -path ./3rdparty -o -path "./build*" \) -prune -o \
  \( -name "*.cc" -o -name "*.h" \) \
  -exec grep '[A-Za-z]\{2\}\. [A-Za-z]' {} /dev/null \; \
    | wc --lines
3958

# Count the number of sentences followed by two spaces
$ find . \
  \( -path ./3rdparty -o -path "./build*" \) -prune -o \
  \( -name "*.cc" -o -name "*.h" \) \
  -exec grep '[A-Za-z]\{2\}\.  [A-Za-z]' {} /dev/null \; \
    | wc --lines
7666

* IntImm, FloatImm, and StringImm are often used to hold primitive
* types inside a TVM container. Conversions into this type may be
* required when using a container, and may be performed
* automatically when passing an object across the FFI. By also
* handling conversions from wrapped to unwrapped types, these
* conversions can be transparent to users.
*
* The trait can be specialized to add type specific conversion logic
* from the TVMArgvalue and TVMRetValue.
*
* \tparam T The type (e.g. int64_t) which may be contained within the
* ObjectRef.
*
* \tparam (anonymous) An anonymous and unused type parameter, which
* may be used for SFINAE.
*/
template <typename T, typename = void>
struct PackedFuncObjectRefConverter {
/*!
* \brief Attempt to convert an ObjectRef from an argument value.
*
* \param obj The ObjectRef which may be convertible to T
*
* \return The converted result, or std::nullopt if not convertible.
*/
static std::optional<T> TryFrom(const ObjectRef& obj) { return std::nullopt; }
};

/*!
* \brief Internal base class to
* handle conversion to POD values.
Expand All @@ -549,25 +586,41 @@ class TVMPODValue_ {
// the frontend while the API expects a float.
if (type_code_ == kDLInt) {
return static_cast<double>(value_.v_int64);
} else if (auto opt = ThroughObjectRef<double>()) {
return opt.value();
} else if (auto opt = ThroughObjectRef<int64_t>()) {
return static_cast<double>(opt.value());
}
TVM_CHECK_TYPE_CODE(type_code_, kDLFloat);
return value_.v_float64;
}
operator int64_t() const {
if (auto opt = ThroughObjectRef<int64_t>()) {
return opt.value();
}
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64;
}
operator uint64_t() const {
if (auto opt = ThroughObjectRef<uint64_t>()) {
return opt.value();
}
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64;
}
operator int() const {
if (auto opt = ThroughObjectRef<int>()) {
return opt.value();
}
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
ICHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
ICHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return static_cast<int>(value_.v_int64);
}
operator bool() const {
if (auto opt = ThroughObjectRef<bool>()) {
return opt.value();
}
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64 != 0;
}
Expand Down Expand Up @@ -638,6 +691,27 @@ class TVMPODValue_ {
TVMValue value_;
/*! \brief the type code */
int type_code_;

private:
/* \brief A utility function to check for conversions through
* PackedFuncObjectRefConverter
*
* \tparam T The type to attempt to convert into
*
* \return The converted type, or std::nullopt if the value cannot
* be converted into T.
*/
template <typename T>
std::optional<T> ThroughObjectRef() const {
if (IsObjectRef<ObjectRef>()) {
if (std::optional<T> from_obj =
PackedFuncObjectRefConverter<T>::TryFrom(AsObjectRef<ObjectRef>())) {
return from_obj.value();
}
}

return std::nullopt;
}
};

/*!
Expand Down
59 changes: 59 additions & 0 deletions tests/cpp/packed_func_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,62 @@ TEST(TypedPackedFunc, RValue) {
tf(1, true);
}
}

TEST(TypedPackedFunc, IntImmWrapper) {
using namespace tvm::runtime;

TypedPackedFunc<void(int)> typed_func = [](int x) {};
PackedFunc func = typed_func;

// Integer argument may be provided
func(5);

// IntImm argument may be provided, automatically unwrapped.
tvm::IntImm lvalue_intimm(DataType::Int(32), 10);
func(lvalue_intimm);

// Unwrapping of IntImm argument works for rvalues as well
func(tvm::IntImm(DataType::Int(32), 10));
}

TEST(TypedPackedFunc, FloatImmWrapper) {
using namespace tvm::runtime;

TypedPackedFunc<void(double)> typed_func = [](double x) {};
PackedFunc func = typed_func;

// Argument may be provided as a floating point. If provided as an
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace.

// integer, it will be converted to a float.
func(static_cast<double>(5.0));
func(static_cast<int>(5));

// IntImm and FloatImm arguments may be provided, and are
// automatically unwrapped. These arguments work correctly for
// either lvalue or rvalue arguments.

tvm::IntImm lvalue_intimm(DataType::Int(32), 10);
tvm::FloatImm lvalue_floatimm(DataType::Float(32), 10.5);

func(lvalue_intimm);
func(lvalue_floatimm);
func(tvm::IntImm(DataType::Int(32), 10));
func(tvm::FloatImm(DataType::Float(32), 10.5));
}

TEST(TypedPackedFunc, BoolWrapper) {
using namespace tvm::runtime;

TypedPackedFunc<void(bool)> typed_func = [](bool x) {};
PackedFunc func = typed_func;

// Argument may be provided as an IntImm, or as its subclass Bool.
func(true);

tvm::IntImm lvalue_intimm(DataType::Int(32), 10);
func(lvalue_intimm);
func(tvm::IntImm(DataType::Int(32), 10));

tvm::Bool lvalue_bool(false);
func(lvalue_bool);
func(tvm::Bool(true));
}