From 3119e69d1aae9d73e84d4d39b73330c221317ad1 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Tue, 25 Mar 2025 15:12:41 -0700 Subject: [PATCH] Utility helpers to convert between std::vector and NSArray. (#9597) Summary: . Reviewed By: swolchok, kirklandsign Differential Revision: D71752746 --- .../ExecuTorch/Internal/ExecuTorchUtils.h | 67 +++++++++++++------ .../ExecuTorch/Internal/ExecuTorchUtils.mm | 35 ++++++++++ runtime/core/span.h | 1 + 3 files changed, 82 insertions(+), 21 deletions(-) create mode 100644 extension/apple/ExecuTorch/Internal/ExecuTorchUtils.mm diff --git a/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h b/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h index 02b82d4a989..9add6dbd48d 100644 --- a/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h +++ b/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h @@ -23,25 +23,7 @@ using namespace runtime; * @param number The NSNumber instance whose scalar type is to be deduced. * @return The corresponding ScalarType. */ -static inline ScalarType deduceType(NSNumber *number) { - auto type = [number objCType][0]; - type = (type >= 'A' && type <= 'Z') ? type + ('a' - 'A') : type; - if (type == 'c') { - return ScalarType::Byte; - } else if (type == 's') { - return ScalarType::Short; - } else if (type == 'i') { - return ScalarType::Int; - } else if (type == 'q' || type == 'l') { - return ScalarType::Long; - } else if (type == 'f') { - return ScalarType::Float; - } else if (type == 'd') { - return ScalarType::Double; - } - ET_CHECK_MSG(false, "Unsupported type: %c", type); - return ScalarType::Undefined; -} +ScalarType deduceType(NSNumber *number); /** * Converts the value held in the NSNumber to the specified C++ type T. @@ -51,8 +33,8 @@ static inline ScalarType deduceType(NSNumber *number) { * @return The value converted to type T. */ template -static inline T extractValue(NSNumber *number) { - ET_CHECK_MSG(!(isFloatingType(deduceScalarType(number)) && +T extractValue(NSNumber *number) { + ET_CHECK_MSG(!(isFloatingType(deduceType(number)) && isIntegralType(CppTypeToScalarType::value, true)), "Cannot convert floating point to integral type"); T value; @@ -93,6 +75,49 @@ static inline T extractValue(NSNumber *number) { return value; } +/** + * Converts an NSArray of NSNumber objects to a std::vector of type T. + * + * @tparam T The target C++ numeric type. + * @param array The NSArray containing NSNumber objects. + * @return A std::vector with the values extracted as type T. + */ +template +std::vector toVector(NSArray *array) { + std::vector vector; + vector.reserve(array.count); + for (NSNumber *number in array) { + vector.push_back(extractValue(number)); + } + return vector; +} + +// Trait for types that can be wrapped into an NSNumber. +template +constexpr bool isNSNumberWrapable = + std::is_arithmetic_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; + +/** + * Converts a generic container of numeric values to an NSArray of NSNumber objects. + * + * @tparam Container The container type holding numeric values. + * @param container The container whose items are to be converted. + * @return An NSArray populated with NSNumber objects representing the container's items. + */ +template +NSArray *toNSArray(const Container &container) { + static_assert(isNSNumberWrapable, "Invalid container value type"); + const NSUInteger count = std::distance(std::begin(container), std::end(container)); + NSMutableArray *array = [NSMutableArray arrayWithCapacity:count]; + for (const auto &item : container) { + [array addObject:@(item)]; + } + return array; +} + } // namespace executorch::extension::utils #endif // __cplusplus diff --git a/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.mm b/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.mm new file mode 100644 index 00000000000..8210b366c0f --- /dev/null +++ b/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.mm @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import "ExecuTorchUtils.h" + +namespace executorch::extension::utils { +using namespace aten; +using namespace runtime; + +ScalarType deduceType(NSNumber *number) { + auto type = [number objCType][0]; + type = (type >= 'A' && type <= 'Z') ? type + ('a' - 'A') : type; + if (type == 'c') { + return ScalarType::Byte; + } else if (type == 's') { + return ScalarType::Short; + } else if (type == 'i') { + return ScalarType::Int; + } else if (type == 'q' || type == 'l') { + return ScalarType::Long; + } else if (type == 'f') { + return ScalarType::Float; + } else if (type == 'd') { + return ScalarType::Double; + } + ET_CHECK_MSG(false, "Unsupported type: %c", type); + return ScalarType::Undefined; +} + +} // namespace executorch::extension::utils diff --git a/runtime/core/span.h b/runtime/core/span.h index b671f340953..1bcde396ccd 100644 --- a/runtime/core/span.h +++ b/runtime/core/span.h @@ -35,6 +35,7 @@ namespace runtime { template class Span final { public: + using value_type = T; using iterator = T*; using size_type = size_t;