Skip to content

Commit 80e1990

Browse files
committed
basic implementation of self attention
1 parent b47213b commit 80e1990

1 file changed

Lines changed: 54 additions & 0 deletions

File tree

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
import numpy as np
5+
6+
def self_attention(queries, keys, values, mask=None):
7+
"""
8+
Compute self-attention
9+
10+
Args:
11+
- queries: numpy array of shape (batch_size, seq_len, d_model)
12+
- keys: numpy array of shape (batch_size, seq_len, d_model)
13+
- values: numpy array of shape (batch_size, seq_len, d_model)
14+
- mask: optional numpy array of shape (batch_size, seq_len, seq_len)
15+
16+
Returns:
17+
- output: numpy array of shape (batch_size, seq_len, d_model)
18+
- attention_weights: numpy array of shape (batch_size, seq_len, seq_len)
19+
"""
20+
21+
# Get dimensions
22+
batch_size, seq_len, d_model = queries.shape
23+
24+
# Compute attention scores
25+
attention_scores = np.matmul(queries, keys.transpose(0, 2, 1))
26+
27+
# Scale attention scores
28+
attention_scores = attention_scores / np.sqrt(d_model)
29+
30+
# Apply mask if provided
31+
if mask is not None:
32+
attention_scores = np.where(mask == 0, -1e9, attention_scores)
33+
34+
# Compute attention weights
35+
attention_weights = np.exp(attention_scores) / np.sum(np.exp(attention_scores), axis=-1, keepdims=True)
36+
37+
# Compute output
38+
output = np.matmul(attention_weights, values)
39+
40+
return output, attention_weights
41+
42+
# Example usage
43+
batch_size = 2
44+
seq_len = 4
45+
d_model = 8
46+
47+
queries = np.random.randn(batch_size, seq_len, d_model)
48+
keys = np.random.randn(batch_size, seq_len, d_model)
49+
values = np.random.randn(batch_size, seq_len, d_model)
50+
51+
output, attention_weights = self_attention(queries, keys, values)
52+
53+
print("Output shape:", output.shape)
54+
print("Attention weights shape:", attention_weights.shape)

0 commit comments

Comments
 (0)