From 84ac359c68a1d352ef8ec8d67758187c3c633ec3 Mon Sep 17 00:00:00 2001 From: leon-L20 Date: Tue, 19 May 2026 09:13:23 +0000 Subject: [PATCH] fix(trainer): add torch.cuda.empty_cache() after FSDP update_actor In colocate mode (vLLM + FSDP on the same GPU), PyTorch's caching allocator holds onto reserved GPU memory after backward passes without releasing it back to CUDA. This causes memory_reserved to grow monotonically across training steps, eventually starving vLLM during weight synchronization. Observed on L20 (46 GiB) with Qwen3.5-0.8B: - Without fix: memory_reserved grows from ~30 GiB to 38.2 GiB over 9 steps, causing OOM during vLLM weight sync - With fix: memory_reserved stabilizes at 32.9 GiB from step 5 onward This matches the existing pattern in megatron_workers.py, which calls aggressive_empty_cache(force_sync=True) at the end of update_actor. The FSDP path had no equivalent cache release. Co-Authored-By: Claude Opus 4.7 (1M context) --- trinity/trainer/verl/fsdp_workers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 5787bf2e6e..3d2f0b40a1 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -958,6 +958,12 @@ def update_actor(self, data: DataProto): "After offload actor optimizer during update_actor", logger=self.logger ) + # Release reserved GPU memory held by PyTorch's caching allocator after + # backward passes. Without this, memory_reserved grows monotonically and + # eventually starves vLLM during weight sync in colocate mode. + # Matches the pattern in megatron_workers.py update_actor(). + torch.cuda.empty_cache() + return output @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))