Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions source/tests/pt/model/test_dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ def test_consistency(
self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE
)

for mixed_types, nfp, nap in itertools.product(
[True, False],
for nfp, nap in itertools.product(
[0, 3],
[0, 4],
):
Expand All @@ -84,7 +83,7 @@ def test_consistency(
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=nfp,
numb_aparam=nap,
mixed_types=mixed_types,
mixed_types=self.dd0.mixed_types(),
).to(env.DEVICE)
ft1 = DPDipoleFitting.deserialize(ft0.serialize())
ft2 = DipoleFittingNet.deserialize(ft1.serialize())
Expand Down Expand Up @@ -159,9 +158,10 @@ def test_rot(self):
atype = self.atype.reshape(1, 5)
rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE)
coord_rot = torch.matmul(self.coord, rmat)
# use larger cell to rotate only coord and shift to the center of cell
cell_rot = 10.0 * torch.eye(3, dtype=dtype, device=env.DEVICE)
rng = np.random.default_rng()
for mixed_types, nfp, nap in itertools.product(
[True, False],
for nfp, nap in itertools.product(
[0, 3],
[0, 4],
):
Expand All @@ -171,7 +171,7 @@ def test_rot(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=nfp,
numb_aparam=nap,
mixed_types=mixed_types,
mixed_types=self.dd0.mixed_types(),
).to(env.DEVICE)
if nfp > 0:
ifp = torch.tensor(
Expand All @@ -196,7 +196,12 @@ def test_rot(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
xyz + self.shift, atype, self.rcut, self.sel, not mixed_types
xyz + self.shift,
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=cell_rot,
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -205,7 +210,7 @@ def test_rot(self):
nlist,
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap)
ret0 = ft0(rd0, atype, gr0, fparam=ifp, aparam=iap)
res.append(ret0["dipole"])

np.testing.assert_allclose(
Expand All @@ -220,7 +225,7 @@ def test_permu(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=0,
numb_aparam=0,
mixed_types=False,
mixed_types=self.dd0.mixed_types(),
).to(env.DEVICE)
res = []
for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]:
Expand All @@ -231,7 +236,12 @@ def test_permu(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
coord[idx_perm], atype, self.rcut, self.sel, True
coord[idx_perm],
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=self.cell,
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -240,7 +250,7 @@ def test_permu(self):
nlist,
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
res.append(ret0["dipole"])

np.testing.assert_allclose(
Expand All @@ -261,7 +271,7 @@ def test_trans(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=0,
numb_aparam=0,
mixed_types=True,
mixed_types=self.dd0.mixed_types(),
).to(env.DEVICE)
res = []
for xyz in [self.coord, coord_s]:
Expand All @@ -271,7 +281,7 @@ def test_trans(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
xyz, atype, self.rcut, self.sel, False
xyz, atype, self.rcut, self.sel, self.dd0.mixed_types(), box=self.cell
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -280,7 +290,7 @@ def test_trans(self):
nlist,
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
res.append(ret0["dipole"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))
Expand All @@ -305,7 +315,7 @@ def setUp(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=0,
numb_aparam=0,
mixed_types=True,
mixed_types=self.dd0.mixed_types(),
).to(env.DEVICE)
self.type_mapping = ["O", "H", "B"]
self.model = DipoleModel(self.dd0, self.ft0, self.type_mapping)
Expand Down
52 changes: 36 additions & 16 deletions source/tests/pt/model/test_polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def test_consistency(
self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE
)

for mixed_types, nfp, nap, fit_diag, scale in itertools.product(
[True, False],
for nfp, nap, fit_diag, scale in itertools.product(
[0, 3],
[0, 4],
[True, False],
Expand All @@ -72,7 +71,7 @@ def test_consistency(
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=nfp,
numb_aparam=nap,
mixed_types=mixed_types,
mixed_types=self.dd0.mixed_types(),
fit_diag=fit_diag,
scale=scale,
).to(env.DEVICE)
Expand Down Expand Up @@ -166,9 +165,10 @@ def test_rot(self):
atype = self.atype.reshape(1, 5)
rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE)
coord_rot = torch.matmul(self.coord, rmat)
# use larger cell to rotate only coord and shift to the center of cell
cell_rot = 10.0 * torch.eye(3, dtype=dtype, device=env.DEVICE)

for mixed_types, nfp, nap, fit_diag, scale in itertools.product(
[True, False],
for nfp, nap, fit_diag, scale in itertools.product(
[0, 3],
[0, 4],
[True, False],
Expand All @@ -180,7 +180,7 @@ def test_rot(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=nfp,
numb_aparam=nap,
mixed_types=True,
mixed_types=self.dd0.mixed_types(),
fit_diag=fit_diag,
scale=scale,
).to(env.DEVICE)
Expand All @@ -207,7 +207,12 @@ def test_rot(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
xyz + self.shift, atype, self.rcut, self.sel, mixed_types
xyz + self.shift,
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=cell_rot,
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -216,7 +221,7 @@ def test_rot(self):
nlist,
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap)
ret0 = ft0(rd0, atype, gr0, fparam=ifp, aparam=iap)
res.append(ret0["polarizability"])
np.testing.assert_allclose(
to_numpy_array(res[1]),
Expand All @@ -237,7 +242,7 @@ def test_permu(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=0,
numb_aparam=0,
mixed_types=True,
mixed_types=self.dd0.mixed_types(),
fit_diag=fit_diag,
scale=scale,
).to(env.DEVICE)
Expand All @@ -250,7 +255,12 @@ def test_permu(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
coord[idx_perm], atype, self.rcut, self.sel, False
coord[idx_perm],
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=self.cell,
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -259,7 +269,7 @@ def test_permu(self):
nlist,
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=None, aparam=None)
ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None)
res.append(ret0["polarizability"])

np.testing.assert_allclose(
Expand All @@ -269,15 +279,20 @@ def test_permu(self):

def test_trans(self):
atype = self.atype.reshape(1, 5)
coord_s = self.coord + self.shift
coord_s = torch.matmul(
torch.remainder(
torch.matmul(self.coord + self.shift, torch.linalg.inv(self.cell)), 1.0
),
self.cell,
)
for fit_diag, scale in itertools.product([True, False], [None, self.scale]):
ft0 = PolarFittingNet(
self.nt,
self.dd0.dim_out,
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=0,
numb_aparam=0,
mixed_types=True,
mixed_types=self.dd0.mixed_types(),
fit_diag=fit_diag,
scale=scale,
).to(env.DEVICE)
Expand All @@ -289,7 +304,12 @@ def test_trans(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
xyz, atype, self.rcut, self.sel, False
xyz,
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=self.cell,
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -298,7 +318,7 @@ def test_trans(self):
nlist,
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
res.append(ret0["polarizability"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))
Expand All @@ -323,7 +343,7 @@ def setUp(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=0,
numb_aparam=0,
mixed_types=True,
mixed_types=self.dd0.mixed_types(),
).to(env.DEVICE)
self.type_mapping = ["O", "H", "B"]
self.model = PolarModel(self.dd0, self.ft0, self.type_mapping)
Expand Down