From 6bde7624e3e4ee581214cf6d7c4b45f8b56f883b Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 25 May 2026 15:25:17 +0800 Subject: [PATCH 01/10] limit entropy calculation --- tests/trainer/trainer_test.py | 2 -- .../algorithm/entropy_loss_fn/entropy_loss_fn.py | 15 ++++++++++++--- trinity/trainer/verl/dp_actor.py | 2 +- trinity/trainer/verl/megatron_actor.py | 2 +- trinity/trainer/verl/verl_config.py | 2 +- 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 63bdca5b5d8..d45b795faaf 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -282,8 +282,6 @@ def test_trainer(self): self.config.buffer.total_epochs = 1 self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") self.config.trainer.trainer_strategy = self.strategy - self.config.model.max_response_tokens = 512 - self.config.model.max_model_len = 2048 self.config.check_and_update() self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 actor_rollout_ref = self.config.trainer.trainer_config.actor_rollout_ref diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py index 75069c37398..4c716aa3130 100644 --- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -11,6 +11,9 @@ class EntropyLossFn(ABC): Entropy loss function. """ + def __init__(self, entropy_coef: float): + self.entropy_coef = entropy_coef + @abstractmethod def __call__( self, @@ -36,6 +39,12 @@ def default_args(cls) -> Dict: """ return {"entropy_coef": 0.0} + def enable(self) -> bool: + """ + Returns: + bool: Whether the entropy loss is enabled. + """ + return self.entropy_coef > 0.0 class DefaultEntropyLossFn(EntropyLossFn): """ @@ -43,7 +52,7 @@ class DefaultEntropyLossFn(EntropyLossFn): """ def __init__(self, entropy_coef: float): - self.entropy_coef = entropy_coef + super().__init__(entropy_coef) def __call__( self, @@ -62,7 +71,7 @@ class MixEntropyLossFn(EntropyLossFn): """ def __init__(self, entropy_coef: float): - self.entropy_coef = entropy_coef + super().__init__(entropy_coef) def __call__( self, @@ -89,7 +98,7 @@ class DummyEntropyLossFn(EntropyLossFn): """ def __init__(self, entropy_coef: float): - self.entropy_coef = entropy_coef + super().__init__(entropy_coef) def __call__( self, diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 254b9cc4d76..940de754ab9 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -517,7 +517,7 @@ def update_policy(self, data: DataProto): # noqa: C901 loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") # all return: (bsz, response_length) - calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn + calculate_entropy = self.entropy_loss_fn.enable() if self.entropy_loss_fn is not None else False outputs = self._forward_micro_batch( micro_batch=model_inputs, temperature=temperature, diff --git a/trinity/trainer/verl/megatron_actor.py b/trinity/trainer/verl/megatron_actor.py index 5cd8b2f2233..0c0af3365e0 100644 --- a/trinity/trainer/verl/megatron_actor.py +++ b/trinity/trainer/verl/megatron_actor.py @@ -514,7 +514,7 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict: # if use distributed optimizer, zero grad buffer will be handled by optimizer chunk.zero_grad_buffer() - calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn + calculate_entropy = self.entropy_loss_fn.enable() if self.entropy_loss_fn is not None else False if data.meta_info.get("micro_batch_size", None) is not None: micro_batch_size = data.meta_info["micro_batch_size"] else: diff --git a/trinity/trainer/verl/verl_config.py b/trinity/trainer/verl/verl_config.py index bc5146ea601..02703aa0b80 100644 --- a/trinity/trainer/verl/verl_config.py +++ b/trinity/trainer/verl/verl_config.py @@ -180,7 +180,7 @@ class Actor: clip_ratio: float = 0.2 clip_ratio_low: Optional[float] = None clip_ratio_high: Optional[float] = None - entropy_coeff: float = 0.001 + entropy_coeff: float = 0 use_kl_loss: bool = False From 756d31cfd99e658582d78a35eb2a767f58eec336 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 25 May 2026 15:38:13 +0800 Subject: [PATCH 02/10] limit entropy --- trinity/trainer/verl/megatron_actor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/trinity/trainer/verl/megatron_actor.py b/trinity/trainer/verl/megatron_actor.py index 0c0af3365e0..000147109c5 100644 --- a/trinity/trainer/verl/megatron_actor.py +++ b/trinity/trainer/verl/megatron_actor.py @@ -68,6 +68,7 @@ def __init__( self.policy_loss_fn = None self.kl_loss_fn = None self.entropy_loss_fn = None + self.calculate_entropy = False def set_algorithm(self, algorithm_config: AlgorithmConfig): self.loss_agg_mode = algorithm_config.loss_agg_mode @@ -78,6 +79,7 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig): self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)( **algorithm_config.entropy_loss_fn_args ) + self.calculate_entropy = self.entropy_loss_fn.enable() def forward_backward_batch( # noqa: C901 self, @@ -514,7 +516,6 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict: # if use distributed optimizer, zero grad buffer will be handled by optimizer chunk.zero_grad_buffer() - calculate_entropy = self.entropy_loss_fn.enable() if self.entropy_loss_fn is not None else False if data.meta_info.get("micro_batch_size", None) is not None: micro_batch_size = data.meta_info["micro_batch_size"] else: @@ -527,7 +528,7 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict: ) metric_micro_batch = self.forward_backward_batch( data, - calculate_entropy=calculate_entropy, + calculate_entropy=self.calculate_entropy, use_dynamic_bsz=self.config.use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len, From feca31d6d768c2295600edafcba4731f7ed32c50 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 25 May 2026 16:25:06 +0800 Subject: [PATCH 03/10] fix entropy calculation oom --- tests/trainer/trainer_test.py | 13 +- trinity/trainer/verl/megatron_actor.py | 188 ++++++++++++++++--------- 2 files changed, 132 insertions(+), 69 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index d45b795faaf..668f9ae6dd9 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -251,12 +251,12 @@ def tearDown(self): @parameterized_class( - ("strategy", "offloading", "engine_type"), + ("strategy", "offloading", "engine_type", "entropy_coef"), [ - ("megatron", False, "vllm"), - ("fsdp2", False, "vllm"), - ("megatron", True, "sglang"), - ("fsdp2", True, "sglang"), + ("megatron", False, "vllm", 0.0), + ("fsdp2", False, "vllm", 0.0), + ("megatron", True, "sglang", 0.001), + ("fsdp2", True, "sglang", 0.001), ], ) class TestTrainerGSM8K(BaseTrainerCase): @@ -269,6 +269,9 @@ def test_trainer(self): self.config.algorithm.advantage_fn_args = { "epsilon": 1e-6, } + self.config.algorithm.entropy_loss_fn_args = { + "entropy_coef": self.entropy_coef, + } if self.offloading: self.config.model.model_path = get_api_model_path() # self.config.algorithm.repeat_times = 8 # TODO: used for real testing diff --git a/trinity/trainer/verl/megatron_actor.py b/trinity/trainer/verl/megatron_actor.py index 000147109c5..7e1be7d368b 100644 --- a/trinity/trainer/verl/megatron_actor.py +++ b/trinity/trainer/verl/megatron_actor.py @@ -23,12 +23,12 @@ import os from functools import partial -from typing import Iterable, Tuple +from typing import Iterable import torch from megatron.core import parallel_state as mpu from megatron.core.pipeline_parallel import get_forward_backward_func -from megatron.core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy +from megatron.core.tensor_parallel.utils import VocabUtility from verl import DataProto from verl.utils.device import get_device_id, get_torch_device from verl.utils.megatron.pipeline_parallel import make_batch_generator @@ -39,10 +39,7 @@ reorder_and_merge_vpp_layers, set_router_replay_data, ) -from verl.utils.megatron.tensor_parallel import ( - vocab_parallel_entropy, - vocab_parallel_log_probs_from_logits, -) +from verl.utils.megatron.tensor_parallel import vocab_parallel_log_probs_from_logits from verl.utils.megatron_utils import get_megatron_mtp_loss, unwrap_model from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict @@ -57,6 +54,122 @@ from trinity.common.config import AlgorithmConfig +class _VocabParallelLogProbsAndEntropy(torch.autograd.Function): + """Compute TP-sharded target log-probs and entropy in one pass. + + This avoids the verl #1970 failure mode where entropy saves logits for + backward and Megatron cross-entropy later mutates the same logits buffer + in-place. The implementation keeps the entropy-safe path local to the + calculate_entropy branch and reuses a single fp32 buffer to limit peak + memory. + """ + + @staticmethod + def forward(ctx, vocab_parallel_logits, target): + @torch.compile(dynamic=True) + def mul_reduce(a, b): + return (a * b).sum(dim=-1) + + if vocab_parallel_logits.dtype == torch.float32: + logits = vocab_parallel_logits.clone() + else: + logits = vocab_parallel_logits.float() + tp_group = mpu.get_tensor_model_parallel_group() + + logits_max = logits.max(dim=-1).values + torch.distributed.all_reduce( + logits_max, + op=torch.distributed.ReduceOp.MAX, + group=tp_group, + ) + + logits.sub_(logits_max.unsqueeze(dim=-1)) + partition_vocab_size = logits.size(-1) + vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size( + partition_vocab_size, + mpu.get_tensor_model_parallel_rank(), + mpu.get_tensor_model_parallel_world_size(), + ) + + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + logits_2d = logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(logits_2d.size(0), device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d].clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + torch.distributed.all_reduce( + predicted_logits, + op=torch.distributed.ReduceOp.SUM, + group=tp_group, + ) + + # Reuse the fp32 buffer through exp -> softmax to avoid keeping + # normalized_logits / exp_logits / softmax as separate large tensors. + logits.exp_() + sum_exp_logits = logits.sum(dim=-1) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=tp_group) + + log_sum_exp = sum_exp_logits.log() + logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + softmax = logits + + # Consume the softmax-weighted logits reduction immediately so the + # temporary product tensor does not live longer than necessary. + sum_softmax_times_logits = mul_reduce(softmax, vocab_parallel_logits) + torch.distributed.all_reduce( + sum_softmax_times_logits, + op=torch.distributed.ReduceOp.SUM, + group=tp_group, + ) + entropy = logits_max + log_sum_exp - sum_softmax_times_logits + + ctx.partition_vocab_size = partition_vocab_size + ctx.save_for_backward(softmax, target_mask, masked_target_1d, entropy) + + return predicted_logits - log_sum_exp, entropy + + @staticmethod + def backward(ctx, grad_log_probs, grad_entropy): + softmax, target_mask, masked_target_1d, entropy = ctx.saved_tensors + grad_input = softmax + if grad_entropy is not None: + # Keep only one temporary vocab-sized tensor in backward: reuse the + # saved softmax buffer for grad_input and materialize log_softmax + # only long enough to build the entropy coefficient. + log_softmax = softmax.log() + log_softmax.add_(entropy.unsqueeze(dim=-1)) + log_softmax.mul_(grad_entropy.unsqueeze(dim=-1)) + if grad_log_probs is not None: + log_softmax.add_(grad_log_probs.unsqueeze(dim=-1)) + grad_input.mul_(log_softmax) + grad_input.mul_(-1) + elif grad_log_probs is not None: + grad_input.mul_(grad_log_probs.unsqueeze(dim=-1)) + grad_input.mul_(-1) + else: + grad_input.zero_() + + if grad_log_probs is not None: + grad_2d = grad_input.view(-1, ctx.partition_vocab_size) + arange_1d = torch.arange(grad_2d.size(0), device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] += ( + grad_log_probs.reshape(-1) * (~target_mask).view(-1).to(grad_input.dtype) + ) + + return grad_input, None + + +def vocab_parallel_log_probs_and_entropy( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + return _VocabParallelLogProbsAndEntropy.apply(vocab_parallel_logits, target) + + class MegatronPPOActor(OldMegatronPPOActor): def __init__( self, @@ -371,22 +484,12 @@ def logits_processor(logits, label, label_mask): logits.div_(temperature) ret = {} if calculate_entropy: - # The veRL fix consumes more GPU memory than our implementation - # (.clone() v.s. monkey patch on megatron function); - # therefore, we have temporarily commented out the veRL fix. - # logits_bak = logits.clone() - # # disable the hint until the fused_kernel is optimized for triton>=3.3 - # logger.warning_once( - # "For memory-efficient computation, enable fused kernels via " - # "`actor_rollout_ref.model.use_fused_kernels=True`. " - # "The current `clone()` operation ensures correctness but increases memory usage." - # ) - entropy = vocab_parallel_entropy(logits) + # Only use the safe path when entropy is enabled. This avoids + # issue #1970 without changing the default Megatron CE path. + log_probs, entropy = vocab_parallel_log_probs_and_entropy(logits, label) ret["entropy"] = entropy - # else: - # logits_bak = logits - # log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label) - log_probs = vocab_parallel_log_probs_from_logits(logits, label) + else: + log_probs = vocab_parallel_log_probs_from_logits(logits, label) log_probs = log_probs.masked_fill(~label_mask, 0.0) ret["log_probs"] = log_probs return ret @@ -563,46 +666,3 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict: self.actor_optimizer.zero_grad() get_torch_device().empty_cache() return metrics - - -def calculate_predicted_logits( - vocab_parallel_logits: torch.Tensor, - target: torch.Tensor, - logits_max: torch.Tensor, - vocab_start_index: int, - vocab_end_index: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Calculates predicted logits. - Modified from megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits - """ - - # No In-place subtraction !!! - vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) - - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (target < vocab_start_index) | (target >= vocab_end_index) - masked_target = target.clone() - vocab_start_index - masked_target[target_mask] = 0 - - # Get predicted-logits = logits[target]. - # For Simplicity, we convert logits to a 2-D tensor with size - # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. - partition_vocab_size = vocab_parallel_logits.size()[-1] - logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) - masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) - predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] - predicted_logits_1d = predicted_logits_1d.clone().contiguous() - predicted_logits = predicted_logits_1d.view_as(target) - predicted_logits[target_mask] = 0.0 - - exp_logits = vocab_parallel_logits - torch.exp(vocab_parallel_logits, out=exp_logits) - sum_exp_logits = exp_logits.sum(dim=-1) - - return target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits - - -# bug fix for https://github.com/volcengine/verl/issues/1970 -VocabParallelCrossEntropy.calculate_predicted_logits = calculate_predicted_logits From b3b006f5957a5cfa6f294e4f1dd93d3062ab4d72 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 25 May 2026 16:25:54 +0800 Subject: [PATCH 04/10] fix pre-commit --- trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py | 1 + trinity/trainer/verl/dp_actor.py | 5 +++-- trinity/trainer/verl/megatron_actor.py | 11 ++++++----- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py index 4c716aa3130..13613489373 100644 --- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -46,6 +46,7 @@ def enable(self) -> bool: """ return self.entropy_coef > 0.0 + class DefaultEntropyLossFn(EntropyLossFn): """ Basic entropy loss function. diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 940de754ab9..59ca980a0af 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -41,7 +41,6 @@ from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN -from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn from trinity.algorithm.kl_fn.kl_fn import DummyKLFn from trinity.algorithm.utils import prefix_metrics from trinity.common.config import AlgorithmConfig @@ -517,7 +516,9 @@ def update_policy(self, data: DataProto): # noqa: C901 loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") # all return: (bsz, response_length) - calculate_entropy = self.entropy_loss_fn.enable() if self.entropy_loss_fn is not None else False + calculate_entropy = ( + self.entropy_loss_fn.enable() if self.entropy_loss_fn is not None else False + ) outputs = self._forward_micro_batch( micro_batch=model_inputs, temperature=temperature, diff --git a/trinity/trainer/verl/megatron_actor.py b/trinity/trainer/verl/megatron_actor.py index 7e1be7d368b..643f1d1978f 100644 --- a/trinity/trainer/verl/megatron_actor.py +++ b/trinity/trainer/verl/megatron_actor.py @@ -49,7 +49,6 @@ from verl.workers.megatron_workers import logger from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN -from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn from trinity.algorithm.utils import prefix_metrics from trinity.common.config import AlgorithmConfig @@ -111,7 +110,9 @@ def mul_reduce(a, b): # normalized_logits / exp_logits / softmax as separate large tensors. logits.exp_() sum_exp_logits = logits.sum(dim=-1) - torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=tp_group) + torch.distributed.all_reduce( + sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=tp_group + ) log_sum_exp = sum_exp_logits.log() logits.div_(sum_exp_logits.unsqueeze(dim=-1)) @@ -156,9 +157,9 @@ def backward(ctx, grad_log_probs, grad_entropy): if grad_log_probs is not None: grad_2d = grad_input.view(-1, ctx.partition_vocab_size) arange_1d = torch.arange(grad_2d.size(0), device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] += ( - grad_log_probs.reshape(-1) * (~target_mask).view(-1).to(grad_input.dtype) - ) + grad_2d[arange_1d, masked_target_1d] += grad_log_probs.reshape(-1) * ( + ~target_mask + ).view(-1).to(grad_input.dtype) return grad_input, None From a69ac85a84e09b2ae6a23c38d52fb0f8ca6d5eb7 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 25 May 2026 16:51:34 +0800 Subject: [PATCH 05/10] add enable entropy arg --- .../entropy_loss_fn/entropy_loss_fn.py | 22 ++++++++++--------- trinity/trainer/verl/megatron_actor.py | 2 +- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py index 13613489373..0ee5c7084f8 100644 --- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import torch @@ -11,8 +11,10 @@ class EntropyLossFn(ABC): Entropy loss function. """ - def __init__(self, entropy_coef: float): + def __init__(self, entropy_coef: float, enable_entropy: Optional[bool] = None): self.entropy_coef = entropy_coef + # enable entropy calculation if entropy_coef > 0.0, or explicitly set by enable_entropy + self._enable_entropy = enable_entropy or self.entropy_coef > 0.0 @abstractmethod def __call__( @@ -39,12 +41,12 @@ def default_args(cls) -> Dict: """ return {"entropy_coef": 0.0} - def enable(self) -> bool: + def enable_entropy(self) -> bool: """ Returns: bool: Whether the entropy loss is enabled. """ - return self.entropy_coef > 0.0 + return self._enable_entropy class DefaultEntropyLossFn(EntropyLossFn): @@ -52,8 +54,8 @@ class DefaultEntropyLossFn(EntropyLossFn): Basic entropy loss function. """ - def __init__(self, entropy_coef: float): - super().__init__(entropy_coef) + def __init__(self, entropy_coef: float, enable_entropy: Optional[bool] = None): + super().__init__(entropy_coef, enable_entropy) def __call__( self, @@ -71,8 +73,8 @@ class MixEntropyLossFn(EntropyLossFn): Basic entropy loss function for mix algorithm. """ - def __init__(self, entropy_coef: float): - super().__init__(entropy_coef) + def __init__(self, entropy_coef: float, enable_entropy: Optional[bool] = None): + super().__init__(entropy_coef, enable_entropy) def __call__( self, @@ -98,8 +100,8 @@ class DummyEntropyLossFn(EntropyLossFn): Dummy entropy loss function. """ - def __init__(self, entropy_coef: float): - super().__init__(entropy_coef) + def __init__(self, entropy_coef: float, enable_entropy: Optional[bool] = None): + super().__init__(entropy_coef, enable_entropy) def __call__( self, diff --git a/trinity/trainer/verl/megatron_actor.py b/trinity/trainer/verl/megatron_actor.py index 643f1d1978f..8ffa8d1d4d4 100644 --- a/trinity/trainer/verl/megatron_actor.py +++ b/trinity/trainer/verl/megatron_actor.py @@ -193,7 +193,7 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig): self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)( **algorithm_config.entropy_loss_fn_args ) - self.calculate_entropy = self.entropy_loss_fn.enable() + self.calculate_entropy = self.entropy_loss_fn.enable_entropy() def forward_backward_batch( # noqa: C901 self, From 172508baa77978da0dd1f12498ff3210ea7324ab Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 25 May 2026 17:15:08 +0800 Subject: [PATCH 06/10] simplify entropy --- tests/trainer/trainer_test.py | 14 +++++----- .../entropy_loss_fn/entropy_loss_fn.py | 26 +++++-------------- trinity/trainer/verl/megatron_actor.py | 2 +- 3 files changed, 14 insertions(+), 28 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 668f9ae6dd9..8286365b50a 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -251,12 +251,12 @@ def tearDown(self): @parameterized_class( - ("strategy", "offloading", "engine_type", "entropy_coef"), + ("strategy", "offloading", "engine_type", "entropy_loss_fn"), [ - ("megatron", False, "vllm", 0.0), - ("fsdp2", False, "vllm", 0.0), - ("megatron", True, "sglang", 0.001), - ("fsdp2", True, "sglang", 0.001), + ("megatron", False, "vllm", "none"), + ("fsdp2", False, "vllm", "none"), + ("megatron", True, "sglang", "default"), + ("fsdp2", True, "sglang", "default"), ], ) class TestTrainerGSM8K(BaseTrainerCase): @@ -269,9 +269,7 @@ def test_trainer(self): self.config.algorithm.advantage_fn_args = { "epsilon": 1e-6, } - self.config.algorithm.entropy_loss_fn_args = { - "entropy_coef": self.entropy_coef, - } + self.config.algorithm.entropy_loss_fn = self.entropy_loss_fn if self.offloading: self.config.model.model_path = get_api_model_path() # self.config.algorithm.repeat_times = 8 # TODO: used for real testing diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py index 0ee5c7084f8..75069c37398 100644 --- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple import torch @@ -11,11 +11,6 @@ class EntropyLossFn(ABC): Entropy loss function. """ - def __init__(self, entropy_coef: float, enable_entropy: Optional[bool] = None): - self.entropy_coef = entropy_coef - # enable entropy calculation if entropy_coef > 0.0, or explicitly set by enable_entropy - self._enable_entropy = enable_entropy or self.entropy_coef > 0.0 - @abstractmethod def __call__( self, @@ -41,21 +36,14 @@ def default_args(cls) -> Dict: """ return {"entropy_coef": 0.0} - def enable_entropy(self) -> bool: - """ - Returns: - bool: Whether the entropy loss is enabled. - """ - return self._enable_entropy - class DefaultEntropyLossFn(EntropyLossFn): """ Basic entropy loss function. """ - def __init__(self, entropy_coef: float, enable_entropy: Optional[bool] = None): - super().__init__(entropy_coef, enable_entropy) + def __init__(self, entropy_coef: float): + self.entropy_coef = entropy_coef def __call__( self, @@ -73,8 +61,8 @@ class MixEntropyLossFn(EntropyLossFn): Basic entropy loss function for mix algorithm. """ - def __init__(self, entropy_coef: float, enable_entropy: Optional[bool] = None): - super().__init__(entropy_coef, enable_entropy) + def __init__(self, entropy_coef: float): + self.entropy_coef = entropy_coef def __call__( self, @@ -100,8 +88,8 @@ class DummyEntropyLossFn(EntropyLossFn): Dummy entropy loss function. """ - def __init__(self, entropy_coef: float, enable_entropy: Optional[bool] = None): - super().__init__(entropy_coef, enable_entropy) + def __init__(self, entropy_coef: float): + self.entropy_coef = entropy_coef def __call__( self, diff --git a/trinity/trainer/verl/megatron_actor.py b/trinity/trainer/verl/megatron_actor.py index 8ffa8d1d4d4..e8bf7c2a014 100644 --- a/trinity/trainer/verl/megatron_actor.py +++ b/trinity/trainer/verl/megatron_actor.py @@ -193,7 +193,7 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig): self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)( **algorithm_config.entropy_loss_fn_args ) - self.calculate_entropy = self.entropy_loss_fn.enable_entropy() + self.calculate_entropy = algorithm_config.entropy_loss_fn != "none" def forward_backward_batch( # noqa: C901 self, From 56ef6fabf255df30742d276ca1e9de00b065bdf8 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 25 May 2026 17:37:00 +0800 Subject: [PATCH 07/10] fix fsdp --- trinity/trainer/verl/dp_actor.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 59ca980a0af..30cbe390298 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -105,6 +105,7 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig): self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)( **algorithm_config.entropy_loss_fn_args ) + self.calculate_entropy = algorithm_config.entropy_loss_fn != "none" def _forward_micro_batch( # noqa: C901 self, @@ -516,16 +517,13 @@ def update_policy(self, data: DataProto): # noqa: C901 loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") # all return: (bsz, response_length) - calculate_entropy = ( - self.entropy_loss_fn.enable() if self.entropy_loss_fn is not None else False - ) outputs = self._forward_micro_batch( micro_batch=model_inputs, temperature=temperature, - calculate_entropy=calculate_entropy, + calculate_entropy=self.calculate_entropy, ) log_prob = outputs["log_probs"] - entropy = outputs["entropys"] if calculate_entropy else None + entropy = outputs["entropys"] if self.calculate_entropy else None pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore logprob=log_prob, **model_inputs From d276009b63b27baebcd9a6802bc0184f33994df2 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 25 May 2026 18:19:28 +0800 Subject: [PATCH 08/10] update empty_cache --- trinity/trainer/verl/fsdp_workers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 4124621d2c2..795679b3739 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -292,7 +292,7 @@ def _fsdp_offload_context(self): if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) torch.distributed.barrier() - torch.cuda.empty_cache() + get_torch_device().empty_cache() @kimi_vl_monkey_patch_decorator def _build_model_optimizer( # noqa: C901 @@ -793,6 +793,8 @@ def init_model(self): trust_remote_code=trust_remote_code, ) + get_torch_device().empty_cache() + @register(dispatch_mode=Dispatch.ONE_TO_ALL) def setup_weight_sync_group(self): if self.config.synchronizer.sync_method == SyncMethod.NCCL: @@ -962,7 +964,7 @@ def update_actor(self, data: DataProto): # backward passes. Without this, memory_reserved grows monotonically and # eventually starves vLLM during weight sync in colocate mode. # Matches the pattern in megatron_workers.py update_actor(). - torch.cuda.empty_cache() + get_torch_device().empty_cache() return output @@ -1049,7 +1051,7 @@ def compute_ref_log_prob(self, data: DataProto): # Release reserved GPU memory after ref model forward pass. # Without this, memory_reserved grows after each ref_log_prob call, # eventually causing OOM in subsequent training steps. - torch.cuda.empty_cache() + get_torch_device().empty_cache() # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module @@ -1641,6 +1643,8 @@ def init_model(self): trust_remote_code=self.config.model.get("trust_remote_code", False), ) + get_torch_device().empty_cache() + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) @DistProfiler.annotate(color="cyan", role="compute_values") def compute_values(self, data: DataProto): From dcf912ae6c42331f033be227cca9916116725eed Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 25 May 2026 18:19:44 +0800 Subject: [PATCH 09/10] update empty_cache --- trinity/trainer/verl/fsdp_workers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 795679b3739..3d94e9a8bff 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -820,7 +820,7 @@ def setup_weight_sync_group(self): (realname, str(param.dtype).split(".")[-1], tuple(param.shape)) ) param = None - torch.cuda.empty_cache() + get_torch_device().empty_cache() else: # fsdp2 for name, param in model.named_parameters(): self.state_dict_meta.append( From 0ffef3f40d34fea8db011876887eceef4e320919 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 25 May 2026 18:32:43 +0800 Subject: [PATCH 10/10] enhance trainer test --- tests/trainer/trainer_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 8286365b50a..4ce018ba1be 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -305,6 +305,12 @@ def test_trainer(self): actor_metrics = parser.metric_list("actor") self.assertGreater(len(actor_metrics), 0) self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) + entropy_loss_metrics = parser.metric_list("actor/entropy_loss") + if self.entropy_loss_fn == "none": + self.assertEqual(len(entropy_loss_metrics), 0) + else: + self.assertGreater(len(entropy_loss_metrics), 0) + self.assertEqual(parser.metric_max_step(entropy_loss_metrics[0]), 4) response_metrics = parser.metric_list("response_length") self.assertGreater(len(response_metrics), 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)