From 04aacf06f695eae37d2a3d3de55c34117b3f35e3 Mon Sep 17 00:00:00 2001 From: Yifan Li Date: Mon, 23 Oct 2023 15:45:42 -0400 Subject: [PATCH 1/5] support conversion to pbtxt in command line interface --- deepmd/entrypoints/convert.py | 39 +++++++++++++++++++++-------------- deepmd_cli/main.py | 2 +- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/deepmd/entrypoints/convert.py b/deepmd/entrypoints/convert.py index 73906fe074..0511950b73 100644 --- a/deepmd/entrypoints/convert.py +++ b/deepmd/entrypoints/convert.py @@ -5,6 +5,7 @@ convert_12_to_21, convert_13_to_21, convert_20_to_21, + convert_pb_to_pbtxt, convert_pbtxt_to_pb, convert_to_21, ) @@ -17,20 +18,26 @@ def convert( output_model: str, **kwargs, ): - if FROM == "auto": - convert_to_21(input_model, output_model) - elif FROM == "0.12": - convert_012_to_21(input_model, output_model) - elif FROM == "1.0": - convert_10_to_21(input_model, output_model) - elif FROM in ["1.1", "1.2"]: - # no difference between 1.1 and 1.2 - convert_12_to_21(input_model, output_model) - elif FROM == "1.3": - convert_13_to_21(input_model, output_model) - elif FROM == "2.0": - convert_20_to_21(input_model, output_model) - elif FROM == "pbtxt": - convert_pbtxt_to_pb(input_model, output_model) + if output_model[:-6] == ".pbtxt": + if input_model[:-6] != ".pbtxt": + convert_pb_to_pbtxt(input_model, output_model) + else: + raise RuntimeError("input model is already pbtx") else: - raise RuntimeError("unsupported model version " + FROM) + if FROM == "auto": + convert_to_21(input_model, output_model) + elif FROM == "0.12": + convert_012_to_21(input_model, output_model) + elif FROM == "1.0": + convert_10_to_21(input_model, output_model) + elif FROM in ["1.1", "1.2"]: + # no difference between 1.1 and 1.2 + convert_12_to_21(input_model, output_model) + elif FROM == "1.3": + convert_13_to_21(input_model, output_model) + elif FROM == "2.0": + convert_20_to_21(input_model, output_model) + elif FROM == "pbtxt": + convert_pbtxt_to_pb(input_model, output_model) + else: + raise RuntimeError("unsupported model version " + FROM) diff --git a/deepmd_cli/main.py b/deepmd_cli/main.py index 5a0670d8dc..6ad14c3714 100644 --- a/deepmd_cli/main.py +++ b/deepmd_cli/main.py @@ -491,7 +491,7 @@ def main_parser() -> argparse.ArgumentParser: "--output-model", default="convert_out.pb", type=str, - help="the output model", + help="the output model\nIf it ends with '.pbtxt', the provided model will be converted to pbtxt format, without version conversion.", ) # neighbor_stat From 9bcb2150f8f57d59f15c03f8bafceb4f93a2bd9e Mon Sep 17 00:00:00 2001 From: Yifan Li Date: Mon, 23 Oct 2023 17:02:44 -0400 Subject: [PATCH 2/5] clarify help message --- deepmd_cli/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd_cli/main.py b/deepmd_cli/main.py index 6ad14c3714..bffc1c6911 100644 --- a/deepmd_cli/main.py +++ b/deepmd_cli/main.py @@ -491,7 +491,7 @@ def main_parser() -> argparse.ArgumentParser: "--output-model", default="convert_out.pb", type=str, - help="the output model\nIf it ends with '.pbtxt', the provided model will be converted to pbtxt format, without version conversion.", + help="the output model\nIf OUTPUT_MODEL ends with '.pbtxt', the provided model will be converted to pbtxt format, without version conversion.", ) # neighbor_stat From ef61957b3ba6024d91d355a78703c8b4c9f867d9 Mon Sep 17 00:00:00 2001 From: Yifan Li Date: Mon, 23 Oct 2023 17:52:36 -0400 Subject: [PATCH 3/5] fix slicing error --- deepmd/entrypoints/convert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/entrypoints/convert.py b/deepmd/entrypoints/convert.py index 0511950b73..ed125f04c7 100644 --- a/deepmd/entrypoints/convert.py +++ b/deepmd/entrypoints/convert.py @@ -18,8 +18,8 @@ def convert( output_model: str, **kwargs, ): - if output_model[:-6] == ".pbtxt": - if input_model[:-6] != ".pbtxt": + if output_model[-6:] == ".pbtxt": + if input_model[-6:] != ".pbtxt": convert_pb_to_pbtxt(input_model, output_model) else: raise RuntimeError("input model is already pbtx") From 48c001be4764ee829a34a0cb530ed939946330ed Mon Sep 17 00:00:00 2001 From: Yifan Li Date: Mon, 23 Oct 2023 18:09:25 -0400 Subject: [PATCH 4/5] fix typo --- deepmd/entrypoints/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/entrypoints/convert.py b/deepmd/entrypoints/convert.py index ed125f04c7..bea047ba72 100644 --- a/deepmd/entrypoints/convert.py +++ b/deepmd/entrypoints/convert.py @@ -22,7 +22,7 @@ def convert( if input_model[-6:] != ".pbtxt": convert_pb_to_pbtxt(input_model, output_model) else: - raise RuntimeError("input model is already pbtx") + raise RuntimeError("input model is already pbtxt") else: if FROM == "auto": convert_to_21(input_model, output_model) From f2d7c8d294767e6dfc21b679bea9e0610d8b0ca1 Mon Sep 17 00:00:00 2001 From: Yifan Li Date: Mon, 23 Oct 2023 18:37:47 -0400 Subject: [PATCH 5/5] avoid using .pbtxt as model conversion output name in unittest --- source/tests/test_deeppot_a.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 12f677e618..1c6cdc4afc 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -745,7 +745,7 @@ def setUp(self): def test_convert_012(self): old_model = "deeppot.pb" - new_model = "deeppot.pbtxt" + new_model = "deeppot-new.pb" convert_pbtxt_to_pb(str(tests_path / "infer" / "sea_012.pbtxt"), old_model) run_dp(f"dp convert-from 0.12 -i {old_model} -o {new_model}") dp = DeepPot(new_model) @@ -755,7 +755,7 @@ def test_convert_012(self): def test_convert(self): old_model = "deeppot.pb" - new_model = "deeppot.pbtxt" + new_model = "deeppot-new.pb" convert_pbtxt_to_pb(str(tests_path / "infer" / "sea_012.pbtxt"), old_model) run_dp(f"dp convert-from -i {old_model} -o {new_model}") dp = DeepPot(new_model)