From 75b403e4c53d1989d4d3ede2f07fc97431f40526 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 9 Feb 2024 23:37:08 -0500 Subject: [PATCH] tf: support checkpoint path (instead of directory) in dp freeze Signed-off-by: Jinzhe Zeng --- deepmd/main.py | 2 +- deepmd/tf/entrypoints/freeze.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/deepmd/main.py b/deepmd/main.py index ff7120c8e7..ab2f0c449e 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -275,7 +275,7 @@ def main_parser() -> argparse.ArgumentParser: "--checkpoint", type=str, default=".", - help="Path to checkpoint. TensorFlow backend: a folder; PyTorch backend: either a folder containing checkpoint, or a pt file", + help="Path to checkpoint, either a folder containing checkpoint or the checkpoint prefix", ) parser_frz.add_argument( "-o", diff --git a/deepmd/tf/entrypoints/freeze.py b/deepmd/tf/entrypoints/freeze.py index 9cb59f4c9d..228f8466cb 100755 --- a/deepmd/tf/entrypoints/freeze.py +++ b/deepmd/tf/entrypoints/freeze.py @@ -12,6 +12,9 @@ from os.path import ( abspath, ) +from pathlib import ( + Path, +) from typing import ( List, Optional, @@ -479,7 +482,7 @@ def freeze( Parameters ---------- checkpoint_folder : str - location of the folder with model + location of either the folder with checkpoint or the checkpoint prefix output : str output file name node_names : Optional[str], optional @@ -492,8 +495,11 @@ def freeze( other arguments """ # We retrieve our checkpoint fullpath - checkpoint = tf.train.get_checkpoint_state(checkpoint_folder) - input_checkpoint = checkpoint.model_checkpoint_path + if Path(checkpoint_folder).is_dir(): + checkpoint = tf.train.get_checkpoint_state(checkpoint_folder) + input_checkpoint = checkpoint.model_checkpoint_path + else: + input_checkpoint = checkpoint_folder # expand the output file to full path output_graph = abspath(output)