Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,12 @@ def tearDown(self):


@parameterized_class(
("strategy", "offloading", "engine_type"),
("strategy", "offloading", "engine_type", "entropy_loss_fn"),
[
("megatron", False, "vllm"),
("fsdp2", False, "vllm"),
("megatron", True, "sglang"),
("fsdp2", True, "sglang"),
("megatron", False, "vllm", "none"),
("fsdp2", False, "vllm", "none"),
("megatron", True, "sglang", "default"),
("fsdp2", True, "sglang", "default"),
],
)
class TestTrainerGSM8K(BaseTrainerCase):
Expand All @@ -269,6 +269,7 @@ def test_trainer(self):
self.config.algorithm.advantage_fn_args = {
"epsilon": 1e-6,
}
self.config.algorithm.entropy_loss_fn = self.entropy_loss_fn
if self.offloading:
self.config.model.model_path = get_api_model_path()
# self.config.algorithm.repeat_times = 8 # TODO: used for real testing
Expand All @@ -282,8 +283,6 @@ def test_trainer(self):
self.config.buffer.total_epochs = 1
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.trainer.trainer_strategy = self.strategy
self.config.model.max_response_tokens = 512
self.config.model.max_model_len = 2048
self.config.check_and_update()
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
actor_rollout_ref = self.config.trainer.trainer_config.actor_rollout_ref
Expand All @@ -306,6 +305,12 @@ def test_trainer(self):
actor_metrics = parser.metric_list("actor")
self.assertGreater(len(actor_metrics), 0)
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
entropy_loss_metrics = parser.metric_list("actor/entropy_loss")
if self.entropy_loss_fn == "none":
self.assertEqual(len(entropy_loss_metrics), 0)
else:
self.assertGreater(len(entropy_loss_metrics), 0)
self.assertEqual(parser.metric_max_step(entropy_loss_metrics[0]), 4)
response_metrics = parser.metric_list("response_length")
self.assertGreater(len(response_metrics), 0)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
Expand Down
7 changes: 3 additions & 4 deletions trinity/trainer/verl/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor

from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN
from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn
from trinity.algorithm.kl_fn.kl_fn import DummyKLFn
from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import AlgorithmConfig
Expand Down Expand Up @@ -106,6 +105,7 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig):
self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)(
**algorithm_config.entropy_loss_fn_args
)
self.calculate_entropy = algorithm_config.entropy_loss_fn != "none"

def _forward_micro_batch( # noqa: C901
self,
Expand Down Expand Up @@ -517,14 +517,13 @@ def update_policy(self, data: DataProto): # noqa: C901
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")

# all return: (bsz, response_length)
calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn
outputs = self._forward_micro_batch(
micro_batch=model_inputs,
temperature=temperature,
calculate_entropy=calculate_entropy,
calculate_entropy=self.calculate_entropy,
)
log_prob = outputs["log_probs"]
entropy = outputs["entropys"] if calculate_entropy else None
entropy = outputs["entropys"] if self.calculate_entropy else None

pg_loss, pg_loss_metrics = self.policy_loss_fn( # type: ignore
logprob=log_prob, **model_inputs
Expand Down
12 changes: 8 additions & 4 deletions trinity/trainer/verl/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def _fsdp_offload_context(self):
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
torch.distributed.barrier()
torch.cuda.empty_cache()
get_torch_device().empty_cache()

@kimi_vl_monkey_patch_decorator
def _build_model_optimizer( # noqa: C901
Expand Down Expand Up @@ -793,6 +793,8 @@ def init_model(self):
trust_remote_code=trust_remote_code,
)

get_torch_device().empty_cache()

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def setup_weight_sync_group(self):
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
Expand All @@ -818,7 +820,7 @@ def setup_weight_sync_group(self):
(realname, str(param.dtype).split(".")[-1], tuple(param.shape))
)
param = None
torch.cuda.empty_cache()
get_torch_device().empty_cache()
else: # fsdp2
for name, param in model.named_parameters():
self.state_dict_meta.append(
Expand Down Expand Up @@ -962,7 +964,7 @@ def update_actor(self, data: DataProto):
# backward passes. Without this, memory_reserved grows monotonically and
# eventually starves vLLM during weight sync in colocate mode.
# Matches the pattern in megatron_workers.py update_actor().
torch.cuda.empty_cache()
get_torch_device().empty_cache()

return output

Expand Down Expand Up @@ -1049,7 +1051,7 @@ def compute_ref_log_prob(self, data: DataProto):
# Release reserved GPU memory after ref model forward pass.
# Without this, memory_reserved grows after each ref_log_prob call,
# eventually causing OOM in subsequent training steps.
torch.cuda.empty_cache()
get_torch_device().empty_cache()

# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
Expand Down Expand Up @@ -1641,6 +1643,8 @@ def init_model(self):
trust_remote_code=self.config.model.get("trust_remote_code", False),
)

get_torch_device().empty_cache()

@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic"))
@DistProfiler.annotate(color="cyan", role="compute_values")
def compute_values(self, data: DataProto):
Expand Down
196 changes: 129 additions & 67 deletions trinity/trainer/verl/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@

import os
from functools import partial
from typing import Iterable, Tuple
from typing import Iterable

import torch
from megatron.core import parallel_state as mpu
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from megatron.core.tensor_parallel.utils import VocabUtility
from verl import DataProto
from verl.utils.device import get_device_id, get_torch_device
from verl.utils.megatron.pipeline_parallel import make_batch_generator
Expand All @@ -39,10 +39,7 @@
reorder_and_merge_vpp_layers,
set_router_replay_data,
)
from verl.utils.megatron.tensor_parallel import (
vocab_parallel_entropy,
vocab_parallel_log_probs_from_logits,
)
from verl.utils.megatron.tensor_parallel import vocab_parallel_log_probs_from_logits
from verl.utils.megatron_utils import get_megatron_mtp_loss, unwrap_model
from verl.utils.profiler import GPUMemoryLogger
from verl.utils.py_functional import append_to_dict
Expand All @@ -52,11 +49,128 @@
from verl.workers.megatron_workers import logger

