Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 32 additions & 51 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def apply_attention(self,
if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
raise ValueError("""Decode not supported with flash attention.
Use `dot_product` instead.""")
return self.cudnn_flash_attention(query, key, value), None, None
return self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None
else:
raise ValueError(f'Unexpected attention kernel {self.attention_kernel=}.')

Expand Down Expand Up @@ -254,59 +254,40 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids):
x = wrap_flash_attention(query, key, value, decoder_segment_ids)
x = jnp.transpose(x, axes=(0, 2, 1, 3))
return x

def cudnn_flash_attention(self,
query: Array,
key: Array,
value: Array) -> Array:

def cudnn_flash_attention(
self,
query: Array,
key: Array,
value: Array,
decoder_segment_ids: Array | None,
model_mode: str = common_types.MODEL_MODE_TRAIN,
) -> Array:
"""CUDNN Flash Attention with Transformer Engine.

It is an unstable API. In future release, the API can get changed
A stable flash attention API will be included soon. Currently,
1. It does not support GQA, num_query_heads == num_kv_heads
2. It supports head_dim till 128
GQA support with head_dim=256 will be added soon
1. Stable API, supports GQA
2. Supports head_dim till 128; head_dim=256 support will be added soon
"""

# These imports are only meant to work in a GPU build.
import transformer_engine.jax.fused_attn as fused_attn # pytype: disable=import-error
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout # pytype: disable=import-error
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available # pytype: disable=import-error

batch, s_q, n_heads, head_dim = query.shape # pylint: disable=unused-variable
_, s_kv, _, _ = key.shape

qkv_layout = QKVLayout.BS3HD
attn_mask_type = AttnMaskType.CAUSAL_MASK
attn_bias_type = AttnBiasType.NO_BIAS

has_fused_attn_kernel = is_fused_attn_kernel_available(
self.dtype, self.dtype, qkv_layout,
attn_bias_type,
attn_mask_type,
self.dropout_rate, self.num_query_heads,
self.num_kv_heads, s_q,
s_kv, head_dim)

if not has_fused_attn_kernel:
raise ValueError("Flash attention is not supported for current config i.e. head_dim, seq_len, n_heads etc."
"Please see transformer_engine/common/fused_attn/fused_attn.cpp:NVTE_Fused_Attn_Backend for details")

q = jnp.reshape(query, (*query.shape[:2], 1, *query.shape[-2:]))
k = jnp.reshape(key, (*query.shape[:2], 1, *query.shape[-2:]))
v = jnp.reshape(value, (*query.shape[:2], 1, *query.shape[-2:]))
qkv = jnp.concatenate((q, k, v), axis=2) # to make it (b, s, 3, h, d)

return fused_attn.self_fused_attn(
qkv=qkv,
bias=None,
mask=jnp.zeros((batch, 1, s_q, s_kv)), # no padding
seed=None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=1.0/math.sqrt(head_dim),
dropout_probability=self.dropout_rate,
is_training=True)
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error

_, _, _, head_dim = query.shape # pylint: disable=unused-variable

#generate attn_mask
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)

dpa_layer = DotProductAttention(head_dim=head_dim,
num_attention_heads=self.num_query_heads,
num_gqa_groups=self.num_kv_heads,
attn_mask_type='causal', # 'causal' or 'padding'
attn_bias_type='NO_BIAS', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
attention_dropout=self.dropout_rate,
dropout_rng_name='aqt',
dtype=self.dtype,
float32_logits=self.float32_logits,
qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
scale_factor=1.0/math.sqrt(head_dim),
transpose_batch_sequence=False)
return dpa_layer(query, key, value, mask=attn_mask)

def compute_local_attention(self,
attn_weights: Array,
Expand Down