From aad66f91b6fa409716069860925fb347c8a53113 Mon Sep 17 00:00:00 2001 From: caic99 Date: Mon, 31 Mar 2025 08:56:09 +0000 Subject: [PATCH 1/2] perf: reschedule plus op --- deepmd/pt/model/descriptor/repflow_layer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 01f8477b7b..0aa4bf03df 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -435,11 +435,12 @@ def optim_angle_update( ) result_update = ( - sub_angle_update + bias + sub_node_update[:, :, None, None, :] + sub_edge_update_ij[:, :, None, :, :] + sub_edge_update_ik[:, :, :, None, :] - ) + bias + + sub_angle_update + ) return result_update def optim_edge_update( @@ -482,8 +483,11 @@ def optim_edge_update( ) result_update = ( - sub_edge_update + sub_node_ext_update + sub_node_update[:, :, None, :] - ) + bias + bias + + sub_node_update[:, :, None, :] + + sub_edge_update + + sub_node_ext_update + ) return result_update def forward( From 2dd7869a3daeca65e86db9266664d09e4edfc5b4 Mon Sep 17 00:00:00 2001 From: caic99 Date: Mon, 31 Mar 2025 09:00:45 +0000 Subject: [PATCH 2/2] update dp backend --- deepmd/dpmodel/descriptor/repflows.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 56692589d7..4338ebe998 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -834,11 +834,12 @@ def optim_angle_update( ) result_update = ( - sub_angle_update + bias + sub_node_update[:, :, xp.newaxis, xp.newaxis, :] + sub_edge_update_ij[:, :, xp.newaxis, :, :] + sub_edge_update_ik[:, :, :, xp.newaxis, :] - ) + bias + + sub_angle_update + ) return result_update def optim_edge_update( @@ -882,8 +883,11 @@ def optim_edge_update( ) result_update = ( - sub_edge_update + sub_node_ext_update + sub_node_update[:, :, xp.newaxis, :] - ) + bias + bias + + sub_node_update[:, :, xp.newaxis, :] + + sub_edge_update + + sub_node_ext_update + ) return result_update def call(