From 2fdf40493f4d07e67bb2145d1288284edeef13c2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 11 Mar 2025 20:09:37 +0800 Subject: [PATCH 1/6] feat(jax): Hessian Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/model/ener_model.py | 19 +++ deepmd/dpmodel/model/transform_output.py | 92 +++++++++++ deepmd/jax/model/base_model.py | 13 ++ source/tests/jax/test_dp_hessian_model.py | 106 ++++++++++++ source/tests/jax/test_make_hessian_model.py | 171 ++++++++++++++++++++ 5 files changed, 401 insertions(+) create mode 100644 source/tests/jax/test_dp_hessian_model.py create mode 100644 source/tests/jax/test_make_hessian_model.py diff --git a/deepmd/dpmodel/model/ener_model.py b/deepmd/dpmodel/model/ener_model.py index e4233eb397..88e65a849a 100644 --- a/deepmd/dpmodel/model/ener_model.py +++ b/deepmd/dpmodel/model/ener_model.py @@ -1,10 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from copy import ( + deepcopy, +) + from deepmd.dpmodel.atomic_model import ( DPEnergyAtomicModel, ) from deepmd.dpmodel.model.base_model import ( BaseModel, ) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, +) from .dp_model import ( DPModelCommon, @@ -25,3 +32,15 @@ def __init__( ) -> None: DPModelCommon.__init__(self) DPEnergyModel_.__init__(self, *args, **kwargs) + self._enable_hessian = False + self.hess_fitting_def = None + + def enable_hessian(self): + self.hess_fitting_def = deepcopy(self.atomic_output_def()) + self.hess_fitting_def["energy"].r_hessian = True + self._enable_hessian = True + + def atomic_output_def(self) -> FittingOutputDef: + if self._enable_hessian: + return self.hess_fitting_def + return super().atomic_output_def() diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index af1429ce25..9b71fda285 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -11,6 +11,7 @@ ModelOutputDef, OutputVariableDef, get_deriv_name, + get_hessian_name, get_reduce_name, ) @@ -81,6 +82,7 @@ def communicate_extended_output( """ xp = array_api_compat.get_namespace(mapping) + mapping_ = mapping new_ret = {} for kk in model_output_def.keys_outp(): vv = model_ret[kk] @@ -116,6 +118,96 @@ def communicate_extended_output( else: # name holders new_ret[kk_derv_r] = None + if vdef.r_hessian: + kk_hess = get_hessian_name(kk) + if model_ret[kk_hess] is not None: + # jax only + if array_api_compat.is_jax_array(force): + from deepmd.jax.common import ( + scatter_sum, + ) + from deepmd.jax.env import ( + jnp, + ) + + # [nf, *def, nall, 3, nall, 3] + hess_ = model_ret[kk_hess] + def_ndim = len(vdef.shape) + # [nf, nall1, nall2, *def, 3(1), 3(2)] + hess_1 = jnp.transpose( + hess_, + ( + 0, + def_ndim + 1, + def_ndim + 3, + *range(1, def_ndim + 1), + def_ndim + 2, + def_ndim + 4, + ), + ) + nall = hess_1.shape[1] + # (1) -> [nf, nloc1, nall2, *def, 3(1), 3(2)] + hessian1 = jnp.zeros( + [*vldims, nall, *vdef.shape, 3, 3], dtype=vv.dtype + ) + mapping_hess = xp.reshape( + mapping_, (mldims + [1] * (len(vdef.shape) + 3)) + ) + mapping_hess = xp.tile( + mapping_hess, + [1] * len(mldims) + [nall, *vdef.shape, 3, 3], + ) + hessian1 = scatter_sum( + hessian1, + 1, + mapping_hess, + hess_1, + ) + # [nf, nall2, nloc1, *def, 3(1), 3(2)] + hessian1 = jnp.transpose( + hessian1, + (0, 2, 1, *range(3, def_ndim + 5)), + ) + nloc = hessian1.shape[2] + # (2) -> [nf, nloc2, nloc1, *def, 3(1), 3(2)] + hessian = jnp.zeros( + [*vldims, nloc, *vdef.shape, 3, 3], dtype=vv.dtype + ) + mapping_hess = xp.reshape( + mapping_, (mldims + [1] * (len(vdef.shape) + 3)) + ) + mapping_hess = xp.tile( + mapping_hess, + [1] * len(mldims) + [nloc, *vdef.shape, 3, 3], + ) + hessian = scatter_sum( + hessian, + 1, + mapping_hess, + hessian1, + ) + # -> [nf, *def, nloc1, 3(1), nloc2, 3(2)] + hessian = jnp.transpose( + hessian, + ( + 0, + *range(3, def_ndim + 3), + 2, + def_ndim + 3, + 1, + def_ndim + 4, + ), + ) + # -> [nf, *def, nloc1 * 3, nloc2 * 3] + hessian = jnp.reshape( + hessian, + (hessian.shape[0], *vdef.shape, nloc * 3, nloc * 3), + ) + else: + raise NotImplementedError("Only JAX arrays are supported.") + new_ret[kk_hess] = hessian + else: + new_ret[kk_hess] = None if vdef.c_differentiable: assert vdef.r_differentiable if model_ret[kk_derv_c] is not None: diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 5ca372c86a..7c97ff692f 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -8,6 +8,7 @@ ) from deepmd.dpmodel.output_def import ( get_deriv_name, + get_hessian_name, get_reduce_name, ) from deepmd.jax.env import ( @@ -87,6 +88,18 @@ def eval_output( ) model_predict[kk_derv_r] = extended_force + if vdef.r_hessian: + # [nf, *def, nall, 3, nall, 3] + hessian = jax.vmap(jax.hessian(eval_output, argnums=0))( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) + kk_hessian = get_hessian_name(kk) + model_predict[kk_hessian] = hessian if vdef.c_differentiable: assert vdef.r_differentiable # avr: [nf, *def, nall, 3, 3] diff --git a/source/tests/jax/test_dp_hessian_model.py b/source/tests/jax/test_dp_hessian_model.py new file mode 100644 index 0000000000..bd7aaa62ef --- /dev/null +++ b/source/tests/jax/test_dp_hessian_model.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.dpmodel.common import ( + to_numpy_array, +) +from deepmd.jax.common import ( + to_jax_array, +) +from deepmd.jax.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.jax.env import ( + jnp, +) +from deepmd.jax.fitting.fitting import ( + EnergyFittingNet, +) +from deepmd.jax.model.ener_model import ( + EnergyModel, +) + +dtype = jnp.float64 + + +class TestCaseSingleFrameWithoutNlist: + def setUp(self) -> None: + # nloc == 3, nall == 4 + self.nloc = 3 + self.nf, self.nt = 1, 2 + self.coord = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], + dtype=np.float64, + ).reshape([1, self.nloc * 3]) + self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc]) + self.cell = 2.0 * np.eye(3).reshape([1, 9]) + # sel = [5, 2] + self.sel = [16, 8] + self.sel_mix = [24] + self.natoms = [3, 3, 2, 1] + self.rcut = 2.2 + self.rcut_smth = 0.4 + self.atol = 1e-12 + + +class TestEnergyHessianModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist): + def setUp(self): + TestCaseSingleFrameWithoutNlist.setUp(self) + + def test_self_consistency(self): + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = EnergyFittingNet( + self.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ) + type_map = ["foo", "bar"] + md0 = EnergyModel(ds, ft, type_map=type_map) + md1 = EnergyModel.deserialize(md0.serialize()) + md0.enable_hessian() + md1.enable_hessian() + args = [to_jax_array(ii) for ii in [self.coord, self.atype, self.cell]] + ret0 = md0.call(*args) + ret1 = md1.call(*args) + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_redu"]), + to_numpy_array(ret1["energy_redu"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_r"]), + to_numpy_array(ret1["energy_derv_r"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c_redu"]), + to_numpy_array(ret1["energy_derv_c_redu"]), + atol=self.atol, + ) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_r_derv_r"]), + to_numpy_array(ret1["energy_derv_r_derv_r"]), + atol=self.atol, + ) + ret0 = md0.call(*args, do_atomic_virial=True) + ret1 = md1.call(*args, do_atomic_virial=True) + np.testing.assert_allclose( + to_numpy_array(ret0["energy_derv_c"]), + to_numpy_array(ret1["energy_derv_c"]), + atol=self.atol, + ) diff --git a/source/tests/jax/test_make_hessian_model.py b/source/tests/jax/test_make_hessian_model.py new file mode 100644 index 0000000000..00c7c4f09b --- /dev/null +++ b/source/tests/jax/test_make_hessian_model.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.dpmodel.common import ( + to_numpy_array, +) +from deepmd.dpmodel.output_def import ( + OutputVariableCategory, +) +from deepmd.jax.common import ( + to_jax_array, +) +from deepmd.jax.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.jax.env import ( + jax, + jnp, +) +from deepmd.jax.fitting.fitting import ( + EnergyFittingNet, +) +from deepmd.jax.model import ( + EnergyModel, +) + +from ..seed import ( + GLOBAL_SEED, +) + +dtype = jnp.float64 + + +def finite_hessian(f, x, delta=1e-6): + in_shape = x.shape + assert len(in_shape) == 1 + y0 = f(x) + out_shape = y0.shape + res = np.empty(out_shape + in_shape + in_shape) + for iidx in np.ndindex(*in_shape): + for jidx in np.ndindex(*in_shape): + i0 = np.zeros(in_shape) + i1 = np.zeros(in_shape) + i2 = np.zeros(in_shape) + i3 = np.zeros(in_shape) + i0[iidx] += delta + i2[iidx] += delta + i1[iidx] -= delta + i3[iidx] -= delta + i0[jidx] += delta + i1[jidx] += delta + i2[jidx] -= delta + i3[jidx] -= delta + y0 = f(x + i0) + y1 = f(x + i1) + y2 = f(x + i2) + y3 = f(x + i3) + res[(Ellipsis, *iidx, *jidx)] = (y0 + y3 - y1 - y2) / (4 * delta**2.0) + return res + + +class HessianTest: + def test( + self, + ) -> None: + # setup test case + places = 6 + delta = 1e-3 + natoms = self.nloc + nf = self.nf + nv = self.nv + generator = jax.random.key(GLOBAL_SEED) + cell0 = jax.random.uniform(generator, [3, 3], dtype=dtype) + cell0 = 1.0 * (cell0 + cell0.T) + 5.0 * jnp.eye(3) + cell1 = jax.random.uniform(generator, [3, 3], dtype=dtype) + cell1 = 1.0 * (cell1 + cell1.T) + 5.0 * jnp.eye(3) + cell = jnp.stack([cell0, cell1]) + coord = jax.random.uniform(generator, [nf, natoms, 3], dtype=dtype) + coord = jnp.matmul(coord, cell) + cell = cell.reshape([nf, 9]) + coord = coord.reshape([nf, natoms * 3]) + atype = jnp.stack( + [ + jnp.asarray([0, 0, 1], dtype=jnp.int64), + jnp.asarray([1, 0, 1], dtype=jnp.int64), + ] + ).reshape([nf, natoms]) + nfp, nap = 2, 3 + fparam = jax.random.uniform(generator, [nf, nfp], dtype=dtype) + aparam = jax.random.uniform(generator, [nf, natoms * nap], dtype=dtype) + # forward hess and value models + ret_dict0 = self.model_hess( + coord, atype, box=cell, fparam=fparam, aparam=aparam + ) + ret_dict1 = self.model_valu( + coord, atype, box=cell, fparam=fparam, aparam=aparam + ) + # compare hess and value models + np.testing.assert_allclose(ret_dict0["energy"], ret_dict1["energy"]) + ana_hess = ret_dict0["energy_derv_r_derv_r"] + + # compute finite difference + fnt_hess = [] + for ii in range(nf): + + def np_infer( + xx, + ): + ret = self.model_valu( + to_jax_array(xx)[None, ...], + atype[ii][None, ...], + box=cell[ii][None, ...], + fparam=fparam[ii][None, ...], + aparam=aparam[ii][None, ...], + ) + # detach + ret = {kk: to_numpy_array(ret[kk]) for kk in ret} + return ret + + def ff(xx): + return np_infer(xx)["energy_redu"] + + xx = to_numpy_array(coord[ii]) + fnt_hess.append(finite_hessian(ff, xx, delta=delta).squeeze()) + + # compare finite difference with autodiff + fnt_hess = np.stack(fnt_hess).reshape([nf, nv, natoms * 3, natoms * 3]) + np.testing.assert_almost_equal( + fnt_hess, to_numpy_array(ana_hess), decimal=places + ) + + +class TestDPModel(unittest.TestCase, HessianTest): + def setUp(self) -> None: + jax.random.key(2) + self.nf = 2 + self.nloc = 3 + self.rcut = 4.0 + self.rcut_smth = 3.0 + self.sel = [10, 10] + self.nt = 2 + self.nv = 1 + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + neuron=[2, 4, 8], + axis_neuron=2, + ) + ft0 = EnergyFittingNet( + self.nt, + ds.get_dim_out(), + # self.nv, + mixed_types=ds.mixed_types(), + neuron=[4, 4, 4], + ) + type_map = ["foo", "bar"] + self.model_hess = EnergyModel(ds, ft0, type_map=type_map) + self.model_hess.enable_hessian() + self.model_valu = EnergyModel.deserialize(self.model_hess.serialize()) + + def test_output_def(self) -> None: + self.assertTrue(self.model_hess.atomic_output_def()["energy"].r_hessian) + self.assertFalse(self.model_valu.atomic_output_def()["energy"].r_hessian) + self.assertTrue(self.model_hess.model_output_def()["energy"].r_hessian) + self.assertEqual( + self.model_hess.model_output_def()["energy_derv_r_derv_r"].category, + OutputVariableCategory.DERV_R_DERV_R, + ) From 36d751275eb48346a004578e2b4d0fe7a464d443 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 11 Mar 2025 20:31:07 +0800 Subject: [PATCH 2/6] skip Python 3.9 Signed-off-by: Jinzhe Zeng --- source/tests/jax/test_dp_hessian_model.py | 9 +++++++++ source/tests/jax/test_make_hessian_model.py | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/source/tests/jax/test_dp_hessian_model.py b/source/tests/jax/test_dp_hessian_model.py index bd7aaa62ef..89c066e980 100644 --- a/source/tests/jax/test_dp_hessian_model.py +++ b/source/tests/jax/test_dp_hessian_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import sys import unittest import numpy as np @@ -25,6 +26,10 @@ dtype = jnp.float64 +@unittest.skipIf( + sys.version_info < (3, 10), + "JAX requires Python 3.10 or later", +) class TestCaseSingleFrameWithoutNlist: def setUp(self) -> None: # nloc == 3, nall == 4 @@ -49,6 +54,10 @@ def setUp(self) -> None: self.atol = 1e-12 +@unittest.skipIf( + sys.version_info < (3, 10), + "JAX requires Python 3.10 or later", +) class TestEnergyHessianModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist): def setUp(self): TestCaseSingleFrameWithoutNlist.setUp(self) diff --git a/source/tests/jax/test_make_hessian_model.py b/source/tests/jax/test_make_hessian_model.py index 00c7c4f09b..7c4bd12dea 100644 --- a/source/tests/jax/test_make_hessian_model.py +++ b/source/tests/jax/test_make_hessian_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import sys import unittest import numpy as np @@ -132,6 +133,10 @@ def ff(xx): ) +@unittest.skipIf( + sys.version_info < (3, 10), + "JAX requires Python 3.10 or later", +) class TestDPModel(unittest.TestCase, HessianTest): def setUp(self) -> None: jax.random.key(2) From 6a77d42333548c84343efadf317e6a5847c5209c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 11 Mar 2025 21:19:08 +0800 Subject: [PATCH 3/6] skip importing jax for py39 Signed-off-by: Jinzhe Zeng --- source/tests/jax/test_dp_hessian_model.py | 34 ++++++++-------- source/tests/jax/test_make_hessian_model.py | 44 +++++++++++---------- 2 files changed, 41 insertions(+), 37 deletions(-) diff --git a/source/tests/jax/test_dp_hessian_model.py b/source/tests/jax/test_dp_hessian_model.py index 89c066e980..798b893651 100644 --- a/source/tests/jax/test_dp_hessian_model.py +++ b/source/tests/jax/test_dp_hessian_model.py @@ -7,23 +7,25 @@ from deepmd.dpmodel.common import ( to_numpy_array, ) -from deepmd.jax.common import ( - to_jax_array, -) -from deepmd.jax.descriptor.se_e2_a import ( - DescrptSeA, -) -from deepmd.jax.env import ( - jnp, -) -from deepmd.jax.fitting.fitting import ( - EnergyFittingNet, -) -from deepmd.jax.model.ener_model import ( - EnergyModel, -) -dtype = jnp.float64 +if sys.version_info >= (3, 10): + from deepmd.jax.common import ( + to_jax_array, + ) + from deepmd.jax.descriptor.se_e2_a import ( + DescrptSeA, + ) + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import ( + EnergyFittingNet, + ) + from deepmd.jax.model.ener_model import ( + EnergyModel, + ) + + dtype = jnp.float64 @unittest.skipIf( diff --git a/source/tests/jax/test_make_hessian_model.py b/source/tests/jax/test_make_hessian_model.py index 7c4bd12dea..79397990b3 100644 --- a/source/tests/jax/test_make_hessian_model.py +++ b/source/tests/jax/test_make_hessian_model.py @@ -10,28 +10,30 @@ from deepmd.dpmodel.output_def import ( OutputVariableCategory, ) -from deepmd.jax.common import ( - to_jax_array, -) -from deepmd.jax.descriptor.se_e2_a import ( - DescrptSeA, -) -from deepmd.jax.env import ( - jax, - jnp, -) -from deepmd.jax.fitting.fitting import ( - EnergyFittingNet, -) -from deepmd.jax.model import ( - EnergyModel, -) - -from ..seed import ( - GLOBAL_SEED, -) -dtype = jnp.float64 +if sys.version_info >= (3, 10): + from deepmd.jax.common import ( + to_jax_array, + ) + from deepmd.jax.descriptor.se_e2_a import ( + DescrptSeA, + ) + from deepmd.jax.env import ( + jax, + jnp, + ) + from deepmd.jax.fitting.fitting import ( + EnergyFittingNet, + ) + from deepmd.jax.model import ( + EnergyModel, + ) + + from ..seed import ( + GLOBAL_SEED, + ) + + dtype = jnp.float64 def finite_hessian(f, x, delta=1e-6): From 65525fa0fcf68cd20b5a86fe9157c3ab2305b710 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 11 Mar 2025 22:31:58 +0800 Subject: [PATCH 4/6] fix(tests): adjust precision in Hessian test case Signed-off-by: Jinzhe Zeng --- source/tests/jax/test_make_hessian_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/jax/test_make_hessian_model.py b/source/tests/jax/test_make_hessian_model.py index 79397990b3..185660e2be 100644 --- a/source/tests/jax/test_make_hessian_model.py +++ b/source/tests/jax/test_make_hessian_model.py @@ -69,7 +69,7 @@ def test( self, ) -> None: # setup test case - places = 6 + places = 5 delta = 1e-3 natoms = self.nloc nf = self.nf From df4ef252907d8f66314518ba450e1e93dc82e5ab Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 12 Mar 2025 12:56:35 +0800 Subject: [PATCH 5/6] set XLA_PYTHON_CLIENT_ALLOCATOR to platform Signed-off-by: Jinzhe Zeng --- .github/workflows/test_cuda.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index b6ea80cf32..47e129d2b4 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -64,6 +64,7 @@ jobs: CUDA_VISIBLE_DEVICES: 0 # See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html XLA_PYTHON_CLIENT_PREALLOCATE: false + XLA_PYTHON_CLIENT_ALLOCATOR: platform - name: Convert models run: source/tests/infer/convert-models.sh - name: Download libtorch From 39ea7a8298723ff165375e9d98ce056ba68c5fb8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 13 Mar 2025 19:08:19 +0800 Subject: [PATCH 6/6] array api implementation for communicate_extended_output Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/array_api.py | 19 +++ deepmd/dpmodel/model/transform_output.py | 179 +++++++++++------------ 2 files changed, 101 insertions(+), 97 deletions(-) diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index e5c0557851..4d6db2521f 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -2,6 +2,7 @@ """Utilities for the array API.""" import array_api_compat +import numpy as np from packaging.version import ( Version, ) @@ -73,3 +74,21 @@ def xp_take_along_axis(arr, indices, axis): out = xp.take(arr, indices) out = xp.reshape(out, shape) return xp_swapaxes(out, axis, -1) + + +def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray: + """Reduces all values from the src tensor to the indices specified in the index tensor.""" + # jax only + if array_api_compat.is_jax_array(input): + from deepmd.jax.common import ( + scatter_sum, + ) + + return scatter_sum( + input, + dim, + index, + src, + ) + else: + raise NotImplementedError("Only JAX arrays are supported.") diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index 9b71fda285..9d7873f081 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -3,6 +3,9 @@ import array_api_compat import numpy as np +from deepmd.dpmodel.array_api import ( + xp_scatter_sum, +) from deepmd.dpmodel.common import ( GLOBAL_ENER_FLOAT_PRECISION, ) @@ -100,20 +103,12 @@ def communicate_extended_output( mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims))) mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims) force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype) - # jax only - if array_api_compat.is_jax_array(force): - from deepmd.jax.common import ( - scatter_sum, - ) - - force = scatter_sum( - force, - 1, - mapping, - model_ret[kk_derv_r], - ) - else: - raise NotImplementedError("Only JAX arrays are supported.") + force = xp_scatter_sum( + force, + 1, + mapping, + model_ret[kk_derv_r], + ) new_ret[kk_derv_r] = force else: # name holders @@ -121,90 +116,80 @@ def communicate_extended_output( if vdef.r_hessian: kk_hess = get_hessian_name(kk) if model_ret[kk_hess] is not None: - # jax only - if array_api_compat.is_jax_array(force): - from deepmd.jax.common import ( - scatter_sum, - ) - from deepmd.jax.env import ( - jnp, - ) - - # [nf, *def, nall, 3, nall, 3] - hess_ = model_ret[kk_hess] - def_ndim = len(vdef.shape) - # [nf, nall1, nall2, *def, 3(1), 3(2)] - hess_1 = jnp.transpose( - hess_, - ( - 0, - def_ndim + 1, - def_ndim + 3, - *range(1, def_ndim + 1), - def_ndim + 2, - def_ndim + 4, - ), - ) - nall = hess_1.shape[1] - # (1) -> [nf, nloc1, nall2, *def, 3(1), 3(2)] - hessian1 = jnp.zeros( - [*vldims, nall, *vdef.shape, 3, 3], dtype=vv.dtype - ) - mapping_hess = xp.reshape( - mapping_, (mldims + [1] * (len(vdef.shape) + 3)) - ) - mapping_hess = xp.tile( - mapping_hess, - [1] * len(mldims) + [nall, *vdef.shape, 3, 3], - ) - hessian1 = scatter_sum( - hessian1, - 1, - mapping_hess, - hess_1, - ) - # [nf, nall2, nloc1, *def, 3(1), 3(2)] - hessian1 = jnp.transpose( - hessian1, - (0, 2, 1, *range(3, def_ndim + 5)), - ) - nloc = hessian1.shape[2] - # (2) -> [nf, nloc2, nloc1, *def, 3(1), 3(2)] - hessian = jnp.zeros( - [*vldims, nloc, *vdef.shape, 3, 3], dtype=vv.dtype - ) - mapping_hess = xp.reshape( - mapping_, (mldims + [1] * (len(vdef.shape) + 3)) - ) - mapping_hess = xp.tile( - mapping_hess, - [1] * len(mldims) + [nloc, *vdef.shape, 3, 3], - ) - hessian = scatter_sum( - hessian, + # [nf, *def, nall, 3, nall, 3] + hess_ = model_ret[kk_hess] + def_ndim = len(vdef.shape) + # [nf, nall1, nall2, *def, 3(1), 3(2)] + hess_1 = xp.permute_dims( + hess_, + ( + 0, + def_ndim + 1, + def_ndim + 3, + *range(1, def_ndim + 1), + def_ndim + 2, + def_ndim + 4, + ), + ) + nall = hess_1.shape[1] + # (1) -> [nf, nloc1, nall2, *def, 3(1), 3(2)] + hessian1 = xp.zeros( + [*vldims, nall, *vdef.shape, 3, 3], dtype=vv.dtype + ) + mapping_hess = xp.reshape( + mapping_, (mldims + [1] * (len(vdef.shape) + 3)) + ) + mapping_hess = xp.tile( + mapping_hess, + [1] * len(mldims) + [nall, *vdef.shape, 3, 3], + ) + hessian1 = xp_scatter_sum( + hessian1, + 1, + mapping_hess, + hess_1, + ) + # [nf, nall2, nloc1, *def, 3(1), 3(2)] + hessian1 = xp.permute_dims( + hessian1, + (0, 2, 1, *range(3, def_ndim + 5)), + ) + nloc = hessian1.shape[2] + # (2) -> [nf, nloc2, nloc1, *def, 3(1), 3(2)] + hessian = xp.zeros( + [*vldims, nloc, *vdef.shape, 3, 3], dtype=vv.dtype + ) + mapping_hess = xp.reshape( + mapping_, (mldims + [1] * (len(vdef.shape) + 3)) + ) + mapping_hess = xp.tile( + mapping_hess, + [1] * len(mldims) + [nloc, *vdef.shape, 3, 3], + ) + hessian = xp_scatter_sum( + hessian, + 1, + mapping_hess, + hessian1, + ) + # -> [nf, *def, nloc1, 3(1), nloc2, 3(2)] + hessian = xp.permute_dims( + hessian, + ( + 0, + *range(3, def_ndim + 3), + 2, + def_ndim + 3, 1, - mapping_hess, - hessian1, - ) - # -> [nf, *def, nloc1, 3(1), nloc2, 3(2)] - hessian = jnp.transpose( - hessian, - ( - 0, - *range(3, def_ndim + 3), - 2, - def_ndim + 3, - 1, - def_ndim + 4, - ), - ) - # -> [nf, *def, nloc1 * 3, nloc2 * 3] - hessian = jnp.reshape( - hessian, - (hessian.shape[0], *vdef.shape, nloc * 3, nloc * 3), - ) - else: - raise NotImplementedError("Only JAX arrays are supported.") + def_ndim + 4, + ), + ) + # -> [nf, *def nloc1 * 3, nloc2 * 3] + hessian = xp.reshape( + hessian, + (hessian.shape[0], *vdef.shape, nloc * 3, nloc * 3), + ) + new_ret[kk_hess] = hessian else: new_ret[kk_hess] = None