Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -415,14 +415,40 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then

# Export to .pte/.ptd (short cache dir avoids objcopy symbol length issues)
echo "::group::Export"
EXPORT_LOG=$(mktemp)
TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \
python -m executorch.examples.models.qwen3_5_moe.export \
--prequantized "$LOCAL_MODEL_DIR" \
--output-dir "${OUTPUT_DIR}" \
--dense-prefill dequant \
--moe-activation-dtype int8
--moe-activation-dtype int8 2>&1 | tee "$EXPORT_LOG"
EXPORT_RC=${PIPESTATUS[0]}
echo "::endgroup::"

if [ "$EXPORT_RC" -ne 0 ]; then
echo "ERROR: Qwen3.5 MoE export failed (exit $EXPORT_RC)"
rm -f "$EXPORT_LOG"
exit "$EXPORT_RC"
fi

# Gate peak GPU memory so we keep the export viable on consumer GPUs
# (e.g. RTX 4090 with 24 GB). The export script prints a machine-
# parseable marker line "EXPORT_GPU_PEAK_MEMORY_MB: <float>".
EXPORT_GPU_PEAK_MB_LIMIT="${EXPORT_GPU_PEAK_MB_LIMIT:-20480}"
PEAK_LINE=$(grep -E '^EXPORT_GPU_PEAK_MEMORY_MB:' "$EXPORT_LOG" | tail -1)
rm -f "$EXPORT_LOG"
if [ -z "$PEAK_LINE" ]; then
echo "ERROR: export did not emit EXPORT_GPU_PEAK_MEMORY_MB marker; cannot enforce GPU memory budget"
exit 1
fi
PEAK_MB=$(echo "$PEAK_LINE" | awk '{print $2}')
echo "Export GPU peak memory: ${PEAK_MB} MB (limit ${EXPORT_GPU_PEAK_MB_LIMIT} MB)"
if awk -v p="$PEAK_MB" -v l="$EXPORT_GPU_PEAK_MB_LIMIT" 'BEGIN{exit !(p>l)}'; then
echo "ERROR: export exceeded GPU memory budget (${PEAK_MB} MB > ${EXPORT_GPU_PEAK_MB_LIMIT} MB)"
echo " — this would prevent the model from being exported on a 24 GB consumer GPU."
exit 1
fi

test -f "${OUTPUT_DIR}/model.pte"
test -f "${OUTPUT_DIR}/aoti_cuda_blob.ptd"
ls -al "${OUTPUT_DIR}"
Expand Down
38 changes: 34 additions & 4 deletions backends/aoti/aoti_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import typing
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Set
from typing import Any, Dict, List, Optional, Set

import torch
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
Expand Down Expand Up @@ -88,8 +88,14 @@ def save_data_externally(cls) -> bool:
return False

@classmethod
def get_extra_aoti_compile_context_manager(cls):
"""Return extra context manager to apply during aoti_compile stage. By default returns an empty context manager."""
def get_extra_aoti_compile_context_manager(
cls, compile_specs: Optional[List[CompileSpec]] = None
):
"""Return extra context manager to apply during aoti_compile stage. By default returns an empty context manager.

Subclasses may inspect ``compile_specs`` to opt into behaviors that
only apply to specific methods/models (e.g. low-memory export).
"""
return contextlib.nullcontext()

@classmethod
Expand All @@ -105,6 +111,24 @@ def codesign_so(cls, so_path: str, compile_specs: List[CompileSpec]) -> None:
"""
return

@classmethod
def release_moved_tensors(
cls,
device_edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> None:
"""Release device memory held by tensors that ``move_to_device_pass``
placed on the target device.

Called at the end of ``preprocess`` so that the next ``preprocess``
call (e.g. for the next method in a multi-method export) can reuse
the freed memory. Override in concrete backends (e.g. ``CudaBackend``)
to actually free device memory.