from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN
from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn
from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import AlgorithmConfig


class _VocabParallelLogProbsAndEntropy(torch.autograd.Function):
"""Compute TP-sharded target log-probs and entropy in one pass.

This avoids the verl #1970 failure mode where entropy saves logits for
backward and Megatron cross-entropy later mutates the same logits buffer
in-place. The implementation keeps the entropy-safe path local to the
calculate_entropy branch and reuses a single fp32 buffer to limit peak
memory.
"""

@staticmethod
def forward(ctx, vocab_parallel_logits, target):
@torch.compile(dynamic=True)
def mul_reduce(a, b):
return (a * b).sum(dim=-1)

if vocab_parallel_logits.dtype == torch.float32:
logits = vocab_parallel_logits.clone()
else:
logits = vocab_parallel_logits.float()
tp_group = mpu.get_tensor_model_parallel_group()

logits_max = logits.max(dim=-1).values
torch.distributed.all_reduce(
logits_max,
op=torch.distributed.ReduceOp.MAX,
group=tp_group,
)

logits.sub_(logits_max.unsqueeze(dim=-1))
partition_vocab_size = logits.size(-1)
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size(
partition_vocab_size,
mpu.get_tensor_model_parallel_rank(),
mpu.get_tensor_model_parallel_world_size(),
)

target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0

logits_2d = logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(logits_2d.size(0), device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d].clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
torch.distributed.all_reduce(
predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=tp_group,
)

# Reuse the fp32 buffer through exp -> softmax to avoid keeping
# normalized_logits / exp_logits / softmax as separate large tensors.
logits.exp_()
sum_exp_logits = logits.sum(dim=-1)
torch.distributed.all_reduce(
sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=tp_group
)

log_sum_exp = sum_exp_logits.log()
logits.div_(sum_exp_logits.unsqueeze(dim=-1))
softmax = logits

# Consume the softmax-weighted logits reduction immediately so the
# temporary product tensor does not live longer than necessary.
sum_softmax_times_logits = mul_reduce(softmax, vocab_parallel_logits)
torch.distributed.all_reduce(
sum_softmax_times_logits,
op=torch.distributed.ReduceOp.SUM,
group=tp_group,
)
entropy = logits_max + log_sum_exp - sum_softmax_times_logits

ctx.partition_vocab_size = partition_vocab_size
ctx.save_for_backward(softmax, target_mask, masked_target_1d, entropy)

return predicted_logits - log_sum_exp, entropy

