Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions finetune_mtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from megatron import get_args, get_tokenizer, print_rank_0, mpu
from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group
from megatron.model.enums import PositionEmbeddingType
from megatron.model.enums import PositionEmbeddingType, AttnMaskType
#from megatron.model import GPTModelPipe
from megatron.model import GPTModel, ModelType
from megatron.training import pretrain
Expand All @@ -28,7 +28,8 @@ def model_provider(pre_process=True, post_process=True):
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
post_process=post_process,
attn_mask_type=AttnMaskType.custom,
)
return model

Expand Down
36 changes: 30 additions & 6 deletions megatron/fused_kernels/scaled_masked_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,22 @@ __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }

template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);

template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }

template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }

template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }

template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }


int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
Expand Down Expand Up @@ -269,7 +285,7 @@ __global__ void scaled_masked_softmax_warp_forward(
if (temp_mask[element] != 1) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -10000.0;
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else {
Expand Down Expand Up @@ -298,7 +314,11 @@ __global__ void scaled_masked_softmax_warp_forward(
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
if (elements[i][it] <= -std::numeric_limits<acc_t>::infinity()) {
elements[i][it] = 0.0f;
} else {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
}
sum[i] += elements[i][it];
}
}
Expand All @@ -314,11 +334,15 @@ __global__ void scaled_masked_softmax_warp_forward(
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
if (sum[i] == 0.0f) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE);
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
Expand Down
15 changes: 9 additions & 6 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,14 @@ def post_language_model_processing(lm_output, labels, logit_weights,
class GPTModel(MegatronModule):
"""GPT-2 Language model."""

def __init__(self,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True):
def __init__(
self,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True,
attn_mask_type: AttnMaskType = AttnMaskType.causal,
):
super(GPTModel, self).__init__()
args = get_args()

Expand All @@ -74,7 +77,7 @@ def __init__(self,
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
encoder_attn_mask_type=attn_mask_type,
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def init_(tensor):


def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
attention_scores.masked_fill_(attention_mask, torch.finfo(attention_scores.dtype).min)
return attention_scores


Expand Down