Default: no-op.
"""
return

@classmethod
@contextlib.contextmanager
def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]):
Expand Down Expand Up @@ -208,7 +232,7 @@ def preprocess(
# Compile with fallback kernel collection
with cls.collect_unsupported_fallback_kernels(
missing_fallback_kernels
), torch.no_grad(), cls.get_extra_aoti_compile_context_manager():
), torch.no_grad(), cls.get_extra_aoti_compile_context_manager(compile_specs):
paths = torch._inductor.aot_compile(
edge_program_module, tuple(user_input_placeholders), options=options
)
Expand Down Expand Up @@ -269,6 +293,12 @@ def preprocess(
os.remove(so_path)
os.remove(blob_path)

# Release device memory held by tensors that ``move_to_device_pass``
# placed on the target device. Default impl is a no-op; concrete
# backends (e.g. CudaBackend) override this to free GPU memory before
# the next preprocess call (e.g. for the next method).
cls.release_moved_tensors(device_edge_program, compile_specs)

return PreprocessResult(
processed_bytes=b"",
debug_handle_map={},
Expand Down
185 changes: 171 additions & 14 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
# LICENSE file in the root directory of this source tree.


import contextlib
import logging
import os
import shutil
import threading
import typing
from importlib import resources
from typing import Any, Dict, final, List, Optional
Expand All @@ -27,6 +29,83 @@
from torch.nn.attention import SDPBackend


# ---------------------------------------------------------------------------
# AOTI compile-time CPU clones for mutated buffers
# ---------------------------------------------------------------------------
#
# Inductor's `_unlift_graph` clones every mutated buffer that gets lifted into
# the AOTI graph. By default it clones on whatever device the original tensor
# lives on — which after `move_to_device_pass` is CUDA. For Large models like
# Qwen3.5-MoE that means an extra ~18 GB GPU clone during compile, blowing past
# the 24 GB cap we want to honor for consumer GPUs (RTX 4090 and similar).
#
# The patch below side-steps that by:
# 1. Wrapping `torch._inductor.compile_fx.clone_preserve_strides` so every
# clone the AOTI compile pipeline produces lands on CPU.
# 2. Wrapping `CppWrapperCpu.codegen_device` so the C++ wrapper still records
# the model's original target device (e.g. cuda) in `constants_info_`,
# not the now-CPU storage device. Without this the runtime would refuse
# to load the constants because of a mixed-device mismatch.
#
# The wrappers are scoped via a thread-local guard and are only active while
# `_compile_time_cpu_clones(...)` is on the call stack — they are inert
# anywhere else in the process.

_CPU_CLONE_GUARD = threading.local()


def _is_cpu_clone_active() -> bool:
return getattr(_CPU_CLONE_GUARD, "active", False)


@contextlib.contextmanager
def _compile_time_cpu_clones(target_device: torch.device):
"""Force AOTI's mutated-buffer clones onto CPU while preserving the
serialized constants' target device."""
from torch._inductor import compile_fx as _cfx
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu as _Cpp

orig_clone = _cfx.clone_preserve_strides
orig_codegen_device = _Cpp.codegen_device

def _cpu_clone_preserve_strides(x: torch.Tensor) -> torch.Tensor:
# `clone_preserve_strides` is shared by `_unlift_graph` (clones
# lifted buffers — can be safely kept on CPU) and by autotuning code
# in `triton_heuristics.py` (clones for benchmark — must stay on
# GPU for Triton). Discriminate by caller frame so we only force
# CPU clones for the buffer-lifting path.
import sys

caller = sys._getframe(1).f_code.co_name
if caller == "_unlift_graph":
return orig_clone(x).cpu()
return orig_clone(x)

def _codegen_device_target_aware(self, device):
# Translate accidental CPU device strings back to the model target
# device only when a constant we forced to CPU is being serialized.
# Other code paths (extern op args etc.) are pass-through.
if (
_is_cpu_clone_active()
and self.device != "cpu"
and isinstance(device, torch.device)
and device.type == "cpu"
):
device = target_device
return orig_codegen_device(self, device)

_cfx.clone_preserve_strides = _cpu_clone_preserve_strides
_Cpp.codegen_device = _codegen_device_target_aware
prev_active = getattr(_CPU_CLONE_GUARD, "active", False)
_CPU_CLONE_GUARD.active = True
try:
yield
finally:
_CPU_CLONE_GUARD.active = prev_active
_cfx.clone_preserve_strides = orig_clone
_Cpp.codegen_device = orig_codegen_device


