1+ ### modified from pytorch codebase
2+
3+ import warnings
4+ import torch
5+ from torch .nn import Linear
6+ from torch .nn .init import xavier_uniform_
7+ from torch .nn .init import constant_
8+ from torch .nn .init import xavier_normal_
9+ from torch .nn .parameter import Parameter
10+ import torch .nn .functional as F
11+ from torch import nn
12+
13+
14+ class MultiheadAttention (nn .Module ):
15+ r"""Allows the model to jointly attend to information
16+ from different representation subspaces.
17+ See reference: Attention Is All You Need
18+
19+ .. math::
20+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
21+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
22+
23+ Args:
24+ embed_dim: total dimension of the model
25+ num_heads: parallel attention layers, or heads
26+
27+ Examples::
28+
29+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
30+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
31+ """
32+
33+ def __init__ (self , embed_dim , num_heads , dropout = 0. , bias = True , add_bias_kv = False , add_zero_attn = False ):
34+ super (MultiheadAttention , self ).__init__ ()
35+ self .embed_dim = embed_dim
36+ self .num_heads = num_heads
37+ self .dropout = dropout
38+ self .head_dim = embed_dim // num_heads
39+ assert self .head_dim * num_heads == self .embed_dim , "embed_dim must be divisible by num_heads"
40+ self .scaling = self .head_dim ** - 0.5
41+
42+ self .in_proj_weight = nn .Parameter (torch .empty (3 * embed_dim , embed_dim ))
43+ if bias :
44+ self .in_proj_bias = nn .Parameter (torch .empty (3 * embed_dim ))
45+ else :
46+ self .register_parameter ('in_proj_bias' , None )
47+ self .out_proj = nn .Linear (embed_dim , embed_dim , bias = bias )
48+
49+ if add_bias_kv :
50+ self .bias_k = nn .Parameter (torch .empty (1 , 1 , embed_dim ))
51+ self .bias_v = nn .Parameter (torch .empty (1 , 1 , embed_dim ))
52+ else :
53+ self .bias_k = self .bias_v = None
54+
55+ self .add_zero_attn = add_zero_attn
56+
57+ self ._reset_parameters ()
58+
59+ def _reset_parameters (self ):
60+ xavier_uniform_ (self .in_proj_weight [:self .embed_dim , :])
61+ xavier_uniform_ (self .in_proj_weight [self .embed_dim :(self .embed_dim * 2 ), :])
62+ xavier_uniform_ (self .in_proj_weight [(self .embed_dim * 2 ):, :])
63+
64+ xavier_uniform_ (self .out_proj .weight )
65+ if self .in_proj_bias is not None :
66+ constant_ (self .in_proj_bias , 0. )
67+ constant_ (self .out_proj .bias , 0. )
68+ if self .bias_k is not None :
69+ xavier_normal_ (self .bias_k )
70+ if self .bias_v is not None :
71+ xavier_normal_ (self .bias_v )
72+
73+ def forward (self , query , key , value , key_padding_mask = None , incremental_state = None ,
74+ need_weights = True , static_kv = False , attn_mask = None , softmax = True ):
75+ """
76+ Inputs of forward function
77+ query: [target length, batch size, embed dim]
78+ key: [sequence length, batch size, embed dim]
79+ value: [sequence length, batch size, embed dim]
80+ key_padding_mask: if True, mask padding based on batch size
81+ incremental_state: if provided, previous time steps are cashed
82+ need_weights: output attn_output_weights
83+ static_kv: key and value are static
84+
85+ Outputs of forward function
86+ attn_output: [target length, batch size, embed dim]
87+ attn_output_weights: [batch size, target length, sequence length]
88+ """
89+ qkv_same = query .data_ptr () == key .data_ptr () == value .data_ptr ()
90+ kv_same = key .data_ptr () == value .data_ptr ()
91+
92+ tgt_len , bsz , embed_dim = query .size ()
93+ assert embed_dim == self .embed_dim
94+ assert list (query .size ()) == [tgt_len , bsz , embed_dim ]
95+ assert key .size () == value .size ()
96+
97+ if incremental_state is not None :
98+ saved_state = self ._get_input_buffer (incremental_state )
99+ if 'prev_key' in saved_state :
100+ # previous time steps are cached - no need to recompute
101+ # key and value if they are static
102+ if static_kv :
103+ assert kv_same and not qkv_same
104+ key = value = None
105+ else :
106+ saved_state = None
107+
108+ if qkv_same :
109+ # self-attention
110+ q , k , v = self ._in_proj_qkv (query )
111+ elif kv_same :
112+ # encoder-decoder attention
113+ q = self ._in_proj_q (query )
114+ if key is None :
115+ assert value is None
116+ k = v = None
117+ else :
118+ k , v = self ._in_proj_kv (key )
119+ else :
120+ q = self ._in_proj_q (query )
121+ k = self ._in_proj_k (key )
122+ v = self ._in_proj_v (value )
123+ q *= self .scaling
124+
125+ if self .bias_k is not None :
126+ assert self .bias_v is not None
127+ k = torch .cat ([k , self .bias_k .repeat (1 , bsz , 1 )])
128+ v = torch .cat ([v , self .bias_v .repeat (1 , bsz , 1 )])
129+ if attn_mask is not None :
130+ attn_mask = torch .cat ([attn_mask , attn_mask .new_zeros (attn_mask .size (0 ), 1 )], dim = 1 )
131+ if key_padding_mask is not None :
132+ key_padding_mask = torch .cat (
133+ [key_padding_mask , key_padding_mask .new_zeros (key_padding_mask .size (0 ), 1 )], dim = 1 )
134+
135+ q = q .contiguous ().view (tgt_len , bsz * self .num_heads , self .head_dim ).transpose (0 , 1 )
136+ if k is not None :
137+ k = k .contiguous ().view (- 1 , bsz * self .num_heads , self .head_dim ).transpose (0 , 1 )
138+ if v is not None :
139+ v = v .contiguous ().view (- 1 , bsz * self .num_heads , self .head_dim ).transpose (0 , 1 )
140+
141+ if saved_state is not None :
142+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
143+ if 'prev_key' in saved_state :
144+ prev_key = saved_state ['prev_key' ].view (bsz * self .num_heads , - 1 , self .head_dim )
145+ if static_kv :
146+ k = prev_key
147+ else :
148+ k = torch .cat ((prev_key , k ), dim = 1 )
149+ if 'prev_value' in saved_state :
150+ prev_value = saved_state ['prev_value' ].view (bsz * self .num_heads , - 1 , self .head_dim )
151+ if static_kv :
152+ v = prev_value
153+ else :
154+ v = torch .cat ((prev_value , v ), dim = 1 )
155+ saved_state ['prev_key' ] = k .view (bsz , self .num_heads , - 1 , self .head_dim )
156+ saved_state ['prev_value' ] = v .view (bsz , self .num_heads , - 1 , self .head_dim )
157+
158+ self ._set_input_buffer (incremental_state , saved_state )
159+
160+ src_len = k .size (1 )
161+
162+ if key_padding_mask is not None :
163+ assert key_padding_mask .size (0 ) == bsz
164+ assert key_padding_mask .size (1 ) == src_len
165+
166+ if self .add_zero_attn :
167+ src_len += 1
168+ k = torch .cat ([k , k .new_zeros ((k .size (0 ), 1 ) + k .size ()[2 :])], dim = 1 )
169+ v = torch .cat ([v , v .new_zeros ((v .size (0 ), 1 ) + v .size ()[2 :])], dim = 1 )
170+ if attn_mask is not None :
171+ attn_mask = torch .cat ([attn_mask , attn_mask .new_zeros (attn_mask .size (0 ), 1 )], dim = 1 )
172+ if key_padding_mask is not None :
173+ key_padding_mask = torch .cat (
174+ [key_padding_mask , torch .zeros (key_padding_mask .size (0 ), 1 ).type_as (key_padding_mask )], dim = 1 )
175+
176+ attn_output_weights = torch .bmm (q , k .transpose (1 , 2 ))
177+ assert list (attn_output_weights .size ()) == [bsz * self .num_heads , tgt_len , src_len ]
178+
179+ if attn_mask is not None :
180+ attn_mask = attn_mask .unsqueeze (0 )
181+ attn_output_weights += attn_mask
182+
183+ if key_padding_mask is not None :
184+ attn_output_weights = attn_output_weights .view (bsz , self .num_heads , tgt_len , src_len )
185+ attn_output_weights = attn_output_weights .masked_fill (
186+ key_padding_mask .unsqueeze (1 ).unsqueeze (2 ),
187+ float ('-inf' ),
188+ )
189+ attn_output_weights = attn_output_weights .view (bsz * self .num_heads , tgt_len , src_len )
190+
191+ # print(attn_output_weights)
192+ if softmax :
193+ attn_output_weights = F .softmax (
194+ attn_output_weights .float (), dim = - 1 ,
195+ dtype = torch .float32 if attn_output_weights .dtype == torch .float16 else attn_output_weights .dtype )
196+ attn_output_weights = F .dropout (attn_output_weights , p = self .dropout , training = self .training )
197+ else :
198+ # attn_output_weights = torch.tensor( attn_output_weights.float(), dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype, device=attn_output_weights.device)
199+ attn_output_weights = F .dropout (attn_output_weights , p = self .dropout , training = self .training )
200+
201+ attn_output = torch .bmm (attn_output_weights , v )
202+ assert list (attn_output .size ()) == [bsz * self .num_heads , tgt_len , self .head_dim ]
203+ attn_output = attn_output .transpose (0 , 1 ).contiguous ().view (tgt_len , bsz , embed_dim )
204+ attn_output = self .out_proj (attn_output )
205+
206+ if need_weights :
207+ # average attention weights over heads
208+ attn_output_weights = attn_output_weights .view (bsz , self .num_heads , tgt_len , src_len )
209+ attn_output_weights = attn_output_weights .sum (dim = 1 ) / self .num_heads
210+ else :
211+ attn_output_weights = None
212+
213+ return attn_output , attn_output_weights
214+
215+
216+ def _in_proj_qkv (self , query ):
217+ return self ._in_proj (query ).chunk (3 , dim = - 1 )
218+
219+ def _in_proj_kv (self , key ):
220+ return self ._in_proj (key , start = self .embed_dim ).chunk (2 , dim = - 1 )
221+
222+ def _in_proj_q (self , query ):
223+ return self ._in_proj (query , end = self .embed_dim )
224+
225+ def _in_proj_k (self , key ):
226+ return self ._in_proj (key , start = self .embed_dim , end = 2 * self .embed_dim )
227+
228+ def _in_proj_v (self , value ):
229+ return self ._in_proj (value , start = 2 * self .embed_dim )
230+
231+ def _in_proj (self , input , start = 0 , end = None ):
232+ weight = self .in_proj_weight
233+ bias = self .in_proj_bias
234+ weight = weight [start :end , :]
235+ if bias is not None :
236+ bias = bias [start :end ]
237+ return F .linear (input , weight , bias )
0 commit comments