Skip to content

Commit 8e41fc2

Browse files
Comp coding fixes; lots of misc infra items (#90)
Signed-off-by: Brian Yu <bxyu@nvidia.com> Signed-off-by: Khushi Bhardwaj <kbhardwaj@nvidia.com> Co-authored-by: Khushi Bhardwaj <kbhardwaj@nvidia.com>
1 parent 4d7ce94 commit 8e41fc2

29 files changed

+1899
-312
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -577,14 +577,15 @@ ng_collect_rollouts +agent_name=multineedle_simple_agent \
577577
+output_jsonl_fpath=results/multineedle_rollout_collection.jsonl \
578578
+limit=null \
579579
+num_repeats=null \
580-
+num_samples_in_parallel=null
580+
+num_samples_in_parallel=null \
581+
+responses_create_params.max_output_tokens=32_768
581582
```
582583

583584
The supported parameters include:
584585
- `limit`: Limits how many examples from the input JSONL file to process
585586
- `num_repeats`: Repeats each input example multiple times to collect multiple rollouts per example
586587
- `num_samples_in_parallel`: Controls how many rollout collection requests run concurrently
587-
588+
- `responses_create_params`: A dictionary of sampling parameter overrides.
588589

589590
View the rollouts just collected!
590591
```
@@ -797,7 +798,7 @@ ng_collect_rollouts +agent_name=library_judge_math_simple_agent \
797798
+input_jsonl_fpath=resources_servers/library_judge_math/data/dapo17k_bytedtsinghua_train.jsonl \
798799
+output_jsonl_fpath=temp/library_judge_math_rollouts.jsonl \
799800
+limit=1024 \
800-
+num_repeats 1
801+
+num_repeats=1
801802
```
802803

803804
After `ng_collect_rollouts` finishes, ctrl+c to quit your servers. You should see some output in the terminal like this:

nemo_gym/config_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class DatasetConfig(BaseModel):
9191
license: Optional[
9292
Union[
9393
Literal["Apache 2.0"],
94+
Literal["MIT"],
9495
Literal["Creative Commons Attribution 4.0 International"],
9596
Literal["Creative Commons Attribution-ShareAlike 4.0 International"],
9697
Literal["TBD"],

nemo_gym/openai_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
from pydantic import BaseModel, ConfigDict, Field
7575
from typing_extensions import TypedDict
7676

77-
from nemo_gym.server_utils import request
77+
from nemo_gym.server_utils import raise_for_status, request
7878

7979

8080
########################################
@@ -432,6 +432,7 @@ async def create_chat_completion(self, **kwargs):
432432
json=kwargs,
433433
headers={"Authorization": f"Bearer {self.api_key}"},
434434
)
435+
raise_for_status(response)
435436
return await response.json()
436437

437438
async def create_response(self, **kwargs):
@@ -441,6 +442,7 @@ async def create_response(self, **kwargs):
441442
json=kwargs,
442443
headers={"Authorization": f"Bearer {self.api_key}"},
443444
)
445+
raise_for_status(response)
444446
return await response.json()
445447

446448
async def create_tokenize(self, **kwargs):
@@ -451,4 +453,5 @@ async def create_tokenize(self, **kwargs):
451453
json=kwargs,
452454
headers={"Authorization": f"Bearer {self.api_key}"},
453455
)
456+
raise_for_status(response)
454457
return await response.json()

nemo_gym/rollout_collection.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from collections import Counter
1818
from contextlib import nullcontext
1919
from itertools import chain, repeat
20-
from typing import Dict, List, Optional
20+
from typing import Any, Dict, List, Optional
2121

22-
from pydantic import BaseModel
22+
from pydantic import BaseModel, Field
2323
from tqdm.asyncio import tqdm
2424

2525
from nemo_gym.config_types import BaseServerConfig
@@ -39,6 +39,7 @@ class RolloutCollectionConfig(BaseModel):
3939
limit: Optional[int] = None
4040
num_repeats: Optional[int] = None
4141
num_samples_in_parallel: Optional[int] = None
42+
responses_create_params: Dict[str, Any] = Field(default_factory=dict)
4243

4344

4445
class RolloutCollectionHelper(BaseModel): # pragma: no cover
@@ -68,10 +69,14 @@ async def run_from_config(self, config: RolloutCollectionConfig):
6869
f"The tqdm progress bar will only update every {tqdm_miniters} samples that finish to ensure that you are not being spammed."
6970
)
7071

72+
if config.responses_create_params:
73+
print(f"Overriding responses_create_params fields with {config.responses_create_params}")
74+
7175
metrics = Counter()
7276
with open(config.output_jsonl_fpath, "a") as f:
7377

7478
async def _post_coroutine(row: dict) -> None:
79+
row["responses_create_params"] = row["responses_create_params"] | config.responses_create_params
7580
async with semaphore:
7681
response = await server_client.post(server_name=config.agent_name, url_path="/run", json=row)
7782
result = await response.json()

nemo_gym/server_utils.py

Lines changed: 93 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import asyncio
1515
import atexit
1616
import json
17+
import resource
1718
from abc import abstractmethod
1819
from contextlib import asynccontextmanager
1920
from io import StringIO
@@ -22,6 +23,7 @@
2223
from os import getenv
2324
from pathlib import Path
2425
from threading import Thread
26+
from traceback import print_exc
2527
from typing import Literal, Optional, Tuple, Type, Union, Unpack
2628
from uuid import uuid4
2729

@@ -31,6 +33,7 @@
3133
from aiohttp import ClientResponse, ClientSession, ClientTimeout, DummyCookieJar, ServerDisconnectedError, TCPConnector
3234
from aiohttp.client import _RequestOptions
3335
from fastapi import FastAPI, Request, Response
36+
from fastapi.responses import JSONResponse
3437
from omegaconf import DictConfig, OmegaConf
3538
from pydantic import BaseModel, ConfigDict
3639
from requests.exceptions import ConnectionError
@@ -62,7 +65,7 @@ class GlobalAIOHTTPAsyncClientConfig(BaseModel):
6265
def get_global_aiohttp_client(
6366
global_config_dict_parser_config: Optional[GlobalConfigDictParserConfig] = None,
6467
global_config_dict_parser_cls: Type[GlobalConfigDictParser] = GlobalConfigDictParser,
65-
) -> ClientSession:
68+
) -> ClientSession: # pragma: no cover
6669
global _GLOBAL_AIOHTTP_CLIENT
6770

6871
if _GLOBAL_AIOHTTP_CLIENT is not None:
@@ -77,7 +80,7 @@ def get_global_aiohttp_client(
7780
return set_global_aiohttp_client(cfg)
7881

7982

80-
def set_global_aiohttp_client(cfg: GlobalAIOHTTPAsyncClientConfig) -> ClientSession:
83+
def set_global_aiohttp_client(cfg: GlobalAIOHTTPAsyncClientConfig) -> ClientSession: # pragma: no cover
8184
assert not is_global_aiohttp_client_setup(), (
8285
"There is already a global aiohttp client setup. Please refactor your code or call `global_aiohttp_client_exit` if you want to explicitly re-make the client!"
8386
)
@@ -97,11 +100,11 @@ def set_global_aiohttp_client(cfg: GlobalAIOHTTPAsyncClientConfig) -> ClientSess
97100
return _GLOBAL_AIOHTTP_CLIENT
98101

99102

100-
def is_global_aiohttp_client_setup() -> bool:
103+
def is_global_aiohttp_client_setup() -> bool: # pragma: no cover
101104
return _GLOBAL_AIOHTTP_CLIENT is not None
102105

103106

104-
def global_aiohttp_client_exit():
107+
def global_aiohttp_client_exit(): # pragma: no cover
105108
if not is_global_aiohttp_client_setup():
106109
return
107110

@@ -118,7 +121,9 @@ def global_aiohttp_client_exit():
118121
MAX_NUM_TRIES = 3
119122

120123

121-
async def request(method: str, url: str, **kwargs: Unpack[_RequestOptions]) -> ClientResponse:
124+
async def request(
125+
method: str, url: str, _internal: bool = False, **kwargs: Unpack[_RequestOptions]
126+
) -> ClientResponse: # pragma: no cover
122127
client = get_global_aiohttp_client()
123128
num_tries = 1
124129
while True:
@@ -127,18 +132,27 @@ async def request(method: str, url: str, **kwargs: Unpack[_RequestOptions]) -> C
127132
except ServerDisconnectedError:
128133
await asyncio.sleep(0.5)
129134
except Exception as e:
130-
print(
131-
f"""Hit an exception while making a request (try {num_tries}): {type(e)}: {e}
135+
# Don't increment internal since we know we are ok. If we are not, the head server will shut everything down anyways.
136+
if not _internal:
137+
print(
138+
f"""Hit an exception while making a request (try {num_tries}): {type(e)}: {e}
132139
Sleeping 0.5s and retrying...
133140
"""
134-
)
135-
if num_tries >= MAX_NUM_TRIES:
136-
raise e
141+
)
142+
if num_tries >= MAX_NUM_TRIES:
143+
raise e
144+
145+
num_tries += 1
137146

138-
num_tries += 1
139147
await asyncio.sleep(0.5)
140148

141149

150+
def raise_for_status(response: ClientResponse) -> None: # pragma: no cover
151+
if not response.ok:
152+
print(response.content)
153+
response.raise_for_status()
154+
155+
142156
DEFAULT_HEAD_SERVER_PORT = 11000
143157

144158
ServerStatus = Union[Literal["success"], Literal["connection_error"], Literal["timeout"], Literal["unknown_error"]]
@@ -193,7 +207,7 @@ async def request(
193207
if isinstance(json_obj, BaseModel):
194208
kwargs["json"] = json_obj.model_dump(exclude_unset=True)
195209

196-
return await request(method=method, url=f"{base_url}{url_path}", **kwargs)
210+
return await request(method=method, url=f"{base_url}{url_path}", _internal=True, **kwargs)
197211

198212
async def get(
199213
self,
@@ -324,6 +338,24 @@ async def add_session_id(request: Request, call_next): # pragma: no cover
324338
session_middleware_key = self.get_session_middleware_key()
325339
app.add_middleware(SessionMiddleware, secret_key=session_middleware_key, session_cookie=session_middleware_key)
326340

341+
def setup_exception_middleware(self, app: FastAPI) -> None: # pragma: no cover
342+
@app.middleware("http")
343+
async def exception_handling_middleware(request: Request, call_next):
344+
try:
345+
return await call_next(request)
346+
except Exception as e:
347+
print_exc()
348+
print(
349+
f"🚨 Caught an exception printed above in {self.config.name} ({self.__class__.__name__}). If you expect this to be fed back into this model, the exception repr i.e. `repr(e)` is returned to the model. However, please make sure this exception is caught in your server and returned to the model as appropriate. See https://fastapi.tiangolo.com/tutorial/handling-errors/#use-httpexception"
350+
)
351+
return JSONResponse(content=repr(e), status_code=500)
352+
except:
353+
print_exc()
354+
print(
355+
f"🚨 Caught an unknown exception printed above in {self.config.name} ({self.__class__.__name__}). If you expect this to be fed back into this model, nothing meaningful is returned to the model. Please make sure this exception is caught in your server and returned to the model as appropriate. See https://fastapi.tiangolo.com/tutorial/handling-errors/#use-httpexception"
356+
)
357+
return JSONResponse(content="An unknown error occurred", status_code=500)
358+
327359
def setup_profiling(self, app: FastAPI, profiling_config: ProfilingMiddlewareConfig) -> None: # pragma: no cover
328360
base_profile_dir = Path(PARENT_DIR) / profiling_config.profiling_results_dirpath
329361
server_profile_path = (base_profile_dir / self.get_session_middleware_key()).with_suffix(".log")
@@ -332,18 +364,7 @@ def setup_profiling(self, app: FastAPI, profiling_config: ProfilingMiddlewareCon
332364

333365
main_app_lifespan = app.router.lifespan_context
334366

335-
@asynccontextmanager
336-
async def lifespan_wrapper(app):
337-
yappi.set_clock_type("WALL")
338-
yappi.start()
339-
print(f"🔍 Enabled profiling for {self.config.name}")
340-
341-
async with main_app_lifespan(app) as maybe_state:
342-
yield maybe_state
343-
344-
print(f"🛑 Stopping profiler for {self.config.name}. Check {server_profile_path} for the metrics!")
345-
yappi.stop()
346-
367+
def _dump_yappi_stats() -> str:
347368
buffer = StringIO()
348369
yappi.get_func_stats().print_all(
349370
out=buffer,
@@ -357,17 +378,56 @@ async def lifespan_wrapper(app):
357378
)
358379

359380
buffer.seek(0)
360-
with open(server_profile_path, "w") as f:
361-
past_header = False
362-
for line in buffer:
363-
if not past_header or self.config.entrypoint in line:
364-
f.write(line)
381+
res = ""
382+
past_header = False
383+
for line in buffer:
384+
if not past_header or self.config.entrypoint in line:
385+
res += line
386+
387+
if line.startswith("name"):
388+
past_header = True
365389

366-
if line.startswith("name"):
367-
past_header = True
390+
return res
391+
392+
@asynccontextmanager
393+
async def lifespan_wrapper(app):
394+
yappi.set_clock_type("CPU")
395+
yappi.start()
396+
print(f"🔍 Enabled profiling for {self.config.name}")
397+
398+
async with main_app_lifespan(app) as maybe_state:
399+
yield maybe_state
400+
401+
print(f"🛑 Stopping profiler for {self.config.name}. Check {server_profile_path} for the metrics!")
402+
yappi.stop()
403+
404+
with open(server_profile_path, "w") as f:
405+
f.write(_dump_yappi_stats())
368406

369407
app.router.lifespan_context = lifespan_wrapper
370408

409+
@app.get("/stats")
410+
def stats():
411+
return Response(_dump_yappi_stats())
412+
413+
def set_ulimit(self, target_soft_limit: int = 65535): # pragma: no cover
414+
# From https://github.com/vllm-project/vllm/blob/fed8a9b107df3e27d57728c6911c7d308b871477/vllm/utils/__init__.py#L2790
415+
resource_type = resource.RLIMIT_NOFILE
416+
current_soft, current_hard = resource.getrlimit(resource_type)
417+
418+
if current_soft < target_soft_limit:
419+
try:
420+
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
421+
except ValueError as e:
422+
print(
423+
"Found ulimit of %s and failed to automatically increase "
424+
"with error %s. This can cause fd limit errors like "
425+
"`OSError: [Errno 24] Too many open files`. Consider "
426+
"increasing with ulimit -n",
427+
current_soft,
428+
e,
429+
)
430+
371431
@classmethod
372432
def run_webserver(cls) -> None: # pragma: no cover
373433
global_config_dict = get_global_config_dict()
@@ -380,6 +440,8 @@ def run_webserver(cls) -> None: # pragma: no cover
380440
server = cls(config=server_config, server_client=server_client)
381441

382442
app = server.setup_webserver()
443+
server.set_ulimit()
444+
server.setup_exception_middleware(app)
383445

384446
profiling_config = ProfilingMiddlewareConfig.model_validate(global_config_dict)
385447
if profiling_config.profiling_enabled:

resources_servers/comp_coding/README.md

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
### Overview
44
Verifies competitive programming solutions by executing submitted code against unit tests. The server consumes agent trajectories and returns a reward based on whether the assistant's code produces the correct outputs for given test inputs.
5-
Model registry link: https://gitlab-master.nvidia.com/bxyu/nemo-gym/-/ml/models/53#/
5+
Model registry link: https://gitlab-master.nvidia.com/bxyu/nemo-gym/-/ml/models/53#/
66

77
### Input schema
88
- `responses_create_params`: OpenAI Responses create params
@@ -65,7 +65,7 @@ ng_prepare_data "+config_paths=[$config_paths]" \
6565
# Download train data from gitlab model registry
6666
ng_download_dataset_from_gitlab \
6767
+dataset_name=comp_coding \
68-
+version=0.0.1 \
68+
+version=0.1.1 \
6969
+run_id=5a1167ef-3533-486f-9c0e-49d1e97fc887 \
7070
+artifact_fpath=train.jsonl \
7171
+output_fpath=resources_servers/comp_coding/data/train.jsonl
@@ -90,15 +90,5 @@ uv run python resources_servers/comp_coding/scripts/validate_dataset.py \
9090
--in data/comp_coding/train.jsonl --fail-fast
9191
```
9292

93-
### Error handling
94-
The server provides specific error messages for different failure modes:
95-
- `Empty model output`: No text found in the response
96-
- `Missing verifier_metadata.unit_tests`: Required test data not provided
97-
- `Invalid unit_tests`: Malformed test case data
98-
- `Could not extract code`: No valid Python code found in response
99-
- `INVALID_TEST_FORMAT`: Test inputs/outputs length mismatch or empty
100-
- `TEST_CASE_N_FAILED`: Specific test case failed with expected vs actual output
101-
- `TEST_CASE_N_ERROR`: Runtime error during test execution
102-
10393
## Licensing information
104-
TODO: @kbhardwaj to confirm data/code licensing information w Vahid and team
94+
Apache 2.0

0 commit comments

Comments
 (0)