Skip to content

[Code scan] Accept standard TensorFlow checkpoint inputs in change-bias #5683

Description

@njzjz

This issue comes from a Codex global scan of deepmodeling/deepmd-kit at commit 73de44b1f94471b2e3bdb6b11f57b34d7bc791bb.

Problem

The CLI examples document TensorFlow change-bias with a checkpoint-style input:

deepmd-kit/deepmd/main.py

Lines 824 to 833 in 73de44b

"change-bias",
parents=[parser_log],
help="Change model out bias according to the input data.",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
examples:
dp --pt change-bias model.pt -s data -n 10 -m change
dp --tf change-bias model.ckpt -s data -n 10 -m change
"""

The frozen-model fallback message also recommends passing a checkpoint directory:

raise NotImplementedError(
"Bias modification for frozen models (.pb) is not yet fully implemented. "
"Recommended workaround:\n"
"1. Use a checkpoint directory instead of a frozen model\n"
"2. Or load the model, modify bias in training, then freeze again\n"
f" dp --tf change-bias <checkpoint_dir> -b {' '.join(map(str, bias_value)) if bias_value else '<bias_values>'} -o <output_dir>\n"
" dp freeze -c <output_dir> -o modified_model.pb"

But the dispatcher only accepts inputs ending in .pb, .pbtxt, .ckpt, .meta, .data, or .index:

# Determine input type and handle accordingly
if INPUT.endswith(".pb"):
# Frozen model (.pb)
return _change_bias_frozen_model(
INPUT,
mode,
bias_value,
datafile,
system,
numb_batch,
model_branch,
output,
log_level,
)
elif INPUT.endswith(".pbtxt"):
# Text format frozen model (.pbtxt) - not supported
raise NotImplementedError(
"Bias changing for .pbtxt models is not supported. "
"Please convert to .pb format first using: dp convert-from pbtxt -i model.pbtxt -o model.pb"
)
elif INPUT.endswith((".ckpt", ".meta", ".data", ".index")):
# Individual checkpoint files
checkpoint_prefix = INPUT
if INPUT.endswith((".meta", ".data", ".index")):
checkpoint_prefix = INPUT.rsplit(".", 1)[0]
return _change_bias_checkpoint_file(
checkpoint_prefix,
mode,
bias_value,
datafile,
system,
numb_batch,
model_branch,
output,
log_level,
)
else:
raise RuntimeError(
"The model provided must be a checkpoint file or frozen model file (.pb)"

Real TensorFlow checkpoint prefixes commonly look like model.ckpt-1000, and checkpoint directories have no such suffix. Those inputs are rejected before _change_bias_checkpoint_file() can inspect the checkpoint state file:

checkpoint_path = Path(checkpoint_prefix)
checkpoint_dir = checkpoint_path.parent
# Check for valid checkpoint and find the actual checkpoint path
checkpoint_state_file = checkpoint_dir / "checkpoint"
if not checkpoint_state_file.exists():
raise RuntimeError(f"No valid checkpoint found in {checkpoint_dir}")
# Get the latest checkpoint path from the checkpoint state file
checkpoint_state = tf.train.get_checkpoint_state(str(checkpoint_dir))
if checkpoint_state is None or checkpoint_state.model_checkpoint_path is None:
raise RuntimeError(f"No valid checkpoint state found in {checkpoint_dir}")
# The model_checkpoint_path from get_checkpoint_state is the full path to the checkpoint
actual_checkpoint_path = checkpoint_state.model_checkpoint_path

Impact

Users following the documented workaround with a checkpoint directory, or passing a standard checkpoint prefix with a step suffix, get RuntimeError("The model provided must be a checkpoint file or frozen model file (.pb)") before checkpoint loading starts.

Suggested fix

Accept checkpoint directories that contain a TensorFlow checkpoint state file, and accept checkpoint prefixes such as model.ckpt-1000 by checking for sibling .meta/.index files or by asking TensorFlow for checkpoint state instead of relying only on filename suffixes.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    Status
    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions