Skip to content

Commit 9f4be4b

Browse files
committed
refactor xtb
1 parent 79c0ff3 commit 9f4be4b

File tree

12 files changed

+682
-624
lines changed

12 files changed

+682
-624
lines changed

pyscfad/backend/_jax/lax/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _eigh_gen_jvp_rule(primals, tangents, *, lower, itype, deg_thresh):
105105
precision=lax.Precision.HIGHEST)
106106

107107
if type(at) is ad_util.Zero:
108-
vt_at_v = lax.zeros_like_array(a)
108+
vt_at_v = jnp.zeros_like(a)
109109
else:
110110
vt_at_v = dot(_H(v), dot(at, v))
111111

pyscfad/ml/scf/hf_pad.py

Lines changed: 11 additions & 209 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021-2025 The PySCFAD Authors
1+
# Copyright 2025-2026 The PySCFAD Authors
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -14,224 +14,26 @@
1414

1515
import numpy
1616
from pyscfad import numpy as np
17-
from pyscfad.lib import logger
1817
from pyscfad.scf import hf_lite as hf
1918
from pyscfad.scipy.linalg import eigh
20-
from pyscfad.ops import stop_grad
21-
22-
from functools import partial
23-
from jax import custom_jvp
24-
from jax.lax import while_loop
25-
26-
def _fermi_entropy(occ):
27-
occ = occ / 2.0
28-
occ_safe = np.where(np.logical_or(occ < 1e-10, occ > 1 - 1e-10), 0.5, occ)
29-
ent_term = occ_safe * np.log(occ_safe) + (1 - occ_safe) * np.log(1 - occ_safe)
30-
31-
return -2 * np.where(
32-
np.logical_or(occ < 1e-10, occ > 1 - 1e-10),
33-
0.,
34-
ent_term
35-
).sum()
36-
37-
def _fermi_smearing_occ(mu, mo_energy, sigma, mo_mask):
38-
de = (mo_energy - mu) / sigma
39-
de_ = np.where(np.less(de, 40.), de, 0)
40-
occ = np.where(np.less(de, 40.), 1. / (np.exp(de_) + 1.), 0.)
41-
occ = np.where(mo_mask, occ, 0)
42-
return occ
43-
44-
@custom_jvp
45-
def _smearing_solve_mu(mo_es, nocc, sigma, mo_mask):
46-
def cond_fun(value):
47-
mu, nerr = value
48-
return nerr**2 > 1e-12
49-
50-
def body_fun(value):
51-
''' One Halley step '''
52-
mu, nerr = value
53-
occ = _fermi_smearing_occ(mu, mo_es, sigma, mo_mask)
54-
grad = occ * (1.-occ) / sigma
55-
hess = grad * (1.-2*occ) / sigma
56-
occ = np.sum(occ)
57-
nerr = occ - nocc
58-
grad = np.sum(grad)
59-
hess = np.sum(hess)
60-
dmu = -nerr * grad / (grad**2 - .5 * hess * nerr)
61-
return mu + dmu, nerr
62-
63-
mu, nerr = while_loop(cond_fun, body_fun, (mo_es[nocc-1], 1e2))
64-
return np.array([mu])
65-
66-
@_smearing_solve_mu.defjvp
67-
def _smearing_solve_mu_jvp(primals, tangents):
68-
mo_es, nocc, sigma, mo_mask = primals
69-
dmo_e, _, _, _ = tangents
70-
71-
mu = _smearing_solve_mu(mo_es, nocc, sigma, mo_mask)
72-
occ = _fermi_smearing_occ(mu, mo_es, sigma, mo_mask)
73-
dndmu = occ * (1.-occ) / sigma
74-
dndmu = np.where(np.abs(dndmu) < 1e-10, 0., dndmu)
75-
return mu, np.dot(dndmu, dmo_e)[None] / (1e-13 + np.sum(dndmu))
76-
77-
def _smearing_optimize(mo_es, nocc, sigma, mo_mask):
78-
mu = _smearing_solve_mu(mo_es, nocc, sigma, mo_mask)
79-
mo_occs = _fermi_smearing_occ(mu, mo_es, sigma, mo_mask)
80-
return mu, mo_occs
81-
82-
def make_mo_mask(mo_energy, mo_coeff, ao_mask):
83-
mask_fake_ao = np.asarray(1 - ao_mask, dtype=bool)
84-
mo_coeff_fake_ao = np.where(mask_fake_ao[:,None], mo_coeff, 0)
85-
tmp = np.linalg.norm(mo_coeff_fake_ao, axis=0)
86-
#mask = np.where(np.logical_and(mo_energy>1e8, tmp>1e-12), False, True)
87-
mask = np.where(np.logical_and(abs(mo_energy)<1e-12, tmp>1e-12), False, True)
88-
return mask
89-
90-
def get_occ(mf, mo_energy=None, mo_coeff=None):
91-
# NOTE assuming mo_energy is in ascending order
92-
if mo_energy is None:
93-
mo_energy = mf.mo_energy
94-
if mo_coeff is None:
95-
mo_coeff = mf.mo_coeff
96-
nmo = mo_energy.size
97-
#e_sort = mo_energy
98-
99-
mask = make_mo_mask(mo_energy, mo_coeff, mf.mol.ao_mask)
100-
nocc = mf.tot_electrons // 2
101-
102-
if mf.sigma is not None and mf.sigma > 0:
103-
mu, mo_occ = _smearing_optimize(mo_energy, stop_grad(nocc), stop_grad(mf.sigma), stop_grad(mask))
104-
mo_occ *= 2
105-
else:
106-
pick = (np.cumsum(mask) <= nocc) & mask
107-
mo_occ = np.where(pick, 2., 0.)
108-
109-
e_homo = np.max(np.where(pick, mo_energy, -np.inf))
110-
e_lumo = np.min(np.where(mask & ~pick, mo_energy, np.inf))
111-
if mf.verbose >= logger.DEBUG:
112-
logger.debug(mf, " HOMO = %.15g LUMO = %.15g", e_homo, e_lumo)
113-
114-
if mf.verbose >= logger.DEBUG:
115-
numpy.set_printoptions(threshold=nmo)
116-
logger.debug(mf, " mo_energy =\n%s", mo_energy)
117-
numpy.set_printoptions(threshold=1000)
118-
return mo_occ
119-
120-
def get_homo_lumo_energy(mf, mo_energy=None, mo_coeff=None):
121-
if mf.sigma is not None and mf.sigma > 0:
122-
raise NotImplementedError
123-
124-
if mo_energy is None:
125-
mo_energy = mf.mo_energy
126-
if mo_coeff is None:
127-
mo_coeff = mf.mo_coeff
128-
129-
mask = make_mo_mask(mo_energy, mo_coeff, mf.mol.ao_mask)
130-
nocc = mf.tot_electrons // 2
131-
pick = (np.cumsum(mask) <= nocc) & mask
132-
133-
e_homo = np.max(np.where(pick, mo_energy, -np.inf))
134-
e_lumo = np.min(np.where(mask & ~pick, mo_energy, np.inf))
135-
if mf.verbose >= logger.DEBUG:
136-
logger.debug(mf, " HOMO = %.15g LUMO = %.15g", e_homo, e_lumo)
137-
return e_homo, e_lumo
138-
139-
def make_rdm1(mo_coeff, mo_occ, **kwargs):
140-
dm = (mo_coeff * mo_occ) @ mo_coeff.conj().T
141-
return dm
142-
143-
def get_grad(mo_coeff, mo_occ, fock_ao):
144-
fock_mo = mo_coeff.conj().T @ fock_ao @ mo_coeff
145-
mask = np.where(mo_occ > 0, True, False)
146-
g = 2 * fock_mo * ((1-mask[:,None]) * mask[None,:])
147-
return g.ravel()
14819

