Skip to content

Commit f4c7df6

Browse files
committed
Tweaked training scripts
1 parent 83ab4d8 commit f4c7df6

File tree

8 files changed

+109
-111
lines changed

8 files changed

+109
-111
lines changed

notebooks/modadd.ipynb

Lines changed: 33 additions & 30 deletions
Large diffs are not rendered by default.

src/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
## How to add new dataset for experiments
2+
3+
1. Implement a function which returns the dataset dictionary in `utils/dataset.py`.
4+
2. Choose a unique id for the new dataset. Implement a function which evaluates the quality of representation in `utils/crystal_metric.py`. Modify the function `crystal_metric` to support the new data_id.
5+
3. Add the new data_id to the array `data_id_choices` in `run_exp.py`.
6+
4. If any auxiliary information is required to evaluate the representations, add them to the dictionary `aux_info` in `run_exp.py`. Sometimes, these information may depend on the specific dataset; In such cases, make any necessary modifications within each of the three experiment for loops in `run_exp.py`.
7+
5. Now, you're ready to test the new dataset! Command format is:
8+
`python run_exp.py --data_id new_data_id --model_id H_MLP`.

src/run_exp.py

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
model_id_choices = ["H_MLP", "standard_MLP", "H_transformer", "standard_transformer"]
2020
if __name__ == '__main__':
2121
parser = argparse.ArgumentParser(description='Experiment')
22-
parser.add_argument('--seed', type=int, default=77, help='random seed')
22+
parser.add_argument('--seed', type=int, default=29, help='random seed')
2323
parser.add_argument('--data_id', type=str, required=True, choices=data_id_choices, help='Data ID')
2424
parser.add_argument('--model_id', type=str, required=True, choices=model_id_choices, help='Model ID')
2525

@@ -42,6 +42,17 @@
4242
'embd_dim': 16,
4343
}
4444

45+
aux_info = {}
46+
if data_id == "lattice":
47+
aux_info["lattice_size"] = 5
48+
elif data_id == "greater":
49+
aux_info["p"] = 30
50+
elif data_id == "equivalence":
51+
aux_info["mod"] = 5
52+
elif data_id == "circle":
53+
aux_info["p"] = 31
54+
else:
55+
raise ValueError(f"Unknown data_id: {data_id}")
4556

