Skip to content

Commit 3c7774c

Browse files
Merge branch 'main' into fsiino/prepare-data-aggregations
2 parents d419c15 + fe9f676 commit 3c7774c

File tree

76 files changed

+1405
-8610
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+1405
-8610
lines changed

CONTRIBUTING.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,34 @@
11
# Contributing To NeMo-Gym
22

3+
## Quality control
4+
A checklist for all verifier data to be submitted to Nemo Gym. Please follow this pipeline before submitting a merge request.
5+
6+
1. Necessary information to be included in the merge request:
7+
1. Corresponding dataset on the spreadsheet.
8+
2. Description of the prompt. What is the source of the prompt, which domain is it covering?
9+
3. Description of the environment, if there is any.
10+
4. Description of the verifier. How is it verified and whether we have checked the correctness of the verifier.
11+
5. Legal approval status? If synthetically generated by ourselves with open models, please note there so that we know we don’t need legal approval.
12+
2. Simple correctness check: After finishing implementing your own resources_servers (and/or your own customized code for more complicated tasks), please follow the guideline here to run the server, query OpenAI gpt-4o model (or any other model you like) and get 5 example rollouts and corresponding rewards there. Please include in your PR merge request:
13+
1. The command you used to run the server for the uploaded data
14+
2. The resulting rollout and judges (include 5 examples here for people to understand better the data samples, and to ensure reward here is correct.)
15+
3. Other additional notes for running the server properly with the new PR.
16+
3. Test: Please follow the guideline here to implement your own test and run test for your environment. Tests are strongly encouraged and you must have at least one test for every server you make. Test coverage is not explicitly required which means that YOU ARE RESPONSIBLE FOR YOUR OWN SERVER CORRECTNESS AND FUNCTION.
17+
4. Reward Profiling: Please run inference on your prompts and environments (a ~500 small subset is OK) on two models:
18+
1. Qwen 3 30B A3B
19+
2. Qwen 3 235B Instruct (if that’s for agent / agentic coding / instruction following / game environments) or Qwen 3 235B Thinking (if math / competition coding)
20+
3. Generate 16 responses for each prompt, and report the reward distribution (percent of all correct, all incorrect, and mixture of correct & incorrect prompts there).
21+
4. [If using tool calling] Please also provide metrics around the number of tool calls issued on average per prompt in the environment, and the correlation of the reward with the number of tool calls.
22+
5. [After Nemo Gym + Nemo RL integration is done] Training-based correctness check: Please train on the following models with GRPO and include both training accuracy curve and test benchmark accuracy curve:
23+
1. Qwen 30B A3B Instruct
24+
2. [With more compute available] Qwen 235B Instruct
25+
6. [PR Check and Review] Please assign another person in your team for reproducing and reviewing the PRs once it’s ready. The person for review needs to
26+
1. Verify the content for all the above 1-5 steps
27+
2. Check the correctness of the 5 examples
28+
3. Re-run the procedure provided in README to ensure one can generate the same dataset
29+
4. After the person confirms reproduction and gives greenlight on the PR, please ping @banghuaz-nvidia @bxyu-nvidia.
30+
31+
332
## Signing Your Work
433

534
* We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license.

README.md

Lines changed: 134 additions & 64 deletions
Large diffs are not rendered by default.

nemo_gym/base_resources_server.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
NeMoGymResponse,
2121
NeMoGymResponseCreateParamsNonStreaming,
2222
)
23-
from nemo_gym.server_utils import BaseRunServerConfig, BaseServer, SimpleServer
23+
from nemo_gym.server_utils import BaseRunServerInstanceConfig, BaseServer, SimpleServer
2424

2525

26-
class BaseResourcesServerConfig(BaseRunServerConfig):
26+
class BaseResourcesServerConfig(BaseRunServerInstanceConfig):
2727
pass
2828

2929

@@ -43,16 +43,30 @@ class BaseVerifyResponse(BaseVerifyRequest):
4343
reward: float
4444

4545

