|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | + |
| 4 | +from itertools import combinations |
| 5 | +from sklearn.decomposition import PCA |
| 6 | + |
| 7 | +def crystal_metric(reps, data_id, aux_info): |
| 8 | + """ |
| 9 | + Compute the crystal metric for the given representations and data_id. |
| 10 | + """ |
| 11 | + if data_id == "lattice": |
| 12 | + return lattice_metric(reps, aux_info) |
| 13 | + elif data_id == "greater": |
| 14 | + return greater_metric(reps, aux_info) |
| 15 | + elif data_id == "family_tree": |
| 16 | + return family_tree_metric(reps, aux_info) |
| 17 | + elif data_id == "equivalence": |
| 18 | + return equivalence_metric(reps, aux_info) |
| 19 | + elif data_id == "circle": |
| 20 | + return circle_metric(reps, aux_info) |
| 21 | + else: |
| 22 | + raise ValueError(f"Unknown data_id: {data_id}") |
| 23 | + |
| 24 | +def lattice_metric(reps, aux_info): |
| 25 | + lattice_size = aux_info['lattice_size'] |
| 26 | + deviation_arr = [] |
| 27 | + points = [(i, j) for i in range(lattice_size) for j in range(lattice_size)] |
| 28 | + |
| 29 | + def side_length_deviation(a, b, c, d): |
| 30 | + a, b, c, d = np.array(a), np.array(b), np.array(c), np.array(d) |
| 31 | + |
| 32 | + # Compute lengths of opposite sides |
| 33 | + length_ab = np.linalg.norm(b - a) |
| 34 | + length_cd = np.linalg.norm(d - c) |
| 35 | + length_ac = np.linalg.norm(c - a) |
| 36 | + length_bd = np.linalg.norm(b - d) |
| 37 | + length_bc = np.linalg.norm(c - b) |
| 38 | + length_ad = np.linalg.norm(d - a) |
| 39 | + |
| 40 | + # Calculate side length deviation |
| 41 | + side_deviation = np.sqrt((length_ab - length_cd)**2 + (length_ac - length_bd)**2) / np.sqrt((length_ab ** 2 + length_bc ** 2 + length_cd ** 2 + length_ad ** 2)/2) |
| 42 | + |
| 43 | + return side_deviation |
| 44 | + |
| 45 | + # Compute the deviation from a perfect parallelogram for all quadrilaterals |
| 46 | + for quad in combinations(points, 3): |
| 47 | + a, b, c = quad |
| 48 | + d = (c[0] + b[0] - a[0], c[1] + b[1] - a[1]) |
| 49 | + if d[0] < 0 or d[0] >= lattice_size or d[1] < 0 or d[1] >= lattice_size: |
| 50 | + continue |
| 51 | + |
| 52 | + if a[0] == b[0] and b[0] == c[0]: |
| 53 | + continue |
| 54 | + if a[1] == b[1] and b[1] == c[1]: |
| 55 | + continue |
| 56 | + |
| 57 | + a = lattice_size * a[0] + a[1] |
| 58 | + b = lattice_size * b[0] + b[1] |
| 59 | + c = lattice_size * c[0] + c[1] |
| 60 | + d = lattice_size * d[0] + d[1] |
| 61 | + |
| 62 | + a = reps[a] |
| 63 | + b = reps[b] |
| 64 | + c = reps[c] |
| 65 | + d = reps[d] |
| 66 | + deviation = side_length_deviation(a, b, c, d) |
| 67 | + deviation_arr.append(deviation) |
| 68 | + |
| 69 | + # Obtatin explained variance ratios |
| 70 | + pca = PCA(n_components=min(reps.shape[0], reps.shape[1])) |
| 71 | + emb_pca = pca.fit_transform(reps) |
| 72 | + variances = pca.explained_variance_ratio_ |
| 73 | + |
| 74 | + metric_dict = { |
| 75 | + 'metric': np.mean(deviation_arr), |
| 76 | + 'variances': variances, |
| 77 | + } |
| 78 | + |
| 79 | + return metric_dict |
| 80 | + |
| 81 | + |
| 82 | +def greater_metric(reps, aux_info): |
| 83 | + diff_arr = [] |
| 84 | + |
| 85 | + # Compute the difference between consecutive representations |
| 86 | + # We expect the perfect representation to be equidistant |
| 87 | + for i in range(reps.shape[0]-1): |
| 88 | + diff_arr.append(np.linalg.norm(reps[i] - reps[i+1])) |
| 89 | + |
| 90 | + pca = PCA(n_components=min(reps.shape[0], reps.shape[1])) |
| 91 | + emb_pca = pca.fit_transform(reps) |
| 92 | + variances = pca.explained_variance_ratio_ |
| 93 | + |
| 94 | + metric_dict = { |
| 95 | + 'metric': np.std(diff_arr) / np.mean(diff_arr), |
| 96 | + 'variances': variances, |
| 97 | + } |
| 98 | + return metric_dict |
| 99 | + |
| 100 | +def family_tree_metric(reps, aux_info): |
| 101 | + dict_level = aux_info['dict_level'] |
| 102 | + |
| 103 | + # Group individuals by generation |
| 104 | + generation_groups = {} |
| 105 | + for individual, generation in dict_level.items(): |
| 106 | + if generation not in generation_groups: |
| 107 | + generation_groups[generation] = [] |
| 108 | + generation_groups[generation].append(individual) |
| 109 | + |
| 110 | + |
| 111 | + # Compute the collinearity of representations for individuals within the same generation |
| 112 | + collinearity_by_generation = {} |
| 113 | + |
| 114 | + for generation, individuals in generation_groups.items(): |
| 115 | + # Get the indices of individuals in this generation |
| 116 | + indices = [individual for individual in individuals] |
| 117 | + # Extract their representations |
| 118 | + gen_representations = reps[indices] |
| 119 | + |
| 120 | + # Compute collinearity by fixing one vector as a pivot |
| 121 | + if gen_representations.shape[0] > 2: # Ensure there are at least three individuals |
| 122 | + pivot = gen_representations[1] - gen_representations[0] # Difference between first two vectors |
| 123 | + dot_products = (gen_representations[2:] - gen_representations[0]) @ pivot |
| 124 | + norms = np.linalg.norm((gen_representations[2:] - gen_representations[0]), axis=1) * np.linalg.norm(pivot) |
| 125 | + |
| 126 | + collinearity = np.abs(dot_products / norms) # Cosine similarity with the pivot |
| 127 | + collinearity_by_generation[generation] = collinearity.mean() |
| 128 | + print(collinearity.mean()) |
| 129 | + |
| 130 | + |
| 131 | + pca = PCA(n_components=min(reps.shape[0], reps.shape[1])) |
| 132 | + emb_pca = pca.fit_transform(reps) |
| 133 | + variances = pca.explained_variance_ratio_ |
| 134 | + |
| 135 | + metric_dict = { |
| 136 | + 'metric': 1 - np.mean([collinearity for collinearity in collinearity_by_generation.values() if not np.isnan(collinearity)]), |
| 137 | + 'variances': variances, |
| 138 | + } |
| 139 | + return metric_dict |
| 140 | + |
| 141 | +def equivalence_metric(reps, aux_info): |
| 142 | + mod = aux_info['mod'] |
| 143 | + n = reps.shape[0] |
| 144 | + |
| 145 | + # Compute the difference between representations within the same equivalence class |
| 146 | + diff_arr = [] |
| 147 | + cross_diff_arr = [] |
| 148 | + for i in range(n): |
| 149 | + for j in range(n): |
| 150 | + if i % mod != j % mod: |
| 151 | + cross_diff_arr.append(np.linalg.norm(reps[i] - reps[j])) |
| 152 | + else: |
| 153 | + diff_arr.append(np.linalg.norm(reps[i] - reps[j])) |
| 154 | + |
| 155 | + pca = PCA(n_components=min(reps.shape[0], reps.shape[1])) |
| 156 | + emb_pca = pca.fit_transform(reps) |
| 157 | + variances = pca.explained_variance_ratio_ |
| 158 | + |
| 159 | + print(np.mean(diff_arr) , np.mean(cross_diff_arr)) |
| 160 | + metric_dict = { |
| 161 | + 'metric': np.mean(diff_arr) / np.mean(cross_diff_arr), |
| 162 | + 'variances': variances, |
| 163 | + } |
| 164 | + return metric_dict |
| 165 | + |
| 166 | + |
| 167 | +def circle_metric(reps, aux_info): |
| 168 | + |
| 169 | + pca = PCA(n_components=min(reps.shape[0], reps.shape[1])) |
| 170 | + emb_pca = pca.fit_transform(reps) |
| 171 | + variances = pca.explained_variance_ratio_ |
| 172 | + |
| 173 | + # Compute the centroid of the points |
| 174 | + centroid = np.mean(emb_pca, axis=0) |
| 175 | + |
| 176 | + # Compute distances of points from the centroid |
| 177 | + distances = np.linalg.norm(emb_pca - centroid, axis=1) |
| 178 | + |
| 179 | + # Mean and standard deviation of distances |
| 180 | + mean_distance = np.mean(distances) |
| 181 | + std_distance = np.std(distances) |
| 182 | + |
| 183 | + # Circularity score |
| 184 | + circularity_score = 1 - (std_distance / mean_distance) |
| 185 | + |
| 186 | + |
| 187 | + metric_dict = { |
| 188 | + 'metric': circularity_score, |
| 189 | + 'variances': variances, |
| 190 | + } |
| 191 | + return metric_dict |
0 commit comments