Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions source/tests/test_data_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
INPUT = os.path.join(modifier_datapath, 'dipole.json')


class TestDataModifier (unittest.TestCase) :
class TestDataModifier (tf.test.TestCase) :

def setUp(self):
# with tf.variable_scope('load', reuse = False) :
Expand Down Expand Up @@ -74,7 +74,7 @@ def _setUp(self):
model.build (data)

# freeze the graph
with tf.Session() as sess:
with self.test_session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
graph = tf.get_default_graph()
Expand Down
4 changes: 2 additions & 2 deletions source/tests/test_data_modifier_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
modifier_datapath = 'data_modifier'


class TestDataModifier (unittest.TestCase) :
class TestDataModifier (tf.test.TestCase) :

def setUp(self):
# with tf.variable_scope('load', reuse = False) :
Expand Down Expand Up @@ -78,7 +78,7 @@ def _setUp(self):
model.build (data)

# freeze the graph
with tf.Session() as sess:
with self.test_session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
graph = tf.get_default_graph()
Expand Down
21 changes: 11 additions & 10 deletions source/tests/test_descrpt_nonsmth.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class Inter():
def setUp (self,
data,
comp = 0,
pbc = True) :
self.sess = tf.Session()
pbc = True,
sess = None) :
self.sess = sess
self.data = data
self.natoms = self.data.get_natoms()
self.ntypes = self.data.get_ntypes()
Expand Down Expand Up @@ -155,18 +156,18 @@ def comp_v_dw (self,



class TestNonSmooth(Inter, unittest.TestCase):
class TestNonSmooth(Inter, tf.test.TestCase):
# def __init__ (self, *args, **kwargs):
# self.places = 5
# data = Data()
# Inter.__init__(self, data)
# unittest.TestCase.__init__(self, *args, **kwargs)
# tf.test.TestCase.__init__(self, *args, **kwargs)
# self.controller = object()

def setUp(self):
self.places = 5
data = Data()
Inter.setUp(self, data)
Inter.setUp(self, data, sess=self.test_session().__enter__())

def test_force (self) :
force_test(self, self, suffix = '_se')
Expand All @@ -181,13 +182,13 @@ def test_virial_dw (self) :
virial_dw_test(self, self, suffix = '_se')


class TestLFPbc(unittest.TestCase):
class TestLFPbc(tf.test.TestCase):
def test_pbc(self):
data = Data()
inter0 = Inter()
inter1 = Inter()
inter0.setUp(data, pbc = True)
inter1.setUp(data, pbc = False)
inter0.setUp(data, pbc = True, sess=self.test_session().__enter__())
inter1.setUp(data, pbc = False, sess=self.test_session().__enter__())
inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt))
inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt))

Expand Down Expand Up @@ -227,8 +228,8 @@ def test_pbc_small_box(self):
data1 = Data(box_scale = 2)
inter0 = Inter()
inter1 = Inter()
inter0.setUp(data0, pbc = True)
inter1.setUp(data1, pbc = False)
inter0.setUp(data0, pbc = True, sess=self.test_session().__enter__())
inter1.setUp(data1, pbc = False, sess=self.test_session().__enter__())
inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt))
inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt))

Expand Down
6 changes: 3 additions & 3 deletions source/tests/test_descrpt_se_a_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
GLOBAL_TF_FLOAT_PRECISION = tf.float64
GLOBAL_NP_FLOAT_PRECISION = np.float64

class TestModel(unittest.TestCase):
class TestModel(tf.test.TestCase):
def setUp(self) :
gen_data()

Expand Down Expand Up @@ -110,7 +110,7 @@ def test_descriptor_two_sides(self):
t_mesh: test_data['default_mesh'],
is_training: False}

sess = tf.Session()
sess = self.test_session().__enter__()
sess.run(tf.global_variables_initializer())
[model_dout] = sess.run([dout],
feed_dict = feed_dict_test)
Expand Down Expand Up @@ -219,7 +219,7 @@ def test_descriptor_one_side(self):
t_mesh: test_data['default_mesh'],
is_training: False}

sess = tf.Session()
sess = self.test_session().__enter__()
sess.run(tf.global_variables_initializer())
[model_dout] = sess.run([dout],
feed_dict = feed_dict_test)
Expand Down
11 changes: 6 additions & 5 deletions source/tests/test_descrpt_se_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@

