Skip to content

[Code scan] Keep TensorFlow linear-model loss configs after frozen components #5682

Description

@njzjz

This issue comes from a Codex global scan of deepmodeling/deepmd-kit at commit 73de44b1f94471b2e3bdb6b11f57b34d7bc791bb.

Problem

LinearEnergyModel.get_loss() overwrites the caller's loss config with each submodel result:

def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss | dict | None:
"""Get the loss function(s)."""
# the first model that is not None, or None if all models are None
for model in self.models:
loss = model.get_loss(loss, lr)
if loss is not None:
return loss
return None

Frozen and table submodels return None because they do not own a trainable loss:

def get_fitting(self) -> Fitting | dict:
"""Get the fitting(s)."""
return {}
def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss | dict | None:
"""Get the loss function(s)."""
# loss should be never used for a frozen model
return

def get_fitting(self) -> Fitting | dict:
"""Get the fitting(s)."""
# nothing needs to do
return {}
def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss | dict | None:
"""Get the loss function(s)."""
# nothing needs to do
return

If such a non-trainable submodel appears before a trainable energy model, the next submodel receives loss=None. EnerFitting.get_loss() then dereferences it as a dict:

"""Get the loss function.
Parameters
----------
loss : dict
The loss function parameters.
lr : LearningRateSchedule
The learning rate.
Returns
-------
Loss
The loss function.
"""
_loss_type = loss.pop("type", "ener")
loss["starter_learning_rate"] = lr.start_lr()
if _loss_type == "ener":
return EnerStdLoss(**loss)

Impact

Linear TensorFlow models that combine frozen/table components before a trainable energy component can crash during trainer construction with an attribute error from loss.pop(...), even though the trainable submodel could have used the original loss config.

Suggested fix

Keep the original loss configuration separate from each submodel's returned Loss object. For example, iterate with model.get_loss(copy.deepcopy(loss), lr) and return the first non-None result.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    Status
    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions