Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions deepmd/tf/model/linear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import operator
from enum import (
Enum,
Expand Down Expand Up @@ -77,11 +78,14 @@ def get_fitting(self) -> Fitting | dict:

def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss | dict | None:
"""Get the loss function(s)."""
# the first model that is not None, or None if all models are None
# Return the first submodel loss that is not None, or None if all
# submodels are non-trainable. Each submodel must receive the original
# loss config: frozen/table submodels return None, so consuming the
# return value would pass None to a later trainable submodel.
for model in self.models:
loss = model.get_loss(loss, lr)
if loss is not None:
return loss
submodel_loss = model.get_loss(copy.deepcopy(loss), lr)
if submodel_loss is not None:
return submodel_loss
return None

def get_rcut(self) -> float:
Expand Down
35 changes: 35 additions & 0 deletions source/tests/tf/test_linear_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import os
import unittest

import numpy as np

Expand Down Expand Up @@ -121,3 +122,37 @@ def tearDown(self) -> None:
for pb in self.graph_dirs:
os.remove(pb)
del_data()


class TestLinearModelGetLoss(unittest.TestCase):
"""A non-trainable submodel (frozen/pairtab) returns ``None`` from
``get_loss``; a later trainable submodel must still receive the original
loss configuration rather than that ``None``.
"""

def test_get_loss_skips_none_returning_submodels(self) -> None:
class _NoneLossModel:
"""Mimics frozen/pairtab submodels, which own no trainable loss."""

def get_loss(self, loss, lr):
return None

class _EnerLossModel:
"""Mimics a trainable ener submodel: dereferences loss as a dict
(as EnerFitting.get_loss does via ``loss.pop``).
"""

def get_loss(self, loss, lr):
loss.pop("type", "ener")
return "ener-loss"

# bypass the heavy __init__; get_loss only touches self.models
model = object.__new__(LinearEnergyModel)
model.models = [_NoneLossModel(), _EnerLossModel()]

loss_config = {"type": "ener", "start_pref_e": 1.0}
result = model.get_loss(loss_config, lr=None)

self.assertEqual(result, "ener-loss")
# the original config must be preserved for other consumers
self.assertEqual(loss_config, {"type": "ener", "start_pref_e": 1.0})
Loading