From 92a9aa463223e18af5c44a323f9618aaaf0600ce Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 7 Jul 2021 02:14:29 -0400 Subject: [PATCH 1/6] use cached Session to speed up py tests --- source/tests/test_data_modifier.py | 4 ++-- source/tests/test_data_modifier_shuffle.py | 4 ++-- source/tests/test_descrpt_nonsmth.py | 8 ++++---- source/tests/test_descrpt_se_a_type.py | 6 +++--- source/tests/test_descrpt_se_ar.py | 6 +++--- source/tests/test_descrpt_se_r.py | 8 ++++---- source/tests/test_descrpt_sea_ef.py | 6 +++--- source/tests/test_descrpt_sea_ef_para.py | 6 +++--- source/tests/test_descrpt_sea_ef_rot.py | 4 ++-- source/tests/test_descrpt_sea_ef_vert.py | 6 +++--- source/tests/test_descrpt_smooth.py | 8 ++++---- source/tests/test_dipole_se_a.py | 4 ++-- source/tests/test_embedding_net.py | 4 ++-- source/tests/test_ewald.py | 8 ++++---- source/tests/test_fitting_ener_type.py | 4 ++-- source/tests/test_model_loc_frame.py | 4 ++-- source/tests/test_model_se_a.py | 6 +++--- source/tests/test_model_se_a_aparam.py | 4 ++-- source/tests/test_model_se_a_fparam.py | 4 ++-- source/tests/test_model_se_a_srtab.py | 4 ++-- source/tests/test_model_se_a_type.py | 4 ++-- source/tests/test_model_se_r.py | 4 ++-- source/tests/test_model_se_t.py | 4 ++-- source/tests/test_polar_se_a.py | 4 ++-- source/tests/test_prod_env_mat.py | 4 ++-- source/tests/test_prod_force.py | 4 ++-- source/tests/test_prod_force_grad.py | 4 ++-- source/tests/test_prod_virial.py | 4 ++-- source/tests/test_prod_virial_grad.py | 4 ++-- source/tests/test_type_embed.py | 6 +++--- source/tests/test_wfc.py | 4 ++-- 31 files changed, 77 insertions(+), 77 deletions(-) diff --git a/source/tests/test_data_modifier.py b/source/tests/test_data_modifier.py index c824262eb1..2d7b26ee7f 100644 --- a/source/tests/test_data_modifier.py +++ b/source/tests/test_data_modifier.py @@ -28,7 +28,7 @@ INPUT = os.path.join(modifier_datapath, 'dipole.json') -class TestDataModifier (unittest.TestCase) : +class TestDataModifier (tf.test.TestCase) : def setUp(self): # with tf.variable_scope('load', reuse = False) : @@ -74,7 +74,7 @@ def _setUp(self): model.build (data) # freeze the graph - with tf.Session() as sess: + with self.cached_session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) graph = tf.get_default_graph() diff --git a/source/tests/test_data_modifier_shuffle.py b/source/tests/test_data_modifier_shuffle.py index 397cb7fa5b..4dce6ab59e 100644 --- a/source/tests/test_data_modifier_shuffle.py +++ b/source/tests/test_data_modifier_shuffle.py @@ -29,7 +29,7 @@ modifier_datapath = 'data_modifier' -class TestDataModifier (unittest.TestCase) : +class TestDataModifier (tf.test.TestCase) : def setUp(self): # with tf.variable_scope('load', reuse = False) : @@ -78,7 +78,7 @@ def _setUp(self): model.build (data) # freeze the graph - with tf.Session() as sess: + with self.cached_session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) graph = tf.get_default_graph() diff --git a/source/tests/test_descrpt_nonsmth.py b/source/tests/test_descrpt_nonsmth.py index 1b99934d12..1f79e048da 100644 --- a/source/tests/test_descrpt_nonsmth.py +++ b/source/tests/test_descrpt_nonsmth.py @@ -25,7 +25,7 @@ def setUp (self, data, comp = 0, pbc = True) : - self.sess = tf.Session() + self.sess = self.cached_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() @@ -155,12 +155,12 @@ def comp_v_dw (self, -class TestNonSmooth(Inter, unittest.TestCase): +class TestNonSmooth(Inter, tf.test.TestCase): # def __init__ (self, *args, **kwargs): # self.places = 5 # data = Data() # Inter.__init__(self, data) - # unittest.TestCase.__init__(self, *args, **kwargs) + # tf.test.TestCase.__init__(self, *args, **kwargs) # self.controller = object() def setUp(self): @@ -181,7 +181,7 @@ def test_virial_dw (self) : virial_dw_test(self, self, suffix = '_se') -class TestLFPbc(unittest.TestCase): +class TestLFPbc(tf.test.TestCase): def test_pbc(self): data = Data() inter0 = Inter() diff --git a/source/tests/test_descrpt_se_a_type.py b/source/tests/test_descrpt_se_a_type.py index c35d952157..2bf16437bd 100644 --- a/source/tests/test_descrpt_se_a_type.py +++ b/source/tests/test_descrpt_se_a_type.py @@ -16,7 +16,7 @@ GLOBAL_TF_FLOAT_PRECISION = tf.float64 GLOBAL_NP_FLOAT_PRECISION = np.float64 -class TestModel(unittest.TestCase): +class TestModel(tf.test.TestCase): def setUp(self) : gen_data() @@ -110,7 +110,7 @@ def test_descriptor_two_sides(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [model_dout] = sess.run([dout], feed_dict = feed_dict_test) @@ -219,7 +219,7 @@ def test_descriptor_one_side(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [model_dout] = sess.run([dout], feed_dict = feed_dict_test) diff --git a/source/tests/test_descrpt_se_ar.py b/source/tests/test_descrpt_se_ar.py index 03026cff97..0fb018903d 100644 --- a/source/tests/test_descrpt_se_ar.py +++ b/source/tests/test_descrpt_se_ar.py @@ -23,7 +23,7 @@ class Inter(): def setUp (self, data) : - self.sess = tf.Session() + self.sess = self.cached_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() @@ -94,11 +94,11 @@ def comp_ef (self, return energy, force, virial -class TestDescrptAR(Inter, unittest.TestCase): +class TestDescrptAR(Inter, tf.test.TestCase): # def __init__ (self, *args, **kwargs): # data = Data() # Inter.__init__(self, data) - # unittest.TestCase.__init__(self, *args, **kwargs) + # tf.test.TestCase.__init__(self, *args, **kwargs) # self.controller = object() def setUp(self): diff --git a/source/tests/test_descrpt_se_r.py b/source/tests/test_descrpt_se_r.py index 4d6222a728..8a759f7ebf 100644 --- a/source/tests/test_descrpt_se_r.py +++ b/source/tests/test_descrpt_se_r.py @@ -24,7 +24,7 @@ class Inter(): def setUp (self, data, pbc = True) : - self.sess = tf.Session() + self.sess = self.cached_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() @@ -136,11 +136,11 @@ def comp_v_dw (self, -class TestSmooth(Inter, unittest.TestCase): +class TestSmooth(Inter, tf.test.TestCase): # def __init__ (self, *args, **kwargs): # data = Data() # Inter.__init__(self, data) - # unittest.TestCase.__init__(self, *args, **kwargs) + # tf.test.TestCase.__init__(self, *args, **kwargs) # self.controller = object() def setUp(self): @@ -161,7 +161,7 @@ def test_virial_dw (self) : virial_dw_test(self, self, suffix = '_se_r') -class TestSeRPbc(unittest.TestCase): +class TestSeRPbc(tf.test.TestCase): def test_pbc(self): data = Data() inter0 = Inter() diff --git a/source/tests/test_descrpt_sea_ef.py b/source/tests/test_descrpt_sea_ef.py index 43b95f1cbc..e8a972c0b8 100644 --- a/source/tests/test_descrpt_sea_ef.py +++ b/source/tests/test_descrpt_sea_ef.py @@ -24,7 +24,7 @@ class Inter(): def setUp (self, data, pbc = True) : - self.sess = tf.Session() + self.sess = self.cached_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() @@ -148,11 +148,11 @@ def comp_v_dw (self, -class TestSmooth(Inter, unittest.TestCase): +class TestSmooth(Inter, tf.test.TestCase): # def __init__ (self, *args, **kwargs): # data = Data() # Inter.__init__(self, data) - # unittest.TestCase.__init__(self, *args, **kwargs) + # tf.test.TestCase.__init__(self, *args, **kwargs) # self.controller = object() def setUp(self): diff --git a/source/tests/test_descrpt_sea_ef_para.py b/source/tests/test_descrpt_sea_ef_para.py index 0e94147161..edc609dd04 100644 --- a/source/tests/test_descrpt_sea_ef_para.py +++ b/source/tests/test_descrpt_sea_ef_para.py @@ -24,7 +24,7 @@ class Inter(): def setUp (self, data, pbc = True) : - self.sess = tf.Session() + self.sess = self.cached_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() @@ -148,11 +148,11 @@ def comp_v_dw (self, -class TestSmooth(Inter, unittest.TestCase): +class TestSmooth(Inter, tf.test.TestCase): # def __init__ (self, *args, **kwargs): # data = Data() # Inter.__init__(self, data) - # unittest.TestCase.__init__(self, *args, **kwargs) + # tf.test.TestCase.__init__(self, *args, **kwargs) # self.controller = object() def setUp(self): diff --git a/source/tests/test_descrpt_sea_ef_rot.py b/source/tests/test_descrpt_sea_ef_rot.py index 34c7434a66..67615328d3 100644 --- a/source/tests/test_descrpt_sea_ef_rot.py +++ b/source/tests/test_descrpt_sea_ef_rot.py @@ -12,9 +12,9 @@ from deepmd.descriptor import DescrptSeA from deepmd.descriptor import DescrptSeAEfLower -class TestEfRot(unittest.TestCase): +class TestEfRot(tf.test.TestCase): def setUp(self): - self.sess = tf.Session() + self.sess = self.cached_session() self.natoms = [5, 5, 2, 3] self.ntypes = 2 self.sel_a = [12,24] diff --git a/source/tests/test_descrpt_sea_ef_vert.py b/source/tests/test_descrpt_sea_ef_vert.py index a0d160e966..05a25e2ee9 100644 --- a/source/tests/test_descrpt_sea_ef_vert.py +++ b/source/tests/test_descrpt_sea_ef_vert.py @@ -24,7 +24,7 @@ class Inter(): def setUp (self, data, pbc = True) : - self.sess = tf.Session() + self.sess = self.cached_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() @@ -148,11 +148,11 @@ def comp_v_dw (self, -class TestSmooth(Inter, unittest.TestCase): +class TestSmooth(Inter, tf.test.TestCase): # def __init__ (self, *args, **kwargs): # data = Data() # Inter.__init__(self, data) - # unittest.TestCase.__init__(self, *args, **kwargs) + # tf.test.TestCase.__init__(self, *args, **kwargs) # self.controller = object() def setUp(self): diff --git a/source/tests/test_descrpt_smooth.py b/source/tests/test_descrpt_smooth.py index 05cb14b26e..8c1cdea174 100644 --- a/source/tests/test_descrpt_smooth.py +++ b/source/tests/test_descrpt_smooth.py @@ -24,7 +24,7 @@ class Inter(): def setUp (self, data, pbc = True) : - self.sess = tf.Session() + self.sess = self.cached_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() @@ -147,11 +147,11 @@ def comp_v_dw (self, -class TestSmooth(Inter, unittest.TestCase): +class TestSmooth(Inter, tf.test.TestCase): # def __init__ (self, *args, **kwargs): # data = Data() # Inter.__init__(self, data) - # unittest.TestCase.__init__(self, *args, **kwargs) + # tf.test.TestCase.__init__(self, *args, **kwargs) # self.controller = object() def setUp(self): @@ -172,7 +172,7 @@ def test_virial_dw (self) : virial_dw_test(self, self, suffix = '_smth') -class TestSeAPbc(unittest.TestCase): +class TestSeAPbc(tf.test.TestCase): def test_pbc(self): data = Data() inter0 = Inter() diff --git a/source/tests/test_dipole_se_a.py b/source/tests/test_dipole_se_a.py index 26a86c137b..fdc5c60d43 100644 --- a/source/tests/test_dipole_se_a.py +++ b/source/tests/test_dipole_se_a.py @@ -14,7 +14,7 @@ GLOBAL_TF_FLOAT_PRECISION = tf.float64 GLOBAL_NP_FLOAT_PRECISION = np.float64 -class TestModel(unittest.TestCase): +class TestModel(tf.test.TestCase): def setUp(self) : gen_data() @@ -90,7 +90,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [p, gp] = sess.run([dipole, gdipole], feed_dict = feed_dict_test) diff --git a/source/tests/test_embedding_net.py b/source/tests/test_embedding_net.py index 4c07ea575e..1e5d223c08 100644 --- a/source/tests/test_embedding_net.py +++ b/source/tests/test_embedding_net.py @@ -11,9 +11,9 @@ from deepmd.env import GLOBAL_NP_FLOAT_PRECISION from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION -class Inter(unittest.TestCase): +class Inter(tf.test.TestCase): def setUp (self) : - self.sess = tf.Session() + self.sess = self.cached_session() self.inputs = tf.constant([ 0., 1., 2.], dtype = tf.float64) self.ndata = 3 self.inputs = tf.reshape(self.inputs, [-1, 1]) diff --git a/source/tests/test_ewald.py b/source/tests/test_ewald.py index 83c5bc3a6d..13a1a5338a 100644 --- a/source/tests/test_ewald.py +++ b/source/tests/test_ewald.py @@ -19,7 +19,7 @@ global_default_places = 5 -class TestEwaldRecp (unittest.TestCase) : +class TestEwaldRecp (tf.test.TestCase) : def setUp(self): boxl = 4.5 # NOTICE grid should not change before and after box pert... box_pert = 0.2 @@ -62,7 +62,7 @@ def setUp(self): def test_py_interface(self): hh = 1e-4 places = 4 - sess = tf.Session() + sess = self.cached_session() t_energy, t_force, t_virial \ = op_module.ewald_recp(self.coord, self.charge, self.nloc, self.box, ewald_h = self.ewald_h, @@ -96,7 +96,7 @@ def test_py_interface(self): def test_force(self): hh = 1e-4 places = 6 - sess = tf.Session() + sess = self.cached_session() t_energy, t_force, t_virial \ = op_module.ewald_recp(self.coord, self.charge, self.nloc, self.box, ewald_h = self.ewald_h, @@ -138,7 +138,7 @@ def test_force(self): def test_virial(self): hh = 1e-4 places = 6 - sess = tf.Session() + sess = self.cached_session() t_energy, t_force, t_virial \ = op_module.ewald_recp(self.coord, self.charge, self.nloc, self.box, ewald_h = self.ewald_h, diff --git a/source/tests/test_fitting_ener_type.py b/source/tests/test_fitting_ener_type.py index 71312188a3..fea952ebfa 100644 --- a/source/tests/test_fitting_ener_type.py +++ b/source/tests/test_fitting_ener_type.py @@ -15,7 +15,7 @@ GLOBAL_TF_FLOAT_PRECISION = tf.float64 GLOBAL_NP_FLOAT_PRECISION = np.float64 -class TestModel(unittest.TestCase): +class TestModel(tf.test.TestCase): def setUp(self) : gen_data() @@ -97,7 +97,7 @@ def test_fitting(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [pred_atom_ener] = sess.run([atom_ener], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_loc_frame.py b/source/tests/test_model_loc_frame.py index fe4d128c08..4a7dbaa8f6 100644 --- a/source/tests/test_model_loc_frame.py +++ b/source/tests/test_model_loc_frame.py @@ -14,7 +14,7 @@ GLOBAL_NP_FLOAT_PRECISION = np.float64 -class TestModel(unittest.TestCase): +class TestModel(tf.test.TestCase): def setUp(self) : gen_data() @@ -96,7 +96,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a.py b/source/tests/test_model_se_a.py index 700986e308..9f4df4f00c 100644 --- a/source/tests/test_model_se_a.py +++ b/source/tests/test_model_se_a.py @@ -14,7 +14,7 @@ GLOBAL_TF_FLOAT_PRECISION = tf.float64 GLOBAL_NP_FLOAT_PRECISION = np.float64 -class TestModel(unittest.TestCase): +class TestModel(tf.test.TestCase): def setUp(self) : gen_data() @@ -102,7 +102,7 @@ def test_model_atom_ener(self): t_mesh: test_data['default_mesh'], is_training: False } - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) @@ -189,7 +189,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a_aparam.py b/source/tests/test_model_se_a_aparam.py index 77acb2143f..f6782a8f4f 100644 --- a/source/tests/test_model_se_a_aparam.py +++ b/source/tests/test_model_se_a_aparam.py @@ -13,7 +13,7 @@ GLOBAL_TF_FLOAT_PRECISION = tf.float64 GLOBAL_NP_FLOAT_PRECISION = np.float64 -class TestModel(unittest.TestCase): +class TestModel(tf.test.TestCase): def setUp(self) : gen_data() @@ -95,7 +95,7 @@ def test_model(self): t_aparam: np.reshape(test_data['aparam'] [:numb_test, :], [-1]), is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a_fparam.py b/source/tests/test_model_se_a_fparam.py index b69ee713e6..6d9077c429 100644 --- a/source/tests/test_model_se_a_fparam.py +++ b/source/tests/test_model_se_a_fparam.py @@ -13,7 +13,7 @@ GLOBAL_TF_FLOAT_PRECISION = tf.float64 GLOBAL_NP_FLOAT_PRECISION = np.float64 -class TestModel(unittest.TestCase): +class TestModel(tf.test.TestCase): def setUp(self) : gen_data() @@ -96,7 +96,7 @@ def test_model(self): t_fparam: np.reshape(test_data['fparam'] [:numb_test, :], [-1]), is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a_srtab.py b/source/tests/test_model_se_a_srtab.py index d04f011acd..8d6dea2e31 100644 --- a/source/tests/test_model_se_a_srtab.py +++ b/source/tests/test_model_se_a_srtab.py @@ -23,7 +23,7 @@ def _make_tab(ntype) : prt = np.reshape(prt, [ninter+1, -1]) np.savetxt('tab.xvg', prt.T) -class TestModel(unittest.TestCase): +class TestModel(tf.test.TestCase): def setUp(self) : gen_data() _make_tab(2) @@ -117,7 +117,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a_type.py b/source/tests/test_model_se_a_type.py index 393cb4c86e..ff663e1ed6 100644 --- a/source/tests/test_model_se_a_type.py +++ b/source/tests/test_model_se_a_type.py @@ -15,7 +15,7 @@ GLOBAL_TF_FLOAT_PRECISION = tf.float64 GLOBAL_NP_FLOAT_PRECISION = np.float64 -class TestModel(unittest.TestCase): +class TestModel(tf.test.TestCase): def setUp(self) : gen_data() @@ -99,7 +99,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_r.py b/source/tests/test_model_se_r.py index 5172746cc5..2375164d6c 100644 --- a/source/tests/test_model_se_r.py +++ b/source/tests/test_model_se_r.py @@ -13,7 +13,7 @@ GLOBAL_TF_FLOAT_PRECISION = tf.float64 GLOBAL_NP_FLOAT_PRECISION = np.float64 -class TestModel(unittest.TestCase): +class TestModel(tf.test.TestCase): def setUp(self) : gen_data() @@ -91,7 +91,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_t.py b/source/tests/test_model_se_t.py index 78b06ea7ab..34bd9b5ae6 100644 --- a/source/tests/test_model_se_t.py +++ b/source/tests/test_model_se_t.py @@ -13,7 +13,7 @@ GLOBAL_TF_FLOAT_PRECISION = tf.float64 GLOBAL_NP_FLOAT_PRECISION = np.float64 -class TestModel(unittest.TestCase): +class TestModel(tf.test.TestCase): def setUp(self) : gen_data() @@ -89,7 +89,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_polar_se_a.py b/source/tests/test_polar_se_a.py index c84bc0aa6d..ca9827d640 100644 --- a/source/tests/test_polar_se_a.py +++ b/source/tests/test_polar_se_a.py @@ -14,7 +14,7 @@ GLOBAL_TF_FLOAT_PRECISION = tf.float64 GLOBAL_NP_FLOAT_PRECISION = np.float64 -class TestModel(unittest.TestCase): +class TestModel(tf.test.TestCase): def setUp(self) : gen_data() @@ -89,7 +89,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [p, gp] = sess.run([polar, gpolar], feed_dict = feed_dict_test) diff --git a/source/tests/test_prod_env_mat.py b/source/tests/test_prod_env_mat.py index 2e8aa168f8..9102842992 100644 --- a/source/tests/test_prod_env_mat.py +++ b/source/tests/test_prod_env_mat.py @@ -9,9 +9,9 @@ from deepmd.env import GLOBAL_NP_FLOAT_PRECISION from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION -class TestProdEnvMat(unittest.TestCase): +class TestProdEnvMat(tf.test.TestCase): def setUp(self): - self.sess = tf.Session() + self.sess = self.cached_session() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_prod_force.py b/source/tests/test_prod_force.py index 012fd7ee9b..bf1b29f5f6 100644 --- a/source/tests/test_prod_force.py +++ b/source/tests/test_prod_force.py @@ -9,9 +9,9 @@ from deepmd.env import GLOBAL_NP_FLOAT_PRECISION from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION -class TestProdForce(unittest.TestCase): +class TestProdForce(tf.test.TestCase): def setUp(self): - self.sess = tf.Session() + self.sess = self.cached_session() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_prod_force_grad.py b/source/tests/test_prod_force_grad.py index da05406477..254a82cc5b 100644 --- a/source/tests/test_prod_force_grad.py +++ b/source/tests/test_prod_force_grad.py @@ -9,9 +9,9 @@ from deepmd.env import GLOBAL_NP_FLOAT_PRECISION from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION -class TestProdForceGrad(unittest.TestCase): +class TestProdForceGrad(tf.test.TestCase): def setUp(self): - self.sess = tf.Session() + self.sess = self.cached_session() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_prod_virial.py b/source/tests/test_prod_virial.py index 4a22e95839..b4fda4b234 100644 --- a/source/tests/test_prod_virial.py +++ b/source/tests/test_prod_virial.py @@ -9,9 +9,9 @@ from deepmd.env import GLOBAL_NP_FLOAT_PRECISION from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION -class TestProdVirial(unittest.TestCase): +class TestProdVirial(tf.test.TestCase): def setUp(self): - self.sess = tf.Session() + self.sess = self.cached_session() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_prod_virial_grad.py b/source/tests/test_prod_virial_grad.py index 60fdd7f66d..31595d58b7 100644 --- a/source/tests/test_prod_virial_grad.py +++ b/source/tests/test_prod_virial_grad.py @@ -9,9 +9,9 @@ from deepmd.env import GLOBAL_NP_FLOAT_PRECISION from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION -class TestProdVirialGrad(unittest.TestCase): +class TestProdVirialGrad(tf.test.TestCase): def setUp(self): - self.sess = tf.Session() + self.sess = self.cached_session() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_type_embed.py b/source/tests/test_type_embed.py index 5f71d8ed93..a472a91839 100644 --- a/source/tests/test_type_embed.py +++ b/source/tests/test_type_embed.py @@ -3,7 +3,7 @@ from deepmd.env import tf from deepmd.utils.type_embed import embed_atom_type, TypeEmbedNet -class TestTypeEbd(unittest.TestCase): +class TestTypeEbd(tf.test.TestCase): def test_embed_atom_type(self): ntypes = 3 natoms = tf.constant([5, 5, 3, 0, 2]) @@ -19,7 +19,7 @@ def test_embed_atom_type(self): [7,7,7], [7,7,7]] atom_embed = embed_atom_type(ntypes, natoms, type_embedding) - sess = tf.Session() + sess = self.cached_session() atom_embed = sess.run(atom_embed) for ii in range(5): for jj in range(3): @@ -29,7 +29,7 @@ def test_embed_atom_type(self): def test_type_embed_net(self): ten = TypeEmbedNet([2, 4, 8], seed = 1, uniform_seed = True) type_embedding = ten.build(2) - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) type_embedding = sess.run(type_embedding) diff --git a/source/tests/test_wfc.py b/source/tests/test_wfc.py index c3759429f6..dead1736c0 100644 --- a/source/tests/test_wfc.py +++ b/source/tests/test_wfc.py @@ -13,7 +13,7 @@ GLOBAL_TF_FLOAT_PRECISION = tf.float64 GLOBAL_NP_FLOAT_PRECISION = np.float64 -class TestModel(unittest.TestCase): +class TestModel(tf.test.TestCase): def setUp(self) : gen_data() @@ -83,7 +83,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = tf.Session() + sess = self.cached_session() sess.run(tf.global_variables_initializer()) [p] = sess.run([wfc], feed_dict = feed_dict_test) From 63860baf342e0be32f5e92e846d047b387f073a3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 7 Jul 2021 02:26:55 -0400 Subject: [PATCH 2/6] cached_session is not avaible in TF 1.8; use test_session instead although it has been DEPRECATED --- source/tests/test_data_modifier.py | 2 +- source/tests/test_data_modifier_shuffle.py | 2 +- source/tests/test_descrpt_nonsmth.py | 2 +- source/tests/test_descrpt_se_a_type.py | 4 ++-- source/tests/test_descrpt_se_ar.py | 2 +- source/tests/test_descrpt_se_r.py | 2 +- source/tests/test_descrpt_sea_ef.py | 2 +- source/tests/test_descrpt_sea_ef_para.py | 2 +- source/tests/test_descrpt_sea_ef_rot.py | 2 +- source/tests/test_descrpt_sea_ef_vert.py | 2 +- source/tests/test_descrpt_smooth.py | 2 +- source/tests/test_dipole_se_a.py | 2 +- source/tests/test_embedding_net.py | 2 +- source/tests/test_ewald.py | 6 +++--- source/tests/test_fitting_ener_type.py | 2 +- source/tests/test_model_loc_frame.py | 2 +- source/tests/test_model_se_a.py | 4 ++-- source/tests/test_model_se_a_aparam.py | 2 +- source/tests/test_model_se_a_fparam.py | 2 +- source/tests/test_model_se_a_srtab.py | 2 +- source/tests/test_model_se_a_type.py | 2 +- source/tests/test_model_se_r.py | 2 +- source/tests/test_model_se_t.py | 2 +- source/tests/test_polar_se_a.py | 2 +- source/tests/test_prod_env_mat.py | 2 +- source/tests/test_prod_force.py | 2 +- source/tests/test_prod_force_grad.py | 2 +- source/tests/test_prod_virial.py | 2 +- source/tests/test_prod_virial_grad.py | 2 +- source/tests/test_type_embed.py | 4 ++-- source/tests/test_wfc.py | 2 +- 31 files changed, 36 insertions(+), 36 deletions(-) diff --git a/source/tests/test_data_modifier.py b/source/tests/test_data_modifier.py index 2d7b26ee7f..829a589d7e 100644 --- a/source/tests/test_data_modifier.py +++ b/source/tests/test_data_modifier.py @@ -74,7 +74,7 @@ def _setUp(self): model.build (data) # freeze the graph - with self.cached_session() as sess: + with self.test_session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) graph = tf.get_default_graph() diff --git a/source/tests/test_data_modifier_shuffle.py b/source/tests/test_data_modifier_shuffle.py index 4dce6ab59e..bd4ab58132 100644 --- a/source/tests/test_data_modifier_shuffle.py +++ b/source/tests/test_data_modifier_shuffle.py @@ -78,7 +78,7 @@ def _setUp(self): model.build (data) # freeze the graph - with self.cached_session() as sess: + with self.test_session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) graph = tf.get_default_graph() diff --git a/source/tests/test_descrpt_nonsmth.py b/source/tests/test_descrpt_nonsmth.py index 1f79e048da..b89fa36120 100644 --- a/source/tests/test_descrpt_nonsmth.py +++ b/source/tests/test_descrpt_nonsmth.py @@ -25,7 +25,7 @@ def setUp (self, data, comp = 0, pbc = True) : - self.sess = self.cached_session() + self.sess = self.test_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() diff --git a/source/tests/test_descrpt_se_a_type.py b/source/tests/test_descrpt_se_a_type.py index 2bf16437bd..5a8ebb3fb2 100644 --- a/source/tests/test_descrpt_se_a_type.py +++ b/source/tests/test_descrpt_se_a_type.py @@ -110,7 +110,7 @@ def test_descriptor_two_sides(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [model_dout] = sess.run([dout], feed_dict = feed_dict_test) @@ -219,7 +219,7 @@ def test_descriptor_one_side(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [model_dout] = sess.run([dout], feed_dict = feed_dict_test) diff --git a/source/tests/test_descrpt_se_ar.py b/source/tests/test_descrpt_se_ar.py index 0fb018903d..3cb078c51d 100644 --- a/source/tests/test_descrpt_se_ar.py +++ b/source/tests/test_descrpt_se_ar.py @@ -23,7 +23,7 @@ class Inter(): def setUp (self, data) : - self.sess = self.cached_session() + self.sess = self.test_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() diff --git a/source/tests/test_descrpt_se_r.py b/source/tests/test_descrpt_se_r.py index 8a759f7ebf..cd632570c7 100644 --- a/source/tests/test_descrpt_se_r.py +++ b/source/tests/test_descrpt_se_r.py @@ -24,7 +24,7 @@ class Inter(): def setUp (self, data, pbc = True) : - self.sess = self.cached_session() + self.sess = self.test_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() diff --git a/source/tests/test_descrpt_sea_ef.py b/source/tests/test_descrpt_sea_ef.py index e8a972c0b8..c9a1f262c1 100644 --- a/source/tests/test_descrpt_sea_ef.py +++ b/source/tests/test_descrpt_sea_ef.py @@ -24,7 +24,7 @@ class Inter(): def setUp (self, data, pbc = True) : - self.sess = self.cached_session() + self.sess = self.test_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() diff --git a/source/tests/test_descrpt_sea_ef_para.py b/source/tests/test_descrpt_sea_ef_para.py index edc609dd04..39e8a15a44 100644 --- a/source/tests/test_descrpt_sea_ef_para.py +++ b/source/tests/test_descrpt_sea_ef_para.py @@ -24,7 +24,7 @@ class Inter(): def setUp (self, data, pbc = True) : - self.sess = self.cached_session() + self.sess = self.test_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() diff --git a/source/tests/test_descrpt_sea_ef_rot.py b/source/tests/test_descrpt_sea_ef_rot.py index 67615328d3..f1bf47f9dd 100644 --- a/source/tests/test_descrpt_sea_ef_rot.py +++ b/source/tests/test_descrpt_sea_ef_rot.py @@ -14,7 +14,7 @@ class TestEfRot(tf.test.TestCase): def setUp(self): - self.sess = self.cached_session() + self.sess = self.test_session() self.natoms = [5, 5, 2, 3] self.ntypes = 2 self.sel_a = [12,24] diff --git a/source/tests/test_descrpt_sea_ef_vert.py b/source/tests/test_descrpt_sea_ef_vert.py index 05a25e2ee9..d6e7e35c81 100644 --- a/source/tests/test_descrpt_sea_ef_vert.py +++ b/source/tests/test_descrpt_sea_ef_vert.py @@ -24,7 +24,7 @@ class Inter(): def setUp (self, data, pbc = True) : - self.sess = self.cached_session() + self.sess = self.test_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() diff --git a/source/tests/test_descrpt_smooth.py b/source/tests/test_descrpt_smooth.py index 8c1cdea174..f47deae858 100644 --- a/source/tests/test_descrpt_smooth.py +++ b/source/tests/test_descrpt_smooth.py @@ -24,7 +24,7 @@ class Inter(): def setUp (self, data, pbc = True) : - self.sess = self.cached_session() + self.sess = self.test_session() self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() diff --git a/source/tests/test_dipole_se_a.py b/source/tests/test_dipole_se_a.py index fdc5c60d43..44c9786efa 100644 --- a/source/tests/test_dipole_se_a.py +++ b/source/tests/test_dipole_se_a.py @@ -90,7 +90,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [p, gp] = sess.run([dipole, gdipole], feed_dict = feed_dict_test) diff --git a/source/tests/test_embedding_net.py b/source/tests/test_embedding_net.py index 1e5d223c08..c7c4cafd3d 100644 --- a/source/tests/test_embedding_net.py +++ b/source/tests/test_embedding_net.py @@ -13,7 +13,7 @@ class Inter(tf.test.TestCase): def setUp (self) : - self.sess = self.cached_session() + self.sess = self.test_session() self.inputs = tf.constant([ 0., 1., 2.], dtype = tf.float64) self.ndata = 3 self.inputs = tf.reshape(self.inputs, [-1, 1]) diff --git a/source/tests/test_ewald.py b/source/tests/test_ewald.py index 13a1a5338a..463351ca58 100644 --- a/source/tests/test_ewald.py +++ b/source/tests/test_ewald.py @@ -62,7 +62,7 @@ def setUp(self): def test_py_interface(self): hh = 1e-4 places = 4 - sess = self.cached_session() + sess = self.test_session() t_energy, t_force, t_virial \ = op_module.ewald_recp(self.coord, self.charge, self.nloc, self.box, ewald_h = self.ewald_h, @@ -96,7 +96,7 @@ def test_py_interface(self): def test_force(self): hh = 1e-4 places = 6 - sess = self.cached_session() + sess = self.test_session() t_energy, t_force, t_virial \ = op_module.ewald_recp(self.coord, self.charge, self.nloc, self.box, ewald_h = self.ewald_h, @@ -138,7 +138,7 @@ def test_force(self): def test_virial(self): hh = 1e-4 places = 6 - sess = self.cached_session() + sess = self.test_session() t_energy, t_force, t_virial \ = op_module.ewald_recp(self.coord, self.charge, self.nloc, self.box, ewald_h = self.ewald_h, diff --git a/source/tests/test_fitting_ener_type.py b/source/tests/test_fitting_ener_type.py index fea952ebfa..bc4e8842aa 100644 --- a/source/tests/test_fitting_ener_type.py +++ b/source/tests/test_fitting_ener_type.py @@ -97,7 +97,7 @@ def test_fitting(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [pred_atom_ener] = sess.run([atom_ener], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_loc_frame.py b/source/tests/test_model_loc_frame.py index 4a7dbaa8f6..f7401965b7 100644 --- a/source/tests/test_model_loc_frame.py +++ b/source/tests/test_model_loc_frame.py @@ -96,7 +96,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a.py b/source/tests/test_model_se_a.py index 9f4df4f00c..24bf601d74 100644 --- a/source/tests/test_model_se_a.py +++ b/source/tests/test_model_se_a.py @@ -102,7 +102,7 @@ def test_model_atom_ener(self): t_mesh: test_data['default_mesh'], is_training: False } - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) @@ -189,7 +189,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a_aparam.py b/source/tests/test_model_se_a_aparam.py index f6782a8f4f..4111f55daf 100644 --- a/source/tests/test_model_se_a_aparam.py +++ b/source/tests/test_model_se_a_aparam.py @@ -95,7 +95,7 @@ def test_model(self): t_aparam: np.reshape(test_data['aparam'] [:numb_test, :], [-1]), is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a_fparam.py b/source/tests/test_model_se_a_fparam.py index 6d9077c429..849107d59a 100644 --- a/source/tests/test_model_se_a_fparam.py +++ b/source/tests/test_model_se_a_fparam.py @@ -96,7 +96,7 @@ def test_model(self): t_fparam: np.reshape(test_data['fparam'] [:numb_test, :], [-1]), is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a_srtab.py b/source/tests/test_model_se_a_srtab.py index 8d6dea2e31..00a68d1983 100644 --- a/source/tests/test_model_se_a_srtab.py +++ b/source/tests/test_model_se_a_srtab.py @@ -117,7 +117,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a_type.py b/source/tests/test_model_se_a_type.py index ff663e1ed6..f98cc73982 100644 --- a/source/tests/test_model_se_a_type.py +++ b/source/tests/test_model_se_a_type.py @@ -99,7 +99,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_r.py b/source/tests/test_model_se_r.py index 2375164d6c..7b09f07304 100644 --- a/source/tests/test_model_se_r.py +++ b/source/tests/test_model_se_r.py @@ -91,7 +91,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_t.py b/source/tests/test_model_se_t.py index 34bd9b5ae6..e9f47f1031 100644 --- a/source/tests/test_model_se_t.py +++ b/source/tests/test_model_se_t.py @@ -89,7 +89,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_polar_se_a.py b/source/tests/test_polar_se_a.py index ca9827d640..cc73a17781 100644 --- a/source/tests/test_polar_se_a.py +++ b/source/tests/test_polar_se_a.py @@ -89,7 +89,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [p, gp] = sess.run([polar, gpolar], feed_dict = feed_dict_test) diff --git a/source/tests/test_prod_env_mat.py b/source/tests/test_prod_env_mat.py index 9102842992..c94e8b08d8 100644 --- a/source/tests/test_prod_env_mat.py +++ b/source/tests/test_prod_env_mat.py @@ -11,7 +11,7 @@ class TestProdEnvMat(tf.test.TestCase): def setUp(self): - self.sess = self.cached_session() + self.sess = self.test_session() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_prod_force.py b/source/tests/test_prod_force.py index bf1b29f5f6..754e30f08d 100644 --- a/source/tests/test_prod_force.py +++ b/source/tests/test_prod_force.py @@ -11,7 +11,7 @@ class TestProdForce(tf.test.TestCase): def setUp(self): - self.sess = self.cached_session() + self.sess = self.test_session() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_prod_force_grad.py b/source/tests/test_prod_force_grad.py index 254a82cc5b..04629e94a6 100644 --- a/source/tests/test_prod_force_grad.py +++ b/source/tests/test_prod_force_grad.py @@ -11,7 +11,7 @@ class TestProdForceGrad(tf.test.TestCase): def setUp(self): - self.sess = self.cached_session() + self.sess = self.test_session() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_prod_virial.py b/source/tests/test_prod_virial.py index b4fda4b234..3d9467dc52 100644 --- a/source/tests/test_prod_virial.py +++ b/source/tests/test_prod_virial.py @@ -11,7 +11,7 @@ class TestProdVirial(tf.test.TestCase): def setUp(self): - self.sess = self.cached_session() + self.sess = self.test_session() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_prod_virial_grad.py b/source/tests/test_prod_virial_grad.py index 31595d58b7..3b0878e9d5 100644 --- a/source/tests/test_prod_virial_grad.py +++ b/source/tests/test_prod_virial_grad.py @@ -11,7 +11,7 @@ class TestProdVirialGrad(tf.test.TestCase): def setUp(self): - self.sess = self.cached_session() + self.sess = self.test_session() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_type_embed.py b/source/tests/test_type_embed.py index a472a91839..07c06f4244 100644 --- a/source/tests/test_type_embed.py +++ b/source/tests/test_type_embed.py @@ -19,7 +19,7 @@ def test_embed_atom_type(self): [7,7,7], [7,7,7]] atom_embed = embed_atom_type(ntypes, natoms, type_embedding) - sess = self.cached_session() + sess = self.test_session() atom_embed = sess.run(atom_embed) for ii in range(5): for jj in range(3): @@ -29,7 +29,7 @@ def test_embed_atom_type(self): def test_type_embed_net(self): ten = TypeEmbedNet([2, 4, 8], seed = 1, uniform_seed = True) type_embedding = ten.build(2) - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) type_embedding = sess.run(type_embedding) diff --git a/source/tests/test_wfc.py b/source/tests/test_wfc.py index dead1736c0..4412f26182 100644 --- a/source/tests/test_wfc.py +++ b/source/tests/test_wfc.py @@ -83,7 +83,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.cached_session() + sess = self.test_session() sess.run(tf.global_variables_initializer()) [p] = sess.run([wfc], feed_dict = feed_dict_test) From 1191b0909fa862e618f11218e0a54702ce68b5e0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 7 Jul 2021 03:04:15 -0400 Subject: [PATCH 3/6] bugfix --- source/tests/test_descrpt_nonsmth.py | 13 +++++++------ source/tests/test_descrpt_se_a_type.py | 4 ++-- source/tests/test_descrpt_se_ar.py | 5 +++-- source/tests/test_descrpt_se_r.py | 13 +++++++------ source/tests/test_descrpt_sea_ef.py | 5 +++-- source/tests/test_descrpt_sea_ef_para.py | 5 +++-- source/tests/test_descrpt_sea_ef_rot.py | 2 +- source/tests/test_descrpt_sea_ef_vert.py | 5 +++-- source/tests/test_descrpt_smooth.py | 13 +++++++------ source/tests/test_dipole_se_a.py | 2 +- source/tests/test_embedding_net.py | 2 +- source/tests/test_ewald.py | 6 +++--- source/tests/test_fitting_ener_type.py | 2 +- source/tests/test_model_loc_frame.py | 2 +- source/tests/test_model_se_a.py | 4 ++-- source/tests/test_model_se_a_aparam.py | 2 +- source/tests/test_model_se_a_srtab.py | 2 +- source/tests/test_model_se_a_type.py | 2 +- source/tests/test_model_se_r.py | 2 +- source/tests/test_model_se_t.py | 2 +- source/tests/test_polar_se_a.py | 2 +- source/tests/test_prod_env_mat.py | 2 +- source/tests/test_prod_force.py | 2 +- source/tests/test_prod_force_grad.py | 2 +- source/tests/test_prod_virial.py | 2 +- source/tests/test_prod_virial_grad.py | 2 +- source/tests/test_type_embed.py | 4 ++-- source/tests/test_wfc.py | 2 +- 28 files changed, 59 insertions(+), 52 deletions(-) diff --git a/source/tests/test_descrpt_nonsmth.py b/source/tests/test_descrpt_nonsmth.py index b89fa36120..901d718c09 100644 --- a/source/tests/test_descrpt_nonsmth.py +++ b/source/tests/test_descrpt_nonsmth.py @@ -24,8 +24,9 @@ class Inter(): def setUp (self, data, comp = 0, - pbc = True) : - self.sess = self.test_session() + pbc = True, + sess = None) : + self.sess = sess self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() @@ -186,8 +187,8 @@ def test_pbc(self): data = Data() inter0 = Inter() inter1 = Inter() - inter0.setUp(data, pbc = True) - inter1.setUp(data, pbc = False) + inter0.setUp(data, pbc = True, sess=self.test_session().__enter__()) + inter1.setUp(data, pbc = False, sess=self.test_session().__enter__()) inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt)) inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt)) @@ -227,8 +228,8 @@ def test_pbc_small_box(self): data1 = Data(box_scale = 2) inter0 = Inter() inter1 = Inter() - inter0.setUp(data0, pbc = True) - inter1.setUp(data1, pbc = False) + inter0.setUp(data0, pbc = True, sess=self.test_session().__enter__()) + inter1.setUp(data1, pbc = False, sess=self.test_session().__enter__()) inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt)) inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt)) diff --git a/source/tests/test_descrpt_se_a_type.py b/source/tests/test_descrpt_se_a_type.py index 5a8ebb3fb2..c6f3cb5a19 100644 --- a/source/tests/test_descrpt_se_a_type.py +++ b/source/tests/test_descrpt_se_a_type.py @@ -110,7 +110,7 @@ def test_descriptor_two_sides(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [model_dout] = sess.run([dout], feed_dict = feed_dict_test) @@ -219,7 +219,7 @@ def test_descriptor_one_side(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [model_dout] = sess.run([dout], feed_dict = feed_dict_test) diff --git a/source/tests/test_descrpt_se_ar.py b/source/tests/test_descrpt_se_ar.py index 3cb078c51d..6e0bf7bf93 100644 --- a/source/tests/test_descrpt_se_ar.py +++ b/source/tests/test_descrpt_se_ar.py @@ -22,8 +22,9 @@ class Inter(): def setUp (self, - data) : - self.sess = self.test_session() + data, + sess) : + self.sess = sess self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() diff --git a/source/tests/test_descrpt_se_r.py b/source/tests/test_descrpt_se_r.py index cd632570c7..f698292eda 100644 --- a/source/tests/test_descrpt_se_r.py +++ b/source/tests/test_descrpt_se_r.py @@ -23,8 +23,9 @@ class Inter(): def setUp (self, data, - pbc = True) : - self.sess = self.test_session() + pbc = True, + sess = None) : + self.sess = sess self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() @@ -166,8 +167,8 @@ def test_pbc(self): data = Data() inter0 = Inter() inter1 = Inter() - inter0.setUp(data, pbc = True) - inter1.setUp(data, pbc = False) + inter0.setUp(data, pbc = True, sess=self.test_session().__enter__()) + inter1.setUp(data, pbc = False, sess=self.test_session().__enter__()) inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt)) inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt)) @@ -208,8 +209,8 @@ def test_pbc_small_box(self): data1 = Data(box_scale = 2) inter0 = Inter() inter1 = Inter() - inter0.setUp(data0, pbc = True) - inter1.setUp(data1, pbc = False) + inter0.setUp(data0, pbc = True, sess=self.test_session().__enter__()) + inter1.setUp(data1, pbc = False, sess=self.test_session().__enter__()) inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt)) inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt)) diff --git a/source/tests/test_descrpt_sea_ef.py b/source/tests/test_descrpt_sea_ef.py index c9a1f262c1..18f777a2c4 100644 --- a/source/tests/test_descrpt_sea_ef.py +++ b/source/tests/test_descrpt_sea_ef.py @@ -23,8 +23,9 @@ class Inter(): def setUp (self, data, - pbc = True) : - self.sess = self.test_session() + pbc = True, + sess = None) : + self.sess = sess self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() diff --git a/source/tests/test_descrpt_sea_ef_para.py b/source/tests/test_descrpt_sea_ef_para.py index 39e8a15a44..d1317f458a 100644 --- a/source/tests/test_descrpt_sea_ef_para.py +++ b/source/tests/test_descrpt_sea_ef_para.py @@ -23,8 +23,9 @@ class Inter(): def setUp (self, data, - pbc = True) : - self.sess = self.test_session() + pbc = True, + sess = None) : + self.sess = sess self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() diff --git a/source/tests/test_descrpt_sea_ef_rot.py b/source/tests/test_descrpt_sea_ef_rot.py index f1bf47f9dd..f172ec2771 100644 --- a/source/tests/test_descrpt_sea_ef_rot.py +++ b/source/tests/test_descrpt_sea_ef_rot.py @@ -14,7 +14,7 @@ class TestEfRot(tf.test.TestCase): def setUp(self): - self.sess = self.test_session() + self.sess = self.test_session().__enter__() self.natoms = [5, 5, 2, 3] self.ntypes = 2 self.sel_a = [12,24] diff --git a/source/tests/test_descrpt_sea_ef_vert.py b/source/tests/test_descrpt_sea_ef_vert.py index d6e7e35c81..75833bcaa5 100644 --- a/source/tests/test_descrpt_sea_ef_vert.py +++ b/source/tests/test_descrpt_sea_ef_vert.py @@ -23,8 +23,9 @@ class Inter(): def setUp (self, data, - pbc = True) : - self.sess = self.test_session() + pbc = True, + sess = None) : + self.sess = sess self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() diff --git a/source/tests/test_descrpt_smooth.py b/source/tests/test_descrpt_smooth.py index f47deae858..71fbd5cfcd 100644 --- a/source/tests/test_descrpt_smooth.py +++ b/source/tests/test_descrpt_smooth.py @@ -23,8 +23,9 @@ class Inter(): def setUp (self, data, - pbc = True) : - self.sess = self.test_session() + pbc = True, + sess) : + self.sess = sess self.data = data self.natoms = self.data.get_natoms() self.ntypes = self.data.get_ntypes() @@ -177,8 +178,8 @@ def test_pbc(self): data = Data() inter0 = Inter() inter1 = Inter() - inter0.setUp(data, pbc = True) - inter1.setUp(data, pbc = False) + inter0.setUp(data, pbc = True, sess=self.test_session().__enter__()) + inter1.setUp(data, pbc = False, sess=self.test_session().__enter__()) inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt)) inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt)) @@ -218,8 +219,8 @@ def test_pbc_small_box(self): data1 = Data(box_scale = 2) inter0 = Inter() inter1 = Inter() - inter0.setUp(data0, pbc = True) - inter1.setUp(data1, pbc = False) + inter0.setUp(data0, pbc = True, sess=self.test_session().__enter__()) + inter1.setUp(data1, pbc = False, sess=self.test_session().__enter__()) inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt)) inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt)) diff --git a/source/tests/test_dipole_se_a.py b/source/tests/test_dipole_se_a.py index 44c9786efa..3687ba47f2 100644 --- a/source/tests/test_dipole_se_a.py +++ b/source/tests/test_dipole_se_a.py @@ -90,7 +90,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [p, gp] = sess.run([dipole, gdipole], feed_dict = feed_dict_test) diff --git a/source/tests/test_embedding_net.py b/source/tests/test_embedding_net.py index c7c4cafd3d..03084a8267 100644 --- a/source/tests/test_embedding_net.py +++ b/source/tests/test_embedding_net.py @@ -13,7 +13,7 @@ class Inter(tf.test.TestCase): def setUp (self) : - self.sess = self.test_session() + self.sess = self.test_session().__enter__() self.inputs = tf.constant([ 0., 1., 2.], dtype = tf.float64) self.ndata = 3 self.inputs = tf.reshape(self.inputs, [-1, 1]) diff --git a/source/tests/test_ewald.py b/source/tests/test_ewald.py index 463351ca58..f4913db36f 100644 --- a/source/tests/test_ewald.py +++ b/source/tests/test_ewald.py @@ -62,7 +62,7 @@ def setUp(self): def test_py_interface(self): hh = 1e-4 places = 4 - sess = self.test_session() + sess = self.test_session().__enter__() t_energy, t_force, t_virial \ = op_module.ewald_recp(self.coord, self.charge, self.nloc, self.box, ewald_h = self.ewald_h, @@ -96,7 +96,7 @@ def test_py_interface(self): def test_force(self): hh = 1e-4 places = 6 - sess = self.test_session() + sess = self.test_session().__enter__() t_energy, t_force, t_virial \ = op_module.ewald_recp(self.coord, self.charge, self.nloc, self.box, ewald_h = self.ewald_h, @@ -138,7 +138,7 @@ def test_force(self): def test_virial(self): hh = 1e-4 places = 6 - sess = self.test_session() + sess = self.test_session().__enter__() t_energy, t_force, t_virial \ = op_module.ewald_recp(self.coord, self.charge, self.nloc, self.box, ewald_h = self.ewald_h, diff --git a/source/tests/test_fitting_ener_type.py b/source/tests/test_fitting_ener_type.py index bc4e8842aa..25b85391e2 100644 --- a/source/tests/test_fitting_ener_type.py +++ b/source/tests/test_fitting_ener_type.py @@ -97,7 +97,7 @@ def test_fitting(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [pred_atom_ener] = sess.run([atom_ener], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_loc_frame.py b/source/tests/test_model_loc_frame.py index f7401965b7..3ffce1bf41 100644 --- a/source/tests/test_model_loc_frame.py +++ b/source/tests/test_model_loc_frame.py @@ -96,7 +96,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a.py b/source/tests/test_model_se_a.py index 24bf601d74..e8b12c8553 100644 --- a/source/tests/test_model_se_a.py +++ b/source/tests/test_model_se_a.py @@ -102,7 +102,7 @@ def test_model_atom_ener(self): t_mesh: test_data['default_mesh'], is_training: False } - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) @@ -189,7 +189,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a_aparam.py b/source/tests/test_model_se_a_aparam.py index 4111f55daf..9c9847c242 100644 --- a/source/tests/test_model_se_a_aparam.py +++ b/source/tests/test_model_se_a_aparam.py @@ -95,7 +95,7 @@ def test_model(self): t_aparam: np.reshape(test_data['aparam'] [:numb_test, :], [-1]), is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a_srtab.py b/source/tests/test_model_se_a_srtab.py index 00a68d1983..4b87fc0bb4 100644 --- a/source/tests/test_model_se_a_srtab.py +++ b/source/tests/test_model_se_a_srtab.py @@ -117,7 +117,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_a_type.py b/source/tests/test_model_se_a_type.py index f98cc73982..11979108b4 100644 --- a/source/tests/test_model_se_a_type.py +++ b/source/tests/test_model_se_a_type.py @@ -99,7 +99,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_r.py b/source/tests/test_model_se_r.py index 7b09f07304..56052db8d3 100644 --- a/source/tests/test_model_se_r.py +++ b/source/tests/test_model_se_r.py @@ -91,7 +91,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_model_se_t.py b/source/tests/test_model_se_t.py index e9f47f1031..48fffd76b0 100644 --- a/source/tests/test_model_se_t.py +++ b/source/tests/test_model_se_t.py @@ -89,7 +89,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_polar_se_a.py b/source/tests/test_polar_se_a.py index cc73a17781..78da154864 100644 --- a/source/tests/test_polar_se_a.py +++ b/source/tests/test_polar_se_a.py @@ -89,7 +89,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [p, gp] = sess.run([polar, gpolar], feed_dict = feed_dict_test) diff --git a/source/tests/test_prod_env_mat.py b/source/tests/test_prod_env_mat.py index c94e8b08d8..4b60da39d9 100644 --- a/source/tests/test_prod_env_mat.py +++ b/source/tests/test_prod_env_mat.py @@ -11,7 +11,7 @@ class TestProdEnvMat(tf.test.TestCase): def setUp(self): - self.sess = self.test_session() + self.sess = self.test_session().__enter__() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_prod_force.py b/source/tests/test_prod_force.py index 754e30f08d..4135ee0e2e 100644 --- a/source/tests/test_prod_force.py +++ b/source/tests/test_prod_force.py @@ -11,7 +11,7 @@ class TestProdForce(tf.test.TestCase): def setUp(self): - self.sess = self.test_session() + self.sess = self.test_session().__enter__() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_prod_force_grad.py b/source/tests/test_prod_force_grad.py index 04629e94a6..8b8ad74ed6 100644 --- a/source/tests/test_prod_force_grad.py +++ b/source/tests/test_prod_force_grad.py @@ -11,7 +11,7 @@ class TestProdForceGrad(tf.test.TestCase): def setUp(self): - self.sess = self.test_session() + self.sess = self.test_session().__enter__() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_prod_virial.py b/source/tests/test_prod_virial.py index 3d9467dc52..2a11f0acfd 100644 --- a/source/tests/test_prod_virial.py +++ b/source/tests/test_prod_virial.py @@ -11,7 +11,7 @@ class TestProdVirial(tf.test.TestCase): def setUp(self): - self.sess = self.test_session() + self.sess = self.test_session().__enter__() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_prod_virial_grad.py b/source/tests/test_prod_virial_grad.py index 3b0878e9d5..2ca9071a56 100644 --- a/source/tests/test_prod_virial_grad.py +++ b/source/tests/test_prod_virial_grad.py @@ -11,7 +11,7 @@ class TestProdVirialGrad(tf.test.TestCase): def setUp(self): - self.sess = self.test_session() + self.sess = self.test_session().__enter__() self.nframes = 2 self.dcoord = [ 12.83, 2.56, 2.18, diff --git a/source/tests/test_type_embed.py b/source/tests/test_type_embed.py index 07c06f4244..42d1c86063 100644 --- a/source/tests/test_type_embed.py +++ b/source/tests/test_type_embed.py @@ -19,7 +19,7 @@ def test_embed_atom_type(self): [7,7,7], [7,7,7]] atom_embed = embed_atom_type(ntypes, natoms, type_embedding) - sess = self.test_session() + sess = self.test_session().__enter__() atom_embed = sess.run(atom_embed) for ii in range(5): for jj in range(3): @@ -29,7 +29,7 @@ def test_embed_atom_type(self): def test_type_embed_net(self): ten = TypeEmbedNet([2, 4, 8], seed = 1, uniform_seed = True) type_embedding = ten.build(2) - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) type_embedding = sess.run(type_embedding) diff --git a/source/tests/test_wfc.py b/source/tests/test_wfc.py index 4412f26182..97af7486bd 100644 --- a/source/tests/test_wfc.py +++ b/source/tests/test_wfc.py @@ -83,7 +83,7 @@ def test_model(self): t_mesh: test_data['default_mesh'], is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [p] = sess.run([wfc], feed_dict = feed_dict_test) From 473807535384b0fc48cabd13262ca21cc0ae63e8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 7 Jul 2021 03:08:14 -0400 Subject: [PATCH 4/6] bugfix --- source/tests/test_descrpt_se_ar.py | 2 +- source/tests/test_descrpt_smooth.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/source/tests/test_descrpt_se_ar.py b/source/tests/test_descrpt_se_ar.py index 6e0bf7bf93..260343520b 100644 --- a/source/tests/test_descrpt_se_ar.py +++ b/source/tests/test_descrpt_se_ar.py @@ -23,7 +23,7 @@ class Inter(): def setUp (self, data, - sess) : + sess = None) : self.sess = sess self.data = data self.natoms = self.data.get_natoms() diff --git a/source/tests/test_descrpt_smooth.py b/source/tests/test_descrpt_smooth.py index 71fbd5cfcd..87ea15e9f0 100644 --- a/source/tests/test_descrpt_smooth.py +++ b/source/tests/test_descrpt_smooth.py @@ -24,7 +24,7 @@ class Inter(): def setUp (self, data, pbc = True, - sess) : + sess = None) : self.sess = sess self.data = data self.natoms = self.data.get_natoms() From 16962672212fe0397122365f09686895b87289f3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 7 Jul 2021 03:26:39 -0400 Subject: [PATCH 5/6] bugfix --- source/tests/test_descrpt_nonsmth.py | 2 +- source/tests/test_descrpt_se_ar.py | 2 +- source/tests/test_descrpt_se_r.py | 2 +- source/tests/test_descrpt_sea_ef.py | 2 +- source/tests/test_descrpt_sea_ef_para.py | 2 +- source/tests/test_descrpt_sea_ef_vert.py | 2 +- source/tests/test_descrpt_smooth.py | 2 +- source/tests/test_tab_nonsmth.py | 7 ++++--- source/tests/test_tab_smooth.py | 7 ++++--- 9 files changed, 15 insertions(+), 13 deletions(-) diff --git a/source/tests/test_descrpt_nonsmth.py b/source/tests/test_descrpt_nonsmth.py index 901d718c09..decb18e053 100644 --- a/source/tests/test_descrpt_nonsmth.py +++ b/source/tests/test_descrpt_nonsmth.py @@ -167,7 +167,7 @@ class TestNonSmooth(Inter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - Inter.setUp(self, data) + Inter.setUp(self, data, sess=self.test_session().__enter__()) def test_force (self) : force_test(self, self, suffix = '_se') diff --git a/source/tests/test_descrpt_se_ar.py b/source/tests/test_descrpt_se_ar.py index 260343520b..9ebb085bab 100644 --- a/source/tests/test_descrpt_se_ar.py +++ b/source/tests/test_descrpt_se_ar.py @@ -105,7 +105,7 @@ class TestDescrptAR(Inter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - Inter.setUp(self, data) + Inter.setUp(self, data, sess=self.test_session().__enter__()) def test_force (self) : force_test(self, self, suffix = '_se_ar') diff --git a/source/tests/test_descrpt_se_r.py b/source/tests/test_descrpt_se_r.py index f698292eda..bf002736b0 100644 --- a/source/tests/test_descrpt_se_r.py +++ b/source/tests/test_descrpt_se_r.py @@ -147,7 +147,7 @@ class TestSmooth(Inter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - Inter.setUp(self, data) + Inter.setUp(self, data, sess=self.test_session().__enter__()) def test_force (self) : force_test(self, self, suffix = '_se_r') diff --git a/source/tests/test_descrpt_sea_ef.py b/source/tests/test_descrpt_sea_ef.py index 18f777a2c4..b15aaaa7b0 100644 --- a/source/tests/test_descrpt_sea_ef.py +++ b/source/tests/test_descrpt_sea_ef.py @@ -159,7 +159,7 @@ class TestSmooth(Inter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - Inter.setUp(self, data) + Inter.setUp(self, data, sess=self.test_session().__enter__()) def test_force (self) : force_test(self, self, suffix = '_sea_ef') diff --git a/source/tests/test_descrpt_sea_ef_para.py b/source/tests/test_descrpt_sea_ef_para.py index d1317f458a..213fc17930 100644 --- a/source/tests/test_descrpt_sea_ef_para.py +++ b/source/tests/test_descrpt_sea_ef_para.py @@ -159,7 +159,7 @@ class TestSmooth(Inter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - Inter.setUp(self, data) + Inter.setUp(self, data, sess=self.test_session().__enter__()) def test_force (self) : force_test(self, self, suffix = '_sea_ef_para') diff --git a/source/tests/test_descrpt_sea_ef_vert.py b/source/tests/test_descrpt_sea_ef_vert.py index 75833bcaa5..8ba82709e3 100644 --- a/source/tests/test_descrpt_sea_ef_vert.py +++ b/source/tests/test_descrpt_sea_ef_vert.py @@ -159,7 +159,7 @@ class TestSmooth(Inter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - Inter.setUp(self, data) + Inter.setUp(self, data, sess=self.test_session().__enter__()) def test_force (self) : force_test(self, self, suffix = '_sea_ef_vert') diff --git a/source/tests/test_descrpt_smooth.py b/source/tests/test_descrpt_smooth.py index 87ea15e9f0..adb5beb354 100644 --- a/source/tests/test_descrpt_smooth.py +++ b/source/tests/test_descrpt_smooth.py @@ -158,7 +158,7 @@ class TestSmooth(Inter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - Inter.setUp(self, data) + Inter.setUp(self, data, sess=self.test_session().__enter__()) def test_force (self) : force_test(self, self, suffix = '_smth') diff --git a/source/tests/test_tab_nonsmth.py b/source/tests/test_tab_nonsmth.py index f6496bbb1d..6a57da4fd7 100644 --- a/source/tests/test_tab_nonsmth.py +++ b/source/tests/test_tab_nonsmth.py @@ -31,9 +31,10 @@ def _make_tab(ntype) : class IntplInter(Inter): def setUp (self, - data) : + data, + sess=None) : # tabulated - Inter.setUp(self, data) + Inter.setUp(self, data, sess=sess) _make_tab(data.get_ntypes()) self.srtab = PairTab('tab.xvg') self.smin_alpha = 0.3 @@ -162,7 +163,7 @@ class TestTabNonSmooth(IntplInter, unittest.TestCase): def setUp(self): self.places = 5 data = Data() - IntplInter.setUp(self, data) + IntplInter.setUp(self, data, sess=self.test_session().__enter__()) def test_force (self) : force_test(self, self, places=5, suffix = '_tab') diff --git a/source/tests/test_tab_smooth.py b/source/tests/test_tab_smooth.py index 55c16d57cc..112c19a297 100644 --- a/source/tests/test_tab_smooth.py +++ b/source/tests/test_tab_smooth.py @@ -31,9 +31,10 @@ def _make_tab(ntype) : class IntplInter(Inter): def setUp (self, - data) : + data, + sess=None) : # tabulated - Inter.setUp(self, data) + Inter.setUp(self, data, sess=sess) _make_tab(data.get_ntypes()) self.srtab = PairTab('tab.xvg') self.smin_alpha = 0.3 @@ -160,7 +161,7 @@ class TestTabSmooth(IntplInter, unittest.TestCase): def setUp(self): self.places = 5 data = Data() - IntplInter.setUp(self, data) + IntplInter.setUp(self, data, sess=self.test_session().__enter__()) def test_force (self) : force_test(self, self, places=5, suffix = '_tab_smth') From fd8aef7e78927b7aa2355b84fe1235aca123ad60 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 7 Jul 2021 11:39:58 -0400 Subject: [PATCH 6/6] bugfix --- source/tests/test_model_se_a_fparam.py | 2 +- source/tests/test_tab_nonsmth.py | 2 +- source/tests/test_tab_smooth.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/source/tests/test_model_se_a_fparam.py b/source/tests/test_model_se_a_fparam.py index 849107d59a..23b0da569b 100644 --- a/source/tests/test_model_se_a_fparam.py +++ b/source/tests/test_model_se_a_fparam.py @@ -96,7 +96,7 @@ def test_model(self): t_fparam: np.reshape(test_data['fparam'] [:numb_test, :], [-1]), is_training: False} - sess = self.test_session() + sess = self.test_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict = feed_dict_test) diff --git a/source/tests/test_tab_nonsmth.py b/source/tests/test_tab_nonsmth.py index 6a57da4fd7..0d8305cef4 100644 --- a/source/tests/test_tab_nonsmth.py +++ b/source/tests/test_tab_nonsmth.py @@ -152,7 +152,7 @@ def comp_interpl_ef (self, -class TestTabNonSmooth(IntplInter, unittest.TestCase): +class TestTabNonSmooth(IntplInter, tf.test.TestCase): # def __init__ (self, *args, **kwargs): # self.places = 5 # data = Data() diff --git a/source/tests/test_tab_smooth.py b/source/tests/test_tab_smooth.py index 112c19a297..ab4dc65c7a 100644 --- a/source/tests/test_tab_smooth.py +++ b/source/tests/test_tab_smooth.py @@ -150,7 +150,7 @@ def comp_ef (self, -class TestTabSmooth(IntplInter, unittest.TestCase): +class TestTabSmooth(IntplInter, tf.test.TestCase): # def __init__ (self, *args, **kwargs): # self.places = 5 # data = Data()