From 40272299d8657e4d089f09a5032350e5a344a3e9 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 1 Jul 2026 15:37:25 +0800 Subject: [PATCH] fix(tf): accept checkpoint dirs and step prefixes in change-bias The change-bias dispatcher only routed inputs ending in .pb/.pbtxt/.ckpt/ .meta/.data/.index and rejected everything else with "must be a checkpoint file or frozen model file (.pb)". Standard TensorFlow checkpoint prefixes such as model.ckpt-1000 carry no recognized suffix, and checkpoint directories have none either, so the inputs recommended by the CLI docs and the frozen-model fallback message were rejected before _change_bias_checkpoint_file (which already reads the checkpoint state file) could run. Route directory and suffix-less prefix inputs to the checkpoint handler when a TensorFlow "checkpoint" state file is present in the effective directory, and make _change_bias_checkpoint_file resolve checkpoint_dir for a directory input. Add dispatch tests for a step-suffixed prefix and a checkpoint directory. Fix #5683 --- deepmd/tf/entrypoints/change_bias.py | 21 +++++++++++++- source/tests/tf/test_change_bias.py | 41 ++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/deepmd/tf/entrypoints/change_bias.py b/deepmd/tf/entrypoints/change_bias.py index 0363ac7437..46c540a244 100644 --- a/deepmd/tf/entrypoints/change_bias.py +++ b/deepmd/tf/entrypoints/change_bias.py @@ -126,6 +126,23 @@ def change_bias( log_level, ) else: + # Checkpoint directory, or a checkpoint prefix such as "model.ckpt-1000" + # that carries no recognized suffix. Route these to the checkpoint + # handler when a TensorFlow "checkpoint" state file is present. + input_path = Path(INPUT) + checkpoint_dir = input_path if input_path.is_dir() else input_path.parent + if (checkpoint_dir / "checkpoint").is_file(): + return _change_bias_checkpoint_file( + INPUT, + mode, + bias_value, + datafile, + system, + numb_batch, + model_branch, + output, + log_level, + ) raise RuntimeError( "The model provided must be a checkpoint file or frozen model file (.pb)" ) @@ -147,7 +164,9 @@ def _change_bias_checkpoint_file( tf.reset_default_graph() checkpoint_path = Path(checkpoint_prefix) - checkpoint_dir = checkpoint_path.parent + checkpoint_dir = ( + checkpoint_path if checkpoint_path.is_dir() else checkpoint_path.parent + ) # Check for valid checkpoint and find the actual checkpoint path checkpoint_state_file = checkpoint_dir / "checkpoint" diff --git a/source/tests/tf/test_change_bias.py b/source/tests/tf/test_change_bias.py index 4392bbd139..6d40b56a09 100644 --- a/source/tests/tf/test_change_bias.py +++ b/source/tests/tf/test_change_bias.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import importlib import json import os import shutil @@ -7,10 +8,18 @@ from pathlib import ( Path, ) +from unittest import ( + mock, +) from deepmd.tf.entrypoints.change_bias import ( change_bias, ) + +# the ``entrypoints`` package re-exports the ``change_bias`` function, which +# shadows the submodule under attribute access; fetch the real module so its +# private helpers can be patched. +change_bias_module = importlib.import_module("deepmd.tf.entrypoints.change_bias") from deepmd.tf.train.run_options import ( RunOptions, ) @@ -112,6 +121,38 @@ def test_change_bias_no_checkpoint_in_directory(self): self.assertIn("No valid checkpoint found", str(cm.exception)) + def test_change_bias_accepts_checkpoint_prefix_with_step(self): + """A checkpoint prefix such as ``model.ckpt-1000`` carries no recognized + suffix, but must be routed to the checkpoint handler when a TensorFlow + ``checkpoint`` state file sits beside it. + """ + ckpt_dir = self.temp_path / "ckpt_prefix" + ckpt_dir.mkdir() + (ckpt_dir / "checkpoint").write_text('model_checkpoint_path: "model.ckpt-1000"') + prefix = ckpt_dir / "model.ckpt-1000" + + with mock.patch.object( + change_bias_module, "_change_bias_checkpoint_file" + ) as mocked: + change_bias(INPUT=str(prefix), mode="change", system=".") + mocked.assert_called_once() + self.assertEqual(mocked.call_args.args[0], str(prefix)) + + def test_change_bias_accepts_checkpoint_directory(self): + """A checkpoint directory (no suffix) containing a ``checkpoint`` state + file must be routed to the checkpoint handler, not rejected. + """ + ckpt_dir = self.temp_path / "ckpt_dir" + ckpt_dir.mkdir() + (ckpt_dir / "checkpoint").write_text('model_checkpoint_path: "model.ckpt-1000"') + + with mock.patch.object( + change_bias_module, "_change_bias_checkpoint_file" + ) as mocked: + change_bias(INPUT=str(ckpt_dir), mode="change", system=".") + mocked.assert_called_once() + self.assertEqual(mocked.call_args.args[0], str(ckpt_dir)) + def test_change_bias_user_defined_requires_real_model(self): """Test that user-defined bias requires a real model with proper structure.""" fake_ckpt_dir = self.temp_path / "fake_checkpoint"