class Inter():
def setUp (self,
data) :
self.sess = tf.Session()
data,
sess = None) :
self.sess = sess
self.data = data
self.natoms = self.data.get_natoms()
self.ntypes = self.data.get_ntypes()
Expand Down Expand Up @@ -94,17 +95,17 @@ def comp_ef (self,
return energy, force, virial


class TestDescrptAR(Inter, unittest.TestCase):
class TestDescrptAR(Inter, tf.test.TestCase):
# def __init__ (self, *args, **kwargs):
# data = Data()
# Inter.__init__(self, data)
# unittest.TestCase.__init__(self, *args, **kwargs)
# tf.test.TestCase.__init__(self, *args, **kwargs)
# self.controller = object()

def setUp(self):
self.places = 5
data = Data()
Inter.setUp(self, data)
Inter.setUp(self, data, sess=self.test_session().__enter__())

def test_force (self) :
force_test(self, self, suffix = '_se_ar')
Expand Down
21 changes: 11 additions & 10 deletions source/tests/test_descrpt_se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
class Inter():
def setUp (self,
data,
pbc = True) :
self.sess = tf.Session()
pbc = True,
sess = None) :
self.sess = sess
self.data = data
self.natoms = self.data.get_natoms()
self.ntypes = self.data.get_ntypes()
Expand Down Expand Up @@ -136,17 +137,17 @@ def comp_v_dw (self,



class TestSmooth(Inter, unittest.TestCase):
class TestSmooth(Inter, tf.test.TestCase):
# def __init__ (self, *args, **kwargs):
# data = Data()
# Inter.__init__(self, data)
# unittest.TestCase.__init__(self, *args, **kwargs)
# tf.test.TestCase.__init__(self, *args, **kwargs)
# self.controller = object()

def setUp(self):
self.places = 5
data = Data()
Inter.setUp(self, data)
Inter.setUp(self, data, sess=self.test_session().__enter__())

def test_force (self) :
force_test(self, self, suffix = '_se_r')
Expand All @@ -161,13 +162,13 @@ def test_virial_dw (self) :
virial_dw_test(self, self, suffix = '_se_r')


class TestSeRPbc(unittest.TestCase):
class TestSeRPbc(tf.test.TestCase):
def test_pbc(self):
data = Data()
inter0 = Inter()
inter1 = Inter()
inter0.setUp(data, pbc = True)
inter1.setUp(data, pbc = False)
inter0.setUp(data, pbc = True, sess=self.test_session().__enter__())
inter1.setUp(data, pbc = False, sess=self.test_session().__enter__())
inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt))
inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt))

Expand Down Expand Up @@ -208,8 +209,8 @@ def test_pbc_small_box(self):
data1 = Data(box_scale = 2)
inter0 = Inter()
inter1 = Inter()
inter0.setUp(data0, pbc = True)
inter1.setUp(data1, pbc = False)
inter0.setUp(data0, pbc = True, sess=self.test_session().__enter__())
inter1.setUp(data1, pbc = False, sess=self.test_session().__enter__())
inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt))
inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt))

