diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh index 3e6d98da54e..8b8783d0db8 100755 --- a/.ci/scripts/test_model_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -397,6 +397,27 @@ if [ -n "$EXPECTED_OUTPUT" ]; then else echo "SUCCESS: Runner completed successfully" fi + +# Validate GPU peak memory usage for models with known memory budgets. +# The runner prints "GPU peak memory usage: XXXX.X MiB" at the end. +case "$MODEL_NAME" in + qwen3_5_moe) + MAX_MEMORY_MIB=20480 # 20 GB — must fit on a single GPU (e.g. 4090) + PEAK_MEM=$(echo "$OUTPUT" | grep -oP 'GPU peak memory usage: \K[0-9.]+' || true) + if [ -n "$PEAK_MEM" ]; then + # Compare as integers (truncate decimals) + PEAK_MEM_INT=${PEAK_MEM%%.*} + if [ "$PEAK_MEM_INT" -gt "$MAX_MEMORY_MIB" ]; then + echo "FAIL: GPU peak memory ${PEAK_MEM} MiB exceeds budget ${MAX_MEMORY_MIB} MiB" + exit 1 + else + echo "Success: GPU peak memory ${PEAK_MEM} MiB within budget (max ${MAX_MEMORY_MIB} MiB)" + fi + else + echo "WARNING: GPU peak memory usage not found in output" + fi + ;; +esac echo "::endgroup::" popd diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py index 0eb775e3459..f9b4b947506 100644 --- a/backends/aoti/aoti_backend.py +++ b/backends/aoti/aoti_backend.py @@ -25,7 +25,6 @@ class COMPILE_SPEC_KEYS(Enum): METHOD_NAME = "method_name" - SHARE_KV_CACHE_ACROSS_METHODS = "share_kv_cache_across_methods" @experimental( @@ -287,13 +286,3 @@ def method_name_from_compile_specs( raise RuntimeError( f"Could not find method name in compile specs: {compile_specs}" ) - - @classmethod - def generate_share_kv_cache_compile_spec(cls) -> CompileSpec: - """ - Generate a CompileSpec to enable cross-method KV cache sharing. - """ - return CompileSpec( - COMPILE_SPEC_KEYS.SHARE_KV_CACHE_ACROSS_METHODS.value, - bytes([1]), - ) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index cf628d3e6fa..6beece6adcd 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -82,9 +83,10 @@ namespace { constexpr char kSkipCopyOutputToCpuForMethod[] = "skip_copy_output_to_cpu_for_method"; constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream"; + constexpr char kEnableCudaGraphForMethod[] = "enable_cuda_graph_for_method"; constexpr int kCudaGraphWarmupSteps = 3; -constexpr char kShareKvCacheAcrossMethods[] = "share_kv_cache_across_methods"; +constexpr char kWeightSharingAcrossMethods[] = "weight_sharing_across_methods"; } // anonymous namespace class ET_EXPERIMENTAL CudaBackend final @@ -192,6 +194,16 @@ class ET_EXPERIMENTAL CudaBackend final return shared_cuda_stream_ != nullptr; } + // Enable cross-method per-FQN weight caching. Set via the + // kWeightSharingAcrossMethods runtime backend option. + void set_weight_sharing_across_methods(bool enabled) { + weight_sharing_across_methods_.store(enabled, std::memory_order_relaxed); + } + + bool is_weight_sharing_across_methods_enabled() const { + return weight_sharing_across_methods_.load(std::memory_order_relaxed); + } + Error load_function_pointers_into_handle( void* so_handle, AOTIDelegateHandle* handle) const { @@ -283,6 +295,14 @@ class ET_EXPERIMENTAL CudaBackend final ET_LOG(Error, "Option %s must be a boolean.", kUseSharedCudaStream); return Error::InvalidArgument; } + } else if (std::strcmp(option.key, kWeightSharingAcrossMethods) == 0) { + if (auto* val = std::get_if(&option.value)) { + set_weight_sharing_across_methods(*val); + } else { + ET_LOG( + Error, + "Option %s must be a boolean.", + kWeightSharingAcrossMethods); } else if (std::strcmp(option.key, kEnableCudaGraphForMethod) == 0) { if (auto* val = std::get_if>( &option.value)) { @@ -317,17 +337,11 @@ class ET_EXPERIMENTAL CudaBackend final ArrayRef compile_specs // This will be my empty list ) const override { std::string method_name; - bool share_kv_cache = false; for (const CompileSpec& spec : compile_specs) { if (std::strcmp(spec.key, "method_name") == 0) { method_name.assign( static_cast(spec.value.buffer), spec.value.nbytes); // no nullptr guarantee, so pass size - } else if (std::strcmp(spec.key, kShareKvCacheAcrossMethods) == 0) { - if (spec.value.nbytes >= 1) { - share_kv_cache = - static_cast(spec.value.buffer)[0] != 0; - } } } @@ -398,29 +412,19 @@ class ET_EXPERIMENTAL CudaBackend final handle->container_handle = container_handle; - // Look into named data map for constant data - std::string weights_blob_key = - method_name.empty() ? "weights_blob" : method_name + "_weights_blob"; - auto buffer_res = named_data_map->get_data(weights_blob_key.c_str()); - if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) { - ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str()); - const void* weights_blob = buffer_res->data(); - // Feed the weights blob into the container. Under the hood it's copying - // weights, so we should free the buffer immediately. - auto update_err = handle->update_constants_from_blob( - handle->container_handle, static_cast(weights_blob)); - if (update_err != Error::Ok) { - ET_LOG(Error, "update_constants_from_blob failed"); - return update_err; - } - // Ensure all weight transfers are complete before execution - cudaDeviceSynchronize(); - buffer_res->Free(); + // Load constants. When weight_sharing_across_methods is enabled (opt-in + // via the kWeightSharingAcrossMethods runtime backend option set by the + // runner), use the per-weight FQN cache so methods that share weights + // (e.g. prefill/decode) avoid duplicate GPU allocations. Otherwise fall + // back to the legacy per-method blob load — required for models whose + // methods are independent sub-graphs that may have FQN collisions + // (e.g. parakeet). + if (is_weight_sharing_across_methods_enabled()) { + ET_CHECK_OK_OR_RETURN_ERROR( + load_constants_with_cache(handle, named_data_map, method_name)); } else { - ET_LOG( - Info, - "weights_blob '%s' not found or update fn is null", - weights_blob_key.c_str()); + ET_CHECK_OK_OR_RETURN_ERROR( + load_constants_legacy(handle, named_data_map, method_name)); } // Use shared CUDA stream if enabled via options, otherwise create one. @@ -448,119 +452,6 @@ class ET_EXPERIMENTAL CudaBackend final method_name.c_str()); } - // --------------------------------------------------------------- - // Cross-method constant sharing (e.g., KV cache between prefill/decode). - // - // Only enabled when share_kv_cache_across_methods compile spec is set. - // The first container to initialize extracts its constants (keyed by - // original FQN) and stores the AtenTensorHandle's. Subsequent containers - // with matching FQNs are updated to point to the same GPU tensors via - // UpdateUserManagedConstantBufferPairs (user_managed = true → no copy, - // the source container retains ownership). - // --------------------------------------------------------------- - if (share_kv_cache && handle->get_num_constants && - handle->get_constant_name && handle->get_constant_original_fqn && - handle->extract_constants_map && - handle->update_user_managed_constant_buffer_pairs) { - size_t num_constants = 0; - handle->get_num_constants(handle->container_handle, &num_constants); - - if (num_constants > 0) { - // Build FQN → internal_name mapping for this container. - std::unordered_map fqn_to_name; - for (size_t i = 0; i < num_constants; i++) { - const char* name = nullptr; - const char* fqn = nullptr; - handle->get_constant_name(handle->container_handle, i, &name); - handle->get_constant_original_fqn(handle->container_handle, i, &fqn); - if (name && fqn && fqn[0] != '\0') { - fqn_to_name[fqn] = name; - } - } - - std::lock_guard guard(shared_constants_mutex_); - - if (!constants_extracted_) { - // First container: extract its constants and store by FQN. - std::unordered_map extracted_map; - auto extract_err = handle->extract_constants_map( - handle->container_handle, - reinterpret_cast(&extracted_map), - /*use_inactive=*/false); - - if (extract_err == Error::Ok) { - for (const auto& [fqn, internal_name] : fqn_to_name) { - auto it = extracted_map.find(fqn); - if (it != extracted_map.end()) { - shared_constant_tensors_[fqn] = it->second; - } - } - constants_extracted_ = true; - ET_LOG( - Info, - "Extracted %zu shared constants from method '%s'", - shared_constant_tensors_.size(), - method_name.c_str()); - } else { - ET_LOG( - Error, - "Failed to extract constants from '%s'", - method_name.c_str()); - delete handle; - return Error::Internal; - } - } else { - // Subsequent container: share matching constants from the first. - std::vector pairs; - for (const auto& [fqn, internal_name] : fqn_to_name) { - auto it = shared_constant_tensors_.find(fqn); - if (it != shared_constant_tensors_.end()) { - // UpdateUserManagedConstantBufferPairs matches against the - // codegen constant name (underscored), not the original FQN. - pairs.push_back({internal_name.c_str(), it->second}); - } - } - - if (!pairs.empty()) { - auto update_err = handle->update_user_managed_constant_buffer_pairs( - handle->container_handle, - pairs.data(), - pairs.size(), - /*use_inactive=*/false, - /*validate_full_update=*/false); - - if (update_err == Error::Ok) { - ET_LOG( - Info, - "Shared %zu constants into method '%s'", - pairs.size(), - method_name.c_str()); - } else { - ET_LOG( - Error, - "Failed to share constants into '%s'", - method_name.c_str()); - delete handle; - return Error::Internal; - } - } - } - } - } else if (share_kv_cache) { - ET_LOG( - Error, - "share_kv_cache_across_methods requested but constant sharing APIs " - "not available for method '%s'", - method_name.c_str()); - delete handle; - return Error::Internal; - } else { - ET_LOG( - Info, - "Constant sharing not requested for method '%s'", - method_name.c_str()); - } - // Initialize CUDA graph state if enabled for this method. if (should_use_cuda_graph_for_method(method_name)) { handle->cuda_graph_state.phase = CudaGraphPhase::Warmup; @@ -990,6 +881,11 @@ class ET_EXPERIMENTAL CudaBackend final mutable std::mutex cuda_stream_mutex_; std::shared_ptr shared_cuda_stream_ = nullptr; + // Whether to enable cross-method per-FQN weight caching at init time. + // Toggled by the kWeightSharingAcrossMethods runtime backend option. Default + // OFF — see set_weight_sharing_across_methods() for safety constraints. + std::atomic weight_sharing_across_methods_{false}; + // Cached output tensors for skip-copy optimization. // When skip-copy is enabled, output SlimTensors are cached here to keep // the underlying GPU memory alive while the caller processes the results. @@ -999,11 +895,345 @@ class ET_EXPERIMENTAL CudaBackend final unordered_map> cached_outputs_; - // Cross-method constant sharing state. - // When multiple AOTI containers share mutable buffers (e.g., KV cache), - // the first container's constants are extracted and stored here. Subsequent - // containers with matching FQNs share the same GPU tensors via - // UpdateUserManagedConstantBufferPairs. + // --------------------------------------------------------------- + // Per-weight constant cache. + // + // Maintains a singleton FQN → AtenTensorHandle cache across methods. + // When loading constants for a method, constants already in the cache + // are reused (zero-copy via update_user_managed_constant_buffer_pairs). + // Only constants not in the cache are loaded from the blob and added + // to the cache. This avoids duplicate GPU allocations when multiple + // methods (e.g., prefill/decode) share the same weights. + // + // ASSUMPTIONS / LIMITATIONS: + // * Constants with the same FQN across methods are assumed to be the + // SAME logical tensor (i.e. the same parameter/buffer of the same + // source model). We validate shape/dtype/strides/device on every + // reuse to catch silent mismatches (see check_cached_constant_match + // below). However, we cannot detect two unrelated models that + // happen to share an FQN. + // * Constants are assumed to be IMMUTABLE (parameters or read-only + // buffers). The AOTI shim today does not expose a mutability bit + // through GetConstantOriginalFQN, so we cannot detect or refuse + // to share mutable buffers (e.g. a per-method KV cache). If a + // future model exports the same FQN as a mutable buffer in + // multiple methods, mutations from one method WILL be visible to + // the other through the shared GPU memory. Callers that need + // per-method mutable state must currently use distinct FQNs. + // TODO: when AOTInductor exposes a constant-type / mutability + // query, refuse to share entries that are not PARAMETER or + // non-mutable BUFFER. + // --------------------------------------------------------------- + + // Validates that a cached constant tensor is compatible with what the + // new container expects for the same FQN (i.e. same dtype, dim, sizes, + // strides, and device). Both handles point to SlimTensors in our shim + // layer, so we can introspect them directly. + // + // Returns Error::Ok on a match. On mismatch, logs the offending field + // and returns Error::Internal so callers can chain via + // ET_CHECK_OK_OR_RETURN_ERROR and fail loudly instead of silently + // pointing the new container at a wrong-shape buffer. + static Error check_cached_constant_match( + const std::string& fqn, + AtenTensorHandle cached_handle, + AtenTensorHandle new_handle) { + ET_CHECK_OR_RETURN_ERROR( + cached_handle != nullptr && new_handle != nullptr, + Internal, + "Constant '%s': null AtenTensorHandle (cached=%p, new=%p)", + fqn.c_str(), + cached_handle, + new_handle); + + auto* cached = reinterpret_cast(cached_handle); + auto* fresh = reinterpret_cast(new_handle); + + ET_CHECK_OR_RETURN_ERROR( + cached->dtype() == fresh->dtype(), + Internal, + "Constant '%s': dtype mismatch (cached=%d, new=%d)", + fqn.c_str(), + static_cast(cached->dtype()), + static_cast(fresh->dtype())); + + ET_CHECK_OR_RETURN_ERROR( + cached->dim() == fresh->dim(), + Internal, + "Constant '%s': dim mismatch (cached=%zu, new=%zu)", + fqn.c_str(), + cached->dim(), + fresh->dim()); + + auto cached_sizes = cached->sizes(); + auto fresh_sizes = fresh->sizes(); + for (size_t i = 0; i < cached->dim(); ++i) { + ET_CHECK_OR_RETURN_ERROR( + cached_sizes[i] == fresh_sizes[i], + Internal, + "Constant '%s': size mismatch at dim %zu (cached=%lld, new=%lld)", + fqn.c_str(), + i, + static_cast(cached_sizes[i]), + static_cast(fresh_sizes[i])); + } + auto cached_strides = cached->strides(); + auto fresh_strides = fresh->strides(); + for (size_t i = 0; i < cached->dim(); ++i) { + ET_CHECK_OR_RETURN_ERROR( + cached_strides[i] == fresh_strides[i], + Internal, + "Constant '%s': stride mismatch at dim %zu (cached=%lld, new=%lld)", + fqn.c_str(), + i, + static_cast(cached_strides[i]), + static_cast(fresh_strides[i])); + } + ET_CHECK_OR_RETURN_ERROR( + cached->device().type() == fresh->device().type() && + cached->device().index() == fresh->device().index(), + Internal, + "Constant '%s': device mismatch (cached=%d:%d, new=%d:%d)", + fqn.c_str(), + static_cast(cached->device().type()), + cached->device().index(), + static_cast(fresh->device().type()), + fresh->device().index()); + + return Error::Ok; + } + + // Load constants for a method using per-weight caching. + // Returns Error::Ok on success. + // + // Flow: + // 1. Enumerate this method's constants and their FQNs. + // 2. For each constant: + // - If FQN is in shared_constant_tensors_ → reuse (cache hit). + // - Otherwise → mark as needing loading (cache miss). + // 3. If all constants are cached → skip blob loading entirely. + // Otherwise → call update_constants_from_blob to load all, then + // extract and cache the new constants. + // 4. For cached constants, call update_user_managed_constant_buffer_pairs + // to point the container to the shared GPU tensors. + Error load_constants_with_cache( + cuda::CudaDelegateHandle* handle, + const NamedDataMap* named_data_map, + const std::string& method_name) const { + // Check if the required APIs are available + if (!handle->get_num_constants || !handle->get_constant_name || + !handle->get_constant_original_fqn || !handle->extract_constants_map || + !handle->update_user_managed_constant_buffer_pairs) { + // Fall back to the legacy path + return load_constants_legacy(handle, named_data_map, method_name); + } + + // Step 1: Enumerate constants and partition into cached/uncached + size_t num_constants = 0; + handle->get_num_constants(handle->container_handle, &num_constants); + if (num_constants == 0) { + ET_LOG(Info, "No constants for method '%s'", method_name.c_str()); + return Error::Ok; + } + + // Build FQN → internal_name mapping and determine cache hits/misses. + std::unordered_map fqn_to_name; + std::vector uncached_fqns; + + // Phase 1 (lock-free): enumerate constants from the container. + for (size_t i = 0; i < num_constants; i++) { + const char* name = nullptr; + const char* fqn = nullptr; + handle->get_constant_name(handle->container_handle, i, &name); + handle->get_constant_original_fqn(handle->container_handle, i, &fqn); + if (name && fqn && fqn[0] != '\0') { + fqn_to_name[fqn] = name; + } + } + + // Phase 2 (locked): pure cache lookup against shared_constant_tensors_. + { + std::lock_guard guard(shared_constants_mutex_); + for (const auto& [fqn, _] : fqn_to_name) { + if (shared_constant_tensors_.find(fqn) == + shared_constant_tensors_.end()) { + uncached_fqns.push_back(fqn); + } + } + } + + size_t num_cached = fqn_to_name.size() - uncached_fqns.size(); + ET_LOG( + Info, + "Method '%s': %zu constants, %zu cached, %zu uncached", + method_name.c_str(), + fqn_to_name.size(), + num_cached, + uncached_fqns.size()); + + // Step 2: Load uncached constants from blob (if any). + std::unordered_map extracted_map; + + if (!uncached_fqns.empty()) { + // Need to load from blob — use update_constants_from_blob for all, + // then extract the new constants into the cache. + std::string weights_blob_key = + method_name.empty() ? "weights_blob" : method_name + "_weights_blob"; + auto buffer_res = named_data_map->get_data(weights_blob_key.c_str()); + + ET_CHECK_OR_RETURN_ERROR( + buffer_res.ok() && handle->update_constants_from_blob != nullptr, + NotFound, + "weights_blob '%s' not found or update fn is null", + weights_blob_key.c_str()); + + ET_LOG( + Info, + "Loading constants from blob '%s' for method '%s'", + weights_blob_key.c_str(), + method_name.c_str()); + const void* weights_blob = buffer_res->data(); + ET_CHECK_OK_OR_RETURN_ERROR( + handle->update_constants_from_blob( + handle->container_handle, + static_cast(weights_blob)), + "update_constants_from_blob failed for method '%s'", + method_name.c_str()); + cudaDeviceSynchronize(); + buffer_res->Free(); + + // Extract all constants from the freshly-loaded container. + ET_CHECK_OK_OR_RETURN_ERROR( + handle->extract_constants_map( + handle->container_handle, + reinterpret_cast(&extracted_map), + /*use_inactive=*/false), + "Failed to extract constants from '%s'", + method_name.c_str()); + + // Validate cache hits against the freshly-extracted tensors, and + // populate the cache with newly-loaded entries. + { + std::lock_guard guard(shared_constants_mutex_); + for (const auto& [fqn, _] : fqn_to_name) { + auto extracted_it = extracted_map.find(fqn); + if (extracted_it == extracted_map.end()) { + // Container did not surface this FQN — skip; the user-managed + // pair build below will simply omit it. + continue; + } + auto cached_it = shared_constant_tensors_.find(fqn); + if (cached_it == shared_constant_tensors_.end()) { + // New constant — add to cache. + shared_constant_tensors_[fqn] = extracted_it->second; + } else { + // Same FQN seen before — verify the cached tensor is still + // compatible with what THIS method expects. On mismatch the + // helper logs the offending field and returns an error. + ET_CHECK_OK_OR_RETURN_ERROR( + check_cached_constant_match( + fqn, cached_it->second, extracted_it->second), + "Constant '%s' in method '%s' is incompatible with the " + "cached version from a previous method. Refusing to share.", + fqn.c_str(), + method_name.c_str()); + } + } + ET_LOG( + Info, + "Cached %zu new constants from method '%s' (total cache: %zu)", + uncached_fqns.size(), + method_name.c_str(), + shared_constant_tensors_.size()); + } + } else { + // All constants are cached — skip blob loading entirely. + // NOTE: in this branch we cannot independently verify the cache + // against the new container's expectations (no extract source). + // We rely on update_user_managed_constant_buffer_pairs below, + // which the AOTI runtime validates internally. + ET_LOG( + Info, + "All %zu constants cached — skipping blob load for method '%s'", + fqn_to_name.size(), + method_name.c_str()); + } + + // Step 3: Point the container to cached tensors via user_managed pairs + if (num_cached > 0 || uncached_fqns.empty()) { + std::vector pairs; + { + std::lock_guard guard(shared_constants_mutex_); + for (const auto& [fqn, internal_name] : fqn_to_name) { + auto it = shared_constant_tensors_.find(fqn); + if (it != shared_constant_tensors_.end()) { + pairs.push_back({internal_name.c_str(), it->second}); + } + } + } + + if (!pairs.empty()) { + ET_CHECK_OK_OR_RETURN_ERROR( + handle->update_user_managed_constant_buffer_pairs( + handle->container_handle, + pairs.data(), + pairs.size(), + /*use_inactive=*/false, + /*validate_full_update=*/false), + "Failed to set cached constants for method '%s'", + method_name.c_str()); + ET_LOG( + Info, + "Shared %zu cached constants into method '%s'", + pairs.size(), + method_name.c_str()); + } + } + + return Error::Ok; + } + + // Legacy constant loading: load the entire blob without caching. + // Used as fallback when constant management APIs are unavailable. + Error load_constants_legacy( + cuda::CudaDelegateHandle* handle, + const NamedDataMap* named_data_map, + const std::string& method_name) const { + std::string weights_blob_key = + method_name.empty() ? "weights_blob" : method_name + "_weights_blob"; + auto buffer_res = named_data_map->get_data(weights_blob_key.c_str()); + if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) { + ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str()); + const void* weights_blob = buffer_res->data(); + auto update_err = handle->update_constants_from_blob( + handle->container_handle, static_cast(weights_blob)); + if (update_err != Error::Ok) { + ET_LOG(Error, "update_constants_from_blob failed"); + return update_err; + } + cudaDeviceSynchronize(); + buffer_res->Free(); + } else { + ET_LOG( + Info, + "weights_blob '%s' not found or update fn is null", + weights_blob_key.c_str()); + } + return Error::Ok; + } + // Guards the singleton FQN → AtenTensorHandle cache below. + // + // The mutex guards init(). + // The CudaBackend instance is a process-wide singleton (registered + // once via register_backend()), and shared_constant_tensors_ is a + // shared-across-handles map. ExecuTorch hosts CAN call init() from + // multiple threads when: + // * a multi-threaded application loads two Modules concurrently, or + // * a single Module is loaded from a thread pool. + // Without the mutex, two concurrent init()s could race on + // shared_constant_tensors_ (rehash during insert, double-insert with + // different handles, etc.). The cost is a one-time lock during init, + // which is negligible. mutable std::mutex shared_constants_mutex_; // FQN → AtenTensorHandle from the source (first) container. @@ -1011,9 +1241,6 @@ class ET_EXPERIMENTAL CudaBackend final // explicitly deleted — see destroy() comment). mutable std::unordered_map shared_constant_tensors_; - - // Whether we've already extracted constants from a source container. - mutable bool constants_extracted_ = false; }; } // namespace executorch::backends::cuda diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index a86aa173d43..ac6c112c08c 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -822,7 +822,6 @@ def _export_cuda(model, config, args): CudaPartitioner( [ CudaBackend.generate_method_name_compile_spec("decode"), - CudaBackend.generate_share_kv_cache_compile_spec(), ] ) ], @@ -830,7 +829,6 @@ def _export_cuda(model, config, args): CudaPartitioner( [ CudaBackend.generate_method_name_compile_spec("prefill"), - CudaBackend.generate_share_kv_cache_compile_spec(), ] ) ], diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 54055e2065d..c5024890645 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -143,6 +143,34 @@ int main(int argc, char** argv) { printf("Loading methods...\n"); + // Enable cross-method per-FQN weight sharing in the CUDA backend so that + // prefill and decode (which share KV cache and other mutable buffers / + // weights) avoid duplicate GPU allocations. This is critical for fitting + // Qwen 3.5 MoE on a single GPU. MUST be set BEFORE load_method, since the + // backend reads this flag during init() to decide between the per-weight + // cache path and the legacy per-method blob load. + { + executorch::runtime::BackendOptions<1> backend_options; + auto set_err = + backend_options.set_option("weight_sharing_across_methods", true); + if (set_err != Error::Ok) { + ET_LOG( + Error, + "Failed to construct weight_sharing_across_methods option: %d", + static_cast(set_err)); + return 1; + } + const auto opt_err = + executorch::runtime::set_option("CudaBackend", backend_options.view()); + if (opt_err != Error::Ok) { + ET_LOG( + Error, + "Failed to enable weight_sharing_across_methods: %d", + static_cast(opt_err)); + return 1; + } + } + auto err = module->load_method("prefill"); if (err != Error::Ok) { ET_LOG(Error, "Failed to load prefill method");