Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions deepmd/entrypoints/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
convert_12_to_21,
convert_13_to_21,
convert_20_to_21,
convert_pb_to_pbtxt,
convert_pbtxt_to_pb,
convert_to_21,
)
Expand All @@ -17,20 +18,26 @@ def convert(
output_model: str,
**kwargs,
):
if FROM == "auto":
convert_to_21(input_model, output_model)
elif FROM == "0.12":
convert_012_to_21(input_model, output_model)
elif FROM == "1.0":
convert_10_to_21(input_model, output_model)
elif FROM in ["1.1", "1.2"]:
# no difference between 1.1 and 1.2
convert_12_to_21(input_model, output_model)
elif FROM == "1.3":
convert_13_to_21(input_model, output_model)
elif FROM == "2.0":
convert_20_to_21(input_model, output_model)
elif FROM == "pbtxt":
convert_pbtxt_to_pb(input_model, output_model)
if output_model[-6:] == ".pbtxt":
if input_model[-6:] != ".pbtxt":
convert_pb_to_pbtxt(input_model, output_model)
else:
raise RuntimeError("input model is already pbtxt")
else:
raise RuntimeError("unsupported model version " + FROM)
if FROM == "auto":
convert_to_21(input_model, output_model)
elif FROM == "0.12":
convert_012_to_21(input_model, output_model)
elif FROM == "1.0":
convert_10_to_21(input_model, output_model)
elif FROM in ["1.1", "1.2"]:
# no difference between 1.1 and 1.2
convert_12_to_21(input_model, output_model)
elif FROM == "1.3":
convert_13_to_21(input_model, output_model)
elif FROM == "2.0":
convert_20_to_21(input_model, output_model)
elif FROM == "pbtxt":
convert_pbtxt_to_pb(input_model, output_model)
else:
raise RuntimeError("unsupported model version " + FROM)
2 changes: 1 addition & 1 deletion deepmd_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def main_parser() -> argparse.ArgumentParser:
"--output-model",
default="convert_out.pb",
type=str,
help="the output model",
help="the output model\nIf OUTPUT_MODEL ends with '.pbtxt', the provided model will be converted to pbtxt format, without version conversion.",
)

# neighbor_stat
Expand Down
4 changes: 2 additions & 2 deletions source/tests/test_deeppot_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ def setUp(self):

def test_convert_012(self):
old_model = "deeppot.pb"
new_model = "deeppot.pbtxt"
new_model = "deeppot-new.pb"
convert_pbtxt_to_pb(str(tests_path / "infer" / "sea_012.pbtxt"), old_model)
run_dp(f"dp convert-from 0.12 -i {old_model} -o {new_model}")
dp = DeepPot(new_model)
Expand All @@ -755,7 +755,7 @@ def test_convert_012(self):

def test_convert(self):
old_model = "deeppot.pb"
new_model = "deeppot.pbtxt"
new_model = "deeppot-new.pb"
convert_pbtxt_to_pb(str(tests_path / "infer" / "sea_012.pbtxt"), old_model)
run_dp(f"dp convert-from -i {old_model} -o {new_model}")
dp = DeepPot(new_model)
Expand Down