Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion deepmd/jax/entrypoints/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def freeze(
*,
checkpoint_folder: str,
output: str,
hessian: bool = False,
**kwargs: object,
) -> None:
"""Freeze a JAX checkpoint into a serialized model file.
Expand All @@ -30,6 +31,8 @@ def freeze(
output : str
Output model filename or prefix. The JAX model suffix is added when the
filename has no supported backend suffix.
hessian : bool, default=False
Whether to include the Hessian in the frozen model outputs.
**kwargs
Other CLI arguments accepted for backend entry-point compatibility.
"""
Expand All @@ -46,4 +49,4 @@ def freeze(
strict_prefer=True,
)
data = serialize_from_file(checkpoint_folder)
deserialize_to_file(output, data)
deserialize_to_file(output, data, hessian=hessian)
5 changes: 5 additions & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]:
OutputVariableCategory.REDU,
OutputVariableCategory.DERV_R,
OutputVariableCategory.DERV_C_REDU,
OutputVariableCategory.DERV_R_DERV_R,
)
]

Expand Down Expand Up @@ -433,6 +434,10 @@ def get_model_def_script(self) -> dict:
"""Get model definition script."""
return json.loads(self.dp.get_model_def_script())

def get_has_hessian(self) -> bool:
"""Check if the model has Hessian output."""
return self.get_model_def_script().get("hessian_mode", False)

def get_model(self) -> Any:
"""Get the JAX model as BaseModel.

Expand Down
9 changes: 7 additions & 2 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)


def deserialize_to_file(model_file: str, data: dict) -> None:
def deserialize_to_file(model_file: str, data: dict, hessian: bool = False) -> None:
"""Deserialize the dictionary to a model file.

Parameters
Expand All @@ -30,10 +30,15 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
The model file to be saved.
data : dict
The dictionary to be deserialized.
hessian : bool, default=False
Whether to include the Hessian in the model outputs.
"""
if model_file.endswith(".savedmodel"):
model = BaseModel.deserialize(data["model"])
model_def_script = data["model_def_script"]
model_def_script = data["model_def_script"].copy()
if hessian:
model.enable_hessian()
model_def_script["hessian_mode"] = True
call_lower = model.call_common_lower

tf_model = tf.Module()
Expand Down
21 changes: 20 additions & 1 deletion deepmd/jax/model/hlo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
from typing import (
Any,
)
Expand Down Expand Up @@ -30,6 +31,14 @@
r_differentiable=True,
c_differentiable=True,
),
"energy_hessian": OutputVariableDef(
"energy",
shape=[1],
reducible=True,
r_differentiable=True,
c_differentiable=True,
r_hessian=True,
),
"mask": OutputVariableDef(
"mask",
shape=[1],
Expand Down Expand Up @@ -171,7 +180,17 @@ def call(

def model_output_def(self) -> ModelOutputDef:
return ModelOutputDef(
FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()])
FittingOutputDef(
[
OUTPUT_DEFS[
f"{tt}_hessian"
if tt == "energy"
and json.loads(self.model_def_script).get("hessian_mode", False)
else tt
]
for tt in self.model_output_type()
]
)
)

def call_lower(
Expand Down
17 changes: 13 additions & 4 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)


def deserialize_to_file(model_file: str, data: dict) -> None:
def deserialize_to_file(model_file: str, data: dict, hessian: bool = False) -> None:

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
"""Deserialize the dictionary to a model file.

