-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Expand file tree
/
Copy pathtraining.py
More file actions
2940 lines (2608 loc) · 120 KB
/
training.py
File metadata and controls
2940 lines (2608 loc) · 120 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Pretrain utilities."""
import dataclasses
from datetime import datetime, timedelta
import functools
import gc
import inspect
import logging
import math
import os
import sys
from typing import List, Optional
import torch.distributed
from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer
from .log_handler import CustomHandler
# Make default logging level INFO, but filter out all log messages not from MCore.
logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO)
from .theoretical_memory_usage import report_theoretical_memory
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
import torch
try:
from megatron.rl import rl_utils
has_rl_utils = True
except ImportError:
has_rl_utils = False
try:
from megatron.post_training.algos.distillation import (
get_tensor_shapes_adjust_fn_for_distillation,
)
has_nvidia_modelopt = True
except ImportError:
has_nvidia_modelopt = False
try:
from nvidia_resiliency_ext.inprocess import CallWrapper
except ImportError:
CallWrapper = type(None)
from megatron.core import mpu, tensor_parallel
from megatron.core.utils import (
check_param_hashes_across_dp_replicas,
get_model_config,
StragglerDetector,
)
from megatron.core.fp8_utils import correct_amax_history_if_needed
from megatron.training.checkpointing import load_checkpoint
from megatron.training.checkpointing import save_checkpoint
from megatron.training.checkpointing import checkpoint_exists
from megatron.core.full_cuda_graph import FullCudaGraphWrapper
from megatron.core.transformer.cuda_graphs import TECudaGraphHelper
from megatron.core.transformer.module import Float16Module
from megatron.core.distributed import DistributedDataParallelConfig, TorchFullyShardedDataParallelConfig
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel as megatron_FSDP
from megatron.core.optimizer.optimizer import param_group_identifier_keys
try:
from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP
HAVE_FSDP2 = True
except ImportError:
HAVE_FSDP2 = False
from megatron.core.distributed import finalize_model_grads
from megatron.core.enums import ModelType
from megatron.core.optimizer import get_megatron_optimizer, OptimizerConfig
from megatron.core.rerun_state_machine import (
get_rerun_state_machine,
destroy_rerun_state_machine,
RerunDataIterator,
RerunMode,
)
from megatron.training.initialize import initialize_megatron
from megatron.training.initialize import write_args_to_tensorboard
from megatron.training.initialize import set_jit_fusion_options
from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank
from megatron.legacy.data.data_samplers import build_pretraining_data_loader
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.core.transformer.moe import upcycling_utils
from megatron.core.transformer.moe.moe_utils import track_moe_metrics
from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper
from megatron.core.parallel_state import (
destroy_global_memory_buffer,
destroy_model_parallel,
update_pg_timeout
)
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.num_microbatches_calculator import (
destroy_num_microbatches_calculator,
get_current_global_batch_size,
get_current_running_global_batch_size,
get_num_microbatches,
update_num_microbatches
)
from .async_utils import maybe_finalize_async_save
from .utils import (
append_to_progress_log,
calc_params_l2_norm,
check_adlr_autoresume_termination,
logical_and_across_model_parallel_group,
reduce_max_stat_across_model_parallel_group,
is_last_rank,
print_rank_0,
print_rank_last,
report_memory,
unwrap_model,
update_use_dist_ckpt,
to_empty_if_meta_device,
)
from .global_vars import (
destroy_global_vars,
get_args,
get_signal_handler,
get_timers,
get_tensorboard_writer,
get_wandb_writer,
get_one_logger,
get_tokenizer,
get_energy_monitor,
)
from . import one_logger_utils
from . import ft_integration
stimer = StragglerDetector()
from megatron.core.msc_utils import MultiStorageClientFeature, open_file
def destroy_global_state():
destroy_global_vars()
destroy_num_microbatches_calculator()
destroy_global_memory_buffer()
destroy_model_parallel()
destroy_rerun_state_machine()
def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print_rank_0(f'[{string}] datetime: {time_str} ')
def num_floating_point_operations(args, batch_size):
def calculate_layer_counts():
"""Calculate the number of attention, Mamba, and MLP layers."""
if args.hybrid_override_pattern:
counts = {'M': 0, '*': 0, '-': 0}
for layer_type in args.hybrid_override_pattern:
if layer_type in counts:
counts[layer_type] += 1
return counts['*'], counts['M'], counts['-']
else:
num_attn_layers = round(args.num_layers * args.hybrid_attention_ratio)
num_mlp_layers = round(args.num_layers * args.hybrid_mlp_ratio)
num_mamba_layers = args.num_layers - num_attn_layers - num_mlp_layers
return num_attn_layers, num_mamba_layers, num_mlp_layers
def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False):
"""Calculate FLOPs for an MLP layer."""
scale_factor = 3.0 / 2.0 if swiglu else 1.0
return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2
def attn_layer_flops(
batch_size, seq_len, hidden_size, num_heads, gqa=True, gqa_groups=8, kv_channels=None
):
"""Calculate FLOPs for an attention layer."""
p = (kv_channels * num_heads / hidden_size) if kv_channels else 1
g = gqa_groups if gqa else num_heads
return (
4
* batch_size
* seq_len
* hidden_size
* p
* (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2))
)
def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16,
head_dim=64, num_groups=1, num_heads=128):
"""Calculate FLOPs for a Mamba layer."""
# Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels,
# but small percent of overall layer flops
d_in = 2 * hidden_size
if num_heads:
nheads = num_heads
else:
nheads = d_in // head_dim
return (
(
2
* batch_size
* seq_len
* hidden_size
* (2 * d_in + 2 * num_groups * state_dim + nheads)
) # in_proj
+ (7 * batch_size * seq_len * d_in * state_dim) # scan
+ (2 * batch_size * seq_len * d_in * hidden_size) # out_proj
)
def hybrid_flops(batch_size, seq_len, hidden_size,
num_attn_layers, num_mamba_layers, num_mlp_layers,
mamba_state_dim=128, mamba_head_dim=64,
mamba_num_groups=8, mamba_num_heads=128,
num_attn_heads=32,gqa=True,
gqa_groups=8, kv_channels=None,
mlp_expansion=4.0, swiglu=False,
vocab_size=256000):
"""Calculate total FLOPs for the hybrid model."""
flops_fwd = (
num_attn_layers * attn_layer_flops(batch_size, seq_len, hidden_size,
num_attn_heads, gqa, gqa_groups, kv_channels) +
num_mlp_layers * mlp_layer_flops(batch_size, seq_len, hidden_size,
mlp_expansion, swiglu) +
num_mamba_layers * mamba_layer_flops(batch_size, seq_len, hidden_size,
mamba_state_dim, mamba_head_dim,
mamba_num_groups, mamba_num_heads) +
(2 * batch_size * seq_len * hidden_size * vocab_size) # logits computation
)
return flops_fwd * 3
def transformer_flops():
"""Calculate FLOPs for a standard Transformer model."""
# TODO(helenn/dnarayanan): Refactor this to reuse the helper methods.
# Attention projection size.
query_projection_size = args.kv_channels * args.num_attention_heads
query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size
# Group Query Attention.
if not args.group_query_attention:
args.num_query_groups = args.num_attention_heads
# MoE.
if args.num_experts is None:
# Every Transformer MLP is dense.
num_dense_layers = args.num_layers
num_moe_layers = 0
num_experts_routed_to = 0
last_layer_is_moe = 0
else:
# Calculate number of dense and MoE Transformer MLPs.
if isinstance(args.moe_layer_freq, int):
moe_layer_pattern = [
1 if (i % args.moe_layer_freq == 0) else 0 for i in range(args.num_layers)
]
elif isinstance(args.moe_layer_freq, list):
moe_layer_pattern = args.moe_layer_freq
else:
raise RuntimeError("Illegal --moe-layer-freq argument provided!")
assert len(moe_layer_pattern) == args.num_layers, (
f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, "
f"expected {args.num_layers}, "
f"current moe layer pattern: {args.moe_layer_freq}"
)
num_moe_layers = sum(moe_layer_pattern) # Number of 1s in `moe_layer_pattern`.
num_dense_layers = args.num_layers - num_moe_layers
num_experts_routed_to = args.moe_router_topk
last_layer_is_moe = moe_layer_pattern[-1]
if args.mtp_num_layers is not None:
mtp_num_layers = args.mtp_num_layers
num_moe_layers += last_layer_is_moe * mtp_num_layers
num_dense_layers += (1 - last_layer_is_moe) * mtp_num_layers
num_layers = args.num_layers + mtp_num_layers
else:
mtp_num_layers = 0
num_layers = args.num_layers
moe_ffn_hidden_size = (
args.moe_ffn_hidden_size
if args.moe_ffn_hidden_size is not None
else args.ffn_hidden_size
)
shared_expert_ffn_hidden_size = (
0
if args.moe_shared_expert_intermediate_size is None
else args.moe_shared_expert_intermediate_size
)
# SwiGLU.
gated_linear_multiplier = 3 / 2 if args.swiglu else 1
# The 12x term below comes from the following factors; for more details, see
# "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473.
# - 3x: Each GEMM in the model needs to be performed 3 times (forward pass,
# backward wgrad [weight gradient], backward dgrad [data gradient]).
# - 2x: GEMMs of a particular size are stacked twice in the standard Transformer model
# architectures implemented in this codebase (e.g., h->ffn_h GEMM and ffn_h->h GEMM
# in MLP layer).
# - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
expansion_factor = 3 * 2 * 2
if args.multi_latent_attention:
assert not args.group_query_attention
'''
Basic arithmetic
let B is batch size, s is seq_len, h is embedding dim,
for one self_attnetion block (prenorm is not included)
qkv projection: 6Bsh^2
attn: 2Bs^2h
attn over value: 2Bs^2h
oproj: 2Bsh^2
references
https://arxiv.org/abs/2305.10403
https://arxiv.org/abs/2205.05198
'''
## MLA
if args.q_lora_rank is None:
q_term = (
args.hidden_size
* args.num_attention_heads
* (args.qk_head_dim + args.qk_pos_emb_head_dim)
)
else:
q_term = args.q_lora_rank * (
args.hidden_size
+ args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)
+ 1
)
self_attn_term = (
3
* 2 # fwd(1) + bwd(2) *FMA
* num_layers
* (
## q lora + rope + q norm
q_term
## kv lora + rope + kv norm
+ args.kv_lora_rank
* (
args.hidden_size
+ args.num_attention_heads * (args.qk_head_dim + args.v_head_dim)
+ 1
)
+ args.hidden_size * args.qk_pos_emb_head_dim
## o proj
+ (args.num_attention_heads * args.v_head_dim) * args.hidden_size
## core attn
+ args.seq_length
* (args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim))
/ 2
+ args.seq_length * args.num_attention_heads * args.v_head_dim / 2
)
)
else:
## MHA or GQA
self_attn_term = (
expansion_factor
* num_layers
* args.hidden_size
* args.hidden_size
* (
(
1
+ (args.num_query_groups / args.num_attention_heads)
# # Only half of the attention matrix is non-zero and needs to be multiplied with V.
+ (args.seq_length / args.hidden_size / 2)
)
* query_projection_to_hidden_size_ratio
)
)
total_floating_point_operations = (
batch_size
* args.seq_length
* (
# MLP
expansion_factor
* num_layers
* args.hidden_size
* (
# dense layer (deepseek v2, v3 style)
(args.ffn_hidden_size * gated_linear_multiplier)
* (num_dense_layers / num_layers)
# routed experts
+ (moe_ffn_hidden_size * num_experts_routed_to * gated_linear_multiplier)
* (num_moe_layers / num_layers)
# Shared Experts.
+ (shared_expert_ffn_hidden_size * gated_linear_multiplier)
* (num_moe_layers / num_layers)
)
# Self Attention
+ self_attn_term
# MTP norms and proj
+ 3
* 2
* mtp_num_layers
* (
# MTP eh norm + final nrom
3 * args.hidden_size
# MTH eh proj
+ 2 * args.hidden_size * args.hidden_size
)
# Logit.
+ 3 * 2 * args.hidden_size * args.padded_vocab_size * (mtp_num_layers + 1)
)
)
return total_floating_point_operations
# Main entrypoint for FLOPs calculation.
if args.is_hybrid_model:
# Calculate the number of each type of layer.
num_attn_layers, num_mamba_layers, num_mlp_layers = calculate_layer_counts()
# Compute hybrid model FLOPs.
return hybrid_flops(
batch_size=batch_size,
seq_len=args.seq_length,
hidden_size=args.hidden_size,
num_attn_layers=num_attn_layers,
num_mamba_layers=num_mamba_layers,
num_mlp_layers=num_mlp_layers,
mamba_state_dim=args.mamba_state_dim,
mamba_head_dim=args.mamba_head_dim,
mamba_num_groups=args.mamba_num_groups,
mamba_num_heads=args.mamba_num_heads,
num_attn_heads=args.num_attention_heads,
gqa=args.group_query_attention,
gqa_groups=args.num_query_groups,
kv_channels=args.kv_channels,
mlp_expansion=args.ffn_hidden_size / args.hidden_size,
swiglu=args.swiglu,
vocab_size=args.padded_vocab_size,
)
else:
# Compute standard Transformer model FLOPs.
return transformer_flops()
def get_start_time_from_progress_log():
"""
Gets start time of earliest job with same world size. Also returns the number
of floating-point operations completed in last saved checkpoint.
"""
args = get_args()
assert args.save is not None
progress_log_filename = os.path.join(args.save, "progress.txt")
# start_time is time when job with same world size started.
# start_num_floating_point_operations is the number of floating-point operations
# completed when this job started.
# latest_num_floating_point_operations is the number of floating-point operations
# completed in most recent saved checkpoint.
start_time = None
start_num_floating_point_operations = None
latest_num_floating_point_operations = 0
def _get_field(string, type):
return type(string.split(': ')[1])
with open_file(progress_log_filename, 'r') as f:
for line in f:
line = line.strip()
line_tokens = line.split('\t')
world_size_in_line = _get_field(line_tokens[2], int)
if line_tokens[3] == "Saved checkpoint":
latest_num_floating_point_operations = _get_field(line_tokens[7], float)
if world_size_in_line != args.world_size:
# Re-start search if we see a different world size.
start_time = None
start_num_floating_point_operations = None
continue
if line_tokens[3] == "Starting job":
if start_time is None:
start_time = line_tokens[0]
start_num_floating_point_operations = latest_num_floating_point_operations
assert (
start_time is not None and start_num_floating_point_operations is not None
), "Should have seen at least one 'Starting job' entry with same world_size"
return datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S'), start_num_floating_point_operations
def preprocess_common_state_dict(common_state_dict):
import copy
# Convert args key of type namespace to dictionary
preprocessed_common_state_dict = copy.deepcopy(common_state_dict)
preprocessed_common_state_dict['args'] = vars(preprocessed_common_state_dict['args'])
# Remove rank and local rank from state dict if it exists, since they are expected to be different
preprocessed_common_state_dict['args'].pop('local_rank', None)
preprocessed_common_state_dict['args'].pop('rank', None)
if (
preprocessed_common_state_dict['args']['use_distributed_optimizer']
and "optimizer" in preprocessed_common_state_dict
):
def reorder_inner_param_groups(optimizer_state_dict):
# When distributed optimizer loading, source param groups will be reordered,
# so we reorder the param groups here to prevent warning.
# Pop empty param_state.
if "param_state" in optimizer_state_dict and not optimizer_state_dict["param_state"]:
optimizer_state_dict.pop("param_state")
# Reorder param groups.
if "optimizer" not in optimizer_state_dict:
return
inner_optimizer = optimizer_state_dict["optimizer"]
if "param_groups" not in inner_optimizer:
return
param_groups = inner_optimizer["param_groups"]
key_fn = lambda pg: [pg[key] for key in param_group_identifier_keys]
param_groups.sort(key=key_fn)
inner_optimizer["param_groups"] = param_groups
optimizer_state_dict = preprocessed_common_state_dict['optimizer']
if "optimizer" in optimizer_state_dict:
# Only 1 optimizer in chained optimizer.
reorder_inner_param_groups(optimizer_state_dict)
else:
# Multiple optimizers in chained optimizer.
for i in range(len(optimizer_state_dict)):
if i in optimizer_state_dict.keys():
reorder_inner_param_groups(optimizer_state_dict[i])
return preprocessed_common_state_dict
def pretrain(
train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None,
args_defaults={},
get_embedding_ranks=None,
get_position_embedding_ranks=None,
non_loss_data_func=None,
store=None,
inprocess_call_wrapper: Optional[CallWrapper] = None,
):
"""Main training program.
This function will run the followings in the order provided:
1) initialize Megatron.
2) setup model, optimizer and lr schedule using the model_provider.
3) call train_val_test_data_provider to get train/val/test datasets.
4) train the model using the forward_step_func.
Args:
train_valid_test_dataset_provider: a function that takes the size of
train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
model_type: an enum that specifies the type of model being trained.
forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
to set already parse arguments.
get_embedding_ranks (TODO):
get_position_embedding_ranks (TODO):
non_loss_data_func (callable): A custom function to call during evaluation.
It can run e.g. benchmarks.
store: an optional instance of torch.distributed.Store, to be used by
torch.distributed.init_process_group
inprocess_call_wrapper: an optional instance of inprocess.CallWrapper,
it is automatically injected when in-process restart is in use
"""
if inprocess_call_wrapper is not None:
iteration = inprocess_call_wrapper.iteration
store = torch.distributed.PrefixStore(str(iteration), store)
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(
extra_args_provider=extra_args_provider,
args_defaults=args_defaults,
get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks,
store=store,
)
args = get_args()
timers = get_timers()
if args.log_progress:
append_to_progress_log("Starting job")
# Initialize fault tolerance
# NOTE: ft_integration functions other than `setup` are no-op if the FT is not initialized
if args.enable_ft_package:
ft_integration.setup(args)
ft_integration.maybe_setup_simulated_fault()
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.tensor([_TRAIN_START_TIME], dtype=torch.double, device='cuda')
torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
app_metrics = {}
app_metrics['app_start_time'] = round(_TRAIN_START_TIME * 1000.0)
app_metrics['app_model_init_start_time'] = round(_TRAIN_START_TIME * 1000.0)
print_rank_0(
'time to initialize megatron (seconds): {:.3f}'.format(time.time() - _TRAIN_START_TIME)
)
print_datetime('after megatron is initialized')
app_metrics['app_model_init_finish_time'] = one_logger_utils.get_timestamp_in_ms()
# Track E2E metrics on pretrain start
one_logger_utils.on_pretrain_start()
# Context used for persisting some state between checkpoint saves.
if args.non_persistent_ckpt_type == 'local':
try:
from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import (
LocalCheckpointManager,
)
from nvidia_resiliency_ext.checkpointing.local.replication.group_utils import (
parse_group_sequence,
GroupWrapper,
)
from nvidia_resiliency_ext.checkpointing.local.replication.strategies import (
CliqueReplicationStrategy,
)
except ModuleNotFoundError:
raise RuntimeError(
"The 'nvidia_resiliency_ext' module is required for local "
"checkpointing but was not found. Please ensure it is installed."
)
if args.replication:
repl_strategy = CliqueReplicationStrategy.from_replication_params(
args.replication_jump, args.replication_factor
)
else:
repl_strategy = None
checkpointing_context = {
'local_checkpoint_manager': LocalCheckpointManager(
args.non_persistent_local_ckpt_dir, repl_strategy=repl_strategy
)
}
else:
checkpointing_context = {}
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
model_provider, model_type, checkpointing_context=checkpointing_context
)
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate ' 'scheduler are built')
config = get_model_config(model[0])
# Data stuff.
app_metrics['app_build_dataiters_start_time'] = one_logger_utils.get_timestamp_in_ms()
timers('train/valid/test-data-iterators-setup', log_level=0).start(barrier=True)
if args.virtual_pipeline_model_parallel_size is not None:
train_data_iterator = []
valid_data_iterator = []
test_data_iterator = []
for vp_stage in range(len(model)):
dataset_provider_parameters = inspect.signature(train_valid_test_dataset_provider).parameters
assert "vp_stage" in dataset_provider_parameters, \
"vp_stage must be a kwarg in train_valid_test_dataset_provider when using virtual pipeline parallelism"
vp_stage_train_valid_test_dataset_provider = \
functools.partial(train_valid_test_dataset_provider, vp_stage=vp_stage)
if getattr(train_valid_test_dataset_provider, 'is_distributed', False):
vp_stage_train_valid_test_dataset_provider.is_distributed = True
iterators = build_train_valid_test_data_iterators(
vp_stage_train_valid_test_dataset_provider
)
train_data_iterator.append(iterators[0])
valid_data_iterator.append(iterators[1])
test_data_iterator.append(iterators[2])
else:
train_data_iterator, valid_data_iterator, test_data_iterator = (
build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
)
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built')
app_metrics['app_build_dataiters_finish_time'] = one_logger_utils.get_timestamp_in_ms()
# Track if training is enabled. Can only be done once args.do_train is assigned after dataloader is built.
one_logger_utils.track_config_flags(
args.train_iters,
args.skip_train,
args.do_train,
args.do_valid,
args.do_test,
args.dataloader_type,
args.retro_project_dir,
args.retro_cyclic_train_iters,
)
# Print setup timing.
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'], barrier=True)
one_logger = get_one_logger()
one_logger and one_logger.log_metrics(app_metrics)
wandb_writer = get_wandb_writer()
if wandb_writer:
# Add job name to the wandb config to make it easier to run more singleton dependency jobs.
wandb_writer.config.update({'slurm_job_name': os.getenv("SLURM_JOB_NAME", "N/A")})
if not args.skip_train:
print_rank_0('training ...')
if args.dataloader_type == 'cyclic' and args.retro_project_dir:
assert args.retro_cyclic_train_iters is not None
args.train_iters = args.retro_cyclic_train_iters
print_rank_0("retro cyclic train iters : %d" % args.train_iters)
iteration = 0
if args.do_train and args.train_iters > 0:
iteration, num_floating_point_operations_so_far = train(
forward_step_func,
model,
optimizer,
opt_param_scheduler,
train_data_iterator,
valid_data_iterator,
process_non_loss_data_func,
config,
checkpointing_context,
non_loss_data_func,
)
print_datetime('after training is done')
if args.save and iteration != 0 and iteration % args.save_interval != 0:
save_checkpoint(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
preprocess_common_state_dict_fn=preprocess_common_state_dict,
)
one_logger and one_logger.log_metrics(
{'app_train_loop_finish_time': one_logger_utils.get_timestamp_in_ms()}
)
else:
print_rank_0('skipping training (--skip-train is on) ...')
iteration = args.iteration
if args.do_valid:
prefix = f'iteration {iteration} on validation set'
if getattr(args, 'perform_rl_step', False):
rl_utils.evaluate_and_print_results_rl(
valid_data_iterator, model, optimizer,
iteration, write_to_tensorboard=not args.skip_train
)
else:
evaluate_and_print_results(
prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func, config,
verbose=True, write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func
)
if args.do_test:
prefix = f'iteration {iteration} on test set'
evaluate_and_print_results(
prefix,
forward_step_func,
test_data_iterator,
model,
iteration,
process_non_loss_data_func,
config,
verbose=True,
write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func,
)
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=True, terminate=True)
ft_integration.on_checkpointing_end(is_async_finalization=True)
one_logger and one_logger.log_metrics(
{'app_finish_time': one_logger_utils.get_timestamp_in_ms()}
)
ft_integration.shutdown()
one_logger_utils.finish()
def update_train_iters(args):
# For iteration-based training, we don't need to do anything
if args.train_iters:
return
# Constant batch size with sample-based training.
if args.rampup_batch_size is None:
args.train_iters = args.train_samples // args.global_batch_size
else:
# Sample based training with rampup batch size.
iterations = 0
consumed_samples = 0
# Rampup phase.
while (
consumed_samples <= int(args.rampup_batch_size[2])
and consumed_samples <= args.train_samples
):
update_num_microbatches(consumed_samples, consistency_check=False)
consumed_samples += get_current_global_batch_size()
iterations += 1
# Reset
update_num_microbatches(0, consistency_check=False)
# Constant phase
# Note that we throw away any partial last batch.
if args.train_samples > consumed_samples:
iterations += (args.train_samples - consumed_samples) // args.global_batch_size
args.train_iters = iterations
print_rank_0(f'setting training iterations to {args.train_iters}')
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
args = get_args()
args.model_type = model_type
# Build model.
def build_model():
if (
mpu.get_pipeline_model_parallel_world_size() > 1
and args.virtual_pipeline_model_parallel_size is not None
):
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=i)
post_process = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=i)
this_model = model_provider_func(
pre_process=pre_process, post_process=post_process, vp_stage=i)
this_model.model_type = model_type
this_model.vp_stage = i
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
model = model_provider_func(pre_process=pre_process, post_process=post_process)
model.model_type = model_type
return model
if args.init_model_with_meta_device:
with torch.device('meta'):
model = build_model()
else:
model = build_model()
if not isinstance(model, list):
model = [model]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
num_parameters = sum(
[sum([p.nelement() for p in model_module.parameters()]) for model_module in model]
)
if mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0:
print(
' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
num_parameters,
),
flush=True,
)
# GPU allocation.
# For FSDP2, we don't allocate GPU memory here. We allocate GPU memory
# in the fully_shard function of FSDP2 instead.
if (
not (args.use_torch_fsdp2 and args.use_cpu_initialization)
and not args.init_model_with_meta_device
):
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16 or args.bf16:
config = get_model_config(model[0])
model = [Float16Module(config, model_module) for model_module in model]
# Materialize tensors on meta device (GPU allocation) if not using FSDP2 and not using Megatron FSDP.
if args.init_model_with_meta_device and not args.use_torch_fsdp2 and not args.use_megatron_fsdp:
#for model_module in model:
model = [to_empty_if_meta_device(model_module, device=torch.device("cuda")) for model_module in model]
# Before TE2.x: The model_module.bfloat16()/model_module.half() above will call the inplace
# copy of TE's Float8Tensor, which will write an unwanted value (amax calculated
# from the current fp8 param) to its amax_history. The below function will correct
# the amax_history back.
# After TE2.x: Below function is an empty function and does nothing.
correct_amax_history_if_needed(model)
if wrap_with_ddp:
if args.use_torch_fsdp2:
assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0"
DP = torch_FSDP
elif args.use_megatron_fsdp:
DP = megatron_FSDP
else:
DP = DDP
config = get_model_config(model[0])
if getattr(args, "use_torch_fsdp2", False):
reshard_after_forward = getattr(args, "torch_fsdp2_reshard_after_forward", True)
ddp_config = TorchFullyShardedDataParallelConfig(reshard_after_forward=reshard_after_forward)
else:
kwargs = {}
for f in dataclasses.fields(DistributedDataParallelConfig):
if hasattr(args, f.name):
kwargs[f.name] = getattr(args, f.name)
kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32
kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad
kwargs['check_for_large_grads'] = args.check_for_large_grads
if args.ddp_num_buckets is not None:
assert args.ddp_bucket_size is None, \
"Cannot specify both --ddp-num-buckets and --ddp-bucket-size"
assert args.ddp_num_buckets > 0, \
"--ddp-num-buckets must be greater than 0"
kwargs['bucket_size'] = num_parameters // args.ddp_num_buckets
else:
kwargs['bucket_size'] = args.ddp_bucket_size
kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw
kwargs['average_in_collective'] = args.ddp_average_in_collective
if args.use_megatron_fsdp and args.use_precision_aware_optimizer:
kwargs["preserve_fp32_weights"] = False
ddp_config = DistributedDataParallelConfig(**kwargs)
# In the Megatron FSDP and DDP use path, we need to initialize the bucket size.
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if ddp_config.bucket_size is None:
ddp_config.bucket_size = max(
40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True)
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if not ddp_config.overlap_grad_reduce:
ddp_config.bucket_size = None
with torch.cuda.stream(torch.cuda.Stream()):
model = [
DP(
config=config,
ddp_config=ddp_config,
module=model_chunk,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0)