diff --git a/deepmd/entrypoints/convert.py b/deepmd/entrypoints/convert.py index 73906fe074..bea047ba72 100644 --- a/deepmd/entrypoints/convert.py +++ b/deepmd/entrypoints/convert.py @@ -5,6 +5,7 @@ convert_12_to_21, convert_13_to_21, convert_20_to_21, + convert_pb_to_pbtxt, convert_pbtxt_to_pb, convert_to_21, ) @@ -17,20 +18,26 @@ def convert( output_model: str, **kwargs, ): - if FROM == "auto": - convert_to_21(input_model, output_model) - elif FROM == "0.12": - convert_012_to_21(input_model, output_model) - elif FROM == "1.0": - convert_10_to_21(input_model, output_model) - elif FROM in ["1.1", "1.2"]: - # no difference between 1.1 and 1.2 - convert_12_to_21(input_model, output_model) - elif FROM == "1.3": - convert_13_to_21(input_model, output_model) - elif FROM == "2.0": - convert_20_to_21(input_model, output_model) - elif FROM == "pbtxt": - convert_pbtxt_to_pb(input_model, output_model) + if output_model[-6:] == ".pbtxt": + if input_model[-6:] != ".pbtxt": + convert_pb_to_pbtxt(input_model, output_model) + else: + raise RuntimeError("input model is already pbtxt") else: - raise RuntimeError("unsupported model version " + FROM) + if FROM == "auto": + convert_to_21(input_model, output_model) + elif FROM == "0.12": + convert_012_to_21(input_model, output_model) + elif FROM == "1.0": + convert_10_to_21(input_model, output_model) + elif FROM in ["1.1", "1.2"]: + # no difference between 1.1 and 1.2 + convert_12_to_21(input_model, output_model) + elif FROM == "1.3": + convert_13_to_21(input_model, output_model) + elif FROM == "2.0": + convert_20_to_21(input_model, output_model) + elif FROM == "pbtxt": + convert_pbtxt_to_pb(input_model, output_model) + else: + raise RuntimeError("unsupported model version " + FROM) diff --git a/deepmd_cli/main.py b/deepmd_cli/main.py index 5a0670d8dc..bffc1c6911 100644 --- a/deepmd_cli/main.py +++ b/deepmd_cli/main.py @@ -491,7 +491,7 @@ def main_parser() -> argparse.ArgumentParser: "--output-model", default="convert_out.pb", type=str, - help="the output model", + help="the output model\nIf OUTPUT_MODEL ends with '.pbtxt', the provided model will be converted to pbtxt format, without version conversion.", ) # neighbor_stat diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 12f677e618..1c6cdc4afc 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -745,7 +745,7 @@ def setUp(self): def test_convert_012(self): old_model = "deeppot.pb" - new_model = "deeppot.pbtxt" + new_model = "deeppot-new.pb" convert_pbtxt_to_pb(str(tests_path / "infer" / "sea_012.pbtxt"), old_model) run_dp(f"dp convert-from 0.12 -i {old_model} -o {new_model}") dp = DeepPot(new_model) @@ -755,7 +755,7 @@ def test_convert_012(self): def test_convert(self): old_model = "deeppot.pb" - new_model = "deeppot.pbtxt" + new_model = "deeppot-new.pb" convert_pbtxt_to_pb(str(tests_path / "infer" / "sea_012.pbtxt"), old_model) run_dp(f"dp convert-from -i {old_model} -o {new_model}") dp = DeepPot(new_model)