Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
144 commits
Select commit Hold shift + click to select a range
928a200
Remove deprecated destination argument to state_dict functions and ma…
jaredcasper Jul 21, 2022
5df9e1f
Remove old merge tool.
jaredcasper Jul 26, 2022
c464a10
Merge branch 'del_merge' into 'main'
jaredcasper Jul 26, 2022
e36cdd7
added a flag to be able to switch between pytorch and ring exchange p2p
shoeybi Jul 26, 2022
8df49e7
Merge branch 'add_ring_exchange_flag' into 'main'
jaredcasper Jul 27, 2022
76db958
support for all mask in fused kernel + avoiding inplace operation in …
kvareddy Jul 28, 2022
189e72a
Merge branch 'fused_softmax_kernel_fixes' into 'main'
jaredcasper Jul 29, 2022
45f4ee5
yttm + BytelevelBPE + setencepeice tokenizer support
kvareddy Aug 4, 2022
b7b2d6a
fix a bug for size mismatch
pxuab Aug 6, 2022
83d7867
Merge branch 'beam_search' into 'main'
jaredcasper Aug 6, 2022
a44360e
adress review comments
kvareddy Aug 8, 2022
77efccc
Timing levels
shoeybi Aug 10, 2022
d207391
Merge branch 'timing' into 'main'
jaredcasper Aug 10, 2022
27bc133
fixed grad scalar warning so it only prints it for fp16
shoeybi Aug 16, 2022
91384a5
Merge branch 'fix_grad_scalar_warning' into 'main'
jaredcasper Aug 16, 2022
aaa5715
fixed grad scalar warning for bf16
shoeybi Aug 16, 2022
d63c254
Merge branch 'fix_grad_scalar_warning' into 'main'
jaredcasper Aug 16, 2022
e38d41c
Memory safety checks were incorrect for the tokens_to_generate=0 case
RPrenger Sep 2, 2022
8b68628
Merge branch 'fixing_safety' into 'main'
jaredcasper Sep 12, 2022
1afe354
Merge branch 'state_dict_fix' into 'main'
jaredcasper Sep 12, 2022
981c3df
support separate datasets for train, valid and test
anmolgupt Sep 22, 2022
fabad46
Clean up licensing.
jaredcasper Sep 23, 2022
28ba253
Merge branch 'licensing' into 'main'
jaredcasper Sep 23, 2022
2e6a46e
Start Megatron-Core with vocab parallel cross entropy
jaredcasper Sep 23, 2022
209f91c
Bring mpu.data into megatron.core.
jaredcasper Sep 23, 2022
c2ea914
Move layers from mpu to core.tensor_parallel.
jaredcasper Sep 23, 2022
5942af9
Alias core.parallel_state as mpu and use it throughout code. RIP mpu.
jaredcasper Sep 23, 2022
a94d0a6
Move get_num_layers into transformer.py.
jaredcasper Sep 24, 2022
e00a1ca
Improve docstrings, destory global memory buffer.
jaredcasper Sep 24, 2022
cbf780d
Update exports.
jaredcasper Sep 26, 2022
e7e9972
Check for pipeline_parallel > 2 when using interleaving.
jaredcasper Sep 26, 2022
5f4ddd9
Add basic setup.py for core.
jaredcasper Sep 26, 2022
77753d0
Small fixes.
jaredcasper Sep 27, 2022
55817ec
Correct some merge errors.
jaredcasper Sep 27, 2022
2366716
Error, not warn, if gradient_accumulation_fusion is requested but not…
jaredcasper Sep 27, 2022
07916bf
Support gradient accumulation fusion in fp16.
jaredcasper Sep 27, 2022
57bfa7c
Perform distributed optimizer's all-gather in param dtype (instead of…
lmcafee-nvidia Sep 30, 2022
fc7f4f0
Merge branch 'lmcafee/byte-buffer' into 'main'
jaredcasper Sep 30, 2022
41276b6
Merge branch 'main' into nmt-main
kvareddy Oct 3, 2022
b9ae7ba
Merge branch 'nmt-main' into 'main'
jaredcasper Oct 3, 2022
05d731a
Setting up code coverage
shanmugamr1992 Oct 4, 2022
fb8c09e
Code coverage setup
shanmugamr1992 Oct 4, 2022
cbf8250
different encoder/decoder num-layers support
kvareddy Oct 4, 2022
6ab70f5
Adding some basic unit tests
shanmugamr1992 Oct 5, 2022
63e5994
support for separate dataset files for train, valid and test
Oct 5, 2022
2514892
fixed the timer issue for the case with no pipelining
shoeybi Oct 5, 2022
96b7559
Merge branch 'fix_backward_no_pipeline' into 'main'
jaredcasper Oct 6, 2022
6defe18
Setter for pipeline parallel split rank, remove print
ericharper Oct 6, 2022
6d41789
Merge branch 'changes_for_nemo' into 'core'
jaredcasper Oct 6, 2022
b69e219
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
136cf03
Merge branch 'core' of ssh://gitlab-master.nvidia.com:12051/ADLR/mega…
shanmugamr1992 Oct 6, 2022
056fc7c
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
423623c
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
56934a2
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
74ee8c0
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
44c94f5
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
e9f2000
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
4ec95a2
Adding some basic unit tests
shanmugamr1992 Oct 7, 2022
11392f0
Changes'
shanmugamr1992 Oct 7, 2022
94dd94e
Changes'
shanmugamr1992 Oct 7, 2022
2fd9ea1
Code covearage
shanmugamr1992 Oct 7, 2022
c0329d8
Code covearage
shanmugamr1992 Oct 7, 2022
f861467
Code covearage
shanmugamr1992 Oct 7, 2022
45cd4e0
removed assert for the case of evaluation only without training
Oct 10, 2022
69f3249
address review comments
kvareddy Oct 11, 2022
a95fda7
Merge branch 'nmt-main' into 'main'
jaredcasper Oct 11, 2022
c7d57ff
Merge branch 'anmolg/validation_1' into 'main'
jaredcasper Oct 11, 2022
8b94a16
Adding proper test cases
shanmugamr1992 Oct 13, 2022
8806ba7
Merge branch 'properTest' into 'core'
jaredcasper Oct 13, 2022
2a86fa2
Merge branch 'main' into core
jaredcasper Oct 13, 2022
5da3bb9
Merge branch 'core-merge-main' into 'core'
jaredcasper Oct 14, 2022
dbed5e0
inverse_square_root learning param schedule
kvareddy Oct 14, 2022
bdd9731
Remove noop used to try to force scheduling and check for environment…
jaredcasper Oct 14, 2022
d3a416c
Merge branch 'core-noop' into 'core'
jaredcasper Oct 14, 2022
abf60f7
Merge branch 'nmt-main' into 'main'
jaredcasper Oct 14, 2022
544e250
Disable newline after colon
pxuab Oct 20, 2022
f4a8b1d
Merge branch 'disable_newline_after_colon' into 'main'
jaredcasper Oct 20, 2022
2fdd54e
Sending in prompts with the wrong type hangs the server. This is a c…
RPrenger Oct 27, 2022
fdc801e
Merge branch 'check_prompts_is_list' into 'main'
jaredcasper Nov 2, 2022
42c4071
Merge branch 'core' into 'main'
jaredcasper Nov 2, 2022
e0a12fe
Fix merge error.
jaredcasper Nov 8, 2022
1a26b29
Merge branch 'core-fix' into 'main'
jaredcasper Nov 8, 2022
fabd3e4
ViT Backbone Tensor Shape Fix
yaoyu-33 Nov 10, 2022
b4297c6
Merge branch 'yuya/vit_fix' into 'main'
jaredcasper Nov 10, 2022
c3e688d
Support for variable sequence lengths across micro-batches
kvareddy Nov 11, 2022
1ad1e1b
Merge branch 'nmt-main' into 'main'
jaredcasper Nov 11, 2022
7fc9611
Data Preprocessing Optimizations
kvareddy Nov 17, 2022
7016945
Merge branch 'nmt-main' into 'main'
mchrzanowski Nov 17, 2022
6d45a90
Fix DropPath for hidden shape [s, b, h]
yaoyu-33 Nov 22, 2022
d48d95a
Open sourcing lm detoxification code
boxin-wbx Nov 24, 2022
8ce8256
Merge branch 'boxin/detoxify_lm_cr' into 'main'
boxin-wbx Nov 24, 2022
84a43b1
bug fixes in partitioned data preprocessor
mchrzanowski Nov 29, 2022
b24f4ad
Merge branch 'partition_fixes' into 'main'
jaredcasper Nov 29, 2022
52e6368
Merge branch 'yuya/drop_path_fix' into 'main'
jaredcasper Nov 29, 2022
f298a85
Fix typo
janEbert Dec 13, 2022
df3ca00
Set SentencePiece tokenizer global variable
janEbert Dec 13, 2022
072b3a6
Refactor masked LM sampling style selection
janEbert Dec 13, 2022
2c94801
Add more masked LM sampling styles
janEbert Dec 13, 2022
e2bc55c
Allow Prefix-LM style masked LM
janEbert Dec 13, 2022
53f0300
Add UL2 pretraining for T5 model
janEbert Dec 13, 2022
35f232c
Refactor span merging
janEbert Dec 13, 2022
6bd44e7
Allow non-causal GPT models
janEbert Dec 13, 2022
9304618
Support UL2 for decoder-only models
janEbert Dec 13, 2022
9add693
Add custom exceptions
janEbert Dec 14, 2022
20b7acd
Error out on too long sequences
janEbert Dec 14, 2022
b5bef77
Remove additional sequence truncation
janEbert Dec 14, 2022
3e46e3c
Prefer array-from-list creation
janEbert Dec 14, 2022
7bb655c
Remove redundant imports
janEbert Jan 3, 2023
4474556
Fix sometimes not inserting prefixes
janEbert Jan 3, 2023
6f88858
Do not insert `extra_id` tokens for PrefixLM task
janEbert Jan 3, 2023
69fa541
Document `max_seq_length_dec` argument
janEbert Jan 3, 2023
020dd64
Skip redundant computations
janEbert Jan 3, 2023
1820f2b
Fix PrefixLM mean location
janEbert Jan 3, 2023
c4a5b40
Pad decoder-only inputs to same length
janEbert Jan 3, 2023
324d70d
Fix decoder-only attention mask shape
janEbert Jan 3, 2023
eb3dd43
Fix `max_ngrams` for normal sampling style
janEbert Jan 23, 2023
2d1b32d
Do not limit `max_predictions_per_seq`
janEbert Jan 23, 2023
10ef283
Calculate and use amount of filtered tokens
janEbert Jan 23, 2023
6b29f42
Document normal sampling style
janEbert Jan 23, 2023
27fc9fb
Fix PrefixLM possible spans calculation
janEbert Jan 23, 2023
359742e
Avoid mutable pointer in arguments
janEbert Jan 23, 2023
11e3d24
Allow passing callable for getting `model_type`
janEbert Jan 23, 2023
2a67e97
Fix getting model type
janEbert Jan 23, 2023
2dc7587
Allow recognizing when UL2 is used
janEbert Jan 23, 2023
7a4a94d
Only add UL2 tokens if using UL2 pretrain script
janEbert Jan 23, 2023
3c852c0
Support UL2 tokens for all tokenizers
janEbert Jan 23, 2023
c03a7be
Add SEP token to GPT tokenizer if using UL2
janEbert Jan 23, 2023
959daaa
Fix enum name
janEbert Jan 23, 2023
49f6b0f
Fix private UL2 argument default value
janEbert Jan 23, 2023
aa9a1c7
Use binary search for PrefixLM first tail index
janEbert Jan 24, 2023
d906cc1
Calculate n-gram indices lazily
janEbert Jan 24, 2023
3805df7
Prefer list comprehensions
janEbert Jan 24, 2023
69a3519
Merge branch 'janEbert-ul2' into ul2-merge
RaymondLi0 Feb 6, 2023
f5d0df1
support UL2 with HFtokenizer
RaymondLi0 Feb 7, 2023
9f024dc
scale normal distribution variance with its mean, and truncate the di…
RaymondLi0 Feb 7, 2023
f845e38
in the decoder-only case, truncate the masked sequence
RaymondLi0 Feb 7, 2023
ea79fe8
refactor: UL2Dataset does not inherit T5Dataset anymore
RaymondLi0 Feb 7, 2023
d1aed24
fix: mpu.get_cuda_rng_tracker() -> tensor_parallel.get_cuda_rng_track…
RaymondLi0 Feb 10, 2023
e712e7e
remove debug print
RaymondLi0 Feb 10, 2023
458ecf8
move is_ul2 to arguments
RaymondLi0 Feb 14, 2023
b9fa5f7
adjust attention-mask in generation for prefix-lm models
RaymondLi0 Feb 15, 2023
3a305eb
fix assert in tokenizer
RaymondLi0 Feb 17, 2023
96d18f7
Merge branch 'ul2-merge' of github.com:bigcode-project/Megatron-LM in…
RaymondLi0 Feb 17, 2023
fe05ccd
fix pretrain_ul2 for causal-decoder
Mar 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,32 @@ def _add_logging_args(parser):
help='If set, calculate and log parameters norm.')
group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--timing-log-level', type=int,
default=0, choices=range(0,3),
help='Granularity level to measure and report timing. '
' 0: report only iteration time and make sure timing '
' does not introduce extra overhead.'
' 1: report timing for operations that are executed '
' very limited times (basically once) during '
' each iteration (such as gradient all-reduce) '
' 2: report timing for operations that migh be '
' executed numerous times during each iteration. '
'Note that setting the level to 1 or 2 might '
'cause increase in iteration time.')
group.add_argument('--no-barrier-with-level-1-timing', action='store_false',
help='If not set, use barrier with level 1 time '
'measurements. Note that this is up to the user '
'to make sure calling barrier with their timers '
'will not result in hangs. This can happen if for '
'example the user adds a level 1 timer that is not '
'called by all ranks.',
dest='barrier_with_L1_time')
group.add_argument('--timing-log-option', type=str, default='minmax',
choices=['max', 'minmax', 'all'],
help='Options for logging timing:'
' max: report the max timing across all ranks'
' minmax: report min and max timings across all ranks'
' all: report timings of all ranks.')
group.add_argument('--tensorboard-log-interval', type=int, default=1,
help='Report to tensorboard interval.')
group.add_argument('--tensorboard-queue-size', type=int, default=1000,
Expand Down
90 changes: 5 additions & 85 deletions megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

import os
import sys
import time
from functools import reduce
import operator
import torch

from megatron import dist_signal_handler
from megatron.tokenizer import build_tokenizer
from .microbatches import build_num_microbatches_calculator
from .timers import Timers

_GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
Expand Down Expand Up @@ -108,7 +108,7 @@ def set_global_variables(args):
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers()
_set_timers(args)
_set_global_memory_buffer()

if args.exit_signal_handler:
Expand Down Expand Up @@ -182,11 +182,12 @@ def _set_adlr_autoresume(args):
_GLOBAL_ADLR_AUTORESUME = AutoResume


def _set_timers():
def _set_timers(args):
"""Initialize timers."""
global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers()
_GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)


