Skip to content

Commit 0d391c7

Browse files
committed
Added code for training transformers with harmonic loss to learn lattice
1 parent 767561b commit 0d391c7

File tree

6 files changed

+640
-2
lines changed

6 files changed

+640
-2
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
__pycache__
2+

dataset.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,17 @@ def combine_dataset(train_dataset, test_dataset):
102102
assert train_dataset['vocab_size'] == test_dataset['vocab_size']
103103
dataset_c['vocab_size'] = train_dataset['vocab_size']
104104

105-
return dataset_c
105+
return dataset_c
106+
107+
108+
# Dataset and DataLoader
109+
class ToyDataset(torch.utils.data.Dataset):
110+
def __init__(self, inputs, targets):
111+
self.inputs = inputs
112+
self.targets = targets
113+
114+
def __len__(self):
115+
return len(self.inputs)
116+
117+
def __getitem__(self, idx):
118+
return self.inputs[idx], self.targets[idx]

model.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import torch.nn as nn
22
import torch
3+
import torch.optim as optim
34
import random
45
import numpy as np
56
import math
67

8+
from tqdm import tqdm
9+
710
class MLP(nn.Module):
811
def __init__(self, shp, vocab_size, embd_dim, input_token=2, init_scale=1., unembd=False, weight_tied=False, seed=0):
912
super(MLP, self).__init__()
@@ -130,4 +133,60 @@ def pred_logit(self, x):
130133
prob = prob_unnorm/torch.sum(prob_unnorm, dim=1, keepdim=True)
131134
logits = torch.log(prob)
132135
return logits
133-
136+
137+
138+
# 2-Layer Transformer Model with Explicit Residual Connections
139+
class ToyTransformer(nn.Module):
140+
def __init__(self, vocab_size, d_model, nhead, num_layers, seq_len = 16, use_dist_layer = False):
141+
super(ToyTransformer, self).__init__()
142+
self.embedding = nn.Embedding(vocab_size, d_model)
143+
self.positional_encoding = nn.Parameter(torch.randn(seq_len, d_model))
144+
145+
# Define transformer encoder layers
146+
self.layers = nn.ModuleList([
147+
nn.TransformerDecoderLayer(
148+
d_model=d_model, nhead=nhead, dim_feedforward=64, batch_first=True
149+
) for _ in range(num_layers)
150+
])
151+
self.use_dist_layer = use_dist_layer
152+
if use_dist_layer:
153+
self.dist = DistLayer(d_model, vocab_size, n=1., eps=1e-4, bias=False)
154+
self.fc = nn.Linear(d_model, vocab_size)
155+
156+
def forward(self, x):
157+
embedded = self.embedding(x) + self.positional_encoding
158+
159+
# Pass through transformer layers with residual connections
160+
x = embedded
161+
for layer in self.layers:
162+
x = layer(x,x) + x # Explicit residual connection
163+
164+
if self.use_dist_layer:
165+
x = x[:, -1]
166+
x = self.dist(x)
167+
prob = x/torch.sum(x, dim=1, keepdim=True)
168+
logits = torch.log(prob)
169+
else:
170+
logits = self.fc(x[:, -1]) # Only predict the last token
171+
return logits
172+
def train(self, param_dict: dict):
173+
174+
num_epochs = param_dict['num_epochs']
175+
learning_rate = param_dict['learning_rate']
176+
dataloader = param_dict['dataloader']
177+
criterion = nn.CrossEntropyLoss()
178+
179+
optimizer = optim.AdamW(self.parameters(), lr=learning_rate)
180+
for epoch in tqdm(range(num_epochs)):
181+
total_loss = 0
182+
for batch_inputs, batch_targets in dataloader:
183+
optimizer.zero_grad()
184+
logits = self.forward(batch_inputs)
185+
186+
loss = criterion(logits, batch_targets.squeeze())
187+
loss.backward()
188+
optimizer.step()
189+
total_loss += loss.item()
190+
191+
if (epoch + 1) % 50 == 0:
192+
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")

notebooks/transformer_lattice.ipynb

Lines changed: 437 additions & 0 deletions
Large diffs are not rendered by default.

