Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
solve comments
  • Loading branch information
Gasoonjia committed Apr 21, 2026
commit ab1cbe1f88035ed588b5e1040b0a390b6f68017d
267 changes: 196 additions & 71 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,8 +648,105 @@ class ET_EXPERIMENTAL CudaBackend final
// Only constants not in the cache are loaded from the blob and added
// to the cache. This avoids duplicate GPU allocations when multiple
// methods (e.g., prefill/decode) share the same weights.
//
// ASSUMPTIONS / LIMITATIONS:
// * Constants with the same FQN across methods are assumed to be the
// SAME logical tensor (i.e. the same parameter/buffer of the same
// source model). We validate shape/dtype/strides/device on every
// reuse to catch silent mismatches (see check_cached_constant_match
// below). However, we cannot detect two unrelated models that
// happen to share an FQN.
// * Constants are assumed to be IMMUTABLE (parameters or read-only
// buffers). The AOTI shim today does not expose a mutability bit
// through GetConstantOriginalFQN, so we cannot detect or refuse
// to share mutable buffers (e.g. a per-method KV cache). If a
// future model exports the same FQN as a mutable buffer in
// multiple methods, mutations from one method WILL be visible to
// the other through the shared GPU memory. Callers that need
// per-method mutable state must currently use distinct FQNs.
// TODO: when AOTInductor exposes a constant-type / mutability
// query, refuse to share entries that are not PARAMETER or
// non-mutable BUFFER.
// ---------------------------------------------------------------

// Validates that a cached constant tensor is compatible with what the
// new container expects for the same FQN (i.e. same dtype, dim, sizes,
// strides, and device). Both handles point to SlimTensors in our shim
// layer, so we can introspect them directly.
//
// Returns Error::Ok on a match. On mismatch, logs the offending field
// and returns Error::Internal so callers can chain via
// ET_CHECK_OK_OR_RETURN_ERROR and fail loudly instead of silently
// pointing the new container at a wrong-shape buffer.
static Error check_cached_constant_match(
const std::string& fqn,
AtenTensorHandle cached_handle,
AtenTensorHandle new_handle) {
ET_CHECK_OR_RETURN_ERROR(
cached_handle != nullptr && new_handle != nullptr,
Internal,
"Constant '%s': null AtenTensorHandle (cached=%p, new=%p)",
fqn.c_str(),
cached_handle,
new_handle);

auto* cached = reinterpret_cast<SlimTensor*>(cached_handle);
auto* fresh = reinterpret_cast<SlimTensor*>(new_handle);

ET_CHECK_OR_RETURN_ERROR(
cached->dtype() == fresh->dtype(),
Internal,
"Constant '%s': dtype mismatch (cached=%d, new=%d)",
fqn.c_str(),
static_cast<int>(cached->dtype()),
static_cast<int>(fresh->dtype()));

ET_CHECK_OR_RETURN_ERROR(
cached->dim() == fresh->dim(),
Internal,
"Constant '%s': dim mismatch (cached=%zu, new=%zu)",
fqn.c_str(),
cached->dim(),
fresh->dim());

auto cached_sizes = cached->sizes();
auto fresh_sizes = fresh->sizes();
for (size_t i = 0; i < cached->dim(); ++i) {
ET_CHECK_OR_RETURN_ERROR(
cached_sizes[i] == fresh_sizes[i],
Internal,
"Constant '%s': size mismatch at dim %zu (cached=%lld, new=%lld)",
fqn.c_str(),
i,
static_cast<long long>(cached_sizes[i]),
static_cast<long long>(fresh_sizes[i]));
}
auto cached_strides = cached->strides();
auto fresh_strides = fresh->strides();
for (size_t i = 0; i < cached->dim(); ++i) {
ET_CHECK_OR_RETURN_ERROR(
cached_strides[i] == fresh_strides[i],
Internal,
"Constant '%s': stride mismatch at dim %zu (cached=%lld, new=%lld)",
fqn.c_str(),
i,
static_cast<long long>(cached_strides[i]),
static_cast<long long>(fresh_strides[i]));
}
ET_CHECK_OR_RETURN_ERROR(
cached->device().type() == fresh->device().type() &&
cached->device().index() == fresh->device().index(),
Internal,
"Constant '%s': device mismatch (cached=%d:%d, new=%d:%d)",
fqn.c_str(),
static_cast<int>(cached->device().type()),
cached->device().index(),
static_cast<int>(fresh->device().type()),
fresh->device().index());

return Error::Ok;
}

