From 6eabd6309c1804ad7f0fbd7a0ae825fadf624d93 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 30 Jun 2026 02:16:21 +0800 Subject: [PATCH] feat(jax): freeze models with hessian output --- deepmd/jax/entrypoints/freeze.py | 5 ++- deepmd/jax/infer/deep_eval.py | 5 +++ deepmd/jax/jax2tf/serialization.py | 9 ++++-- deepmd/jax/model/hlo.py | 21 +++++++++++- deepmd/jax/utils/serialization.py | 17 +++++++--- deepmd/main.py | 6 ++++ source/tests/jax/test_training.py | 51 ++++++++++++++++++++++++++++-- 7 files changed, 103 insertions(+), 11 deletions(-) diff --git a/deepmd/jax/entrypoints/freeze.py b/deepmd/jax/entrypoints/freeze.py index fbc126ffc7..03536d4031 100644 --- a/deepmd/jax/entrypoints/freeze.py +++ b/deepmd/jax/entrypoints/freeze.py @@ -18,6 +18,7 @@ def freeze( *, checkpoint_folder: str, output: str, + hessian: bool = False, **kwargs: object, ) -> None: """Freeze a JAX checkpoint into a serialized model file. @@ -30,6 +31,8 @@ def freeze( output : str Output model filename or prefix. The JAX model suffix is added when the filename has no supported backend suffix. + hessian : bool, default=False + Whether to include the Hessian in the frozen model outputs. **kwargs Other CLI arguments accepted for backend entry-point compatibility. """ @@ -46,4 +49,4 @@ def freeze( strict_prefer=True, ) data = serialize_from_file(checkpoint_folder) - deserialize_to_file(output, data) + deserialize_to_file(output, data, hessian=hessian) diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 2e028225f7..2cfa362998 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -281,6 +281,7 @@ def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]: OutputVariableCategory.REDU, OutputVariableCategory.DERV_R, OutputVariableCategory.DERV_C_REDU, + OutputVariableCategory.DERV_R_DERV_R, ) ] @@ -433,6 +434,10 @@ def get_model_def_script(self) -> dict: """Get model definition script.""" return json.loads(self.dp.get_model_def_script()) + def get_has_hessian(self) -> bool: + """Check if the model has Hessian output.""" + return self.get_model_def_script().get("hessian_mode", False) + def get_model(self) -> Any: """Get the JAX model as BaseModel. diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index 4881ca98f8..0ab45fd56c 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -21,7 +21,7 @@ ) -def deserialize_to_file(model_file: str, data: dict) -> None: +def deserialize_to_file(model_file: str, data: dict, hessian: bool = False) -> None: """Deserialize the dictionary to a model file. Parameters @@ -30,10 +30,15 @@ def deserialize_to_file(model_file: str, data: dict) -> None: The model file to be saved. data : dict The dictionary to be deserialized. + hessian : bool, default=False + Whether to include the Hessian in the model outputs. """ if model_file.endswith(".savedmodel"): model = BaseModel.deserialize(data["model"]) - model_def_script = data["model_def_script"] + model_def_script = data["model_def_script"].copy() + if hessian: + model.enable_hessian() + model_def_script["hessian_mode"] = True call_lower = model.call_common_lower tf_model = tf.Module() diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py index 8c1e85c59c..b0d81c31ee 100644 --- a/deepmd/jax/model/hlo.py +++ b/deepmd/jax/model/hlo.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import json from typing import ( Any, ) @@ -30,6 +31,14 @@ r_differentiable=True, c_differentiable=True, ), + "energy_hessian": OutputVariableDef( + "energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + r_hessian=True, + ), "mask": OutputVariableDef( "mask", shape=[1], @@ -171,7 +180,17 @@ def call( def model_output_def(self) -> ModelOutputDef: return ModelOutputDef( - FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()]) + FittingOutputDef( + [ + OUTPUT_DEFS[ + f"{tt}_hessian" + if tt == "energy" + and json.loads(self.model_def_script).get("hessian_mode", False) + else tt + ] + for tt in self.model_output_type() + ] + ) ) def call_lower( diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 14386d9f3d..3e24a139d0 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -22,7 +22,7 @@ ) -def deserialize_to_file(model_file: str, data: dict) -> None: +def deserialize_to_file(model_file: str, data: dict, hessian: bool = False) -> None: """Deserialize the dictionary to a model file. Parameters @@ -31,10 +31,15 @@ def deserialize_to_file(model_file: str, data: dict) -> None: The model file to be saved. data : dict The dictionary to be deserialized. + hessian : bool, default=False + Whether to include the Hessian in the model outputs. """ if model_file.endswith(".jax"): model = BaseModel.deserialize(data["model"]) - model_def_script = data["model_def_script"] + model_def_script = data["model_def_script"].copy() + if hessian: + model.enable_hessian() + model_def_script["hessian_mode"] = True _, state = nnx.split(model) with ocp.Checkpointer( ocp.CompositeCheckpointHandler("state", "model_def_script") @@ -48,7 +53,10 @@ def deserialize_to_file(model_file: str, data: dict) -> None: ) elif model_file.endswith(".hlo"): model = BaseModel.deserialize(data["model"]) - model_def_script = data["model_def_script"] + model_def_script = data["model_def_script"].copy() + if hessian: + model.enable_hessian() + model_def_script["hessian_mode"] = True call_lower = model.call_common_lower nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost") @@ -113,6 +121,7 @@ def call_lower_with_fixed_do_atomic_virial( serialized_atomic_virial_no_ghost = exported_atomic_virial_no_ghost.serialize() data = data.copy() + data["model_def_script"] = model_def_script data.setdefault("@variables", {}) data["@variables"]["stablehlo"] = np.void(serialized) data["@variables"]["stablehlo_atomic_virial"] = np.void( @@ -142,7 +151,7 @@ def call_lower_with_fixed_do_atomic_virial( deserialize_to_file as deserialize_to_savedmodel, ) - return deserialize_to_savedmodel(model_file, data) + return deserialize_to_savedmodel(model_file, data, hessian=hessian) else: raise ValueError("Unsupported file extension") diff --git a/deepmd/main.py b/deepmd/main.py index 43f40dc214..68fdbf6269 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -350,6 +350,12 @@ def main_parser() -> argparse.ArgumentParser: type=str, help="(Supported backend: PyTorch) Task head (alias: model branch) to freeze if in multi-task mode.", ) + parser_frz.add_argument( + "--hessian", + action="store_true", + default=False, + help="(Supported backend: JAX) Add the Hessian to the frozen model output.", + ) # * test script ******************************************************************** parser_tst = subparsers.add_parser( diff --git a/source/tests/jax/test_training.py b/source/tests/jax/test_training.py index d0713fe466..075bd9ebd7 100644 --- a/source/tests/jax/test_training.py +++ b/source/tests/jax/test_training.py @@ -14,10 +14,16 @@ from pathlib import ( Path, ) +from types import ( + SimpleNamespace, +) from unittest.mock import ( patch, ) +from deepmd.dpmodel.output_def import ( + OutputVariableCategory, +) from deepmd.jax.entrypoints.freeze import ( freeze, ) @@ -27,6 +33,12 @@ from deepmd.jax.entrypoints.train import ( update_sel, ) +from deepmd.jax.infer.deep_eval import ( + DeepEval, +) +from deepmd.jax.model.hlo import ( + HLO, +) from deepmd.utils.compat import ( convert_optimizer_v31_to_v32, ) @@ -169,17 +181,19 @@ def test_update_sel_uses_jax_neighbor_stat(self, get_nbor_stat, get_data) -> Non def test_freeze_entrypoint_uses_checkpoint_pointer( self, serialize_from_file, deserialize_to_file ) -> None: - """Freeze resolves the stable checkpoint pointer without Hessian options.""" + """Freeze resolves the stable checkpoint pointer and forwards Hessian.""" checkpoint_dir = self.work_dir / "ckpt" checkpoint_dir.mkdir() (checkpoint_dir / "checkpoint").write_text("model-1.jax") serialize_from_file.return_value = {"model": {}, "model_def_script": {}} - freeze(checkpoint_folder=str(checkpoint_dir), output="frozen_model") + freeze( + checkpoint_folder=str(checkpoint_dir), output="frozen_model", hessian=True + ) serialize_from_file.assert_called_once_with(str(checkpoint_dir / "model-1.jax")) deserialize_to_file.assert_called_once_with( - "frozen_model.hlo", serialize_from_file.return_value + "frozen_model.hlo", serialize_from_file.return_value, hessian=True ) @patch("deepmd.jax.entrypoints.main.freeze") @@ -191,8 +205,39 @@ def test_main_dispatches_freeze(self, freeze_entrypoint) -> None: log_path=None, checkpoint_folder=".", output="frozen_model", + hessian=False, ) main(args) freeze_entrypoint.assert_called_once() + + def test_hlo_hessian_mode_updates_output_def(self) -> None: + """HLO output definition should expose Hessian when requested.""" + hlo = object.__new__(HLO) + hlo._model_output_type = ["energy"] + hlo.model_def_script = json.dumps({"hessian_mode": True}) + + output_def = hlo.model_output_def() + + self.assertTrue(output_def["energy"].r_hessian) + self.assertIn("energy_derv_r_derv_r", output_def.keys()) + + def test_deep_eval_requests_hessian_for_hessian_model(self) -> None: + """Non-atomic JAX evaluation should request Hessian outputs.""" + hlo = object.__new__(HLO) + hlo._model_output_type = ["energy"] + hlo.model_def_script = json.dumps({"hessian_mode": True}) + deep_eval = object.__new__(DeepEval) + deep_eval.output_def = hlo.model_output_def() + deep_eval.dp = SimpleNamespace( + get_model_def_script=lambda: json.dumps({"hessian_mode": True}) + ) + + request_defs = deep_eval._get_request_defs(atomic=False) + + self.assertTrue(deep_eval.get_has_hessian()) + self.assertIn( + OutputVariableCategory.DERV_R_DERV_R, + {odef.category for odef in request_defs}, + )