Skip to content

Commit 16ce70e

Browse files
authored
add CTA
1 parent 69685b4 commit 16ce70e

13 files changed

Lines changed: 1677 additions & 0 deletions

cta/attention.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
### modified from pytorch codebase
2+
3+
import warnings
4+
import torch
5+
from torch.nn import Linear
6+
from torch.nn.init import xavier_uniform_
7+
from torch.nn.init import constant_
8+
from torch.nn.init import xavier_normal_
9+
from torch.nn.parameter import Parameter
10+
import torch.nn.functional as F
11+
from torch import nn
12+
13+
14+
class MultiheadAttention(nn.Module):
15+
r"""Allows the model to jointly attend to information
16+
from different representation subspaces.
17+
See reference: Attention Is All You Need
18+
19+
.. math::
20+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
21+
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
22+
23+
Args:
24+
embed_dim: total dimension of the model
25+
num_heads: parallel attention layers, or heads
26+
27+
Examples::
28+
29+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
30+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
31+
"""
32+
33+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
34+
super(MultiheadAttention, self).__init__()
35+
self.embed_dim = embed_dim
36+
self.num_heads = num_heads
37+
self.dropout = dropout
38+
self.head_dim = embed_dim // num_heads
39+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
40+
self.scaling = self.head_dim ** -0.5
41+
42+
self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
43+
if bias:
44+
self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))
45+
else:
46+
self.register_parameter('in_proj_bias', None)
47+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
48+
49+
if add_bias_kv:
50+
self.bias_k = nn.Parameter(torch.empty(1, 1, embed_dim))
51+
self.bias_v = nn.Parameter(torch.empty(1, 1, embed_dim))
52+
else:
53+
self.bias_k = self.bias_v = None
54+
55+
self.add_zero_attn = add_zero_attn
56+
57+
self._reset_parameters()
58+
59+
def _reset_parameters(self):
60+
xavier_uniform_(self.in_proj_weight[:self.embed_dim, :])
61+
xavier_uniform_(self.in_proj_weight[self.embed_dim:(self.embed_dim * 2), :])
62+
xavier_uniform_(self.in_proj_weight[(self.embed_dim * 2):, :])
63+
64+
xavier_uniform_(self.out_proj.weight)
65+
if self.in_proj_bias is not None:
66+
constant_(self.in_proj_bias, 0.)
67+
constant_(self.out_proj.bias, 0.)
68+
if self.bias_k is not None:
69+
xavier_normal_(self.bias_k)
70+
if self.bias_v is not None:
71+
xavier_normal_(self.bias_v)
72+
73+
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
74+
need_weights=True, static_kv=False, attn_mask=None, softmax=True):
75+
"""
76+
Inputs of forward function
77+
query: [target length, batch size, embed dim]
78+
key: [sequence length, batch size, embed dim]
79+
value: [sequence length, batch size, embed dim]
80+
key_padding_mask: if True, mask padding based on batch size
81+
incremental_state: if provided, previous time steps are cashed
82+
need_weights: output attn_output_weights
83+
static_kv: key and value are static
84+
85+
Outputs of forward function
86+
attn_output: [target length, batch size, embed dim]
87+
attn_output_weights: [batch size, target length, sequence length]
88+
"""
89+
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
90+
kv_same = key.data_ptr() == value.data_ptr()
91+
92+
tgt_len, bsz, embed_dim = query.size()
93+
assert embed_dim == self.embed_dim
94+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
95+
assert key.size() == value.size()
96+
97+
if incremental_state is not None:
98+
saved_state = self._get_input_buffer(incremental_state)
99+
if 'prev_key' in saved_state:
100+
# previous time steps are cached - no need to recompute
101+
# key and value if they are static
102+
if static_kv:
103+
assert kv_same and not qkv_same
104+
key = value = None
105+
else:
106+
saved_state = None
107+
108+
if qkv_same:
109+
# self-attention
110+
q, k, v = self._in_proj_qkv(query)
111+
elif kv_same:
112+
# encoder-decoder attention
113+
q = self._in_proj_q(query)
114+
if key is None:
115+
assert value is None
116+
k = v = None
117+
else:
118+
k, v = self._in_proj_kv(key)
119+
else:
120+
q = self._in_proj_q(query)
121+
k = self._in_proj_k(key)
122+
v = self._in_proj_v(value)
123+
q *= self.scaling
124+
125+
if self.bias_k is not None:
126+
assert self.bias_v is not None
127+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
128+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
129+
if attn_mask is not None:
130+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
131+
if key_padding_mask is not None:
132+
key_padding_mask = torch.cat(
133+
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
134+
135+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
136+
if k is not None:
137+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
138+
if v is not None:
139+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
140+
141+
if saved_state is not None:
142+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
143+
if 'prev_key' in saved_state:
144+
prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
145+
if static_kv:
146+
k = prev_key
147+
else:
148+
k = torch.cat((prev_key, k), dim=1)
149+
if 'prev_value' in saved_state:
150+
prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
151+
if static_kv:
152+
v = prev_value
153+
else:
154+
v = torch.cat((prev_value, v), dim=1)
155+
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
156+
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
157+
158+
self._set_input_buffer(incremental_state, saved_state)
159+
160+
src_len = k.size(1)
161+
162+
if key_padding_mask is not None:
163+
assert key_padding_mask.size(0) == bsz
164+
assert key_padding_mask.size(1) == src_len
165+
166+
if self.add_zero_attn:
167+
src_len += 1
168+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
169+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
170+
if attn_mask is not None:
171+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
172+
if key_padding_mask is not None:
173+
key_padding_mask = torch.cat(
174+
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
175+
176+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
177+
assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
178+
179+
if attn_mask is not None:
180+
attn_mask = attn_mask.unsqueeze(0)
181+
attn_output_weights += attn_mask
182+
183+
if key_padding_mask is not None:
184+
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
185+
attn_output_weights = attn_output_weights.masked_fill(
186+
key_padding_mask.unsqueeze(1).unsqueeze(2),
187+
float('-inf'),
188+
)
189+
attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
190+
191+
# print(attn_output_weights)
192+
if softmax:
193+
attn_output_weights = F.softmax(
194+
attn_output_weights.float(), dim=-1,
195+
dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype)
196+
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
197+
else:
198+
# attn_output_weights = torch.tensor( attn_output_weights.float(), dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype, device=attn_output_weights.device)
199+
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
200+
201+
attn_output = torch.bmm(attn_output_weights, v)
202+
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
203+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
204+
attn_output = self.out_proj(attn_output)
205+
206+
if need_weights:
207+
# average attention weights over heads
208+
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
209+
attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads
210+
else:
211+
attn_output_weights = None
212+
213+
return attn_output, attn_output_weights
214+
215+
216+
def _in_proj_qkv(self, query):
217+
return self._in_proj(query).chunk(3, dim=-1)
218+
219+
def _in_proj_kv(self, key):
220+
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
221+
222+
def _in_proj_q(self, query):
223+
return self._in_proj(query, end=self.embed_dim)
224+
225+
def _in_proj_k(self, key):
226+
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
227+
228+
def _in_proj_v(self, value):
229+
return self._in_proj(value, start=2 * self.embed_dim)
230+
231+
def _in_proj(self, input, start=0, end=None):
232+
weight = self.in_proj_weight
233+
bias = self.in_proj_bias
234+
weight = weight[start:end, :]
235+
if bias is not None:
236+
bias = bias[start:end]
237+
return F.linear(input, weight, bias)

0 commit comments

Comments
 (0)