Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
144 commits
Select commit Hold shift + click to select a range
928a200
Remove deprecated destination argument to state_dict functions and ma…
jaredcasper Jul 21, 2022
5df9e1f
Remove old merge tool.
jaredcasper Jul 26, 2022
c464a10
Merge branch 'del_merge' into 'main'
jaredcasper Jul 26, 2022
e36cdd7
added a flag to be able to switch between pytorch and ring exchange p2p
shoeybi Jul 26, 2022
8df49e7
Merge branch 'add_ring_exchange_flag' into 'main'
jaredcasper Jul 27, 2022
76db958
support for all mask in fused kernel + avoiding inplace operation in …
kvareddy Jul 28, 2022
189e72a
Merge branch 'fused_softmax_kernel_fixes' into 'main'
jaredcasper Jul 29, 2022
45f4ee5
yttm + BytelevelBPE + setencepeice tokenizer support
kvareddy Aug 4, 2022
b7b2d6a
fix a bug for size mismatch
pxuab Aug 6, 2022
83d7867
Merge branch 'beam_search' into 'main'
jaredcasper Aug 6, 2022
a44360e
adress review comments
kvareddy Aug 8, 2022
77efccc
Timing levels
shoeybi Aug 10, 2022
d207391
Merge branch 'timing' into 'main'
jaredcasper Aug 10, 2022
27bc133
fixed grad scalar warning so it only prints it for fp16
shoeybi Aug 16, 2022
91384a5
Merge branch 'fix_grad_scalar_warning' into 'main'
jaredcasper Aug 16, 2022
aaa5715
fixed grad scalar warning for bf16
shoeybi Aug 16, 2022
d63c254
Merge branch 'fix_grad_scalar_warning' into 'main'
jaredcasper Aug 16, 2022
e38d41c
Memory safety checks were incorrect for the tokens_to_generate=0 case
RPrenger Sep 2, 2022
8b68628
Merge branch 'fixing_safety' into 'main'
jaredcasper Sep 12, 2022
1afe354
Merge branch 'state_dict_fix' into 'main'
jaredcasper Sep 12, 2022
981c3df
support separate datasets for train, valid and test
anmolgupt Sep 22, 2022
fabad46
Clean up licensing.
jaredcasper Sep 23, 2022
28ba253
Merge branch 'licensing' into 'main'
jaredcasper Sep 23, 2022
2e6a46e
Start Megatron-Core with vocab parallel cross entropy
jaredcasper Sep 23, 2022
209f91c
Bring mpu.data into megatron.core.
jaredcasper Sep 23, 2022
c2ea914
Move layers from mpu to core.tensor_parallel.
jaredcasper Sep 23, 2022
5942af9
Alias core.parallel_state as mpu and use it throughout code. RIP mpu.
jaredcasper Sep 23, 2022
a94d0a6
Move get_num_layers into transformer.py.
jaredcasper Sep 24, 2022
e00a1ca
Improve docstrings, destory global memory buffer.
jaredcasper Sep 24, 2022
cbf780d
Update exports.
jaredcasper Sep 26, 2022
e7e9972
Check for pipeline_parallel > 2 when using interleaving.
jaredcasper Sep 26, 2022
5f4ddd9
Add basic setup.py for core.
jaredcasper Sep 26, 2022
77753d0
Small fixes.
jaredcasper Sep 27, 2022
55817ec
Correct some merge errors.
jaredcasper Sep 27, 2022
2366716
Error, not warn, if gradient_accumulation_fusion is requested but not…
jaredcasper Sep 27, 2022
07916bf
Support gradient accumulation fusion in fp16.
jaredcasper Sep 27, 2022
57bfa7c
Perform distributed optimizer's all-gather in param dtype (instead of…
lmcafee-nvidia Sep 30, 2022
fc7f4f0
Merge branch 'lmcafee/byte-buffer' into 'main'
jaredcasper Sep 30, 2022
41276b6
Merge branch 'main' into nmt-main
kvareddy Oct 3, 2022
b9ae7ba
Merge branch 'nmt-main' into 'main'
jaredcasper Oct 3, 2022
05d731a
Setting up code coverage
shanmugamr1992 Oct 4, 2022
fb8c09e
Code coverage setup
shanmugamr1992 Oct 4, 2022
cbf8250
different encoder/decoder num-layers support
kvareddy Oct 4, 2022
6ab70f5
Adding some basic unit tests
shanmugamr1992 Oct 5, 2022
63e5994
support for separate dataset files for train, valid and test
Oct 5, 2022
2514892
fixed the timer issue for the case with no pipelining
shoeybi Oct 5, 2022
96b7559
Merge branch 'fix_backward_no_pipeline' into 'main'
jaredcasper Oct 6, 2022
6defe18
Setter for pipeline parallel split rank, remove print
ericharper Oct 6, 2022
6d41789
Merge branch 'changes_for_nemo' into 'core'
jaredcasper Oct 6, 2022
b69e219
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
136cf03
Merge branch 'core' of ssh://gitlab-master.nvidia.com:12051/ADLR/mega…
shanmugamr1992 Oct 6, 2022
056fc7c
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
423623c
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
56934a2
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
74ee8c0
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
44c94f5
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
e9f2000
Adding some basic unit tests
shanmugamr1992 Oct 6, 2022
4ec95a2
Adding some basic unit tests
shanmugamr1992 Oct 7, 2022
11392f0
Changes'
shanmugamr1992 Oct 7, 2022
94dd94e
Changes'
shanmugamr1992 Oct 7, 2022
2fd9ea1
Code covearage
shanmugamr1992 Oct 7, 2022
c0329d8
Code covearage
shanmugamr1992 Oct 7, 2022
f861467
Code covearage
shanmugamr1992 Oct 7, 2022
45cd4e0
removed assert for the case of evaluation only without training
Oct 10, 2022
69f3249
address review comments
kvareddy Oct 11, 2022
a95fda7
Merge branch 'nmt-main' into 'main'
jaredcasper Oct 11, 2022
c7d57ff
Merge branch 'anmolg/validation_1' into 'main'
jaredcasper Oct 11, 2022
8b94a16
Adding proper test cases
shanmugamr1992 Oct 13, 2022
8806ba7
Merge branch 'properTest' into 'core'
jaredcasper Oct 13, 2022
2a86fa2
Merge branch 'main' into core
jaredcasper Oct 13, 2022
5da3bb9
Merge branch 'core-merge-main' into 'core'
jaredcasper Oct 14, 2022
dbed5e0
inverse_square_root learning param schedule
kvareddy Oct 14, 2022
bdd9731
Remove noop used to try to force scheduling and check for environment…
jaredcasper Oct 14, 2022
d3a416c
Merge branch 'core-noop' into 'core'
jaredcasper Oct 14, 2022
abf60f7
Merge branch 'nmt-main' into 'main'
jaredcasper Oct 14, 2022
544e250
Disable newline after colon
pxuab Oct 20, 2022
f4a8b1d
Merge branch 'disable_newline_after_colon' into 'main'
jaredcasper Oct 20, 2022
2fdd54e
Sending in prompts with the wrong type hangs the server. This is a c…
RPrenger Oct 27, 2022
fdc801e
Merge branch 'check_prompts_is_list' into 'main'
jaredcasper Nov 2, 2022
42c4071
Merge branch 'core' into 'main'
jaredcasper Nov 2, 2022
e0a12fe
Fix merge error.
jaredcasper Nov 8, 2022
1a26b29
Merge branch 'core-fix' into 'main'
jaredcasper Nov 8, 2022
fabd3e4
ViT Backbone Tensor Shape Fix
yaoyu-33 Nov 10, 2022
b4297c6
Merge branch 'yuya/vit_fix' into 'main'
jaredcasper Nov 10, 2022
c3e688d
Support for variable sequence lengths across micro-batches
kvareddy Nov 11, 2022
1ad1e1b
Merge branch 'nmt-main' into 'main'
jaredcasper Nov 11, 2022
7fc9611
Data Preprocessing Optimizations
kvareddy Nov 17, 2022
7016945
Merge branch 'nmt-main' into 'main'
mchrzanowski Nov 17, 2022
6d45a90
Fix DropPath for hidden shape [s, b, h]
yaoyu-33 Nov 22, 2022
d48d95a
Open sourcing lm detoxification code
boxin-wbx Nov 24, 2022
8ce8256
Merge branch 'boxin/detoxify_lm_cr' into 'main'
boxin-wbx Nov 24, 2022
84a43b1
bug fixes in partitioned data preprocessor
mchrzanowski Nov 29, 2022
b24f4ad
Merge branch 'partition_fixes' into 'main'
jaredcasper Nov 29, 2022
52e6368
Merge branch 'yuya/drop_path_fix' into 'main'
jaredcasper Nov 29, 2022
f298a85
Fix typo
janEbert Dec 13, 2022
df3ca00
Set SentencePiece tokenizer global variable
janEbert Dec 13, 2022
072b3a6
Refactor masked LM sampling style selection
janEbert Dec 13, 2022
2c94801
Add more masked LM sampling styles
janEbert Dec 13, 2022
e2bc55c
Allow Prefix-LM style masked LM
janEbert Dec 13, 2022
53f0300
Add UL2 pretraining for T5 model
janEbert Dec 13, 2022
35f232c
Refactor span merging
janEbert Dec 13, 2022
6bd44e7
Allow non-causal GPT models
janEbert Dec 13, 2022
9304618
Support UL2 for decoder-only models
janEbert Dec 13, 2022
9add693
Add custom exceptions
janEbert Dec 14, 2022
20b7acd
Error out on too long sequences
janEbert Dec 14, 2022
b5bef77
Remove additional sequence truncation
janEbert Dec 14, 2022
3e46e3c
Prefer array-from-list creation
janEbert Dec 14, 2022
7bb655c
Remove redundant imports
janEbert Jan 3, 2023
4474556
Fix sometimes not inserting prefixes
janEbert Jan 3, 2023
6f88858
Do not insert `extra_id` tokens for PrefixLM task
janEbert Jan 3, 2023
69fa541
Document `max_seq_length_dec` argument
janEbert Jan 3, 2023
020dd64
Skip redundant computations
janEbert Jan 3, 2023
1820f2b
Fix PrefixLM mean location
janEbert Jan 3, 2023
c4a5b40
Pad decoder-only inputs to same length
janEbert Jan 3, 2023
324d70d
Fix decoder-only attention mask shape
janEbert Jan 3, 2023
eb3dd43
Fix `max_ngrams` for normal sampling style
janEbert Jan 23, 2023
2d1b32d
Do not limit `max_predictions_per_seq`
janEbert Jan 23, 2023
10ef283
Calculate and use amount of filtered tokens
janEbert Jan 23, 2023
6b29f42
Document normal sampling style
janEbert Jan 23, 2023
27fc9fb
Fix PrefixLM possible spans calculation
janEbert Jan 23, 2023
359742e
Avoid mutable pointer in arguments
janEbert Jan 23, 2023
11e3d24
Allow passing callable for getting `model_type`
janEbert Jan 23, 2023
2a67e97
Fix getting model type
janEbert Jan 23, 2023
2dc7587
Allow recognizing when UL2 is used
janEbert Jan 23, 2023
7a4a94d
Only add UL2 tokens if using UL2 pretrain script
janEbert Jan 23, 2023
3c852c0
Support UL2 tokens for all tokenizers
janEbert Jan 23, 2023
c03a7be
Add SEP token to GPT tokenizer if using UL2
janEbert Jan 23, 2023
959daaa
Fix enum name
janEbert Jan 23, 2023
49f6b0f
Fix private UL2 argument default value
janEbert Jan 23, 2023
aa9a1c7
Use binary search for PrefixLM first tail index
janEbert Jan 24, 2023
d906cc1
Calculate n-gram indices lazily
janEbert Jan 24, 2023
3805df7
Prefer list comprehensions
janEbert Jan 24, 2023
69a3519
Merge branch 'janEbert-ul2' into ul2-merge
RaymondLi0 Feb 6, 2023
f5d0df1
support UL2 with HFtokenizer
RaymondLi0 Feb 7, 2023
9f024dc
scale normal distribution variance with its mean, and truncate the di…
RaymondLi0 Feb 7, 2023
f845e38
in the decoder-only case, truncate the masked sequence
RaymondLi0 Feb 7, 2023
ea79fe8
refactor: UL2Dataset does not inherit T5Dataset anymore
RaymondLi0 Feb 7, 2023
d1aed24
fix: mpu.get_cuda_rng_tracker() -> tensor_parallel.get_cuda_rng_track…
RaymondLi0 Feb 10, 2023
e712e7e
remove debug print
RaymondLi0 Feb 10, 2023
458ecf8
move is_ul2 to arguments
RaymondLi0 Feb 14, 2023
b9fa5f7
adjust attention-mask in generation for prefix-lm models
RaymondLi0 Feb 15, 2023
3a305eb
fix assert in tokenizer
RaymondLi0 Feb 17, 2023
96d18f7
Merge branch 'ul2-merge' of github.com:bigcode-project/Megatron-LM in…
RaymondLi0 Feb 17, 2023
fe05ccd
fix pretrain_ul2 for causal-decoder
Mar 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support UL2 for decoder-only models
  • Loading branch information
