From d9af6fd6042c4cadd255d884e69b2df2a86aa074 Mon Sep 17 00:00:00 2001 From: Md Fahim Faysal Khan Date: Tue, 26 Mar 2024 10:12:16 -0700 Subject: [PATCH 1/2] added support for GQA --- MaxText/layers/attentions.py | 84 ++++++++++++++---------------------- 1 file changed, 33 insertions(+), 51 deletions(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 05145c5abd..1cbb2e33ef 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -182,7 +182,7 @@ def apply_attention(self, if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: raise ValueError("""Decode not supported with flash attention. Use `dot_product` instead.""") - return self.cudnn_flash_attention(query, key, value), None, None + return self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None else: raise ValueError(f'Unexpected attention kernel {self.attention_kernel=}.') @@ -254,59 +254,41 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids): x = wrap_flash_attention(query, key, value, decoder_segment_ids) x = jnp.transpose(x, axes=(0, 2, 1, 3)) return x - - def cudnn_flash_attention(self, - query: Array, - key: Array, - value: Array) -> Array: + + def cudnn_flash_attention( + self, + query: Array, + key: Array, + value: Array, + decoder_segment_ids: Array | None, + model_mode: str = common_types.MODEL_MODE_TRAIN, + ) -> Array: """CUDNN Flash Attention with Transformer Engine. - - It is an unstable API. In future release, the API can get changed - A stable flash attention API will be included soon. Currently, - 1. It does not support GQA, num_query_heads == num_kv_heads - 2. It supports head_dim till 128 - GQA support with head_dim=256 will be added soon + 1. Stable API, supports GQA + 2. Supports head_dim till 128; head_dim=256 support will be added soon """ - # These imports are only meant to work in a GPU build. - import transformer_engine.jax.fused_attn as fused_attn # pytype: disable=import-error - from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout # pytype: disable=import-error - from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available # pytype: disable=import-error - - batch, s_q, n_heads, head_dim = query.shape # pylint: disable=unused-variable - _, s_kv, _, _ = key.shape - - qkv_layout = QKVLayout.BS3HD - attn_mask_type = AttnMaskType.CAUSAL_MASK - attn_bias_type = AttnBiasType.NO_BIAS - - has_fused_attn_kernel = is_fused_attn_kernel_available( - self.dtype, self.dtype, qkv_layout, - attn_bias_type, - attn_mask_type, - self.dropout_rate, self.num_query_heads, - self.num_kv_heads, s_q, - s_kv, head_dim) - - if not has_fused_attn_kernel: - raise ValueError("Flash attention is not supported for current config i.e. head_dim, seq_len, n_heads etc." - "Please see transformer_engine/common/fused_attn/fused_attn.cpp:NVTE_Fused_Attn_Backend for details") - - q = jnp.reshape(query, (*query.shape[:2], 1, *query.shape[-2:])) - k = jnp.reshape(key, (*query.shape[:2], 1, *query.shape[-2:])) - v = jnp.reshape(value, (*query.shape[:2], 1, *query.shape[-2:])) - qkv = jnp.concatenate((q, k, v), axis=2) # to make it (b, s, 3, h, d) - - return fused_attn.self_fused_attn( - qkv=qkv, - bias=None, - mask=jnp.zeros((batch, 1, s_q, s_kv)), # no padding - seed=None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=1.0/math.sqrt(head_dim), - dropout_probability=self.dropout_rate, - is_training=True) + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + + _, _, _, head_dim = query.shape # pylint: disable=unused-variable + + #generate attn_mask + attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + + dpa_layer = DotProductAttention(head_dim=head_dim, + num_attention_heads=self.num_query_heads, + num_gqa_groups=self.num_kv_heads, + attn_mask_type='causal', # 'causal' or 'padding' + attn_bias_type='NO_BIAS', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + attention_dropout=self.dropout_rate, + dropout_rng_name='aqt', + dtype=self.dtype, + float32_logits=self.float32_logits, + qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=1.0/math.sqrt(head_dim), + transpose_batch_sequence=False) + return dpa_layer(query, key, value, mask=attn_mask) + def compute_local_attention(self, attn_weights: Array, From 51f6d85501537044ffbbbdf563c5712da0d7e36d Mon Sep 17 00:00:00 2001 From: Md Fahim Faysal Khan Date: Fri, 5 Apr 2024 17:08:18 -0700 Subject: [PATCH 2/2] added GQA support for cudnn flash attention --- MaxText/layers/attentions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 1cbb2e33ef..8ea5e7bd0d 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -289,7 +289,6 @@ def cudnn_flash_attention( transpose_batch_sequence=False) return dpa_layer(query, key, value, mask=attn_mask) - def compute_local_attention(self, attn_weights: Array, value: Array) -> tuple[Array, Array, Array]: