diff --git a/runtime/core/exec_aten/exec_aten.h b/runtime/core/exec_aten/exec_aten.h index 8c06045927e..f539414aec9 100644 --- a/runtime/core/exec_aten/exec_aten.h +++ b/runtime/core/exec_aten/exec_aten.h @@ -8,7 +8,10 @@ #pragma once +#include // @manual +#include // @manual #include // @manual +#include // @manual #include #ifdef USE_ATEN_LIB #include // @manual @@ -28,6 +31,7 @@ #include // @manual #include // @manual #include // @manual +#include // @manual #include // @manual #include #else // use executor @@ -110,6 +114,32 @@ inline ssize_t compute_numel(const SizesType* sizes, ssize_t dim) { c10::multiply_integers(c10::ArrayRef(sizes, dim))); } +inline ::executorch::runtime::Result safe_numel( + const SizesType* sizes, + ssize_t dim) { + ET_CHECK_OR_RETURN_ERROR( + dim == 0 || sizes != nullptr, + InvalidArgument, + "Sizes must be provided for non-scalar tensors"); + ssize_t numel = 1; + for (ssize_t i = 0; i < dim; i++) { + ET_CHECK_OR_RETURN_ERROR( + sizes[i] >= 0, + InvalidArgument, + "Size must be non-negative, got %zd at dimension %zd", + static_cast(sizes[i]), + i); + ssize_t next_numel; + ET_CHECK_OR_RETURN_ERROR( + !c10::mul_overflows(numel, static_cast(sizes[i]), &next_numel), + InvalidArgument, + "Overflow computing numel at dimension %zd", + i); + numel = next_numel; + } + return numel; +} + #undef ET_PRI_TENSOR_SIZE #define ET_PRI_TENSOR_SIZE PRId64 @@ -158,6 +188,7 @@ using OptionalArrayRef = using OptionalIntArrayRef = OptionalArrayRef; using torch::executor::compute_numel; +using torch::executor::safe_numel; #endif // Use ExecuTorch types diff --git a/runtime/core/exec_aten/targets.bzl b/runtime/core/exec_aten/targets.bzl index df4a87ef033..7499d3b0bea 100644 --- a/runtime/core/exec_aten/targets.bzl +++ b/runtime/core/exec_aten/targets.bzl @@ -16,6 +16,9 @@ def define_common_targets(): exported_headers = ["exec_aten.h"], exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [], visibility = ["PUBLIC"], - exported_deps = ["//executorch/runtime/core:tensor_shape_dynamism"] + ([] if aten_mode else ["//executorch/runtime/core/portable_type:portable_type"]), + exported_deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/core:tensor_shape_dynamism", + ] + ([] if aten_mode else ["//executorch/runtime/core/portable_type:portable_type"]), exported_external_deps = ["libtorch"] if aten_mode else [], ) diff --git a/runtime/core/portable_type/tensor_impl.cpp b/runtime/core/portable_type/tensor_impl.cpp index 17243fca0fd..affc5821fed 100644 --- a/runtime/core/portable_type/tensor_impl.cpp +++ b/runtime/core/portable_type/tensor_impl.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -43,6 +44,32 @@ ssize_t compute_numel(const TensorImpl::SizesType* sizes, ssize_t dim) { return numel; } +::executorch::runtime::Result safe_numel( + const TensorImpl::SizesType* sizes, + ssize_t dim) { + ET_CHECK_OR_RETURN_ERROR( + dim == 0 || sizes != nullptr, + InvalidArgument, + "Sizes must be provided for non-scalar tensors"); + ssize_t numel = 1; + for (const auto i : c10::irange(dim)) { + ET_CHECK_OR_RETURN_ERROR( + sizes[i] >= 0, + InvalidArgument, + "Size must be non-negative, got %zd at dimension %zd", + static_cast(sizes[i]), + i); + ssize_t next_numel; + ET_CHECK_OR_RETURN_ERROR( + !c10::mul_overflows(numel, static_cast(sizes[i]), &next_numel), + InvalidArgument, + "Overflow computing numel at dimension %zd", + i); + numel = next_numel; + } + return numel; +} + TensorImpl::TensorImpl( ScalarType type, ssize_t dim, diff --git a/runtime/core/portable_type/tensor_impl.h b/runtime/core/portable_type/tensor_impl.h index ea2cde5aeb0..b01d8fa6c52 100644 --- a/runtime/core/portable_type/tensor_impl.h +++ b/runtime/core/portable_type/tensor_impl.h @@ -12,7 +12,9 @@ #include #include #include +#include #include +#include // Forward declaration of a helper that provides access to internal resizing // methods of TensorImpl. Real definition is in @@ -293,6 +295,16 @@ ssize_t compute_numel( const ::executorch::runtime::etensor::TensorImpl::SizesType* sizes, ssize_t dim); +/** + * Compute the number of elements based on the sizes of a tensor. + * Returns Error::InvalidArgument if any intermediate multiplication would + * overflow ssize_t, or if a size is negative. Prefer this over compute_numel() + * for paths that can propagate an Error upward. + */ +::executorch::runtime::Result safe_numel( + const ::executorch::runtime::etensor::TensorImpl::SizesType* sizes, + ssize_t dim); + /// Appropriate format specifier for the result of calling /// size(). Must be used instead of using zd directly to support ATen /// mode. @@ -322,6 +334,7 @@ namespace executor { // TODO(T197294990): Remove these deprecated aliases once all users have moved // to the new `::executorch` namespaces. using ::executorch::runtime::etensor::compute_numel; +using ::executorch::runtime::etensor::safe_numel; using ::executorch::runtime::etensor::TensorImpl; } // namespace executor } // namespace torch