Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from deepmd.pt.model.task.sezm_ener import (
SeZMEnergyFittingNet,
)
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.utils.spin import (
Spin,
)
Expand All @@ -44,6 +47,8 @@
)
from .dp_linear_model import (
LinearEnergyModel,
normalize_linear_model_type_map,
validate_linear_shared_descriptor_type_maps,
)
from .dp_model import (
DPModelCommon,
Expand Down Expand Up @@ -155,19 +160,45 @@ def get_spin_model(model_params: dict) -> SpinModel:
def get_linear_model(model_params: dict) -> LinearEnergyModel:
model_params = copy.deepcopy(model_params)
weights = model_params.get("weights", "mean")
shared_links = None
if "shared_dict" in model_params:
shared_config = {
"model_dict": {
f"model_{idx}": sub_model
for idx, sub_model in enumerate(model_params["models"])
},
"shared_dict": model_params.get("shared_dict", {}),
}
if "type_map" in model_params:
shared_config["type_map"] = copy.deepcopy(model_params["type_map"])
shared_config, shared_links = preprocess_shared_params(
shared_config,
require_shared_type_map=False,
)
model_params["models"] = list(shared_config["model_dict"].values())
normalize_linear_model_type_map(model_params)
validate_linear_shared_descriptor_type_maps(
model_params["models"],
shared_links,
)

list_of_models = []
ntypes = len(model_params["type_map"])
for sub_model_params in model_params["models"]:
if "type_map" not in sub_model_params:
sub_model_params["type_map"] = model_params["type_map"]
if "descriptor" in sub_model_params:
# descriptor
sub_model_params["descriptor"]["ntypes"] = ntypes
sub_ntypes = len(sub_model_params["type_map"])
sub_model_params["descriptor"]["ntypes"] = sub_ntypes
descriptor, fitting, _ = _get_standard_model_components(
sub_model_params, ntypes
sub_model_params, sub_ntypes
)
list_of_models.append(
DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"])
DPAtomicModel(
descriptor,
fitting,
type_map=copy.deepcopy(sub_model_params["type_map"]),
)
)

else: # must be pairtab
Expand All @@ -179,19 +210,23 @@ def get_linear_model(model_params: dict) -> LinearEnergyModel:
sub_model_params["tab_file"],
sub_model_params["rcut"],
sub_model_params["sel"],
type_map=model_params["type_map"],
type_map=copy.deepcopy(sub_model_params["type_map"]),
)
)

atom_exclude_types = model_params.get("atom_exclude_types", [])
pair_exclude_types = model_params.get("pair_exclude_types", [])
return LinearEnergyModel(
model = LinearEnergyModel(
models=list_of_models,
type_map=model_params["type_map"],
weights=weights,
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
)
model.shared_links = shared_links
if shared_links:
model.share_params(shared_links, resume=True)
return model


def get_zbl_model(model_params: dict) -> DPZBLModel:
Expand Down
230 changes: 226 additions & 4 deletions deepmd/pt/model/model/dp_linear_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from copy import (
deepcopy,
)
from typing import (
Any,
)
Expand All @@ -14,6 +18,9 @@
from deepmd.pt.model.model.model import (
BaseModel,
)
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand All @@ -25,9 +32,66 @@
make_model,
)

log = logging.getLogger(__name__)

DPLinearModel_ = make_model(LinearEnergyAtomicModel)


def _get_linear_model_index(model_key: str) -> int:
if not model_key.startswith("model_"):
raise RuntimeError(f"Unknown linear model key {model_key}!")
return int(model_key.removeprefix("model_"))


def normalize_linear_model_type_map(model_params: dict[str, Any]) -> None:
"""Fill the linear model type_map from sub-models when needed."""
if "type_map" in model_params:
return
for idx, sub_model_params in enumerate(model_params["models"]):
if "type_map" not in sub_model_params:
raise ValueError(
f"Linear sub-model {idx} must define type_map when "
"linear_ener has no top-level type_map."
)
first_type_map = model_params["models"][0]["type_map"]
for idx, sub_model_params in enumerate(model_params["models"][1:], start=1):
if sub_model_params["type_map"] != first_type_map:
raise ValueError(
f"Linear sub-model {idx} type_map differs from sub-model 0. "
"All type_map values must be identical when linear_ener "
"has no top-level type_map."
)
model_params["type_map"] = deepcopy(first_type_map)