46+
class BaseSeedSessionRequest(BaseModel):
47+
pass
48+
49+
50+
class BaseSeedSessionResponse(BaseModel):
51+
pass
52+
53+
4654
class SimpleResourcesServer(BaseResourcesServer, SimpleServer):
4755
config: BaseResourcesServerConfig
4856

4957
def setup_webserver(self) -> FastAPI:
5058
app = FastAPI()
5159

60+
self.setup_session_middleware(app)
61+
62+
app.post("/seed_session")(self.seed_session)
5263
app.post("/verify")(self.verify)
5364

5465
return app
5566

67+
async def seed_session(self, body: BaseSeedSessionRequest) -> BaseSeedSessionResponse:
68+
return BaseSeedSessionResponse()
69+
5670
@abstractmethod
5771
async def verify(self, body: BaseVerifyRequest) -> BaseVerifyResponse:
5872
pass

nemo_gym/base_responses_api_agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
NeMoGymResponse,
2121
NeMoGymResponseCreateParamsNonStreaming,
2222
)
23-
from nemo_gym.server_utils import BaseRunServerConfig, BaseServer, SimpleServer
23+
from nemo_gym.server_utils import BaseRunServerInstanceConfig, BaseServer, SimpleServer
2424

2525

26-
class BaseResponsesAPIAgentConfig(BaseRunServerConfig):
26+
class BaseResponsesAPIAgentConfig(BaseRunServerInstanceConfig):
2727
pass
2828

2929

@@ -37,6 +37,8 @@ class SimpleResponsesAPIAgent(BaseResponsesAPIAgent, SimpleServer):
3737
def setup_webserver(self) -> FastAPI:
3838
app = FastAPI()
3939

40+
self.setup_session_middleware(app)
41+
4042
app.post("/v1/responses")(self.responses)
4143
app.post("/run")(self.run)
4244

nemo_gym/base_responses_api_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
NeMoGymResponse,
2222
NeMoGymResponseCreateParamsNonStreaming,
2323
)
24-
from nemo_gym.server_utils import BaseRunServerConfig, BaseServer, SimpleServer
24+
from nemo_gym.server_utils import BaseRunServerInstanceConfig, BaseServer, SimpleServer
2525

2626

27-
class BaseResponsesAPIModelConfig(BaseRunServerConfig):
27+
class BaseResponsesAPIModelConfig(BaseRunServerInstanceConfig):
2828
pass
2929

3030

@@ -36,6 +36,8 @@ class SimpleResponsesAPIModel(BaseResponsesAPIModel, SimpleServer):
3636
def setup_webserver(self) -> FastAPI:
3737
app = FastAPI()
3838

39+
self.setup_session_middleware(app)
40+
3941
app.post("/v1/chat/completions")(self.chat_completions)
4042

4143
app.post("/v1/responses")(self.responses)

nemo_gym/cli.py

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pathlib import Path
2121
from subprocess import Popen
2222
from threading import Thread
23+
from time import sleep
2324
from typing import Dict, List, Optional
2425

2526
from devtools import pprint
@@ -35,7 +36,7 @@
3536
GlobalConfigDictParserConfig,
3637
get_global_config_dict,
3738
)
38-
from nemo_gym.server_utils import HeadServer
39+
from nemo_gym.server_utils import HEAD_SERVER_KEY_NAME, HeadServer, ServerClient, ServerStatus
3940

4041

4142
def _setup_env_command(dir_path: Path) -> str: # pragma: no cover
@@ -71,7 +72,7 @@ def model_post_init(self, context): # pragma: no cover
7172
return super().model_post_init(context)
7273

7374

74-
class ServerInstance(BaseModel):
75+
class ServerInstanceDisplayConfig(BaseModel):
7576
process_name: str
7677
server_type: str
7778
name: str
@@ -87,7 +88,8 @@ class ServerInstance(BaseModel):
8788
class RunHelper: # pragma: no cover
8889
_head_server_thread: Thread
8990
_processes: Dict[str, Popen]
90-
_server_instances: List[ServerInstance]
91+
_server_instance_display_configs: List[ServerInstanceDisplayConfig]
92+
_server_client: ServerClient
9193

