Skip to content

Commit 34e8c14

Browse files
committed
Added code for parameter sweeping and descendant dataset
1 parent 340839a commit 34e8c14

File tree

12 files changed

+1155
-2
lines changed

12 files changed

+1155
-2
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
__pycache__
2+
results
23

4+
scratch.ipynb

dataset.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,34 @@ def __len__(self):
115115
return len(self.inputs)
116116

117117
def __getitem__(self, idx):
118-
return self.inputs[idx], self.targets[idx]
118+
return self.inputs[idx], self.targets[idx]
119+
120+
def descendant_dataset(p, num, seed=0, device='cpu'):
121+
122+
torch.manual_seed(seed)
123+
np.random.seed(seed)
124+
125+
N_sample = num
126+
x = np.random.choice(range(1,p), N_sample*2).reshape(N_sample, 2)
127+
128+
# Check if b is a descendant of a
129+
# In a complete binary tree where two children of x is 2x and 2x+1
130+
def is_desc(a, b):
131+
while b > 1:
132+
if b == a:
133+
return True
134+
b //= 2 # Move up to the parent node
135+
return b == a
136+
target = np.array([(p+1) if is_desc(x[i,0], x[i,1]) else p for i in range(N_sample)])
137+
138+
data_id = torch.from_numpy(x).to(device)
139+
labels = torch.from_numpy(target).to(device)
140+
141+
vocab_size = p+2
142+
143+
dataset = {}
144+
dataset['data_id'] = data_id
145+
dataset['label'] = labels
146+
dataset['vocab_size'] = vocab_size
147+
148+
return dataset

model.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
from tqdm import tqdm
99

10+
from itertools import combinations
11+
from sklearn.decomposition import PCA
12+
1013
class MLP(nn.Module):
1114
def __init__(self, shp, vocab_size, embd_dim, input_token=2, init_scale=1., unembd=False, weight_tied=False, seed=0):
1215
super(MLP, self).__init__()
@@ -169,17 +172,21 @@ def forward(self, x):
169172
else:
170173
logits = self.fc(x[:, -1]) # Only predict the last token
171174
return logits
175+
172176
def train(self, param_dict: dict):
173177

174178
num_epochs = param_dict['num_epochs']
175179
learning_rate = param_dict['learning_rate']
176180
dataloader = param_dict['dataloader']
181+
device = param_dict['device']
177182
criterion = nn.CrossEntropyLoss()
178183

179184
optimizer = optim.AdamW(self.parameters(), lr=learning_rate)
180185
for epoch in tqdm(range(num_epochs)):
181186
total_loss = 0
182187
for batch_inputs, batch_targets in dataloader:
188+
batch_inputs = batch_inputs.to(device)
189+
batch_targets = batch_targets.to(device)
183190
optimizer.zero_grad()
184191
logits = self.forward(batch_inputs)
185192

@@ -189,4 +196,61 @@ def train(self, param_dict: dict):
189196
total_loss += loss.item()
190197

191198
if (epoch + 1) % 50 == 0:
192-
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")
199+
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")
200+
201+
202+
def eval(self):
203+
deviation_arr = []
204+
points = [(i, j) for i in range(5) for j in range(5)]
205+
206+
207+
def side_length_deviation(a, b, c, d):
208+
a, b, c, d = np.array(a), np.array(b), np.array(c), np.array(d)
209+
210+
# Compute lengths of opposite sides
211+
length_ab = np.linalg.norm(b - a)
212+
length_cd = np.linalg.norm(d - c)
213+
length_ac = np.linalg.norm(c - a)
214+
length_bd = np.linalg.norm(b - d)
215+
length_bc = np.linalg.norm(c - b)
216+
length_ad = np.linalg.norm(d - a)
217+
218+
# Calculate side length deviation
219+
side_deviation = np.sqrt((length_ab - length_cd)**2 + (length_ac - length_bd)**2) / np.sqrt((length_ab ** 2 + length_bc ** 2 + length_cd ** 2 + length_ad ** 2)/2)
220+
221+
return side_deviation
222+
223+
for quad in combinations(points, 3):
224+
a, b, c = quad
225+
d = (c[0] + b[0] - a[0], c[1] + b[1] - a[1])
226+
if d[0] < 0 or d[0] >= 5 or d[1] < 0 or d[1] >= 5:
227+
continue
228+
229+
if a[0] == b[0] and b[0] == c[0]:
230+
continue
231+
if a[1] == b[1] and b[1] == c[1]:
232+
continue
233+
234+
a = 5*a[0] + a[1]
235+
b = 5*b[0] + b[1]
236+
c = 5*c[0] + c[1]
237+
d = 5*d[0] + d[1]
238+
239+
a = self.embedding.weight[a].cpu().detach().numpy()
240+
b = self.embedding.weight[b].cpu().detach().numpy()
241+
c = self.embedding.weight[c].cpu().detach().numpy()
242+
d = self.embedding.weight[d].cpu().detach().numpy()
243+
deviation = side_length_deviation(a, b, c, d)
244+
deviation_arr.append(deviation)
245+
246+
pca = PCA(n_components=10)
247+
emb_pca = pca.fit_transform(self.embedding.weight.cpu().detach().numpy())
248+
pca.fit_transform(emb_pca)
249+
variances = pca.explained_variance_ratio_
250+
251+
result_dict = {
252+
'parallelogram_quality': np.mean(deviation_arr),
253+
'variances': variances,
254+
}
255+
256+
return result_dict

notebooks/plot_sweep.ipynb

Lines changed: 144 additions & 0 deletions
Large diffs are not rendered by default.

notebooks/transformer_lattice.ipynb

Lines changed: 170 additions & 0 deletions
Large diffs are not rendered by default.

scripts/data_size_sweep.sh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/bin/bash
2+
#SBATCH -t 2:00:00
3+
#SBATCH --gres=gpu:1
4+
#SBATCH -n 32
5+
6+
sizes=$(python3 -c "import numpy as np; print(' '.join(map(str, np.logspace(1, 4, num=10, dtype=int))))")
7+
8+
9+
for size in $sizes
10+
do
11+
python3 ../sweep_transformers.py --data_size $size --use_harmonic 0
12+
done
13+

0 commit comments

Comments
 (0)