@staticmethod
def backward(ctx, grad_log_probs, grad_entropy):
softmax, target_mask, masked_target_1d, entropy = ctx.saved_tensors
grad_input = softmax
if grad_entropy is not None:
# Keep only one temporary vocab-sized tensor in backward: reuse the
# saved softmax buffer for grad_input and materialize log_softmax
# only long enough to build the entropy coefficient.
log_softmax = softmax.log()
log_softmax.add_(entropy.unsqueeze(dim=-1))
log_softmax.mul_(grad_entropy.unsqueeze(dim=-1))
if grad_log_probs is not None:
log_softmax.add_(grad_log_probs.unsqueeze(dim=-1))
grad_input.mul_(log_softmax)
grad_input.mul_(-1)
elif grad_log_probs is not None:
grad_input.mul_(grad_log_probs.unsqueeze(dim=-1))
grad_input.mul_(-1)
else:
grad_input.zero_()

if grad_log_probs is not None:
grad_2d = grad_input.view(-1, ctx.partition_vocab_size)
arange_1d = torch.arange(grad_2d.size(0), device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] += grad_log_probs.reshape(-1) * (
~target_mask
).view(-1).to(grad_input.dtype)

return grad_input, None


def vocab_parallel_log_probs_and_entropy(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return _VocabParallelLogProbsAndEntropy.apply(vocab_parallel_logits, target)


class MegatronPPOActor(OldMegatronPPOActor):
def __init__(
self,
Expand All @@ -68,6 +182,7 @@ def __init__(
self.policy_loss_fn = None
self.kl_loss_fn = None
self.entropy_loss_fn = None
self.calculate_entropy = False

def set_algorithm(self, algorithm_config: AlgorithmConfig):
self.loss_agg_mode = algorithm_config.loss_agg_mode
Expand All @@ -78,6 +193,7 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig):
self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)(
**algorithm_config.entropy_loss_fn_args
)
self.calculate_entropy = algorithm_config.entropy_loss_fn != "none"

def forward_backward_batch( # noqa: C901
self,
Expand Down Expand Up @@ -369,22 +485,12 @@ def logits_processor(logits, label, label_mask):
logits.div_(temperature)
ret = {}
if calculate_entropy:
# The veRL fix consumes more GPU memory than our implementation
# (.clone() v.s. monkey patch on megatron function);
# therefore, we have temporarily commented out the veRL fix.
# logits_bak = logits.clone()
# # disable the hint until the fused_kernel is optimized for triton>=3.3
# logger.warning_once(
# "For memory-efficient computation, enable fused kernels via "
# "`actor_rollout_ref.model.use_fused_kernels=True`. "
# "The current `clone()` operation ensures correctness but increases memory usage."
# )
entropy = vocab_parallel_entropy(logits)
# Only use the safe path when entropy is enabled. This avoids
# issue #1970 without changing the default Megatron CE path.
log_probs, entropy = vocab_parallel_log_probs_and_entropy(logits, label)
ret["entropy"] = entropy
# else:
# logits_bak = logits
# log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label)
log_probs = vocab_parallel_log_probs_from_logits(logits, label)
else:
log_probs = vocab_parallel_log_probs_from_logits(logits, label)
log_probs = log_probs.masked_fill(~label_mask, 0.0)
ret["log_probs"] = log_probs
return ret
Expand Down Expand Up @@ -514,7 +620,6 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict:
# if use distributed optimizer, zero grad buffer will be handled by optimizer
chunk.zero_grad_buffer()

calculate_entropy = self.entropy_loss_fn != DummyEntropyLossFn
if data.meta_info.get("micro_batch_size", None) is not None:
micro_batch_size = data.meta_info["micro_batch_size"]
else:
Expand All @@ -527,7 +632,7 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict:
)
metric_micro_batch = self.forward_backward_batch(
data,
calculate_entropy=calculate_entropy,
calculate_entropy=self.calculate_entropy,
use_dynamic_bsz=self.config.use_dynamic_bsz,
micro_batch_size=micro_batch_size,
max_token_len=max_token_len,
Expand Down Expand Up @@ -562,46 +667,3 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict:
self.actor_optimizer.zero_grad()
get_torch_device().empty_cache()
return metrics


def calculate_predicted_logits(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
logits_max: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Calculates predicted logits.
Modified from megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits
"""

# No In-place subtraction !!!
vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)

# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0

# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
partition_vocab_size = vocab_parallel_logits.size()[-1]
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0

exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)

return target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits


# bug fix for https://github.com/volcengine/verl/issues/1970
VocabParallelCrossEntropy.calculate_predicted_logits = calculate_predicted_logits
Loading
Loading