janEbert committed Dec 13, 2022
commit 9304618d92c1b93039d93319005acb9c29b5eaa7
18 changes: 18 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import torch

from megatron.model.enums import UL2ModelType

def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
Expand Down Expand Up @@ -321,6 +323,17 @@ def validate_args(args, defaults={}):
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False

args.ul2_model_type = UL2ModelType(args.ul2_model_type)
if (
args.ul2_model_type is not UL2ModelType.encoder_decoder
and args.decoder_seq_length is not None
):
print(
f'WARNING: `--decoder_seq_length` is ignored when '
f'`--ul2-model-type` is not '
f'"{UL2ModelType.encoder_decoder.value}"!'
)


if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
Expand Down Expand Up @@ -1072,6 +1085,11 @@ def _add_vision_args(parser):
def _add_ul2_args(parser):
group = parser.add_argument_group(title="UL2")

group.add_argument('--ul2-model-type', type=str, default='ED',
choices=['ED', 'ND', 'CD'],
help='What type of model to use for UL2 pretraining. '
'ED = encoder-decoder; ND = non-causal decoder-only; '
'CD = causal decoder-only')
group.add_argument('--ul2-denoiser-ratios', nargs='+', type=float,
default=None,
help='Probability of each denoising objective to be '
Expand Down
1 change: 1 addition & 0 deletions megatron/data/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def build_dataset(index, name):
args = get_args()
dataset = UL2Dataset(
indexed_dataset=indexed_dataset,
model_type=args.ul2_model_type,
denoiser_ratios=args.ul2_denoiser_ratios,
denoisers=args.ul2_denoisers,
mean_span_lengths=args.ul2_mean_span_lengths,
Expand Down
121 changes: 92 additions & 29 deletions megatron/data/ul2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

"""UL2-style dataset."""

import math

import numpy as np

from megatron import get_tokenizer
Expand All @@ -10,16 +12,34 @@
get_samples_mapping,
SamplingStyle
)
from megatron.data.t5_dataset import pad_and_convert_to_numpy, T5Dataset
from megatron.data.t5_dataset import (
make_history_mask,
merge_subsequent_masks,
pad_and_convert_to_numpy,
T5Dataset,
)
from megatron.model.enums import UL2ModelType


def is_decoder_only(ul2_model_type):
"""Return whether we use a decoder-only model."""
assert isinstance(ul2_model_type, UL2ModelType)
return ul2_model_type is not UL2ModelType.encoder_decoder


def is_prefix_lm(ul2_model_type):
"""Return whether we use a non-causal decoder-only model."""
assert isinstance(ul2_model_type, UL2ModelType)
return ul2_model_type is UL2ModelType.non_causal_decoder


class UL2Dataset(T5Dataset):

def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, denoiser_ratios,
denoisers, mean_span_lengths, mask_ratios,
denoiser_tokens, max_seq_length, max_seq_length_dec,
short_seq_prob, seed):
num_epochs, max_num_samples, model_type,
denoiser_ratios, denoisers, mean_span_lengths,
mask_ratios, denoiser_tokens, max_seq_length,
max_seq_length_dec, short_seq_prob, seed):