// Load constants for a method using per-weight caching.
// Returns Error::Ok on success.
//
Expand Down Expand Up @@ -683,23 +780,28 @@ class ET_EXPERIMENTAL CudaBackend final
return Error::Ok;
}

// Build FQN → internal_name mapping and determine cache hits/misses
// Build FQN → internal_name mapping and determine cache hits/misses.
std::unordered_map<std::string, std::string> fqn_to_name;
std::vector<std::string> uncached_fqns;

// Phase 1 (lock-free): enumerate constants from the container.
for (size_t i = 0; i < num_constants; i++) {
const char* name = nullptr;
const char* fqn = nullptr;
handle->get_constant_name(handle->container_handle, i, &name);
handle->get_constant_original_fqn(handle->container_handle, i, &fqn);
if (name && fqn && fqn[0] != '\0') {
fqn_to_name[fqn] = name;
}
}

// Phase 2 (locked): pure cache lookup against shared_constant_tensors_.
{
std::lock_guard<std::mutex> guard(shared_constants_mutex_);
for (size_t i = 0; i < num_constants; i++) {
const char* name = nullptr;
const char* fqn = nullptr;
handle->get_constant_name(handle->container_handle, i, &name);
handle->get_constant_original_fqn(handle->container_handle, i, &fqn);
if (name && fqn && fqn[0] != '\0') {
fqn_to_name[fqn] = name;
if (shared_constant_tensors_.find(fqn) ==
shared_constant_tensors_.end()) {
uncached_fqns.push_back(fqn);
}
for (const auto& [fqn, _] : fqn_to_name) {
if (shared_constant_tensors_.find(fqn) ==
shared_constant_tensors_.end()) {
uncached_fqns.push_back(fqn);
}
}
}
Expand All @@ -713,55 +815,72 @@ class ET_EXPERIMENTAL CudaBackend final
num_cached,
uncached_fqns.size());

// Step 2: Load uncached constants from blob (if any)
// Step 2: Load uncached constants from blob (if any).
std::unordered_map<std::string, AtenTensorHandle> extracted_map;

if (!uncached_fqns.empty()) {
// Need to load from blob — use update_constants_from_blob for all,
// then extract the new constants into the cache.
std::string weights_blob_key =
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());

if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
ET_LOG(
Info,
"Loading constants from blob '%s' for method '%s'",
weights_blob_key.c_str(),
method_name.c_str());
const void* weights_blob = buffer_res->data();
auto update_err = handle->update_constants_from_blob(
handle->container_handle,
static_cast<const uint8_t*>(weights_blob));
if (update_err != Error::Ok) {
ET_LOG(Error, "update_constants_from_blob failed");
return update_err;
}
cudaDeviceSynchronize();
buffer_res->Free();
} else {
ET_LOG(
Error,
"weights_blob '%s' not found or update fn is null",
weights_blob_key.c_str());
return Error::NotFound;
}
ET_CHECK_OR_RETURN_ERROR(
buffer_res.ok() && handle->update_constants_from_blob != nullptr,
NotFound,
"weights_blob '%s' not found or update fn is null",
weights_blob_key.c_str());

ET_LOG(
Info,
"Loading constants from blob '%s' for method '%s'",
weights_blob_key.c_str(),
method_name.c_str());
const void* weights_blob = buffer_res->data();
ET_CHECK_OK_OR_RETURN_ERROR(
handle->update_constants_from_blob(
handle->container_handle,
static_cast<const uint8_t*>(weights_blob)),
"update_constants_from_blob failed for method '%s'",
method_name.c_str());
cudaDeviceSynchronize();
buffer_res->Free();

// Extract all constants and cache the newly loaded ones
std::unordered_map<std::string, AtenTensorHandle> extracted_map;
auto extract_err = handle->extract_constants_map(
handle->container_handle,
reinterpret_cast<AOTInductorConstantMapHandle>(&extracted_map),
/*use_inactive=*/false);
// Extract all constants from the freshly-loaded container.
ET_CHECK_OK_OR_RETURN_ERROR(
handle->extract_constants_map(
handle->container_handle,
reinterpret_cast<AOTInductorConstantMapHandle>(&extracted_map),
/*use_inactive=*/false),
"Failed to extract constants from '%s'",
method_name.c_str());

