From a09b69d3a1e92434bacebdff0348d693be80a7d5 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 1 Mar 2024 18:30:21 -0500 Subject: [PATCH 1/3] tf: remove freeze warning for optional nodes Fix #3334. Signed-off-by: Jinzhe Zeng --- deepmd/tf/entrypoints/freeze.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/deepmd/tf/entrypoints/freeze.py b/deepmd/tf/entrypoints/freeze.py index 228f8466cb..093b1801d0 100755 --- a/deepmd/tf/entrypoints/freeze.py +++ b/deepmd/tf/entrypoints/freeze.py @@ -359,7 +359,13 @@ def freeze_graph( output_node = _make_node_names( freeze_type, modifier, out_suffix=out_suffix, node_names=node_names ) - different_set = set(output_node) - set(input_node) + # see #3334 + optional_node = [ + "train_attr/min_nbor_dist", + "fitting_attr/aparam_nall", + "spin_attr/ntypes_spin", + ] + different_set = set(output_node) - set(input_node) - set(optional_node) if different_set: log.warning( "The following nodes are not in the graph: %s. " From 1fd9a7b8d478bb760ae96b119990baf862d712af Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 1 Mar 2024 18:32:46 -0500 Subject: [PATCH 2/3] still need to do intersection Signed-off-by: Jinzhe Zeng --- deepmd/tf/entrypoints/freeze.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/deepmd/tf/entrypoints/freeze.py b/deepmd/tf/entrypoints/freeze.py index 093b1801d0..9b4dd59c7a 100755 --- a/deepmd/tf/entrypoints/freeze.py +++ b/deepmd/tf/entrypoints/freeze.py @@ -365,13 +365,14 @@ def freeze_graph( "fitting_attr/aparam_nall", "spin_attr/ntypes_spin", ] - different_set = set(output_node) - set(input_node) - set(optional_node) + different_set = set(output_node) - set(input_node) if different_set: - log.warning( - "The following nodes are not in the graph: %s. " - "Skip freezeing these nodes. You may be freezing " - "a checkpoint generated by an old version." % different_set - ) + if different_set - set(optional_node): + log.warning( + "The following nodes are not in the graph: %s. " + "Skip freezeing these nodes. You may be freezing " + "a checkpoint generated by an old version." % different_set + ) # use intersection as output list output_node = list(set(output_node) & set(input_node)) log.info(f"The following nodes will be frozen: {output_node}") From 4714b3f5c05d0c06725fcaa0c368c432fd10be7c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 1 Mar 2024 18:34:04 -0500 Subject: [PATCH 3/3] improve warning Signed-off-by: Jinzhe Zeng --- deepmd/tf/entrypoints/freeze.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/tf/entrypoints/freeze.py b/deepmd/tf/entrypoints/freeze.py index 9b4dd59c7a..c7ab1023fa 100755 --- a/deepmd/tf/entrypoints/freeze.py +++ b/deepmd/tf/entrypoints/freeze.py @@ -367,7 +367,8 @@ def freeze_graph( ] different_set = set(output_node) - set(input_node) if different_set: - if different_set - set(optional_node): + different_set -= set(optional_node) + if different_set: log.warning( "The following nodes are not in the graph: %s. " "Skip freezeing these nodes. You may be freezing "