diff --git a/deepmd/main.py b/deepmd/main.py index ff7120c8e7..ab2f0c449e 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -275,7 +275,7 @@ def main_parser() -> argparse.ArgumentParser: "--checkpoint", type=str, default=".", - help="Path to checkpoint. TensorFlow backend: a folder; PyTorch backend: either a folder containing checkpoint, or a pt file", + help="Path to checkpoint, either a folder containing checkpoint or the checkpoint prefix", ) parser_frz.add_argument( "-o", diff --git a/deepmd/tf/entrypoints/freeze.py b/deepmd/tf/entrypoints/freeze.py index 9cb59f4c9d..228f8466cb 100755 --- a/deepmd/tf/entrypoints/freeze.py +++ b/deepmd/tf/entrypoints/freeze.py @@ -12,6 +12,9 @@ from os.path import ( abspath, ) +from pathlib import ( + Path, +) from typing import ( List, Optional, @@ -479,7 +482,7 @@ def freeze( Parameters ---------- checkpoint_folder : str - location of the folder with model + location of either the folder with checkpoint or the checkpoint prefix output : str output file name node_names : Optional[str], optional @@ -492,8 +495,11 @@ def freeze( other arguments """ # We retrieve our checkpoint fullpath - checkpoint = tf.train.get_checkpoint_state(checkpoint_folder) - input_checkpoint = checkpoint.model_checkpoint_path + if Path(checkpoint_folder).is_dir(): + checkpoint = tf.train.get_checkpoint_state(checkpoint_folder) + input_checkpoint = checkpoint.model_checkpoint_path + else: + input_checkpoint = checkpoint_folder # expand the output file to full path output_graph = abspath(output)