if (extract_err == Error::Ok) {
// Validate cache hits against the freshly-extracted tensors, and
// populate the cache with newly-loaded entries.
{
std::lock_guard<std::mutex> guard(shared_constants_mutex_);
for (const auto& fqn : uncached_fqns) {
auto it_name = fqn_to_name.find(fqn);
if (it_name == fqn_to_name.end())
for (const auto& [fqn, _] : fqn_to_name) {
auto extracted_it = extracted_map.find(fqn);
if (extracted_it == extracted_map.end()) {
// Container did not surface this FQN — skip; the user-managed
// pair build below will simply omit it.
continue;
// extract_constants_map returns entries keyed by FQN
auto it = extracted_map.find(fqn);
if (it != extracted_map.end()) {
shared_constant_tensors_[fqn] = it->second;
}
auto cached_it = shared_constant_tensors_.find(fqn);
if (cached_it == shared_constant_tensors_.end()) {
// New constant — add to cache.
shared_constant_tensors_[fqn] = extracted_it->second;
} else {
// Same FQN seen before — verify the cached tensor is still
// compatible with what THIS method expects. On mismatch the
// helper logs the offending field and returns an error.
ET_CHECK_OK_OR_RETURN_ERROR(
check_cached_constant_match(
fqn, cached_it->second, extracted_it->second),
"Constant '%s' in method '%s' is incompatible with the "
"cached version from a previous method. Refusing to share.",
fqn.c_str(),
method_name.c_str());
}
}
ET_LOG(
Expand All @@ -770,15 +889,13 @@ class ET_EXPERIMENTAL CudaBackend final
uncached_fqns.size(),
method_name.c_str(),
shared_constant_tensors_.size());
} else {
ET_LOG(
Error,
"Failed to extract constants from '%s'",
method_name.c_str());
return Error::Internal;
}
} else {
// All constants are cached — skip blob loading entirely!
// All constants are cached — skip blob loading entirely.
// NOTE: in this branch we cannot independently verify the cache
// against the new container's expectations (no extract source).
// We rely on update_user_managed_constant_buffer_pairs below,
// which the AOTI runtime validates internally.
ET_LOG(
Info,
"All %zu constants cached — skipping blob load for method '%s'",
Expand All @@ -800,20 +917,15 @@ class ET_EXPERIMENTAL CudaBackend final
}

if (!pairs.empty()) {
auto update_err = handle->update_user_managed_constant_buffer_pairs(
handle->container_handle,
pairs.data(),
pairs.size(),
/*use_inactive=*/false,
/*validate_full_update=*/false);

if (update_err != Error::Ok) {
ET_LOG(
Error,
"Failed to set cached constants for method '%s'",
method_name.c_str());
return Error::Internal;
}
ET_CHECK_OK_OR_RETURN_ERROR(
handle->update_user_managed_constant_buffer_pairs(
handle->container_handle,
pairs.data(),
pairs.size(),
/*use_inactive=*/false,
/*validate_full_update=*/false),
"Failed to set cached constants for method '%s'",
method_name.c_str());
ET_LOG(
Info,
"Shared %zu cached constants into method '%s'",
Expand Down Expand Up @@ -853,6 +965,19 @@ class ET_EXPERIMENTAL CudaBackend final
}
return Error::Ok;
}
// Guards the singleton FQN → AtenTensorHandle cache below.
//
// The mutex guards init().
// The CudaBackend instance is a process-wide singleton (registered
// once via register_backend()), and shared_constant_tensors_ is a
// shared-across-handles map. ExecuTorch hosts CAN call init() from
// multiple threads when:
// * a multi-threaded application loads two Modules concurrently, or
// * a single Module is loaded from a thread pool.
// Without the mutex, two concurrent init()s could race on
// shared_constant_tensors_ (rehash during insert, double-insert with
// different handles, etc.). The cost is a one-time lock during init,
// which is negligible.
mutable std::mutex shared_constants_mutex_;

// FQN → AtenTensorHandle from the source (first) container.
Expand Down
Loading