Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def __init__(
model_data = load_dp_model(model_file)
self.dp = HLO(
stablehlo=model_data["@variables"]["stablehlo"].tobytes(),
stablehlo_atomic_virial=model_data["@variables"][
"stablehlo_atomic_virial"
].tobytes(),
model_def_script=model_data["model_def_script"],
**model_data["constants"],
)
Expand Down
60 changes: 54 additions & 6 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,65 @@ def eval_output(
assert vdef.r_differentiable
# avr: [nf, *def, nall, 3, 3]
avr = jnp.einsum("f...ai,faj->f...aij", ff, extended_coord)
# the correction sums to zero, which does not contribute to global virial
if do_atomic_virial:

def eval_ce(
cc_ext,
extended_atype,
nlist,
mapping,
fparam,
aparam,
*,
_kk=kk,
_atom_axis=atom_axis - 1,
):
# atomic_ret[_kk]: [nf, nloc, *def]
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext[None, ...],
extended_atype[None, ...],
nlist[None, ...],
mapping=mapping[None, ...] if mapping is not None else None,
fparam=fparam[None, ...] if fparam is not None else None,
aparam=aparam[None, ...] if aparam is not None else None,
)
nloc = nlist.shape[0]
cc_loc = jax.lax.stop_gradient(cc_ext)[:nloc, ...]
cc_loc = jnp.reshape(cc_loc, [nloc, *[1] * def_ndim, 3])
Comment thread
njzjz marked this conversation as resolved.
# [*def, 3]
return jnp.sum(
atomic_ret[_kk][0, ..., None] * cc_loc, axis=_atom_axis
)

# extended_virial_corr: [nf, *def, 3, nall, 3]
extended_virial_corr = jax.vmap(jax.jacrev(eval_ce, argnums=0))(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
)
# move the first 3 to the last
# [nf, *def, nall, 3, 3]
extended_virial_corr = jnp.transpose(
extended_virial_corr,
[
0,
*range(1, def_ndim + 1),
def_ndim + 2,
def_ndim + 3,
def_ndim + 1,
],
)
avr += extended_virial_corr
# to [...,3,3] -> [...,9]
# avr: [nf, *def, nall, 9]
avr = jnp.reshape(avr, [*ff.shape[:-1], 9])
# extended_virial: [nf, nall, *def, 9]
extended_virial = jnp.transpose(
avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2]
)

# the correction sums to zero, which does not contribute to global virial
# cannot jit
# if do_atomic_virial:
# raise NotImplementedError("Atomic virial is not implemented yet.")
# to [...,3,3] -> [...,9]
model_predict[kk_derv_c] = extended_virial
return model_predict
11 changes: 9 additions & 2 deletions deepmd/jax/model/hlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class HLO(BaseModel):
def __init__(
self,
stablehlo,
stablehlo_atomic_virial,
model_def_script,
type_map,
rcut,
Expand All @@ -58,6 +59,9 @@ def __init__(
sel,
) -> None:
self._call_lower = jax_export.deserialize(stablehlo).call
self._call_lower_atomic_virial = jax_export.deserialize(
stablehlo_atomic_virial
).call
self.stablehlo = stablehlo
self.type_map = type_map
self.rcut = rcut
Expand Down Expand Up @@ -170,14 +174,17 @@ def call_lower(
aparam: Optional[jnp.ndarray] = None,
do_atomic_virial: bool = False,
):
return self._call_lower(
if do_atomic_virial:
call_lower = self._call_lower_atomic_virial
else:
call_lower = self._call_lower
return call_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
do_atomic_virial,
)

def get_type_map(self) -> list[str]:
Expand Down
49 changes: 37 additions & 12 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,48 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
call_lower = model.call_lower

nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost")
exported = jax_export.export(jax.jit(call_lower))(
jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype
jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping
jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64)
if model.get_dim_fparam()
else None, # fparam
jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64)
if model.get_dim_aparam()
else None, # aparam
False, # do_atomic_virial

def exported_whether_do_atomic_virial(do_atomic_virial):
def call_lower_with_fixed_do_atomic_virial(
coord, atype, nlist, nlist_start, fparam, aparam
):
return call_lower(
coord,
atype,
nlist,
nlist_start,
fparam,
aparam,
do_atomic_virial=do_atomic_virial,
)

return jax_export.export(jax.jit(call_lower_with_fixed_do_atomic_virial))(
jax.ShapeDtypeStruct(
(nf, nloc + nghost, 3), jnp.float64
), # extended_coord
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype
jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping
jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64)
if model.get_dim_fparam()
else None, # fparam
jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64)
if model.get_dim_aparam()
else None, # aparam
)

exported = exported_whether_do_atomic_virial(do_atomic_virial=False)
exported_atomic_virial = exported_whether_do_atomic_virial(
do_atomic_virial=True
)
serialized: bytearray = exported.serialize()
serialized_atomic_virial = exported_atomic_virial.serialize()
data = data.copy()
data.setdefault("@variables", {})
data["@variables"]["stablehlo"] = np.void(serialized)
data["@variables"]["stablehlo_atomic_virial"] = np.void(
serialized_atomic_virial
)
data["constants"] = {
"type_map": model.get_type_map(),
"rcut": model.get_rcut(),
Expand Down
9 changes: 9 additions & 0 deletions source/tests/consistent/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ def test_deep_eval(self):
aparam=aparam,
)
rets.append(ret)
ret = deep_eval.eval(
self.coords,
self.box,
self.atype,
fparam=fparam,
aparam=aparam,
do_atomic_virial=True,
)
rets.append(ret)
for ret in rets[1:]:
for vv1, vv2 in zip(rets[0], ret):
if np.isnan(vv2).all():
Expand Down
10 changes: 9 additions & 1 deletion source/tests/consistent/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix):
{},
suffix=suffix,
)
return [ret["energy"], ret["atom_ener"], ret["force"], ret["virial"]], {
return [
ret["energy"],
ret["atom_ener"],
ret["force"],
ret["virial"],
ret["atom_virial"],
], {
t_coord: coords,
t_type: atype,
t_natoms: natoms,
Expand All @@ -69,6 +75,7 @@ def eval_pt_model(self, pt_obj: Any, natoms, coords, atype, box) -> Any:
numpy_to_torch(coords),
numpy_to_torch(atype),
box=numpy_to_torch(box),
do_atomic_virial=True,
Comment thread
njzjz marked this conversation as resolved.
).items()
}

Expand All @@ -83,5 +90,6 @@ def assert_jax_array(arr):
numpy_to_jax(coords),
numpy_to_jax(atype),
box=numpy_to_jax(box),
do_atomic_virial=True,
).items()
}
11 changes: 10 additions & 1 deletion source/tests/consistent/model/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,21 +216,30 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
ret["energy"].ravel(),
SKIP_FLAG,
SKIP_FLAG,
SKIP_FLAG,
)
elif backend is self.RefBackend.PT:
return (
ret["energy"].ravel(),
ret["atom_energy"].ravel(),
ret["force"].ravel(),
ret["virial"].ravel(),
ret["atom_virial"].ravel(),
)
elif backend is self.RefBackend.TF:
return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel())
return (
ret[0].ravel(),
ret[1].ravel(),
ret[2].ravel(),
ret[3].ravel(),
ret[4].ravel(),
)
elif backend is self.RefBackend.JAX:
return (
ret["energy_redu"].ravel(),
ret["energy"].ravel(),
ret["energy_derv_r"].ravel(),
ret["energy_derv_c_redu"].ravel(),
ret["energy_derv_c"].ravel(),
)
raise ValueError(f"Unknown backend: {backend}")