Skip to content

Commit ed2a070

Browse files
committed
Added Crystal Metric
1 parent 75bba65 commit ed2a070

File tree

1 file changed

+191
-0
lines changed

1 file changed

+191
-0
lines changed

crystal_metric.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import numpy as np
2+
import torch
3+
4+
from itertools import combinations
5+
from sklearn.decomposition import PCA
6+
7+
def crystal_metric(reps, data_id, aux_info):
8+
"""
9+
Compute the crystal metric for the given representations and data_id.
10+
"""
11+
if data_id == "lattice":
12+
return lattice_metric(reps, aux_info)
13+
elif data_id == "greater":
14+
return greater_metric(reps, aux_info)
15+
elif data_id == "family_tree":
16+
return family_tree_metric(reps, aux_info)
17+
elif data_id == "equivalence":
18+
return equivalence_metric(reps, aux_info)
19+
elif data_id == "circle":
20+
return circle_metric(reps, aux_info)
21+
else:
22+
raise ValueError(f"Unknown data_id: {data_id}")
23+
24+
def lattice_metric(reps, aux_info):
25+
lattice_size = aux_info['lattice_size']
26+
deviation_arr = []
27+
points = [(i, j) for i in range(lattice_size) for j in range(lattice_size)]
28+
29+
def side_length_deviation(a, b, c, d):
30+
a, b, c, d = np.array(a), np.array(b), np.array(c), np.array(d)
31+
32+
# Compute lengths of opposite sides
33+
length_ab = np.linalg.norm(b - a)
34+
length_cd = np.linalg.norm(d - c)
35+
length_ac = np.linalg.norm(c - a)
36+
length_bd = np.linalg.norm(b - d)
37+
length_bc = np.linalg.norm(c - b)
38+
length_ad = np.linalg.norm(d - a)
39+
40+
# Calculate side length deviation
41+
side_deviation = np.sqrt((length_ab - length_cd)**2 + (length_ac - length_bd)**2) / np.sqrt((length_ab ** 2 + length_bc ** 2 + length_cd ** 2 + length_ad ** 2)/2)
42+
43+
return side_deviation
44+
45+
# Compute the deviation from a perfect parallelogram for all quadrilaterals
46+
for quad in combinations(points, 3):
47+
a, b, c = quad
48+
d = (c[0] + b[0] - a[0], c[1] + b[1] - a[1])
49+
if d[0] < 0 or d[0] >= lattice_size or d[1] < 0 or d[1] >= lattice_size:
50+
continue
51+
52+
if a[0] == b[0] and b[0] == c[0]:
53+
continue
54+
if a[1] == b[1] and b[1] == c[1]:
55+
continue
56+
57+
a = lattice_size * a[0] + a[1]
58+
b = lattice_size * b[0] + b[1]
59+
c = lattice_size * c[0] + c[1]
60+
d = lattice_size * d[0] + d[1]
61+
62+
a = reps[a]
63+
b = reps[b]
64+
c = reps[c]
65+
d = reps[d]
66+
deviation = side_length_deviation(a, b, c, d)
67+
deviation_arr.append(deviation)
68+
69+
# Obtatin explained variance ratios
70+
pca = PCA(n_components=min(reps.shape[0], reps.shape[1]))
71+
emb_pca = pca.fit_transform(reps)
72+
variances = pca.explained_variance_ratio_
73+
74+
metric_dict = {
75+
'metric': np.mean(deviation_arr),
76+
'variances': variances,
77+
}
78+
79+
return metric_dict
80+
81+
82+
def greater_metric(reps, aux_info):
83+
diff_arr = []
84+
85+
# Compute the difference between consecutive representations
86+
# We expect the perfect representation to be equidistant
87+
for i in range(reps.shape[0]-1):
88+
diff_arr.append(np.linalg.norm(reps[i] - reps[i+1]))
89+
90+
pca = PCA(n_components=min(reps.shape[0], reps.shape[1]))
91+
emb_pca = pca.fit_transform(reps)
92+
variances = pca.explained_variance_ratio_
93+
94+
metric_dict = {
95+
'metric': np.std(diff_arr) / np.mean(diff_arr),
96+
'variances': variances,
97+
}
98+
return metric_dict
99+
100+
def family_tree_metric(reps, aux_info):
101+
dict_level = aux_info['dict_level']
102+
103+
# Group individuals by generation
104+
generation_groups = {}
105+
for individual, generation in dict_level.items():
106+
if generation not in generation_groups:
107+
generation_groups[generation] = []
108+
generation_groups[generation].append(individual)
109+
110+
111+
# Compute the collinearity of representations for individuals within the same generation
112+
collinearity_by_generation = {}
113+
114+
for generation, individuals in generation_groups.items():
115+
# Get the indices of individuals in this generation
116+
indices = [individual for individual in individuals]
117+
# Extract their representations
118+
gen_representations = reps[indices]
119+
120+
# Compute collinearity by fixing one vector as a pivot
121+
if gen_representations.shape[0] > 2: # Ensure there are at least three individuals
122+
pivot = gen_representations[1] - gen_representations[0] # Difference between first two vectors
123+
dot_products = (gen_representations[2:] - gen_representations[0]) @ pivot
124+
norms = np.linalg.norm((gen_representations[2:] - gen_representations[0]), axis=1) * np.linalg.norm(pivot)
125+
126+
collinearity = np.abs(dot_products / norms) # Cosine similarity with the pivot
127+
collinearity_by_generation[generation] = collinearity.mean()
128+
print(collinearity.mean())
129+
130+
131+
pca = PCA(n_components=min(reps.shape[0], reps.shape[1]))
132+
emb_pca = pca.fit_transform(reps)
133+
variances = pca.explained_variance_ratio_
134+
135+
metric_dict = {
136+
'metric': 1 - np.mean([collinearity for collinearity in collinearity_by_generation.values() if not np.isnan(collinearity)]),
137+
'variances': variances,
138+
}
139+
return metric_dict
140+
141+
def equivalence_metric(reps, aux_info):
142+
mod = aux_info['mod']
143+
n = reps.shape[0]
144+
145+
# Compute the difference between representations within the same equivalence class
146+
diff_arr = []
147+
cross_diff_arr = []
148+
for i in range(n):
149+
for j in range(n):
150+
if i % mod != j % mod:
151+
cross_diff_arr.append(np.linalg.norm(reps[i] - reps[j]))
152+
else:
153+
diff_arr.append(np.linalg.norm(reps[i] - reps[j]))
154+
155+
pca = PCA(n_components=min(reps.shape[0], reps.shape[1]))
156+
emb_pca = pca.fit_transform(reps)
157+
variances = pca.explained_variance_ratio_
158+
159+
print(np.mean(diff_arr) , np.mean(cross_diff_arr))
160+
metric_dict = {
161+
'metric': np.mean(diff_arr) / np.mean(cross_diff_arr),
162+
'variances': variances,
163+
}
164+
return metric_dict
165+
166+
167+
def circle_metric(reps, aux_info):
168+
169+
pca = PCA(n_components=min(reps.shape[0], reps.shape[1]))
170+
emb_pca = pca.fit_transform(reps)
171+
variances = pca.explained_variance_ratio_
172+
173+
# Compute the centroid of the points
174+
centroid = np.mean(emb_pca, axis=0)
175+
176+
# Compute distances of points from the centroid
177+
distances = np.linalg.norm(emb_pca - centroid, axis=1)
178+
179+
# Mean and standard deviation of distances
180+
mean_distance = np.mean(distances)
181+
std_distance = np.std(distances)
182+
183+
# Circularity score
184+
circularity_score = 1 - (std_distance / mean_distance)
185+
186+
187+
metric_dict = {
188+
'metric': circularity_score,
189+
'variances': variances,
190+
}
191+
return metric_dict

0 commit comments

Comments
 (0)