diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index fa4be9171c..db266c6c8b 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -73,8 +73,7 @@ def test_consistency( self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE ) - for mixed_types, nfp, nap in itertools.product( - [True, False], + for nfp, nap in itertools.product( [0, 3], [0, 4], ): @@ -84,7 +83,7 @@ def test_consistency( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=mixed_types, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) ft1 = DPDipoleFitting.deserialize(ft0.serialize()) ft2 = DipoleFittingNet.deserialize(ft1.serialize()) @@ -159,9 +158,10 @@ def test_rot(self): atype = self.atype.reshape(1, 5) rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE) coord_rot = torch.matmul(self.coord, rmat) + # use larger cell to rotate only coord and shift to the center of cell + cell_rot = 10.0 * torch.eye(3, dtype=dtype, device=env.DEVICE) rng = np.random.default_rng() - for mixed_types, nfp, nap in itertools.product( - [True, False], + for nfp, nap in itertools.product( [0, 3], [0, 4], ): @@ -171,7 +171,7 @@ def test_rot(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=mixed_types, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) if nfp > 0: ifp = torch.tensor( @@ -196,7 +196,12 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz + self.shift, atype, self.rcut, self.sel, not mixed_types + xyz + self.shift, + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=cell_rot, ) rd0, gr0, _, _, _ = self.dd0( @@ -205,7 +210,7 @@ def test_rot(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) + ret0 = ft0(rd0, atype, gr0, fparam=ifp, aparam=iap) res.append(ret0["dipole"]) np.testing.assert_allclose( @@ -220,7 +225,7 @@ def test_permu(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=False, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) res = [] for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]: @@ -231,7 +236,12 @@ def test_permu(self): _, nlist, ) = extend_input_and_build_neighbor_list( - coord[idx_perm], atype, self.rcut, self.sel, True + coord[idx_perm], + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=self.cell, ) rd0, gr0, _, _, _ = self.dd0( @@ -240,7 +250,7 @@ def test_permu(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) res.append(ret0["dipole"]) np.testing.assert_allclose( @@ -261,7 +271,7 @@ def test_trans(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) res = [] for xyz in [self.coord, coord_s]: @@ -271,7 +281,7 @@ def test_trans(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz, atype, self.rcut, self.sel, False + xyz, atype, self.rcut, self.sel, self.dd0.mixed_types(), box=self.cell ) rd0, gr0, _, _, _ = self.dd0( @@ -280,7 +290,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) res.append(ret0["dipole"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -305,7 +315,7 @@ def setUp(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) self.type_mapping = ["O", "H", "B"] self.model = DipoleModel(self.dd0, self.ft0, self.type_mapping) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 84d6bd91ab..6826807a45 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -59,8 +59,7 @@ def test_consistency( self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE ) - for mixed_types, nfp, nap, fit_diag, scale in itertools.product( - [True, False], + for nfp, nap, fit_diag, scale in itertools.product( [0, 3], [0, 4], [True, False], @@ -72,7 +71,7 @@ def test_consistency( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=mixed_types, + mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -166,9 +165,10 @@ def test_rot(self): atype = self.atype.reshape(1, 5) rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE) coord_rot = torch.matmul(self.coord, rmat) + # use larger cell to rotate only coord and shift to the center of cell + cell_rot = 10.0 * torch.eye(3, dtype=dtype, device=env.DEVICE) - for mixed_types, nfp, nap, fit_diag, scale in itertools.product( - [True, False], + for nfp, nap, fit_diag, scale in itertools.product( [0, 3], [0, 4], [True, False], @@ -180,7 +180,7 @@ def test_rot(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=True, + mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -207,7 +207,12 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz + self.shift, atype, self.rcut, self.sel, mixed_types + xyz + self.shift, + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=cell_rot, ) rd0, gr0, _, _, _ = self.dd0( @@ -216,7 +221,7 @@ def test_rot(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) + ret0 = ft0(rd0, atype, gr0, fparam=ifp, aparam=iap) res.append(ret0["polarizability"]) np.testing.assert_allclose( to_numpy_array(res[1]), @@ -237,7 +242,7 @@ def test_permu(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -250,7 +255,12 @@ def test_permu(self): _, nlist, ) = extend_input_and_build_neighbor_list( - coord[idx_perm], atype, self.rcut, self.sel, False + coord[idx_perm], + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=self.cell, ) rd0, gr0, _, _, _ = self.dd0( @@ -259,7 +269,7 @@ def test_permu(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=None, aparam=None) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["polarizability"]) np.testing.assert_allclose( @@ -269,7 +279,12 @@ def test_permu(self): def test_trans(self): atype = self.atype.reshape(1, 5) - coord_s = self.coord + self.shift + coord_s = torch.matmul( + torch.remainder( + torch.matmul(self.coord + self.shift, torch.linalg.inv(self.cell)), 1.0 + ), + self.cell, + ) for fit_diag, scale in itertools.product([True, False], [None, self.scale]): ft0 = PolarFittingNet( self.nt, @@ -277,7 +292,7 @@ def test_trans(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -289,7 +304,12 @@ def test_trans(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz, atype, self.rcut, self.sel, False + xyz, + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=self.cell, ) rd0, gr0, _, _, _ = self.dd0( @@ -298,7 +318,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) res.append(ret0["polarizability"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -323,7 +343,7 @@ def setUp(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) self.type_mapping = ["O", "H", "B"] self.model = PolarModel(self.dd0, self.ft0, self.type_mapping)