14920
class SCF(hf.SCF):
150-
def __init__(self, mol, **kwargs):
151-
super().__init__(mol, **kwargs)
152-
self.sigma = None
153-
154-
def get_occ(self, mo_energy=None, mo_coeff=None):
155-
return get_occ(self, mo_energy=mo_energy, mo_coeff=mo_coeff)
156-
157-
@property
158-
def tot_electrons(self):
159-
return self.mol.tot_electrons()
160-
161-
def make_rdm1(self, mo_coeff=None, mo_occ=None, **kwargs):
162-
if mo_coeff is None:
163-
mo_coeff = self.mo_coeff
164-
if mo_occ is None:
165-
mo_occ = self.mo_occ
166-
return make_rdm1(mo_coeff, mo_occ, **kwargs)
167-
168-
def get_grad(self, mo_coeff, mo_occ, fock=None):
169-
if fock is None:
170-
dm1 = self.make_rdm1(mo_coeff, mo_occ)
171-
fock = self.get_hcore(self.mol) + self.get_veff(self.mol, dm1)
172-
return get_grad(mo_coeff, mo_occ, fock)
21+
# NOTE too large shift will give large errors
22+
padding_level_shift = 1e6
17323

17424
def _eigh(self, h, s):
17525
ao_mask = self.mol.ao_mask
17626
mask = np.asarray(1 - ao_mask, dtype=np.int32)
17727
s = s + np.diag(mask)
178-
#h = np.where(np.outer(ao_mask, ao_mask), h, 0.)
179-
#h1e_diag = np.diag(h)
180-
#h1e_diag = np.where(ao_mask, h1e_diag, 1e10)
181-
#h = np.fill_diagonal(h, h1e_diag, inplace=False)
28+
h = np.where(np.outer(ao_mask, ao_mask), h, 0.)
29+
h = h + np.diag(mask) * self.padding_level_shift
18230
return eigh(h, s)
18331

