Skip to content

Commit a6df36e

Browse files
fsiino-nvidiabxyu-nvidia
authored andcommitted
Add data aggregations to data preparation (#49)
This change updates the train_data_utils via `ng_prepare_data` to apply data aggregations to the other keys within an `example.jsonl`. file. --------- Signed-off-by: Frankie Siino <fsiino@nvidia.com> Co-authored-by: bxyu-nvidia <bxyu@nvidia.com>
1 parent e64f71f commit a6df36e

File tree

20 files changed

+572
-144
lines changed

20 files changed

+572
-144
lines changed

nemo_gym/dataset_viewer.py

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@
2929
from nemo_gym.base_resources_server import BaseVerifyResponse
3030
from nemo_gym.server_utils import get_global_config_dict
3131
from nemo_gym.train_data_utils import (
32-
AvgMinMax,
3332
DatasetMetrics,
33+
aggregate_other_metrics,
3434
compute_sample_metrics,
35+
postprocess_other_metrics,
3536
)
3637

3738

@@ -206,59 +207,34 @@ class JsonlDatasetViewerConfig(BaseModel):
206207
jsonl_fpath: str
207208

208209

209-
def aggregate_other_metrics(data: List[DatasetViewerVerifyResponse]) -> Dict[str, Any]:
210-
metric_values = {}
211-
string_values = {}
212-
for d in data:
213-
d = d.model_dump() if hasattr(d, "model_dump") else d
214-
for k, v in d.items():
215-
if k in ("responses_create_params", "response"):
216-
continue
217-
if isinstance(v, bool):
218-
v = int(v)
219-
if isinstance(v, (int, float)):
220-
metric_values.setdefault(k, []).append(v)
221-
# get unique count for strings
222-
elif isinstance(v, str):
223-
string_values.setdefault(k, []).append(v)
224-
225-
result = {}
226-
for k, v in metric_values.items():
227-
if v:
228-
obj = AvgMinMax(
229-
total=len(v),
230-
average=sum(v) / len(v),
231-
min=min(v),
232-
max=max(v),
233-
)
234-
result[k] = obj.model_dump(by_alias=True)
235-
236-
for k, v in string_values.items():
237-
result[k] = {"unique_count": len(set(v)), "total_count": len(v)}
238-
239-
return result
240-
241-
242-
def get_aggregate_metrics(data: List[DatasetViewerVerifyResponse], raw_lines: List[str]) -> Dict[str, Any]:
210+
def get_aggregate_metrics(data: List[DatasetViewerVerifyResponse]) -> Dict[str, Any]:
243211
dataset_metrics = DatasetMetrics()
244-
for line in raw_lines:
212+
other_metrics = {}
213+
214+
for line in data:
215+
line = json.dumps(line.model_dump())
245216
metrics, is_offending = compute_sample_metrics(line)
246217
if not is_offending:
247218
dataset_metrics.add(metrics)
248219

220+
sample_dict = json.loads(line)
221+
aggregate_other_metrics(other_metrics, sample_dict)
222+
223+
postprocess_other_metrics(dataset_metrics, other_metrics)
224+
249225
aggregate_metrics = dataset_metrics.aggregate()
250226
aggregate_metrics_dict = aggregate_metrics.model_dump(by_alias=True)
251-
aggregate_metrics_dict.update(**aggregate_other_metrics(data))
252227
return aggregate_metrics_dict
253228

254229

255230
def build_jsonl_dataset_viewer(config: JsonlDatasetViewerConfig) -> Blocks:
256-
data = []
257-
raw_lines = []
258231
with open(config.jsonl_fpath) as f:
259-
for line in tqdm(f, desc="Loading data"):
260-
raw_lines.append(line)
261-
data.append(DatasetViewerVerifyResponse.model_validate_json(line))
232+
data = list(
233+
tqdm(
234+
map(DatasetViewerVerifyResponse.model_validate_json, f),
235+
desc="Loading data",
236+
)
237+
)
262238

263239
choices = [(f"Sample {i + 1} - Responses ID {d.response.id}", i) for i, d in enumerate(data)]
264240

@@ -274,7 +250,7 @@ def select_item(value: int):
274250
}
275251
"""
276252
with Blocks(analytics_enabled=False, css=CSS) as demo:
277-
aggregate_dicts = get_aggregate_metrics(data, raw_lines)
253+
aggregate_dicts = get_aggregate_metrics(data)
278254
JSON(value=aggregate_dicts, label="Aggregate Metrics", open=False)
279255

280256
item_dropdown = Dropdown(choices=choices, value=0, label="Samples")

nemo_gym/train_data_utils.py

Lines changed: 128 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
from abc import abstractmethod
1616
from collections import Counter, defaultdict
1717
from itertools import count, repeat
18+
from math import sqrt
1819
from pathlib import Path
1920
from shutil import copyfileobj
20-
from typing import Dict, List, Literal, Optional, Self, Tuple, Union
21+
from typing import Any, Dict, List, Literal, Optional, Self, Tuple, Union
2122

2223
from devtools import pprint
2324
from omegaconf import DictConfig
2425
from pydantic import BaseModel, ConfigDict, Field, ValidationError
26+
from tdigest import TDigest
2527
from tqdm.auto import tqdm
2628

2729
from nemo_gym.base_resources_server import BaseRunRequest
@@ -79,27 +81,90 @@ def _aggregate(self: Self) -> Self:
7981

8082

8183
class AvgMinMax(Accumulator):
84+
model_config = ConfigDict(arbitrary_types_allowed=True)
8285
total: int = Field(serialization_alias="Total # non-null values", default=0)
8386
average: float = Field(serialization_alias="Average", default=0)
8487
min: float = Field(serialization_alias="Min", default=float("inf"))
8588
max: float = Field(serialization_alias="Max", default=float("-inf"))
89+
median: float = Field(serialization_alias="Median", default=0)
90+
stddev: float = Field(serialization_alias="Standard deviation", default=0)
91+
# Internal state
92+
mean: float = Field(default=0, exclude=True) # running value (before final average)
93+
M2: float = Field(default=0, exclude=True) # sum of squared differences (for variance)
94+
tdigest: TDigest = Field(default_factory=TDigest, exclude=True)
95+
"""
96+
T-Digest is used to estimate the Median without storing and sorting all values. The Median is essentially an approximation using the 50th percentile, which is very close to the true Median.
97+
"""
98+
99+
def observe(self, x: float) -> None:
100+
if x < self.min:
101+
self.min = x
102+
if x > self.max:
103+
self.max = x
104+
105+
# Update running mean and variance
106+
self.total += 1
107+
delta = x - self.mean
108+
self.mean += delta / self.total
109+
self.M2 += delta * (x - self.mean)
110+
111+
# Update quantile estimator (for median)
112+
self.tdigest.update(x)
86113

87114
def _add(self: Self, other: Self) -> None:
88-
self.total += other.total
89-
self.average += other.average
90-
self.min = min(self.min, other.min)
91-
self.max = max(self.max, other.max)
115+
# Merge accumulators
116+
if other.total == 0:
117+
return
118+
if self.total == 0:
119+
self.total = other.total
120+
self.mean = other.mean
121+
self.M2 = other.M2
122+
self.min = other.min
123+
self.max = other.max
124+
self.tdigest = TDigest()
125+
self.tdigest = self.tdigest + other.tdigest
126+
return
127+
128+
# Merge mean and variance
129+
n1, n2 = self.total, other.total
130+
delta = other.mean - self.mean
131+
n = n1 + n2
132+
self.mean = self.mean + delta * (n2 / n)
133+
self.M2 = self.M2 + other.M2 + (delta * delta) * (n1 * n2 / n)
134+
self.total = n
135+
136+
if other.min < self.min:
137+
self.min = other.min
138+
if other.max > self.max:
139+
self.max = other.max
140+
141+
# Merge t-digests for quantiles/median
142+
self.tdigest = self.tdigest + other.tdigest
143+
144+
def _aggregate(self: Self) -> Self:
145+
n = self.total
146+
mean = self.mean if n > 0 else 0.0
147+
stddev = sqrt(self.M2 / (n - 1)) if n > 1 else 0.0
148+
med = float(self.tdigest.percentile(50)) if n > 0 and self.tdigest.n > 0 else 0.0
92149

93-
def _aggregate(self) -> Self:
94150
return AvgMinMax(
95151
total=self.total,
96-
average=self.average / max(self.total, 1),
97-
min=self.min if self.total > 0 else 0,
98-
max=self.max if self.total > 0 else 0,
152+
average=mean,
153+
min=self.min if n > 0 else 0.0,
154+
max=self.max if n > 0 else 0.0,
155+
median=med,
156+
stddev=stddev,
99157
)
100158

101159

160+
class StringMetrics(BaseModel):
161+
unique_count: int
162+
total_count: int
163+
164+
102165
class DatasetMetrics(Accumulator):
166+
model_config = ConfigDict(extra="allow") # Allow any arbitrary fields
167+
103168
number_of_examples: int = Field(serialization_alias="Number of examples", default=0)
104169
number_of_tools: AvgMinMax = Field(serialization_alias="Number of tools", default_factory=AvgMinMax)
105170
json_dumped_number_of_words: AvgMinMax = Field(
@@ -118,16 +183,60 @@ def _add(self: Self, other: Self) -> None:
118183
self.number_of_turns.add(other.number_of_turns)
119184
self.temperature.add(other.temperature)
120185

186+
# Merge extra fields safely
187+
if other.model_extra:
188+
for k, v in other.model_extra.items():
189+
if k in DatasetMetrics.model_fields.keys():
190+
continue
191+
setattr(self, k, v)
192+
121193
def _aggregate(self: Self) -> Self:
194+
extras = {}
195+
if self.model_extra:
196+
for k, v in self.model_extra.items():
197+
if k in DatasetMetrics.model_fields.keys():
198+
continue
199+
extras[k] = v
122200
return DatasetMetrics(
123201
number_of_examples=self.number_of_examples,
124202
number_of_tools=self.number_of_tools.aggregate(),
125203
json_dumped_number_of_words=self.json_dumped_number_of_words.aggregate(),
126204
number_of_turns=self.number_of_turns.aggregate(),
127205
temperature=self.temperature.aggregate(),
206+
**extras,
128207
)
129208

130209

210+
def aggregate_other_metrics(metrics: Dict[str, Any], sample: Dict[str, Any]) -> None:
211+
"""Combines misc items (those other than response/response create params) into current metrics"""
212+
for k, v in sample.items():
213+
if k in ("responses_create_params", "response"):
214+
continue
215+
216+
values = v if isinstance(v, list) else [v]
217+
218+
for item in values:
219+
if isinstance(item, bool):
220+
item = int(item)
221+
if isinstance(item, (int, float)):
222+
if k not in metrics:
223+
metrics[k] = AvgMinMax()
224+
metrics[k].observe(item)
225+
elif isinstance(item, str):
226+
if k not in metrics:
227+
metrics[k] = Counter()
228+
metrics[k][item] += 1
229+
230+
231+
def postprocess_other_metrics(metrics: DatasetMetrics, other_metrics: Dict[str, Any]) -> None:
232+
"""Aggregates metrics and merges current metrics (containing only AvgMinMax) with StringMetrics"""
233+
for k, v in other_metrics.items():
234+
if isinstance(v, AvgMinMax):
235+
setattr(metrics, k, v.aggregate())
236+
elif isinstance(v, Counter):
237+
setattr(metrics, k, StringMetrics(unique_count=len(v), total_count=sum(v.values())))
238+
239+
131240
def compute_sample_metrics(sample_dict_str: str) -> Tuple[DatasetMetrics, bool]:
132241
try:
133242
sample_dict = json.loads(sample_dict_str)
@@ -146,43 +255,24 @@ def compute_sample_metrics(sample_dict_str: str) -> Tuple[DatasetMetrics, bool]:
146255
number_of_tools_metrics = AvgMinMax()
147256
if responses_create_params.get("tools") is not None:
148257
number_of_tools = len(responses_create_params["tools"])
149-
number_of_tools_metrics = AvgMinMax(
150-
total=1,
151-
average=number_of_tools,
152-
min=number_of_tools,
153-
max=number_of_tools,
154-
)
258+
number_of_tools_metrics.observe(number_of_tools)
155259

156260
if isinstance(inputs, str):
157261
inputs = [{"role": "user", "content": inputs}]
158262
user_inputs = [i for i in inputs if i.get("role") == "user"] if inputs else []
159263
number_of_turns_metrics = AvgMinMax()
160264
if user_inputs:
161265
number_of_turns = len(user_inputs)
162-
number_of_turns_metrics = AvgMinMax(
163-
total=1,
164-
average=number_of_turns,
165-
min=number_of_turns,
166-
max=number_of_turns,
167-
)
266+
number_of_turns_metrics.observe(number_of_turns)
168267

169268
temperature_metrics = AvgMinMax()
170269
if responses_create_params.get("temperature") is not None:
171270
temperature = responses_create_params["temperature"]
172-
temperature_metrics = AvgMinMax(
173-
total=1,
174-
average=temperature,
175-
min=temperature,
176-
max=temperature,
177-
)
271+
temperature_metrics.observe(temperature)
178272

273+
json_dumped_number_of_words_metrics = AvgMinMax()
179274
json_dumped_number_of_words = len(json.dumps(responses_create_params).split())
180-
json_dumped_number_of_words_metrics = AvgMinMax(
181-
total=1,
182-
average=json_dumped_number_of_words,
183-
min=json_dumped_number_of_words,
184-
max=json_dumped_number_of_words,
185-
)
275+
json_dumped_number_of_words_metrics.observe(json_dumped_number_of_words)
186276

187277
metrics = DatasetMetrics(
188278
number_of_examples=1,
@@ -200,6 +290,7 @@ class DatasetValidatorState(BaseModel):
200290
metrics: DatasetMetrics = Field(default_factory=DatasetMetrics)
201291
key_counts: Counter = Field(default_factory=Counter)
202292
offending_example_idxs: List[int] = Field(default_factory=list)
293+
other_metrics: Dict[str, Any] = Field(default_factory=dict)
203294

204295

205296
class TrainDataProcessor(BaseModel):
@@ -358,6 +449,8 @@ def _validate_samples_and_aggregate_metrics_single_sample(
358449
state.key_counts.update(sample_dict.keys())
359450
state.metrics.add(metrics)
360451

452+
aggregate_other_metrics(state.other_metrics, sample_dict)
453+
361454
def _validate_samples_and_aggregate_metrics_single_dataset(
362455
self, dataset_config: DatasetConfig
363456
) -> DatasetValidatorState:
@@ -373,6 +466,8 @@ def _validate_samples_and_aggregate_metrics_single_dataset(
373466
)
374467
)
375468

469+
postprocess_other_metrics(state.metrics, state.other_metrics)
470+
376471
return state
377472

378473
def _validate_aggregate_metrics(self, aggregate_metrics_dict: Dict, metrics_fpath: Path) -> Optional[Path]:

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ dependencies = [
121121
# Updated Tue Aug 05, 2025 with mlflow==3.2.0
122122
# License: Apache 2.0 https://github.com/mlflow/mlflow/blob/1510ed1bc92d3a4258973005d64f64a43136e251/LICENSE.txt
123123
"mlflow",
124+
125+
# Tdigest: Data structure for percentiles and quantiles, specifically calculating metrics such as median in a memory-efficient way.
126+
# Updated Wed Sep 17, 2025 with tdigest==0.5.2.2
127+
# License: MIT https://github.com/CamDavidsonPilon/tdigest/blob/e35cfd708962ae5e9d1c5d2b15a99af7b2e2f323/LICENSE.txt
128+
"tdigest>=0.5.2.2",
124129
]
125130

126131
[dependency-groups]

resources_servers/comp_coding/data/example_metrics.json

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,32 @@
99
"Total # non-null values": 0,
1010
"Average": 0.0,
1111
"Min": 0.0,
12-
"Max": 0.0
12+
"Max": 0.0,
13+
"Median": 0.0,
14+
"Standard deviation": 0.0
1315
},
1416
"Json-dumped number of words (proxy for token count)": {
1517
"Total # non-null values": 5,
1618
"Average": 457.0,
1719
"Min": 348.0,
18-
"Max": 542.0
20+
"Max": 542.0,
21+
"Median": 473.0,
22+
"Standard deviation": 79.75587752636166
1923
},
2024
"Number of turns": {
2125
"Total # non-null values": 5,
2226
"Average": 1.0,
2327
"Min": 1.0,
24-
"Max": 1.0
28+
"Max": 1.0,
29+
"Median": 1.0,
30+
"Standard deviation": 0.0
2531
},
2632
"Temperature": {
2733
"Total # non-null values": 0,
2834
"Average": 0.0,
2935
"Min": 0.0,
30-
"Max": 0.0
36+
"Max": 0.0,
37+
"Median": 0.0,
38+
"Standard deviation": 0.0
3139
}
3240
}

0 commit comments

Comments
 (0)