diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 60b5ecf0e0..514b552aec 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -18,7 +18,7 @@ jobs: - python: 3.8 tf: torch: - - python: "3.11" + - python: "3.12" tf: torch: diff --git a/backend/find_tensorflow.py b/backend/find_tensorflow.py index c4d58ea0cd..fb9e719600 100644 --- a/backend/find_tensorflow.py +++ b/backend/find_tensorflow.py @@ -136,6 +136,11 @@ def get_tf_requirement(tf_version: str = "") -> dict: extra_select = {} if not (tf_version == "" or tf_version in SpecifierSet(">=2.12", prereleases=True)): extra_requires.append("protobuf<3.20") + # keras 3 is not compatible with tf.compat.v1 + if tf_version == "" or tf_version in SpecifierSet(">=2.15.0rc0", prereleases=True): + extra_requires.append("tf-keras; python_version>='3.9'") + # only TF>=2.16 is compatible with Python 3.12 + extra_requires.append("tf-keras>=2.16.0rc0; python_version>='3.12'") if tf_version == "" or tf_version in SpecifierSet(">=1.15", prereleases=True): extra_select["mpi"] = [ "horovod", diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 327c3c1d3d..1c3c48e484 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -990,6 +990,7 @@ def _attention_layers( input_xyz = tf.keras.layers.LayerNormalization( beta_initializer=tf.constant_initializer(self.beta[i]), gamma_initializer=tf.constant_initializer(self.gamma[i]), + dtype=self.filter_precision, )(input_xyz) # input_xyz = self._feedforward(input_xyz, outputs_size[-1], self.att_n) return input_xyz diff --git a/deepmd/tf/env.py b/deepmd/tf/env.py index 2afe5cc862..3127e01e97 100644 --- a/deepmd/tf/env.py +++ b/deepmd/tf/env.py @@ -75,7 +75,9 @@ def dlopen_library(module: str, filename: str): dlopen_library("nvidia.cusparse.lib", "libcusparse.so*") dlopen_library("nvidia.cudnn.lib", "libcudnn.so*") - +# keras 3 is incompatible with tf.compat.v1 +# https://keras.io/getting_started/#tensorflow--keras-2-backwards-compatibility +os.environ["TF_USE_LEGACY_KERAS"] = "1" # import tensorflow v1 compatability try: import tensorflow.compat.v1 as tf