Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ec820ff
Init branch
Chengqian-Zhang Dec 25, 2024
2cac1c6
Delete files
Chengqian-Zhang Dec 25, 2024
c561efe
Delete file:test_neighbor_stat.py
Chengqian-Zhang Dec 25, 2024
af5e589
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 25, 2024
e7dfe91
Merge branch 'devel' into fitting_stat
Chengqian-Zhang Feb 17, 2025
bc7a0aa
Add aparam
Chengqian-Zhang Feb 17, 2025
95229c5
Delete fparam.npy
Chengqian-Zhang Feb 17, 2025
9e0b2ff
Add unittest of fitting_stat
Chengqian-Zhang Feb 17, 2025
66cab94
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2025
3072fb1
Add protection
Chengqian-Zhang Feb 17, 2025
1a4836c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2025
1b0184b
Merge branch 'devel' into fitting_stat
Chengqian-Zhang Feb 18, 2025
3e5cbe4
Add near zero UT
Chengqian-Zhang Feb 18, 2025
f968b76
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2025
2d808af
Delete useless import
Chengqian-Zhang Feb 18, 2025
4f1e009
Merge branch 'fitting_stat' of github.com:Chengqian-Zhang/deepmd-kit …
Chengqian-Zhang Feb 18, 2025
68b0f61
Fix UT
Chengqian-Zhang Feb 18, 2025
1bd0854
Delete checkpoint
Chengqian-Zhang Feb 18, 2025
9010763
Rerun YT
Chengqian-Zhang Feb 18, 2025
3bc746d
Delete ignore
Chengqian-Zhang Feb 20, 2025
bdd7b3e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2025
b766b0c
Add noqa
Chengqian-Zhang Feb 20, 2025
b034bee
Solve conflict
Chengqian-Zhang Feb 20, 2025
7dd3609
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2025
e651bf4
Fix pre-coommit
Chengqian-Zhang Feb 20, 2025
9cb04eb
Solve pre-commit
Chengqian-Zhang Feb 20, 2025
1b033a9
Delete ignore
Chengqian-Zhang Feb 20, 2025
3aed1b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
pair_exclude_types: list[tuple[int, int]] = [],
rcond: Optional[float] = None,
preset_out_bias: Optional[dict[str, np.ndarray]] = None,
data_stat_protect: float = 1e-2,
) -> None:
torch.nn.Module.__init__(self)
BaseAtomicModel_.__init__(self)
Expand All @@ -87,6 +88,7 @@ def __init__(
self.reinit_pair_exclude(pair_exclude_types)
self.rcond = rcond
self.preset_out_bias = preset_out_bias
self.data_stat_protect = data_stat_protect

def init_out_stat(self) -> None:
"""Initialize the output bias."""
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ def wrapped_sampler():
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
self.fitting_net.compute_input_stats(
wrapped_sampler, protection=self.data_stat_protect
)
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def get_dim_fparam(self) -> int:
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def get_standard_model(model_params):
preset_out_bias = _convert_preset_out_bias_to_array(
preset_out_bias, model_params["type_map"]
)
data_stat_protect = model_params.get("data_stat_protect", 1e-2)

if fitting_net_type == "dipole":
modelcls = DipoleModel
Expand All @@ -275,6 +276,7 @@ def get_standard_model(model_params):
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
preset_out_bias=preset_out_bias,
data_stat_protect=data_stat_protect,
)
if model_params.get("hessian_mode"):
model.enable_hessian()
Expand Down
79 changes: 79 additions & 0 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
abstractmethod,
)
from typing import (
Callable,
Optional,
Union,
)
Expand Down Expand Up @@ -71,6 +72,84 @@ def share_params(self, base_class, shared_level, resume=False) -> None:
else:
raise NotImplementedError

