From 6853af172ab049d4ac7fe5b68add7017d9331418 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 10 Mar 2025 22:45:17 +0800 Subject: [PATCH 1/4] feat: add CustomSilu --- deepmd/common.py | 1 + deepmd/dpmodel/utils/network.py | 29 +++++ deepmd/pd/utils/utils.py | 38 ++++++ deepmd/pt/entrypoints/main.py | 1 + deepmd/pt/utils/env.py | 1 + deepmd/pt/utils/utils.py | 148 ++++++++++++++++++++++ deepmd/tf/common.py | 46 +++++++ source/tests/pt/test_custom_activation.py | 57 +++++++++ 8 files changed, 321 insertions(+) create mode 100644 source/tests/pt/test_custom_activation.py diff --git a/deepmd/common.py b/deepmd/common.py index 3197b1a858..e99e83dbdc 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -54,6 +54,7 @@ "gelu", "gelu_tf", "silu", + "custom_silu", "none", "linear", ] diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index f51308e881..91dfb167c4 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -325,6 +325,35 @@ def fn(x): # generated by GitHub Copilot return x / (1 + xp.exp(-x)) + return fn + elif activation_function.startswith("custom_silu"): + + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + def silu(x): + return x * sigmoid(x) + + def silu_grad(x): + sig = sigmoid(x) + return sig + x * sig * (1 - sig) + + threshold = ( + float(activation_function.split(":")[-1]) + if ":" in activation_function + else 3.0 + ) + slope = float(silu_grad(threshold)) + const = float(silu(threshold)) + + def fn(x): + xp = array_api_compat.array_namespace(x) + return xp.where( + x < threshold, + x * (1 / (1 + xp.exp(-x))), + xp.tanh(slope * (x - threshold)) + const, + ) + return fn elif activation_function.lower() in ("none", "linear"): diff --git a/deepmd/pd/utils/utils.py b/deepmd/pd/utils/utils.py index 8bddddb164..39ff273abd 100644 --- a/deepmd/pd/utils/utils.py +++ b/deepmd/pd/utils/utils.py @@ -32,10 +32,45 @@ ) +class CustomSilu(paddle.nn.Layer): + def __init__(self, threshold=3.0): + super().__init__() + + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + def silu(x): + return x * sigmoid(x) + + def silu_grad(x): + sig = sigmoid(x) + return sig + x * sig * (1 - sig) + + self.threshold = threshold + self.slope = float(silu_grad(threshold)) + self.const = float(silu(threshold)) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + silu_part = F.silu(x) + mask = x >= self.threshold + if paddle.any(mask): + tanh_part = paddle.tanh(self.slope * (x - self.threshold)) + self.const + return paddle.where(x < self.threshold, silu_part, tanh_part) + else: + return silu_part + + class ActivationFn(paddle.nn.Layer): def __init__(self, activation: str | None): super().__init__() self.activation: str = activation if activation is not None else "linear" + if self.activation.lower().startswith("custom_silu"): + threshold = ( + float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0 + ) + self.custom_silu = CustomSilu(threshold=threshold) + else: + self.custom_silu = None def forward(self, x: paddle.Tensor) -> paddle.Tensor: """Returns the tensor after applying activation function corresponding to `activation`.""" @@ -53,6 +88,9 @@ def forward(self, x: paddle.Tensor) -> paddle.Tensor: return F.sigmoid(x) elif self.activation.lower() == "silu": return F.silu(x) + elif self.activation.lower().startswith("custom_silu"): + assert self.custom_silu is not None + return self.custom_silu(x) elif self.activation.lower() == "linear" or self.activation.lower() == "none": return x else: diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index b0cddae2f1..3fe507ecc2 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -249,6 +249,7 @@ def train( output: str = "out.json", ) -> None: log.info("Configuration path: %s", input_file) + env.CUSTOM_OP_USE_JIT = True if LOCAL_RANK == 0: SummaryPrinter()() with open(input_file) as fin: diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 0e1322a640..6f731b121e 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -34,6 +34,7 @@ JIT = False CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory ENERGY_BIAS_TRAINABLE = True +CUSTOM_OP_USE_JIT = False PRECISION_DICT = { "float16": torch.float16, diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 50d378455b..3e44f4eb58 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -11,6 +11,9 @@ import torch.nn.functional as F from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT +from deepmd.pt.utils import ( + env, +) from .env import ( DEVICE, @@ -18,10 +21,152 @@ from .env import PRECISION_DICT as PT_PRECISION_DICT +@torch.jit.script +def custom_silu_forward( + x: torch.Tensor, threshold: float, slope: float, const_val: float +) -> torch.Tensor: + sig = torch.sigmoid(x) + silu = x * sig + tanh_part = torch.tanh(slope * (x - threshold)) + const_val + return torch.where(x >= threshold, tanh_part, silu) + + +@torch.jit.script +def custom_silu_backward( + x: torch.Tensor, grad_output: torch.Tensor, threshold: float, slope: float +): + sig = torch.sigmoid(x) + grad_silu = sig * (1 + x * (1 - sig)) + + tanh_term = torch.tanh(slope * (x - threshold)) + grad_tanh = slope * (1 - tanh_term.pow(2)) + + grad = torch.where(x >= threshold, grad_tanh, grad_silu) + return grad * grad_output, grad + + +@torch.jit.script +def custom_silu_double_backward( + x: torch.Tensor, + grad_grad_output: torch.Tensor, + grad_output: torch.Tensor, + threshold: float, + slope: float, +) -> torch.Tensor: + # Tanh branch + tanh_term = torch.tanh(slope * (x - threshold)) + grad_grad = -2 * slope * slope * tanh_term * (1 - tanh_term * tanh_term) + + # SiLU branch + sig = 1.0 / (1.0 + torch.exp(-x)) + sig_prime = sig * (1 - sig) + silu_term = sig_prime * (2 + x * (1 - 2 * sig)) + + grad_grad = torch.where(x >= threshold, grad_grad, silu_term) + + return grad_output * grad_grad * grad_grad_output + + +class CustomSiluFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, threshold, slope, const_val): + ctx.save_for_backward(x) + ctx.threshold = threshold + ctx.slope = slope + ctx.const_val = const_val + return custom_silu_forward(x, threshold, slope, const_val) + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensors + threshold = ctx.threshold + slope = ctx.slope + + grad_input = CustomSiluGradFunction.apply(x, grad_output, threshold, slope) + return grad_input, None, None, None + + +class CustomSiluGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, grad_output, threshold, slope): + ctx.threshold = threshold + ctx.slope = slope + grad_input, grad = custom_silu_backward(x, grad_output, threshold, slope) + ctx.save_for_backward(x, grad_output, grad) + return grad_input + + @staticmethod + def backward(ctx, grad_grad_output): + (x, grad_output, grad) = ctx.saved_tensors + threshold = ctx.threshold + slope = ctx.slope + + grad_input = custom_silu_double_backward( + x, grad_grad_output, grad_output, threshold, slope + ) + return grad_input, grad * grad_grad_output, None, None + + +class CustomSiluScript(torch.nn.Module): + def __init__(self, threshold: float = 3.0): + super().__init__() + self.threshold = threshold + + # Precompute parameters for the tanh replacement + sigmoid_threshold = 1 / (1 + np.exp(-threshold)) + self.slope = float( + sigmoid_threshold + threshold * sigmoid_threshold * (1 - sigmoid_threshold) + ) + self.const_val = float(threshold * sigmoid_threshold) + + def forward(self, x): + return CustomSiluFunction.apply(x, self.threshold, self.slope, self.const_val) + + +class CustomSilu(torch.nn.Module): + def __init__(self, threshold=3.0): + super().__init__() + + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + def silu(x): + return x * sigmoid(x) + + def silu_grad(x): + sig = sigmoid(x) + return sig + x * sig * (1 - sig) + + self.threshold = threshold + self.slope = float(silu_grad(threshold)) + self.const = float(silu(threshold)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + silu_part = F.silu(x) + mask = x >= self.threshold + if torch.any(mask): + tanh_part = torch.tanh(self.slope * (x - self.threshold)) + self.const + return torch.where(x < self.threshold, silu_part, tanh_part) + else: + return silu_part + + class ActivationFn(torch.nn.Module): def __init__(self, activation: Optional[str]) -> None: super().__init__() self.activation: str = activation if activation is not None else "linear" + if self.activation.lower().startswith("custom_silu"): + threshold = ( + float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0 + ) + if env.CUSTOM_OP_USE_JIT: + # for efficient training but can not be jit + self.custom_silu = CustomSiluScript(threshold=threshold) + else: + # for jit freeze + self.custom_silu = CustomSilu(threshold=threshold) + else: + self.custom_silu = None def forward(self, x: torch.Tensor) -> torch.Tensor: """Returns the tensor after applying activation function corresponding to `activation`.""" @@ -41,6 +186,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.sigmoid(x) elif self.activation.lower() == "silu": return F.silu(x) + elif self.activation.lower().startswith("custom_silu"): + assert self.custom_silu is not None + return self.custom_silu(x) elif self.activation.lower() == "linear" or self.activation.lower() == "none": return x else: diff --git a/deepmd/tf/common.py b/deepmd/tf/common.py index c570d06b72..8147d1f68d 100644 --- a/deepmd/tf/common.py +++ b/deepmd/tf/common.py @@ -144,6 +144,47 @@ def silu(x: tf.Tensor) -> tf.Tensor: return x * tf.sigmoid(x) +def get_custom_silu(activation_function: str = "custom_silu"): + import numpy as np + + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + def silu(x): + return x * sigmoid(x) + + def silu_grad(x): + sig = sigmoid(x) + return sig + x * sig * (1 - sig) + + threshold = ( + float(activation_function.split(":")[-1]) if ":" in activation_function else 3.0 + ) + slope = float(silu_grad(threshold)) + const = float(silu(threshold)) + + def custom_silu(x: tf.Tensor) -> tf.Tensor: + """The customized sigmoid-weighted linear unit with tanh. + + Parameters + ---------- + x : tf.Tensor + float Tensor to perform activation + + Returns + ------- + tf.Tensor + `x` with the custom_silu activation applied + """ + return tf.where( + x < threshold, + x * tf.sigmoid(x), + tf.nn.tanh(slope * (x - threshold)) + const, + ) + + return custom_silu + + ACTIVATION_FN_DICT = { "relu": tf.nn.relu, "relu6": tf.nn.relu6, @@ -153,6 +194,7 @@ def silu(x: tf.Tensor) -> tf.Tensor: "gelu": gelu, "gelu_tf": gelu_tf, "silu": silu, + "custom_silu": get_custom_silu("custom_silu"), "linear": lambda x: x, "none": lambda x: x, } @@ -182,6 +224,10 @@ def get_activation_func( if activation_fn is None: activation_fn = "none" assert activation_fn is not None + if activation_fn.lower().startswith("custom_silu"): + ACTIVATION_FN_DICT[activation_fn.lower()] = get_custom_silu( + activation_fn.lower() + ) if activation_fn.lower() not in ACTIVATION_FN_DICT: raise RuntimeError(f"{activation_fn} is not a valid activation function") return ACTIVATION_FN_DICT[activation_fn.lower()] diff --git a/source/tests/pt/test_custom_activation.py b/source/tests/pt/test_custom_activation.py new file mode 100644 index 0000000000..bc547463cf --- /dev/null +++ b/source/tests/pt/test_custom_activation.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + CustomSilu, + CustomSiluScript, + to_numpy_array, +) + +from ..consistent.common import ( + parameterized, +) + + +@parameterized( + (3.0, 10.0), +) +class TestCustomSilu(unittest.TestCase): + def setUp(self) -> None: + (self.threshold,) = self.param + self.custom_silu_naive = CustomSilu(threshold=self.threshold) + self.custom_silu_script = CustomSiluScript(threshold=self.threshold) + + def test_naive_consistent_with_script(self) -> None: + def get_compare(cust_silu_tmp): + x_tmp = torch.arange( + -60.0, + 60.0, + 0.1, + device=env.DEVICE, + requires_grad=True, + dtype=torch.float64, + ) + y_tmp = cust_silu_tmp(x_tmp) + dy_tmp = torch.autograd.grad(y_tmp, x_tmp, y_tmp * 10.0, create_graph=True)[ + 0 + ] + dy2_tmp = torch.autograd.grad(dy_tmp, x_tmp, dy_tmp * 10.0)[0] + return ( + to_numpy_array(y_tmp), + to_numpy_array(dy_tmp), + to_numpy_array(dy2_tmp), + ) + + rtol = 1e-8 + atol = 1e-8 + naive_y, naive_dy, naive_dy2 = get_compare(self.custom_silu_naive) + script_y, script_dy, script_dy2 = get_compare(self.custom_silu_script) + np.testing.assert_allclose(naive_y, script_y, rtol=rtol, atol=atol) + np.testing.assert_allclose(naive_dy, script_dy, rtol=rtol, atol=atol) + np.testing.assert_allclose(naive_dy2, script_dy2, rtol=rtol, atol=atol) From c7e559d00a860c9badf8ab0da6657bac3750e746 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 11 Mar 2025 15:45:56 +0800 Subject: [PATCH 2/4] rename customsilu to silut --- deepmd/common.py | 2 +- deepmd/dpmodel/utils/network.py | 4 ++- deepmd/pd/utils/utils.py | 18 ++++++---- deepmd/pt/utils/utils.py | 42 +++++++++++++---------- deepmd/tf/common.py | 16 ++++----- source/tests/pt/test_custom_activation.py | 18 +++++----- 6 files changed, 54 insertions(+), 46 deletions(-) diff --git a/deepmd/common.py b/deepmd/common.py index e99e83dbdc..c9873a6d94 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -54,7 +54,7 @@ "gelu", "gelu_tf", "silu", - "custom_silu", + "silut", "none", "linear", ] diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 91dfb167c4..8fe20021aa 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -326,7 +326,9 @@ def fn(x): return x / (1 + xp.exp(-x)) return fn - elif activation_function.startswith("custom_silu"): + elif activation_function.startswith("silut") or activation_function.startswith( + "custom_silu" + ): def sigmoid(x): return 1 / (1 + np.exp(-x)) diff --git a/deepmd/pd/utils/utils.py b/deepmd/pd/utils/utils.py index 39ff273abd..1872d9ca1d 100644 --- a/deepmd/pd/utils/utils.py +++ b/deepmd/pd/utils/utils.py @@ -32,7 +32,7 @@ ) -class CustomSilu(paddle.nn.Layer): +class SiLUT(paddle.nn.Layer): def __init__(self, threshold=3.0): super().__init__() @@ -64,13 +64,15 @@ class ActivationFn(paddle.nn.Layer): def __init__(self, activation: str | None): super().__init__() self.activation: str = activation if activation is not None else "linear" - if self.activation.lower().startswith("custom_silu"): + if self.activation.lower().startswith( + "silut" + ) or self.activation.lower().startswith("custom_silu"): threshold = ( float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0 ) - self.custom_silu = CustomSilu(threshold=threshold) + self.silut = SiLUT(threshold=threshold) else: - self.custom_silu = None + self.silut = None def forward(self, x: paddle.Tensor) -> paddle.Tensor: """Returns the tensor after applying activation function corresponding to `activation`.""" @@ -88,9 +90,11 @@ def forward(self, x: paddle.Tensor) -> paddle.Tensor: return F.sigmoid(x) elif self.activation.lower() == "silu": return F.silu(x) - elif self.activation.lower().startswith("custom_silu"): - assert self.custom_silu is not None - return self.custom_silu(x) + elif self.activation.lower().startswith( + "silut" + ) or self.activation.lower().startswith("custom_silu"): + assert self.silut is not None + return self.silut(x) elif self.activation.lower() == "linear" or self.activation.lower() == "none": return x else: diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 3e44f4eb58..ae092d9468 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -22,7 +22,7 @@ @torch.jit.script -def custom_silu_forward( +def silut_forward( x: torch.Tensor, threshold: float, slope: float, const_val: float ) -> torch.Tensor: sig = torch.sigmoid(x) @@ -32,7 +32,7 @@ def custom_silu_forward( @torch.jit.script -def custom_silu_backward( +def silut_backward( x: torch.Tensor, grad_output: torch.Tensor, threshold: float, slope: float ): sig = torch.sigmoid(x) @@ -46,7 +46,7 @@ def custom_silu_backward( @torch.jit.script -def custom_silu_double_backward( +def silut_double_backward( x: torch.Tensor, grad_grad_output: torch.Tensor, grad_output: torch.Tensor, @@ -67,14 +67,14 @@ def custom_silu_double_backward( return grad_output * grad_grad * grad_grad_output -class CustomSiluFunction(torch.autograd.Function): +class SiLUTFunction(torch.autograd.Function): @staticmethod def forward(ctx, x, threshold, slope, const_val): ctx.save_for_backward(x) ctx.threshold = threshold ctx.slope = slope ctx.const_val = const_val - return custom_silu_forward(x, threshold, slope, const_val) + return silut_forward(x, threshold, slope, const_val) @staticmethod def backward(ctx, grad_output): @@ -82,16 +82,16 @@ def backward(ctx, grad_output): threshold = ctx.threshold slope = ctx.slope - grad_input = CustomSiluGradFunction.apply(x, grad_output, threshold, slope) + grad_input = SiLUTGradFunction.apply(x, grad_output, threshold, slope) return grad_input, None, None, None -class CustomSiluGradFunction(torch.autograd.Function): +class SiLUTGradFunction(torch.autograd.Function): @staticmethod def forward(ctx, x, grad_output, threshold, slope): ctx.threshold = threshold ctx.slope = slope - grad_input, grad = custom_silu_backward(x, grad_output, threshold, slope) + grad_input, grad = silut_backward(x, grad_output, threshold, slope) ctx.save_for_backward(x, grad_output, grad) return grad_input @@ -101,13 +101,13 @@ def backward(ctx, grad_grad_output): threshold = ctx.threshold slope = ctx.slope - grad_input = custom_silu_double_backward( + grad_input = silut_double_backward( x, grad_grad_output, grad_output, threshold, slope ) return grad_input, grad * grad_grad_output, None, None -class CustomSiluScript(torch.nn.Module): +class SiLUTScript(torch.nn.Module): def __init__(self, threshold: float = 3.0): super().__init__() self.threshold = threshold @@ -120,10 +120,10 @@ def __init__(self, threshold: float = 3.0): self.const_val = float(threshold * sigmoid_threshold) def forward(self, x): - return CustomSiluFunction.apply(x, self.threshold, self.slope, self.const_val) + return SiLUTFunction.apply(x, self.threshold, self.slope, self.const_val) -class CustomSilu(torch.nn.Module): +class SiLUT(torch.nn.Module): def __init__(self, threshold=3.0): super().__init__() @@ -155,18 +155,20 @@ class ActivationFn(torch.nn.Module): def __init__(self, activation: Optional[str]) -> None: super().__init__() self.activation: str = activation if activation is not None else "linear" - if self.activation.lower().startswith("custom_silu"): + if self.activation.lower().startswith( + "silut" + ) or self.activation.lower().startswith("custom_silu"): threshold = ( float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0 ) if env.CUSTOM_OP_USE_JIT: # for efficient training but can not be jit - self.custom_silu = CustomSiluScript(threshold=threshold) + self.silut = SiLUTScript(threshold=threshold) else: # for jit freeze - self.custom_silu = CustomSilu(threshold=threshold) + self.silut = SiLUT(threshold=threshold) else: - self.custom_silu = None + self.silut = None def forward(self, x: torch.Tensor) -> torch.Tensor: """Returns the tensor after applying activation function corresponding to `activation`.""" @@ -186,9 +188,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.sigmoid(x) elif self.activation.lower() == "silu": return F.silu(x) - elif self.activation.lower().startswith("custom_silu"): - assert self.custom_silu is not None - return self.custom_silu(x) + elif self.activation.lower().startswith( + "silut" + ) or self.activation.lower().startswith("custom_silu"): + assert self.silut is not None + return self.silut(x) elif self.activation.lower() == "linear" or self.activation.lower() == "none": return x else: diff --git a/deepmd/tf/common.py b/deepmd/tf/common.py index 8147d1f68d..985c36c686 100644 --- a/deepmd/tf/common.py +++ b/deepmd/tf/common.py @@ -144,7 +144,7 @@ def silu(x: tf.Tensor) -> tf.Tensor: return x * tf.sigmoid(x) -def get_custom_silu(activation_function: str = "custom_silu"): +def get_silut(activation_function: str = "silut"): import numpy as np def sigmoid(x): @@ -163,7 +163,7 @@ def silu_grad(x): slope = float(silu_grad(threshold)) const = float(silu(threshold)) - def custom_silu(x: tf.Tensor) -> tf.Tensor: + def silut(x: tf.Tensor) -> tf.Tensor: """The customized sigmoid-weighted linear unit with tanh. Parameters @@ -174,7 +174,7 @@ def custom_silu(x: tf.Tensor) -> tf.Tensor: Returns ------- tf.Tensor - `x` with the custom_silu activation applied + `x` with the SiLUT activation applied """ return tf.where( x < threshold, @@ -182,7 +182,7 @@ def custom_silu(x: tf.Tensor) -> tf.Tensor: tf.nn.tanh(slope * (x - threshold)) + const, ) - return custom_silu + return silut ACTIVATION_FN_DICT = { @@ -194,7 +194,7 @@ def custom_silu(x: tf.Tensor) -> tf.Tensor: "gelu": gelu, "gelu_tf": gelu_tf, "silu": silu, - "custom_silu": get_custom_silu("custom_silu"), + "silut": get_silut("silut"), "linear": lambda x: x, "none": lambda x: x, } @@ -224,10 +224,8 @@ def get_activation_func( if activation_fn is None: activation_fn = "none" assert activation_fn is not None - if activation_fn.lower().startswith("custom_silu"): - ACTIVATION_FN_DICT[activation_fn.lower()] = get_custom_silu( - activation_fn.lower() - ) + if activation_fn.lower().startswith("silut"): + ACTIVATION_FN_DICT[activation_fn.lower()] = get_silut(activation_fn.lower()) if activation_fn.lower() not in ACTIVATION_FN_DICT: raise RuntimeError(f"{activation_fn} is not a valid activation function") return ACTIVATION_FN_DICT[activation_fn.lower()] diff --git a/source/tests/pt/test_custom_activation.py b/source/tests/pt/test_custom_activation.py index bc547463cf..461c0ee7b9 100644 --- a/source/tests/pt/test_custom_activation.py +++ b/source/tests/pt/test_custom_activation.py @@ -8,8 +8,8 @@ env, ) from deepmd.pt.utils.utils import ( - CustomSilu, - CustomSiluScript, + SiLUT, + SiLUTScript, to_numpy_array, ) @@ -21,14 +21,14 @@ @parameterized( (3.0, 10.0), ) -class TestCustomSilu(unittest.TestCase): +class TestSiLUT(unittest.TestCase): def setUp(self) -> None: (self.threshold,) = self.param - self.custom_silu_naive = CustomSilu(threshold=self.threshold) - self.custom_silu_script = CustomSiluScript(threshold=self.threshold) + self.silut_naive = SiLUT(threshold=self.threshold) + self.silut_script = SiLUTScript(threshold=self.threshold) def test_naive_consistent_with_script(self) -> None: - def get_compare(cust_silu_tmp): + def get_compare(silut_tmp): x_tmp = torch.arange( -60.0, 60.0, @@ -37,7 +37,7 @@ def get_compare(cust_silu_tmp): requires_grad=True, dtype=torch.float64, ) - y_tmp = cust_silu_tmp(x_tmp) + y_tmp = silut_tmp(x_tmp) dy_tmp = torch.autograd.grad(y_tmp, x_tmp, y_tmp * 10.0, create_graph=True)[ 0 ] @@ -50,8 +50,8 @@ def get_compare(cust_silu_tmp): rtol = 1e-8 atol = 1e-8 - naive_y, naive_dy, naive_dy2 = get_compare(self.custom_silu_naive) - script_y, script_dy, script_dy2 = get_compare(self.custom_silu_script) + naive_y, naive_dy, naive_dy2 = get_compare(self.silut_naive) + script_y, script_dy, script_dy2 = get_compare(self.silut_script) np.testing.assert_allclose(naive_y, script_y, rtol=rtol, atol=atol) np.testing.assert_allclose(naive_dy, script_dy, rtol=rtol, atol=atol) np.testing.assert_allclose(naive_dy2, script_dy2, rtol=rtol, atol=atol) From ae6bfdf06db0ed5f26fd4d433227f9ab5466d35b Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 12 Mar 2025 15:50:45 +0800 Subject: [PATCH 3/4] make jit runtime --- deepmd/pt/utils/utils.py | 93 +++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index ae092d9468..1812aca9ec 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -21,7 +21,6 @@ from .env import PRECISION_DICT as PT_PRECISION_DICT -@torch.jit.script def silut_forward( x: torch.Tensor, threshold: float, slope: float, const_val: float ) -> torch.Tensor: @@ -31,7 +30,6 @@ def silut_forward( return torch.where(x >= threshold, tanh_part, silu) -@torch.jit.script def silut_backward( x: torch.Tensor, grad_output: torch.Tensor, threshold: float, slope: float ): @@ -45,7 +43,6 @@ def silut_backward( return grad * grad_output, grad -@torch.jit.script def silut_double_backward( x: torch.Tensor, grad_grad_output: torch.Tensor, @@ -67,46 +64,6 @@ def silut_double_backward( return grad_output * grad_grad * grad_grad_output -class SiLUTFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x, threshold, slope, const_val): - ctx.save_for_backward(x) - ctx.threshold = threshold - ctx.slope = slope - ctx.const_val = const_val - return silut_forward(x, threshold, slope, const_val) - - @staticmethod - def backward(ctx, grad_output): - (x,) = ctx.saved_tensors - threshold = ctx.threshold - slope = ctx.slope - - grad_input = SiLUTGradFunction.apply(x, grad_output, threshold, slope) - return grad_input, None, None, None - - -class SiLUTGradFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x, grad_output, threshold, slope): - ctx.threshold = threshold - ctx.slope = slope - grad_input, grad = silut_backward(x, grad_output, threshold, slope) - ctx.save_for_backward(x, grad_output, grad) - return grad_input - - @staticmethod - def backward(ctx, grad_grad_output): - (x, grad_output, grad) = ctx.saved_tensors - threshold = ctx.threshold - slope = ctx.slope - - grad_input = silut_double_backward( - x, grad_grad_output, grad_output, threshold, slope - ) - return grad_input, grad * grad_grad_output, None, None - - class SiLUTScript(torch.nn.Module): def __init__(self, threshold: float = 3.0): super().__init__() @@ -118,9 +75,57 @@ def __init__(self, threshold: float = 3.0): sigmoid_threshold + threshold * sigmoid_threshold * (1 - sigmoid_threshold) ) self.const_val = float(threshold * sigmoid_threshold) + self.get_script_code() + + def get_script_code(self): + silut_forward_script = torch.jit.script(silut_forward) + silut_backward_script = torch.jit.script(silut_backward) + silut_double_backward_script = torch.jit.script(silut_double_backward) + + class SiLUTFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, threshold, slope, const_val): + ctx.save_for_backward(x) + ctx.threshold = threshold + ctx.slope = slope + ctx.const_val = const_val + return silut_forward_script(x, threshold, slope, const_val) + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensors + threshold = ctx.threshold + slope = ctx.slope + + grad_input = SiLUTGradFunction.apply(x, grad_output, threshold, slope) + return grad_input, None, None, None + + class SiLUTGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, grad_output, threshold, slope): + ctx.threshold = threshold + ctx.slope = slope + grad_input, grad = silut_backward_script( + x, grad_output, threshold, slope + ) + ctx.save_for_backward(x, grad_output, grad) + return grad_input + + @staticmethod + def backward(ctx, grad_grad_output): + (x, grad_output, grad) = ctx.saved_tensors + threshold = ctx.threshold + slope = ctx.slope + + grad_input = silut_double_backward_script( + x, grad_grad_output, grad_output, threshold, slope + ) + return grad_input, grad * grad_grad_output, None, None + + self.SiLUTFunction = SiLUTFunction def forward(self, x): - return SiLUTFunction.apply(x, self.threshold, self.slope, self.const_val) + return self.SiLUTFunction.apply(x, self.threshold, self.slope, self.const_val) class SiLUT(torch.nn.Module): From 610564346e1d50c710e4ad8ab9a5a6e7b97e2b42 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 13 Mar 2025 15:55:22 +0800 Subject: [PATCH 4/4] add CUSTOM_OP_USE_JIT to __all__ --- deepmd/pt/utils/env.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 6f731b121e..185bb1add3 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -77,6 +77,7 @@ __all__ = [ "CACHE_PER_SYS", + "CUSTOM_OP_USE_JIT", "DEFAULT_PRECISION", "DEVICE", "ENERGY_BIAS_TRAINABLE",