Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 85 additions & 15 deletions rl/buffer_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
from datetime import datetime
from logging.handlers import RotatingFileHandler
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set

# Add rl directory to path for utils import
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -78,6 +78,10 @@
# Pending items by instance_id (for grouping)
pending_items_by_instance: Dict[str, List[Dict[str, Any]]] = {}

# Completed session ids by instance_id. The DB may contain every environment
# step, but this server only reads rows marked is_trainable by the rollout side.
completed_sessions_by_instance: Dict[str, Set[str]] = {}

# Group size (set by /start_rollout)
group_size: int = 1

Expand Down Expand Up @@ -137,6 +141,8 @@ def _build_item_from_row(row: Dict[str, Any]) -> Dict[str, Any]:
"group_id": group_id,
"weight_version": weight_version,
"truncated": row.get("truncated", False),
"step_pk": row.get("step_pk"),
"is_session_completed": bool(row.get("is_session_completed", False)),
}

return {
Expand All @@ -148,8 +154,29 @@ def _build_item_from_row(row: Dict[str, Any]) -> Dict[str, Any]:
}


def _propagate_terminal_rewards(group: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
terminal_rewards_by_session: Dict[str, float] = {}
for item in group:
extra = item.get("extra_info") or {}
session_id = str(extra.get("session_id") or "")
if session_id and bool(extra.get("is_session_completed", False)):
terminal_rewards_by_session[session_id] = float(item.get("reward", 0.0))

for item in group:
extra = item.setdefault("extra_info", {})
session_id = str(extra.get("session_id") or "")
if session_id not in terminal_rewards_by_session:
continue
original_reward = float(item.get("reward", 0.0))
extra["step_reward"] = original_reward
extra["terminal_reward"] = terminal_rewards_by_session[session_id]
item["reward"] = terminal_rewards_by_session[session_id]

return group


async def fetch_new_items_from_db(limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""Fetch new completed steps from the database using cursor-based pagination."""
"""Fetch new trainable rows from the database using cursor-based pagination."""
global data_manager, last_served_id

if data_manager is None:
Expand Down Expand Up @@ -180,9 +207,9 @@ async def fetch_new_items_from_db(limit: Optional[int] = None) -> List[Dict[str,
return items


def accumulate_and_pop_ready_groups(new_items: List[Dict[str, Any]]) -> tuple:
def accumulate_and_pop_ready_groups(new_items: List[Dict[str, Any]], max_groups: Optional[int] = None) -> tuple:
"""Accumulate items and return ready groups."""
global pending_items_by_instance, group_size
global pending_items_by_instance, completed_sessions_by_instance, group_size

ready_groups = []
finished_instance_ids = []
Expand All @@ -193,34 +220,76 @@ def accumulate_and_pop_ready_groups(new_items: List[Dict[str, Any]]) -> tuple:
if not instance_id:
continue
pending_items_by_instance.setdefault(instance_id, []).append(item)
extra = item.get("extra_info") or {}
session_id = str(extra.get("session_id") or "")
if session_id and bool(extra.get("is_session_completed", False)):
completed_sessions_by_instance.setdefault(instance_id, set()).add(session_id)

# Check for complete groups
# Check for complete groups. A complete group means all K trajectories for
# this prompt group have emitted their final row.
to_delete = []
for instance_id, bucket in pending_items_by_instance.items():
while len(bucket) >= group_size:
group = bucket[:group_size]
del bucket[:group_size]
ready_groups.append((instance_id, list(group)))
if max_groups is not None and len(ready_groups) >= max_groups:
break
completed_sessions = completed_sessions_by_instance.get(instance_id, set())
if len(completed_sessions) >= group_size:
if len(completed_sessions) > group_size:
logger.warning(
"Group %s has %d completed sessions, expected group_size=%d",
instance_id,
len(completed_sessions),
group_size,
)
group = sorted(
bucket,
key=lambda item: (
str((item.get("extra_info") or {}).get("session_id") or ""),
int((item.get("extra_info") or {}).get("steps") or 0),
int((item.get("extra_info") or {}).get("step_pk") or 0),
),
)
group = _propagate_terminal_rewards(group)
ready_groups.append((instance_id, group))
finished_instance_ids.append(instance_id)
if not bucket:
to_delete.append(instance_id)

for k in to_delete:
pending_items_by_instance.pop(k, None)
completed_sessions_by_instance.pop(k, None)

return ready_groups, finished_instance_ids


@app.post("/get_rollout_data", response_model=BufferResponse)
async def get_rollout_data(request: Request):
global pending_items_by_instance
global pending_items_by_instance, completed_sessions_by_instance

try:
payload = await request.json()
except Exception:
payload = {}
if not isinstance(payload, dict):
payload = {}
max_groups = payload.get("max_groups")
try:
max_groups = int(max_groups) if max_groups is not None else None
except (TypeError, ValueError):
max_groups = None
if max_groups is not None and max_groups <= 0:
max_groups = None

# Fetch new items from database and accumulate groups
new_items = await fetch_new_items_from_db(limit=None)
ready_groups, finished_ids = accumulate_and_pop_ready_groups(new_items)
ready_groups, finished_ids = accumulate_and_pop_ready_groups(new_items, max_groups=max_groups)

# Log pending status
pending_counts = {k: len(v) for k, v in pending_items_by_instance.items()}
pending_counts = {
k: {
"items": len(v),
"completed_sessions": len(completed_sessions_by_instance.get(k, set())),
}
for k, v in pending_items_by_instance.items()
}
logger.info(f"new_items={len(new_items)}, ready_groups={len(ready_groups)}, pending={pending_counts}")

# Flatten groups to items
Expand Down Expand Up @@ -291,15 +360,16 @@ def start_aievobox_process(data: dict):
NOTE: LLM Proxy is now hosted in-process by slime_generator.
It must already be running before this function is called.
"""
global aievobox_process, group_size, last_served_id, pending_items_by_instance, data_manager
global aievobox_process, group_size, last_served_id, pending_items_by_instance, completed_sessions_by_instance, data_manager

# Set group size (num_repeat_per_sample)
group_size = int(data.get("num_repeat_per_sample", 16))
group_size = max(1, int(data.get("num_repeat_per_sample", 16)))

# Clear state for new rollout
restart_training = data.get("restart_training", False)
if restart_training:
pending_items_by_instance.clear()
completed_sessions_by_instance.clear()
logger.info("restart_training=True, cleared pending items")

# Keep a single job_session for both reader and writer process.
Expand Down
13 changes: 7 additions & 6 deletions rl/examples/geo3k_vl/env.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# -------------------------------------------
# AIEvobox (rollout) Settings
# -------------------------------------------
export AIEVOBOX_ROOT=/root/AIEvoBox
export AIEVOBOX_ROOT=/root/Safactory
export STORAGE_TYPE=sqlite
export AIEVOBOX_DB_URL=sqlite:///${AIEVOBOX_ROOT}/rl/examples/geo3k_vl/geo3k_vl.db
export AIEVOBOX_MAX_STEPS=10
export AIEVOBOX_MESSAGE_CUT=0
export AIEVOBOX_MESSAGE_CUT=1
# ENV_CONFIG 指定单个 yaml 文件
export AIEVOBOX_ENV_CONFIG=/root/AIEvoBox/env/geo3k_vl_test/geo3k_vl_test_env_configs.yaml
export AIEVOBOX_ENV_CONFIG=${AIEVOBOX_ROOT}/env/geo3k_vl_test/geo3k_vl_test_env_configs.yaml
# ENV_ROOT 指定读取目录下所有子目录的环境
# export AIEVOBOX_ENV_ROOT=/root/AIEvoBox/env
export AIEVOBOX_POOL_SIZE=256
Expand All @@ -26,6 +26,10 @@ export AIEVOBOX_SQLITE_BULK_INSERT_PAUSE_S=0.01
# RL Settings
# -------------------------------------------
export RL_GROUP_SIZE=8
# RL_ROLLOUT_GROUP_BATCH_SIZE has priority. Leave it empty to derive from
# RL_GLOBAL_BATCH_SIZE / RL_GROUP_SIZE.
export RL_ROLLOUT_GROUP_BATCH_SIZE=64
export RL_GLOBAL_BATCH_SIZE=512
export RL_EPOCH=1000
export RL_OFF_BY_N=0

Expand Down Expand Up @@ -61,6 +65,3 @@ export LLM_PROXY_ENABLE_CONSOLE_LOG=0
# Slime Training Settings (reference RL values)
# -------------------------------------------
export SLIME_ROLLBUF_RESTART_TRAINING=True
export SLIME_N_SAMPLES_PER_PROMPT=$RL_GROUP_SIZE
export SLIME_GLOBAL_BATCH_SIZE=512
export SLIME_ROLLOUT_BATCH_SIZE=$((SLIME_GLOBAL_BATCH_SIZE / RL_GROUP_SIZE))
21 changes: 18 additions & 3 deletions rl/examples/geo3k_vl/run_slime_generator.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ source "${SCRIPT_DIR}/env.sh"
ROLLOUT_BUFFER_URL="http://${BUFFER_SERVER_HOST}:${BUFFER_SERVER_PORT}"
LLM_PROXY_URL="http://${LLM_PROXY_HOST}:${LLM_PROXY_PORT}"

if [[ -z "${RL_ROLLOUT_GROUP_BATCH_SIZE:-}" ]]; then
RL_ROLLOUT_GROUP_BATCH_SIZE=$((RL_GLOBAL_BATCH_SIZE / RL_GROUP_SIZE))
fi

if [[ -z "${RL_GLOBAL_BATCH_SIZE:-}" ]]; then
RL_GLOBAL_BATCH_SIZE=$((RL_ROLLOUT_GROUP_BATCH_SIZE * RL_GROUP_SIZE))
fi

if (( RL_ROLLOUT_GROUP_BATCH_SIZE <= 0 || RL_GLOBAL_BATCH_SIZE <= 0 || RL_GROUP_SIZE <= 0 )); then
echo "RL_ROLLOUT_GROUP_BATCH_SIZE, RL_GLOBAL_BATCH_SIZE, and RL_GROUP_SIZE must be positive" >&2
exit 1
fi

export PYTHONBUFFERED=16
NUM_GPUS=${NUM_GPUS:-8}

Expand All @@ -43,11 +56,12 @@ ROLLOUT_ARGS=(
--rollout-buffer-url ${ROLLOUT_BUFFER_URL}
--disable-rollout-global-dataset
--num-rollout 300
--rollout-batch-size ${SLIME_ROLLOUT_BATCH_SIZE}
--n-samples-per-prompt ${SLIME_N_SAMPLES_PER_PROMPT}
--rollout-batch-size ${RL_ROLLOUT_GROUP_BATCH_SIZE}
--n-samples-per-prompt ${RL_GROUP_SIZE}
--rollout-max-response-len ${LLM_MAX_LENGTH}
--rollout-temperature ${LLM_TEMPERATURE}
--global-batch-size ${SLIME_GLOBAL_BATCH_SIZE}
--global-batch-size ${RL_GLOBAL_BATCH_SIZE}
--custom-reward-post-process-path rl.variable_group_rewards.post_process_rewards
--loss-mask-type qwen
)

Expand All @@ -71,6 +85,7 @@ MEGATRON_ARGS=(

TRAIN_ARGS=(
--use-dynamic-batch-size
--use-dynamic-global-batch-size
--max-tokens-per-gpu 5000
--calculate-per-token-loss
)
Expand Down
7 changes: 7 additions & 0 deletions rl/mask/trajectory_mask_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ def _build_mm_inputs(
return list(input_ids), mm_train_inputs

def _render_message_delta_str(self, model_input_message: Dict[str, Any]) -> str:
if model_input_message.get("role") == "system":
return self.tokenizer.apply_chat_template(
[model_input_message],
add_generation_prompt=False,
tokenize=False,
)

single_message_chat_template_str = self.tokenizer.apply_chat_template(
BASE_CHAT_HISTORY + [model_input_message],
add_generation_prompt=False,
Expand Down
Loading