def compute_input_stats(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
protection: float = 1e-2,
) -> None:
"""
Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.

Parameters
----------
merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
protection : float
Divided-by-zero protection
"""
if callable(merged):
sampled = merged()
else:
sampled = merged
# stat fparam
if self.numb_fparam > 0:
cat_data = torch.cat([frame["fparam"] for frame in sampled], dim=0)
cat_data = torch.reshape(cat_data, [-1, self.numb_fparam])
fparam_avg = torch.mean(cat_data, dim=0)
fparam_std = torch.std(cat_data, dim=0, unbiased=False)
fparam_std = torch.where(
fparam_std < protection,
torch.tensor(
protection, dtype=fparam_std.dtype, device=fparam_std.device
),
fparam_std,
)
fparam_inv_std = 1.0 / fparam_std
self.fparam_avg.copy_(
torch.tensor(fparam_avg, device=env.DEVICE, dtype=self.fparam_avg.dtype)
)
self.fparam_inv_std.copy_(
torch.tensor(
fparam_inv_std, device=env.DEVICE, dtype=self.fparam_inv_std.dtype
)
)
# stat aparam
if self.numb_aparam > 0:
sys_sumv = []
sys_sumv2 = []
sys_sumn = []
for ss_ in [frame["aparam"] for frame in sampled]:
ss = torch.reshape(ss_, [-1, self.numb_aparam])
sys_sumv.append(torch.sum(ss, dim=0))
sys_sumv2.append(torch.sum(ss * ss, dim=0))
sys_sumn.append(ss.shape[0])
sumv = torch.sum(torch.stack(sys_sumv), dim=0)
sumv2 = torch.sum(torch.stack(sys_sumv2), dim=0)
sumn = sum(sys_sumn)
aparam_avg = sumv / sumn
aparam_std = torch.sqrt(sumv2 / sumn - (sumv / sumn) ** 2)
aparam_std = torch.where(
aparam_std < protection,
torch.tensor(
protection, dtype=aparam_std.dtype, device=aparam_std.device
),
aparam_std,
)
aparam_inv_std = 1.0 / aparam_std
self.aparam_avg.copy_(
torch.tensor(aparam_avg, device=env.DEVICE, dtype=self.aparam_avg.dtype)
)
self.aparam_inv_std.copy_(
torch.tensor(
aparam_inv_std, device=env.DEVICE, dtype=self.aparam_inv_std.dtype
)
)


class GeneralFitting(Fitting):
"""Construct a general fitting net.
Expand Down
103 changes: 103 additions & 0 deletions source/tests/pt/test_fitting_stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import numpy as np

from deepmd.pt.model.descriptor import (
DescrptSeA,
)
from deepmd.pt.model.task import (
EnergyFittingNet,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
)


def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds):
merged_output_stat = []
nsys = len(sys_natoms)
ndof = len(avgs)
for ii in range(nsys):
sys_dict = {}
tmp_data_f = []
tmp_data_a = []
for jj in range(ndof):
rng = np.random.default_rng(2025 * ii + 220 * jj)
tmp_data_f.append(
rng.normal(loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], 1))
)
rng = np.random.default_rng(220 * ii + 1636 * jj)
tmp_data_a.append(
rng.normal(
loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], sys_natoms[ii])
)
)
tmp_data_f = np.transpose(tmp_data_f, (1, 2, 0))
tmp_data_a = np.transpose(tmp_data_a, (1, 2, 0))
sys_dict["fparam"] = to_torch_tensor(tmp_data_f)
sys_dict["aparam"] = to_torch_tensor(tmp_data_a)
merged_output_stat.append(sys_dict)
return merged_output_stat


def _brute_fparam_pt(data, ndim):
adata = [to_numpy_array(ii["fparam"]) for ii in data]
all_data = []
for ii in adata:
tmp = np.reshape(ii, [-1, ndim])
if len(all_data) == 0:
all_data = np.array(tmp)
else:
all_data = np.concatenate((all_data, tmp), axis=0)
avg = np.average(all_data, axis=0)
std = np.std(all_data, axis=0)
return avg, std


def _brute_aparam_pt(data, ndim):
adata = [to_numpy_array(ii["aparam"]) for ii in data]
all_data = []
for ii in adata:
tmp = np.reshape(ii, [-1, ndim])
if len(all_data) == 0:
all_data = np.array(tmp)
else:
all_data = np.concatenate((all_data, tmp), axis=0)
avg = np.average(all_data, axis=0)
std = np.std(all_data, axis=0)
return avg, std


class TestEnerFittingStat(unittest.TestCase):
def test(self) -> None:
Comment thread
Chengqian-Zhang marked this conversation as resolved.
descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16)
fitting = EnergyFittingNet(
descrpt.get_ntypes(),
descrpt.get_dim_out(),
neuron=[240, 240, 240],
resnet_dt=True,
numb_fparam=3,
numb_aparam=3,
)
avgs = [0, 10, 100]
stds = [2, 0.4, 0.00001]
sys_natoms = [10, 100]
sys_nframes = [5, 2]
all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds)
frefa, frefs = _brute_fparam_pt(all_data, len(avgs))
arefa, arefs = _brute_aparam_pt(all_data, len(avgs))
fitting.compute_input_stats(all_data, protection=1e-2)
frefs_inv = 1.0 / frefs
arefs_inv = 1.0 / arefs
frefs_inv[frefs_inv > 100] = 100
arefs_inv[arefs_inv > 100] = 100
np.testing.assert_almost_equal(frefa, to_numpy_array(fitting.fparam_avg))
np.testing.assert_almost_equal(
frefs_inv, to_numpy_array(fitting.fparam_inv_std)
)
np.testing.assert_almost_equal(arefa, to_numpy_array(fitting.aparam_avg))
np.testing.assert_almost_equal(
arefs_inv, to_numpy_array(fitting.aparam_inv_std)
)