Skip to content

Commit 9e4cb87

Browse files
authored
add lno module (#91)
1 parent fa13b21 commit 9e4cb87

File tree

19 files changed

+2402
-3
lines changed

19 files changed

+2402
-3
lines changed

.github/workflows/install_pyscf.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ python -m pip cache purge
44
pip install wheel
55
pip install pytest
66
pip install pytest-cov
7-
pip install jax
8-
pip install pyscf
7+
pip install 'numpy<2.4'
8+
pip install 'jax<0.9'
9+
pip install 'pyscf<2.12.0'

.github/workflows/run_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ export OMP_NUM_THREADS=1
77

88
coverage erase
99

10-
MODULES=("scipy" "gto" "cc" "fci" "gw" "mp" "tdscf" "lo" "pbc")
10+
MODULES=("scipy" "gto" "cc" "fci" "gw" "mp" "tdscf" "lo" "pbc" "lno")
1111

1212
FAILED=0
1313

examples/lno/00-lno_mp2.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
'''LNO-MP2 energy and gradient.
2+
'''
3+
import jax
4+
from pyscfad import gto, scf
5+
from pyscfad.mp import dfmp2
6+
from pyscfad import config
7+
from pyscfad.lno import LNOMP2
8+
9+
# use optimized backpropagation
10+
config.update('pyscfad_moleintor_opt', True)
11+
config.update('pyscfad_scf_implicit_diff', True)
12+
config.update('pyscfad_ccsd_implicit_diff', True)
13+
14+
atom = 'water_dimer.xyz'
15+
basis = 'ccpvdz'
16+
frozen = None
17+
18+
mol = gto.Mole(atom=atom, basis=basis)
19+
mol.verbose = 4
20+
mol.build(trace_exp=False, trace_ctr_coeff=False)
21+
22+
# canonical MP2
23+
def mp2_energy(mol):
24+
mf = scf.RHF(mol).density_fit()
25+
mf.kernel()
26+
27+
mymp = dfmp2.MP2(mf, frozen=frozen)
28+
mymp.kernel()
29+
return mymp.e_tot
30+
31+
e_mp2, jac_mp2 = jax.value_and_grad(mp2_energy)(mol)
32+
33+
# LNO-MP2
34+
thresh = 1e-4
35+
def lno_mp2_energy(mol):
36+
mf = scf.RHF(mol).density_fit()
37+
ehf = mf.kernel()
38+
39+
mfcc = LNOMP2(mf, thresh=thresh, frozen=frozen)
40+
mfcc.thresh_occ = thresh
41+
mfcc.thresh_vir = thresh
42+
mfcc.lo_type = 'iao'
43+
mfcc.kernel(frag_lolist=None)
44+
return ehf + mfcc.e_corr
45+
46+
e_lno_mp2, jac_lno_mp2 = jax.value_and_grad(lno_mp2_energy)(mol)
47+
48+
print(e_mp2, e_lno_mp2, e_mp2-e_lno_mp2)
49+
print(jac_lno_mp2.coords)
50+
print(abs(jac_lno_mp2.coords - jac_mp2.coords).max())

examples/lno/01-lno_ccsd_t.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
'''LNO-CCSD(T) energy and gradient
2+
'''
3+
import jax
4+
from pyscfad import gto, scf, mp, cc
5+
from pyscfad.cc import dfccsd
6+
from pyscfad import config
7+
from pyscfad.lno import LNOCCSD
8+
9+
config.update('pyscfad_moleintor_opt', True)
10+
config.update('pyscfad_scf_implicit_diff', True)
11+
config.update('pyscfad_ccsd_implicit_diff', True)
12+
13+
atom = 'water_dimer.xyz'
14+
basis = 'ccpvdz'
15+
frozen = 2
16+
17+
mol = gto.Mole(atom=atom, basis=basis)
18+
mol.verbose = 4
19+
mol.build(trace_exp=False, trace_ctr_coeff=False)
20+
21+
# canonical CCSD(T)
22+
mf = scf.RHF(mol).density_fit()
23+
mf.kernel()
24+
mycc = dfccsd.RCCSD(mf, frozen=frozen)
25+
eris = mycc.ao2mo()
26+
mycc.kernel(eris=eris)
27+
et = mycc.ccsd_t(eris=eris)
28+
29+
# LNO-CCSD(T)
30+
thresh = 1e-4
31+
def energy(mol):
32+
mf = scf.RHF(mol).density_fit()
33+
ehf = mf.kernel()
34+
35+
mmp = mp.dfmp2.MP2(mf, frozen=frozen)
36+
mmp.kernel(with_t2=False)
37+
38+
mfcc = LNOCCSD(mf, thresh=thresh, frozen=frozen)
39+
mfcc.thresh_occ = thresh
40+
mfcc.thresh_vir = thresh
41+
mfcc.lo_type = 'iao'
42+
mfcc.no_type = 'ie'
43+
mfcc.ccsd_t = True
44+
mfcc.kernel(frag_lolist=None)
45+
46+
ecc_pt2corrected = mfcc.e_corr_pt2corrected(mmp.e_corr)
47+
return ehf + ecc_pt2corrected
48+
49+
e, jac = jax.value_and_grad(energy)(mol)
50+
print(e, mycc.e_tot+et, e-(mycc.e_tot+et))
51+
print(jac.coords)

examples/lno/11-mpi_lno_ccsd_t.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
'''LNO-CCSD(T) with MPI parallelization.
2+
3+
run with:
4+
mpirun -n 2 python 11-mpi_lno_ccsd_t.py
5+
'''
6+
from mpi4py import MPI
7+
import jax
8+
import numpy
9+
from pyscfad import gto, scf, mp
10+
from pyscfad.cc import dfccsd
11+
from pyscfad import config
12+
from pyscfad.lno.ccsd_mpi import LNOCCSD
13+
14+
config.update('pyscfad_moleintor_opt', True)
15+
config.update('pyscfad_scf_implicit_diff', True)
16+
config.update('pyscfad_ccsd_implicit_diff', True)
17+
18+
atom = 'water_dimer.xyz'
19+
basis = 'ccpvdz'
20+
21+
mol = gto.Mole(atom=atom, basis=basis)
22+
mol.verbose = 4
23+
mol.build(trace_exp=False, trace_ctr_coeff=False)
24+
25+
comm = MPI.COMM_WORLD
26+
rank = comm.Get_rank()
27+
frozen = 2
28+
thresh_occ = 1e-3
29+
thresh_vir = 1e-4
30+
def energy(mol):
31+
mf = scf.RHF(mol).density_fit()
32+
ehf = mf.kernel()
33+
34+
mfcc = LNOCCSD(mf, frozen=frozen)
35+
mfcc.thresh_occ = thresh_occ
36+
mfcc.thresh_vir = thresh_vir
37+
mfcc.lo_type = 'iao'
38+
mfcc.ccsd_t = True
39+
mfcc.kernel(frag_lolist=None)
40+
41+
if rank == 0:
42+
mmp = mp.dfmp2.MP2(mf, frozen=frozen)
43+
mmp.kernel(with_t2=False)
44+
ecc_pt2corrected = mfcc.e_corr_pt2corrected(mmp.e_corr)
45+
etot = ehf + ecc_pt2corrected
46+
else:
47+
etot = mfcc.e_corr - mfcc.e_corr_pt2
48+
return etot
49+
50+
e, jac = jax.value_and_grad(energy)(mol)
51+
e = numpy.asarray(e)
52+
grad = numpy.asarray(jac.coords)
53+
54+
if rank == 0:
55+
etot = numpy.zeros_like(e)
56+
grad_tot = numpy.zeros_like(grad)
57+
else:
58+
etot = None
59+
grad_tot = None
60+
61+
comm.Reduce([e, MPI.DOUBLE], [etot, MPI.DOUBLE],
62+
op=MPI.SUM, root=0)
63+
64+
comm.Reduce([grad, MPI.DOUBLE], [grad_tot, MPI.DOUBLE],
65+
op=MPI.SUM, root=0)
66+
67+
if rank == 0:
68+
print(f'LNO-CCSD(T) energy: {etot}')
69+
print(f'LNO-CCSD(T) gradient:\n{grad_tot}')

examples/lno/water_dimer.xyz

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
6
2+
3+
O -1.485163346097 -0.114724564047 0.000000000000
4+
H -1.868415346097 0.762298435953 0.000000000000
5+
H -0.533833346097 0.040507435953 0.000000000000
6+
O 1.416468653903 0.111264435953 0.000000000000
7+
H 1.746241653903 -0.373945564047 -0.758561000000
8+
H 1.746241653903 -0.373945564047 0.758561000000

pyscfad/lno/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2023-2026 The PySCFAD Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
LNO methods
17+
"""
18+
from pyscfad.lno.mp2 import LNOMP2
19+
from pyscfad.lno.ccsd import LNOCCSD, LNOCCSD_T

pyscfad/lno/_checkpointed.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2023-2026 The PySCFAD Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import jax
16+
from pyscfad import numpy as np
17+
18+
def make_mp2_rdm1_ie(Lia, Ljb, eia, ejb):
19+
naux1, nocc1, nvir1 = Lia.shape
20+
naux2, nocc2, nvir2 = Ljb.shape
21+
assert naux1 == naux2
22+
assert nvir1 == nvir2
23+
naux = naux1
24+
nvir = nvir1
25+
26+
@jax.checkpoint
27+
def fn(carry, x):
28+
dmvv, dmoo = carry
29+
La, ea = x
30+
buf = np.dot(La.T, Ljb.reshape(naux,-1)).reshape(nvir, nocc2, nvir)
31+
t2i = buf / (ea[:,None,None] + ejb[None,:,:])
32+
dmvv += np.dot(t2i.reshape(nvir, -1), t2i.reshape(nvir, -1).T)
33+
dmvv -= .5 * np.einsum('ajc,cjb->ab', t2i, t2i)
34+
dmvv += np.dot(t2i.reshape(-1,nvir).T, t2i.reshape(-1,nvir))
35+
dmvv -= .5 * np.einsum('cja,bjc->ab', t2i, t2i)
36+
37+
dmoo += np.einsum('aib,ajb->ij', t2i, t2i)
38+
dmoo -= .5 * np.einsum('aib,bja->ij', t2i, t2i)
39+
dmoo += np.einsum('bia,bja->ij', t2i, t2i)
40+
dmoo -= .5 * np.einsum('bia,ajb->ij', t2i, t2i)
41+
return (dmvv, dmoo), None
42+
43+
(dmvv, dmoo), _ = jax.lax.scan(fn, (np.zeros((nvir, nvir)), np.zeros((nocc2,nocc2))),
44+
(Lia.transpose(1,0,2), eia))
45+
return dmvv, dmoo

0 commit comments

Comments
 (0)