diff --git a/ci/scripts/test_grpo_trainer.py b/ci/scripts/test_grpo_trainer.py index 6c392b15a..7b0df7c37 100644 --- a/ci/scripts/test_grpo_trainer.py +++ b/ci/scripts/test_grpo_trainer.py @@ -38,6 +38,7 @@ MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] +os.environ['XTUNER_USE_FA3'] = "1" def parse_args(): parser = argparse.ArgumentParser(description="VLLM Rollout Test Script") diff --git a/xtuner/v1/ray/config/worker.py b/xtuner/v1/ray/config/worker.py index 2647e8892..cc3e68c7c 100644 --- a/xtuner/v1/ray/config/worker.py +++ b/xtuner/v1/ray/config/worker.py @@ -78,7 +78,7 @@ class RolloutConfig(BaseModel): ] = "lmdeploy" model_path: Annotated[str | Path, Parameter(group=infer_group, help="Path to the SGLang model.")] model_name: Annotated[str, Parameter(group=infer_group, help="Name of the model to be used in the LMDeploy.")] - tokenizer_path: Annotated[str, Parameter(group=infer_group, help="Path to the tokenizer for the model.")] = "" + tokenizer_path: Annotated[str, Parameter(group=infer_group, help="Path to the tokenizer for the model.")] api_key: Annotated[ Optional[Union[List[str], str]], Parameter( diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py index 3af3a56e0..133a4565d 100644 --- a/xtuner/v1/ray/rollout/controller.py +++ b/xtuner/v1/ray/rollout/controller.py @@ -82,16 +82,14 @@ def __init__( self.num_workers = 0 self.worker_server_urls: List[str] = [] self.active_rollout_workers: List[RolloutWorker] = [] - self.tokenizer = ( - AutoTokenizer.from_pretrained(infer_config.model_path, trust_remote_code=True) - if infer_config.tokenizer_path - else None - ) + self.tokenizer = AutoTokenizer.from_pretrained(infer_config.model_path, trust_remote_code=True) self.workers_bundle_idx_map = workers_bundle_idx_map self.engine_mesh_list, self.server_url_dict = self.init_workers() # todo(@duanyanhui): add router to replace native round robin self.worker_index = 0 # round robin index - self.sample_params = SampleParams() + self.sample_params = SampleParams( + stops=[self.tokenizer.decode(self.tokenizer.eos_token_id)], stop_token_ids=[self.tokenizer.eos_token_id] + ) def get_rollout_info(self): """Get information about the current rollout setup. @@ -185,11 +183,14 @@ async def rollout( """ index = self.worker_index % len(self.active_rollout_workers) final_sample_params = sample_params if sample_params else self.sample_params + # note(@duanyanhui): ensure stops and stop_token_ids are set to append eos in response + final_sample_params.stops = final_sample_params.stops or self.sample_params.stops + final_sample_params.stop_token_ids = final_sample_params.stop_token_ids or self.sample_params.stop_token_ids response_ref = self.active_rollout_workers[index].rollout.remote( # type: ignore[attr-defined] prompt, tools=tools, tool_choice=tool_choice, - sample_params=final_sample_params, + sample_params=final_sample_params.dict(), extra_params=extra_params, format=format, ) diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py index 51c1d3dc3..270831480 100644 --- a/xtuner/v1/ray/rollout/worker.py +++ b/xtuner/v1/ray/rollout/worker.py @@ -285,29 +285,42 @@ async def rollout_task( return "", "failed" # 返回明确的失败状态 last_trajectory = "" - async for chunk in response.aiter_text(): - if chunk == "": + async for chunk in response.aiter_lines(): + if not chunk.startswith("data:"): continue try: if self.paused: await response.aclose() self.logger.debug(f"--- get paused request {uid}") return last_trajectory, "unfinished" - chunk_data = chunk[len("data:") :].strip() # Remove "data: " prefix - if chunk_data == "[DONE]": + + chunk_data_str = chunk[len("data:") :].strip() + if chunk_data_str == "[DONE]": self.logger.debug(f" --- get finished request {uid}") await response.aclose() return last_trajectory, "finished" - else: - if not (chunk_data.startswith("{") and chunk_data.endswith("}")): - continue - last_trajectory += json.loads(chunk_data)["choices"][0]["delta"]["content"] + + if not (chunk_data_str.startswith("{") and chunk_data_str.endswith("}")): + continue + + chunk_data = json.loads(chunk_data_str) + + delta_content = chunk_data["choices"][0]["delta"].get("content") + if delta_content: + last_trajectory += delta_content + + # todo(@duanyanhui): remove appending stop tokens manually after lmdeploy support return stop_token_ids. + finish_reason = chunk_data["choices"][0].get("finish_reason") + if finish_reason == "stop": + assert len(sample_params["stops"]) == 1 + last_trajectory += sample_params["stops"][0] + except json.JSONDecodeError as e: self.logger.error(f"JSON decode error for chunk in request {uid}: {chunk}, error: {e}") - continue # 选择跳过这个损坏的块 + continue except Exception as e: self.logger.error(f"Error processing chunk for {uid}: {chunk}, error: {e}") - return last_trajectory, "failed" # 出现意外错误时,终止并返回失败 + return last_trajectory, "failed" except httpx.RequestError as e: self.logger.error(f"Request {uid} failed with a network error: {e}") return "", "failed"