1515from abc import abstractmethod
1616from collections import Counter , defaultdict
1717from itertools import count , repeat
18+ from math import sqrt
1819from pathlib import Path
1920from shutil import copyfileobj
20- from typing import Dict , List , Literal , Optional , Self , Tuple , Union
21+ from typing import Any , Dict , List , Literal , Optional , Self , Tuple , Union
2122
2223from devtools import pprint
2324from omegaconf import DictConfig
2425from pydantic import BaseModel , ConfigDict , Field , ValidationError
26+ from tdigest import TDigest
2527from tqdm .auto import tqdm
2628
2729from nemo_gym .base_resources_server import BaseRunRequest
@@ -79,27 +81,90 @@ def _aggregate(self: Self) -> Self:
7981
8082
8183class AvgMinMax (Accumulator ):
84+ model_config = ConfigDict (arbitrary_types_allowed = True )
8285 total : int = Field (serialization_alias = "Total # non-null values" , default = 0 )
8386 average : float = Field (serialization_alias = "Average" , default = 0 )
8487 min : float = Field (serialization_alias = "Min" , default = float ("inf" ))
8588 max : float = Field (serialization_alias = "Max" , default = float ("-inf" ))
89+ median : float = Field (serialization_alias = "Median" , default = 0 )
90+ stddev : float = Field (serialization_alias = "Standard deviation" , default = 0 )
91+ # Internal state
92+ mean : float = Field (default = 0 , exclude = True ) # running value (before final average)
93+ M2 : float = Field (default = 0 , exclude = True ) # sum of squared differences (for variance)
94+ tdigest : TDigest = Field (default_factory = TDigest , exclude = True )
95+ """
96+ T-Digest is used to estimate the Median without storing and sorting all values. The Median is essentially an approximation using the 50th percentile, which is very close to the true Median.
97+ """
98+
99+ def observe (self , x : float ) -> None :
100+ if x < self .min :
101+ self .min = x
102+ if x > self .max :
103+ self .max = x
104+
105+ # Update running mean and variance
106+ self .total += 1
107+ delta = x - self .mean
108+ self .mean += delta / self .total
109+ self .M2 += delta * (x - self .mean )
110+
111+ # Update quantile estimator (for median)
112+ self .tdigest .update (x )
86113
87114 def _add (self : Self , other : Self ) -> None :
88- self .total += other .total
89- self .average += other .average
90- self .min = min (self .min , other .min )
91- self .max = max (self .max , other .max )
115+ # Merge accumulators
116+ if other .total == 0 :
117+ return
118+ if self .total == 0 :
119+ self .total = other .total
120+ self .mean = other .mean
121+ self .M2 = other .M2
122+ self .min = other .min
123+ self .max = other .max
124+ self .tdigest = TDigest ()
125+ self .tdigest = self .tdigest + other .tdigest
126+ return
127+
128+ # Merge mean and variance
129+ n1 , n2 = self .total , other .total
130+ delta = other .mean - self .mean
131+ n = n1 + n2
132+ self .mean = self .mean + delta * (n2 / n )
133+ self .M2 = self .M2 + other .M2 + (delta * delta ) * (n1 * n2 / n )
134+ self .total = n
135+
136+ if other .min < self .min :
137+ self .min = other .min
138+ if other .max > self .max :
139+ self .max = other .max
140+
141+ # Merge t-digests for quantiles/median
142+ self .tdigest = self .tdigest + other .tdigest
143+
144+ def _aggregate (self : Self ) -> Self :
145+ n = self .total
146+ mean = self .mean if n > 0 else 0.0
147+ stddev = sqrt (self .M2 / (n - 1 )) if n > 1 else 0.0
148+ med = float (self .tdigest .percentile (50 )) if n > 0 and self .tdigest .n > 0 else 0.0
92149
93- def _aggregate (self ) -> Self :
94150 return AvgMinMax (
95151 total = self .total ,
96- average = self .average / max (self .total , 1 ),
97- min = self .min if self .total > 0 else 0 ,
98- max = self .max if self .total > 0 else 0 ,
152+ average = mean ,
153+ min = self .min if n > 0 else 0.0 ,
154+ max = self .max if n > 0 else 0.0 ,
155+ median = med ,
156+ stddev = stddev ,
99157 )
100158
101159
160+ class StringMetrics (BaseModel ):
161+ unique_count : int
162+ total_count : int
163+
164+
102165class DatasetMetrics (Accumulator ):
166+ model_config = ConfigDict (extra = "allow" ) # Allow any arbitrary fields
167+
103168 number_of_examples : int = Field (serialization_alias = "Number of examples" , default = 0 )
104169 number_of_tools : AvgMinMax = Field (serialization_alias = "Number of tools" , default_factory = AvgMinMax )
105170 json_dumped_number_of_words : AvgMinMax = Field (
@@ -118,16 +183,60 @@ def _add(self: Self, other: Self) -> None:
118183 self .number_of_turns .add (other .number_of_turns )
119184 self .temperature .add (other .temperature )
120185
186+ # Merge extra fields safely
187+ if other .model_extra :
188+ for k , v in other .model_extra .items ():
189+ if k in DatasetMetrics .model_fields .keys ():
190+ continue
191+ setattr (self , k , v )
192+
121193 def _aggregate (self : Self ) -> Self :
194+ extras = {}
195+ if self .model_extra :
196+ for k , v in self .model_extra .items ():
197+ if k in DatasetMetrics .model_fields .keys ():
198+ continue
199+ extras [k ] = v
122200 return DatasetMetrics (
123201 number_of_examples = self .number_of_examples ,
124202 number_of_tools = self .number_of_tools .aggregate (),
125203 json_dumped_number_of_words = self .json_dumped_number_of_words .aggregate (),
126204 number_of_turns = self .number_of_turns .aggregate (),
127205 temperature = self .temperature .aggregate (),
206+ ** extras ,
128207 )
129208
130209
210+ def aggregate_other_metrics (metrics : Dict [str , Any ], sample : Dict [str , Any ]) -> None :
211+ """Combines misc items (those other than response/response create params) into current metrics"""
212+ for k , v in sample .items ():
213+ if k in ("responses_create_params" , "response" ):
214+ continue
215+
216+ values = v if isinstance (v , list ) else [v ]
217+
218+ for item in values :
219+ if isinstance (item , bool ):
220+ item = int (item )
221+ if isinstance (item , (int , float )):
222+ if k not in metrics :
223+ metrics [k ] = AvgMinMax ()
224+ metrics [k ].observe (item )
225+ elif isinstance (item , str ):
226+ if k not in metrics :
227+ metrics [k ] = Counter ()
228+ metrics [k ][item ] += 1
229+
230+
231+ def postprocess_other_metrics (metrics : DatasetMetrics , other_metrics : Dict [str , Any ]) -> None :
232+ """Aggregates metrics and merges current metrics (containing only AvgMinMax) with StringMetrics"""
233+ for k , v in other_metrics .items ():
234+ if isinstance (v , AvgMinMax ):
235+ setattr (metrics , k , v .aggregate ())
236+ elif isinstance (v , Counter ):
237+ setattr (metrics , k , StringMetrics (unique_count = len (v ), total_count = sum (v .values ())))
238+
239+
131240def compute_sample_metrics (sample_dict_str : str ) -> Tuple [DatasetMetrics , bool ]:
132241 try :
133242 sample_dict = json .loads (sample_dict_str )
@@ -146,43 +255,24 @@ def compute_sample_metrics(sample_dict_str: str) -> Tuple[DatasetMetrics, bool]:
146255 number_of_tools_metrics = AvgMinMax ()
147256 if responses_create_params .get ("tools" ) is not None :
148257 number_of_tools = len (responses_create_params ["tools" ])
149- number_of_tools_metrics = AvgMinMax (
150- total = 1 ,
151- average = number_of_tools ,
152- min = number_of_tools ,
153- max = number_of_tools ,
154- )
258+ number_of_tools_metrics .observe (number_of_tools )
155259
156260 if isinstance (inputs , str ):
157261 inputs = [{"role" : "user" , "content" : inputs }]
158262 user_inputs = [i for i in inputs if i .get ("role" ) == "user" ] if inputs else []
159263 number_of_turns_metrics = AvgMinMax ()
160264 if user_inputs :
161265 number_of_turns = len (user_inputs )
162- number_of_turns_metrics = AvgMinMax (
163- total = 1 ,
164- average = number_of_turns ,
165- min = number_of_turns ,
166- max = number_of_turns ,
167- )
266+ number_of_turns_metrics .observe (number_of_turns )
168267
169268 temperature_metrics = AvgMinMax ()
170269 if responses_create_params .get ("temperature" ) is not None :
171270 temperature = responses_create_params ["temperature" ]
172- temperature_metrics = AvgMinMax (
173- total = 1 ,
174- average = temperature ,
175- min = temperature ,
176- max = temperature ,
177- )
271+ temperature_metrics .observe (temperature )
178272
273+ json_dumped_number_of_words_metrics = AvgMinMax ()
179274 json_dumped_number_of_words = len (json .dumps (responses_create_params ).split ())
180- json_dumped_number_of_words_metrics = AvgMinMax (
181- total = 1 ,
182- average = json_dumped_number_of_words ,
183- min = json_dumped_number_of_words ,
184- max = json_dumped_number_of_words ,
185- )
275+ json_dumped_number_of_words_metrics .observe (json_dumped_number_of_words )
186276
187277 metrics = DatasetMetrics (
188278 number_of_examples = 1 ,
@@ -200,6 +290,7 @@ class DatasetValidatorState(BaseModel):
200290 metrics : DatasetMetrics = Field (default_factory = DatasetMetrics )
201291 key_counts : Counter = Field (default_factory = Counter )
202292 offending_example_idxs : List [int ] = Field (default_factory = list )
293+ other_metrics : Dict [str , Any ] = Field (default_factory = dict )
203294
204295
205296class TrainDataProcessor (BaseModel ):
@@ -358,6 +449,8 @@ def _validate_samples_and_aggregate_metrics_single_sample(
358449 state .key_counts .update (sample_dict .keys ())
359450 state .metrics .add (metrics )
360451
452+ aggregate_other_metrics (state .other_metrics , sample_dict )
453+
361454 def _validate_samples_and_aggregate_metrics_single_dataset (
362455 self , dataset_config : DatasetConfig
363456 ) -> DatasetValidatorState :
@@ -373,6 +466,8 @@ def _validate_samples_and_aggregate_metrics_single_dataset(
373466 )
374467 )
375468
469+ postprocess_other_metrics (state .metrics , state .other_metrics )
470+
376471 return state
377472
378473 def _validate_aggregate_metrics (self , aggregate_metrics_dict : Dict , metrics_fpath : Path ) -> Optional [Path ]:
0 commit comments