Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion deepmd/tf/entrypoints/change_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,23 @@ def change_bias(
log_level,
)
else:
# Checkpoint directory, or a checkpoint prefix such as "model.ckpt-1000"
# that carries no recognized suffix. Route these to the checkpoint
# handler when a TensorFlow "checkpoint" state file is present.
input_path = Path(INPUT)
checkpoint_dir = input_path if input_path.is_dir() else input_path.parent
if (checkpoint_dir / "checkpoint").is_file():
return _change_bias_checkpoint_file(
INPUT,
mode,
bias_value,
datafile,
system,
numb_batch,
model_branch,
output,
log_level,
)
raise RuntimeError(
"The model provided must be a checkpoint file or frozen model file (.pb)"
)
Expand All @@ -147,7 +164,9 @@ def _change_bias_checkpoint_file(
tf.reset_default_graph()

checkpoint_path = Path(checkpoint_prefix)
checkpoint_dir = checkpoint_path.parent
checkpoint_dir = (
checkpoint_path if checkpoint_path.is_dir() else checkpoint_path.parent
)

# Check for valid checkpoint and find the actual checkpoint path
checkpoint_state_file = checkpoint_dir / "checkpoint"
Expand Down
41 changes: 41 additions & 0 deletions source/tests/tf/test_change_bias.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib
import json
import os
import shutil
Expand All @@ -7,10 +8,18 @@
from pathlib import (
Path,
)
from unittest import (
mock,
)

from deepmd.tf.entrypoints.change_bias import (
change_bias,
)

# the ``entrypoints`` package re-exports the ``change_bias`` function, which
# shadows the submodule under attribute access; fetch the real module so its
# private helpers can be patched.
change_bias_module = importlib.import_module("deepmd.tf.entrypoints.change_bias")
from deepmd.tf.train.run_options import (
RunOptions,
)
Expand Down Expand Up @@ -112,6 +121,38 @@ def test_change_bias_no_checkpoint_in_directory(self):

self.assertIn("No valid checkpoint found", str(cm.exception))

def test_change_bias_accepts_checkpoint_prefix_with_step(self):
"""A checkpoint prefix such as ``model.ckpt-1000`` carries no recognized
suffix, but must be routed to the checkpoint handler when a TensorFlow
``checkpoint`` state file sits beside it.
"""
ckpt_dir = self.temp_path / "ckpt_prefix"
ckpt_dir.mkdir()
(ckpt_dir / "checkpoint").write_text('model_checkpoint_path: "model.ckpt-1000"')
prefix = ckpt_dir / "model.ckpt-1000"

with mock.patch.object(
change_bias_module, "_change_bias_checkpoint_file"
) as mocked:
change_bias(INPUT=str(prefix), mode="change", system=".")
mocked.assert_called_once()
self.assertEqual(mocked.call_args.args[0], str(prefix))

def test_change_bias_accepts_checkpoint_directory(self):
"""A checkpoint directory (no suffix) containing a ``checkpoint`` state
file must be routed to the checkpoint handler, not rejected.
"""
ckpt_dir = self.temp_path / "ckpt_dir"
ckpt_dir.mkdir()
(ckpt_dir / "checkpoint").write_text('model_checkpoint_path: "model.ckpt-1000"')

with mock.patch.object(
change_bias_module, "_change_bias_checkpoint_file"
) as mocked:
change_bias(INPUT=str(ckpt_dir), mode="change", system=".")
mocked.assert_called_once()
self.assertEqual(mocked.call_args.args[0], str(ckpt_dir))

def test_change_bias_user_defined_requires_real_model(self):
"""Test that user-defined bias requires a real model with proper structure."""
fake_ckpt_dir = self.temp_path / "fake_checkpoint"
Expand Down
Loading