def validate_linear_shared_descriptor_type_maps(
models: list[dict[str, Any]],
shared_links: dict[str, Any] | None,
) -> None:
"""Reject descriptor sharing across incompatible linear sub-model type maps."""
if not shared_links:
return
for shared_key, shared_item in shared_links.items():
descriptor_links = [
link for link in shared_item["links"] if "descriptor" in link["shared_type"]
]
if len(descriptor_links) < 2:
continue
base_link = descriptor_links[0]
base_index = _get_linear_model_index(base_link["model_key"])
base_type_map = models[base_index]["type_map"]
for link_item in descriptor_links[1:]:
model_index = _get_linear_model_index(link_item["model_key"])
model_type_map = models[model_index]["type_map"]
if model_type_map != base_type_map:
raise ValueError(
f"Linear sub-model {model_index} type_map {model_type_map} "
f"is incompatible with sub-model {base_index} type_map "
f"{base_type_map} for shared descriptor {shared_key!r}. "
"Shared descriptor links require identical type_map values."
)


@BaseModel.register("linear_ener")
class LinearEnergyModel(DPLinearModel_):
model_type = "linear_ener"
Expand All @@ -39,6 +103,106 @@ def __init__(
) -> None:
super().__init__(*args, **kwargs)

def share_params(
self,
shared_links: dict[str, Any],
model_key_prob_map: dict[str, float] | None = None,
data_stat_protect: float = 1e-2,
resume: bool = False,
) -> None:
"""Share parameters between linear sub-models.

``shared_links`` follows the same structure as the multi-task
preprocessor. Linear sub-model keys are named ``model_0``, ``model_1``,
... by ``get_linear_model``.
"""

def get_sub_model(model_key: str): # noqa: ANN202
model_index = _get_linear_model_index(model_key)
return self.atomic_model.models[model_index]

def get_descriptor_class(model_key: str, shared_type: str): # noqa: ANN202
sub_model = get_sub_model(model_key)
if shared_type == "descriptor":
return sub_model.descriptor
if "hybrid" in shared_type:
hybrid_index = int(shared_type.split("_")[-1])
return sub_model.descriptor.descrpt_list[hybrid_index]
raise RuntimeError(f"Unknown class_type {shared_type}!")

for shared_item in shared_links:
shared_base = shared_links[shared_item]["links"][0]
class_type_base = shared_base["shared_type"]
model_key_base = shared_base["model_key"]
shared_level_base = int(shared_base["shared_level"])
previous_shared_level = shared_level_base
if "descriptor" in class_type_base:
base_class = get_descriptor_class(model_key_base, class_type_base)
for link_item in shared_links[shared_item]["links"][1:]:
class_type_link = link_item["shared_type"]
model_key_link = link_item["model_key"]
shared_level_link = int(link_item["shared_level"])
if shared_level_link < previous_shared_level:
raise ValueError(
"The shared_links must be sorted by shared_level!"
)
previous_shared_level = shared_level_link
if "descriptor" not in class_type_link:
raise ValueError(
f"Class type mismatched: {class_type_base} vs {class_type_link}!"
)
link_class = get_descriptor_class(model_key_link, class_type_link)
link_class.share_params(
base_class, shared_level_link, resume=resume
)
log.warning(
"Shared params of %s.%s and %s.%s!",
model_key_base,
class_type_base,
model_key_link,
class_type_link,
)
else:
base_model = get_sub_model(model_key_base)
if hasattr(base_model, class_type_base):
base_class = getattr(base_model, class_type_base)
for link_item in shared_links[shared_item]["links"][1:]:
class_type_link = link_item["shared_type"]
model_key_link = link_item["model_key"]
shared_level_link = int(link_item["shared_level"])
if shared_level_link < previous_shared_level:
raise ValueError(
"The shared_links must be sorted by shared_level!"
)
previous_shared_level = shared_level_link
if class_type_base != class_type_link:
raise ValueError(
f"Class type mismatched: {class_type_base} vs {class_type_link}!"
)
link_model = get_sub_model(model_key_link)
link_class = getattr(link_model, class_type_link)
if model_key_prob_map is None:
frac_prob = 1.0
else:
frac_prob = (
model_key_prob_map[model_key_link]
/ model_key_prob_map[model_key_base]
)
link_class.share_params(
base_class,
shared_level_link,
model_prob=frac_prob,
protection=data_stat_protect,
resume=resume,
)
log.warning(
"Shared params of %s.%s and %s.%s!",
model_key_base,
class_type_base,
model_key_link,
class_type_link,
)

