Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
is_lmdb,
)
from deepmd.pt.utils.multi_task import (
preprocess_linear_shared_params,
preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
Expand Down Expand Up @@ -329,6 +330,8 @@ def train(
assert "RANDOM" not in config["model"]["model_dict"], (
"Model name can not be 'RANDOM' in multi-task mode!"
)
elif config["model"].get("type") == "linear_ener":
config["model"], shared_links = preprocess_linear_shared_params(config["model"])

# update fine-tuning config
finetune_links = None
Expand Down
8 changes: 7 additions & 1 deletion deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from deepmd.pt.model.task.sezm_ener import (
SeZMEnergyFittingNet,
)
from deepmd.pt.utils.multi_task import (
preprocess_linear_shared_params,
)
from deepmd.utils.spin import (
Spin,
)
Expand Down Expand Up @@ -154,6 +157,7 @@ def get_spin_model(model_params: dict) -> SpinModel:

def get_linear_model(model_params: dict) -> LinearEnergyModel:
model_params = copy.deepcopy(model_params)
model_params, shared_links = preprocess_linear_shared_params(model_params)
weights = model_params.get("weights", "mean")
list_of_models = []
ntypes = len(model_params["type_map"])
Expand Down Expand Up @@ -185,13 +189,15 @@ def get_linear_model(model_params: dict) -> LinearEnergyModel:

atom_exclude_types = model_params.get("atom_exclude_types", [])
pair_exclude_types = model_params.get("pair_exclude_types", [])
return LinearEnergyModel(
model = LinearEnergyModel(
models=list_of_models,
type_map=model_params["type_map"],
weights=weights,
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
)
model.shared_links = shared_links
return model


def get_zbl_model(model_params: dict) -> DPZBLModel:
Expand Down
78 changes: 78 additions & 0 deletions deepmd/pt/model/model/dp_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,84 @@ def __init__(
) -> None:
super().__init__(*args, **kwargs)

def share_params(
self,
shared_links: dict[str, Any],
data_stat_protect: float = 1e-2,
resume: bool = False,
) -> None:
"""Share parameters between linear submodels.

``shared_links`` follows the same structure as multitask training, but
model keys are the synthetic names generated by
:func:`preprocess_linear_shared_params`: ``model_0``, ``model_1``, ...
"""

def get_model_index(model_key: str) -> int:
if model_key.startswith("model_"):
return int(model_key.removeprefix("model_"))
return int(model_key)

def get_descriptor(model_idx: int, shared_type: str) -> Any:
descriptor = self.atomic_model.models[model_idx].descriptor
if shared_type == "descriptor":
return descriptor
if "hybrid" in shared_type:
hybrid_index = int(shared_type.split("_")[-1])
return descriptor.descriptor_list[hybrid_index]
raise RuntimeError(f"Unknown class_type {shared_type}!")

def get_fitting_net(model_idx: int, shared_type: str) -> Any:
atomic_model = self.atomic_model.models[model_idx]
if not hasattr(atomic_model, shared_type):
raise RuntimeError(f"Unknown class_type {shared_type}!")
return getattr(atomic_model, shared_type)

for shared_item in shared_links:
shared_base = shared_links[shared_item]["links"][0]
class_type_base = shared_base["shared_type"]
model_idx_base = get_model_index(shared_base["model_key"])
shared_level_base = int(shared_base["shared_level"])

if "descriptor" in class_type_base:
base_class = get_descriptor(model_idx_base, class_type_base)
for link_item in shared_links[shared_item]["links"][1:]:
class_type_link = link_item["shared_type"]
shared_level_link = int(link_item["shared_level"])
assert shared_level_link >= shared_level_base, (
"The shared_links must be sorted by shared_level!"
)
Comment on lines +86 to +88

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Enforce true monotonic shared_level ordering.

Line 86 and Line 103 only compare each link against the base level, so an unsorted sequence like 0, 2, 1 still passes. That can apply sharing in the wrong order.

Suggested fix
         for shared_item in shared_links:
             shared_base = shared_links[shared_item]["links"][0]
             class_type_base = shared_base["shared_type"]
             model_idx_base = get_model_index(shared_base["model_key"])
             shared_level_base = int(shared_base["shared_level"])

             if "descriptor" in class_type_base:
                 base_class = get_descriptor(model_idx_base, class_type_base)
+                prev_shared_level = shared_level_base
                 for link_item in shared_links[shared_item]["links"][1:]:
                     class_type_link = link_item["shared_type"]
                     shared_level_link = int(link_item["shared_level"])
-                    assert shared_level_link >= shared_level_base, (
+                    assert shared_level_link >= prev_shared_level, (
                         "The shared_links must be sorted by shared_level!"
                     )
+                    prev_shared_level = shared_level_link
                     assert "descriptor" in class_type_link, (
                         f"Class type mismatched: {class_type_base} vs {class_type_link}!"
                     )
                     link_class = get_descriptor(
                         get_model_index(link_item["model_key"]), class_type_link
@@
             else:
                 base_class = get_fitting_net(model_idx_base, class_type_base)
+                prev_shared_level = shared_level_base
                 for link_item in shared_links[shared_item]["links"][1:]:
                     class_type_link = link_item["shared_type"]
                     shared_level_link = int(link_item["shared_level"])
-                    assert shared_level_link >= shared_level_base, (
+                    assert shared_level_link >= prev_shared_level, (
                         "The shared_links must be sorted by shared_level!"
                     )
+                    prev_shared_level = shared_level_link
                     assert class_type_base == class_type_link, (
                         f"Class type mismatched: {class_type_base} vs {class_type_link}!"
                     )

Also applies to: 103-105

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/model/dp_linear_model.py` around lines 86 - 88, The assertion
checking shared_level ordering only compares each shared_level_link against the
shared_level_base, which allows unsorted sequences like [0, 2, 1] to pass
validation. To enforce true monotonic ordering, you need to track and compare
each shared_level_link against the previously processed shared_level value (not
just the base level). Update the assertion logic in both locations (around line
86-88 and line 103-105) to maintain a running reference to the last processed
shared_level and compare each new link against that previous value to ensure
strictly increasing order.

assert "descriptor" in class_type_link, (
f"Class type mismatched: {class_type_base} vs {class_type_link}!"
)
link_class = get_descriptor(
get_model_index(link_item["model_key"]), class_type_link
)
link_class.share_params(
base_class, shared_level_link, resume=resume
)
else:
base_class = get_fitting_net(model_idx_base, class_type_base)
for link_item in shared_links[shared_item]["links"][1:]:
class_type_link = link_item["shared_type"]
shared_level_link = int(link_item["shared_level"])
assert shared_level_link >= shared_level_base, (
"The shared_links must be sorted by shared_level!"
)
assert class_type_base == class_type_link, (
f"Class type mismatched: {class_type_base} vs {class_type_link}!"
)
link_class = get_fitting_net(
get_model_index(link_item["model_key"]), class_type_link
)
link_class.share_params(
base_class,
shared_level_link,
model_prob=1.0,
protection=data_stat_protect,
resume=resume,
)

def translated_output_def(self) -> dict[str, OutputVariableDef]:
out_def_data = self.model_output_def().get_data()
output_def = {
Expand Down
10 changes: 9 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,14 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
rank=self.rank,
)

# Linear-model share params
if not self.multi_task and shared_links is not None:
self.model.share_params(
shared_links,
resume=(resuming and not self.finetune_update_stat) or self.rank != 0,
data_stat_protect=model_params.get("data_stat_protect", 1e-2),
)

# Learning rate
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
self.nonfinite_grad_guard = NonFiniteGradGuard()
Expand Down Expand Up @@ -929,7 +937,7 @@ def single_model_finetune(
)

# Multi-task share params
if shared_links is not None:
if self.multi_task and shared_links is not None:
_data_stat_protect = np.array(
[
model_params["model_dict"][ii].get("data_stat_protect", 1e-2)
Expand Down
47 changes: 45 additions & 2 deletions deepmd/pt/utils/multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _cascade_top_level_defaults(model_config: dict[str, Any]) -> None:

def preprocess_shared_params(
model_config: dict[str, Any],
require_shared_type_map: bool = True,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Preprocess the model params for multitask model, and generate the links dict for further sharing.

Expand Down Expand Up @@ -167,7 +168,10 @@ def replace_one_item(
item_params = model_params_item[item_key]
if isinstance(item_params, str):
replace_one_item(model_params_item, item_key, item_params)
elif item_params.get("type", "") == "hybrid":
elif (
isinstance(item_params, dict)
and item_params.get("type", "") == "hybrid"
):
for ii, hybrid_item in enumerate(item_params["list"]):
if isinstance(hybrid_item, str):
replace_one_item(
Expand All @@ -187,10 +191,49 @@ def replace_one_item(
)
# little trick to make spin models in the front to be the base models,
# because its type embeddings are more general.
assert len(type_map_keys) == 1, "Multitask model must have only one type_map!"
if require_shared_type_map:
assert len(type_map_keys) == 1, "Multitask model must have only one type_map!"
else:
assert len(type_map_keys) <= 1, "Model must not reference multiple type_maps!"
return model_config, shared_links


def preprocess_linear_shared_params(
model_config: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any] | None]:
"""Preprocess ``shared_dict`` references in a linear-energy model config.

Linear models store their branches in a list named ``models``, while the
shared-parameter preprocessor works on named ``model_dict`` branches. This
adapter temporarily names linear submodels ``model_0``, ``model_1``, ... so
the same ``shared_dict`` syntax and link generation can be reused.
"""
if "shared_dict" not in model_config:
return model_config, None

linear_config = deepcopy(model_config)
shared_config = {
"model_dict": {
f"model_{idx}": sub_model
for idx, sub_model in enumerate(linear_config["models"])
},
"shared_dict": linear_config["shared_dict"],
}
shared_config, shared_links = preprocess_shared_params(
shared_config,
require_shared_type_map=False,
)
linear_config["models"] = [
shared_config["model_dict"][f"model_{idx}"]
for idx in range(len(linear_config["models"]))
]
if "type_map" not in linear_config and linear_config["models"]:
first_type_map = linear_config["models"][0].get("type_map")
if isinstance(first_type_map, list):
linear_config["type_map"] = deepcopy(first_type_map)
return linear_config, shared_links


def get_class_name(item_key: str, item_params: dict[str, Any]) -> type:
if item_key == "descriptor":
return BaseDescriptor.get_class_by_type(item_params.get("type", "se_e2_a"))
Expand Down
11 changes: 11 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3388,6 +3388,11 @@ def linear_ener_model_args() -> Argument:
'If "mean", the weights are set to be 1 / len(models). '
'If "sum", the weights are set to be 1.'
)
doc_shared_dict = (
"Named shared items that can be referenced by sub-model type_map, "
"descriptor, or fitting_net fields. This follows the multi-task "
"shared_dict syntax."
)
models_args = model_args(exclude_hybrid=True)
models_args.name = "models"
models_args.fold_subdoc = True
Expand All @@ -3405,6 +3410,12 @@ def linear_ener_model_args() -> Argument:
optional=False,
doc=doc_weights,
),
Argument(
"shared_dict",
dict,
optional=True,
doc=doc_only_pt_supported + doc_shared_dict,
),
],
doc=doc_only_tf_supported,
)
Expand Down
78 changes: 78 additions & 0 deletions source/tests/pt/model/test_shared_dict_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import unittest

from deepmd.pt.model.model import (
get_model,
)
from deepmd.pt.utils.multi_task import (
preprocess_linear_shared_params,
)


class TestSharedDictLinear(unittest.TestCase):
def setUp(self) -> None:
self.config = {
"type": "linear_ener",
"type_map": ["O", "H"],
"shared_dict": {
"type_map_all": ["O", "H"],
"shared_descriptor": {
"type": "dpa1",
"rcut": 6.0,
"rcut_smth": 0.5,
"sel": 16,
"neuron": [4, 8, 16],
"axis_neuron": 4,
"seed": 1,
},
},
"models": [
{
"type_map": "type_map_all",
"descriptor": "shared_descriptor",
"fitting_net": {
"neuron": [8, 8],
"resnet_dt": True,
"seed": 1,
},
},
{
"type_map": "type_map_all",
"descriptor": "shared_descriptor",
"fitting_net": {
"neuron": [8, 8],
"resnet_dt": True,
"seed": 2,
},
},
],
"weights": "mean",
}

def test_preprocess_linear_shared_dict(self) -> None:
model_config, shared_links = preprocess_linear_shared_params(
copy.deepcopy(self.config)
)

self.assertEqual(model_config["models"][0]["type_map"], ["O", "H"])
self.assertIsInstance(model_config["models"][0]["descriptor"], dict)
self.assertIn("shared_descriptor", shared_links)
self.assertEqual(
[item["model_key"] for item in shared_links["shared_descriptor"]["links"]],
["model_0", "model_1"],
)

def test_linear_model_shares_descriptor_params(self) -> None:
model = get_model(copy.deepcopy(self.config))
self.assertIsNotNone(model.shared_links)

model.share_params(model.shared_links)

descriptor_0 = model.atomic_model.models[0].descriptor
descriptor_1 = model.atomic_model.models[1].descriptor
self.assertIs(descriptor_0.type_embedding, descriptor_1.type_embedding)


if __name__ == "__main__":
unittest.main()
Loading