diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 4338ebe998..a49edb534c 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -796,42 +796,36 @@ def optim_angle_update( feat: str = "edge", ) -> np.ndarray: xp = array_api_compat.array_namespace(angle_ebd, node_ebd, edge_ebd) - angle_dim = angle_ebd.shape[-1] - node_dim = node_ebd.shape[-1] - edge_dim = edge_ebd.shape[-1] - sub_angle_idx = (0, angle_dim) - sub_node_idx = (angle_dim, angle_dim + node_dim) - sub_edge_idx_ij = (angle_dim + node_dim, angle_dim + node_dim + edge_dim) - sub_edge_idx_ik = ( - angle_dim + node_dim + edge_dim, - angle_dim + node_dim + 2 * edge_dim, - ) if feat == "edge": + assert self.edge_angle_linear1 is not None matrix, bias = self.edge_angle_linear1.w, self.edge_angle_linear1.b elif feat == "angle": + assert self.angle_self_linear is not None matrix, bias = self.angle_self_linear.w, self.angle_self_linear.b else: raise NotImplementedError + assert bias is not None + + angle_dim = angle_ebd.shape[-1] + node_dim = node_ebd.shape[-1] + edge_dim = edge_ebd.shape[-1] assert angle_dim + node_dim + 2 * edge_dim == matrix.shape[0] + # Array API does not provide a way to split the array + sub_angle = matrix[:angle_dim, ...] # angle_dim + sub_node = matrix[angle_dim : angle_dim + node_dim, ...] # node_dim + sub_edge_ij = matrix[ + angle_dim + node_dim : angle_dim + node_dim + edge_dim, ... + ] # edge_dim + sub_edge_ik = matrix[angle_dim + node_dim + edge_dim :, ...] # edge_dim # nf * nloc * a_sel * a_sel * angle_dim - sub_angle_update = xp.matmul( - angle_ebd, matrix[sub_angle_idx[0] : sub_angle_idx[1], :] - ) - + sub_angle_update = xp.matmul(angle_ebd, sub_angle) # nf * nloc * angle_dim - sub_node_update = xp.matmul( - node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1], :] - ) - + sub_node_update = xp.matmul(node_ebd, sub_node) # nf * nloc * a_nnei * angle_dim - sub_edge_update_ij = xp.matmul( - edge_ebd, matrix[sub_edge_idx_ij[0] : sub_edge_idx_ij[1], :] - ) - sub_edge_update_ik = xp.matmul( - edge_ebd, matrix[sub_edge_idx_ik[0] : sub_edge_idx_ik[1], :] - ) + sub_edge_update_ij = xp.matmul(edge_ebd, sub_edge_ij) + sub_edge_update_ik = xp.matmul(edge_ebd, sub_edge_ik) result_update = ( bias @@ -851,11 +845,6 @@ def optim_edge_update( feat: str = "node", ) -> np.ndarray: xp = array_api_compat.array_namespace(node_ebd, node_ebd_ext, edge_ebd, nlist) - node_dim = node_ebd.shape[-1] - edge_dim = edge_ebd.shape[-1] - sub_node_idx = (0, node_dim) - sub_node_ext_idx = (node_dim, 2 * node_dim) - sub_edge_idx = (2 * node_dim, 2 * node_dim + edge_dim) if feat == "node": matrix, bias = self.node_edge_linear.w, self.node_edge_linear.b @@ -863,24 +852,24 @@ def optim_edge_update( matrix, bias = self.edge_self_linear.w, self.edge_self_linear.b else: raise NotImplementedError - assert 2 * node_dim + edge_dim == matrix.shape[0] + node_dim = node_ebd.shape[-1] + edge_dim = edge_ebd.shape[-1] + assert node_dim * 2 + edge_dim == matrix.shape[0] + # Array API does not provide a way to split the array + node = matrix[:node_dim, ...] # node_dim + node_ext = matrix[node_dim : 2 * node_dim, ...] # node_dim + edge = matrix[2 * node_dim : 2 * node_dim + edge_dim, ...] # edge_dim # nf * nloc * node/edge_dim - sub_node_update = xp.matmul( - node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1], :] - ) + sub_node_update = xp.matmul(node_ebd, node) # nf * nall * node/edge_dim - sub_node_ext_update = xp.matmul( - node_ebd_ext, matrix[sub_node_ext_idx[0] : sub_node_ext_idx[1], :] - ) + sub_node_ext_update = xp.matmul(node_ebd_ext, node_ext) # nf * nloc * nnei * node/edge_dim sub_node_ext_update = _make_nei_g1(sub_node_ext_update, nlist) # nf * nloc * nnei * node/edge_dim - sub_edge_update = xp.matmul( - edge_ebd, matrix[sub_edge_idx[0] : sub_edge_idx[1], :] - ) + sub_edge_update = xp.matmul(edge_ebd, edge) result_update = ( bias diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 0aa4bf03df..f109109cfd 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -397,48 +397,37 @@ def optim_angle_update( edge_ebd: torch.Tensor, feat: str = "edge", ) -> torch.Tensor: - angle_dim = angle_ebd.shape[-1] - node_dim = node_ebd.shape[-1] - edge_dim = edge_ebd.shape[-1] - sub_angle_idx = (0, angle_dim) - sub_node_idx = (angle_dim, angle_dim + node_dim) - sub_edge_idx_ij = (angle_dim + node_dim, angle_dim + node_dim + edge_dim) - sub_edge_idx_ik = ( - angle_dim + node_dim + edge_dim, - angle_dim + node_dim + 2 * edge_dim, - ) - if feat == "edge": + assert self.edge_angle_linear1 is not None matrix, bias = self.edge_angle_linear1.matrix, self.edge_angle_linear1.bias elif feat == "angle": + assert self.angle_self_linear is not None matrix, bias = self.angle_self_linear.matrix, self.angle_self_linear.bias else: raise NotImplementedError - assert angle_dim + node_dim + 2 * edge_dim == matrix.size()[0] + assert bias is not None - # nf * nloc * a_sel * a_sel * angle_dim - sub_angle_update = torch.matmul( - angle_ebd, matrix[sub_angle_idx[0] : sub_angle_idx[1]] + angle_dim = angle_ebd.shape[-1] + node_dim = node_ebd.shape[-1] + edge_dim = edge_ebd.shape[-1] + # angle_dim, node_dim, edge_dim, edge_dim + sub_angle, sub_node, sub_edge_ij, sub_edge_ik = torch.split( + matrix, [angle_dim, node_dim, edge_dim, edge_dim] ) + # nf * nloc * a_sel * a_sel * angle_dim + sub_angle_update = torch.matmul(angle_ebd, sub_angle) # nf * nloc * angle_dim - sub_node_update = torch.matmul( - node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1]] - ) - + sub_node_update = torch.matmul(node_ebd, sub_node) # nf * nloc * a_nnei * angle_dim - sub_edge_update_ij = torch.matmul( - edge_ebd, matrix[sub_edge_idx_ij[0] : sub_edge_idx_ij[1]] - ) - sub_edge_update_ik = torch.matmul( - edge_ebd, matrix[sub_edge_idx_ik[0] : sub_edge_idx_ik[1]] - ) + sub_edge_update_ij = torch.matmul(edge_ebd, sub_edge_ij) + sub_edge_update_ik = torch.matmul(edge_ebd, sub_edge_ik) result_update = ( bias - + sub_node_update[:, :, None, None, :] - + sub_edge_update_ij[:, :, None, :, :] - + sub_edge_update_ik[:, :, :, None, :] + + sub_node_update.unsqueeze(2).unsqueeze(3) + + sub_edge_update_ij.unsqueeze(2) + + sub_edge_update_ik.unsqueeze(3) + sub_angle_update ) return result_update @@ -451,42 +440,30 @@ def optim_edge_update( nlist: torch.Tensor, feat: str = "node", ) -> torch.Tensor: - node_dim = node_ebd.shape[-1] - edge_dim = edge_ebd.shape[-1] - sub_node_idx = (0, node_dim) - sub_node_ext_idx = (node_dim, 2 * node_dim) - sub_edge_idx = (2 * node_dim, 2 * node_dim + edge_dim) - if feat == "node": matrix, bias = self.node_edge_linear.matrix, self.node_edge_linear.bias elif feat == "edge": matrix, bias = self.edge_self_linear.matrix, self.edge_self_linear.bias else: raise NotImplementedError - assert 2 * node_dim + edge_dim == matrix.size()[0] + assert bias is not None - # nf * nloc * node/edge_dim - sub_node_update = torch.matmul( - node_ebd, matrix[sub_node_idx[0] : sub_node_idx[1]] - ) + node_dim = node_ebd.shape[-1] + edge_dim = edge_ebd.shape[-1] + # node_dim, node_dim, edge_dim + node, node_ext, edge = torch.split(matrix, [node_dim, node_dim, edge_dim]) + # nf * nloc * node/edge_dim + sub_node_update = torch.matmul(node_ebd, node) # nf * nall * node/edge_dim - sub_node_ext_update = torch.matmul( - node_ebd_ext, matrix[sub_node_ext_idx[0] : sub_node_ext_idx[1]] - ) + sub_node_ext_update = torch.matmul(node_ebd_ext, node_ext) # nf * nloc * nnei * node/edge_dim sub_node_ext_update = _make_nei_g1(sub_node_ext_update, nlist) - # nf * nloc * nnei * node/edge_dim - sub_edge_update = torch.matmul( - edge_ebd, matrix[sub_edge_idx[0] : sub_edge_idx[1]] - ) + sub_edge_update = torch.matmul(edge_ebd, edge) result_update = ( - bias - + sub_node_update[:, :, None, :] - + sub_edge_update - + sub_node_ext_update + bias + sub_node_update.unsqueeze(2) + sub_edge_update + sub_node_ext_update ) return result_update @@ -614,7 +591,7 @@ def forward( nb, nloc, self.n_multi_edge_message, self.n_dim ) for head_index in range(self.n_multi_edge_message): - n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) + n_update_list.append(node_edge_update_mul_head[..., head_index, :]) else: n_update_list.append(node_edge_update) # update node_ebd @@ -649,14 +626,14 @@ def forward( edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) else: # use the first a_compress_dim dim for node and edge - node_ebd_for_angle = node_ebd[:, :, : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd[:, :, :, : self.e_a_compress_dim] + node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] else: node_ebd_for_angle = node_ebd edge_ebd_for_angle = edge_ebd # nb x nloc x a_nnei x e_dim - edge_for_angle = edge_ebd_for_angle[:, :, : self.a_sel, :] + edge_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] # nb x nloc x a_nnei x e_dim edge_for_angle = torch.where( a_nlist_mask.unsqueeze(-1), edge_for_angle, 0.0 @@ -704,9 +681,7 @@ def forward( # nb x nloc x a_nnei x a_nnei x e_dim weighted_edge_angle_update = ( - a_sw[:, :, :, None, None] - * a_sw[:, :, None, :, None] - * edge_angle_update + a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update ) # nb x nloc x a_nnei x e_dim reduced_edge_angle_update = torch.sum(