From 31164563daf60c6b2098ef131ee666a4e18e2452 Mon Sep 17 00:00:00 2001 From: nahso Date: Tue, 30 Apr 2024 14:14:02 +0800 Subject: [PATCH] more lossy float32 precision requirements --- .../tf/test_model_compression_se_atten.py | 42 +++++++++---------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/source/tests/tf/test_model_compression_se_atten.py b/source/tests/tf/test_model_compression_se_atten.py index 03ddedad39..1ac82446c6 100644 --- a/source/tests/tf/test_model_compression_se_atten.py +++ b/source/tests/tf/test_model_compression_se_atten.py @@ -28,36 +28,36 @@ def _file_delete(file): os.remove(file) -# 4 tests: -# - type embedding FP64, se_atten FP64 -# - type embedding FP64, se_atten FP32 -# - type embedding FP32, se_atten FP64 -# - type embedding FP32, se_atten FP32 tests = [ { "se_atten precision": "float64", "type embedding precision": "float64", "smooth_type_embedding": True, + "precision_digit": 10, }, { "se_atten precision": "float64", "type embedding precision": "float64", "smooth_type_embedding": False, + "precision_digit": 10, }, { "se_atten precision": "float64", "type embedding precision": "float32", "smooth_type_embedding": True, + "precision_digit": 2, }, { "se_atten precision": "float32", "type embedding precision": "float64", "smooth_type_embedding": True, + "precision_digit": 2, }, { "se_atten precision": "float32", "type embedding precision": "float32", "smooth_type_embedding": True, + "precision_digit": 2, }, ] @@ -158,10 +158,6 @@ def _init_models_exclude_types(): INPUTS_ET, FROZEN_MODELS_ET, COMPRESSED_MODELS_ET = _init_models_exclude_types() -def _get_default_places(nth_test): - return 10 if nth_test == 0 else 3 - - @unittest.skipIf( parse_version(tf.__version__) < parse_version("2"), f"The current tf version {tf.__version__} is too low to run the new testing model.", @@ -200,7 +196,7 @@ def test_attrs(self): for i in range(len(tests)): dp_original = self.dp_originals[i] dp_compressed = self.dp_compresseds[i] - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] self.assertEqual(dp_original.get_ntypes(), 2) self.assertAlmostEqual(dp_original.get_rcut(), 6.0, places=default_places) @@ -218,7 +214,7 @@ def test_1frame(self): for i in range(len(tests)): dp_original = self.dp_originals[i] dp_compressed = self.dp_compresseds[i] - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] ee0, ff0, vv0 = dp_original.eval( self.coords, self.box, self.atype, atomic=False @@ -244,7 +240,7 @@ def test_1frame_atm(self): for i in range(len(tests)): dp_original = self.dp_originals[i] dp_compressed = self.dp_compresseds[i] - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] ee0, ff0, vv0, ae0, av0 = dp_original.eval( self.coords, self.box, self.atype, atomic=True @@ -276,7 +272,7 @@ def test_2frame_atm(self): for i in range(len(tests)): dp_original = self.dp_originals[i] dp_compressed = self.dp_compresseds[i] - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] coords2 = np.concatenate((self.coords, self.coords)) box2 = np.concatenate((self.box, self.box)) @@ -346,7 +342,7 @@ def test_1frame(self): for i in range(len(tests)): dp_original = self.dp_originals[i] dp_compressed = self.dp_compresseds[i] - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] ee0, ff0, vv0 = dp_original.eval( self.coords, self.box, self.atype, atomic=False @@ -372,7 +368,7 @@ def test_1frame_atm(self): for i in range(len(tests)): dp_original = self.dp_originals[i] dp_compressed = self.dp_compresseds[i] - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] ee0, ff0, vv0, ae0, av0 = dp_original.eval( self.coords, self.box, self.atype, atomic=True @@ -404,7 +400,7 @@ def test_2frame_atm(self): for i in range(len(tests)): dp_original = self.dp_originals[i] dp_compressed = self.dp_compresseds[i] - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] coords2 = np.concatenate((self.coords, self.coords)) ee0, ff0, vv0, ae0, av0 = dp_original.eval( @@ -473,7 +469,7 @@ def test_1frame(self): for i in range(len(tests)): dp_original = self.dp_originals[i] dp_compressed = self.dp_compresseds[i] - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] ee0, ff0, vv0 = dp_original.eval( self.coords, self.box, self.atype, atomic=False @@ -505,7 +501,7 @@ def test_1frame_atm(self): for i in range(len(tests)): dp_original = self.dp_originals[i] dp_compressed = self.dp_compresseds[i] - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] ee0, ff0, vv0, ae0, av0 = dp_original.eval( self.coords, self.box, self.atype, atomic=True @@ -535,7 +531,7 @@ def test_1frame_atm(self): def test_ase(self): for i in range(len(tests)): - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] from ase import ( Atoms, ) @@ -628,7 +624,7 @@ def test_attrs(self): for i in range(len(tests)): dp_original = self.dp_originals[i] dp_compressed = self.dp_compresseds[i] - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] self.assertEqual(dp_original.get_ntypes(), 2) self.assertAlmostEqual(dp_original.get_rcut(), 6.0, places=default_places) @@ -646,7 +642,7 @@ def test_1frame(self): for i in range(len(tests)): dp_original = self.dp_originals[i] dp_compressed = self.dp_compresseds[i] - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] ee0, ff0, vv0 = dp_original.eval( self.coords, self.box, self.atype, atomic=False @@ -672,7 +668,7 @@ def test_1frame_atm(self): for i in range(len(tests)): dp_original = self.dp_originals[i] dp_compressed = self.dp_compresseds[i] - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] ee0, ff0, vv0, ae0, av0 = dp_original.eval( self.coords, self.box, self.atype, atomic=True @@ -704,7 +700,7 @@ def test_2frame_atm(self): for i in range(len(tests)): dp_original = self.dp_originals[i] dp_compressed = self.dp_compresseds[i] - default_places = _get_default_places(i) + default_places = tests[i]["precision_digit"] coords2 = np.concatenate((self.coords, self.coords)) box2 = np.concatenate((self.box, self.box))