Expand Down
11 changes: 6 additions & 5 deletions source/tests/test_descrpt_sea_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
class Inter():
def setUp (self,
data,
pbc = True) :
self.sess = tf.Session()
pbc = True,
sess = None) :
self.sess = sess
self.data = data
self.natoms = self.data.get_natoms()
self.ntypes = self.data.get_ntypes()
Expand Down Expand Up @@ -148,17 +149,17 @@ def comp_v_dw (self,



class TestSmooth(Inter, unittest.TestCase):
class TestSmooth(Inter, tf.test.TestCase):
# def __init__ (self, *args, **kwargs):
# data = Data()
# Inter.__init__(self, data)
# unittest.TestCase.__init__(self, *args, **kwargs)
# tf.test.TestCase.__init__(self, *args, **kwargs)
# self.controller = object()

def setUp(self):
self.places = 5
data = Data()
Inter.setUp(self, data)
Inter.setUp(self, data, sess=self.test_session().__enter__())

def test_force (self) :
force_test(self, self, suffix = '_sea_ef')
Expand Down
11 changes: 6 additions & 5 deletions source/tests/test_descrpt_sea_ef_para.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
class Inter():
def setUp (self,
data,
pbc = True) :
self.sess = tf.Session()
pbc = True,
sess = None) :
self.sess = sess
self.data = data
self.natoms = self.data.get_natoms()
self.ntypes = self.data.get_ntypes()
Expand Down Expand Up @@ -148,17 +149,17 @@ def comp_v_dw (self,



class TestSmooth(Inter, unittest.TestCase):
class TestSmooth(Inter, tf.test.TestCase):
# def __init__ (self, *args, **kwargs):
# data = Data()
# Inter.__init__(self, data)
# unittest.TestCase.__init__(self, *args, **kwargs)
# tf.test.TestCase.__init__(self, *args, **kwargs)
# self.controller = object()

def setUp(self):
self.places = 5
data = Data()
Inter.setUp(self, data)
Inter.setUp(self, data, sess=self.test_session().__enter__())

def test_force (self) :
force_test(self, self, suffix = '_sea_ef_para')
Expand Down
4 changes: 2 additions & 2 deletions source/tests/test_descrpt_sea_ef_rot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from deepmd.descriptor import DescrptSeA
from deepmd.descriptor import DescrptSeAEfLower

class TestEfRot(unittest.TestCase):
class TestEfRot(tf.test.TestCase):
def setUp(self):
self.sess = tf.Session()
self.sess = self.test_session().__enter__()
self.natoms = [5, 5, 2, 3]
self.ntypes = 2
self.sel_a = [12,24]
Expand Down
11 changes: 6 additions & 5 deletions source/tests/test_descrpt_sea_ef_vert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
class Inter():
def setUp (self,
data,
pbc = True) :
self.sess = tf.Session()
pbc = True,
sess = None) :
self.sess = sess
self.data = data
self.natoms = self.data.get_natoms()
self.ntypes = self.data.get_ntypes()
Expand Down Expand Up @@ -148,17 +149,17 @@ def comp_v_dw (self,



class TestSmooth(Inter, unittest.TestCase):
class TestSmooth(Inter, tf.test.TestCase):
# def __init__ (self, *args, **kwargs):
# data = Data()
# Inter.__init__(self, data)
# unittest.TestCase.__init__(self, *args, **kwargs)
# tf.test.TestCase.__init__(self, *args, **kwargs)
# self.controller = object()

def setUp(self):
self.places = 5
data = Data()
Inter.setUp(self, data)
Inter.setUp(self, data, sess=self.test_session().__enter__())

def test_force (self) :
force_test(self, self, suffix = '_sea_ef_vert')
Expand Down
21 changes: 11 additions & 10 deletions source/tests/test_descrpt_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
class Inter():
def setUp (self,
data,
pbc = True) :
self.sess = tf.Session()
pbc = True,
sess = None) :
self.sess = sess
self.data = data
self.natoms = self.data.get_natoms()
self.ntypes = self.data.get_ntypes()
Expand Down Expand Up @@ -147,17 +148,17 @@ def comp_v_dw (self,



class TestSmooth(Inter, unittest.TestCase):
class TestSmooth(Inter, tf.test.TestCase):
# def __init__ (self, *args, **kwargs):
# data = Data()
# Inter.__init__(self, data)
# unittest.TestCase.__init__(self, *args, **kwargs)
# tf.test.TestCase.__init__(self, *args, **kwargs)
# self.controller = object()

def setUp(self):
self.places = 5
data = Data()
Inter.setUp(self, data)
Inter.setUp(self, data, sess=self.test_session().__enter__())

def test_force (self) :
force_test(self, self, suffix = '_smth')
Expand All @@ -172,13 +173,13 @@ def test_virial_dw (self) :
virial_dw_test(self, self, suffix = '_smth')


class TestSeAPbc(unittest.TestCase):
class TestSeAPbc(tf.test.TestCase):
def test_pbc(self):
data = Data()
inter0 = Inter()
inter1 = Inter()
inter0.setUp(data, pbc = True)
inter1.setUp(data, pbc = False)
inter0.setUp(data, pbc = True, sess=self.test_session().__enter__())
inter1.setUp(data, pbc = False, sess=self.test_session().__enter__())
inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt))
inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt))

Expand Down Expand Up @@ -218,8 +219,8 @@ def test_pbc_small_box(self):
data1 = Data(box_scale = 2)
inter0 = Inter()
inter1 = Inter()
inter0.setUp(data0, pbc = True)
inter1.setUp(data1, pbc = False)
inter0.setUp(data0, pbc = True, sess=self.test_session().__enter__())
inter1.setUp(data1, pbc = False, sess=self.test_session().__enter__())
inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt))
inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt))

Expand Down
Loading