diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index c83e35dab3..56af5f4f43 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -79,6 +79,7 @@ def __init__( pair_exclude_types: list[tuple[int, int]] = [], rcond: Optional[float] = None, preset_out_bias: Optional[dict[str, np.ndarray]] = None, + data_stat_protect: float = 1e-2, ) -> None: torch.nn.Module.__init__(self) BaseAtomicModel_.__init__(self) @@ -87,6 +88,7 @@ def __init__( self.reinit_pair_exclude(pair_exclude_types) self.rcond = rcond self.preset_out_bias = preset_out_bias + self.data_stat_protect = data_stat_protect def init_out_stat(self) -> None: """Initialize the output bias.""" diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 1b20eeb217..5a5655b72c 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -299,6 +299,9 @@ def wrapped_sampler(): return sampled self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) + self.fitting_net.compute_input_stats( + wrapped_sampler, protection=self.data_stat_protect + ) self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) def get_dim_fparam(self) -> int: diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 37e664e82a..8d451f087f 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -254,6 +254,7 @@ def get_standard_model(model_params): preset_out_bias = _convert_preset_out_bias_to_array( preset_out_bias, model_params["type_map"] ) + data_stat_protect = model_params.get("data_stat_protect", 1e-2) if fitting_net_type == "dipole": modelcls = DipoleModel @@ -275,6 +276,7 @@ def get_standard_model(model_params): atom_exclude_types=atom_exclude_types, pair_exclude_types=pair_exclude_types, preset_out_bias=preset_out_bias, + data_stat_protect=data_stat_protect, ) if model_params.get("hessian_mode"): model.enable_hessian() diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index dcbbbe9602..f32592b977 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -4,6 +4,7 @@ abstractmethod, ) from typing import ( + Callable, Optional, Union, ) @@ -71,6 +72,84 @@ def share_params(self, base_class, shared_level, resume=False) -> None: else: raise NotImplementedError + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + protection: float = 1e-2, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the fittings from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + protection : float + Divided-by-zero protection + """ + if callable(merged): + sampled = merged() + else: + sampled = merged + # stat fparam + if self.numb_fparam > 0: + cat_data = torch.cat([frame["fparam"] for frame in sampled], dim=0) + cat_data = torch.reshape(cat_data, [-1, self.numb_fparam]) + fparam_avg = torch.mean(cat_data, dim=0) + fparam_std = torch.std(cat_data, dim=0, unbiased=False) + fparam_std = torch.where( + fparam_std < protection, + torch.tensor( + protection, dtype=fparam_std.dtype, device=fparam_std.device + ), + fparam_std, + ) + fparam_inv_std = 1.0 / fparam_std + self.fparam_avg.copy_( + torch.tensor(fparam_avg, device=env.DEVICE, dtype=self.fparam_avg.dtype) + ) + self.fparam_inv_std.copy_( + torch.tensor( + fparam_inv_std, device=env.DEVICE, dtype=self.fparam_inv_std.dtype + ) + ) + # stat aparam + if self.numb_aparam > 0: + sys_sumv = [] + sys_sumv2 = [] + sys_sumn = [] + for ss_ in [frame["aparam"] for frame in sampled]: + ss = torch.reshape(ss_, [-1, self.numb_aparam]) + sys_sumv.append(torch.sum(ss, dim=0)) + sys_sumv2.append(torch.sum(ss * ss, dim=0)) + sys_sumn.append(ss.shape[0]) + sumv = torch.sum(torch.stack(sys_sumv), dim=0) + sumv2 = torch.sum(torch.stack(sys_sumv2), dim=0) + sumn = sum(sys_sumn) + aparam_avg = sumv / sumn + aparam_std = torch.sqrt(sumv2 / sumn - (sumv / sumn) ** 2) + aparam_std = torch.where( + aparam_std < protection, + torch.tensor( + protection, dtype=aparam_std.dtype, device=aparam_std.device + ), + aparam_std, + ) + aparam_inv_std = 1.0 / aparam_std + self.aparam_avg.copy_( + torch.tensor(aparam_avg, device=env.DEVICE, dtype=self.aparam_avg.dtype) + ) + self.aparam_inv_std.copy_( + torch.tensor( + aparam_inv_std, device=env.DEVICE, dtype=self.aparam_inv_std.dtype + ) + ) + class GeneralFitting(Fitting): """Construct a general fitting net. diff --git a/source/tests/pt/test_fitting_stat.py b/source/tests/pt/test_fitting_stat.py new file mode 100644 index 0000000000..bc02b539a0 --- /dev/null +++ b/source/tests/pt/test_fitting_stat.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.pt.model.descriptor import ( + DescrptSeA, +) +from deepmd.pt.model.task import ( + EnergyFittingNet, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + + +def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds): + merged_output_stat = [] + nsys = len(sys_natoms) + ndof = len(avgs) + for ii in range(nsys): + sys_dict = {} + tmp_data_f = [] + tmp_data_a = [] + for jj in range(ndof): + rng = np.random.default_rng(2025 * ii + 220 * jj) + tmp_data_f.append( + rng.normal(loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], 1)) + ) + rng = np.random.default_rng(220 * ii + 1636 * jj) + tmp_data_a.append( + rng.normal( + loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], sys_natoms[ii]) + ) + ) + tmp_data_f = np.transpose(tmp_data_f, (1, 2, 0)) + tmp_data_a = np.transpose(tmp_data_a, (1, 2, 0)) + sys_dict["fparam"] = to_torch_tensor(tmp_data_f) + sys_dict["aparam"] = to_torch_tensor(tmp_data_a) + merged_output_stat.append(sys_dict) + return merged_output_stat + + +def _brute_fparam_pt(data, ndim): + adata = [to_numpy_array(ii["fparam"]) for ii in data] + all_data = [] + for ii in adata: + tmp = np.reshape(ii, [-1, ndim]) + if len(all_data) == 0: + all_data = np.array(tmp) + else: + all_data = np.concatenate((all_data, tmp), axis=0) + avg = np.average(all_data, axis=0) + std = np.std(all_data, axis=0) + return avg, std + + +def _brute_aparam_pt(data, ndim): + adata = [to_numpy_array(ii["aparam"]) for ii in data] + all_data = [] + for ii in adata: + tmp = np.reshape(ii, [-1, ndim]) + if len(all_data) == 0: + all_data = np.array(tmp) + else: + all_data = np.concatenate((all_data, tmp), axis=0) + avg = np.average(all_data, axis=0) + std = np.std(all_data, axis=0) + return avg, std + + +class TestEnerFittingStat(unittest.TestCase): + def test(self) -> None: + descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) + fitting = EnergyFittingNet( + descrpt.get_ntypes(), + descrpt.get_dim_out(), + neuron=[240, 240, 240], + resnet_dt=True, + numb_fparam=3, + numb_aparam=3, + ) + avgs = [0, 10, 100] + stds = [2, 0.4, 0.00001] + sys_natoms = [10, 100] + sys_nframes = [5, 2] + all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds) + frefa, frefs = _brute_fparam_pt(all_data, len(avgs)) + arefa, arefs = _brute_aparam_pt(all_data, len(avgs)) + fitting.compute_input_stats(all_data, protection=1e-2) + frefs_inv = 1.0 / frefs + arefs_inv = 1.0 / arefs + frefs_inv[frefs_inv > 100] = 100 + arefs_inv[arefs_inv > 100] = 100 + np.testing.assert_almost_equal(frefa, to_numpy_array(fitting.fparam_avg)) + np.testing.assert_almost_equal( + frefs_inv, to_numpy_array(fitting.fparam_inv_std) + ) + np.testing.assert_almost_equal(arefa, to_numpy_array(fitting.aparam_avg)) + np.testing.assert_almost_equal( + arefs_inv, to_numpy_array(fitting.aparam_inv_std) + )