From a42eb627586fc2cff1c08bb8e8ce331564a39064 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 29 May 2024 17:40:05 +0800 Subject: [PATCH 1/4] fix: bugs in uts for polar and dipole fit --- source/tests/pt/model/test_dipole_fitting.py | 26 +++++++------- .../pt/model/test_polarizability_fitting.py | 35 ++++++++++--------- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index fa4be9171c..2b1f097175 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -73,8 +73,7 @@ def test_consistency( self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE ) - for mixed_types, nfp, nap in itertools.product( - [True, False], + for nfp, nap in itertools.product( [0, 3], [0, 4], ): @@ -84,7 +83,7 @@ def test_consistency( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=mixed_types, + mixed_types=False, ).to(env.DEVICE) ft1 = DPDipoleFitting.deserialize(ft0.serialize()) ft2 = DipoleFittingNet.deserialize(ft1.serialize()) @@ -160,8 +159,7 @@ def test_rot(self): rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE) coord_rot = torch.matmul(self.coord, rmat) rng = np.random.default_rng() - for mixed_types, nfp, nap in itertools.product( - [True, False], + for nfp, nap in itertools.product( [0, 3], [0, 4], ): @@ -171,7 +169,7 @@ def test_rot(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=mixed_types, + mixed_types=False, ).to(env.DEVICE) if nfp > 0: ifp = torch.tensor( @@ -196,7 +194,7 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz + self.shift, atype, self.rcut, self.sel, not mixed_types + xyz + self.shift, atype, self.rcut, self.sel, False, box=self.cell ) rd0, gr0, _, _, _ = self.dd0( @@ -205,7 +203,7 @@ def test_rot(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) + ret0 = ft0(rd0, atype, gr0, fparam=ifp, aparam=iap) res.append(ret0["dipole"]) np.testing.assert_allclose( @@ -231,7 +229,7 @@ def test_permu(self): _, nlist, ) = extend_input_and_build_neighbor_list( - coord[idx_perm], atype, self.rcut, self.sel, True + coord[idx_perm], atype, self.rcut, self.sel, False, box=self.cell ) rd0, gr0, _, _, _ = self.dd0( @@ -240,7 +238,7 @@ def test_permu(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) res.append(ret0["dipole"]) np.testing.assert_allclose( @@ -261,7 +259,7 @@ def test_trans(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=False, ).to(env.DEVICE) res = [] for xyz in [self.coord, coord_s]: @@ -271,7 +269,7 @@ def test_trans(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz, atype, self.rcut, self.sel, False + xyz, atype, self.rcut, self.sel, False, box=self.cell ) rd0, gr0, _, _, _ = self.dd0( @@ -280,7 +278,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) res.append(ret0["dipole"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -305,7 +303,7 @@ def setUp(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=False, ).to(env.DEVICE) self.type_mapping = ["O", "H", "B"] self.model = DipoleModel(self.dd0, self.ft0, self.type_mapping) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 84d6bd91ab..b8ac6ff243 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -59,8 +59,7 @@ def test_consistency( self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE ) - for mixed_types, nfp, nap, fit_diag, scale in itertools.product( - [True, False], + for nfp, nap, fit_diag, scale in itertools.product( [0, 3], [0, 4], [True, False], @@ -72,7 +71,7 @@ def test_consistency( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=mixed_types, + mixed_types=False, fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -167,8 +166,7 @@ def test_rot(self): rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE) coord_rot = torch.matmul(self.coord, rmat) - for mixed_types, nfp, nap, fit_diag, scale in itertools.product( - [True, False], + for nfp, nap, fit_diag, scale in itertools.product( [0, 3], [0, 4], [True, False], @@ -180,7 +178,7 @@ def test_rot(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=True, + mixed_types=False, fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -207,7 +205,7 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz + self.shift, atype, self.rcut, self.sel, mixed_types + xyz + self.shift, atype, self.rcut, self.sel, False, box=self.cell ) rd0, gr0, _, _, _ = self.dd0( @@ -216,7 +214,7 @@ def test_rot(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) + ret0 = ft0(rd0, atype, gr0, fparam=ifp, aparam=iap) res.append(ret0["polarizability"]) np.testing.assert_allclose( to_numpy_array(res[1]), @@ -237,7 +235,7 @@ def test_permu(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=False, fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -250,7 +248,7 @@ def test_permu(self): _, nlist, ) = extend_input_and_build_neighbor_list( - coord[idx_perm], atype, self.rcut, self.sel, False + coord[idx_perm], atype, self.rcut, self.sel, False, box=self.cell ) rd0, gr0, _, _, _ = self.dd0( @@ -259,7 +257,7 @@ def test_permu(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=None, aparam=None) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["polarizability"]) np.testing.assert_allclose( @@ -269,7 +267,12 @@ def test_permu(self): def test_trans(self): atype = self.atype.reshape(1, 5) - coord_s = self.coord + self.shift + coord_s = torch.matmul( + torch.remainder( + torch.matmul(self.coord + self.shift, torch.linalg.inv(self.cell)), 1.0 + ), + self.cell, + ) for fit_diag, scale in itertools.product([True, False], [None, self.scale]): ft0 = PolarFittingNet( self.nt, @@ -277,7 +280,7 @@ def test_trans(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=False, fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -289,7 +292,7 @@ def test_trans(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz, atype, self.rcut, self.sel, False + xyz, atype, self.rcut, self.sel, False, box=self.cell ) rd0, gr0, _, _, _ = self.dd0( @@ -298,7 +301,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) res.append(ret0["polarizability"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -323,7 +326,7 @@ def setUp(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=False, ).to(env.DEVICE) self.type_mapping = ["O", "H", "B"] self.model = PolarModel(self.dd0, self.ft0, self.type_mapping) From 07b611e8c55bc2382d24ef93fa5b6857bc2055d5 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 29 May 2024 18:04:25 +0800 Subject: [PATCH 2/4] rm shift from rot test --- source/tests/pt/model/test_dipole_fitting.py | 2 +- source/tests/pt/model/test_polarizability_fitting.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index 2b1f097175..7fb05f935b 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -194,7 +194,7 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz + self.shift, atype, self.rcut, self.sel, False, box=self.cell + xyz, atype, self.rcut, self.sel, False, box=self.cell ) rd0, gr0, _, _, _ = self.dd0( diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index b8ac6ff243..5c85920e29 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -205,7 +205,7 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz + self.shift, atype, self.rcut, self.sel, False, box=self.cell + xyz, atype, self.rcut, self.sel, False, box=self.cell ) rd0, gr0, _, _, _ = self.dd0( From ebd928989d03c02b5ae24b35dbbb8df9df184036 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 29 May 2024 18:16:45 +0800 Subject: [PATCH 3/4] Revert "rm shift from rot test" This reverts commit 07b611e8c55bc2382d24ef93fa5b6857bc2055d5. --- source/tests/pt/model/test_dipole_fitting.py | 2 +- source/tests/pt/model/test_polarizability_fitting.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index 7fb05f935b..2b1f097175 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -194,7 +194,7 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz, atype, self.rcut, self.sel, False, box=self.cell + xyz + self.shift, atype, self.rcut, self.sel, False, box=self.cell ) rd0, gr0, _, _, _ = self.dd0( diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 5c85920e29..b8ac6ff243 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -205,7 +205,7 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz, atype, self.rcut, self.sel, False, box=self.cell + xyz + self.shift, atype, self.rcut, self.sel, False, box=self.cell ) rd0, gr0, _, _, _ = self.dd0( From 328816fb5d145c50c4b8a111d9c4ea00f6cfa97e Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 30 May 2024 14:30:47 +0800 Subject: [PATCH 4/4] fix uts and conversations --- source/tests/pt/model/test_dipole_fitting.py | 28 +++++++++++----- .../pt/model/test_polarizability_fitting.py | 33 ++++++++++++++----- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index 2b1f097175..db266c6c8b 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -83,7 +83,7 @@ def test_consistency( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=False, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) ft1 = DPDipoleFitting.deserialize(ft0.serialize()) ft2 = DipoleFittingNet.deserialize(ft1.serialize()) @@ -158,6 +158,8 @@ def test_rot(self): atype = self.atype.reshape(1, 5) rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE) coord_rot = torch.matmul(self.coord, rmat) + # use larger cell to rotate only coord and shift to the center of cell + cell_rot = 10.0 * torch.eye(3, dtype=dtype, device=env.DEVICE) rng = np.random.default_rng() for nfp, nap in itertools.product( [0, 3], @@ -169,7 +171,7 @@ def test_rot(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=False, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) if nfp > 0: ifp = torch.tensor( @@ -194,7 +196,12 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz + self.shift, atype, self.rcut, self.sel, False, box=self.cell + xyz + self.shift, + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=cell_rot, ) rd0, gr0, _, _, _ = self.dd0( @@ -218,7 +225,7 @@ def test_permu(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=False, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) res = [] for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]: @@ -229,7 +236,12 @@ def test_permu(self): _, nlist, ) = extend_input_and_build_neighbor_list( - coord[idx_perm], atype, self.rcut, self.sel, False, box=self.cell + coord[idx_perm], + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=self.cell, ) rd0, gr0, _, _, _ = self.dd0( @@ -259,7 +271,7 @@ def test_trans(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=False, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) res = [] for xyz in [self.coord, coord_s]: @@ -269,7 +281,7 @@ def test_trans(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz, atype, self.rcut, self.sel, False, box=self.cell + xyz, atype, self.rcut, self.sel, self.dd0.mixed_types(), box=self.cell ) rd0, gr0, _, _, _ = self.dd0( @@ -303,7 +315,7 @@ def setUp(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=False, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) self.type_mapping = ["O", "H", "B"] self.model = DipoleModel(self.dd0, self.ft0, self.type_mapping) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index b8ac6ff243..6826807a45 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -71,7 +71,7 @@ def test_consistency( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=False, + mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -165,6 +165,8 @@ def test_rot(self): atype = self.atype.reshape(1, 5) rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE) coord_rot = torch.matmul(self.coord, rmat) + # use larger cell to rotate only coord and shift to the center of cell + cell_rot = 10.0 * torch.eye(3, dtype=dtype, device=env.DEVICE) for nfp, nap, fit_diag, scale in itertools.product( [0, 3], @@ -178,7 +180,7 @@ def test_rot(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=False, + mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -205,7 +207,12 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz + self.shift, atype, self.rcut, self.sel, False, box=self.cell + xyz + self.shift, + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=cell_rot, ) rd0, gr0, _, _, _ = self.dd0( @@ -235,7 +242,7 @@ def test_permu(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=False, + mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -248,7 +255,12 @@ def test_permu(self): _, nlist, ) = extend_input_and_build_neighbor_list( - coord[idx_perm], atype, self.rcut, self.sel, False, box=self.cell + coord[idx_perm], + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=self.cell, ) rd0, gr0, _, _, _ = self.dd0( @@ -280,7 +292,7 @@ def test_trans(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=False, + mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -292,7 +304,12 @@ def test_trans(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz, atype, self.rcut, self.sel, False, box=self.cell + xyz, + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=self.cell, ) rd0, gr0, _, _, _ = self.dd0( @@ -326,7 +343,7 @@ def setUp(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=False, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) self.type_mapping = ["O", "H", "B"] self.model = PolarModel(self.dd0, self.ft0, self.type_mapping)