Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,89 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Optional,
Tuple,
)

import numpy as np

from deepmd.dpmodel.utils import (
AtomExcludeMask,
PairExcludeMask,
)

from .make_base_atomic_model import (
make_base_atomic_model,
)

BaseAtomicModel = make_base_atomic_model(np.ndarray)
BaseAtomicModel_ = make_base_atomic_model(np.ndarray)


class BaseAtomicModel(BaseAtomicModel_):
def __init__(
self,
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)

def reinit_atom_exclude(
self,
exclude_types: List[int] = [],
):
self.atom_exclude_types = exclude_types
if exclude_types == []:
self.atom_excl = None
else:
self.atom_excl = AtomExcludeMask(self.get_ntypes(), self.atom_exclude_types)

def reinit_pair_exclude(
self,
exclude_types: List[Tuple[int, int]] = [],
):
self.pair_exclude_types = exclude_types
if exclude_types == []:
self.pair_excl = None
else:
self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)

def forward_common_atomic(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Dict[str, np.ndarray]:
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]
if self.pair_excl is not None:
pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype)
# exclude neighbors in the nlist
nlist = np.where(pair_mask == 1, nlist, -1)

ret_dict = self.forward_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

if self.atom_excl is not None:
atom_mask = self.atom_excl.build_type_exclude_mask(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]

return ret_dict

def serialize(self) -> dict:
return {
"atom_exclude_types": self.atom_exclude_types,
"pair_exclude_types": self.pair_exclude_types,
}
30 changes: 18 additions & 12 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ def __init__(
descriptor,
fitting,
type_map: Optional[List[str]] = None,
**kwargs,
):
super().__init__()
self.type_map = type_map
self.descriptor = descriptor
self.fitting = fitting
super().__init__(**kwargs)

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand Down Expand Up @@ -132,24 +133,29 @@ def forward_atomic(
return ret

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"@version": 1,
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
}
dd = super().serialize()
dd.update(
{
"@class": "Model",
"type": "standard",
"@version": 1,
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
}
)
return dd

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor"))
fitting_obj = BaseFitting.deserialize(data.pop("fitting"))
type_map = data.pop("type_map", None)
obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data)
return obj

def get_dim_fparam(self) -> int:
Expand Down
33 changes: 19 additions & 14 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def __init__(
models: List[BaseAtomicModel],
**kwargs,
):
super().__init__()
self.models = models
self.mixed_types_list = [model.mixed_types() for model in self.models]
super().__init__(**kwargs)

def mixed_types(self) -> bool:
"""If true, the model
Expand Down Expand Up @@ -273,34 +273,39 @@ def __init__(
self.smin_alpha = smin_alpha

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
}
dd = BaseAtomicModel.serialize(self)
dd.update(
{
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
}
)
return dd

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
sw_rmin = data["sw_rmin"]
sw_rmax = data["sw_rmax"]
smin_alpha = data["smin_alpha"]
sw_rmin = data.pop("sw_rmin")
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

dp_model, zbl_model = LinearAtomicModel.deserialize(data["models"])
dp_model, zbl_model = LinearAtomicModel.deserialize(data.pop("models"))

return cls(
dp_model=dp_model,
zbl_model=zbl_model,
sw_rmin=sw_rmin,
sw_rmax=sw_rmax,
smin_alpha=smin_alpha,
**data,
)

def _compute_weight(
Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ def get_rcut(self) -> float:
def get_type_map(self) -> Optional[List[str]]:
"""Get the type map."""

def get_ntypes(self) -> int:
"""Get the number of atom types."""
tmap = self.get_type_map()
if tmap is not None:
return len(tmap)
else:
raise ValueError(
"cannot infer the number of types from a None type map"
)

@abstractmethod
def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
Expand Down
28 changes: 16 additions & 12 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,29 @@ def mixed_types(self) -> bool:
return True

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "pairtab",
"@version": 1,
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}
dd = BaseAtomicModel.serialize(self)
dd.update(
{
"@class": "Model",
"type": "pairtab",
"@version": 1,
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}
)
return dd

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
tab_model = cls(None, rcut, sel)
rcut = data.pop("rcut")
sel = data.pop("sel")
tab = PairTab.deserialize(data.pop("tab"))
tab_model = cls(None, rcut, sel, **data)
tab_model.tab = tab
tab_model.tab_info = tab_model.tab.tab_info
tab_model.tab_data = tab_model.tab.tab_data
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def call_lower(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
atomic_ret = self.forward_atomic(
atomic_ret = self.forward_common_atomic(
cc_ext,
extended_atype,
nlist,
Expand Down
2 changes: 2 additions & 0 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@ def get_model(data: dict) -> DPModel:
descriptor=descriptor,
fitting=fitting,
type_map=data["type_map"],
atom_exclude_types=data.get("atom_exclude_types", []),
pair_exclude_types=data.get("pair_exclude_types", []),
)
79 changes: 79 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,99 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from typing import (
Dict,
List,
Optional,
Tuple,
)

import torch

from deepmd.dpmodel.atomic_model import (
make_base_atomic_model,
)
from deepmd.pt.utils import (
AtomExcludeMask,
PairExcludeMask,
)

BaseAtomicModel_ = make_base_atomic_model(torch.Tensor)


class BaseAtomicModel(BaseAtomicModel_):
def __init__(
self,
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)

def reinit_atom_exclude(
self,
exclude_types: List[int] = [],
):
self.atom_exclude_types = exclude_types
if exclude_types == []:
self.atom_excl = None
else:
self.atom_excl = AtomExcludeMask(self.get_ntypes(), self.atom_exclude_types)

def reinit_pair_exclude(
self,
exclude_types: List[Tuple[int, int]] = [],
):
self.pair_exclude_types = exclude_types
if exclude_types == []:
self.pair_excl = None
else:
self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)

# export public methods that are not abstract
get_nsel = torch.jit.export(BaseAtomicModel_.get_nsel)
get_nnei = torch.jit.export(BaseAtomicModel_.get_nnei)

@torch.jit.export
def get_model_def_script(self) -> str:
return self.model_def_script

def forward_common_atomic(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]

if self.pair_excl is not None:
pair_mask = self.pair_excl(nlist, extended_atype)
# exclude neighbors in the nlist
nlist = torch.where(pair_mask == 1, nlist, -1)

ret_dict = self.forward_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

if self.atom_excl is not None:
atom_mask = self.atom_excl(atype)
for kk in ret_dict.keys():
ret_dict[kk] = ret_dict[kk] * atom_mask[:, :, None]

return ret_dict

def serialize(self) -> dict:
return {
"atom_exclude_types": self.atom_exclude_types,
"pair_exclude_types": self.pair_exclude_types,
}
Loading