Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
00d5ce9
Merge pull request #1642 from amcadmus/devel
amcadmus Apr 16, 2022
f7aec87
change default NN precision from `float64` to `default` (#1644)
njzjz Apr 18, 2022
8d291b5
fix variable declaration error (#1651)
denghuilu Apr 22, 2022
a6ca734
update TF installation doc (#1652)
njzjz Apr 23, 2022
de7ba72
migrate test_cc from conda to docker (#1650)
njzjz Apr 23, 2022
18ac81f
use float constants and functions in float functions (#1647)
njzjz Apr 28, 2022
2bf51f3
fix bug of aparam size, should be nlocal_real (#1664)
amcadmus Apr 28, 2022
85a3a0e
convert tabulate data from np.ndarray to tf.Tensor (#1657)
njzjz Apr 29, 2022
d1fa9e9
reset the graph before freezing the compressed model (#1658)
njzjz Apr 29, 2022
f8533cb
fix rcut in hybrid model compression (#1663)
njzjz Apr 29, 2022
a2443d9
add free_energy to ase calculator (#1667)
njzjz Apr 29, 2022
a31544c
rewrite data doc (#1668)
njzjz Apr 29, 2022
edfef49
migrate sphinx mathjax from jsdelivr to cdnjs (#1669)
njzjz Apr 29, 2022
28fd987
Documentation improvements (#1673)
chazeon May 3, 2022
cc136bb
provide valid_data the same type_map as train_data (#1677)
njzjz May 3, 2022
92230a3
deepmodeling.org -> deepmodeling.com (#1678)
njzjz May 3, 2022
899d102
fix compress training (#1680)
njzjz May 4, 2022
e9b27a7
doc: add information abotu supported versions of dependencies (#1683)
njzjz May 6, 2022
087ae56
supports dp convert-from 0.12 (#1685)
njzjz May 6, 2022
44cf60a
fix bug of model compression training with se_e2_r type descriptor (#…
denghuilu May 9, 2022
f275ce7
Change Typo (#1687)
likefallwind May 9, 2022
89e7149
doc: add Interfaces out of DeePMD-kit (#1691)
njzjz May 9, 2022
2038016
fix grappler compilation error with TF 1.15 ~ 2.6 (#1697)
njzjz May 11, 2022
cf8bc56
set default fparam and aparam stat and recover from graph (#1695)
njzjz May 11, 2022
f1b8dca
fix git permission issue (#1716)
njzjz May 18, 2022
c2e6b45
optimize format_nlist_i_cpu (#1717)
njzjz May 18, 2022
b534355
use net-wise tabulate range (#1665)
njzjz May 19, 2022
962c9a8
implement parallelism for neighbor stat (#1624)
njzjz May 19, 2022
eb2a3c3
render equations in markdown files (#1721)
njzjz May 21, 2022
12dd522
fix tf_cxx_abi in TF 2.9 (#1723)
njzjz May 23, 2022
b11e33f
update the latest state of easy installation (#1726)
njzjz May 24, 2022
6bd3bda
correct type behavior when atomic energy is requested (#1727)
njzjz May 25, 2022
f2b5c2c
prevent explicit slash in the path (#1713)
njzjz May 25, 2022
28707e6
throw warning in C++ if env is not set (#1728)
njzjz May 25, 2022
c4ad552
in model_devi, assumes nopbc if box is set to None (#1704)
wanghan-iapcm May 25, 2022
2288e43
avoid static CUDA linking (#1731)
njzjz May 26, 2022
88012a2
add Loss abstract class (#1733)
njzjz May 31, 2022
901b23c
prevent from linking TF lib when determining TF version (#1734)
njzjz May 31, 2022
68932fa
fix finding TF 2.9 ABI (#1736)
njzjz Jun 2, 2022
e4cb2a7
Automatically label new pull requests based on the paths of files bei…
njzjz Jun 3, 2022
41af8e0
using int64 within the memory allocation operations (#1737)
denghuilu Jun 4, 2022
32b7a03
add `enable_atom_ener_coeff` option for energy loss (#1743)
njzjz Jun 7, 2022
ec1e816
replace GPU 1./sqrt with rsqrt (#1741)
njzjz Jun 7, 2022
47f9c0d
add DPRc docs (#1750)
njzjz Jun 8, 2022
3a14368
fix typos in docs and docstrings (#1752)
njzjz Jun 11, 2022
01cb2fa
docs: switch to dargs directive (#1753)
njzjz Jun 11, 2022
8494c1f
docs: fix emoji in PDF (#1754)
njzjz Jun 11, 2022
5ea22c9
add a script to build TF C++ library from source (#1755)
njzjz Jun 13, 2022
8c58455
add auto cli docs (#1751)
njzjz Jun 13, 2022
d329cfa
search TF from user site-packages (#1764)
njzjz Jun 17, 2022
049bfd4
set a proper std when there is no atoms in the data (#1765)
njzjz Jun 17, 2022
3270244
build_tf.py: expose CC and CXX env to bazel (#1766)
njzjz Jun 17, 2022
17d75d7
docs: add links to parameter keys (#1767)
njzjz Jun 19, 2022
2c07c78
add argument tests to check examples (#1770)
njzjz Jun 19, 2022
731cbbc
reduce training steps in tests (#1771)
njzjz Jun 19, 2022
48a525e
bump manylinux image to 2014 (#1780)
njzjz Jun 25, 2022
c84ac9e
deprecated docstring_parameter; use sphinx rst_epilog instead (#1783)
njzjz Jun 25, 2022
9da5b3c
add __init__.py to deepmd/train/ (#1784)
njzjz Jun 25, 2022
8e0bb19
remove run_doxygen from sphinx conf.py (#1785)
njzjz Jun 25, 2022
efc9cf5
docs: fix arg reference (#1786)
njzjz Jun 25, 2022
0d92a1e
bump LAMMPS version to stable_23Jun2022 (#1779)
njzjz Jun 26, 2022
06763c2
Merge branch 'devel'
Jun 26, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Python:
- deepmd/**/*
- source/tests/**/*
Docs: doc/**/*
Examples: examples/**/*
Core: source/lib/**/*
CUDA: source/lib/src/cuda/**/*
ROCM: source/lib/src/rocm/**/*
OP: source/op/**/*
C++: source/api_cc/**/*
LAMMPS: source/lmp/**/*
Gromacs: source/gmx/**/*
i-Pi: source/ipi/**/*
2 changes: 2 additions & 0 deletions .github/workflows/build_cc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ jobs:
- variant: cpu
- variant: cuda
steps:
- name: work around permission issue
run: git config --global --add safe.directory /__w/deepmd-kit/deepmd-kit
- uses: actions/checkout@master
with:
submodules: true
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/build_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ jobs:
os: [ubuntu-18.04] #, windows-latest, macos-latest]

steps:
- name: work around permission issue
run: git config --global --add safe.directory /__w/deepmd-kit/deepmd-kit
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
name: Install Python
Expand All @@ -26,7 +28,7 @@ jobs:
- name: Build wheels
env:
CIBW_BUILD: "cp36-* cp37-* cp38-* cp39-* cp310-*"
CIBW_MANYLINUX_X86_64_IMAGE: ghcr.io/deepmodeling/manylinux2010_x86_64_tensorflow
CIBW_MANYLINUX_X86_64_IMAGE: ghcr.io/deepmodeling/manylinux2014_x86_64_tensorflow
CIBW_BEFORE_BUILD: pip install tensorflow
CIBW_SKIP: "*-win32 *-manylinux_i686 *-musllinux*"
run: |
Expand Down
14 changes: 14 additions & 0 deletions .github/workflows/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: "Pull Request Labeler"
on:
- pull_request_target

jobs:
triage:
permissions:
contents: read
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v4
with:
repo-token: "${{ secrets.GITHUB_TOKEN }}"
2 changes: 2 additions & 0 deletions .github/workflows/lint_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ jobs:
python-version: [3.8]

steps:
- name: work around permission issue
run: git config --global --add safe.directory /__w/deepmd-kit/deepmd-kit
- uses: actions/checkout@v1
- uses: actions/setup-python@v2
with:
Expand Down
8 changes: 7 additions & 1 deletion .github/workflows/test_cc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ jobs:
testpython:
name: Test C++
runs-on: ubuntu-latest
container: ghcr.io/deepmodeling/deepmd-kit-test-cc:latest
steps:
- name: work around permission issue
run: git config --global --add safe.directory /__w/deepmd-kit/deepmd-kit
- uses: actions/checkout@master
- run: source/install/test_cc.sh
- run: source/install/test_cc_local.sh
env:
tensorflow_root: /usr/local
- run: source/install/codecov.sh
2 changes: 2 additions & 0 deletions .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ jobs:

container: ghcr.io/deepmodeling/deepmd-kit-test-environment:py${{ matrix.python }}-gcc${{ matrix.gcc }}-tf${{ matrix.tf }}
steps:
- name: work around permission issue
run: git config --global --add safe.directory /__w/deepmd-kit/deepmd-kit
- uses: actions/checkout@master
- name: pip cache
uses: actions/cache@v2
Expand Down
24 changes: 10 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ A full [document](doc/train/train-input-auto.rst) on options in the training inp
- [Install GROMACS](doc/install/install-gromacs.md)
- [Building conda packages](doc/install/build-conda.md)
- [Data](doc/data/index.md)
- [Data conversion](doc/data/data-conv.md)
- [System](doc/data/system.md)
- [Formats of a system](doc/data/data-conv.md)
- [Prepare data with dpdata](doc/data/dpdata.md)
- [Model](doc/model/index.md)
- [Overall](doc/model/overall.md)
Expand All @@ -99,6 +100,7 @@ A full [document](doc/train/train-input-auto.rst) on options in the training inp
- [Fit `tensor` like `Dipole` and `Polarizability`](doc/model/train-fitting-tensor.md)
- [Train a Deep Potential model using `type embedding` approach](doc/model/train-se-e2-a-tebd.md)
- [Deep potential long-range](doc/model/dplr.md)
- [Deep Potential - Range Correction (DPRc)](doc/model/dprc.md)
- [Training](doc/train/index.md)
- [Training a model](doc/train/training.md)
- [Advanced options](doc/train/training-advanced.md)
Expand All @@ -121,37 +123,31 @@ A full [document](doc/train/train-input-auto.rst) on options in the training inp
- [LAMMPS commands](doc/third-party/lammps-command.md)
- [Run path-integral MD with i-PI](doc/third-party/ipi.md)
- [Run MD with GROMACS](doc/third-party/gromacs.md)
- [Interfaces out of DeePMD-kit](doc/third-party/out-of-deepmd-kit.md)

# Code structure

The code is organized as follows:

* `data/raw`: tools manipulating the raw data files.

* `examples`: examples.

* `deepmd`: DeePMD-kit python modules.

* `source/api_cc`: source code of DeePMD-kit C++ API.

* `source/ipi`: source code of i-PI client.

* `source/lib`: source code of DeePMD-kit library.

* `source/lmp`: source code of Lammps module.

* `source/gmx`: source code of Gromacs plugin.

* `source/op`: tensorflow op implementation. working with library.


# Troubleshooting

- [Model compatibility](doc/troubleshooting/model-compatability.md)
- [Model compatibility](doc/troubleshooting/model_compatability.md)
- [Installation](doc/troubleshooting/installation.md)
- [The temperature undulates violently during early stages of MD](doc/troubleshooting/md-energy-undulation.md)
- [MD: cannot run LAMMPS after installing a new version of DeePMD-kit](doc/troubleshooting/md-version-compatibility.md)
- [Do we need to set rcut < half boxsize?](doc/troubleshooting/howtoset-rcut.md)
- [How to set sel?](doc/troubleshooting/howtoset-sel.md)
- [The temperature undulates violently during early stages of MD](doc/troubleshooting/md_energy_undulation.md)
- [MD: cannot run LAMMPS after installing a new version of DeePMD-kit](doc/troubleshooting/md_version_compatibility.md)
- [Do we need to set rcut < half boxsize?](doc/troubleshooting/howtoset_rcut.md)
- [How to set sel?](doc/troubleshooting/howtoset_sel.md)
- [How to control the number of nodes used by a job?](doc/troubleshooting/howtoset_num_nodes.md)
- [How to tune Fitting/embedding-net size?](doc/troubleshooting/howtoset_netsize.md)

Expand Down
4 changes: 3 additions & 1 deletion deepmd/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class DP(Calculator):
"""

name = "DP"
implemented_properties = ["energy", "forces", "virial", "stress"]
implemented_properties = ["energy", "free_energy", "forces", "virial", "stress"]

def __init__(
self,
Expand Down Expand Up @@ -102,6 +102,8 @@ def calculate(
atype = [self.type_dict[k] for k in symbols]
e, f, v = self.dp.eval(coords=coord, cells=cell, atom_types=atype)
self.results['energy'] = e[0][0]
# see https://gitlab.com/ase/ase/-/merge_requests/2485
self.results['free_energy'] = e[0][0]
self.results['forces'] = f[0]
self.results['virial'] = v[0].reshape(3, 3)

Expand Down
26 changes: 4 additions & 22 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def add_data_requirement(
high_prec: bool = False,
type_sel: bool = None,
repeat: int = 1,
default: float = 0.,
):
"""Specify data requirements for training.

Expand All @@ -116,6 +117,8 @@ def add_data_requirement(
select only certain type of atoms, by default None
repeat : int, optional
if specify repaeat data `repeat` times, by default 1
default : float, optional, default=0.
default value of data
"""
data_requirement[key] = {
"ndof": ndof,
Expand All @@ -124,6 +127,7 @@ def add_data_requirement(
"high_prec": high_prec,
"type_sel": type_sel,
"repeat": repeat,
"default": default,
}


Expand Down Expand Up @@ -444,28 +448,6 @@ def expand_sys_str(root_dir: Union[str, Path]) -> List[str]:
return matches


def docstring_parameter(*sub: Tuple[str, ...]):
"""Add parameters to object docstring.

Parameters
----------
sub: Tuple[str, ...]
list of strings that will be inserted into prepared locations in docstring.

Note
----
Can be used on both object and classes.
"""

@wraps
def dec(obj: "_OBJ") -> "_OBJ":
if obj.__doc__ is not None:
obj.__doc__ = obj.__doc__.format(*sub)
return obj

return dec


def get_np_precision(precision: "_PRECISION") -> np.dtype:
"""Get numpy precision constant from string.

Expand Down
2 changes: 1 addition & 1 deletion deepmd/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def build (self,
dout = tf.reshape(dout, [-1, ii.get_dim_out()])
all_dout.append(dout)
dout = tf.concat(all_dout, axis = 1)
dout = tf.reshape(dout, [-1, natoms[0] * self.get_dim_out()])
dout = tf.reshape(dout, [-1, natoms[0], self.get_dim_out()])
return dout


Expand Down
2 changes: 1 addition & 1 deletion deepmd/descriptor/loc_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def prod_force_virial(self,
"""
[net_deriv] = tf.gradients (atom_ener, self.descrpt)
tf.summary.histogram('net_derivative', net_deriv)
net_deriv_reshape = tf.reshape (net_deriv, [-1, natoms[0] * self.ndescrpt])
net_deriv_reshape = tf.reshape (net_deriv, [np.cast['int64'](-1), natoms[0] * np.cast['int64'](self.ndescrpt)])
force = op_module.prod_force (net_deriv_reshape,
self.descrpt_deriv,
self.nlist,
Expand Down
38 changes: 18 additions & 20 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from typing import Tuple, List, Dict, Any

from deepmd.env import tf
from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision
from deepmd.utils.argcheck import list_to_doc
from deepmd.common import get_activation_func, get_precision, cast_precision
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import GLOBAL_NP_FLOAT_PRECISION
from deepmd.env import op_module
Expand Down Expand Up @@ -88,9 +87,9 @@ class DescrptSeA (DescrptSe):
set_davg_zero
Set the shift of embedding net input to zero.
activation_function
The activation function in the embedding net. Supported options are {0}
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision
The precision of the embedding net parameters. Supported options are {1}
The precision of the embedding net parameters. Supported options are |PRECISION|
uniform_seed
Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed

Expand All @@ -101,7 +100,6 @@ class DescrptSeA (DescrptSe):
systems. In Proceedings of the 32nd International Conference on Neural Information Processing
Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 4441–4451.
"""
@docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys()))
def __init__ (self,
rcut: float,
rcut_smth: float,
Expand Down Expand Up @@ -517,7 +515,7 @@ def prod_force_virial(self,
"""
[net_deriv] = tf.gradients (atom_ener, self.descrpt_reshape)
tf.summary.histogram('net_derivative', net_deriv)
net_deriv_reshape = tf.reshape (net_deriv, [-1, natoms[0] * self.ndescrpt])
net_deriv_reshape = tf.reshape (net_deriv, [np.cast['int64'](-1), natoms[0] * np.cast['int64'](self.ndescrpt)])
force \
= op_module.prod_force_se_a (net_deriv_reshape,
self.descrpt_deriv,
Expand Down Expand Up @@ -553,14 +551,14 @@ def _pass_filter(self,
else:
type_embedding = None
start_index = 0
inputs = tf.reshape(inputs, [-1, self.ndescrpt * natoms[0]])
inputs = tf.reshape(inputs, [-1, natoms[0], self.ndescrpt])
output = []
output_qmat = []
if not (self.type_one_side and len(self.exclude_types) == 0) and type_embedding is None:
for type_i in range(self.ntypes):
inputs_i = tf.slice (inputs,
[ 0, start_index* self.ndescrpt],
[-1, natoms[2+type_i]* self.ndescrpt] )
[ 0, start_index, 0],
[-1, natoms[2+type_i], -1] )
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
if self.type_one_side:
# reuse NN parameters for all types to support type_one_side along with exclude_types
Expand All @@ -569,8 +567,8 @@ def _pass_filter(self,
else:
filter_name = 'filter_type_'+str(type_i)+suffix
layer, qmat = self._filter(inputs_i, type_i, name=filter_name, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn)
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3])
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i], self.get_dim_out()])
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i], self.get_dim_rot_mat_1() * 3])
output.append(layer)
output_qmat.append(qmat)
start_index += natoms[2+type_i]
Expand All @@ -579,8 +577,8 @@ def _pass_filter(self,
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
type_i = -1
layer, qmat = self._filter(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn, type_embedding=type_embedding)
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()])
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3])
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0], self.get_dim_out()])
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0], self.get_dim_rot_mat_1() * 3])
output.append(layer)
output_qmat.append(qmat)
output = tf.concat(output, axis = 1)
Expand Down Expand Up @@ -635,7 +633,7 @@ def _compute_dstats_sys_smth (self,

def _compute_std (self,sumv2, sumv, sumn) :
if sumn == 0:
return 1e-2
return 1. / self.rcut_r
val = np.sqrt(sumv2/sumn - np.multiply(sumv/sumn, sumv/sumn))
if np.abs(val) < 1e-2:
val = 1e-2
Expand Down Expand Up @@ -720,12 +718,12 @@ def _filter_lower(
raise RuntimeError('compression of type embedded descriptor is not supported at the moment')
# natom x 4 x outputs_size
if self.compress and (not is_exclude):
info = [self.lower, self.upper, self.upper * self.table_config[0], self.table_config[1], self.table_config[2], self.table_config[3]]
if self.type_one_side:
net = 'filter_-1_net_' + str(type_i)
else:
net = 'filter_' + str(type_input) + '_net_' + str(type_i)
return op_module.tabulate_fusion_se_a(tf.cast(self.table.data[net], self.filter_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
if self.type_one_side:
net = 'filter_-1_net_' + str(type_i)
else:
net = 'filter_' + str(type_input) + '_net_' + str(type_i)
info = [self.lower[net], self.upper[net], self.upper[net] * self.table_config[0], self.table_config[1], self.table_config[2], self.table_config[3]]
return op_module.tabulate_fusion_se_a(tf.cast(self.table.data[net], self.filter_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
else:
if (not is_exclude):
# with (natom x nei_type_i) x out_size
Expand Down
4 changes: 2 additions & 2 deletions deepmd/descriptor/se_a_ebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@ def _pass_filter(self,
seed = self.seed,
trainable = trainable,
activation_fn = self.filter_activation_fn)
output = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()])
output_qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3])
output = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0], self.get_dim_out()])
output_qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0], self.get_dim_rot_mat_1() * 3])
return output, output_qmat


Expand Down
10 changes: 4 additions & 6 deletions deepmd/descriptor/se_a_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from typing import Tuple, List

from deepmd.env import tf
from deepmd.common import add_data_requirement,get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter
from deepmd.utils.argcheck import list_to_doc
from deepmd.common import add_data_requirement
from deepmd.utils.sess import run_sess
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import GLOBAL_NP_FLOAT_PRECISION
Expand Down Expand Up @@ -43,13 +42,12 @@ class DescrptSeAEf (Descriptor):
set_davg_zero
Set the shift of embedding net input to zero.
activation_function
The activation function in the embedding net. Supported options are {0}
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision
The precision of the embedding net parameters. Supported options are {1}
The precision of the embedding net parameters. Supported options are |PRECISION|
uniform_seed
Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
"""
@docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys()))
def __init__(self,
rcut: float,
rcut_smth: float,
Expand Down Expand Up @@ -230,7 +228,7 @@ def build (self,
self.dout_vert = tf.reshape(self.dout_vert, [nframes * natoms[0], self.descrpt_vert.get_dim_out()])
self.dout_para = tf.reshape(self.dout_para, [nframes * natoms[0], self.descrpt_para.get_dim_out()])
self.dout = tf.concat([self.dout_vert, self.dout_para], axis = 1)
self.dout = tf.reshape(self.dout, [nframes, natoms[0] * self.get_dim_out()])
self.dout = tf.reshape(self.dout, [nframes, natoms[0], self.get_dim_out()])
self.qmat = self.descrpt_vert.qmat + self.descrpt_para.qmat

tf.summary.histogram('embedding_net_output', self.dout)
Expand Down
Loading