Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions extension/tensor/tensor_ptr.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand All @@ -10,6 +10,8 @@

#include <numeric>

#include <c10/util/safe_numerics.h>

#include <executorch/runtime/core/exec_aten/util/tensor_util.h>

namespace executorch {
Expand Down Expand Up @@ -147,11 +149,22 @@
std::vector<executorch::aten::StridesType> strides,
executorch::aten::ScalarType type,
executorch::aten::TensorShapeDynamism dynamism) {
const ssize_t numel =
executorch::aten::compute_numel_overflow(sizes.data(), sizes.size());
size_t nbytes;
ET_CHECK_MSG(
!c10::mul_overflows(
static_cast<size_t>(numel),
executorch::aten::elementSize(type),
&nbytes),
"Overflow computing nbytes: numel=%zd element_size=%zu",
numel,
executorch::aten::elementSize(type));
ET_CHECK_MSG(
data.size() ==
executorch::aten::compute_numel(sizes.data(), sizes.size()) *
executorch::aten::elementSize(type),
"Data size does not match tensor size.");
data.size() == nbytes,
"Data size (%zu) does not match tensor size (%zu).",
data.size(),
nbytes);
auto data_ptr = data.data();
return make_tensor_ptr(
std::move(sizes),
Expand Down Expand Up @@ -205,7 +218,14 @@
runtime::canCast(tensor_type, type),
"Cannot cast tensor type to desired type.");
const auto tensor_numel = static_cast<size_t>(tensor.numel());
std::vector<uint8_t> data(tensor_numel * aten::elementSize(type));
size_t clone_nbytes;
ET_CHECK_MSG(
!c10::mul_overflows(
tensor_numel, aten::elementSize(type), &clone_nbytes),
"Overflow computing clone nbytes: numel=%zu element_size=%zu",
tensor_numel,
aten::elementSize(type));
std::vector<uint8_t> data(clone_nbytes);

// Create a minimal context for error handling in ET_SWITCH
struct {
Expand Down
5 changes: 3 additions & 2 deletions extension/tensor/tensor_ptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ inline TensorPtr make_tensor_ptr(
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) {
ET_CHECK_MSG(
data.size() ==
executorch::aten::compute_numel(sizes.data(), sizes.size()),
static_cast<size_t>(executorch::aten::compute_numel_overflow(
sizes.data(), sizes.size())),
"Data size does not match tensor size.");
if (type != deduced_type) {
ET_CHECK_MSG(
Expand Down Expand Up @@ -359,7 +360,7 @@ inline TensorPtr make_tensor_ptr(
const auto same_shape = same_rank &&
std::equal(sizes.begin(), sizes.end(), tensor.sizes().begin());
const auto element_count =
executorch::aten::compute_numel(sizes.data(), sizes.size());
executorch::aten::compute_numel_overflow(sizes.data(), sizes.size());
const auto parent_element_count = tensor.numel();
ET_CHECK_MSG(
element_count <= parent_element_count,
Expand Down
15 changes: 13 additions & 2 deletions extension/tensor/tensor_ptr_maker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <random>

#include <c10/util/safe_numerics.h>

namespace executorch {
namespace extension {
namespace {
Expand Down Expand Up @@ -111,9 +113,18 @@ TensorPtr empty_strided(
std::vector<executorch::aten::StridesType> strides,
executorch::aten::ScalarType type,
executorch::aten::TensorShapeDynamism dynamism) {
std::vector<uint8_t> data(
executorch::aten::compute_numel(sizes.data(), sizes.size()) *
const ssize_t numel =
executorch::aten::compute_numel_overflow(sizes.data(), sizes.size());
size_t nbytes;
ET_CHECK_MSG(
!c10::mul_overflows(
static_cast<size_t>(numel),
executorch::aten::elementSize(type),
&nbytes),
"Overflow computing nbytes: numel=%zd element_size=%zu",
numel,
executorch::aten::elementSize(type));
std::vector<uint8_t> data(nbytes);
return make_tensor_ptr(
std::move(sizes),
std::move(data),
Expand Down
Loading