scratch.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#%%
2+
import torch
3+
import torch.nn as nn
4+
import torch.optim as optim
5+
import numpy as np
6+
7+
from dataset import parallelogram_dataset, repeat_dataset
8+
from model import DistLayer
9+
10+
seed = 0
11+
np.random.seed(seed)
12+
torch.manual_seed(seed)
13+
14+
torch.set_default_tensor_type(torch.DoubleTensor)
15+
16+
p = 5
17+
embd_dim = 5
18+
input_token = 3
19+
lattice_dim = 2
20+
vocab_size = p ** lattice_dim
21+
22+
23+
# data
24+
dataset = parallelogram_dataset(p=p, dim=lattice_dim, num=1000, seed=seed)
25+
dataset = repeat_dataset(dataset)
26+
27+
# Parameters
28+
d_model = 16 # Embedding and hidden size
29+
nhead = 2 # Number of attention heads
30+
num_layers = 2 # Number of transformer layers
31+
seq_len = 3 # Sequence length
32+
num_epochs = 500
33+
batch_size = 16
34+
learning_rate = 0.001
35+
36+
37+
# Dataset and DataLoader
38+
class ToyDataset(torch.utils.data.Dataset):
39+
def __init__(self, inputs, targets):
40+
self.inputs = inputs
41+
self.targets = targets
42+
43+
def __len__(self):
44+
return len(self.inputs)
45+
46+
def __getitem__(self, idx):
47+
return self.inputs[idx], self.targets[idx]
48+
49+
toy_dataset = ToyDataset(dataset['train_data_id'], dataset['train_label'])
50+
print(vocab_size, dataset['vocab_size'])
51+
52+
dataloader = torch.utils.data.DataLoader(toy_dataset, batch_size=batch_size, shuffle=True)
53+
54+
# 2-Layer Transformer Model with Explicit Residual Connections
55+
class ToyTransformer(nn.Module):
56+
def __init__(self, vocab_size, d_model, nhead, num_layers):
57+
super(ToyTransformer, self).__init__()
58+
self.embedding = nn.Embedding(vocab_size, d_model)
59+
self.positional_encoding = nn.Parameter(torch.randn(seq_len, d_model))
60+
61+
# Define transformer encoder layers
62+
self.layers = nn.ModuleList([
63+
nn.TransformerDecoderLayer(
64+
d_model=d_model, nhead=nhead, dim_feedforward=64, batch_first=True
65+
) for _ in range(num_layers)
66+
])
67+
self.dist = DistLayer(d_model, d_model, n=1., eps=1e-4, bias=False)
68+
self.fc = nn.Linear(d_model, vocab_size)
69+
70+
def forward(self, x):
71+
embedded = self.embedding(x) + self.positional_encoding
72+
73+
# Pass through transformer layers with residual connections
74+
x = embedded
75+
for layer in self.layers:
76+
x = layer(x,x) + x # Explicit residual connection
77+
78+
logits = self.fc(x[:, -1]) # Only predict the last token
79+
return logits
80+
81+
model = ToyTransformer(vocab_size, d_model, nhead, num_layers)
82+
criterion = nn.CrossEntropyLoss()
83+
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
84+
85+
# Training loop
86+
for epoch in range(num_epochs):
87+
total_loss = 0
88+
for batch_inputs, batch_targets in dataloader:
89+
optimizer.zero_grad()
90+
logits = model(batch_inputs)
91+
batch_indices = torch.arange(logits.size(0))
92+
loss = ((1/(logits[batch_indices,batch_targets] + 1e-4)) / (1/(logits + 1e-4)).sum(dim=1)).sum()
93+
# loss = criterion(logits, batch_targets.squeeze())
94+
loss.backward()
95+
optimizer.step()
96+
total_loss += loss.item()
97+
98+
if (epoch + 1) % 10 == 0:
99+
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")
100+
101+
# %%
102+
emb = model.embedding.weight
103+
104+
from sklearn.decomposition import PCA
105+
import matplotlib.pyplot as plt
106+
107+
pca = PCA(n_components=2)
108+
emb_pca = pca.fit_transform(emb.detach().numpy())
109+
110+
for i in range(len(emb_pca)):
111+
plt.text(emb_pca[i, 0], emb_pca[i, 1], str(i), fontsize=12)
112+
plt.scatter(emb_pca[:, 0], emb_pca[:, 1])
113+
# %%

visualization.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from sklearn.decomposition import PCA
2+
import matplotlib.pyplot as plt
3+
4+
def visualize_embedding(emb, title=""):
5+
6+
pca = PCA(n_components=2)
7+
emb_pca = pca.fit_transform(emb.detach().numpy())
8+
print("Explained Variance Ratio", pca.explained_variance_ratio_)
9+
dim1 = 0
10+
dim2 = 1
11+
plt.title(title)
12+
for i in range(len(emb_pca)):
13+
plt.text(emb_pca[i, dim1], emb_pca[i, dim2], str(i), fontsize=12)
14+
plt.scatter(emb_pca[:, dim1], emb_pca[:, dim2])

0 commit comments

Comments
 (0)