def _set_global_memory_buffer():
"""Initialize global buffer"""
Expand All @@ -205,87 +206,6 @@ def _ensure_var_is_not_initialized(var, name):
assert var is None, '{} is already initialized.'.format(name)


class _Timer:
"""Timer."""

def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()

def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True

def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False

def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False

def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_


class Timers:
"""Group of timers."""

def __init__(self):
self.timers = {}

def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]

def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '-time', value, iteration)

def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)


class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Expand Down
20 changes: 12 additions & 8 deletions megatron/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,17 +532,20 @@ def reduce_model_grads(self, args, timers):
"""

# All-reduce layer-norm grads (for sequence parallelism).
timers('backward-layernorm-all-reduce').start()
timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_layernorm_grads(args)
timers('backward-layernorm-all-reduce').stop()
timers('layernorm-grads-all-reduce').stop()

# All-reduce embedding grads.
timers('backward-embedding-all-reduce').start()
timers('embedding-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()
timers('embedding-grads-all-reduce').stop()

# Reduce-scatter setup.
timers('backward-params-all-reduce').start()
timers('grads-reduce-scatter', log_level=1).start(
barrier=args.barrier_with_L1_time)
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
Expand All @@ -563,7 +566,7 @@ def reduce_model_grads(self, args, timers):
group = data_parallel_group,
)

timers('backward-params-all-reduce').stop()
timers('grads-reduce-scatter').stop()


def gather_model_params(self, args, timers):
Expand All @@ -575,7 +578,8 @@ def gather_model_params(self, args, timers):
can be copied from param.main_grad to param.
"""