@final
@experimental(
"This API and all of cuda backend related functionality are experimental."
Expand Down Expand Up @@ -253,19 +332,97 @@ def get_aoti_compile_options(
return options

@classmethod
def get_extra_aoti_compile_context_manager(cls):
def get_extra_aoti_compile_context_manager(
cls, compile_specs: Optional[List[CompileSpec]] = None
):
"""
Return SDPA MATH backend context manager for CUDA compilation.

This context manager plays as a fallback solution for any remaining PyTorch SDPA
operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation.

Note:
- If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass,
this context manager will have no effect on those ops (they are no longer
PyTorch SDPA ops).
- If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this
context manager will force them to use the MATH backend, causing them to
be automatically decomposed during compilation.
Combine all extra context managers needed during AOTInductor
compilation for the CUDA backend. Each manager is documented at
its own `enter_context` call site below.

The low-memory export monkey-patch (CPU clones for mutated buffers)
is gated on the ``low_memory_mode`` compile spec — only models that
explicitly opt in (currently Qwen3.5 MoE) get it. Other models go
through the unmodified AOTI codepath, which avoids regressions in
their cuda CI exports.
"""
# Parse compile_specs for low_memory_mode (default OFF). compile_specs
# may be None when called without specs (parity with base default).
low_memory_mode = "OFF"
for spec in compile_specs or []:
if spec.key == "low_memory_mode":
mode = spec.value.decode("utf-8").upper()
if mode not in ["ON", "OFF"]:
raise ValueError(
f"Invalid low_memory_mode: {mode}. Expected 'ON' or 'OFF'."
)
low_memory_mode = mode

@contextlib.contextmanager
def _combined():
with contextlib.ExitStack() as stack:
# Force any remaining PyTorch SDPA ops to use the MATH
# backend during compilation so AOTI can lower / decompose
# them. SDPA ops already replaced by Triton kernels via
# `ReplaceEdgeOpWithTritonOpPass` are unaffected; this is
# only the fallback for the `triton_kernel_mode="OFF"` path.
stack.enter_context(torch.nn.attention.sdpa_kernel([SDPBackend.MATH]))
if low_memory_mode == "ON":
# Force AOTI's mutated-buffer clones onto CPU during
# compile so we stay under tight GPU memory caps (e.g.
# 24 GB on a consumer 4090). See
# `_compile_time_cpu_clones` for details. Only enabled
# for models that explicitly opt in via the
# `low_memory_mode="ON"` compile spec, since the
# monkey-patch can interact poorly with other models'
# AOTI compile pipelines.
stack.enter_context(
_compile_time_cpu_clones(torch.device(cls.get_device_name()))
)
yield

return _combined()

@staticmethod
def _is_low_memory_mode(compile_specs: List[CompileSpec]) -> bool:
"""Return True if any compile spec opts into low-memory export."""
for spec in compile_specs:
if spec.key == "low_memory_mode":
return spec.value.decode("utf-8").upper() == "ON"
return False

@classmethod
def release_moved_tensors(
cls,
device_edge_program,
compile_specs: List[CompileSpec],
) -> None:
"""
Free GPU memory held by tensors that ``move_to_device_pass`` placed
on CUDA (params, buffers, and constants of ``device_edge_program``).

Resizing the underlying storage to 0 returns those bytes to PyTorch's
caching allocator, so the next ``preprocess`` call (e.g. for the
next method in a multi-method export) can reuse them when its own
``move_to_device_pass`` runs.
"""
return torch.nn.attention.sdpa_kernel([SDPBackend.MATH])
if not torch.cuda.is_available():
return

pools = []
state_dict = getattr(device_edge_program, "state_dict", None)
if state_dict:
pools.append(state_dict.values())
constants = getattr(device_edge_program, "constants", None)
if constants:
pools.append(constants.values())

for pool in pools:
for tensor in pool:
if isinstance(tensor, torch.Tensor) and tensor.is_cuda:
try:
tensor.untyped_storage().resize_(0)
except Exception:
# Some storages may be shared / non-resizable; skip
# them rather than failing the export.
pass
17 changes: 17 additions & 0 deletions examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,7 @@ def _export_cuda(model, config, args):
ExecutorchBackendConfig,
to_edge_transform_and_lower,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.passes import MemoryPlanningPass
from torch.export import Dim, export

Expand Down Expand Up @@ -1007,13 +1008,15 @@ def _export_cuda(model, config, args):
CudaPartitioner(
[
CudaBackend.generate_method_name_compile_spec("decode"),
CompileSpec("low_memory_mode", b"ON"),
]
)
],
"prefill": [
CudaPartitioner(
[
CudaBackend.generate_method_name_compile_spec("prefill"),
CompileSpec("low_memory_mode", b"ON"),
]
)
],
Expand Down Expand Up @@ -1166,6 +1169,13 @@ def main(): # noqa: C901
# Register FLA Triton kernel (CUDA only)
import executorch.backends.cuda.triton.kernels # noqa: F401

# Reset peak GPU memory stats so we can report the actual peak
# consumed during the export pipeline (load + quantize + lowering)
# at the very end. This is also gated by CI to make sure low-VRAM
# GPUs (e.g. RTX 4090, 24 GB) can still complete the export.
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(0)

if args.backend == "mlx":
if args.prequantized:
parser.error("--prequantized is not supported with --backend mlx")
Expand Down Expand Up @@ -1207,6 +1217,13 @@ def main(): # noqa: C901

export_and_lower(model, config, args)

# Report peak GPU memory consumed during the export so CI / users can
# gate this against a known budget (e.g. 24 GB consumer GPUs).
if args.backend == "cuda" and torch.cuda.is_available():
peak_mb = torch.cuda.max_memory_allocated(0) / (1024 * 1024)
# Stable, machine-parseable marker for CI grep.
print(f"EXPORT_GPU_PEAK_MEMORY_MB: {peak_mb:.2f}")


if __name__ == "__main__":
main()
Loading