From d3a5a76bb802b55c45ef4ee16c353bf3a3032b26 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 16 Oct 2024 14:47:43 +0800 Subject: [PATCH 1/4] fix(pt): make int `rcut` safe after jit op --- deepmd/pt/utils/neighbor_stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/neighbor_stat.py b/deepmd/pt/utils/neighbor_stat.py index d427dc758a..7d52bfaae1 100644 --- a/deepmd/pt/utils/neighbor_stat.py +++ b/deepmd/pt/utils/neighbor_stat.py @@ -44,7 +44,7 @@ def __init__( mixed_types: bool, ) -> None: super().__init__() - self.rcut = rcut + self.rcut = float(rcut) self.ntypes = ntypes self.mixed_types = mixed_types From 332d7d9578233c158e130f66ed9839497dd37ea7 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:44:22 +0800 Subject: [PATCH 2/4] fix rcut/rcut_smth in all descriptors --- .../atomic_model/pairtab_atomic_model.py | 4 +- deepmd/pt/model/descriptor/repformer_layer.py | 4 +- deepmd/pt/model/descriptor/repformers.py | 4 +- deepmd/pt/model/descriptor/se_a.py | 4 +- deepmd/pt/model/descriptor/se_atten.py | 4 +- deepmd/pt/model/descriptor/se_r.py | 4 +- deepmd/pt/model/descriptor/se_t.py | 4 +- deepmd/pt/model/descriptor/se_t_tebd.py | 4 +- source/tests/pt/model/test_jit.py | 184 ++++++++++-------- 9 files changed, 117 insertions(+), 99 deletions(-) diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 2918bba947..9a7ea14cfb 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -74,8 +74,8 @@ def __init__( super().__init__(type_map, **kwargs) super().init_out_stat() self.tab_file = tab_file - self.rcut = rcut - self.tab = self._set_pairtab(tab_file, rcut) + self.rcut = float(rcut) + self.tab = self._set_pairtab(tab_file, self.rcut) self.type_map = type_map self.ntypes = len(type_map) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 92e2404469..5270c94112 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -605,8 +605,8 @@ def __init__( ): super().__init__() self.epsilon = 1e-4 # protection of 1./nnei - self.rcut = rcut - self.rcut_smth = rcut_smth + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) self.ntypes = ntypes sel = [sel] if isinstance(sel, int) else sel self.nnei = sum(sel) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index ad4ead4d74..f237088a16 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -193,8 +193,8 @@ def __init__( Random seed for parameter initialization. """ super().__init__() - self.rcut = rcut - self.rcut_smth = rcut_smth + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) self.ntypes = ntypes self.nlayers = nlayers sel = [sel] if isinstance(sel, int) else sel diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index e939a2541b..ffd645f2b9 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -395,8 +395,8 @@ def __init__( - axis_neuron: Number of columns of the sub-matrix of the embedding matrix. """ super().__init__() - self.rcut = rcut - self.rcut_smth = rcut_smth + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) self.neuron = neuron self.filter_neuron = self.neuron self.axis_neuron = axis_neuron diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index c028230e9b..c3174a2011 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -149,8 +149,8 @@ def __init__( """ super().__init__() del type - self.rcut = rcut - self.rcut_smth = rcut_smth + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) self.neuron = neuron self.filter_neuron = self.neuron self.axis_neuron = axis_neuron diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index e82bb23dac..4492a6c6b5 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -74,8 +74,8 @@ def __init__( **kwargs, ): super().__init__() - self.rcut = rcut - self.rcut_smth = rcut_smth + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) self.neuron = neuron self.filter_neuron = self.neuron self.set_davg_zero = set_davg_zero diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 072457b48f..49dbdaf027 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -446,8 +446,8 @@ def __init__( Random seed for initializing the network parameters. """ super().__init__() - self.rcut = rcut - self.rcut_smth = rcut_smth + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) self.neuron = neuron self.filter_neuron = self.neuron self.set_davg_zero = set_davg_zero diff --git a/deepmd/pt/model/descriptor/se_t_tebd.py b/deepmd/pt/model/descriptor/se_t_tebd.py index 437a464709..c140527f31 100644 --- a/deepmd/pt/model/descriptor/se_t_tebd.py +++ b/deepmd/pt/model/descriptor/se_t_tebd.py @@ -512,8 +512,8 @@ def __init__( seed: Optional[Union[int, list[int]]] = None, ): super().__init__() - self.rcut = rcut - self.rcut_smth = rcut_smth + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) self.neuron = neuron self.filter_neuron = self.neuron self.tebd_dim = tebd_dim diff --git a/source/tests/pt/model/test_jit.py b/source/tests/pt/model/test_jit.py index 248ccf9173..20abd76653 100644 --- a/source/tests/pt/model/test_jit.py +++ b/source/tests/pt/model/test_jit.py @@ -47,55 +47,103 @@ def tearDown(self): os.remove(f) -class TestEnergyModelSeA(unittest.TestCase, JITTest): - def setUp(self): - input_json = str(Path(__file__).parent / "water/se_atten.json") - with open(input_json) as f: - self.config = json.load(f) - data_file = [str(Path(__file__).parent / "water/data/data_0")] - self.config["training"]["training_data"]["systems"] = data_file - self.config["training"]["validation_data"]["systems"] = data_file - self.config["model"] = deepcopy(model_se_e2_a) - self.config["training"]["numb_steps"] = 10 - self.config["training"]["save_freq"] = 10 - - def tearDown(self): - JITTest.tearDown(self) - - -class TestDOSModelSeA(unittest.TestCase, JITTest): - def setUp(self): - input_json = str(Path(__file__).parent.parent / "dos/input.json") - with open(input_json) as f: - self.config = json.load(f) - data_file = [str(Path(__file__).parent.parent / "dos/data/global_system")] - self.config["training"]["training_data"]["systems"] = data_file - self.config["training"]["validation_data"]["systems"] = data_file - self.config["model"] = deepcopy(model_dos) - self.config["training"]["numb_steps"] = 10 - self.config["training"]["save_freq"] = 10 - - def tearDown(self): - JITTest.tearDown(self) - - -class TestEnergyModelDPA1(unittest.TestCase, JITTest): - def setUp(self): - input_json = str(Path(__file__).parent / "water/se_atten.json") - with open(input_json) as f: - self.config = json.load(f) - data_file = [str(Path(__file__).parent / "water/data/data_0")] - self.config["training"]["training_data"]["systems"] = data_file - self.config["training"]["validation_data"]["systems"] = data_file - self.config["model"] = deepcopy(model_dpa1) - self.config["training"]["numb_steps"] = 10 - self.config["training"]["save_freq"] = 10 - - def tearDown(self): - JITTest.tearDown(self) - - -class TestEnergyModelDPA2(unittest.TestCase, JITTest): +# class TestEnergyModelSeA(unittest.TestCase, JITTest): +# def setUp(self): +# input_json = str(Path(__file__).parent / "water/se_atten.json") +# with open(input_json) as f: +# self.config = json.load(f) +# data_file = [str(Path(__file__).parent / "water/data/data_0")] +# self.config["training"]["training_data"]["systems"] = data_file +# self.config["training"]["validation_data"]["systems"] = data_file +# self.config["model"] = deepcopy(model_se_e2_a) +# self.config["training"]["numb_steps"] = 10 +# self.config["training"]["save_freq"] = 10 +# +# def tearDown(self): +# JITTest.tearDown(self) +# +# +# class TestDOSModelSeA(unittest.TestCase, JITTest): +# def setUp(self): +# input_json = str(Path(__file__).parent.parent / "dos/input.json") +# with open(input_json) as f: +# self.config = json.load(f) +# data_file = [str(Path(__file__).parent.parent / "dos/data/global_system")] +# self.config["training"]["training_data"]["systems"] = data_file +# self.config["training"]["validation_data"]["systems"] = data_file +# self.config["model"] = deepcopy(model_dos) +# self.config["training"]["numb_steps"] = 10 +# self.config["training"]["save_freq"] = 10 +# +# def tearDown(self): +# JITTest.tearDown(self) +# +# +# class TestEnergyModelDPA1(unittest.TestCase, JITTest): +# def setUp(self): +# input_json = str(Path(__file__).parent / "water/se_atten.json") +# with open(input_json) as f: +# self.config = json.load(f) +# data_file = [str(Path(__file__).parent / "water/data/data_0")] +# self.config["training"]["training_data"]["systems"] = data_file +# self.config["training"]["validation_data"]["systems"] = data_file +# self.config["model"] = deepcopy(model_dpa1) +# self.config["training"]["numb_steps"] = 10 +# self.config["training"]["save_freq"] = 10 +# +# def tearDown(self): +# JITTest.tearDown(self) +# +# +# class TestEnergyModelDPA2(unittest.TestCase, JITTest): +# def setUp(self): +# input_json = str(Path(__file__).parent / "water/se_atten.json") +# with open(input_json) as f: +# self.config = json.load(f) +# data_file = [str(Path(__file__).parent / "water/data/data_0")] +# self.config["training"]["training_data"]["systems"] = data_file +# self.config["training"]["validation_data"]["systems"] = data_file +# self.config["model"] = deepcopy(model_dpa2) +# self.config["training"]["numb_steps"] = 10 +# self.config["training"]["save_freq"] = 10 +# +# def tearDown(self): +# JITTest.tearDown(self) +# +# +# class TestEnergyModelHybrid(unittest.TestCase, JITTest): +# def setUp(self): +# input_json = str(Path(__file__).parent / "water/se_atten.json") +# with open(input_json) as f: +# self.config = json.load(f) +# data_file = [str(Path(__file__).parent / "water/data/data_0")] +# self.config["training"]["training_data"]["systems"] = data_file +# self.config["training"]["validation_data"]["systems"] = data_file +# self.config["model"] = deepcopy(model_hybrid) +# self.config["training"]["numb_steps"] = 10 +# self.config["training"]["save_freq"] = 10 +# +# def tearDown(self): +# JITTest.tearDown(self) +# +# +# class TestEnergyModelHybrid2(unittest.TestCase, JITTest): +# def setUp(self): +# input_json = str(Path(__file__).parent / "water/se_atten.json") +# with open(input_json) as f: +# self.config = json.load(f) +# data_file = [str(Path(__file__).parent / "water/data/data_0")] +# self.config["training"]["training_data"]["systems"] = data_file +# self.config["training"]["validation_data"]["systems"] = data_file +# self.config["model"] = deepcopy(model_hybrid) +# # self.config["model"]["descriptor"]["hybrid_mode"] = "sequential" +# self.config["training"]["numb_steps"] = 10 +# self.config["training"]["save_freq"] = 10 +# +# def tearDown(self): +# JITTest.tearDown(self) + +class TestEnergyModelDPA2IntRcut(unittest.TestCase, JITTest): def setUp(self): input_json = str(Path(__file__).parent / "water/se_atten.json") with open(input_json) as f: @@ -104,45 +152,15 @@ def setUp(self): self.config["training"]["training_data"]["systems"] = data_file self.config["training"]["validation_data"]["systems"] = data_file self.config["model"] = deepcopy(model_dpa2) + self.config["model"]["descriptor"]["repinit"]["rcut"] = int(self.config["model"]["descriptor"]["repinit"]["rcut"]) + self.config["model"]["descriptor"]["repinit"]["rcut_smth"] = int(self.config["model"]["descriptor"]["repinit"]["rcut_smth"]) + # from IPython import embed + # embed() self.config["training"]["numb_steps"] = 10 self.config["training"]["save_freq"] = 10 def tearDown(self): JITTest.tearDown(self) - -class TestEnergyModelHybrid(unittest.TestCase, JITTest): - def setUp(self): - input_json = str(Path(__file__).parent / "water/se_atten.json") - with open(input_json) as f: - self.config = json.load(f) - data_file = [str(Path(__file__).parent / "water/data/data_0")] - self.config["training"]["training_data"]["systems"] = data_file - self.config["training"]["validation_data"]["systems"] = data_file - self.config["model"] = deepcopy(model_hybrid) - self.config["training"]["numb_steps"] = 10 - self.config["training"]["save_freq"] = 10 - - def tearDown(self): - JITTest.tearDown(self) - - -class TestEnergyModelHybrid2(unittest.TestCase, JITTest): - def setUp(self): - input_json = str(Path(__file__).parent / "water/se_atten.json") - with open(input_json) as f: - self.config = json.load(f) - data_file = [str(Path(__file__).parent / "water/data/data_0")] - self.config["training"]["training_data"]["systems"] = data_file - self.config["training"]["validation_data"]["systems"] = data_file - self.config["model"] = deepcopy(model_hybrid) - # self.config["model"]["descriptor"]["hybrid_mode"] = "sequential" - self.config["training"]["numb_steps"] = 10 - self.config["training"]["save_freq"] = 10 - - def tearDown(self): - JITTest.tearDown(self) - - if __name__ == "__main__": unittest.main() From 35a3accdd55aed4b65a4dd786a0754904744bcfc Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:45:09 +0800 Subject: [PATCH 3/4] Update test_jit.py --- source/tests/pt/model/test_jit.py | 192 +++++++++++++++--------------- 1 file changed, 95 insertions(+), 97 deletions(-) diff --git a/source/tests/pt/model/test_jit.py b/source/tests/pt/model/test_jit.py index 20abd76653..746706db36 100644 --- a/source/tests/pt/model/test_jit.py +++ b/source/tests/pt/model/test_jit.py @@ -47,101 +47,101 @@ def tearDown(self): os.remove(f) -# class TestEnergyModelSeA(unittest.TestCase, JITTest): -# def setUp(self): -# input_json = str(Path(__file__).parent / "water/se_atten.json") -# with open(input_json) as f: -# self.config = json.load(f) -# data_file = [str(Path(__file__).parent / "water/data/data_0")] -# self.config["training"]["training_data"]["systems"] = data_file -# self.config["training"]["validation_data"]["systems"] = data_file -# self.config["model"] = deepcopy(model_se_e2_a) -# self.config["training"]["numb_steps"] = 10 -# self.config["training"]["save_freq"] = 10 -# -# def tearDown(self): -# JITTest.tearDown(self) -# -# -# class TestDOSModelSeA(unittest.TestCase, JITTest): -# def setUp(self): -# input_json = str(Path(__file__).parent.parent / "dos/input.json") -# with open(input_json) as f: -# self.config = json.load(f) -# data_file = [str(Path(__file__).parent.parent / "dos/data/global_system")] -# self.config["training"]["training_data"]["systems"] = data_file -# self.config["training"]["validation_data"]["systems"] = data_file -# self.config["model"] = deepcopy(model_dos) -# self.config["training"]["numb_steps"] = 10 -# self.config["training"]["save_freq"] = 10 -# -# def tearDown(self): -# JITTest.tearDown(self) -# -# -# class TestEnergyModelDPA1(unittest.TestCase, JITTest): -# def setUp(self): -# input_json = str(Path(__file__).parent / "water/se_atten.json") -# with open(input_json) as f: -# self.config = json.load(f) -# data_file = [str(Path(__file__).parent / "water/data/data_0")] -# self.config["training"]["training_data"]["systems"] = data_file -# self.config["training"]["validation_data"]["systems"] = data_file -# self.config["model"] = deepcopy(model_dpa1) -# self.config["training"]["numb_steps"] = 10 -# self.config["training"]["save_freq"] = 10 -# -# def tearDown(self): -# JITTest.tearDown(self) -# -# -# class TestEnergyModelDPA2(unittest.TestCase, JITTest): -# def setUp(self): -# input_json = str(Path(__file__).parent / "water/se_atten.json") -# with open(input_json) as f: -# self.config = json.load(f) -# data_file = [str(Path(__file__).parent / "water/data/data_0")] -# self.config["training"]["training_data"]["systems"] = data_file -# self.config["training"]["validation_data"]["systems"] = data_file -# self.config["model"] = deepcopy(model_dpa2) -# self.config["training"]["numb_steps"] = 10 -# self.config["training"]["save_freq"] = 10 -# -# def tearDown(self): -# JITTest.tearDown(self) -# -# -# class TestEnergyModelHybrid(unittest.TestCase, JITTest): -# def setUp(self): -# input_json = str(Path(__file__).parent / "water/se_atten.json") -# with open(input_json) as f: -# self.config = json.load(f) -# data_file = [str(Path(__file__).parent / "water/data/data_0")] -# self.config["training"]["training_data"]["systems"] = data_file -# self.config["training"]["validation_data"]["systems"] = data_file -# self.config["model"] = deepcopy(model_hybrid) -# self.config["training"]["numb_steps"] = 10 -# self.config["training"]["save_freq"] = 10 -# -# def tearDown(self): -# JITTest.tearDown(self) -# -# -# class TestEnergyModelHybrid2(unittest.TestCase, JITTest): -# def setUp(self): -# input_json = str(Path(__file__).parent / "water/se_atten.json") -# with open(input_json) as f: -# self.config = json.load(f) -# data_file = [str(Path(__file__).parent / "water/data/data_0")] -# self.config["training"]["training_data"]["systems"] = data_file -# self.config["training"]["validation_data"]["systems"] = data_file -# self.config["model"] = deepcopy(model_hybrid) -# # self.config["model"]["descriptor"]["hybrid_mode"] = "sequential" -# self.config["training"]["numb_steps"] = 10 -# self.config["training"]["save_freq"] = 10 -# -# def tearDown(self): -# JITTest.tearDown(self) +class TestEnergyModelSeA(unittest.TestCase, JITTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_se_e2_a) + self.config["training"]["numb_steps"] = 10 + self.config["training"]["save_freq"] = 10 + + def tearDown(self): + JITTest.tearDown(self) + + +class TestDOSModelSeA(unittest.TestCase, JITTest): + def setUp(self): + input_json = str(Path(__file__).parent.parent / "dos/input.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent.parent / "dos/data/global_system")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_dos) + self.config["training"]["numb_steps"] = 10 + self.config["training"]["save_freq"] = 10 + + def tearDown(self): + JITTest.tearDown(self) + + +class TestEnergyModelDPA1(unittest.TestCase, JITTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_dpa1) + self.config["training"]["numb_steps"] = 10 + self.config["training"]["save_freq"] = 10 + + def tearDown(self): + JITTest.tearDown(self) + + +class TestEnergyModelDPA2(unittest.TestCase, JITTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_dpa2) + self.config["training"]["numb_steps"] = 10 + self.config["training"]["save_freq"] = 10 + + def tearDown(self): + JITTest.tearDown(self) + + +class TestEnergyModelHybrid(unittest.TestCase, JITTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_hybrid) + self.config["training"]["numb_steps"] = 10 + self.config["training"]["save_freq"] = 10 + + def tearDown(self): + JITTest.tearDown(self) + + +class TestEnergyModelHybrid2(unittest.TestCase, JITTest): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_hybrid) + # self.config["model"]["descriptor"]["hybrid_mode"] = "sequential" + self.config["training"]["numb_steps"] = 10 + self.config["training"]["save_freq"] = 10 + + def tearDown(self): + JITTest.tearDown(self) class TestEnergyModelDPA2IntRcut(unittest.TestCase, JITTest): def setUp(self): @@ -154,8 +154,6 @@ def setUp(self): self.config["model"] = deepcopy(model_dpa2) self.config["model"]["descriptor"]["repinit"]["rcut"] = int(self.config["model"]["descriptor"]["repinit"]["rcut"]) self.config["model"]["descriptor"]["repinit"]["rcut_smth"] = int(self.config["model"]["descriptor"]["repinit"]["rcut_smth"]) - # from IPython import embed - # embed() self.config["training"]["numb_steps"] = 10 self.config["training"]["save_freq"] = 10 From e9e8541b9a4aabf8724b77139bb012b5e319e4f2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 07:46:25 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/model/test_jit.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/source/tests/pt/model/test_jit.py b/source/tests/pt/model/test_jit.py index 746706db36..1f1034c330 100644 --- a/source/tests/pt/model/test_jit.py +++ b/source/tests/pt/model/test_jit.py @@ -143,6 +143,7 @@ def setUp(self): def tearDown(self): JITTest.tearDown(self) + class TestEnergyModelDPA2IntRcut(unittest.TestCase, JITTest): def setUp(self): input_json = str(Path(__file__).parent / "water/se_atten.json") @@ -152,13 +153,18 @@ def setUp(self): self.config["training"]["training_data"]["systems"] = data_file self.config["training"]["validation_data"]["systems"] = data_file self.config["model"] = deepcopy(model_dpa2) - self.config["model"]["descriptor"]["repinit"]["rcut"] = int(self.config["model"]["descriptor"]["repinit"]["rcut"]) - self.config["model"]["descriptor"]["repinit"]["rcut_smth"] = int(self.config["model"]["descriptor"]["repinit"]["rcut_smth"]) + self.config["model"]["descriptor"]["repinit"]["rcut"] = int( + self.config["model"]["descriptor"]["repinit"]["rcut"] + ) + self.config["model"]["descriptor"]["repinit"]["rcut_smth"] = int( + self.config["model"]["descriptor"]["repinit"]["rcut_smth"] + ) self.config["training"]["numb_steps"] = 10 self.config["training"]["save_freq"] = 10 def tearDown(self): JITTest.tearDown(self) + if __name__ == "__main__": unittest.main()