timers('backward-params-all-gather').start()
timers('params-all-gather', log_level=1).start(
barrier=args.barrier_with_L1_time)

data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group()
Expand All @@ -602,7 +606,7 @@ def gather_model_params(self, args, timers):
for param in param_map:
param.detach().copy_(param.main_grad)

timers('backward-params-all-gather').stop()
timers('params-all-gather').stop()


def _collect_main_grad_data_for_unscaling(self):
Expand Down
45 changes: 29 additions & 16 deletions megatron/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,21 +294,24 @@ def reduce_model_grads(self, args, timers):
"""All-reduce all grads, and all-reduce embeddings."""

# All-reduce layer-norm grads (for sequence parallelism).
timers('backward-layernorm-all-reduce').start()
timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_layernorm_grads(args)
timers('backward-layernorm-all-reduce').stop()
timers('layernorm-grads-all-reduce').stop()

# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
timers('grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
for model in self.models:
model.allreduce_gradients()
timers('backward-params-all-reduce').stop()
timers('grads-all-reduce').stop()

# All-reduce embedding grads.
timers('backward-embedding-all-reduce').start()
timers('embedding-grads-all-reduce', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()
timers('embedding-grads-all-reduce').stop()


class MixedPrecisionOptimizer(MegatronOptimizer):
Expand Down Expand Up @@ -416,7 +419,8 @@ def _unscale_main_grads_and_check_for_nan(self):
def step(self, args, timers):

# Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start()
timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop()

Expand All @@ -425,7 +429,8 @@ def step(self, args, timers):
if self.grad_scaler:

# Unscale and check for inf/nan.
timers('optimizer-unscale-and-check-inf').start()
timers('optimizer-unscale-and-check-inf', log_level=1).start(
barrier=args.barrier_with_L1_time)
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()

Expand All @@ -438,25 +443,29 @@ def step(self, args, timers):
return False, None, None

# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
timers('optimizer-clip-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()

# Count the zeros in the grads.
timers('optimizer-count-zeros').start()
timers('optimizer-count-zeros', log_level=1).start(
barrier=args.barrier_with_L1_time)
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
timers('optimizer-count-zeros').stop()

# Step the optimizer.
timers('optimizer-inner-step').start()
timers('optimizer-inner-step', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.optimizer.step()
timers('optimizer-inner-step').stop()

# Update params from main params.
timers('optimizer-copy-main-to-model-params').start()
timers('optimizer-copy-main-to-model-params', log_level=1).start(
barrier=args.barrier_with_L1_time)
self._copy_main_params_to_model_params()
timers('optimizer-copy-main-to-model-params').stop()

Expand Down Expand Up @@ -725,7 +734,8 @@ def step(self, args, timers):
Always return successful since there is no overflow."""

# Copy main_grads to grads.
timers('optimizer-copy-to-main-grad').start()
timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
if self.params_have_main_grad:
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
Expand All @@ -739,20 +749,23 @@ def step(self, args, timers):
timers('optimizer-copy-to-main-grad').stop()

# Clip gradients.
timers('optimizer-clip-main-grad').start()
timers('optimizer-clip-main-grad', log_level=1).start(
barrier=args.barrier_with_L1_time)
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()

# count the zeros in the grads
timers('optimizer-count-zeros').start()
timers('optimizer-count-zeros', log_level=1).start(
barrier=args.barrier_with_L1_time)
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
timers('optimizer-count-zeros').stop()

# Update parameters.
timers('optimizer-inner-step').start()
timers('optimizer-inner-step', log_level=1).start(
barrier=args.barrier_with_L1_time)
self.optimizer.step()
timers('optimizer-inner-step').stop()

Expand Down
Loading