if denoiser_ratios is None:
# Uniform distribution by default.
Expand All @@ -39,6 +59,7 @@ def __init__(self, name, indexed_dataset, data_prefix,
short_seq_prob, seed)

# Params to store.
self.model_type = model_type
self.denoiser_ratios = [
denoiser_ratio / sum(denoiser_ratios)
for denoiser_ratio in denoiser_ratios
Expand Down Expand Up @@ -84,18 +105,17 @@ def __getitem__(self, idx):
self.vocab_id_to_token_dict,
self.cls_ids, self.sep_id,
self.mask_id, self.pad_id,
self.denoiser_ratios, self.denoisers,
self.mean_span_lengths, self.mask_ratios,
np_rng,
self.bos_id, self.eos_id,
self.sentinel_tokens)
self.model_type, self.denoiser_ratios,
self.denoisers, self.mean_span_lengths,
self.mask_ratios, np_rng, self.bos_id,
self.eos_id, self.sentinel_tokens)


def build_training_sample(sample, target_seq_length,
max_seq_length, max_seq_length_dec,
vocab_id_list, vocab_id_to_token_dict,
cls_ids, sep_id, mask_id, pad_id,
denoiser_ratios, denoisers,
model_type, denoiser_ratios, denoisers,
mean_span_lengths, mask_ratios,
np_rng, bos_id=None,
eos_id=None, sentinel_tokens=None):
Expand All @@ -112,6 +132,7 @@ def build_training_sample(sample, target_seq_length,
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
model_type: What type of model is used.
denoiser_ratios: Probability of each denoising objective to be selected.
denoisers: What type of UL2 denoising objective the other UL2
configurations refer to.
Expand Down Expand Up @@ -174,22 +195,64 @@ def build_training_sample(sample, target_seq_length,
sampling_style=sampling_style, prefix_lm=prefix_lm,
)

# Padding.
tokens_enc, tokens_dec_in, labels, enc_mask, \
dec_mask, enc_dec_mask, loss_mask \
= pad_and_convert_to_numpy(tokens, masked_positions,
masked_labels, pad_id, max_seq_length,
max_seq_length_dec, masked_spans,
bos_id, eos_id, sentinel_tokens)

train_sample = {
'text_enc': tokens_enc,
'text_dec': tokens_dec_in,
'labels': labels,
'loss_mask': loss_mask,
'truncated': int(truncated),
'enc_mask': enc_mask,
'dec_mask': dec_mask,
'enc_dec_mask': enc_dec_mask,
}
if is_decoder_only(model_type):
# Concatenate to one sequence.
tokens_enc, tokens_dec_in, labels = merge_subsequent_masks(
tokens, masked_spans, bos_id, eos_id, sentinel_tokens)

# Move EOS tokens to end of sequence.
while tokens_enc[-1] == eos_id:
del tokens_enc[-1]
tokens_dec_in.append(eos_id)
labels.append(eos_id)

num_labels = len(labels)

# Move BOS token to start of sequence.
tokens_dec_in = tokens_dec_in[1:]
tokens = np.concatenate([
np.array([bos_id], dtype=np.int64),
tokens_enc,
np.array([sep_id], dtype=np.int64),
tokens_dec_in,
])
labels = np.concatenate([
tokens_enc,
np.array([sep_id], dtype=np.int64),
labels,
])

loss_mask = np.zeros(len(tokens), dtype=np.int64)
loss_mask[-num_labels:] = 1

dec_mask = make_history_mask(tokens)
if is_prefix_lm(model_type):
dec_mask[:-num_labels, :-num_labels] = 1

train_sample = {
'text': tokens,
'labels': labels,
'loss_mask': loss_mask,
'truncated': int(truncated),
'dec_mask': dec_mask,
}
else:
# Padding.
tokens_enc, tokens_dec_in, labels, enc_mask, \
dec_mask, enc_dec_mask, loss_mask \
= pad_and_convert_to_numpy(tokens, masked_positions,
masked_labels, pad_id, max_seq_length,
max_seq_length_dec, masked_spans,
bos_id, eos_id, sentinel_tokens)

train_sample = {
'text_enc': tokens_enc,
'text_dec': tokens_dec_in,
'labels': labels,
'loss_mask': loss_mask,
'truncated': int(truncated),
'enc_mask': enc_mask,
'dec_mask': dec_mask,
'enc_dec_mask': enc_dec_mask,
}
return train_sample
5 changes: 5 additions & 0 deletions megatron/model/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ class AttnMaskType(enum.Enum):
padding = 1
causal = 2
prefix = 3

class UL2ModelType(enum.Enum):
encoder_decoder = 'ED'
non_causal_decoder = 'ND'
causal_decoder = 'CD'
110 changes: 80 additions & 30 deletions pretrain_ul2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,71 @@
)
from megatron.core import tensor_parallel
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model, ModelType
from megatron.data.ul2_dataset import (
is_decoder_only as _is_decoder_only,
is_prefix_lm as _is_prefix_lm,
)
from megatron.model import GPTModel, ModelType, T5Model
from megatron.model.t5_model import t5_position_ids
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group


