diff --git a/finetune_mtf.py b/finetune_mtf.py index 89600e780ba..921e89affbf 100644 --- a/finetune_mtf.py +++ b/finetune_mtf.py @@ -9,7 +9,7 @@ from megatron import get_args, get_tokenizer, print_rank_0, mpu from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group -from megatron.model.enums import PositionEmbeddingType +from megatron.model.enums import PositionEmbeddingType, AttnMaskType #from megatron.model import GPTModelPipe from megatron.model import GPTModel, ModelType from megatron.training import pretrain @@ -28,7 +28,8 @@ def model_provider(pre_process=True, post_process=True): num_tokentypes=0, parallel_output=True, pre_process=pre_process, - post_process=post_process + post_process=post_process, + attn_mask_type=AttnMaskType.custom, ) return model diff --git a/megatron/fused_kernels/scaled_masked_softmax.h b/megatron/fused_kernels/scaled_masked_softmax.h index f9ca0bbc7ec..8abca7e90d4 100644 --- a/megatron/fused_kernels/scaled_masked_softmax.h +++ b/megatron/fused_kernels/scaled_masked_softmax.h @@ -47,6 +47,22 @@ __device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t * template <> __device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + + int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; @@ -269,7 +285,7 @@ __global__ void scaled_masked_softmax_warp_forward( if (temp_mask[element] != 1) { elements[i][it + element] = (acc_t)temp_data[element] * scale; } else { - elements[i][it + element] = -10000.0; + elements[i][it + element] = -std::numeric_limits::infinity(); } } } else { @@ -298,7 +314,11 @@ __global__ void scaled_masked_softmax_warp_forward( for (int i = 0; i < WARP_BATCH; ++i) { #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); + if (elements[i][it] <= -std::numeric_limits::infinity()) { + elements[i][it] = 0.0f; + } else { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + } sum[i] += elements[i][it]; } } @@ -314,11 +334,15 @@ __global__ void scaled_masked_softmax_warp_forward( for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; + if (sum[i] == 0.0f) { + copy_zero_vector(dst + i * element_count + it * WARP_SIZE); + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); } else { break; } diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index b6a1d7b5e90..86899da7b45 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -58,11 +58,14 @@ def post_language_model_processing(lm_output, labels, logit_weights, class GPTModel(MegatronModule): """GPT-2 Language model.""" - def __init__(self, - num_tokentypes=0, - parallel_output=True, - pre_process=True, - post_process=True): + def __init__( + self, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True, + attn_mask_type: AttnMaskType = AttnMaskType.causal, + ): super(GPTModel, self).__init__() args = get_args() @@ -74,7 +77,7 @@ def __init__(self, self.language_model, self._language_model_key = get_language_model( num_tokentypes=num_tokentypes, add_pooler=False, - encoder_attn_mask_type=AttnMaskType.causal, + encoder_attn_mask_type=attn_mask_type, init_method=init_method_normal(args.init_method_std), scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers), diff --git a/megatron/model/utils.py b/megatron/model/utils.py index f26b0685340..1b85d128330 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -40,7 +40,7 @@ def init_(tensor): def attention_mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) + attention_scores.masked_fill_(attention_mask, torch.finfo(attention_scores.dtype).min) return attention_scores