Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,14 @@ def forward(self, x):


class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., att_step=1):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.scale = dim_head ** -0.5
self.heads = heads
self.att_step = att_step

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
Expand All @@ -178,21 +179,42 @@ def forward(self, x, context=None, mask=None):

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (8, 4096, 40)
del q, k

if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
del mask

# attention, what we cannot get enough of, by halves
sim[4:] = sim[4:].softmax(dim=-1)
sim[:4] = sim[:4].softmax(dim=-1)

sim = einsum('b i j, b j d -> b i d', sim, v)
limit = k.shape[0]
att_step = self.att_step
q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))

q_chunks.reverse()
k_chunks.reverse()
v_chunks.reverse()
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
del k, q, v
for i in range (0, limit, att_step):

q_buffer = q_chunks.pop()
k_buffer = k_chunks.pop()
v_buffer = v_chunks.pop()
sim_buffer = einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale

del k_buffer, q_buffer
'''
if exists(mask):
mask_buffer = rearrange(mask[i:i+att_step,:,:,:], 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim_buffer.dtype).max
mask_buffer = repeat(mask_buffer, 'b j -> (b h) () j', h=h)
sim_buffer.masked_fill_(~mask_buffer, max_neg_value)
'''
# attention, what we cannot get enough of, by chunks

sim_buffer = sim_buffer.softmax(dim=-1)

sim_buffer = einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
del v_buffer
sim[i:i+att_step,:,:] = sim_buffer

del sim_buffer
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)

Expand Down