"""
Pipeline parallelism for UL2 with T5
====================================
Pipeline parallelism for UL2
============================

Since UL2 re-uses the T5 model architecture, please see its
Since UL2 re-uses the T5 model architecture for encoder-decoder models
and the GPT model architecture for decoder-only models, please see their
documentation for more information.
"""


def is_decoder_only():
"""Return whether we use a decoder-only model."""
args = get_args()
return _is_decoder_only(args.ul2_model_type)


def is_prefix_lm():
"""Return whether we use a non-causal decoder-only model."""
args = get_args()
return _is_prefix_lm(args.ul2_model_type)


def model_provider(pre_process=True, post_process=True,
add_encoder=True, add_decoder=True):
"""Build the model."""

print_rank_0('building UL2 model ...')
model = T5Model(num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
if is_decoder_only():
print_rank_0('Using decoder-only UL2 model.')
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
prefix_lm=True
)
else:
print_rank_0('Using encoder-decoder UL2 model.')
model = T5Model(num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
return model


def get_batch(data_iterator):
"""Build the batch."""

keys = ['text_enc', 'text_dec', 'labels', 'loss_mask',
'enc_mask', 'dec_mask', 'enc_dec_mask']
if is_decoder_only():
keys = ['text', 'labels', 'loss_mask', 'dec_mask']
else:
keys = ['text_enc', 'text_dec', 'labels', 'loss_mask',
'enc_mask', 'dec_mask', 'enc_dec_mask']
datatype = torch.int64

# Broadcast data.
Expand All @@ -56,17 +88,25 @@ def get_batch(data_iterator):
data_b = tensor_parallel.broadcast_data(keys, data, datatype)

# Unpack.
tokens_enc = data_b['text_enc'].long()
tokens_dec = data_b['text_dec'].long()
labels = data_b['labels'].long()
loss_mask = data_b['loss_mask'].float()
if is_decoder_only():
tokens = data_b['text'].long()
labels = data_b['labels'].long()
loss_mask = data_b['loss_mask'].float()

dec_mask = (data_b['dec_mask'] < 0.5)
return tokens, loss_mask, labels, dec_mask
else:
tokens_enc = data_b['text_enc'].long()
tokens_dec = data_b['text_dec'].long()
labels = data_b['labels'].long()
loss_mask = data_b['loss_mask'].float()

enc_mask = (data_b['enc_mask'] < 0.5)
dec_mask = (data_b['dec_mask'] < 0.5)
enc_dec_mask = (data_b['enc_dec_mask'] < 0.5)
enc_mask = (data_b['enc_mask'] < 0.5)
dec_mask = (data_b['dec_mask'] < 0.5)
enc_dec_mask = (data_b['enc_dec_mask'] < 0.5)

return tokens_enc, tokens_dec, loss_mask, labels, \
enc_mask, dec_mask, enc_dec_mask
return tokens_enc, tokens_dec, loss_mask, labels, \
enc_mask, dec_mask, enc_dec_mask


def loss_func(loss_mask, output_tensor):
Expand All @@ -87,18 +127,28 @@ def forward_step(data_iterator, model):

# Get the batch.
timers('batch generator', log_level=2).start()
tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask \
= get_batch(data_iterator)
if is_decoder_only():
(tokens, loss_mask, lm_labels, dec_mask) = get_batch(data_iterator)
else:
(
tokens_enc, tokens_dec, loss_mask, lm_labels,
enc_mask, dec_mask, enc_dec_mask,
) = get_batch(data_iterator)
timers('batch generator').stop()

# Forward model lm_labels
output_tensor = model(tokens_enc,
tokens_dec,
enc_mask,
dec_mask,
enc_dec_mask,
tokentype_ids=None,
lm_labels=lm_labels)
if is_decoder_only():
position_ids = t5_position_ids(tokens)
output_tensor = model(tokens, position_ids, dec_mask,
labels=lm_labels)
else:
output_tensor = model(tokens_enc,
tokens_dec,
enc_mask,
dec_mask,
enc_dec_mask,
tokentype_ids=None,
lm_labels=lm_labels)

return output_tensor, partial(loss_func, loss_mask)

Expand Down