9294
def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig) -> None:
9395
global_config_dict = get_global_config_dict(global_config_dict_parser_config=global_config_dict_parser_config)
@@ -100,8 +102,8 @@ def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig)
100102

101103
top_level_paths = [k for k in global_config_dict.keys() if k not in NEMO_GYM_RESERVED_TOP_LEVEL_KEYS]
102104

103-
processes: Dict[str, Popen] = dict()
104-
server_instances: List[ServerInstance] = []
105+
self._processes: Dict[str, Popen] = dict()
106+
self._server_instance_display_configs: List[ServerInstanceDisplayConfig] = []
105107

106108
# TODO there is a better way to resolve this that uses nemo_gym/global_config.py::ServerInstanceConfig
107109
for top_level_path in top_level_paths:
@@ -133,13 +135,13 @@ def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig)
133135
python {str(entrypoint_fpath)}"""
134136

135137
process = _run_command(command, dir_path)
136-
processes[top_level_path] = process
138+
self._processes[top_level_path] = process
137139

138140
host = server_config_dict.get("host")
139141
port = server_config_dict.get("port")
140142

141-
server_instances.append(
142-
ServerInstance(
143+
self._server_instance_display_configs.append(
144+
ServerInstanceDisplayConfig(
143145
process_name=top_level_path,
144146
server_type=first_key,
145147
name=second_key,
@@ -153,14 +155,25 @@ def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig)
153155
)
154156
)
155157

156-
self._processes = processes
157-
self._server_instances = server_instances
158+
self._server_client = ServerClient(
159+
head_server_config=ServerClient.load_head_server_config(),
160+
global_config_dict=global_config_dict,
161+
)
162+
163+
print("Waiting for head server to spin up")
164+
while True:
165+
status = self._server_client.poll_for_status(HEAD_SERVER_KEY_NAME)
166+
if status == "success":
167+
break
158168

159-
# TODO: Server block summaries may get cut off/interleaved by other process output(s)
160-
self.display_server_instance_info()
169+
print(f"Head server is not up yet (status `{status}`). Sleeping 3s")
170+
sleep(3)
171+
172+
print("Waiting for servers to spin up")
173+
self.wait_for_spinup()
161174

162175
def display_server_instance_info(self) -> None:
163-
if not getattr(self, "_server_instances", None):
176+
if not self._server_instance_display_configs:
164177
print("No server instances to display.")
165178
return
166179

@@ -172,7 +185,7 @@ def display_server_instance_info(self) -> None:
172185
{"#" * 100}
173186
""")
174187

175-
for i, inst in enumerate(self._server_instances, 1):
188+
for i, inst in enumerate(self._server_instance_display_configs, 1):
176189
print(f"[{i}] {inst.process_name} ({inst.server_type}/{inst.name})")
177190
pprint(inst.model_dump())
178191
print(f"{'#' * 100}\n")
@@ -185,12 +198,37 @@ def poll(self) -> None:
185198
if process.poll() is not None:
186199
raise RuntimeError(f"Process `{process_name}` finished unexpectedly!")
187200

201+
def wait_for_spinup(self) -> None:
202+
sleep_interval = 3
203+
204+
# Until we spin up or error out.
205+
while True:
206+
self.poll()
207+
statuses = self.check_http_server_statuses()
208+
209+
num_spun_up = statuses.count("success")
210+
if len(statuses) != num_spun_up:
211+
print(
212+
f"""{num_spun_up} / {len(statuses)} servers ready ({statuses.count("timeout")} timed out, {statuses.count("connection_error")} connection errored, {statuses.count("unknown_error")} had unknown errors).
213+
Waiting for servers to spin up. Sleeping {sleep_interval}s..."""
214+
)
215+
else:
216+
print(f"All {num_spun_up} / {len(statuses)} servers ready! Polling every 60s")
217+
self.display_server_instance_info()
218+
return
219+
220+
sleep(sleep_interval)
221+
188222
def run_forever(self) -> None:
189223
async def sleep():
190224
# Indefinitely
191225
while True:
192226
self.poll()
193-
await asyncio.sleep(60) # Check every 60s.
227+
228+
statuses = self.check_http_server_statuses()
229+
assert statuses.count("success") == len(statuses), "Found non-success statuses"
230+
231+
await asyncio.sleep(60)
194232

