diff --git a/deepmd/common.py b/deepmd/common.py index 695fee1a93..1f9d3afb0c 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -2,6 +2,7 @@ import json import warnings +import tensorflow from functools import wraps from pathlib import Path from typing import ( @@ -64,7 +65,12 @@ def gelu(x: tf.Tensor) -> tf.Tensor: Original paper https://arxiv.org/abs/1606.08415 """ - return op_module.gelu(x) + def gelu_wrapper(x): + try: + return tensorflow.nn.gelu(x, approximate=True) + except AttributeError: + return op_module.gelu(x) + return (lambda x: gelu_wrapper(x))(x) # TODO this is not a good way to do things. This is some global variable to which diff --git a/source/op/_gelu.py b/source/op/_gelu.py index ac0585da78..db45ef798e 100644 --- a/source/op/_gelu.py +++ b/source/op/_gelu.py @@ -2,14 +2,17 @@ """ First-order derivatives and second-order derivatives for gelu function. """ - +import tensorflow from tensorflow.python.framework import ops from deepmd.env import op_module -@ops.RegisterGradient("Gelu") -def _gelu_cc (op, dy) : - return op_module.gelu_grad(dy, op.inputs[0]) +try: + gelu = tensorflow.nn.gelu +except AttributeError: + @ops.RegisterGradient("Gelu") + def _gelu_cc (op, dy) : + return op_module.gelu_grad(dy, op.inputs[0]) -@ops.RegisterGradient("GeluGrad") -def _gelu_grad_cc (op, dy) : - return [op_module.gelu_grad(dy, op.inputs[1]), op_module.gelu_grad_grad(dy, op.inputs[0], op.inputs[1])] + @ops.RegisterGradient("GeluGrad") + def _gelu_grad_cc (op, dy) : + return [op_module.gelu_grad(dy, op.inputs[1]), op_module.gelu_grad_grad(dy, op.inputs[0], op.inputs[1])]