diff --git a/rl/buffer_server.py b/rl/buffer_server.py index 76d9132..3117435 100644 --- a/rl/buffer_server.py +++ b/rl/buffer_server.py @@ -11,7 +11,7 @@ import numpy as np from datetime import datetime from logging.handlers import RotatingFileHandler -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set # Add rl directory to path for utils import _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -78,6 +78,10 @@ # Pending items by instance_id (for grouping) pending_items_by_instance: Dict[str, List[Dict[str, Any]]] = {} +# Completed session ids by instance_id. The DB may contain every environment +# step, but this server only reads rows marked is_trainable by the rollout side. +completed_sessions_by_instance: Dict[str, Set[str]] = {} + # Group size (set by /start_rollout) group_size: int = 1 @@ -137,6 +141,8 @@ def _build_item_from_row(row: Dict[str, Any]) -> Dict[str, Any]: "group_id": group_id, "weight_version": weight_version, "truncated": row.get("truncated", False), + "step_pk": row.get("step_pk"), + "is_session_completed": bool(row.get("is_session_completed", False)), } return { @@ -148,8 +154,29 @@ def _build_item_from_row(row: Dict[str, Any]) -> Dict[str, Any]: } +def _propagate_terminal_rewards(group: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + terminal_rewards_by_session: Dict[str, float] = {} + for item in group: + extra = item.get("extra_info") or {} + session_id = str(extra.get("session_id") or "") + if session_id and bool(extra.get("is_session_completed", False)): + terminal_rewards_by_session[session_id] = float(item.get("reward", 0.0)) + + for item in group: + extra = item.setdefault("extra_info", {}) + session_id = str(extra.get("session_id") or "") + if session_id not in terminal_rewards_by_session: + continue + original_reward = float(item.get("reward", 0.0)) + extra["step_reward"] = original_reward + extra["terminal_reward"] = terminal_rewards_by_session[session_id] + item["reward"] = terminal_rewards_by_session[session_id] + + return group + + async def fetch_new_items_from_db(limit: Optional[int] = None) -> List[Dict[str, Any]]: - """Fetch new completed steps from the database using cursor-based pagination.""" + """Fetch new trainable rows from the database using cursor-based pagination.""" global data_manager, last_served_id if data_manager is None: @@ -180,9 +207,9 @@ async def fetch_new_items_from_db(limit: Optional[int] = None) -> List[Dict[str, return items -def accumulate_and_pop_ready_groups(new_items: List[Dict[str, Any]]) -> tuple: +def accumulate_and_pop_ready_groups(new_items: List[Dict[str, Any]], max_groups: Optional[int] = None) -> tuple: """Accumulate items and return ready groups.""" - global pending_items_by_instance, group_size + global pending_items_by_instance, completed_sessions_by_instance, group_size ready_groups = [] finished_instance_ids = [] @@ -193,34 +220,76 @@ def accumulate_and_pop_ready_groups(new_items: List[Dict[str, Any]]) -> tuple: if not instance_id: continue pending_items_by_instance.setdefault(instance_id, []).append(item) + extra = item.get("extra_info") or {} + session_id = str(extra.get("session_id") or "") + if session_id and bool(extra.get("is_session_completed", False)): + completed_sessions_by_instance.setdefault(instance_id, set()).add(session_id) - # Check for complete groups + # Check for complete groups. A complete group means all K trajectories for + # this prompt group have emitted their final row. to_delete = [] for instance_id, bucket in pending_items_by_instance.items(): - while len(bucket) >= group_size: - group = bucket[:group_size] - del bucket[:group_size] - ready_groups.append((instance_id, list(group))) + if max_groups is not None and len(ready_groups) >= max_groups: + break + completed_sessions = completed_sessions_by_instance.get(instance_id, set()) + if len(completed_sessions) >= group_size: + if len(completed_sessions) > group_size: + logger.warning( + "Group %s has %d completed sessions, expected group_size=%d", + instance_id, + len(completed_sessions), + group_size, + ) + group = sorted( + bucket, + key=lambda item: ( + str((item.get("extra_info") or {}).get("session_id") or ""), + int((item.get("extra_info") or {}).get("steps") or 0), + int((item.get("extra_info") or {}).get("step_pk") or 0), + ), + ) + group = _propagate_terminal_rewards(group) + ready_groups.append((instance_id, group)) finished_instance_ids.append(instance_id) - if not bucket: to_delete.append(instance_id) for k in to_delete: pending_items_by_instance.pop(k, None) + completed_sessions_by_instance.pop(k, None) return ready_groups, finished_instance_ids @app.post("/get_rollout_data", response_model=BufferResponse) async def get_rollout_data(request: Request): - global pending_items_by_instance + global pending_items_by_instance, completed_sessions_by_instance + + try: + payload = await request.json() + except Exception: + payload = {} + if not isinstance(payload, dict): + payload = {} + max_groups = payload.get("max_groups") + try: + max_groups = int(max_groups) if max_groups is not None else None + except (TypeError, ValueError): + max_groups = None + if max_groups is not None and max_groups <= 0: + max_groups = None # Fetch new items from database and accumulate groups new_items = await fetch_new_items_from_db(limit=None) - ready_groups, finished_ids = accumulate_and_pop_ready_groups(new_items) + ready_groups, finished_ids = accumulate_and_pop_ready_groups(new_items, max_groups=max_groups) # Log pending status - pending_counts = {k: len(v) for k, v in pending_items_by_instance.items()} + pending_counts = { + k: { + "items": len(v), + "completed_sessions": len(completed_sessions_by_instance.get(k, set())), + } + for k, v in pending_items_by_instance.items() + } logger.info(f"new_items={len(new_items)}, ready_groups={len(ready_groups)}, pending={pending_counts}") # Flatten groups to items @@ -291,15 +360,16 @@ def start_aievobox_process(data: dict): NOTE: LLM Proxy is now hosted in-process by slime_generator. It must already be running before this function is called. """ - global aievobox_process, group_size, last_served_id, pending_items_by_instance, data_manager + global aievobox_process, group_size, last_served_id, pending_items_by_instance, completed_sessions_by_instance, data_manager # Set group size (num_repeat_per_sample) - group_size = int(data.get("num_repeat_per_sample", 16)) + group_size = max(1, int(data.get("num_repeat_per_sample", 16))) # Clear state for new rollout restart_training = data.get("restart_training", False) if restart_training: pending_items_by_instance.clear() + completed_sessions_by_instance.clear() logger.info("restart_training=True, cleared pending items") # Keep a single job_session for both reader and writer process. diff --git a/rl/examples/geo3k_vl/env.sh b/rl/examples/geo3k_vl/env.sh index b6fd60c..eab11bc 100644 --- a/rl/examples/geo3k_vl/env.sh +++ b/rl/examples/geo3k_vl/env.sh @@ -1,13 +1,13 @@ # ------------------------------------------- # AIEvobox (rollout) Settings # ------------------------------------------- -export AIEVOBOX_ROOT=/root/AIEvoBox +export AIEVOBOX_ROOT=/root/Safactory export STORAGE_TYPE=sqlite export AIEVOBOX_DB_URL=sqlite:///${AIEVOBOX_ROOT}/rl/examples/geo3k_vl/geo3k_vl.db export AIEVOBOX_MAX_STEPS=10 -export AIEVOBOX_MESSAGE_CUT=0 +export AIEVOBOX_MESSAGE_CUT=1 # ENV_CONFIG 指定单个 yaml 文件 -export AIEVOBOX_ENV_CONFIG=/root/AIEvoBox/env/geo3k_vl_test/geo3k_vl_test_env_configs.yaml +export AIEVOBOX_ENV_CONFIG=${AIEVOBOX_ROOT}/env/geo3k_vl_test/geo3k_vl_test_env_configs.yaml # ENV_ROOT 指定读取目录下所有子目录的环境 # export AIEVOBOX_ENV_ROOT=/root/AIEvoBox/env export AIEVOBOX_POOL_SIZE=256 @@ -26,6 +26,10 @@ export AIEVOBOX_SQLITE_BULK_INSERT_PAUSE_S=0.01 # RL Settings # ------------------------------------------- export RL_GROUP_SIZE=8 +# RL_ROLLOUT_GROUP_BATCH_SIZE has priority. Leave it empty to derive from +# RL_GLOBAL_BATCH_SIZE / RL_GROUP_SIZE. +export RL_ROLLOUT_GROUP_BATCH_SIZE=64 +export RL_GLOBAL_BATCH_SIZE=512 export RL_EPOCH=1000 export RL_OFF_BY_N=0 @@ -61,6 +65,3 @@ export LLM_PROXY_ENABLE_CONSOLE_LOG=0 # Slime Training Settings (reference RL values) # ------------------------------------------- export SLIME_ROLLBUF_RESTART_TRAINING=True -export SLIME_N_SAMPLES_PER_PROMPT=$RL_GROUP_SIZE -export SLIME_GLOBAL_BATCH_SIZE=512 -export SLIME_ROLLOUT_BATCH_SIZE=$((SLIME_GLOBAL_BATCH_SIZE / RL_GROUP_SIZE)) diff --git a/rl/examples/geo3k_vl/run_slime_generator.sh b/rl/examples/geo3k_vl/run_slime_generator.sh index 2341e65..dca218e 100755 --- a/rl/examples/geo3k_vl/run_slime_generator.sh +++ b/rl/examples/geo3k_vl/run_slime_generator.sh @@ -22,6 +22,19 @@ source "${SCRIPT_DIR}/env.sh" ROLLOUT_BUFFER_URL="http://${BUFFER_SERVER_HOST}:${BUFFER_SERVER_PORT}" LLM_PROXY_URL="http://${LLM_PROXY_HOST}:${LLM_PROXY_PORT}" +if [[ -z "${RL_ROLLOUT_GROUP_BATCH_SIZE:-}" ]]; then + RL_ROLLOUT_GROUP_BATCH_SIZE=$((RL_GLOBAL_BATCH_SIZE / RL_GROUP_SIZE)) +fi + +if [[ -z "${RL_GLOBAL_BATCH_SIZE:-}" ]]; then + RL_GLOBAL_BATCH_SIZE=$((RL_ROLLOUT_GROUP_BATCH_SIZE * RL_GROUP_SIZE)) +fi + +if (( RL_ROLLOUT_GROUP_BATCH_SIZE <= 0 || RL_GLOBAL_BATCH_SIZE <= 0 || RL_GROUP_SIZE <= 0 )); then + echo "RL_ROLLOUT_GROUP_BATCH_SIZE, RL_GLOBAL_BATCH_SIZE, and RL_GROUP_SIZE must be positive" >&2 + exit 1 +fi + export PYTHONBUFFERED=16 NUM_GPUS=${NUM_GPUS:-8} @@ -43,11 +56,12 @@ ROLLOUT_ARGS=( --rollout-buffer-url ${ROLLOUT_BUFFER_URL} --disable-rollout-global-dataset --num-rollout 300 - --rollout-batch-size ${SLIME_ROLLOUT_BATCH_SIZE} - --n-samples-per-prompt ${SLIME_N_SAMPLES_PER_PROMPT} + --rollout-batch-size ${RL_ROLLOUT_GROUP_BATCH_SIZE} + --n-samples-per-prompt ${RL_GROUP_SIZE} --rollout-max-response-len ${LLM_MAX_LENGTH} --rollout-temperature ${LLM_TEMPERATURE} - --global-batch-size ${SLIME_GLOBAL_BATCH_SIZE} + --global-batch-size ${RL_GLOBAL_BATCH_SIZE} + --custom-reward-post-process-path rl.variable_group_rewards.post_process_rewards --loss-mask-type qwen ) @@ -71,6 +85,7 @@ MEGATRON_ARGS=( TRAIN_ARGS=( --use-dynamic-batch-size + --use-dynamic-global-batch-size --max-tokens-per-gpu 5000 --calculate-per-token-loss ) diff --git a/rl/mask/trajectory_mask_builder.py b/rl/mask/trajectory_mask_builder.py index 02ecff6..bde0aa5 100644 --- a/rl/mask/trajectory_mask_builder.py +++ b/rl/mask/trajectory_mask_builder.py @@ -214,6 +214,13 @@ def _build_mm_inputs( return list(input_ids), mm_train_inputs def _render_message_delta_str(self, model_input_message: Dict[str, Any]) -> str: + if model_input_message.get("role") == "system": + return self.tokenizer.apply_chat_template( + [model_input_message], + add_generation_prompt=False, + tokenize=False, + ) + single_message_chat_template_str = self.tokenizer.apply_chat_template( BASE_CHAT_HISTORY + [model_input_message], add_generation_prompt=False, diff --git a/rl/slime_generator.py b/rl/slime_generator.py index fca0a59..6f2970e 100644 --- a/rl/slime_generator.py +++ b/rl/slime_generator.py @@ -316,12 +316,16 @@ def group_by_instance_id(results: List[Dict]) -> List[List[Dict]]: return list(groups.values()) -async def get_rollout_data(api_base_url: str) -> tuple[List[Dict[str, Any]], Dict[str, Any]]: +async def get_rollout_data(api_base_url: str, max_groups: Optional[int] = None) -> tuple[List[Dict[str, Any]], Dict[str, Any]]: start_time = time.time() + payload = {} + if max_groups is not None: + payload["max_groups"] = max(1, int(max_groups)) + async with aiohttp.ClientSession() as session: while True: async with session.post( - f"{api_base_url}/get_rollout_data", json={}, timeout=aiohttp.ClientTimeout(total=120) + f"{api_base_url}/get_rollout_data", json=payload, timeout=aiohttp.ClientTimeout(total=120) ) as response: response.raise_for_status() resp_json = await response.json() @@ -386,50 +390,20 @@ def start_rollout(api_base_url: str, args, metadata): print(f"[start_rollout] Failed to send rollout config: {e}") -def filter_by_weight_version(data_buffer, current_version: int, off_by_n: int = 0): - """根据权重版本过滤 buffer 中的数据。 - - 过滤掉那些权重版本与当前版本差距超过 off_by_n 的样本。 - - Args: - data_buffer: 数据 buffer - current_version: 当前权重版本(当前 pipeline 中通常是 rollout_id + 1) - off_by_n: 允许的最大权重差,默认为 0(只保留当前版本的数据) - """ - buffer_length = data_buffer.get_buffer_length() - if buffer_length == 0: - return - - # 获取所有样本 - all_samples = data_buffer.get_samples(buffer_length) - - # 过滤样本 - filtered_samples = [] - for sample_group in all_samples: - filtered_group = [] - len_sample_group = len(sample_group) - for sample in sample_group: - metadata = getattr(sample, "metadata", None) or {} - sample_version = metadata.get("weight_version", 0) - try: - sample_version = int(sample_version) - except (ValueError, TypeError): - sample_version = 0 - - # 检查权重版本差距是否在允许范围内 - if current_version - sample_version <= off_by_n: - filtered_group.append(sample) - else: - logger.debug( - f"Filtered out sample with weight_version={sample_version}, " - f"current_version={current_version}, off_by_n={off_by_n}" +def record_used_metrics(metrics: MetricsRecorder, sample_groups: List[List[Sample]]) -> None: + for group in sample_groups: + for sample in group: + if sample.reward is None: + raise RuntimeError( + "Encountered reward=None after rollout assembly. " + "The rollout buffer is likely underfilled." ) + metrics.record("used/reward", float(sample.reward), AggType.MEAN) + metrics.record("used/response_length", float(sample.response_length), AggType.MEAN) + meta = getattr(sample, "metadata", {}) or {} + metrics.record("used/weight_version", float(meta.get("weight_version", 0)), AggType.MEAN) + metrics.record("used/count", float(sum(len(g) for g in sample_groups)), AggType.SUM) - if filtered_group and len(filtered_group) == len_sample_group: - filtered_samples.append(filtered_group) - - if filtered_samples: - data_buffer.add_samples(filtered_samples) async def generate_rollout_async(args, rollout_id: int, data_buffer, evaluation: bool = False) -> Dict[str, Any]: if evaluation: @@ -442,30 +416,8 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer, evaluation: # 根据weight_version过滤已完成的数据 off_by_n = int(get_env("RL_OFF_BY_N")) dapo_filter_enabled = os.environ.get("DAPO_filter", "true").strip().lower() in ("1", "true", "yes", "on") - filter_by_weight_version(data_buffer, current_version=current_version, off_by_n=off_by_n) - buffer_length = data_buffer.get_buffer_length() - needed_groups = max(0, args.rollout_batch_size - buffer_length) - data_number_to_fetch = needed_groups * args.n_samples_per_prompt - print(f"INFO: buffer length: {buffer_length}, data_number_to_fetch: {data_number_to_fetch}") - if needed_groups <= 0: - print( - f"❕buffer length: {data_buffer.get_buffer_length()}, buffer has enough data, return {args.rollout_batch_size} prompts" - ) - final_return_results = data_buffer.get_samples(args.rollout_batch_size) - for group in final_return_results: - for sample in group: - if sample.reward is None: - raise RuntimeError( - "Encountered reward=None after rollout assembly. " - "The rollout buffer is likely underfilled." - ) - metrics.record("used/reward", float(sample.reward), AggType.MEAN) - metrics.record("used/response_length", float(sample.response_length), AggType.MEAN) - meta = getattr(sample, "metadata", {}) or {} - metrics.record("used/weight_version", float(meta.get("weight_version", 0)), AggType.MEAN) - metrics.record("used/count", float(sum(len(g) for g in final_return_results)), AggType.SUM) - metrics.push(step=rollout_id) - return final_return_results + sample_group_buffer: List[List[Sample]] = [] + print(f"INFO: need rollout groups: {args.rollout_batch_size}") base_url = args.rollout_buffer_url tokenizer = TOKENIZER retry_times = 0 @@ -477,23 +429,27 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer, evaluation: ) # 持续获取数据,直到 buffer 中有足够的可训练 groups - while data_buffer.get_buffer_length() < args.rollout_batch_size and ( + while len(sample_group_buffer) < args.rollout_batch_size and ( args.fetch_trajectory_retry_times == -1 or retry_times < args.fetch_trajectory_retry_times ): retry_times += 1 try: - remaining_groups = args.rollout_batch_size - data_buffer.get_buffer_length() - fetch_sample_count = remaining_groups * args.n_samples_per_prompt - print(f"need sample count: fetch_sample_count: {fetch_sample_count}") + remaining_groups = args.rollout_batch_size - len(sample_group_buffer) + print(f"need group count: remaining_groups: {remaining_groups}") raw_results = [] + raw_group_count = 0 - while len(raw_results) < fetch_sample_count: + while raw_group_count < remaining_groups: await asyncio.sleep(5) - data, meta_info = await get_rollout_data(api_base_url=base_url) + data, meta_info = await get_rollout_data( + api_base_url=base_url, + max_groups=remaining_groups - raw_group_count, + ) raw_results.extend(data) + raw_group_count = len(group_by_instance_id(raw_results)) if meta_info: all_meta_info.append(meta_info) - print(f"get rollout data with length: {len(raw_results)}") + print(f"get rollout data with items={len(raw_results)}, groups={raw_group_count}") # 从 extra_info 中获取 weight_version,记录 fetched metrics for record in raw_results: @@ -509,7 +465,11 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer, evaluation: # 按 group 过滤:group 中所有 sample 都必须符合版本要求 valid_groups = [] for group in grouped_results: - rewards = [record.get("reward") for record in group] + reward_records = [ + record for record in group + if (record.get("extra_info") or {}).get("is_session_completed", False) + ] or group + rewards = [record.get("reward") for record in reward_records] if dapo_filter_enabled and len(set(rewards)) == 1: logger.info( f"Filtered out group with rewards={rewards}, " @@ -529,13 +489,13 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer, evaluation: print(f"✅ Valid groups collected this round: {len(valid_groups)}") sample_results = [] - touched_session_ids = set() + touched_session_ids = { + record["extra_info"].get("session_id", "") + for record in raw_results + if record.get("extra_info") and record["extra_info"].get("session_id", "") + } try: flat_records = [record for group_record in valid_groups for record in group_record] - for record in flat_records: - session_id = record["extra_info"].get("session_id", "") - if session_id: - touched_session_ids.add(session_id) loop = asyncio.get_running_loop() traininfo_executor = _get_traininfo_executor() @@ -589,8 +549,11 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer, evaluation: write_debug_to_file(tokenizer, rollout_id, record, oai_messages, token_ids, loss_mask, response_length) metadata = dict(record["extra_info"]) + step_id = metadata.get("steps", 0) + sample_index = f"{session_id}:{step_id}" if session_id else record["instance_id"] sample = Sample( - index=record["instance_id"], + group_index=record["instance_id"], + index=sample_index, prompt=record["uid"], tokens=token_ids, response_length=response_length, @@ -627,10 +590,10 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer, evaluation: if TRAJECTORY_MASK_BUILDER is not None: for session_id in touched_session_ids: TRAJECTORY_MASK_BUILDER.clear_session(session_id) - data_buffer.add_samples(sample_results) + sample_group_buffer.extend(sample_results) print( "✅ Trainable groups added this round: " - f"{len(sample_results)}, buffer length: {data_buffer.get_buffer_length()}/{args.rollout_batch_size}" + f"{len(sample_results)}, buffer length: {len(sample_group_buffer)}/{args.rollout_batch_size}" ) except Exception as err: @@ -641,28 +604,18 @@ async def generate_rollout_async(args, rollout_id: int, data_buffer, evaluation: for item in all_meta_info: finished_groups_instance_id_list.extend(item["finished_groups"]) - data_buffer.update_metadata({str(rollout_id): finished_groups_instance_id_list}) + if hasattr(data_buffer, "update_metadata"): + data_buffer.update_metadata({str(rollout_id): finished_groups_instance_id_list}) - print("finally buffered trainable group count: ", data_buffer.get_buffer_length()) - if data_buffer.get_buffer_length() < args.rollout_batch_size: + print("finally buffered trainable group count: ", len(sample_group_buffer)) + if len(sample_group_buffer) < args.rollout_batch_size: raise RuntimeError( "Insufficient trainable rollout groups after filtering and trajectory matching: " - f"buffer_length={data_buffer.get_buffer_length()}, required={args.rollout_batch_size}" + f"buffer_length={len(sample_group_buffer)}, required={args.rollout_batch_size}" ) - final_return_results = data_buffer.get_samples(args.rollout_batch_size) - for group in final_return_results: - for sample in group: - if sample.reward is None: - raise RuntimeError( - "Encountered reward=None after rollout assembly. " - "The rollout buffer is likely underfilled." - ) - metrics.record("used/reward", float(sample.reward), AggType.MEAN) - metrics.record("used/response_length", float(sample.response_length), AggType.MEAN) - meta = getattr(sample, "metadata", {}) or {} - metrics.record("used/weight_version", float(meta.get("weight_version", 0)), AggType.MEAN) - metrics.record("used/count", float(sum(len(g) for g in final_return_results)), AggType.SUM) + final_return_results = sample_group_buffer + record_used_metrics(metrics, final_return_results) metrics.push(step=rollout_id) return final_return_results diff --git a/rl/variable_group_rewards.py b/rl/variable_group_rewards.py new file mode 100644 index 0000000..f6b7a49 --- /dev/null +++ b/rl/variable_group_rewards.py @@ -0,0 +1,45 @@ +from collections import OrderedDict +import math +from typing import Any, Dict, List, TYPE_CHECKING, Tuple + +if TYPE_CHECKING: + from slime.utils.types import Sample + + +def _group_key(sample: "Sample", fallback_index: int) -> Any: + group_index = getattr(sample, "group_index", None) + if group_index is not None: + return group_index + + metadata = getattr(sample, "metadata", None) or {} + return metadata.get("group_id", fallback_index) + + +def post_process_rewards(args, samples: List["Sample"]) -> Tuple[List[float], List[float]]: + """Normalize rewards by rollout group for message_cut variable-size groups.""" + raw_rewards = [sample.get_reward_value(args) for sample in samples] + if not ( + args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"] + and args.rewards_normalization + ): + return raw_rewards, raw_rewards + + groups: Dict[Any, List[int]] = OrderedDict() + for i, sample in enumerate(samples): + groups.setdefault(_group_key(sample, i), []).append(i) + + rewards = [0.0] * len(samples) + for indices in groups.values(): + group_rewards = [float(raw_rewards[i]) for i in indices] + mean = sum(group_rewards) / len(group_rewards) + normalized = [reward - mean for reward in group_rewards] + + if args.advantage_estimator in ["grpo", "gspo"] and args.grpo_std_normalization and len(indices) > 1: + variance = sum((reward - mean) ** 2 for reward in group_rewards) / (len(group_rewards) - 1) + std = math.sqrt(variance) + normalized = [reward / (std + 1e-6) for reward in normalized] + + for i, reward in zip(indices, normalized, strict=False): + rewards[i] = reward + + return raw_rewards, rewards