diff --git a/deepmd/common.py b/deepmd/common.py index 3197b1a858..c9873a6d94 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -54,6 +54,7 @@ "gelu", "gelu_tf", "silu", + "silut", "none", "linear", ] diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index f51308e881..8fe20021aa 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -325,6 +325,37 @@ def fn(x): # generated by GitHub Copilot return x / (1 + xp.exp(-x)) + return fn + elif activation_function.startswith("silut") or activation_function.startswith( + "custom_silu" + ): + + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + def silu(x): + return x * sigmoid(x) + + def silu_grad(x): + sig = sigmoid(x) + return sig + x * sig * (1 - sig) + + threshold = ( + float(activation_function.split(":")[-1]) + if ":" in activation_function + else 3.0 + ) + slope = float(silu_grad(threshold)) + const = float(silu(threshold)) + + def fn(x): + xp = array_api_compat.array_namespace(x) + return xp.where( + x < threshold, + x * (1 / (1 + xp.exp(-x))), + xp.tanh(slope * (x - threshold)) + const, + ) + return fn elif activation_function.lower() in ("none", "linear"): diff --git a/deepmd/pd/utils/utils.py b/deepmd/pd/utils/utils.py index 8bddddb164..1872d9ca1d 100644 --- a/deepmd/pd/utils/utils.py +++ b/deepmd/pd/utils/utils.py @@ -32,10 +32,47 @@ ) +class SiLUT(paddle.nn.Layer): + def __init__(self, threshold=3.0): + super().__init__() + + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + def silu(x): + return x * sigmoid(x) + + def silu_grad(x): + sig = sigmoid(x) + return sig + x * sig * (1 - sig) + + self.threshold = threshold + self.slope = float(silu_grad(threshold)) + self.const = float(silu(threshold)) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + silu_part = F.silu(x) + mask = x >= self.threshold + if paddle.any(mask): + tanh_part = paddle.tanh(self.slope * (x - self.threshold)) + self.const + return paddle.where(x < self.threshold, silu_part, tanh_part) + else: + return silu_part + + class ActivationFn(paddle.nn.Layer): def __init__(self, activation: str | None): super().__init__() self.activation: str = activation if activation is not None else "linear" + if self.activation.lower().startswith( + "silut" + ) or self.activation.lower().startswith("custom_silu"): + threshold = ( + float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0 + ) + self.silut = SiLUT(threshold=threshold) + else: + self.silut = None def forward(self, x: paddle.Tensor) -> paddle.Tensor: """Returns the tensor after applying activation function corresponding to `activation`.""" @@ -53,6 +90,11 @@ def forward(self, x: paddle.Tensor) -> paddle.Tensor: return F.sigmoid(x) elif self.activation.lower() == "silu": return F.silu(x) + elif self.activation.lower().startswith( + "silut" + ) or self.activation.lower().startswith("custom_silu"): + assert self.silut is not None + return self.silut(x) elif self.activation.lower() == "linear" or self.activation.lower() == "none": return x else: diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index b0cddae2f1..3fe507ecc2 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -249,6 +249,7 @@ def train( output: str = "out.json", ) -> None: log.info("Configuration path: %s", input_file) + env.CUSTOM_OP_USE_JIT = True if LOCAL_RANK == 0: SummaryPrinter()() with open(input_file) as fin: diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 0e1322a640..185bb1add3 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -34,6 +34,7 @@ JIT = False CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory ENERGY_BIAS_TRAINABLE = True +CUSTOM_OP_USE_JIT = False PRECISION_DICT = { "float16": torch.float16, @@ -76,6 +77,7 @@ __all__ = [ "CACHE_PER_SYS", + "CUSTOM_OP_USE_JIT", "DEFAULT_PRECISION", "DEVICE", "ENERGY_BIAS_TRAINABLE", diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 50d378455b..1812aca9ec 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -11,6 +11,9 @@ import torch.nn.functional as F from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT +from deepmd.pt.utils import ( + env, +) from .env import ( DEVICE, @@ -18,10 +21,159 @@ from .env import PRECISION_DICT as PT_PRECISION_DICT +def silut_forward( + x: torch.Tensor, threshold: float, slope: float, const_val: float +) -> torch.Tensor: + sig = torch.sigmoid(x) + silu = x * sig + tanh_part = torch.tanh(slope * (x - threshold)) + const_val + return torch.where(x >= threshold, tanh_part, silu) + + +def silut_backward( + x: torch.Tensor, grad_output: torch.Tensor, threshold: float, slope: float +): + sig = torch.sigmoid(x) + grad_silu = sig * (1 + x * (1 - sig)) + + tanh_term = torch.tanh(slope * (x - threshold)) + grad_tanh = slope * (1 - tanh_term.pow(2)) + + grad = torch.where(x >= threshold, grad_tanh, grad_silu) + return grad * grad_output, grad + + +def silut_double_backward( + x: torch.Tensor, + grad_grad_output: torch.Tensor, + grad_output: torch.Tensor, + threshold: float, + slope: float, +) -> torch.Tensor: + # Tanh branch + tanh_term = torch.tanh(slope * (x - threshold)) + grad_grad = -2 * slope * slope * tanh_term * (1 - tanh_term * tanh_term) + + # SiLU branch + sig = 1.0 / (1.0 + torch.exp(-x)) + sig_prime = sig * (1 - sig) + silu_term = sig_prime * (2 + x * (1 - 2 * sig)) + + grad_grad = torch.where(x >= threshold, grad_grad, silu_term) + + return grad_output * grad_grad * grad_grad_output + + +class SiLUTScript(torch.nn.Module): + def __init__(self, threshold: float = 3.0): + super().__init__() + self.threshold = threshold + + # Precompute parameters for the tanh replacement + sigmoid_threshold = 1 / (1 + np.exp(-threshold)) + self.slope = float( + sigmoid_threshold + threshold * sigmoid_threshold * (1 - sigmoid_threshold) + ) + self.const_val = float(threshold * sigmoid_threshold) + self.get_script_code() + + def get_script_code(self): + silut_forward_script = torch.jit.script(silut_forward) + silut_backward_script = torch.jit.script(silut_backward) + silut_double_backward_script = torch.jit.script(silut_double_backward) + + class SiLUTFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, threshold, slope, const_val): + ctx.save_for_backward(x) + ctx.threshold = threshold + ctx.slope = slope + ctx.const_val = const_val + return silut_forward_script(x, threshold, slope, const_val) + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensors + threshold = ctx.threshold + slope = ctx.slope + + grad_input = SiLUTGradFunction.apply(x, grad_output, threshold, slope) + return grad_input, None, None, None + + class SiLUTGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, grad_output, threshold, slope): + ctx.threshold = threshold + ctx.slope = slope + grad_input, grad = silut_backward_script( + x, grad_output, threshold, slope + ) + ctx.save_for_backward(x, grad_output, grad) + return grad_input + + @staticmethod + def backward(ctx, grad_grad_output): + (x, grad_output, grad) = ctx.saved_tensors + threshold = ctx.threshold + slope = ctx.slope + + grad_input = silut_double_backward_script( + x, grad_grad_output, grad_output, threshold, slope + ) + return grad_input, grad * grad_grad_output, None, None + + self.SiLUTFunction = SiLUTFunction + + def forward(self, x): + return self.SiLUTFunction.apply(x, self.threshold, self.slope, self.const_val) + + +class SiLUT(torch.nn.Module): + def __init__(self, threshold=3.0): + super().__init__() + + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + def silu(x): + return x * sigmoid(x) + + def silu_grad(x): + sig = sigmoid(x) + return sig + x * sig * (1 - sig) + + self.threshold = threshold + self.slope = float(silu_grad(threshold)) + self.const = float(silu(threshold)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + silu_part = F.silu(x) + mask = x >= self.threshold + if torch.any(mask): + tanh_part = torch.tanh(self.slope * (x - self.threshold)) + self.const + return torch.where(x < self.threshold, silu_part, tanh_part) + else: + return silu_part + + class ActivationFn(torch.nn.Module): def __init__(self, activation: Optional[str]) -> None: super().__init__() self.activation: str = activation if activation is not None else "linear" + if self.activation.lower().startswith( + "silut" + ) or self.activation.lower().startswith("custom_silu"): + threshold = ( + float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0 + ) + if env.CUSTOM_OP_USE_JIT: + # for efficient training but can not be jit + self.silut = SiLUTScript(threshold=threshold) + else: + # for jit freeze + self.silut = SiLUT(threshold=threshold) + else: + self.silut = None def forward(self, x: torch.Tensor) -> torch.Tensor: """Returns the tensor after applying activation function corresponding to `activation`.""" @@ -41,6 +193,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.sigmoid(x) elif self.activation.lower() == "silu": return F.silu(x) + elif self.activation.lower().startswith( + "silut" + ) or self.activation.lower().startswith("custom_silu"): + assert self.silut is not None + return self.silut(x) elif self.activation.lower() == "linear" or self.activation.lower() == "none": return x else: diff --git a/deepmd/tf/common.py b/deepmd/tf/common.py index c570d06b72..985c36c686 100644 --- a/deepmd/tf/common.py +++ b/deepmd/tf/common.py @@ -144,6 +144,47 @@ def silu(x: tf.Tensor) -> tf.Tensor: return x * tf.sigmoid(x) +def get_silut(activation_function: str = "silut"): + import numpy as np + + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + def silu(x): + return x * sigmoid(x) + + def silu_grad(x): + sig = sigmoid(x) + return sig + x * sig * (1 - sig) + + threshold = ( + float(activation_function.split(":")[-1]) if ":" in activation_function else 3.0 + ) + slope = float(silu_grad(threshold)) + const = float(silu(threshold)) + + def silut(x: tf.Tensor) -> tf.Tensor: + """The customized sigmoid-weighted linear unit with tanh. + + Parameters + ---------- + x : tf.Tensor + float Tensor to perform activation + + Returns + ------- + tf.Tensor + `x` with the SiLUT activation applied + """ + return tf.where( + x < threshold, + x * tf.sigmoid(x), + tf.nn.tanh(slope * (x - threshold)) + const, + ) + + return silut + + ACTIVATION_FN_DICT = { "relu": tf.nn.relu, "relu6": tf.nn.relu6, @@ -153,6 +194,7 @@ def silu(x: tf.Tensor) -> tf.Tensor: "gelu": gelu, "gelu_tf": gelu_tf, "silu": silu, + "silut": get_silut("silut"), "linear": lambda x: x, "none": lambda x: x, } @@ -182,6 +224,8 @@ def get_activation_func( if activation_fn is None: activation_fn = "none" assert activation_fn is not None + if activation_fn.lower().startswith("silut"): + ACTIVATION_FN_DICT[activation_fn.lower()] = get_silut(activation_fn.lower()) if activation_fn.lower() not in ACTIVATION_FN_DICT: raise RuntimeError(f"{activation_fn} is not a valid activation function") return ACTIVATION_FN_DICT[activation_fn.lower()] diff --git a/source/tests/pt/test_custom_activation.py b/source/tests/pt/test_custom_activation.py new file mode 100644 index 0000000000..461c0ee7b9 --- /dev/null +++ b/source/tests/pt/test_custom_activation.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + SiLUT, + SiLUTScript, + to_numpy_array, +) + +from ..consistent.common import ( + parameterized, +) + + +@parameterized( + (3.0, 10.0), +) +class TestSiLUT(unittest.TestCase): + def setUp(self) -> None: + (self.threshold,) = self.param + self.silut_naive = SiLUT(threshold=self.threshold) + self.silut_script = SiLUTScript(threshold=self.threshold) + + def test_naive_consistent_with_script(self) -> None: + def get_compare(silut_tmp): + x_tmp = torch.arange( + -60.0, + 60.0, + 0.1, + device=env.DEVICE, + requires_grad=True, + dtype=torch.float64, + ) + y_tmp = silut_tmp(x_tmp) + dy_tmp = torch.autograd.grad(y_tmp, x_tmp, y_tmp * 10.0, create_graph=True)[ + 0 + ] + dy2_tmp = torch.autograd.grad(dy_tmp, x_tmp, dy_tmp * 10.0)[0] + return ( + to_numpy_array(y_tmp), + to_numpy_array(dy_tmp), + to_numpy_array(dy2_tmp), + ) + + rtol = 1e-8 + atol = 1e-8 + naive_y, naive_dy, naive_dy2 = get_compare(self.silut_naive) + script_y, script_dy, script_dy2 = get_compare(self.silut_script) + np.testing.assert_allclose(naive_y, script_y, rtol=rtol, atol=atol) + np.testing.assert_allclose(naive_dy, script_dy, rtol=rtol, atol=atol) + np.testing.assert_allclose(naive_dy2, script_dy2, rtol=rtol, atol=atol)