195233
try:
196234
asyncio.run(sleep())
@@ -204,6 +242,18 @@ async def sleep():
204242

205243
print("NeMo Gym finished!")
206244

245+
def check_http_server_statuses(self) -> List[ServerStatus]:
246+
print(
247+
"Checking for HTTP server statuses (you should see some HTTP requests to `/` that may 404. This is expected.)"
248+
)
249+
statuses = []
250+
for server_instance_display_config in self._server_instance_display_configs:
251+
name = server_instance_display_config.config_path
252+
status = self._server_client.poll_for_status(name)
253+
statuses.append(status)
254+
255+
return statuses
256+
207257

208258
def run(
209259
global_config_dict_parser_config: Optional[GlobalConfigDictParserConfig] = None,
@@ -447,7 +497,7 @@ def init_resources_server(): # pragma: no cover
447497
name: {server_type_name}_resources_server
448498
model_server:
449499
type: responses_api_models
450-
name: openai_model
500+
name: policy_model
451501
datasets:
452502
- name: train
453503
type: train

nemo_gym/config_types.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,14 @@ class DatasetConfig(BaseModel):
8787
jsonl_fpath: str
8888

8989
gitlab_identifier: Optional[JsonlDatasetGitlabIdentifer] = None
90-
license: Optional[Union[Literal["Apache 2.0"], Literal["TBD"]]] = None
90+
license: Optional[
91+
Union[
92+
Literal["Apache 2.0"],
93+
Literal["Creative Commons Attribution 4.0 International"],
94+
Literal["Creative Commons Attribution-ShareAlike 4.0 International"],
95+
Literal["TBD"],
96+
]
97+
] = None
9198

9299
@model_validator(mode="after")
93100
def check_train_validation_sets(self) -> "DatasetConfig":
@@ -112,6 +119,10 @@ class BaseRunServerConfig(BaseServerConfig):
112119
entrypoint: str
113120

114121

122+
class BaseRunServerInstanceConfig(BaseRunServerConfig):
123+
name: str # This name is unique at runtime.
124+
125+
115126
########################################
116127
# Server type and server instance configs
117128
########################################

nemo_gym/dataset_viewer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def format_function_call_output(m: FunctionCallOutput) -> List[ChatMessage]:
5151
ChatMessage(
5252
content=content,
5353
role="assistant",
54-
metadata=MetadataDict(title="Function call output", status="done"),
54+
metadata=MetadataDict(title=f"Function call output (tool call ID `{m['call_id']}`)", status="done"),
5555
)
5656
]
5757

@@ -67,7 +67,7 @@ def format_function_call(m: ResponseFunctionToolCallParam) -> List[ChatMessage]:
6767
ChatMessage(
6868
content=content,
6969
role="assistant",
70-
metadata=MetadataDict(title=f"Function call - `{name}`", status="done"),
70+
metadata=MetadataDict(title=f"Function call - `{name}` (tool call ID `{m['call_id']}`)", status="done"),
7171
)
7272
]
7373

nemo_gym/openai_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,15 @@ class NeMoGymResponseReasoningItemForTraining(NeMoGymResponseReasoningItem, Toke
197197
pass
198198

199199

200+
RESPONSES_TO_TRAIN = {
201+
NeMoGymEasyInputMessage: NeMoGymEasyInputMessageForTraining,
202+
NeMoGymMessage: NeMoGymMessageForTraining,
203+
NeMoGymResponseOutputMessage: NeMoGymResponseOutputMessageForTraining,
204+
NeMoGymResponseFunctionToolCall: NeMoGymResponseFunctionToolCallForTraining,
205+
NeMoGymResponseReasoningItem: NeMoGymResponseReasoningItemForTraining,
206+
}
207+
208+
200209
NeMoGymResponseInputItem = Union[
201210
NeMoGymEasyInputMessage,
202211
NeMoGymMessage,

0 commit comments

Comments
 (0)