Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/scripts/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"]
TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"]
TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"]
os.environ['XTUNER_USE_FA3'] = "1"

def parse_args():
parser = argparse.ArgumentParser(description="VLLM Rollout Test Script")
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/ray/config/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class RolloutConfig(BaseModel):
] = "lmdeploy"
model_path: Annotated[str | Path, Parameter(group=infer_group, help="Path to the SGLang model.")]
model_name: Annotated[str, Parameter(group=infer_group, help="Name of the model to be used in the LMDeploy.")]
tokenizer_path: Annotated[str, Parameter(group=infer_group, help="Path to the tokenizer for the model.")] = ""
tokenizer_path: Annotated[str, Parameter(group=infer_group, help="Path to the tokenizer for the model.")]
api_key: Annotated[
Optional[Union[List[str], str]],
Parameter(
Expand Down
15 changes: 8 additions & 7 deletions xtuner/v1/ray/rollout/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,14 @@ def __init__(
self.num_workers = 0
self.worker_server_urls: List[str] = []
self.active_rollout_workers: List[RolloutWorker] = []
self.tokenizer = (
AutoTokenizer.from_pretrained(infer_config.model_path, trust_remote_code=True)
if infer_config.tokenizer_path
else None
)
self.tokenizer = AutoTokenizer.from_pretrained(infer_config.model_path, trust_remote_code=True)
self.workers_bundle_idx_map = workers_bundle_idx_map
self.engine_mesh_list, self.server_url_dict = self.init_workers()
# todo(@duanyanhui): add router to replace native round robin
self.worker_index = 0 # round robin index
self.sample_params = SampleParams()
self.sample_params = SampleParams(
stops=[self.tokenizer.decode(self.tokenizer.eos_token_id)], stop_token_ids=[self.tokenizer.eos_token_id]
)

def get_rollout_info(self):
"""Get information about the current rollout setup.
Expand Down Expand Up @@ -185,11 +183,14 @@ async def rollout(
"""
index = self.worker_index % len(self.active_rollout_workers)
final_sample_params = sample_params if sample_params else self.sample_params
# note(@duanyanhui): ensure stops and stop_token_ids are set to append eos in response
final_sample_params.stops = final_sample_params.stops or self.sample_params.stops
final_sample_params.stop_token_ids = final_sample_params.stop_token_ids or self.sample_params.stop_token_ids
response_ref = self.active_rollout_workers[index].rollout.remote( # type: ignore[attr-defined]
prompt,
tools=tools,
tool_choice=tool_choice,
sample_params=final_sample_params,
sample_params=final_sample_params.dict(),
extra_params=extra_params,
format=format,
)
Expand Down
33 changes: 23 additions & 10 deletions xtuner/v1/ray/rollout/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,29 +285,42 @@ async def rollout_task(
return "", "failed" # 返回明确的失败状态

last_trajectory = ""
async for chunk in response.aiter_text():
if chunk == "":
async for chunk in response.aiter_lines():
if not chunk.startswith("data:"):
continue
try:
if self.paused:
await response.aclose()
self.logger.debug(f"--- get paused request {uid}")
return last_trajectory, "unfinished"
chunk_data = chunk[len("data:") :].strip() # Remove "data: " prefix
if chunk_data == "[DONE]":

chunk_data_str = chunk[len("data:") :].strip()
if chunk_data_str == "[DONE]":
self.logger.debug(f" --- get finished request {uid}")
await response.aclose()
return last_trajectory, "finished"
else:
if not (chunk_data.startswith("{") and chunk_data.endswith("}")):
continue
last_trajectory += json.loads(chunk_data)["choices"][0]["delta"]["content"]

if not (chunk_data_str.startswith("{") and chunk_data_str.endswith("}")):
continue

chunk_data = json.loads(chunk_data_str)

delta_content = chunk_data["choices"][0]["delta"].get("content")
if delta_content:
last_trajectory += delta_content

# todo(@duanyanhui): remove appending stop tokens manually after lmdeploy support return stop_token_ids.
finish_reason = chunk_data["choices"][0].get("finish_reason")
if finish_reason == "stop":
assert len(sample_params["stops"]) == 1
last_trajectory += sample_params["stops"][0]

except json.JSONDecodeError as e:
self.logger.error(f"JSON decode error for chunk in request {uid}: {chunk}, error: {e}")
continue # 选择跳过这个损坏的块
continue
except Exception as e:
self.logger.error(f"Error processing chunk for {uid}: {chunk}, error: {e}")
return last_trajectory, "failed" # 出现意外错误时,终止并返回失败
return last_trajectory, "failed"
except httpx.RequestError as e:
self.logger.error(f"Request {uid} failed with a network error: {e}")
return "", "failed"
Expand Down