Skip to content

Commit 8e254d7

Browse files
committed
fix scipy gmres dtype error
1 parent 6a5ee81 commit 8e254d7

File tree

7 files changed

+50
-28
lines changed

7 files changed

+50
-28
lines changed

pyscfad/_src/scipy/sparse/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def _matvec(x):
1111
# NOTE result may not be writable
1212
# (required by scipy>=1.12), so make a copy
1313
return numpy.array(ops.to_numpy(Ax))
14-
A = LinearOperator((b.size, b.size), matvec=_matvec)
14+
A = LinearOperator((b.size, b.size), matvec=_matvec, dtype=b.dtype)
1515
return A
1616

1717
def gmres(A_or_matvec, b, x0=None, *,

pyscfad/cc/test/conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,20 @@ def get_mol():
1717
yield mol
1818

1919
config.reset()
20+
21+
@pytest.fixture
22+
def get_opt_mol():
23+
config.update("pyscfad_moleintor_opt", True)
24+
config.update('pyscfad_scf_implicit_diff', True)
25+
config.update('pyscfad_ccsd_implicit_diff', True)
26+
27+
mol = gto.Mole()
28+
mol.atom = 'H 0. 0. 0.; F 0. 0. 1.1'
29+
mol.basis = '631g'
30+
mol.verbose = 0
31+
mol.incore_anyway = True
32+
mol.max_memory = 7000
33+
mol.build(trace_exp=False, trace_ctr_coeff=False)
34+
yield mol
35+
36+
config.reset()
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
1-
import pytest
21
import numpy
32
import jax
4-
from pyscfad import gto, scf, cc
5-
from pyscfad import config
6-
7-
config.update('pyscfad_scf_implicit_diff', True)
8-
config.update('pyscfad_ccsd_implicit_diff', True)
3+
from pyscfad import scf, cc
94

105
def test_nuc_grad(get_mol):
116
mol = get_mol
@@ -16,8 +11,7 @@ def energy(mol):
1611
mycc.kernel()
1712
et = mycc.ccsd_t()
1813
return mycc.e_tot + et
19-
with jax.disable_jit():
20-
g1 = jax.grad(energy)(mol).coords
14+
g1 = jax.grad(energy)(mol).coords
2115
g0 = numpy.array([[0., 0., -8.60709468e-02],
2216
[0., 0., 8.60709468e-02]])
2317
assert(abs(g1-g0).max() < 1e-6)

pyscfad/cc/test/test_dcsd.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
import numpy
32
import jax
43
from pyscfad import scf
@@ -12,8 +11,7 @@ def energy(mol):
1211
mycc = dfdcsd.RDCSD(mf)
1312
mycc.kernel()
1413
return mycc.e_tot
15-
with jax.disable_jit():
16-
g1 = jax.grad(energy)(mol).coords
14+
g1 = jax.grad(energy)(mol).coords
1715
# finite difference
1816
g0 = numpy.array([[0., 0., -0.08500490828],
1917
[0., 0., 0.08500490828]])

pyscfad/cc/test/test_opt_dfccsd.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy
2+
import jax
3+
from pyscfad import scf
4+
from pyscfad.cc import dfccsd
5+
6+
def test_dfccsdt_nuc_grad(get_opt_mol):
7+
mol = get_opt_mol
8+
def energy(mol):
9+
mf = scf.RHF(mol).density_fit()
10+
mf.kernel()
11+
12+
mycc = dfccsd.RCCSD(mf)
13+
eris = mycc.ao2mo()
14+
mycc.kernel(eris=eris)
15+
et = mycc.ccsd_t(eris=eris)
16+
return mycc.e_tot + et
17+
18+
e, jac = jax.value_and_grad(energy)(mol)
19+
20+
e0 = -100.10156178822595
21+
assert abs(e - e0) < 1e-6
22+
23+
g0 = numpy.array([[0., 0., -0.0860735932],
24+
[0., 0., 0.0860735932]])
25+
assert abs(jac.coords-g0).max() < 1e-6

pyscfad/cc/test/test_rccsd.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
import numpy
32
import jax
43
from pyscfad import scf, cc
@@ -11,8 +10,7 @@ def energy(mol):
1110
mycc = cc.RCCSD(mf)
1211
mycc.kernel()
1312
return mycc.e_tot
14-
with jax.disable_jit():
15-
g1 = jax.grad(energy)(mol).coords
13+
g1 = jax.grad(energy)(mol).coords
1614
g0 = numpy.array([[0., 0., -0.0873564848],
1715
[0., 0., 0.0873564848]])
1816
assert(abs(g1-g0).max() < 1e-6)
@@ -25,8 +23,7 @@ def energy(mol):
2523
mycc = cc.dfccsd.RCCSD(mf)
2624
mycc.kernel()
2725
return mycc.e_tot
28-
with jax.disable_jit():
29-
g1 = jax.grad(energy)(mol).coords
26+
g1 = jax.grad(energy)(mol).coords
3027
# finite difference
3128
g0 = numpy.array([[0., 0., -0.0873569023],
3229
[0., 0., 0.0873569023]])

pyscfad/cc/test/test_rccsd_hess.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
1-
import pytest
21
import numpy
32
import jax
4-
from pyscf import lib
5-
from pyscfad import gto, scf, cc
6-
from pyscfad import config
7-
8-
config.update('pyscfad_scf_implicit_diff', True)
9-
config.update('pyscfad_ccsd_implicit_diff', True)
3+
from pyscfad import scf, cc
104

115
def test_nuc_hessian(get_mol):
126
mol = get_mol
@@ -16,8 +10,7 @@ def energy(mol):
1610
mycc = cc.RCCSD(mf)
1711
mycc.kernel()
1812
return mycc.e_tot
19-
with jax.disable_jit():
20-
h1 = jax.jacrev(jax.jacrev(energy))(mol).coords.coords
13+
h1 = jax.jacrev(jax.jacrev(energy))(mol).coords.coords
2114
h0 = numpy.array(
2215
[[[[ 4.20246014e-02, 0, 0],
2316
[-4.20246014e-02, 0, 0]],
@@ -33,5 +26,3 @@ def energy(mol):
3326
[0, 0, 1.53241780e-01]]]]
3427
)
3528
assert(abs(h1-h0).max() < 5e-5)
36-
#f = lib.fp(h1)
37-
#assert(abs(f - -0.18529155578401263) < 1e-6)

0 commit comments

Comments
 (0)