From 275fb5258b2f762c6a33e161df9229ae42113c88 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Mon, 3 Mar 2025 14:43:41 +0800 Subject: [PATCH 1/8] add atomic_weight for model inference --- deepmd/pt/model/model/dipole_model.py | 4 ++++ deepmd/pt/model/model/make_model.py | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index a24820b74a..108f1241be 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -60,6 +60,7 @@ def forward( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, + atomic_weight: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: model_ret = self.forward_common( coord, @@ -68,6 +69,7 @@ def forward( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + atomic_weight=atomic_weight, ) if self.get_fitting_net() is not None: model_predict = {} @@ -98,6 +100,7 @@ def forward_lower( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, + atomic_weight: Optional[torch.Tensor] = None, ): model_ret = self.forward_common_lower( extended_coord, @@ -108,6 +111,7 @@ def forward_lower( aparam=aparam, do_atomic_virial=do_atomic_virial, extra_nlist_sort=self.need_sorted_nlist_for_lower(), + atomic_weight=atomic_weight, ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index c32abaa095..0793ec2cf1 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -135,6 +135,7 @@ def forward_common( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, + atomic_weight: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: """Return model prediction. @@ -188,6 +189,7 @@ def forward_common( do_atomic_virial=do_atomic_virial, fparam=fp, aparam=ap, + atomic_weight=atomic_weight, ) model_predict = communicate_extended_output( model_predict_lower, @@ -242,6 +244,7 @@ def forward_common_lower( do_atomic_virial: bool = False, comm_dict: Optional[dict[str, torch.Tensor]] = None, extra_nlist_sort: bool = False, + atomic_weight: Optional[torch.Tensor] = None, ): """Return model prediction. Lower interface that takes extended atomic coordinates and types, nlist, and mapping @@ -293,6 +296,14 @@ def forward_common_lower( aparam=ap, comm_dict=comm_dict, ) + # add weight to atomic_output + kw = next(iter(self.atomic_output_def().var_defs.keys())) + atomic_weight = torch.ones_like(atomic_ret[kw]) + if atomic_weight is not None: + # atomic_weight: nf x nloc x dim + atomic_ret[kw] = atomic_ret[kw] * atomic_weight.reshape( + *atomic_ret[kw].shape[:-1], -1 + ) model_predict = fit_output_to_model_output( atomic_ret, self.atomic_output_def(), From b5356d9e275fba730cc7e8ea695dfd3c37fa1e7f Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Mon, 3 Mar 2025 15:22:12 +0800 Subject: [PATCH 2/8] remove test tensor --- deepmd/pt/model/model/make_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 0793ec2cf1..93aaa5f9f2 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -298,7 +298,6 @@ def forward_common_lower( ) # add weight to atomic_output kw = next(iter(self.atomic_output_def().var_defs.keys())) - atomic_weight = torch.ones_like(atomic_ret[kw]) if atomic_weight is not None: # atomic_weight: nf x nloc x dim atomic_ret[kw] = atomic_ret[kw] * atomic_weight.reshape( From 205e506c95e3e3d57db2bf37e8fb48b05d6e1e62 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Mon, 3 Mar 2025 15:52:10 +0800 Subject: [PATCH 3/8] fix bug in model frozen --- deepmd/pt/model/model/make_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 93aaa5f9f2..2bcd0fab50 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -297,11 +297,11 @@ def forward_common_lower( comm_dict=comm_dict, ) # add weight to atomic_output - kw = next(iter(self.atomic_output_def().var_defs.keys())) + kw = self.atomic_model.fitting_net.var_name if atomic_weight is not None: # atomic_weight: nf x nloc x dim - atomic_ret[kw] = atomic_ret[kw] * atomic_weight.reshape( - *atomic_ret[kw].shape[:-1], -1 + atomic_ret[kw] = atomic_ret[kw] * atomic_weight.view( + [atomic_ret[kw].shape[0], atomic_ret[kw].shape[1], -1] ) model_predict = fit_output_to_model_output( atomic_ret, From 50a8a21d94d98d501b7ff0ccc7c6412ee2838b8c Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Mon, 3 Mar 2025 19:57:56 +0800 Subject: [PATCH 4/8] deal with case without fitting net --- deepmd/pt/model/model/make_model.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 2bcd0fab50..3dd2d0a06d 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -297,12 +297,14 @@ def forward_common_lower( comm_dict=comm_dict, ) # add weight to atomic_output - kw = self.atomic_model.fitting_net.var_name - if atomic_weight is not None: - # atomic_weight: nf x nloc x dim - atomic_ret[kw] = atomic_ret[kw] * atomic_weight.view( - [atomic_ret[kw].shape[0], atomic_ret[kw].shape[1], -1] - ) + if hasattr(self.atomic_model, "fitting_net"): + if hasattr(self.atomic_model.fitting_net, "var_name"): + kw = self.atomic_model.fitting_net.var_name + if atomic_weight is not None: + # atomic_weight: nf x nloc x dim + atomic_ret[kw] = atomic_ret[kw] * atomic_weight.view( + [atomic_ret[kw].shape[0], atomic_ret[kw].shape[1], -1] + ) model_predict = fit_output_to_model_output( atomic_ret, self.atomic_output_def(), From 32a64e18388aaa5b4508e5d372b7b1080413b07c Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Tue, 11 Mar 2025 22:10:48 +0800 Subject: [PATCH 5/8] move `atomic_weight` to atomic model level --- deepmd/pt/model/atomic_model/base_atomic_model.py | 7 +++++++ deepmd/pt/model/model/make_model.py | 10 +--------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 56af5f4f43..ae8bbe2bd8 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -203,6 +203,7 @@ def forward_common_atomic( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, comm_dict: Optional[dict[str, torch.Tensor]] = None, + atomic_weight: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: """Common interface for atomic inference. @@ -271,6 +272,10 @@ def forward_common_atomic( ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) * atom_mask[:, :, None] ).view(out_shape) + if atomic_weight is not None: + ret_dict[kk] = ret_dict[kk] * atomic_weight.view( + [out_shape[0], out_shape[1], -1] + ) ret_dict["mask"] = atom_mask return ret_dict @@ -284,6 +289,7 @@ def forward( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, comm_dict: Optional[dict[str, torch.Tensor]] = None, + atomic_weight: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: return self.forward_common_atomic( extended_coord, @@ -293,6 +299,7 @@ def forward( fparam=fparam, aparam=aparam, comm_dict=comm_dict, + atomic_weight=atomic_weight, ) def change_type_map( diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 3dd2d0a06d..40a48a6c51 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -295,16 +295,8 @@ def forward_common_lower( fparam=fp, aparam=ap, comm_dict=comm_dict, + atomic_weight=atomic_weight, ) - # add weight to atomic_output - if hasattr(self.atomic_model, "fitting_net"): - if hasattr(self.atomic_model.fitting_net, "var_name"): - kw = self.atomic_model.fitting_net.var_name - if atomic_weight is not None: - # atomic_weight: nf x nloc x dim - atomic_ret[kw] = atomic_ret[kw] * atomic_weight.view( - [atomic_ret[kw].shape[0], atomic_ret[kw].shape[1], -1] - ) model_predict = fit_output_to_model_output( atomic_ret, self.atomic_output_def(), From ad063bd9ca4f1c4132fbeeb4285564d59b1c1b78 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Tue, 11 Mar 2025 23:53:32 +0800 Subject: [PATCH 6/8] (np) add atomic_weight to atomic model and dipole model; add UT for pt/np --- .../dpmodel/atomic_model/base_atomic_model.py | 8 +++++ deepmd/dpmodel/model/make_model.py | 32 +++++++++++++++---- source/tests/pt/model/test_dp_atomic_model.py | 15 +++++++++ 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index eb95886598..63fcbc4dbe 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -149,6 +149,7 @@ def forward_common_atomic( mapping: Optional[np.ndarray] = None, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, + atomic_weight: Optional[np.ndarray] = None, ) -> dict[str, np.ndarray]: """Common interface for atomic inference. @@ -213,6 +214,11 @@ def forward_common_atomic( tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr)) ret_dict[kk] = xp.reshape(tmp_arr, out_shape) + if atomic_weight is not None: + _out_shape = ret_dict[kk].shape + ret_dict[kk] = ret_dict[kk] * atomic_weight.reshape( + [_out_shape[0], _out_shape[1], -1] + ) ret_dict["mask"] = xp.astype(atom_mask, xp.int32) return ret_dict @@ -225,6 +231,7 @@ def call( mapping: Optional[np.ndarray] = None, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, + atomic_weight: Optional[np.ndarray] = None, ) -> dict[str, np.ndarray]: return self.forward_common_atomic( extended_coord, @@ -233,6 +240,7 @@ def call( mapping=mapping, fparam=fparam, aparam=aparam, + atomic_weight=atomic_weight, ) def serialize(self) -> dict: diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index ec0b986394..9f95db3416 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -63,6 +63,7 @@ def model_call_from_call_lower( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, do_atomic_virial: bool = False, + atomic_weight: Optional[np.ndarray] = None, ): """Return model prediction from lower interface. @@ -121,6 +122,7 @@ def model_call_from_call_lower( fparam=fp, aparam=ap, do_atomic_virial=do_atomic_virial, + atomic_weight=atomic_weight, ) model_predict = communicate_extended_output( model_predict_lower, @@ -224,6 +226,7 @@ def call( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, do_atomic_virial: bool = False, + atomic_weight: Optional[np.ndarray] = None, ) -> dict[str, np.ndarray]: """Return model prediction. @@ -250,8 +253,12 @@ def call( The keys are defined by the `ModelOutputDef`. """ - cc, bb, fp, ap, input_prec = self.input_type_cast( - coord, box=box, fparam=fparam, aparam=aparam + cc, bb, fp, ap, aw, input_prec = self.input_type_cast( + coord, + box=box, + fparam=fparam, + aparam=aparam, + atomic_weight=atomic_weight, ) del coord, box, fparam, aparam model_predict = model_call_from_call_lower( @@ -266,6 +273,7 @@ def call( fparam=fp, aparam=ap, do_atomic_virial=do_atomic_virial, + atomic_weight=aw, ) model_predict = self.output_type_cast(model_predict, input_prec) return model_predict @@ -279,6 +287,7 @@ def call_lower( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, do_atomic_virial: bool = False, + atomic_weight: Optional[np.ndarray] = None, ): """Return model prediction. Lower interface that takes extended atomic coordinates and types, nlist, and mapping @@ -316,8 +325,11 @@ def call_lower( nlist, extra_nlist_sort=self.need_sorted_nlist_for_lower(), ) - cc_ext, _, fp, ap, input_prec = self.input_type_cast( - extended_coord, fparam=fparam, aparam=aparam + cc_ext, _, fp, ap, aw, input_prec = self.input_type_cast( + extended_coord, + fparam=fparam, + aparam=aparam, + atomic_weight=atomic_weight, ) del extended_coord, fparam, aparam model_predict = self.forward_common_atomic( @@ -328,6 +340,7 @@ def call_lower( fparam=fp, aparam=ap, do_atomic_virial=do_atomic_virial, + atomic_weight=aw, ) model_predict = self.output_type_cast(model_predict, input_prec) return model_predict @@ -341,6 +354,7 @@ def forward_common_atomic( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, do_atomic_virial: bool = False, + atomic_weight: Optional[np.ndarray] = None, ): atomic_ret = self.atomic_model.forward_common_atomic( extended_coord, @@ -349,6 +363,7 @@ def forward_common_atomic( mapping=mapping, fparam=fparam, aparam=aparam, + atomic_weight=atomic_weight, ) return fit_output_to_model_output( atomic_ret, @@ -365,11 +380,13 @@ def input_type_cast( box: Optional[np.ndarray] = None, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, + atomic_weight: Optional[np.ndarray] = None, ) -> tuple[ np.ndarray, Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], + Optional[np.ndarray], str, ]: """Cast the input data to global float type.""" @@ -379,11 +396,11 @@ def input_type_cast( ### _lst: list[Optional[np.ndarray]] = [ vv.astype(coord.dtype) if vv is not None else None - for vv in [box, fparam, aparam] + for vv in [box, fparam, aparam, atomic_weight] ] - box, fparam, aparam = _lst + box, fparam, aparam, atomic_weight = _lst if input_prec == RESERVED_PRECISION_DICT[self.global_np_float_precision]: - return coord, box, fparam, aparam, input_prec + return coord, box, fparam, aparam, atomic_weight, input_prec else: pp = self.global_np_float_precision return ( @@ -391,6 +408,7 @@ def input_type_cast( box.astype(pp) if box is not None else None, fparam.astype(pp) if fparam is not None else None, aparam.astype(pp) if aparam is not None else None, + atomic_weight.astype(pp) if atomic_weight is not None else None, input_prec, ) diff --git a/source/tests/pt/model/test_dp_atomic_model.py b/source/tests/pt/model/test_dp_atomic_model.py index 6d6a22f357..2c9c00faeb 100644 --- a/source/tests/pt/model/test_dp_atomic_model.py +++ b/source/tests/pt/model/test_dp_atomic_model.py @@ -73,6 +73,13 @@ def test_self_consistency(self) -> None: to_numpy_array(ret0["energy"]), to_numpy_array(ret1["energy"]), ) + # add test for atomic_weight + aw = torch.rand([nf, nloc, 1], dtype=dtype, device=env.DEVICE) + ret2 = md0.forward_common_atomic(*args, atomic_weight=aw) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"] * aw.reshape(nf, nloc, -1)), + to_numpy_array(ret2["energy"]), + ) def test_dp_consistency(self) -> None: nf, nloc, nnei = self.nlist.shape @@ -101,6 +108,14 @@ def test_dp_consistency(self) -> None: ret0["energy"], to_numpy_array(ret1["energy"]), ) + # add test for atomic_weight + aw = torch.rand([nf, nloc, 1], dtype=dtype, device=env.DEVICE) + ret2 = md0.forward_common_atomic(*args0, atomic_weight=to_numpy_array(aw)) + ret3 = md1.forward_common_atomic(*args1, atomic_weight=aw) + np.testing.assert_allclose( + ret2["energy"], + to_numpy_array(ret3["energy"]), + ) def test_jit(self) -> None: nf, nloc, nnei = self.nlist.shape From 778af5b64d9be9fb7ad7293eb3843a59db9ca7f9 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Wed, 12 Mar 2025 00:05:30 +0800 Subject: [PATCH 7/8] add doc for atomic_weight --- deepmd/dpmodel/atomic_model/base_atomic_model.py | 3 +++ deepmd/pt/model/atomic_model/base_atomic_model.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 63fcbc4dbe..f5abe98fd5 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -171,6 +171,9 @@ def forward_common_atomic( frame parameters, shape: nf x dim_fparam aparam atomic parameter, shape: nf x nloc x dim_aparam + atomic_weight + atomic weights for scaling outputs, shape: nf x nloc x dim_aw + if provided, all output values will be multiplied by this weight. Returns ------- diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index ae8bbe2bd8..02ddd9b501 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -227,6 +227,9 @@ def forward_common_atomic( atomic parameter, shape: nf x nloc x dim_aparam comm_dict The data needed for communication for parallel inference. + atomic_weight + atomic weights for scaling outputs, shape: nf x nloc x dim_aw + if provided, all output values will be multiplied by this weight. Returns ------- From 088b25265a61eef8da0a48f22c3e93a8840e9e94 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Wed, 12 Mar 2025 00:29:59 +0800 Subject: [PATCH 8/8] fix bug in jax2tf model convert --- deepmd/jax/atomic_model/dp_atomic_model.py | 2 + deepmd/jax/jax2tf/make_model.py | 7 +++- deepmd/jax/jax2tf/serialization.py | 46 +++++++++++++++++----- deepmd/jax/jax2tf/tfmodel.py | 17 +++++++- deepmd/jax/model/base_model.py | 12 ++++++ deepmd/jax/model/dp_model.py | 2 + deepmd/jax/utils/serialization.py | 3 +- 7 files changed, 76 insertions(+), 13 deletions(-) diff --git a/deepmd/jax/atomic_model/dp_atomic_model.py b/deepmd/jax/atomic_model/dp_atomic_model.py index adfc22c6fa..49a2882c88 100644 --- a/deepmd/jax/atomic_model/dp_atomic_model.py +++ b/deepmd/jax/atomic_model/dp_atomic_model.py @@ -58,6 +58,7 @@ def forward_common_atomic( mapping: Optional[jnp.ndarray] = None, fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, + atomic_weight: Optional[jnp.ndarray] = None, ) -> dict[str, jnp.ndarray]: return super().forward_common_atomic( extended_coord, @@ -66,6 +67,7 @@ def forward_common_atomic( mapping=mapping, fparam=fparam, aparam=aparam, + atomic_weight=atomic_weight, ) return jax_atomic_model diff --git a/deepmd/jax/jax2tf/make_model.py b/deepmd/jax/jax2tf/make_model.py index 29ed131f8e..f8462c984e 100644 --- a/deepmd/jax/jax2tf/make_model.py +++ b/deepmd/jax/jax2tf/make_model.py @@ -30,6 +30,7 @@ def model_call_from_call_lower( tnp.ndarray, tnp.ndarray, tnp.ndarray, + tnp.ndarray, bool, ], dict[str, tnp.ndarray], @@ -43,6 +44,7 @@ def model_call_from_call_lower( box: tnp.ndarray, fparam: tnp.ndarray, aparam: tnp.ndarray, + atomic_weight: tnp.ndarray, do_atomic_virial: bool = False, ): """Return model prediction from lower interface. @@ -72,8 +74,8 @@ def model_call_from_call_lower( """ atype_shape = tf.shape(atype) nframes, nloc = atype_shape[0], atype_shape[1] - cc, bb, fp, ap = coord, box, fparam, aparam - del coord, box, fparam, aparam + cc, bb, fp, ap, aw = coord, box, fparam, aparam, atomic_weight + del coord, box, fparam, aparam, atomic_weight if tf.shape(bb)[-1] != 0: coord_normalized = normalize_coord( cc.reshape(nframes, nloc, 3), @@ -102,6 +104,7 @@ def model_call_from_call_lower( mapping, fparam=fp, aparam=ap, + atomic_weight=aw, ) model_predict = communicate_extended_output( model_predict_lower, diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index aac022ace9..f04b999419 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -40,7 +40,13 @@ def deserialize_to_file(model_file: str, data: dict) -> None: def exported_whether_do_atomic_virial(do_atomic_virial, has_ghost_atoms): def call_lower_with_fixed_do_atomic_virial( - coord, atype, nlist, mapping, fparam, aparam + coord, + atype, + nlist, + mapping, + fparam, + aparam, + atomic_weight, ): return call_lower( coord, @@ -49,6 +55,7 @@ def call_lower_with_fixed_do_atomic_virial( mapping, fparam, aparam, + atomic_weight=atomic_weight, do_atomic_virial=do_atomic_virial, ) @@ -68,12 +75,14 @@ def call_lower_with_fixed_do_atomic_virial( f"(nf, nloc + {nghost})", f"(nf, {model.get_dim_fparam()})", f"(nf, nloc, {model.get_dim_aparam()})", + "(nf, nloc, 1)", ], with_gradient=True, ) # Save a function that can take scalar inputs. # We need to explicit set the function name, so C++ can find it. + # bug: replace 1 with fitting output dim @tf.function( autograph=False, input_signature=[ @@ -83,24 +92,32 @@ def call_lower_with_fixed_do_atomic_virial( tf.TensorSpec([None, None], tf.int64), tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + tf.TensorSpec([None, None, 1], tf.float64), ], ) def call_lower_without_atomic_virial( - coord, atype, nlist, mapping, fparam, aparam + coord, + atype, + nlist, + mapping, + fparam, + aparam, + atomic_weight, ): nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) return tf.cond( tf.shape(coord)[1] == tf.shape(nlist)[1], lambda: exported_whether_do_atomic_virial( do_atomic_virial=False, has_ghost_atoms=False - )(coord, atype, nlist, mapping, fparam, aparam), + )(coord, atype, nlist, mapping, fparam, aparam, atomic_weight), lambda: exported_whether_do_atomic_virial( do_atomic_virial=False, has_ghost_atoms=True - )(coord, atype, nlist, mapping, fparam, aparam), + )(coord, atype, nlist, mapping, fparam, aparam, atomic_weight), ) tf_model.call_lower = call_lower_without_atomic_virial + # bug: replace 1 with fitting output dim @tf.function( autograph=False, input_signature=[ @@ -110,18 +127,21 @@ def call_lower_without_atomic_virial( tf.TensorSpec([None, None], tf.int64), tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + tf.TensorSpec([None, None, 1], tf.float64), ], ) - def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam): + def call_lower_with_atomic_virial( + coord, atype, nlist, mapping, fparam, aparam, atomic_weight + ): nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) return tf.cond( tf.shape(coord)[1] == tf.shape(nlist)[1], lambda: exported_whether_do_atomic_virial( do_atomic_virial=True, has_ghost_atoms=False - )(coord, atype, nlist, mapping, fparam, aparam), + )(coord, atype, nlist, mapping, fparam, aparam, atomic_weight), lambda: exported_whether_do_atomic_virial( do_atomic_virial=True, has_ghost_atoms=True - )(coord, atype, nlist, mapping, fparam, aparam), + )(coord, atype, nlist, mapping, fparam, aparam, atomic_weight), ) tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial @@ -138,6 +158,7 @@ def call( box: Optional[tnp.ndarray] = None, fparam: Optional[tnp.ndarray] = None, aparam: Optional[tnp.ndarray] = None, + atomic_weight: Optional[tnp.ndarray] = None, ): """Return model prediction. @@ -173,11 +194,13 @@ def call( box=box, fparam=fparam, aparam=aparam, + atomic_weight=atomic_weight, do_atomic_virial=do_atomic_virial, ) return call + # bug: replace 1 with fitting output dim @tf.function( autograph=True, input_signature=[ @@ -186,6 +209,7 @@ def call( tf.TensorSpec([None, None, None], tf.float64), tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + tf.TensorSpec([None, None, 1], tf.float64), ], ) def call_with_atomic_virial( @@ -194,13 +218,15 @@ def call_with_atomic_virial( box: tnp.ndarray, fparam: tnp.ndarray, aparam: tnp.ndarray, + atomic_weight: tnp.ndarray, ): return make_call_whether_do_atomic_virial(do_atomic_virial=True)( - coord, atype, box, fparam, aparam + coord, atype, box, fparam, aparam, atomic_weight ) tf_model.call_atomic_virial = call_with_atomic_virial + # bug: replace 1 with fitting output dim @tf.function( autograph=True, input_signature=[ @@ -209,6 +235,7 @@ def call_with_atomic_virial( tf.TensorSpec([None, None, None], tf.float64), tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + tf.TensorSpec([None, None, 1], tf.float64), ], ) def call_without_atomic_virial( @@ -217,9 +244,10 @@ def call_without_atomic_virial( box: tnp.ndarray, fparam: tnp.ndarray, aparam: tnp.ndarray, + atomic_weight: tnp.ndarray, ): return make_call_whether_do_atomic_virial(do_atomic_virial=False)( - coord, atype, box, fparam, aparam + coord, atype, box, fparam, aparam, atomic_weight ) tf_model.call = call_without_atomic_virial diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py index 0d7b13ba1f..64c6fb7540 100644 --- a/deepmd/jax/jax2tf/tfmodel.py +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -79,6 +79,7 @@ def __call__( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, + atomic_weight: Optional[jnp.ndarray] = None, ) -> Any: """Return model prediction. @@ -105,7 +106,9 @@ def __call__( The keys are defined by the `ModelOutputDef`. """ - return self.call(coord, atype, box, fparam, aparam, do_atomic_virial) + return self.call( + coord, atype, box, fparam, aparam, do_atomic_virial, atomic_weight + ) def call( self, @@ -115,6 +118,7 @@ def call( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, + atomic_weight: Optional[jnp.ndarray] = None, ): """Return model prediction. @@ -157,12 +161,17 @@ def call( (coord.shape[0], coord.shape[1], self.get_dim_aparam()), dtype=jnp.float64, ) + if atomic_weight is None: + atomic_weight = jnp.empty( + (coord.shape[0], coord.shape[1], 1), dtype=jnp.float64 + ) return call( coord, atype, box, fparam, aparam, + atomic_weight, ) def model_output_def(self): @@ -179,6 +188,7 @@ def call_lower( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, + atomic_weight: Optional[jnp.ndarray] = None, ): if do_atomic_virial: call_lower = self._call_lower_atomic_virial @@ -194,6 +204,10 @@ def call_lower( (extended_coord.shape[0], nlist.shape[1], self.get_dim_aparam()), dtype=jnp.float64, ) + if atomic_weight is None: + atomic_weight = jnp.empty( + (extended_coord.shape[0], nlist.shape[1], 1), dtype=jnp.float + ) return call_lower( extended_coord, extended_atype, @@ -201,6 +215,7 @@ def call_lower( mapping, fparam, aparam, + atomic_weight, ) def get_type_map(self) -> list[str]: diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 5ca372c86a..4af0395a15 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -27,6 +27,7 @@ def forward_common_atomic( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, + atomic_weight: Optional[jnp.ndarray] = None, ): atomic_ret = self.atomic_model.forward_common_atomic( extended_coord, @@ -35,6 +36,7 @@ def forward_common_atomic( mapping=mapping, fparam=fparam, aparam=aparam, + atomic_weight=atomic_weight, ) atomic_output_def = self.atomic_output_def() model_predict = {} @@ -56,6 +58,7 @@ def eval_output( mapping, fparam, aparam, + atomic_weight, *, _kk=kk, _atom_axis=atom_axis, @@ -67,6 +70,9 @@ def eval_output( mapping=mapping[None, ...] if mapping is not None else None, fparam=fparam[None, ...] if fparam is not None else None, aparam=aparam[None, ...] if aparam is not None else None, + atomic_weight=atomic_weight[None, ...] + if atomic_weight is not None + else None, ) return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis) @@ -79,6 +85,7 @@ def eval_output( mapping, fparam, aparam, + atomic_weight, ) # extended_force: [nf, nall, *def, 3] def_ndim = len(vdef.shape) @@ -101,6 +108,7 @@ def eval_ce( mapping, fparam, aparam, + atomic_weight, *, _kk=kk, _atom_axis=atom_axis - 1, @@ -113,6 +121,9 @@ def eval_ce( mapping=mapping[None, ...] if mapping is not None else None, fparam=fparam[None, ...] if fparam is not None else None, aparam=aparam[None, ...] if aparam is not None else None, + atomic_weight=atomic_weight[None, ...] + if atomic_weight is not None + else None, ) nloc = nlist.shape[0] cc_loc = jax.lax.stop_gradient(cc_ext)[:nloc, ...] @@ -130,6 +141,7 @@ def eval_ce( mapping, fparam, aparam, + atomic_weight, ) # move the first 3 to the last # [nf, *def, nall, 3, 3] diff --git a/deepmd/jax/model/dp_model.py b/deepmd/jax/model/dp_model.py index 436582f22b..15e67780bc 100644 --- a/deepmd/jax/model/dp_model.py +++ b/deepmd/jax/model/dp_model.py @@ -56,6 +56,7 @@ def forward_common_atomic( fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, + atomic_weight: Optional[jnp.ndarray] = None, ): return forward_common_atomic( self, @@ -66,6 +67,7 @@ def forward_common_atomic( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + atomic_weight=atomic_weight, ) def format_nlist( diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 5d4da49e08..a1a170e594 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -57,7 +57,7 @@ def exported_whether_do_atomic_virial( do_atomic_virial: bool, has_ghost_atoms: bool ): def call_lower_with_fixed_do_atomic_virial( - coord, atype, nlist, mapping, fparam, aparam + coord, atype, nlist, mapping, fparam, aparam, atomic_weight ): return call_lower( coord, @@ -66,6 +66,7 @@ def call_lower_with_fixed_do_atomic_virial( mapping, fparam, aparam, + atomic_weight, do_atomic_virial=do_atomic_virial, )