184-
get_homo_lumo_energy = get_homo_lumo_energy
32+
def mo_mask(self, mo_energy=None, mo_coeff=None):
33+
if mo_energy is None:
34+
mo_energy = self.mo_energy
18535

186-
SCFPad = SCF
36+
thr = .99 * self.padding_level_shift
37+
return np.where(np.greater(mo_energy, thr), False, True)
18738

188-
189-
if __name__ == "__main__":
190-
import numpy
191-
import jax
192-
from pyscfad.ml.gto import make_basis_array, MolePad
193-
#from pyscfad.xtb import basis as xtb_basis
194-
#bfile = xtb_basis.get_basis_filename()
195-
196-
import os
197-
from pyscf.gto import basis
198-
bfile = os.path.dirname(basis.__file__) + "/sto-3g.dat"
199-
basis = make_basis_array(bfile, max_number=8)
200-
201-
numbers = np.array([8, 1, 1, 0], dtype=np.int32)
202-
coords = np.array(
203-
[
204-
[0.00000, 0.00000, 0.00000],
205-
[1.43355, 0.00000, -0.95296],
206-
[1.43355, 0.00000, 0.95296],
207-
[0.00000, 0.00000, 0.00000],
208-
]
209-
)
210-
211-
@jax.jit
212-
def energy(numbers, coords):
213-
mol = MolePad(numbers, coords, basis=basis, verbose=4)
214-
dm0 = np.zeros((mol.nao, mol.nao))
215-
for i in range(5):
216-
dm0 = dm0.at[i,i].set(2.)
217-
mf = SCFPad(mol)
218-
ehf = mf.kernel(dm0=dm0)
219-
return ehf
220-
221-
ehf = energy(numbers, coords)
222-
print(ehf)
223-
print(energy._cache_size())
224-
225-
numbers = np.array([7, 1, 1, 1], dtype=np.int32)
226-
coords = np.array(
227-
[
228-
[-0.80650, -1.00659, 0.02850],
229-
[-0.50540, -0.31299, 0.68220],
230-
[ 0.00620, -1.41579, -0.38500],
231-
[-1.32340, -0.54779, -0.69350],
232-
]
233-
) / 0.52917721067121
234-
235-
ehf = energy(numbers, coords)
236-
print(ehf)
237-
print(energy._cache_size())
39+
SCFPad = SCF

pyscfad/ml/xtb/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021-2025 The PySCFAD Authors
1+
# Copyright 2025-2026 The PySCFAD Authors
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)