Parameters
Expand All @@ -31,10 +31,15 @@
The model file to be saved.
data : dict
The dictionary to be deserialized.
hessian : bool, default=False
Whether to include the Hessian in the model outputs.
"""
if model_file.endswith(".jax"):
model = BaseModel.deserialize(data["model"])
model_def_script = data["model_def_script"]
model_def_script = data["model_def_script"].copy()
if hessian:
model.enable_hessian()
model_def_script["hessian_mode"] = True
_, state = nnx.split(model)
with ocp.Checkpointer(
ocp.CompositeCheckpointHandler("state", "model_def_script")
Expand All @@ -48,7 +53,10 @@
)
elif model_file.endswith(".hlo"):
model = BaseModel.deserialize(data["model"])
model_def_script = data["model_def_script"]
model_def_script = data["model_def_script"].copy()
if hessian:
model.enable_hessian()
model_def_script["hessian_mode"] = True
call_lower = model.call_common_lower

nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost")
Expand Down Expand Up @@ -113,6 +121,7 @@
serialized_atomic_virial_no_ghost = exported_atomic_virial_no_ghost.serialize()

data = data.copy()
data["model_def_script"] = model_def_script
data.setdefault("@variables", {})
data["@variables"]["stablehlo"] = np.void(serialized)
data["@variables"]["stablehlo_atomic_virial"] = np.void(
Expand Down Expand Up @@ -142,7 +151,7 @@
deserialize_to_file as deserialize_to_savedmodel,
)

return deserialize_to_savedmodel(model_file, data)
return deserialize_to_savedmodel(model_file, data, hessian=hessian)
else:
raise ValueError("Unsupported file extension")

Expand Down
6 changes: 6 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ def main_parser() -> argparse.ArgumentParser:
type=str,
help="(Supported backend: PyTorch) Task head (alias: model branch) to freeze if in multi-task mode.",
)
parser_frz.add_argument(
"--hessian",
action="store_true",
default=False,
help="(Supported backend: JAX) Add the Hessian to the frozen model output.",
)

# * test script ********************************************************************
parser_tst = subparsers.add_parser(
Expand Down
51 changes: 48 additions & 3 deletions source/tests/jax/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@
from pathlib import (
Path,
)
from types import (
SimpleNamespace,
)
from unittest.mock import (
patch,
)

from deepmd.dpmodel.output_def import (
OutputVariableCategory,
)
from deepmd.jax.entrypoints.freeze import (
freeze,
)
Expand All @@ -27,6 +33,12 @@
from deepmd.jax.entrypoints.train import (
update_sel,
)
from deepmd.jax.infer.deep_eval import (
DeepEval,
)
from deepmd.jax.model.hlo import (
HLO,
)
from deepmd.utils.compat import (
convert_optimizer_v31_to_v32,
)
Expand Down Expand Up @@ -169,17 +181,19 @@ def test_update_sel_uses_jax_neighbor_stat(self, get_nbor_stat, get_data) -> Non
def test_freeze_entrypoint_uses_checkpoint_pointer(
self, serialize_from_file, deserialize_to_file
) -> None:
"""Freeze resolves the stable checkpoint pointer without Hessian options."""
"""Freeze resolves the stable checkpoint pointer and forwards Hessian."""
checkpoint_dir = self.work_dir / "ckpt"
checkpoint_dir.mkdir()
(checkpoint_dir / "checkpoint").write_text("model-1.jax")
serialize_from_file.return_value = {"model": {}, "model_def_script": {}}

freeze(checkpoint_folder=str(checkpoint_dir), output="frozen_model")
freeze(
checkpoint_folder=str(checkpoint_dir), output="frozen_model", hessian=True
)

serialize_from_file.assert_called_once_with(str(checkpoint_dir / "model-1.jax"))
deserialize_to_file.assert_called_once_with(
"frozen_model.hlo", serialize_from_file.return_value
"frozen_model.hlo", serialize_from_file.return_value, hessian=True
)

@patch("deepmd.jax.entrypoints.main.freeze")
Expand All @@ -191,8 +205,39 @@ def test_main_dispatches_freeze(self, freeze_entrypoint) -> None:
log_path=None,
checkpoint_folder=".",
output="frozen_model",
hessian=False,
)

main(args)

freeze_entrypoint.assert_called_once()

def test_hlo_hessian_mode_updates_output_def(self) -> None:
"""HLO output definition should expose Hessian when requested."""
hlo = object.__new__(HLO)
hlo._model_output_type = ["energy"]
hlo.model_def_script = json.dumps({"hessian_mode": True})

output_def = hlo.model_output_def()

self.assertTrue(output_def["energy"].r_hessian)
self.assertIn("energy_derv_r_derv_r", output_def.keys())

def test_deep_eval_requests_hessian_for_hessian_model(self) -> None:
"""Non-atomic JAX evaluation should request Hessian outputs."""
hlo = object.__new__(HLO)
hlo._model_output_type = ["energy"]
hlo.model_def_script = json.dumps({"hessian_mode": True})
deep_eval = object.__new__(DeepEval)
deep_eval.output_def = hlo.model_output_def()
deep_eval.dp = SimpleNamespace(
get_model_def_script=lambda: json.dumps({"hessian_mode": True})
)

request_defs = deep_eval._get_request_defs(atomic=False)

self.assertTrue(deep_eval.get_has_hessian())
self.assertIn(
OutputVariableCategory.DERV_R_DERV_R,
{odef.category for odef in request_defs},
)
Loading