-
Notifications
You must be signed in to change notification settings - Fork 964
Add safe_numel() #19074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add safe_numel() #19074
Changes from all commits
25e8f81
e9450de
86633f6
b3db241
f09383c
e2b7cdb
1302631
0fb439e
d17950c
34b3e45
1303417
788683f
ae06ce8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,7 +8,10 @@ | |
|
|
||
| #pragma once | ||
|
|
||
| #include <executorch/runtime/core/error.h> // @manual | ||
| #include <executorch/runtime/core/result.h> // @manual | ||
| #include <executorch/runtime/core/tensor_shape_dynamism.h> // @manual | ||
| #include <executorch/runtime/platform/assert.h> // @manual | ||
| #include <executorch/runtime/platform/compiler.h> | ||
| #ifdef USE_ATEN_LIB | ||
| #include <ATen/Tensor.h> // @manual | ||
|
|
@@ -28,6 +31,7 @@ | |
| #include <c10/util/quint2x4.h> // @manual | ||
| #include <c10/util/quint4x2.h> // @manual | ||
| #include <c10/util/quint8.h> // @manual | ||
| #include <c10/util/safe_numerics.h> // @manual | ||
| #include <c10/util/string_view.h> // @manual | ||
| #include <torch/torch.h> | ||
| #else // use executor | ||
|
|
@@ -110,6 +114,32 @@ inline ssize_t compute_numel(const SizesType* sizes, ssize_t dim) { | |
| c10::multiply_integers(c10::ArrayRef<SizesType>(sizes, dim))); | ||
| } | ||
|
|
||
| inline ::executorch::runtime::Result<ssize_t> safe_numel( | ||
| const SizesType* sizes, | ||
| ssize_t dim) { | ||
| ET_CHECK_OR_RETURN_ERROR( | ||
| dim == 0 || sizes != nullptr, | ||
| InvalidArgument, | ||
| "Sizes must be provided for non-scalar tensors"); | ||
| ssize_t numel = 1; | ||
| for (ssize_t i = 0; i < dim; i++) { | ||
| ET_CHECK_OR_RETURN_ERROR( | ||
| sizes[i] >= 0, | ||
| InvalidArgument, | ||
| "Size must be non-negative, got %zd at dimension %zd", | ||
| static_cast<ssize_t>(sizes[i]), | ||
| i); | ||
| ssize_t next_numel; | ||
| ET_CHECK_OR_RETURN_ERROR( | ||
| !c10::mul_overflows(numel, static_cast<ssize_t>(sizes[i]), &next_numel), | ||
| InvalidArgument, | ||
|
Comment on lines
+126
to
+135
|
||
| "Overflow computing numel at dimension %zd", | ||
| i); | ||
| numel = next_numel; | ||
| } | ||
| return numel; | ||
| } | ||
|
lucylq marked this conversation as resolved.
|
||
|
|
||
| #undef ET_PRI_TENSOR_SIZE | ||
| #define ET_PRI_TENSOR_SIZE PRId64 | ||
|
|
||
|
|
@@ -158,6 +188,7 @@ using OptionalArrayRef = | |
| using OptionalIntArrayRef = OptionalArrayRef<int64_t>; | ||
|
|
||
| using torch::executor::compute_numel; | ||
| using torch::executor::safe_numel; | ||
|
|
||
| #endif // Use ExecuTorch types | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| #include <cstdint> | ||
|
|
||
| #include <c10/util/irange.h> | ||
| #include <c10/util/safe_numerics.h> | ||
|
|
||
| #include <executorch/runtime/core/exec_aten/util/dim_order_util.h> | ||
| #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> | ||
|
|
@@ -43,6 +44,32 @@ ssize_t compute_numel(const TensorImpl::SizesType* sizes, ssize_t dim) { | |
| return numel; | ||
| } | ||
|
|
||
| ::executorch::runtime::Result<ssize_t> safe_numel( | ||
| const TensorImpl::SizesType* sizes, | ||
| ssize_t dim) { | ||
| ET_CHECK_OR_RETURN_ERROR( | ||
| dim == 0 || sizes != nullptr, | ||
| InvalidArgument, | ||
| "Sizes must be provided for non-scalar tensors"); | ||
| ssize_t numel = 1; | ||
| for (const auto i : c10::irange(dim)) { | ||
|
Comment on lines
+50
to
+55
|
||
| ET_CHECK_OR_RETURN_ERROR( | ||
| sizes[i] >= 0, | ||
| InvalidArgument, | ||
| "Size must be non-negative, got %zd at dimension %zd", | ||
| static_cast<ssize_t>(sizes[i]), | ||
| i); | ||
| ssize_t next_numel; | ||
| ET_CHECK_OR_RETURN_ERROR( | ||
| !c10::mul_overflows(numel, static_cast<ssize_t>(sizes[i]), &next_numel), | ||
| InvalidArgument, | ||
| "Overflow computing numel at dimension %zd", | ||
| i); | ||
| numel = next_numel; | ||
|
Comment on lines
+62
to
+68
|
||
| } | ||
| return numel; | ||
| } | ||
|
Comment on lines
+47
to
+71
|
||
|
|
||
| TensorImpl::TensorImpl( | ||
| ScalarType type, | ||
| ssize_t dim, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exec_aten::safe_numel() doesn’t reject negative
dim. For malformed metadata, a negative dim will skip the loop and return1, which is indistinguishable from a scalar. Consider explicitly returning InvalidArgument whendim < 0(in addition to the existing sizes!=nullptr check for dim>0).