Skip to content

Commit 21bd089

Browse files
bxyu-nvidiaabhibha-nvidia
authored andcommitted
Add profiling; improve rollout collection usability and efficiency; add uvicorn logging filtering (#79)
Signed-off-by: Brian Yu <bxyu@nvidia.com> Signed-off-by: Abhibha Gupta <abhibhag@nvidia.com>
1 parent 8cbc5fd commit 21bd089

File tree

11 files changed

+251
-55
lines changed

11 files changed

+251
-55
lines changed

README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
- [How To: ng\_dump\_config - Dump a YAML config as exactly as NeMo Gym sees it](#how-to-ng_dump_config---dump-a-yaml-config-as-exactly-as-nemo-gym-sees-it)
2020
- [How To: Use NeMo Gym with a non-Responses compatible API endpoint like vLLM](#how-to-use-nemo-gym-with-a-non-responses-compatible-api-endpoint-like-vllm)
2121
- [How To: Multi-verifier usage](#how-to-multi-verifier-usage)
22+
- [How To: Profile your resources server](#how-to-profile-your-resources-server)
2223
- [FAQ: DCO and commit signing VSCode and Git setup](#faq-dco-and-commit-signing-vscode-and-git-setup)
2324
- [FAQ: SFT and RL](#faq-sft-and-rl)
2425
- [FAQ: Error: Found files with missing copyright](#faq-error-found-files-with-missing-copyright)
@@ -772,6 +773,53 @@ ng_run "+config_paths=[$config_paths]"
772773
The same process goes for data preparation and downstream training framework Gym configuration, you would just add additional server configs.
773774

774775

776+
# How To: Profile your resources server
777+
For large scale verifier training, it's critical that your resources server is as efficient as possible. It may be slammed with 16k concurrent requests or more. Gym provides easy tools to profile and understand the efficiency of your servers.
778+
779+
In one terminal, start your agent, model, and resources servers, with profiling enabled.
780+
- `profiling_enabled` (bool): whether profiling is enabled or not. By default this is disabled since it incurs some slight overhead we don't want at runtime.
781+
- `profiling_results_dirpath` (str): The directory to save all server profiling results in. Previous logs for the same will be overwritten in the same directory.
782+
```bash
783+
config_paths="responses_api_models/openai_model/configs/openai_model.yaml,\
784+
resources_servers/library_judge_math/configs/bytedtsinghua_dapo17k.yaml"
785+
ng_run "+config_paths=[${config_paths}]" \
786+
+profiling_enabled=true \
787+
+profiling_results_dirpath=results/profiling/library_judge_math
788+
```
789+
790+
In another terminal, run some large number of rollouts against your servers. Use the `limit` and `num_repeats` flags to adjust the number of samples you want to run.
791+
```bash
792+
ng_collect_rollouts +agent_name=library_judge_math_simple_agent \
793+
+input_jsonl_fpath=resources_servers/library_judge_math/data/dapo17k_bytedtsinghua_train.jsonl \
794+
+output_jsonl_fpath=temp/library_judge_math_rollouts.jsonl \
795+
+limit=1024 \
796+
+num_repeats 1
797+
```
798+
799+
After `ng_collect_rollouts` finishes, ctrl+c to quit your servers. You should see some output in the terminal like this:
800+
```bash
801+
```
802+
803+
The log file content for a server will look something like the following:
804+
```
805+
name ncall tsub ttot tavg
806+
.../nemo-gym/resources_servers/library_judge_math/app.py:118 LibraryJudgeMathResourcesServer.verify 1024 0.009755 17.98387 0.017562
807+
.../nemo-gym/resources_servers/library_judge_math/app.py:145 LibraryJudgeMathResourcesServer._verify_answer 1024 0.002933 17.87998 0.017461
808+
.../nemo-gym/resources_servers/library_judge_math/app.py:173 LibraryJudgeMathResourcesServer._verify_answer_with_library 1024 0.007851 17.87704 0.017458
809+
.../nemo-gym/resources_servers/library_judge_math/app.py:191 <genexpr> 2339 0.001695 0.029082 0.000012
810+
.../nemo-gym/resources_servers/library_judge_math/app.py:163 _mute_output 2048 0.007473 0.016538 0.000008
811+
```
812+
813+
- `ncall`: number of calls (how many times the function/subroutine was invoked).
814+
- The `LibraryJudgeMathResourcesServer.verify` function was invoked 1024 times.
815+
- `tsub`: time spent inside the subroutine itself, excluding calls to other functions (sometimes called "self time").
816+
- The `LibraryJudgeMathResourcesServer.verify` function __itself__ accounted for only 0.009755s of time.
817+
- `ttot`: total time spent in the subroutine, including all the functions it called.
818+
- The `LibraryJudgeMathResourcesServer.verify` function and all functions it called including `_verify_answer`, etc accounted for a total of 17.98387s.
819+
- `tavg`: average time per call (often ttot / ncall).
820+
- The `LibraryJudgeMathResourcesServer.verify` function took 0.017562s per call on average.
821+
822+
775823
# FAQ: DCO and commit signing VSCode and Git setup
776824
Here are some suggestions for easier development using the VSCode code editor.
777825

nemo_gym/cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@
3737
GlobalConfigDictParserConfig,
3838
get_global_config_dict,
3939
)
40-
from nemo_gym.server_utils import HEAD_SERVER_KEY_NAME, HeadServer, ServerClient, ServerStatus
40+
from nemo_gym.server_utils import (
41+
HEAD_SERVER_KEY_NAME,
42+
HeadServer,
43+
ServerClient,
44+
ServerStatus,
45+
)
4146

4247

4348
def _setup_env_command(dir_path: Path) -> str: # pragma: no cover

nemo_gym/rollout_collection.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ class RolloutCollectionConfig(BaseModel):
4343

4444
class RolloutCollectionHelper(BaseModel): # pragma: no cover
4545
async def run_from_config(self, config: RolloutCollectionConfig):
46+
range_iterator = repeat(0)
47+
if config.limit:
48+
range_iterator = range(config.limit)
49+
print(f"Limiting the number of rows to {config.limit}!")
50+
4651
with open(config.input_jsonl_fpath) as input_dataset:
47-
rows = list(map(json.loads, input_dataset))
52+
rows = [row for _, row in zip(range_iterator, map(json.loads, input_dataset))]
4853
print(f"Found {len(rows)} rows!")
4954

50-
if config.limit:
51-
previous_length = len(rows)
52-
rows = rows[: config.limit]
53-
print(f"Limiting rows from {previous_length} to {len(rows)}!")
54-
5555
if config.num_repeats:
5656
previous_length = len(rows)
5757
rows = list(chain.from_iterable(repeat(row, config.num_repeats) for row in rows))
@@ -63,6 +63,11 @@ async def run_from_config(self, config: RolloutCollectionConfig):
6363

6464
server_client = self.setup_server_client()
6565

66+
tqdm_miniters = 10
67+
print(
68+
f"The tqdm progress bar will only update every {tqdm_miniters} samples that finish to ensure that you are not being spammed."
69+
)
70+
6671
metrics = Counter()
6772
with open(config.output_jsonl_fpath, "a") as f:
6873

@@ -73,7 +78,7 @@ async def _post_coroutine(row: dict) -> None:
7378
f.write(json.dumps(result) + "\n")
7479
metrics.update({k: v for k, v in result.items() if isinstance(v, (int, float))})
7580

76-
await tqdm.gather(*map(_post_coroutine, rows), desc="Collecting rollouts")
81+
await tqdm.gather(*map(_post_coroutine, rows), desc="Collecting rollouts", miniters=tqdm_miniters)
7782

7883
avg_metrics = {k: v / len(rows) for k, v in metrics.items()}
7984

@@ -88,7 +93,7 @@ async def _post_subroutine(row: Dict) -> Dict:
8893
res = await server_client.post(server_name=row.pop("agent_ref")["name"], url_path="/run", json=row)
8994
return await res.json()
9095

91-
return await tqdm.gather(*map(_post_subroutine, examples), desc="Collecting rollouts")
96+
return await tqdm.gather(*map(_post_subroutine, examples), desc="Collecting rollouts", miniters=10)
9297

9398
def setup_server_client(self, head_server_config: Optional[BaseServerConfig] = None) -> ServerClient:
9499
server_client = ServerClient.load_from_global_config(head_server_config)

nemo_gym/server_utils.py

Lines changed: 85 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,19 @@
1515
import atexit
1616
import json
1717
from abc import abstractmethod
18+
from contextlib import asynccontextmanager
19+
from io import StringIO
1820
from logging import Filter as LoggingFilter
1921
from logging import LogRecord, getLogger
2022
from os import getenv
23+
from pathlib import Path
2124
from threading import Thread
2225
from typing import Literal, Optional, Tuple, Type, Union, Unpack
2326
from uuid import uuid4
2427

2528
import requests
2629
import uvicorn
30+
import yappi
2731
from aiohttp import ClientResponse, ClientSession, ClientTimeout, DummyCookieJar, ServerDisconnectedError, TCPConnector
2832
from aiohttp.client import _RequestOptions
2933
from fastapi import FastAPI, Request, Response
@@ -32,6 +36,7 @@
3236
from requests.exceptions import ConnectionError
3337
from starlette.middleware.sessions import SessionMiddleware
3438

39+
from nemo_gym import PARENT_DIR
3540
from nemo_gym.config_types import (
3641
BaseRunServerInstanceConfig,
3742
BaseServerConfig,
@@ -50,8 +55,8 @@
5055

5156

5257
class GlobalAIOHTTPAsyncClientConfig(BaseModel):
53-
global_aiohttp_connector_limit: int = 1000
54-
global_aiohttp_connector_limit_per_host: int = 100
58+
global_aiohttp_connector_limit: int = 100 * 1024
59+
global_aiohttp_connector_limit_per_host: int = 1024
5560

5661

5762
def get_global_aiohttp_client(
@@ -123,7 +128,7 @@ async def request(method: str, url: str, **kwargs: Unpack[_RequestOptions]) -> C
123128
await asyncio.sleep(0.5)
124129
except Exception as e:
125130
print(
126-
f"""Hit an exception while making a request (try {num_tries}): {e}
131+
f"""Hit an exception while making a request (try {num_tries}): {type(e)}: {e}
127132
Sleeping 0.5s and retrying...
128133
"""
129134
)
@@ -274,6 +279,20 @@ def load_config_from_global_config(cls) -> "BaseRunServerInstanceConfig":
274279
return server_config
275280

276281

282+
class ProfilingMiddlewareInputConfig(BaseModel):
283+
# Relative to the Gym root dir.
284+
profiling_results_dirpath: Optional[str] = None
285+
286+
287+
class ProfilingMiddlewareConfig(ProfilingMiddlewareInputConfig):
288+
profiling_enabled: bool = False
289+
290+
291+
class UvicornLoggingConfig(BaseModel):
292+
# Default to False for regular use cases.
293+
uvicorn_logging_show_200_ok: bool = False
294+
295+
277296
class SimpleServer(BaseServer):
278297
server_client: ServerClient
279298

@@ -305,36 +324,86 @@ async def add_session_id(request: Request, call_next): # pragma: no cover
305324
session_middleware_key = self.get_session_middleware_key()
306325
app.add_middleware(SessionMiddleware, secret_key=session_middleware_key, session_cookie=session_middleware_key)
307326

327+
def setup_profiling(self, app: FastAPI, profiling_config: ProfilingMiddlewareConfig) -> None: # pragma: no cover
328+
base_profile_dir = Path(PARENT_DIR) / profiling_config.profiling_results_dirpath
329+
server_profile_path = (base_profile_dir / self.get_session_middleware_key()).with_suffix(".log")
330+
331+
base_profile_dir.mkdir(parents=True, exist_ok=True)
332+
333+
main_app_lifespan = app.router.lifespan_context
334+
335+
@asynccontextmanager
336+
async def lifespan_wrapper(app):
337+
yappi.set_clock_type("WALL")
338+
yappi.start()
339+
print(f"🔍 Enabled profiling for {self.config.name}")
340+
341+
async with main_app_lifespan(app) as maybe_state:
342+
yield maybe_state
343+
344+
print(f"🛑 Stopping profiler for {self.config.name}. Check {server_profile_path} for the metrics!")
345+
yappi.stop()
346+
347+
buffer = StringIO()
348+
yappi.get_func_stats().print_all(
349+
out=buffer,
350+
columns={
351+
0: ("name", 200),
352+
1: ("ncall", 10),
353+
2: ("tsub", 8),
354+
3: ("ttot", 8),
355+
4: ("tavg", 8),
356+
},
357+
)
358+
359+
buffer.seek(0)
360+
with open(server_profile_path, "w") as f:
361+
past_header = False
362+
for line in buffer:
363+
if not past_header or self.config.entrypoint in line:
364+
f.write(line)
365+
366+
if line.startswith("name"):
367+
past_header = True
368+
369+
app.router.lifespan_context = lifespan_wrapper
370+
308371
@classmethod
309372
def run_webserver(cls) -> None: # pragma: no cover
373+
global_config_dict = get_global_config_dict()
374+
310375
server_config = cls.load_config_from_global_config()
311376
server_client = ServerClient(
312377
head_server_config=ServerClient.load_head_server_config(),
313-
global_config_dict=get_global_config_dict(),
378+
global_config_dict=global_config_dict,
314379
)
315380
server = cls(config=server_config, server_client=server_client)
316381

317382
app = server.setup_webserver()
318383

319-
class No200Filter(LoggingFilter):
320-
def filter(self, record: LogRecord) -> bool:
321-
msg = record.getMessage()
322-
return not msg.strip().endswith("200")
384+
profiling_config = ProfilingMiddlewareConfig.model_validate(global_config_dict)
385+
if profiling_config.profiling_enabled:
386+
server.setup_profiling(app, profiling_config)
323387

324-
uvicorn_logger = getLogger("uvicorn.access")
325-
uvicorn_logger.addFilter(No200Filter())
388+
uvicorn_logging_cfg = UvicornLoggingConfig.model_validate(global_config_dict)
389+
if not uvicorn_logging_cfg.uvicorn_logging_show_200_ok:
326390

327-
print(
328-
"Adding a uvicorn logging filter so that the logs aren't spammed with 200 OK messages. This is to help errors pop up better and filter out noise."
329-
)
391+
class No200Filter(LoggingFilter):
392+
def filter(self, record: LogRecord) -> bool:
393+
msg = record.getMessage()
394+
return not msg.strip().endswith("200")
395+
396+
uvicorn_logger = getLogger("uvicorn.access")
397+
uvicorn_logger.addFilter(No200Filter())
398+
399+
print(
400+
"Adding a uvicorn logging filter so that the logs aren't spammed with 200 OK messages. This is to help errors pop up better and filter out noise."
401+
)
330402

331403
uvicorn.run(
332404
app,
333405
host=server.config.host,
334406
port=server.config.port,
335-
# We don't have any explicit lifespan logic, so instead of defaulting to "auto"
336-
# We just turn lifespan off
337-
lifespan="off",
338407
)
339408

340409

nemo_gym/train_data_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def collate_samples(
577577

578578
parent = Path(config.output_dirpath)
579579
parent.mkdir(exist_ok=True)
580-
metrics_fpath = parent / f"{type}_metrics.jsonl"
580+
metrics_fpath = parent / f"{type}_metrics.json"
581581
maybe_conflicting_metrics_fpath = self._validate_aggregate_metrics(
582582
aggregate_metrics_dict=aggregate_metrics_dict,
583583
metrics_fpath=metrics_fpath,

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ dependencies = [
131131
# Updated Sun Sep 21, 2025 with aiohttp==3.12.15
132132
# License: Apache 2.0 https://github.com/aio-libs/aiohttp/blob/9a2f146a12e3525b43e96723ef41584bf9cf784e/LICENSE.txt
133133
"aiohttp",
134+
135+
# yappi: profiling tool
136+
# Updated Mon Sep 22, 2025 with yappi==1.6.10
137+
# License: MIT https://github.com/sumerc/yappi/blob/1d3f7501701e1f050b6dcd6a86fd36aec08185c7/LICENSE
138+
"yappi",
134139
]
135140

136141
[dependency-groups]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
{
2+
"name": "train",
3+
"type": "train",
4+
"jsonl_fpath": "resources_servers/comp_coding/data/train.jsonl",
5+
"gitlab_identifier": {
6+
"dataset_name": "comp_coding",
7+
"version": "0.0.1",
8+
"artifact_fpath": "train.jsonl"
9+
},
10+
"license": "Apache 2.0",
11+
"Number of examples": 5000,
12+
"Number of tools": {
13+
"Total # non-null values": 0,
14+
"Average": 0.0,
15+
"Min": 0.0,
16+
"Max": 0.0,
17+
"Median": 0.0,
18+
"Standard deviation": 0.0
19+
},
20+
"Json-dumped number of words (proxy for token count)": {
21+
"Total # non-null values": 5000,
22+
"Average": 336.1797999999992,
23+
"Min": 46.0,
24+
"Max": 1274.0,
25+
"Median": 319.5131482834187,
26+
"Standard deviation": 135.7584072571132
27+
},
28+
"Number of turns": {
29+
"Total # non-null values": 5000,
30+
"Average": 1.0,
31+
"Min": 1.0,
32+
"Max": 1.0,
33+
"Median": 1.0,
34+
"Standard deviation": 0.0
35+
},
36+
"Temperature": {
37+
"Total # non-null values": 0,
38+
"Average": 0.0,
39+
"Min": 0.0,
40+
"Max": 0.0,
41+
"Median": 0.0,
42+
"Standard deviation": 0.0
43+
}
44+
}

resources_servers/comp_coding/data/train_metrics.jsonl

Lines changed: 0 additions & 27 deletions
This file was deleted.

0 commit comments

Comments
 (0)