def translated_output_def(self) -> dict[str, OutputVariableDef]:
out_def_data = self.model_output_def().get_data()
output_def = {
Expand Down Expand Up @@ -159,14 +323,72 @@ def update_sel(
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy = deepcopy(local_jdata)
original_models = deepcopy(local_jdata_cpy["models"])
has_shared_dict = "shared_dict" in local_jdata_cpy
if has_shared_dict:
shared_config = {
"model_dict": {
f"model_{idx}": sub_model
for idx, sub_model in enumerate(local_jdata_cpy["models"])
},
"shared_dict": local_jdata_cpy.get("shared_dict", {}),
}
if "type_map" in local_jdata_cpy:
shared_config["type_map"] = deepcopy(local_jdata_cpy["type_map"])
shared_config, shared_links = preprocess_shared_params(
shared_config,
require_shared_type_map=False,
)
local_jdata_cpy["models"] = list(shared_config["model_dict"].values())
normalize_linear_model_type_map(local_jdata_cpy)
validate_linear_shared_descriptor_type_maps(
local_jdata_cpy["models"],
shared_links,
)
type_map = local_jdata_cpy["type_map"]
min_nbor_dist = None
for idx, sub_model in enumerate(local_jdata_cpy["models"]):
if "tab_file" not in sub_model:
sub_model, temp_min = DPModelCommon.update_sel(
train_data, type_map, local_jdata["models"][idx]
sub_type_map = sub_model.get("type_map", type_map)
local_jdata_cpy["models"][idx], temp_min = DPModelCommon.update_sel(
train_data, sub_type_map, sub_model
)
if min_nbor_dist is None or temp_min <= min_nbor_dist:
min_nbor_dist = temp_min
return local_jdata_cpy, min_nbor_dist
if not has_shared_dict:
return local_jdata_cpy, min_nbor_dist

def get_shared_key(shared_ref: str) -> str:
return shared_ref.split(":", maxsplit=1)[0]

ret_jdata = deepcopy(local_jdata)
ret_jdata["models"] = original_models
if "type_map" not in ret_jdata:
ret_jdata["type_map"] = deepcopy(type_map)
for idx, original_sub_model in enumerate(original_models):
if "tab_file" in original_sub_model:
continue
updated_sub_model = local_jdata_cpy["models"][idx]
descriptor_ref = original_sub_model.get("descriptor")
if isinstance(descriptor_ref, str):
ret_jdata["shared_dict"][get_shared_key(descriptor_ref)] = (
updated_sub_model["descriptor"]
)
elif (
isinstance(descriptor_ref, dict)
and descriptor_ref.get("type") == "hybrid"
):
updated_descriptor = updated_sub_model["descriptor"]
for hybrid_idx, hybrid_ref in enumerate(descriptor_ref["list"]):
if isinstance(hybrid_ref, str):
ret_jdata["shared_dict"][get_shared_key(hybrid_ref)] = (
updated_descriptor["list"][hybrid_idx]
)
else:
ret_jdata["models"][idx]["descriptor"]["list"][hybrid_idx] = (
updated_descriptor["list"][hybrid_idx]
)
else:
ret_jdata["models"][idx]["descriptor"] = updated_sub_model["descriptor"]
Comment thread
njzjz marked this conversation as resolved.
return ret_jdata, min_nbor_dist
Loading
Loading