-
Notifications
You must be signed in to change notification settings - Fork 364
Expand file tree
/
Copy pathtrain.py
More file actions
2014 lines (1773 loc) · 83.1 KB
/
train.py
File metadata and controls
2014 lines (1773 loc) · 83.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Implements RL on general MDPs
"""
from __future__ import annotations
import asyncio
import io
import logging
import re
import time
from collections.abc import Callable, Coroutine, Iterable, Iterator, Sequence
from concurrent.futures import Executor
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, TypeVar
import chz
import numpy as np
import tinker
import torch
from tinker.types import LossFnType
from tqdm import tqdm
from tinker_cookbook import checkpoint_utils, model_info
from tinker_cookbook.display import colorize_example
from tinker_cookbook.eval.evaluators import SamplingClientEvaluator, SamplingClientEvaluatorBuilder
from tinker_cookbook.exceptions import ConfigurationError
from tinker_cookbook.rl.data_processing import (
assemble_training_data,
compute_advantages,
remove_constant_reward_groups,
)
from tinker_cookbook.rl.metric_util import RLTestSetEvaluator, compute_trajectory_metrics
from tinker_cookbook.rl.metrics import (
compute_kl_sample_train,
compute_post_kl,
compute_sampling_client_metrics,
incorporate_kl_penalty,
)
from tinker_cookbook.rl.rollout_logging import (
RolloutSummaryExportConfig,
RolloutSummaryGroup,
rollout_summaries_jsonl_path,
write_rollout_summaries_jsonl_from_groups,
)
from tinker_cookbook.rl.rollout_strategy import (
RolloutStrategy,
rollout_strategy_from_config,
)
from tinker_cookbook.rl.rollouts import (
RolloutErrorCounter,
do_group_rollout, # noqa: F401 — re-exported for verifiers monkey-patching
do_group_rollout_and_filter_constant_reward,
set_rollout_executor,
)
from tinker_cookbook.rl.types import (
EnvGroupBuilder,
RLDataset,
RLDatasetBuilder,
TrajectoryGroup,
)
from tinker_cookbook.tokenizer_utils import Tokenizer
from tinker_cookbook.utils import logtree, ml_log, trace
from tinker_cookbook.utils.deprecation import warn_deprecated
from tinker_cookbook.utils.misc_utils import iteration_dir, safezip, split_list
logger = logging.getLogger(__name__)
T = TypeVar("T")
@chz.chz
class KLReferenceConfig:
"""Configuration for the KL penalty reference model.
If not specified in Config, the training model's base model is used.
Attributes:
base_model (str): Name of the base model to use as the KL reference.
load_checkpoint_path (str | None): Optional checkpoint path to load
reference model weights from. If None, the base model weights
are used directly.
"""
base_model: str
load_checkpoint_path: str | None = None
async def gather_with_progress(
coroutines: Iterable[Coroutine[Any, Any, T]],
desc: str,
) -> list[T]:
"""Run coroutines concurrently with a progress bar that updates as each completes.
This preserves the order of results (like asyncio.gather) while providing
real-time progress feedback as individual coroutines complete.
Args:
coroutines (Iterable[Coroutine[Any, Any, T]]): Coroutines to run concurrently.
desc (str): Description label for the tqdm progress bar.
Returns:
list[T]: Results from each coroutine, in the same order as the input.
"""
coroutine_list = list(coroutines)
pbar = tqdm(total=len(coroutine_list), desc=desc)
async def track(coro: Coroutine[Any, Any, T]) -> T:
result = await coro
pbar.update(1)
return result
try:
results = await asyncio.gather(*[track(coro) for coro in coroutine_list])
finally:
pbar.close()
return results
def _get_evaluator_name(evaluator: SamplingClientEvaluator) -> str:
return (
evaluator.name
if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None
else ""
)
def _sanitize_filename_component(text: str) -> str:
"""Make a safe filename component."""
sanitized = re.sub(r"[^A-Za-z0-9_.-]+", "_", text)
return sanitized.strip("._") or "unnamed"
def _maybe_export_rollout_summary_jsonl(
*,
config: Config,
output_dir: Path | None,
base_name: str,
split: str,
iteration: int,
groups_P: Sequence[RolloutSummaryGroup],
) -> None:
"""
Write per-trajectory rollout summaries for one train/eval pass when enabled.
This is a thin policy gate around rollout_logging utilities:
- path naming (`<base_name>_rollout_summaries.jsonl` inside the iteration dir)
- on/off switch (`config.rollout_json_export`)
"""
if not config.rollout_json_export or output_dir is None:
return
write_rollout_summaries_jsonl_from_groups(
rollout_summaries_jsonl_path(output_dir, base_name),
split=split,
iteration=iteration,
groups_P=groups_P,
)
_LOGTREE_EXPLANATION = (
"This HTML log was generated by logtree during RL training. "
"It shows rollouts and rewards for a subset of trajectory groups in this iteration. "
"To customize what gets logged, modify the logtree calls in your Env implementation "
"(see examples in tinker_cookbook/recipes/)."
)
@contextmanager
def _get_logtree_scope(
output_dir: Path | None, num_groups_to_log: int, f_name: str, scope_name: str
) -> Iterator[None]:
"""
Creates a context manager; all log inside this context will be logged under the section `scope_name`.
It will create files with the paths output_dir/f_name.html and output_dir/f_name_logtree.json.
If num_groups_to_log is 0, it will disable logging (but note that this function does not actually implement the logic for logging itself!)
"""
if output_dir is None or num_groups_to_log <= 0:
yield
return
output_dir.mkdir(parents=True, exist_ok=True)
logtree_path = str(output_dir / f"{f_name}.html")
logtree_json_path = str(output_dir / f"{f_name}_logtree.json")
logtree_trace = None
try:
with logtree.init_trace(scope_name, path=logtree_path) as logtree_trace:
logtree.log_text(_LOGTREE_EXPLANATION)
yield
finally:
if logtree_trace is not None:
logtree.write_trace_json(logtree_trace, logtree_json_path)
def _select_representative_inds(scores: list[float], num_inds: int) -> list[int]:
assert num_inds <= len(scores)
sorted_inds = np.argsort(scores)
uniform_inds = np.linspace(0, len(sorted_inds) - 1, num_inds).astype(int)
return [int(sorted_inds[i]) for i in uniform_inds]
def print_group(traj_group: TrajectoryGroup, tokenizer: Tokenizer):
"""Print a subset of the trajectory group to the console.
Selects a representative sample of up to 4 trajectories (spanning the
reward distribution) and logs their tokens, rewards, advantages, and
per-step metrics via the module logger.
Args:
traj_group (TrajectoryGroup): The trajectory group to display.
tokenizer (Tokenizer): Tokenizer used to decode tokens for display.
"""
# Cut down the number of trajectories to print
max_trajs_to_print = 4
if len(traj_group.trajectories_G) > max_trajs_to_print:
inds = _select_representative_inds(traj_group.get_total_rewards(), max_trajs_to_print)
traj_group = TrajectoryGroup(
trajectories_G=[traj_group.trajectories_G[i] for i in inds],
final_rewards_G=[traj_group.final_rewards_G[i] for i in inds],
metrics_G=[traj_group.metrics_G[i] for i in inds],
)
rewards = traj_group.get_total_rewards()
advantages_G = compute_advantages([traj_group])
data_D, metadata_D = assemble_training_data([traj_group], advantages_G)
buf = io.StringIO()
def bprint(s: str):
print(s, file=buf)
bprint("\n====== Trajectory Group ======")
last_metadata = None
for datum, metadata in safezip(data_D, metadata_D):
idx = metadata["traj_idx"]
if metadata != last_metadata:
bprint(f"****** trajectory idx={idx}, reward={rewards[idx]:.3g} ******")
# Print trajectory-level metrics
if traj_group.metrics_G[idx]:
bprint("Trajectory metrics:")
for key, value in traj_group.metrics_G[idx].items():
bprint(f" {key}: {value}")
# Print per-transition metrics
transition_metrics = [
transition.metrics
for transition in traj_group.trajectories_G[idx].transitions
if transition.metrics
]
if transition_metrics:
bprint("Per-step metrics:")
for i, metrics in enumerate(transition_metrics):
bprint(f" Step {i}:")
for key, value in metrics.items():
bprint(f" {key}: {value}")
bprint("---- datum ----")
bprint(colorize_example(datum, tokenizer, key="advantages"))
last_metadata = metadata
bprint("====== End Trajectory Group ======")
logger.info(buf.getvalue().rstrip())
def _remove_mask(datum: tinker.Datum) -> tinker.Datum:
return tinker.Datum(
model_input=datum.model_input,
loss_fn_inputs={k: v for k, v in datum.loss_fn_inputs.items() if k != "mask"},
)
def _training_logprobs_from_fwd_bwd(
fwd_bwd_result: tinker.ForwardBackwardOutput,
) -> list[torch.Tensor]:
return [output["logprobs"].to_torch() for output in fwd_bwd_result.loss_fn_outputs]
@trace.scope
async def train_step(
data_D: list[tinker.Datum],
training_client: tinker.TrainingClient,
learning_rate: float,
num_substeps: int,
loss_fn: LossFnType,
loss_fn_config: dict[str, Any] | None = None,
metrics: dict[str, Any] | None = None,
) -> list[torch.Tensor]:
"""Train the model on collected trajectories.
Pipelines ``forward_backward`` and ``optim_step`` so they land on the same
clock cycle, maximizing GPU utilization. The data is split into
``num_substeps`` batches; each batch is enqueued before consuming the
previous result to keep the pipeline full.
Args:
data_D (list[tinker.Datum]): Training data assembled from trajectory
rollouts, including advantages and log-probabilities.
training_client (tinker.TrainingClient): Client connected to the
Tinker training service.
learning_rate (float): Learning rate for the Adam optimizer.
num_substeps (int): Number of sub-batches to split data_D into.
Each sub-batch triggers one forward_backward + optim_step pair.
loss_fn (LossFnType): Loss function identifier (e.g.
``"importance_sampling"``, ``"ppo"``).
loss_fn_config (dict[str, Any] | None): Extra configuration passed
to the loss function. Defaults to None.
metrics (dict[str, Any] | None): If provided, optimizer metrics from
the final optim_step are merged into this dict in-place.
Returns:
list[torch.Tensor]: Per-datum training log-probabilities returned by
the forward pass, one tensor per datum in data_D.
Example::
logprobs = await train_step(
data_D=data,
training_client=client,
learning_rate=1e-5,
num_substeps=2,
loss_fn="importance_sampling",
)
"""
batches = split_list(data_D, min(num_substeps, len(data_D)))
if not batches:
return []
adam_params = tinker.AdamParams(learning_rate=learning_rate, beta1=0.9, beta2=0.95, eps=1e-8)
training_logprobs_D: list[torch.Tensor] = []
optim_result: tinker.OptimStepResponse | None = None
# Enqueue first batch
fwd_bwd_future = await training_client.forward_backward_async(
[_remove_mask(d) for d in batches[0]], loss_fn=loss_fn, loss_fn_config=loss_fn_config
)
optim_future = await training_client.optim_step_async(adam_params)
for i in range(len(batches)):
# Enqueue next batch before consuming current results (to stay on same clock cycle)
if i + 1 < len(batches):
next_fwd_bwd_future = await training_client.forward_backward_async(
[_remove_mask(d) for d in batches[i + 1]],
loss_fn=loss_fn,
loss_fn_config=loss_fn_config,
)
next_optim_future = await training_client.optim_step_async(adam_params)
else:
next_fwd_bwd_future = None
next_optim_future = None
# Consume current results
fwd_bwd_result = await fwd_bwd_future.result_async()
training_logprobs_D.extend(_training_logprobs_from_fwd_bwd(fwd_bwd_result))
optim_result = await optim_future.result_async()
# Move to next iteration
if next_fwd_bwd_future is not None and next_optim_future is not None:
fwd_bwd_future = next_fwd_bwd_future
optim_future = next_optim_future
if metrics is not None and optim_result is not None and optim_result.metrics:
metrics.update(optim_result.metrics)
return training_logprobs_D
@chz.chz
class StreamMinibatchConfig:
"""Configuration for training with minibatch streaming.
Once enough trajectories for a minibatch have been accumulated, training
begins immediately rather than waiting for the full batch. This overlaps
sampling and training within a single iteration.
Attributes:
groups_per_batch (int): Total number of trajectory groups across all
minibatches and substeps in one training iteration.
num_minibatches (int): Number of minibatches per optimizer substep.
Each minibatch triggers one ``forward_backward()`` call, and one
``optim_step()`` is issued per substep.
"""
# Total number of trajectory groups across all minibatches and substeps
groups_per_batch: int
# For each substep, we will divide up the number of trajectory groups
# into this many minibatches.
# We will do num_minibatches forward_backward() passes and one optim_step()
# per substep.
num_minibatches: int
@chz.chz
class AsyncConfig:
"""Configuration for async RL training.
In async mode, sampling and training run concurrently. Trajectory groups
generated from a sampler that is too many steps behind the current
training step are discarded (or requeued) to limit off-policy staleness.
Attributes:
max_steps_off_policy (int): Maximum number of training steps a sample
can lag behind the current step before being considered stale.
groups_per_batch (int): Minimum number of trajectory groups required
to form a training batch, even after discarding stale samples.
"""
# If samples are generated from a sample more than this many steps ago,
# we will skip training on them.
max_steps_off_policy: int
# We will ensure all batches have at least this many groups, even
# as we discard stale samples
groups_per_batch: int
@chz.chz
class Config:
"""Configuration for RL training.
This is the main configuration object for :func:`main`. It controls the
model, dataset, optimizer, loss function, KL penalty, evaluation cadence,
checkpointing, logging, and execution mode (sync, async, or streaming
minibatch).
All fields use ``chz`` dataclass semantics and can be overridden via CLI
or YAML configuration files.
"""
# -------------------------------------------------------------------------
# Core parameters (recommended to set for nearly all runs)
# -------------------------------------------------------------------------
# Base learning rate used by Adam.
learning_rate: float
# Builds the RL dataset; also determines number of groups per batch.
dataset_builder: RLDatasetBuilder
# Model name (base weights) to train.
model_name: str
# Maximum number of generated tokens per rollout trajectory.
max_tokens: int
# Directory for checkpoints, logs, and traces.
log_path: str = chz.field(munger=lambda _, s: str(Path(s).expanduser()))
# Evaluation cadence in training iterations (0 = disabled).
eval_every: int = 20
# Checkpoint cadence in training iterations (0 = disabled).
save_every: int = 20
# Optional evaluators run during training.
evaluator_builders: list[SamplingClientEvaluatorBuilder] = chz.field(default_factory=list)
# Start training from weights at this checkpoint (fresh optimizer state).
load_checkpoint_path: str | None = None
# Renderer used by the training dataset/environment.
renderer_name: str | None = None
# Optional W&B project and run name.
wandb_project: str | None = None
wandb_name: str | None = None
# -------------------------------------------------------------------------
# KL penalty configuration (advanced)
# -------------------------------------------------------------------------
# KL penalty coefficient against reference policy (0 = disabled).
kl_penalty_coef: float = 0.0
# Optional position discount for KL penalty terms.
kl_discount_factor: float = 0.0
# Required when kl_penalty_coef > 0.
kl_reference_config: KLReferenceConfig | None = None
# -------------------------------------------------------------------------
# Loss and optimizer behavior (advanced)
# -------------------------------------------------------------------------
# Loss function and configuration.
# See https://tinker-docs.thinkingmachines.ai/losses
loss_fn: LossFnType = "importance_sampling"
loss_fn_config: dict[str, Any] | None = None
# Number of optimizer steps per training iteration.
# Useful for very large batch sizes.
num_substeps: int = 1
# LoRA rank for the training adapter.
lora_rank: int = 32
# -------------------------------------------------------------------------
# Sampling and diagnostics (advanced)
# -------------------------------------------------------------------------
# Changing sampling temperature is not generally recommended; T=1 is near-optimal
# for most post-trained models, and non-1 temperatures currently do not play
# well with KL penalty.
temperature: float = 1.0
# Compute extra post-update KL metrics (adds overhead).
compute_post_kl: bool = False
# Remove groups where all trajectories have identical reward.
remove_constant_reward_groups: bool = False
# Tolerance for errors during rollouts (container crashes, sandbox flakes, etc.).
# False (default): crash on any error (FailFast).
# True: retry failed trajectories with default budget (RetryOnFailure(max_retries=3)).
# RolloutStrategy instance: custom strategy (e.g. RetryOnFailure(max_retries=5)).
rollout_error_tolerance: bool | RolloutStrategy = False
# Emit async trace events for debugging/profiling.
enable_trace: bool = False
# Save a Gantt chart HTML every N iterations (0 = disabled). Requires plotly.
span_chart_every: int = 0
# -------------------------------------------------------------------------
# Execution mode knobs (advanced)
# -------------------------------------------------------------------------
# Enable async/off-policy training mode when set.
async_config: AsyncConfig | None = None
# Enable sync training with streaming minibatches when set.
stream_minibatch_config: StreamMinibatchConfig | None = None
# Optional service base URL override (primarily internal/dev use).
base_url: str | None = None
# -------------------------------------------------------------------------
# Checkpoint retention and logging detail (advanced)
# -------------------------------------------------------------------------
# Periodic checkpoints use this TTL; the final checkpoint is kept indefinitely.
# None disables expiry entirely.
ttl_seconds: int | None = 604800 # 7 days
# Rolling checkpoint cadence (0 = disabled). Saves training state for resume
# but skips the sampler-weight export, making it cheaper than periodic checkpoints.
rolling_save_every: int = 0
# TTL for rolling checkpoints; short to auto-clean if explicit deletion fails.
rolling_ttl_seconds: int = 7200 # 2 hours
num_groups_to_log: int = 4 # Number of groups to log per iteration (0 = disable logging)
rollout_json_export: bool = True
# Maximum number of training iterations. If None, train on the full dataset.
max_steps: int | None = None
@trace.scope
async def run_single_evaluation(
evaluator: SamplingClientEvaluator,
config: Config,
i_batch: int,
sampling_client: tinker.SamplingClient,
evaluator_label: str,
) -> dict[str, Any]:
"""Run a single evaluator and return its metrics.
Sets up a logtree scope for the evaluation, exports rollout summary JSONL
when applicable (for ``RLTestSetEvaluator``), and delegates to the
evaluator callable.
Args:
evaluator (SamplingClientEvaluator): The evaluator to run.
config (Config): RL training configuration (used for logging settings).
i_batch (int): Current training iteration index.
sampling_client (tinker.SamplingClient): Sampling client with the
current model weights.
evaluator_label (str): Filesystem-safe label used for log file naming.
Returns:
dict[str, Any]: Evaluation metrics produced by the evaluator.
"""
ev_name = _get_evaluator_name(evaluator)
eval_base_name = f"eval_{evaluator_label}"
iter_dir = iteration_dir(config.log_path, i_batch)
with _get_logtree_scope(
output_dir=iter_dir,
num_groups_to_log=config.num_groups_to_log,
f_name=eval_base_name,
scope_name=f"Running evaluation {ev_name} {i_batch}",
):
if isinstance(evaluator, RLTestSetEvaluator):
rollout_summary_export = (
RolloutSummaryExportConfig(
path=rollout_summaries_jsonl_path(iter_dir, eval_base_name),
split=f"eval/{evaluator_label}",
iteration=i_batch,
sampling_client_step=i_batch,
)
if config.rollout_json_export and iter_dir is not None
else None
)
eval_metrics = await evaluator(
sampling_client,
rollout_summary_export=rollout_summary_export,
)
else:
eval_metrics = await evaluator(sampling_client)
return eval_metrics
@trace.scope
async def run_evaluations_parallel(
evaluators: list[SamplingClientEvaluator],
sampling_client: tinker.SamplingClient,
config: Config,
i_batch: int,
) -> dict[str, Any]:
"""Run all evaluators in parallel and return aggregated metrics.
Each evaluator is launched as an independent ``asyncio.Task``. Results
are gathered and merged into a single metrics dictionary.
Args:
evaluators (list[SamplingClientEvaluator]): Evaluators to execute.
sampling_client (tinker.SamplingClient): Sampling client with the
current model weights.
config (Config): RL training configuration.
i_batch (int): Current training iteration index.
Returns:
dict[str, Any]: Merged metrics from all evaluators.
"""
# Create tasks for all evaluators with names for better traceability
tasks = []
for i, evaluator in enumerate(evaluators):
ev_name = _get_evaluator_name(evaluator)
evaluator_label = _sanitize_filename_component(ev_name or str(i))
task = asyncio.create_task(
run_single_evaluation(evaluator, config, i_batch, sampling_client, evaluator_label),
name=f"eval_{evaluator_label}_iteration_{i_batch:06d}",
)
tasks.append(task)
# Wait for all to complete
results = await asyncio.gather(*tasks)
# Merge all metrics
metrics = {}
for result in results:
metrics.update(result)
return metrics
@trace.scope
async def do_sync_training_with_stream_minibatch(
start_batch: int,
end_batch: int,
num_batches: int,
config: Config,
training_client: tinker.TrainingClient,
kl_reference_client: tinker.SamplingClient | None,
evaluators: list[SamplingClientEvaluator],
dataset: RLDataset,
ml_logger: ml_log.Logger,
tokenizer: Tokenizer,
error_counter: RolloutErrorCounter | None = None,
strategy: RolloutStrategy | None = None,
rolling_mgr: checkpoint_utils.RollingCheckpointManager | None = None,
):
"""Implement fully synchronous on-policy training with minibatch streaming.
Once enough trajectories for a minibatch have been accumulated, training
begins immediately rather than waiting for the full batch. This overlaps
sampling and training within a single iteration, reducing wall-clock time.
Args:
start_batch (int): First training iteration index (inclusive).
end_batch (int): Last training iteration index (exclusive).
num_batches (int): Total number of batches in the dataset, used for
progress fraction calculation.
config (Config): RL training configuration. Must have
``stream_minibatch_config`` set.
training_client (tinker.TrainingClient): Client connected to the
Tinker training service.
kl_reference_client (tinker.SamplingClient | None): Sampling client
for the KL reference model, or None if KL penalty is disabled.
evaluators (list[SamplingClientEvaluator]): Evaluators to run
periodically during training.
dataset (RLDataset): The RL dataset providing batches of
``EnvGroupBuilder`` instances.
ml_logger (ml_log.Logger): Logger for metrics and W&B integration.
tokenizer (Tokenizer): Tokenizer for decoding rollout tokens.
error_counter (RolloutErrorCounter | None): Optional counter for
tracking rollout errors. Defaults to None.
strategy (RolloutStrategy | None): Rollout error handling strategy.
Defaults to None.
"""
# Initial sampling client
sampling_client, _ = await save_checkpoint_and_get_sampling_client(
training_client,
start_batch,
config.log_path,
config.save_every,
start_batch,
config.ttl_seconds,
)
for i_batch in range(start_batch, end_batch):
metrics: dict[str, Any] = {
"progress/batch": i_batch,
"optim/lr": config.learning_rate,
"progress/done_frac": (i_batch + 1) / num_batches,
}
with trace.trace_iteration(step=i_batch) as window:
# Run evaluations
if (
config.eval_every > 0 and i_batch % config.eval_every == 0
) or i_batch == end_batch - 1:
async with trace.scope_span("run_evals"):
eval_metrics = await run_evaluations_parallel(
evaluators, sampling_client, config, i_batch
)
metrics.update(eval_metrics)
iter_dir = iteration_dir(config.log_path, i_batch)
with _get_logtree_scope(
iter_dir,
config.num_groups_to_log,
"train",
f"RL Iteration {i_batch}",
):
# Samplers will produce trajectory groups asynchronously,
# and the trainer will consume them as soon as they are ready
trajectory_groups_queue: asyncio.Queue[
WrappedTrajectoryGroup | _Shutdown | None
] = asyncio.Queue()
env_group_builders_P = dataset.get_batch(i_batch)
@trace.scope
async def trajectory_group_worker_task(
builder: EnvGroupBuilder, enable_logging: bool
) -> None:
worker_metrics: dict[str, Any] = {}
t_start = time.time()
async with trace.scope_span("trajectory_group_worker"):
trajectory_group = await do_group_rollout_and_filter_constant_reward(
sampling_client,
builder,
max_tokens=config.max_tokens,
temperature=config.temperature,
do_remove_constant_reward_groups=config.remove_constant_reward_groups,
enable_logging=enable_logging,
strategy=strategy,
)
worker_metrics["time/trajectory_group_worker_loop/total"] = (
time.time() - t_start
)
# Ingest error info (safe: same event loop thread)
if error_counter is not None:
error_counter.ingest(trajectory_group)
if trajectory_group is not None:
trajectory_groups_queue.put_nowait(
WrappedTrajectoryGroup(
trajectory_group=trajectory_group,
env_group_builder=builder,
sampling_client_step=i_batch,
metrics=worker_metrics,
)
)
else:
trajectory_groups_queue.put_nowait(None)
# Sample all trajectories asynchronously. If we have multiple minibatches,
# then sampling can overlap with training.
for i, builder in enumerate(env_group_builders_P):
asyncio.create_task(
trajectory_group_worker_task(
builder, enable_logging=i < config.num_groups_to_log
),
name=f"trajectory_group_worker_task_{i}",
)
# Run multiple optimizer substeps per training iteration
streaming_result = await do_train_step_streaming_and_get_sampling_client(
config,
i_batch,
trajectory_groups_queue,
training_client,
kl_reference_client,
tokenizer,
)
# _Shutdown cannot appear in the sync path's local queue
assert streaming_result is not None, "Unexpected shutdown in sync streaming path"
(
sampling_client,
full_batch_metrics,
full_batch_wrapped_trajectory_groups,
) = streaming_result
_maybe_export_rollout_summary_jsonl(
config=config,
output_dir=iter_dir,
base_name="train",
split="train",
iteration=i_batch,
groups_P=[
RolloutSummaryGroup(
trajectory_group=group.trajectory_group,
tags=group.env_group_builder.logging_tags(),
sampling_client_step=group.sampling_client_step,
)
for group in full_batch_wrapped_trajectory_groups
],
)
# Rolling checkpoint (fire-and-forget, overlaps with next iteration)
if rolling_mgr is not None:
await rolling_mgr.maybe_save_async(step=i_batch + 1, loop_state={"batch": i_batch + 1})
# Log metrics
metrics.update(full_batch_metrics)
if error_counter is not None:
metrics.update(error_counter.get_metrics())
metrics.update(window.get_timing_metrics())
window.write_spans_jsonl(Path(config.log_path) / "timing_spans.jsonl", step=i_batch)
if (
config.span_chart_every > 0
and i_batch % config.span_chart_every == 0
and iter_dir is not None
):
iter_dir.mkdir(parents=True, exist_ok=True)
trace.save_gantt_chart_html(window, i_batch, iter_dir / "timing_gantt.html")
ml_logger.log_metrics(metrics, step=i_batch)
@chz.chz
class WrappedTrajectoryGroup:
"""A wrapper around a trajectory group that includes generation metadata.
Used when sampling and training are overlapped (streaming minibatch or
async modes) so that staleness can be checked and stale groups requeued.
Attributes:
trajectory_group (TrajectoryGroup): The collected trajectory group.
env_group_builder (EnvGroupBuilder): The builder that produced this
group. Retained so that stale groups can be requeued.
sampling_client_step (int): The training step at which the sampling
client was created for this rollout.
metrics (dict[str, Any]): Timing and worker-level metrics collected
during rollout generation.
"""
trajectory_group: TrajectoryGroup
# The env group builder that produced the trajectory group.
# Pass this along in case the sampler is too stale, and we need to
# requeue this group.
env_group_builder: EnvGroupBuilder
# The step that produced this trajectory group.
sampling_client_step: int
metrics: dict[str, Any] = chz.field(default_factory=dict)
@dataclass
class _Shutdown:
"""Sentinel value to signal graceful shutdown through async queues.
Used in the cascading shutdown protocol for async RL training:
dataloader -> workers -> training loop -> evaluation loop.
"""
pass
class _AsyncCounter:
"""Async-safe counter for tracking the number of alive worker tasks."""
def __init__(self, start: int):
self._value = start
self._lock = asyncio.Lock()
async def decrement_and_get(self) -> int:
async with self._lock:
self._value -= 1
return self._value
@trace.scope
async def do_async_training(
start_batch: int,
end_batch: int,
num_batches: int,
config: Config,
training_client: tinker.TrainingClient,
kl_reference_client: tinker.SamplingClient | None,
evaluators: list[SamplingClientEvaluator],
dataset: RLDataset,
ml_logger: ml_log.Logger,
tokenizer: Tokenizer,
error_counter: RolloutErrorCounter | None = None,
strategy: RolloutStrategy | None = None,
rolling_mgr: checkpoint_utils.RollingCheckpointManager | None = None,
):
"""Implement async off-policy training, capped at K steps off policy.
Launches four concurrent coroutine groups that communicate via async
queues:
1. **Dataloader loop** -- feeds ``EnvGroupBuilder`` items into a queue.
2. **Trajectory worker loops** (one per ``groups_per_batch``) -- consume
builders, run rollouts, and push ``WrappedTrajectoryGroup`` results.
3. **Training loop** -- accumulates groups, discards stale samples, and
performs forward_backward + optim_step.
4. **Evaluation loop** -- runs evaluators whenever the sampling client is
updated.
Shutdown cascades from the dataloader through workers, training, and
finally evaluation via ``_Shutdown`` sentinels and ``asyncio.Event`` flags.
Args:
start_batch (int): First training iteration index (inclusive).
end_batch (int): Last training iteration index (exclusive).
num_batches (int): Total number of batches, used for progress
fraction calculation.
config (Config): RL training configuration. Must have
``async_config`` set.
training_client (tinker.TrainingClient): Client connected to the
Tinker training service.
kl_reference_client (tinker.SamplingClient | None): Sampling client
for the KL reference model, or None if KL penalty is disabled.
evaluators (list[SamplingClientEvaluator]): Evaluators to run
periodically during training.
dataset (RLDataset): The RL dataset providing batches of
``EnvGroupBuilder`` instances.
ml_logger (ml_log.Logger): Logger for metrics and W&B integration.
tokenizer (Tokenizer): Tokenizer for decoding rollout tokens.
error_counter (RolloutErrorCounter | None): Optional counter for
tracking rollout errors. Defaults to None.
strategy (RolloutStrategy | None): Rollout error handling strategy.
Defaults to None.
"""
assert config.async_config is not None
# We will have groups_per_batch workers generating rollouts, so cap the
# queue size to be groups_per_batch.
env_group_builders_queue = asyncio.Queue[EnvGroupBuilder | _Shutdown](
maxsize=config.async_config.groups_per_batch
)
trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | _Shutdown | None]()
# Initial sampling client to use
path_dict = await checkpoint_utils.save_checkpoint_async(
training_client=training_client,
name=f"{start_batch:06d}",
log_path=config.log_path,
loop_state={"batch": start_batch},
kind="both",
ttl_seconds=config.ttl_seconds,
)
# Shutdown coordination — cascading sequence:
# 1. Dataloader exhausts data → sets dataloader_done_event (prevents requeuing stale
# samples) and enqueues one _Shutdown sentinel per worker into env_group_builders_queue.
# 2. Each trajectory worker receives its _Shutdown sentinel → exits and decrements
# worker_alive_counter. The last worker enqueues a _Shutdown into trajectory_groups_queue.
# 3. Training loop receives _Shutdown from trajectory_groups_queue → finishes current
# batch, sets evaluation_loop_should_shutdown_event, and exits.
# 4. Eval loop sees evaluation_loop_should_shutdown_event → exits.
dataloader_done_event = asyncio.Event()
evaluation_loop_should_shutdown_event = asyncio.Event()
worker_alive_counter = _AsyncCounter(config.async_config.groups_per_batch)
# This will be updated by the training loop
sampling_client = training_client.create_sampling_client(path_dict["sampler_path"])
sampling_client_step = start_batch
sampling_client_updated_event = asyncio.Event()
sampling_client_updated_event.set()
@trace.scope
async def dataloader_loop():
"""Gets the next set of env builders to run"""
i_batch = start_batch
while i_batch < end_batch:
env_group_builders_P = dataset.get_batch(i_batch)
for env_group_builder in env_group_builders_P:
await env_group_builders_queue.put(env_group_builder)
i_batch += 1
# Signal that no more data will be produced, so stale samples should not be requeued
dataloader_done_event.set()
# Enqueue shutdown sentinels — one per worker — to cascade the shutdown
logger.info("[dataloader_loop] No more data, shutting down trajectory group workers")
assert config.async_config is not None
for _ in range(config.async_config.groups_per_batch):
await env_group_builders_queue.put(_Shutdown())
logger.info("[dataloader_loop] Terminated")
@trace.scope
async def trajectory_group_worker_loop():
"""Generates trajectories for a single env builder"""
while True:
env_group_builder = await env_group_builders_queue.get()
if isinstance(env_group_builder, _Shutdown):
logger.info("[trajectory_group_worker_loop] Received shutdown signal")
break
# Save a reference to the sampling client step in case it changes
# while we're running the rollout
sampling_client_step_copy = sampling_client_step
worker_metrics: dict[str, Any] = {}
t_start = time.time()
async with trace.scope_span("trajectory_group_worker"):
trajectory_group = await do_group_rollout_and_filter_constant_reward(
sampling_client,
env_group_builder,
max_tokens=config.max_tokens,
temperature=config.temperature,
do_remove_constant_reward_groups=config.remove_constant_reward_groups,
strategy=strategy,
)
worker_metrics["time/trajectory_group_worker_loop/total"] = time.time() - t_start
# Ingest error info (safe: same event loop thread)
if error_counter is not None:
error_counter.ingest(trajectory_group)
if trajectory_group is None:
trajectory_groups_queue.put_nowait(None)
else:
trajectory_groups_queue.put_nowait(
WrappedTrajectoryGroup(
trajectory_group=trajectory_group,