4657
# Train the model
4758
print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}")
@@ -81,20 +92,9 @@
8192
torch.save(model.state_dict(), f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}.pt")
8293
with open(f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_train_results.json", "w") as f:
8394
json.dump(ret_dic["results"], f, indent=4)
84-
85-
aux_info = {}
86-
if data_id == "lattice":
87-
aux_info["lattice_size"] = 5
88-
elif data_id == "greater":
89-
aux_info["p"] = 30
90-
elif data_id == "family_tree":
95+
96+
if data_id == "family_tree":
9197
aux_info["dict_level"] = dataset['dict_level']
92-
elif data_id == "equivalence":
93-
aux_info["mod"] = 10
94-
elif data_id == "circle":
95-
aux_info["p"] = 59
96-
else:
97-
raise ValueError(f"Unknown data_id: {data_id}")
9898

9999
if hasattr(model.embedding, 'weight'):
100100
metric_dict = crystal_metric(model.embedding.weight.cpu().detach(), data_id, aux_info)
@@ -128,19 +128,8 @@
128128
with open(f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_train_results.json", "w") as f:
129129
json.dump(ret_dic["results"], f, indent=4)
130130

131-
aux_info = {}
132-
if data_id == "lattice":
133-
aux_info["lattice_size"] = 5
134-
elif data_id == "greater":
135-
aux_info["p"] = 30
136-
elif data_id == "family_tree":
131+
if data_id == "family_tree":
137132
aux_info["dict_level"] = dataset['dict_level']
138-
elif data_id == "equivalence":
139-
aux_info["mod"] = 10
140-
elif data_id == "circle":
141-
aux_info["p"] = 59
142-
else:
143-
raise ValueError(f"Unknown data_id: {data_id}")
144133

145134
if hasattr(model.embedding, 'weight'):
146135
metric_dict = crystal_metric(model.embedding.weight.cpu().detach(), data_id, aux_info)
@@ -179,19 +168,8 @@
179168
with open(f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_train_results.json", "w") as f:
180169
json.dump(ret_dic["results"], f, indent=4)
181170

182-
aux_info = {}
183-
if data_id == "lattice":
184-
aux_info["lattice_size"] = 5
185-
elif data_id == "greater":
186-
aux_info["p"] = 30
187-
elif data_id == "family_tree":
171+
if data_id == "family_tree":
188172
aux_info["dict_level"] = dataset['dict_level']
189-
elif data_id == "equivalence":
190-
aux_info["mod"] = 10
191-
elif data_id == "circle":
192-
aux_info["p"] = 31
193-
else:
194-
raise ValueError(f"Unknown data_id: {data_id}")
195173

196174
if hasattr(model.embedding, 'weight'):
197175
metric_dict = crystal_metric(model.embedding.weight.cpu().detach(), data_id, aux_info)

src/unit_exp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
model_id_choices = ["H_MLP", "standard_MLP", "H_transformer", "standard_transformer"]
2020
if __name__ == '__main__':
2121
parser = argparse.ArgumentParser(description='Experiment')
22-
parser.add_argument('--seed', type=int, default=51, help='random seed')
22+
parser.add_argument('--seed', type=int, default=29, help='random seed')
2323
parser.add_argument('--data_id', type=str, required=True, choices=data_id_choices, help='Data ID')
2424
parser.add_argument('--model_id', type=str, required=True, choices=model_id_choices, help='Model ID')
2525

@@ -65,7 +65,7 @@
6565
if data_id == "lattice":
6666
aux_info["lattice_size"] = 5
6767
elif data_id == "greater":
68-
aux_info["p"] = 200
68+
aux_info["p"] = 30
6969
elif data_id == "family_tree":
7070
aux_info["dict_level"] = dataset['dict_level']
7171
elif data_id == "equivalence":

src/utils/crystal_metric.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -99,45 +99,39 @@ def greater_metric(reps, aux_info):
9999

100100
def family_tree_metric(reps, aux_info):
101101

102+
dict_level = aux_info['dict_level']
103+
reps = reps[:(max(dict_level.keys()) + 1)]
104+
102105
pca = PCA(n_components=min(reps.shape[0], reps.shape[1]))
103106
reps = pca.fit_transform(reps)
104107
reps = reps[:, :2]
105108

106-
dict_level = aux_info['dict_level']
107-
108-
# Group individuals by generation
109-
generation_groups = {}
110-
for individual, generation in dict_level.items():
111-
if generation not in generation_groups:
112-
generation_groups[generation] = []
113-
generation_groups[generation].append(individual)
114109

115-
116-
# Compute the collinearity of representations for individuals within the same generation
117-
collinearity_by_generation = {}
118-
119-
for generation, individuals in generation_groups.items():
120-
# Get the indices of individuals in this generation
121-
indices = [individual for individual in individuals]
122-
# Extract their representations
123-
gen_representations = reps[indices]
124-
125-
# Compute collinearity by fixing one vector as a pivot
126-
if gen_representations.shape[0] > 2: # Ensure there are at least three individuals
127-
pivot = gen_representations[1] - gen_representations[0] # Difference between first two vectors
128-
dot_products = (gen_representations[2:] - gen_representations[0]) @ pivot
129-
norms = np.linalg.norm((gen_representations[2:] - gen_representations[0]), axis=1) * np.linalg.norm(pivot)
130-
131-
norms = np.where(norms == 0, np.nan, norms)
132-
collinearity = np.abs(dot_products / norms) # Cosine similarity with the pivot
133-
collinearity = np.nan_to_num(collinearity, nan=1.0)
134-
collinearity_by_generation[generation] = collinearity.mean()
110+
# Group embeddings by generation
111+
levels = {}
112+
for node, generation in dict_level.items():
113+
if generation not in levels:
114+
levels[generation] = []
115+
levels[generation].append(reps[node])
116+
117+
# Compute one-dimensionality for each generation
118+
level_scores = {}
119+
for generation, points in levels.items():
120+
if len(points) < 5:
121+
continue
122+
123+
points_array = np.stack(points) # Convert to NumPy array
124+
pca_sub = PCA(n_components=min(points_array.shape[0], points_array.shape[1]))
125+
pca_sub.fit(points_array)
126+
one_dimensionality = pca_sub.explained_variance_ratio_[0] # Ratio of variance explained by the first PC
127+
level_scores[generation] = one_dimensionality
135128

136129

130+
# pca.fit_transform(reps)
137131
variances = pca.explained_variance_ratio_
138132

139133
metric_dict = {
140-
'metric': float(1 - np.mean([collinearity for collinearity in collinearity_by_generation.values() if not np.isnan(collinearity)])),
134+
'metric': float(1 - np.mean(list(level_scores.values()))),
141135
'variances': variances.tolist(),
142136
}
143137
return metric_dict
@@ -156,6 +150,10 @@ def equivalence_metric(reps, aux_info):
156150
else:
157151
diff_arr.append(np.linalg.norm(reps[i] - reps[j]))
158152

153+
# Filter Outliers
154+
diff_arr = np.array(diff_arr)
155+
diff_arr = diff_arr[diff_arr < np.mean(cross_diff_arr)]
156+
159157
pca = PCA(n_components=min(reps.shape[0], reps.shape[1]))
160158
emb_pca = pca.fit_transform(reps)
161159
variances = pca.explained_variance_ratio_
@@ -174,18 +172,28 @@ def circle_metric(reps, aux_info):
174172
emb_pca = pca.fit_transform(reps)
175173
variances = pca.explained_variance_ratio_
176174

175+
points = emb_pca[:, :2]
176+
177+
min_x, min_y = points.min(axis=0)
178+
max_x, max_y = points.max(axis=0)
179+
width = max_x - min_x
180+
height = max_y - min_y
181+
182+
# Normalize points to [0, 1] in both dimensions
183+
normalized_points = (points - [min_x, min_y]) / [width, height]
184+
177185
# Compute the centroid of the points
178-
centroid = np.mean(emb_pca, axis=0)
186+
centroid = np.mean(normalized_points, axis=0)
179187

180188
# Compute distances of points from the centroid
181-
distances = np.linalg.norm(emb_pca - centroid, axis=1)
189+
distances = np.linalg.norm(normalized_points - centroid, axis=1)
182190

183191
# Mean and standard deviation of distances
184192
mean_distance = np.mean(distances)
185193
std_distance = np.std(distances)
186194

187195
# Circularity score
188-
circularity_score = 1 - (std_distance / mean_distance)
196+
circularity_score = (std_distance / mean_distance)
189197

190198

191199
metric_dict = {

src/utils/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def is_desc(a, b):
151151
data_id = torch.from_numpy(x).to(device)
152152
labels = torch.from_numpy(target).to(device)
153153

154-
vocab_size = p+2
154+
vocab_size = p
155155

156156
dataset = {}
157157
dataset['data_id'] = data_id

src/utils/driver.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,14 @@ def train_single_model(param_dict: dict):
6363
dataset = parallelogram_dataset(p=5, dim=2, num=data_size, seed=seed, device=device)
6464
input_token = 3
6565
elif data_id == "greater":
66-
dataset = greater_than_dataset(p=200, num=data_size, seed=seed, device=device)
66+
dataset = greater_than_dataset(p=30, num=data_size, seed=seed, device=device)
6767
elif data_id == "family_tree":
6868
dataset = family_tree_dataset_2(p=127, num=data_size, seed=seed, device=device)
6969
elif data_id == "equivalence":
7070
input_token = 2
71-
dataset = mod_classification_dataset(p=300, num=data_size, seed=seed, device=device)
71+
dataset = mod_classification_dataset(p=100, num=data_size, seed=seed, device=device)
7272
elif data_id == "circle":
7373
dataset = modular_addition_dataset(p=31, num=data_size, seed=seed, device=device)
74-
input_token = 3
7574
else:
7675
raise ValueError(f"Unknown data_id: {data_id}")
7776

@@ -92,9 +91,9 @@ def train_single_model(param_dict: dict):
9291
shp = [input_token * embd_dim, hidden_size, embd_dim, vocab_size]
9392
model = MLP(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, unembd=unembd, weight_tied=weight_tied, seed=seed).to(device)
9493
elif model_id == "H_transformer":
95-
model = ToyTransformer(vocab_size=vocab_size, d_model=embd_dim, nhead=16, num_layers=3, seq_len=input_token, use_dist_layer=True).to(device)
94+
model = ToyTransformer(vocab_size=vocab_size, d_model=embd_dim, nhead=8, num_layers=1, seq_len=input_token, use_dist_layer=True).to(device)
9695
elif model_id == "standard_transformer":
97-
model = ToyTransformer(vocab_size=vocab_size, d_model=embd_dim, nhead=16, num_layers=3, seq_len=input_token, use_dist_layer=False).to(device)
96+
model = ToyTransformer(vocab_size=vocab_size, d_model=embd_dim, nhead=8, num_layers=1, seq_len=input_token, use_dist_layer=False).to(device)
9897
else:
9998
raise ValueError(f"Unknown model_id: {model_id}")
10099

@@ -106,7 +105,7 @@ def train_single_model(param_dict: dict):
106105
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
107106

108107
ret_dic = {}
109-
ret_dic["results"] = model.train(param_dict={'num_epochs': 10000, 'learning_rate': 0.01, 'train_dataloader': train_dataloader, 'test_dataloader': test_dataloader, 'device': device})
108+
ret_dic["results"] = model.train(param_dict={'num_epochs': 4000, 'learning_rate': 0.001, 'train_dataloader': train_dataloader, 'test_dataloader': test_dataloader, 'device': device})
110109
ret_dic["model"] = model
111110
ret_dic["dataset"] = dataset
112111

src/utils/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def train(self, param_dict: dict):
2929
test_accuracies = []
3030

3131
best_loss = float('inf')
32-
patience = 200
32+
patience = 100
3333
min_delta = 1e-4
3434
counter = 0
3535

@@ -57,7 +57,7 @@ def train(self, param_dict: dict):
5757
else:
5858
total_loss = loss + lamb_reg * torch.mean(torch.sqrt(torch.mean(self.embedding.data**2, dim=0)))
5959

60-
loss.backward()
60+
total_loss.backward()
6161
optimizer.step()
6262
train_loss += loss.item()
6363

@@ -100,9 +100,11 @@ def train(self, param_dict: dict):
100100
else:
101101
counter += 1 # Increment counter if no improvement
102102

103+
'''
103104
if counter >= patience:
104105
print("Early stopping triggered!")
105106
break
107+
'''
106108

107109
ret_dic = {}
108110
ret_dic['train_losses'] = train_losses

0 commit comments

Comments
 (0)