Skip to content

Commit 0e6b611

Browse files
committed
Modify plotting script
1 parent c4dac27 commit 0e6b611

File tree

3 files changed

+56
-23
lines changed

3 files changed

+56
-23
lines changed

notebooks/plot_runs.ipynb

Lines changed: 29 additions & 19 deletions
Large diffs are not rendered by default.

src/utils/dataset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,29 @@ def mod_classification_dataset(p, num, seed=0, device='cpu'):
328328

329329
return dataset
330330

331+
332+
def mod_equiv_dataset(p, num, seed=0, device='cpu'):
333+
334+
torch.manual_seed(seed)
335+
np.random.seed(seed)
336+
337+
N_sample = num
338+
x = np.random.choice(range(p), N_sample*2).reshape(N_sample, 2)
339+
340+
target = np.array([p if (x[i,0]-x[i,1])%5 == 0 else p+1 for i in range(N_sample)])
341+
342+
data_id = torch.from_numpy(x).to(device)
343+
labels = torch.from_numpy(target).to(device)
344+
345+
vocab_size = p+2
346+
347+
dataset = {}
348+
dataset['data_id'] = data_id
349+
dataset['label'] = labels
350+
dataset['vocab_size'] = vocab_size
351+
352+
return dataset
353+
331354
def family_tree_dataset(p, num, seed=0, device='cpu'):
332355

333356
torch.manual_seed(seed)

src/utils/driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def train_single_model(param_dict: dict):
6767
elif data_id == "greater":
6868
dataset = greater_than_dataset(p=30, num=data_size, seed=seed, device=device)
6969
elif data_id == "family_tree":
70-
dataset = family_tree_dataset_2(p=127, num=data_size, seed=seed, device=device)
70+
dataset = family_tree_dataset_2(p=255, num=data_size, seed=seed, device=device)
7171
elif data_id == "equivalence":
7272
input_token = 1
73-
dataset = mod_classification_dataset(p=100, num=data_size, seed=seed, device=device)
73+
dataset = mod_equiv_dataset(p=50, num=data_size, seed=seed, device=device)
7474
elif data_id == "circle":
7575
dataset = modular_addition_dataset(p=31, num=data_size, seed=seed, device=device)
7676
elif data_id=="permutation":
@@ -87,15 +87,15 @@ def train_single_model(param_dict: dict):
8787
weight_tied = True
8888
hidden_size = 100
8989
shp = [input_token * embd_dim, hidden_size, embd_dim, vocab_size]
90-
model = MLP_HS(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, weight_tied=weight_tied, seed=seed, n=np.sqrt(embd_dim), init_scale=1).to(device)
90+
model = MLP_HS(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, weight_tied=weight_tied, seed=seed, n=(embd_dim), init_scale=1).to(device)
9191
elif model_id == "standard_MLP":
9292
unembd = True
9393
weight_tied = True
9494
hidden_size = 100
9595
shp = [input_token * embd_dim, hidden_size, embd_dim, vocab_size]
9696
model = MLP(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, unembd=unembd, weight_tied=weight_tied, seed=seed, init_scale=1).to(device)
9797
elif model_id == "H_transformer":
98-
model = ToyTransformer(vocab_size=vocab_size, d_model=embd_dim, nhead=2, num_layers=2, n_dist=np.sqrt(embd_dim),seq_len=input_token, seed=seed, use_dist_layer=True, init_scale=1).to(device)
98+
model = ToyTransformer(vocab_size=vocab_size, d_model=embd_dim, nhead=2, num_layers=2, n_dist=embd_dim,seq_len=input_token, seed=seed, use_dist_layer=True, init_scale=1).to(device)
9999
elif model_id == "standard_transformer":
100100
model = ToyTransformer(vocab_size=vocab_size, d_model=embd_dim, nhead=2, num_layers=2, seq_len=input_token, seed=seed, use_dist_layer=False, init_scale=1).to(device)
101101
else:

0 commit comments

Comments
 (0)