diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index b8892c2d95..ca445c8588 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -192,3 +192,6 @@ def forward( # (nframes, nloc, 3) out = torch.bmm(out, gr).squeeze(-2).view(nframes, nloc, 3) return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} + + # make jit happy with torch 2.0.0 + exclude_types: List[int] diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 55ffd8c650..b58b0c9b19 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -214,6 +214,9 @@ def forward( """ return self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam) + # make jit happy with torch 2.0.0 + exclude_types: List[int] + @Fitting.register("ener") class EnergyFittingNet(InvarFitting): @@ -262,6 +265,9 @@ def serialize(self) -> dict: "type": "ener", } + # make jit happy with torch 2.0.0 + exclude_types: List[int] + @Fitting.register("direct_force") @Fitting.register("direct_force_ener") diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index eb6ccc2b7d..d7428c4d53 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -316,3 +316,6 @@ def forward( out = out + bias return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} + + # make jit happy with torch 2.0.0 + exclude_types: List[int]