Skip to content

Commit 2c9a779

Browse files
authored
fix optimized DF-CCSD(T) (#78)
1 parent 9fe99a3 commit 2c9a779

File tree

4 files changed

+77
-23
lines changed

4 files changed

+77
-23
lines changed

examples/cc/02-df-ccsd_t.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
Nuclear gradient of density-fitting CCSD(T)
3+
4+
Reference:
5+
CCSD(T) total energy: -76.30586034
6+
Nuclear gradient:
7+
[[ 5.81461842e-02 6.30418580e-17 8.76964385e-15]
8+
[-2.90730921e-02 -3.65174463e-17 -1.27052571e-01]
9+
[-2.90730921e-02 -2.65244117e-17 1.27052571e-01]]
10+
"""
11+
import jax
12+
from pyscfad import gto, scf
13+
from pyscfad.cc import dfccsd
14+
from pyscfad import config
15+
16+
# Setting `pyscfad_moleintor_opt` to `True` will use the
17+
# efficient back-propagation CPU implementation. However, only
18+
# 1st order derivative is available.
19+
config.update("pyscfad_moleintor_opt", True)
20+
config.update("pyscfad_scf_implicit_diff", True)
21+
config.update("pyscfad_ccsd_implicit_diff", True)
22+
23+
mol = gto.Mole()
24+
mol.atom = '''
25+
O 0.000000 0.000000 0.000000
26+
H 0.758602 0.000000 0.504284
27+
H 0.758602 0.000000 -0.504284
28+
'''
29+
mol.basis = 'aug-ccpvtz'
30+
mol.verbose = 4
31+
mol.incore_anyway = True
32+
mol.max_memory = 16000
33+
mol.build(trace_exp=False, trace_ctr_coeff=False)
34+
35+
def energy(mol):
36+
mf = scf.RHF(mol).density_fit()
37+
mf.kernel()
38+
mycc = dfccsd.RCCSD(mf)
39+
eris = mycc.ao2mo()
40+
mycc.kernel(eris=eris)
41+
et = mycc.ccsd_t(eris=eris)
42+
return mycc.e_tot + et
43+
44+
e, jac = jax.value_and_grad(energy)(mol)
45+
print(f"CCSD(T) total energy: {e:.8f}")
46+
print("Nuclear gradient:")
47+
print(jac.coords)

pyscfad/backend/_jax/pytree.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,12 @@ class PytreeNodeMeta(type):
110110
def __new__(mcls, name, bases, dct, **kwargs):
111111
cls = super().__new__(mcls, name, bases, dct, **kwargs)
112112

113-
_dynamic_attr = set()
114-
for base in cls.__mro__:
113+
# preserve the order of the attributes
114+
_dynamic_attr = []
115+
for base in reversed(cls.__mro__):
115116
if hasattr(base, '_dynamic_attr'):
116-
_dynamic_attr |= set(base._dynamic_attr)
117-
_dynamic_attr = tuple(_dynamic_attr)
117+
_dynamic_attr.extend(base._dynamic_attr)
118+
_dynamic_attr = tuple(dict.fromkeys(_dynamic_attr))
118119

119120
def _flatten(obj, keys=(), with_keys=False):
120121
if keys:

pyscfad/cc/ccsd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from pyscfad.tools.linear_solver import gen_gmres
1414

1515
# attributes explicitly appearing in :fun:`update_amps` are dynamic
16-
ERI_Tracers = {'fock', 'mo_energy',
17-
'oooo', 'ovoo', 'ovov', 'oovv', 'ovvo', 'ovvv', 'vvvv'}
16+
ERI_Tracers = ('fock', 'mo_energy',
17+
'oooo', 'ovoo', 'ovov', 'oovv', 'ovvo', 'ovvv', 'vvvv')
1818

1919
def _converged_iter(amp, mycc, eris):
2020
t1, t2 = mycc.vector_to_amplitudes(amp)

pyscfad/cc/ccsd_t.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
import ctypes
22
import numpy
33
from jax import custom_vjp
4-
from jax.tree_util import tree_flatten, tree_unflatten
4+
from jax.tree_util import tree_flatten_with_path, tree_unflatten
55
from pyscf.lib import (
66
prange_tril,
77
num_threads,
88
current_memory,
9-
# load_library
109
)
1110
from pyscf.cc import ccsd_t as pyscf_ccsd_t
1211
from pyscfad.lib import logger
13-
#libcc = load_library('libcc')
1412
from pyscfadlib import libcc_vjp as libcc
1513

1614
def kernel(mycc, eris, t1=None, t2=None, verbose=logger.NOTE):
@@ -36,24 +34,32 @@ def _ccsd_t_kernel_fwd(eris, t1, t2):
3634
def _ccsd_t_kernel_bwd(res, ybar):
3735
eris, t1, t2 = res
3836

39-
leaves, tree = tree_flatten(eris)
40-
assert len(leaves) == 9
41-
shapes = [leaf.shape for leaf in leaves]
42-
del leaves
37+
# TODO clean up tree unflatten
38+
path_vals, treedef = tree_flatten_with_path(eris)
39+
keys = [item[0][0].name for item in path_vals]
40+
shapes = [item[1].shape for item in path_vals]
41+
path_vals = None
4342

4443
t1_bar, t2_bar, fock_bar, mo_energy_bar,\
4544
ovoo_bar, ovov_bar, ovvv_bar = _ccsd_t_energy_vjp(eris, t1, t2, ybar, max_memory)
4645

47-
leaves = [fock_bar,
48-
mo_energy_bar,
49-
numpy.zeros(shapes[2]),
50-
ovoo_bar,
51-
ovov_bar,
52-
numpy.zeros(shapes[5]),
53-
numpy.zeros(shapes[6]),
54-
ovvv_bar,
55-
numpy.zeros(shapes[8])]
56-
eris_bar = tree_unflatten(tree, leaves)
46+
leaves = [None] * len(keys)
47+
key_to_bar = {
48+
'fock': fock_bar,
49+
'mo_energy': mo_energy_bar,
50+
'ovoo': ovoo_bar,
51+
'ovov': ovov_bar,
52+
'ovvv': ovvv_bar,
53+
}
54+
55+
for k, val in key_to_bar.items():
56+
leaves[keys.index(k)] = val
57+
58+
for i, leaf in enumerate(leaves):
59+
if leaf is None:
60+
leaves[i] = numpy.zeros(shapes[i])
61+
62+
eris_bar = tree_unflatten(treedef, leaves)
5763
return eris_bar, t1_bar, t2_bar
5864

5965
_ccsd_t_kernel.defvjp(_ccsd_t_kernel_fwd, _ccsd_t_kernel_bwd)

0 commit comments

Comments
 (0)