Skip to content

Commit 70971e0

Browse files
committed
Added exception handlingfor computing family tree metric with duplicate data points
1 parent 188e540 commit 70971e0

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/utils/crystal_metric.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,11 @@ def family_tree_metric(reps, aux_info):
123123
dot_products = (gen_representations[2:] - gen_representations[0]) @ pivot
124124
norms = np.linalg.norm((gen_representations[2:] - gen_representations[0]), axis=1) * np.linalg.norm(pivot)
125125

126+
norms = np.where(norms == 0, np.nan, norms)
126127
collinearity = np.abs(dot_products / norms) # Cosine similarity with the pivot
128+
collinearity = np.nan_to_num(collinearity, nan=1.0)
127129
collinearity_by_generation[generation] = collinearity.mean()
128-
print(collinearity.mean())
129-
130+
130131

131132
pca = PCA(n_components=min(reps.shape[0], reps.shape[1]))
132133
emb_pca = pca.fit_transform(reps)

0 commit comments

Comments
 (0)