|
1 | | -# Copyright 2021-2025 The PySCFAD Authors |
| 1 | +# Copyright 2025-2026 The PySCFAD Authors |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
|
14 | 14 |
|
15 | 15 | import numpy |
16 | 16 | from pyscfad import numpy as np |
17 | | -from pyscfad.lib import logger |
18 | 17 | from pyscfad.scf import hf_lite as hf |
19 | 18 | from pyscfad.scipy.linalg import eigh |
20 | | -from pyscfad.ops import stop_grad |
21 | | - |
22 | | -from functools import partial |
23 | | -from jax import custom_jvp |
24 | | -from jax.lax import while_loop |
25 | | - |
26 | | -def _fermi_entropy(occ): |
27 | | - occ = occ / 2.0 |
28 | | - occ_safe = np.where(np.logical_or(occ < 1e-10, occ > 1 - 1e-10), 0.5, occ) |
29 | | - ent_term = occ_safe * np.log(occ_safe) + (1 - occ_safe) * np.log(1 - occ_safe) |
30 | | - |
31 | | - return -2 * np.where( |
32 | | - np.logical_or(occ < 1e-10, occ > 1 - 1e-10), |
33 | | - 0., |
34 | | - ent_term |
35 | | - ).sum() |
36 | | - |
37 | | -def _fermi_smearing_occ(mu, mo_energy, sigma, mo_mask): |
38 | | - de = (mo_energy - mu) / sigma |
39 | | - de_ = np.where(np.less(de, 40.), de, 0) |
40 | | - occ = np.where(np.less(de, 40.), 1. / (np.exp(de_) + 1.), 0.) |
41 | | - occ = np.where(mo_mask, occ, 0) |
42 | | - return occ |
43 | | - |
44 | | -@custom_jvp |
45 | | -def _smearing_solve_mu(mo_es, nocc, sigma, mo_mask): |
46 | | - def cond_fun(value): |
47 | | - mu, nerr = value |
48 | | - return nerr**2 > 1e-12 |
49 | | - |
50 | | - def body_fun(value): |
51 | | - ''' One Halley step ''' |
52 | | - mu, nerr = value |
53 | | - occ = _fermi_smearing_occ(mu, mo_es, sigma, mo_mask) |
54 | | - grad = occ * (1.-occ) / sigma |
55 | | - hess = grad * (1.-2*occ) / sigma |
56 | | - occ = np.sum(occ) |
57 | | - nerr = occ - nocc |
58 | | - grad = np.sum(grad) |
59 | | - hess = np.sum(hess) |
60 | | - dmu = -nerr * grad / (grad**2 - .5 * hess * nerr) |
61 | | - return mu + dmu, nerr |
62 | | - |
63 | | - mu, nerr = while_loop(cond_fun, body_fun, (mo_es[nocc-1], 1e2)) |
64 | | - return np.array([mu]) |
65 | | - |
66 | | -@_smearing_solve_mu.defjvp |
67 | | -def _smearing_solve_mu_jvp(primals, tangents): |
68 | | - mo_es, nocc, sigma, mo_mask = primals |
69 | | - dmo_e, _, _, _ = tangents |
70 | | - |
71 | | - mu = _smearing_solve_mu(mo_es, nocc, sigma, mo_mask) |
72 | | - occ = _fermi_smearing_occ(mu, mo_es, sigma, mo_mask) |
73 | | - dndmu = occ * (1.-occ) / sigma |
74 | | - dndmu = np.where(np.abs(dndmu) < 1e-10, 0., dndmu) |
75 | | - return mu, np.dot(dndmu, dmo_e)[None] / (1e-13 + np.sum(dndmu)) |
76 | | - |
77 | | -def _smearing_optimize(mo_es, nocc, sigma, mo_mask): |
78 | | - mu = _smearing_solve_mu(mo_es, nocc, sigma, mo_mask) |
79 | | - mo_occs = _fermi_smearing_occ(mu, mo_es, sigma, mo_mask) |
80 | | - return mu, mo_occs |
81 | | - |
82 | | -def make_mo_mask(mo_energy, mo_coeff, ao_mask): |
83 | | - mask_fake_ao = np.asarray(1 - ao_mask, dtype=bool) |
84 | | - mo_coeff_fake_ao = np.where(mask_fake_ao[:,None], mo_coeff, 0) |
85 | | - tmp = np.linalg.norm(mo_coeff_fake_ao, axis=0) |
86 | | - #mask = np.where(np.logical_and(mo_energy>1e8, tmp>1e-12), False, True) |
87 | | - mask = np.where(np.logical_and(abs(mo_energy)<1e-12, tmp>1e-12), False, True) |
88 | | - return mask |
89 | | - |
90 | | -def get_occ(mf, mo_energy=None, mo_coeff=None): |
91 | | - # NOTE assuming mo_energy is in ascending order |
92 | | - if mo_energy is None: |
93 | | - mo_energy = mf.mo_energy |
94 | | - if mo_coeff is None: |
95 | | - mo_coeff = mf.mo_coeff |
96 | | - nmo = mo_energy.size |
97 | | - #e_sort = mo_energy |
98 | | - |
99 | | - mask = make_mo_mask(mo_energy, mo_coeff, mf.mol.ao_mask) |
100 | | - nocc = mf.tot_electrons // 2 |
101 | | - |
102 | | - if mf.sigma is not None and mf.sigma > 0: |
103 | | - mu, mo_occ = _smearing_optimize(mo_energy, stop_grad(nocc), stop_grad(mf.sigma), stop_grad(mask)) |
104 | | - mo_occ *= 2 |
105 | | - else: |
106 | | - pick = (np.cumsum(mask) <= nocc) & mask |
107 | | - mo_occ = np.where(pick, 2., 0.) |
108 | | - |
109 | | - e_homo = np.max(np.where(pick, mo_energy, -np.inf)) |
110 | | - e_lumo = np.min(np.where(mask & ~pick, mo_energy, np.inf)) |
111 | | - if mf.verbose >= logger.DEBUG: |
112 | | - logger.debug(mf, " HOMO = %.15g LUMO = %.15g", e_homo, e_lumo) |
113 | | - |
114 | | - if mf.verbose >= logger.DEBUG: |
115 | | - numpy.set_printoptions(threshold=nmo) |
116 | | - logger.debug(mf, " mo_energy =\n%s", mo_energy) |
117 | | - numpy.set_printoptions(threshold=1000) |
118 | | - return mo_occ |
119 | | - |
120 | | -def get_homo_lumo_energy(mf, mo_energy=None, mo_coeff=None): |
121 | | - if mf.sigma is not None and mf.sigma > 0: |
122 | | - raise NotImplementedError |
123 | | - |
124 | | - if mo_energy is None: |
125 | | - mo_energy = mf.mo_energy |
126 | | - if mo_coeff is None: |
127 | | - mo_coeff = mf.mo_coeff |
128 | | - |
129 | | - mask = make_mo_mask(mo_energy, mo_coeff, mf.mol.ao_mask) |
130 | | - nocc = mf.tot_electrons // 2 |
131 | | - pick = (np.cumsum(mask) <= nocc) & mask |
132 | | - |
133 | | - e_homo = np.max(np.where(pick, mo_energy, -np.inf)) |
134 | | - e_lumo = np.min(np.where(mask & ~pick, mo_energy, np.inf)) |
135 | | - if mf.verbose >= logger.DEBUG: |
136 | | - logger.debug(mf, " HOMO = %.15g LUMO = %.15g", e_homo, e_lumo) |
137 | | - return e_homo, e_lumo |
138 | | - |
139 | | -def make_rdm1(mo_coeff, mo_occ, **kwargs): |
140 | | - dm = (mo_coeff * mo_occ) @ mo_coeff.conj().T |
141 | | - return dm |
142 | | - |
143 | | -def get_grad(mo_coeff, mo_occ, fock_ao): |
144 | | - fock_mo = mo_coeff.conj().T @ fock_ao @ mo_coeff |
145 | | - mask = np.where(mo_occ > 0, True, False) |
146 | | - g = 2 * fock_mo * ((1-mask[:,None]) * mask[None,:]) |
147 | | - return g.ravel() |
148 | 19 |
|
149 | 20 | class SCF(hf.SCF): |
150 | | - def __init__(self, mol, **kwargs): |
151 | | - super().__init__(mol, **kwargs) |
152 | | - self.sigma = None |
153 | | - |
154 | | - def get_occ(self, mo_energy=None, mo_coeff=None): |
155 | | - return get_occ(self, mo_energy=mo_energy, mo_coeff=mo_coeff) |
156 | | - |
157 | | - @property |
158 | | - def tot_electrons(self): |
159 | | - return self.mol.tot_electrons() |
160 | | - |
161 | | - def make_rdm1(self, mo_coeff=None, mo_occ=None, **kwargs): |
162 | | - if mo_coeff is None: |
163 | | - mo_coeff = self.mo_coeff |
164 | | - if mo_occ is None: |
165 | | - mo_occ = self.mo_occ |
166 | | - return make_rdm1(mo_coeff, mo_occ, **kwargs) |
167 | | - |
168 | | - def get_grad(self, mo_coeff, mo_occ, fock=None): |
169 | | - if fock is None: |
170 | | - dm1 = self.make_rdm1(mo_coeff, mo_occ) |
171 | | - fock = self.get_hcore(self.mol) + self.get_veff(self.mol, dm1) |
172 | | - return get_grad(mo_coeff, mo_occ, fock) |
| 21 | + # NOTE too large shift will give large errors |
| 22 | + padding_level_shift = 1e6 |
173 | 23 |
|
174 | 24 | def _eigh(self, h, s): |
175 | 25 | ao_mask = self.mol.ao_mask |
176 | 26 | mask = np.asarray(1 - ao_mask, dtype=np.int32) |
177 | 27 | s = s + np.diag(mask) |
178 | | - #h = np.where(np.outer(ao_mask, ao_mask), h, 0.) |
179 | | - #h1e_diag = np.diag(h) |
180 | | - #h1e_diag = np.where(ao_mask, h1e_diag, 1e10) |
181 | | - #h = np.fill_diagonal(h, h1e_diag, inplace=False) |
| 28 | + h = np.where(np.outer(ao_mask, ao_mask), h, 0.) |
| 29 | + h = h + np.diag(mask) * self.padding_level_shift |
182 | 30 | return eigh(h, s) |
183 | 31 |
|
184 | | - get_homo_lumo_energy = get_homo_lumo_energy |
| 32 | + def mo_mask(self, mo_energy=None, mo_coeff=None): |
| 33 | + if mo_energy is None: |
| 34 | + mo_energy = self.mo_energy |
185 | 35 |
|
186 | | -SCFPad = SCF |
| 36 | + thr = .99 * self.padding_level_shift |
| 37 | + return np.where(np.greater(mo_energy, thr), False, True) |
187 | 38 |
|
188 | | - |
189 | | -if __name__ == "__main__": |
190 | | - import numpy |
191 | | - import jax |
192 | | - from pyscfad.ml.gto import make_basis_array, MolePad |
193 | | - #from pyscfad.xtb import basis as xtb_basis |
194 | | - #bfile = xtb_basis.get_basis_filename() |
195 | | - |
196 | | - import os |
197 | | - from pyscf.gto import basis |
198 | | - bfile = os.path.dirname(basis.__file__) + "/sto-3g.dat" |
199 | | - basis = make_basis_array(bfile, max_number=8) |
200 | | - |
201 | | - numbers = np.array([8, 1, 1, 0], dtype=np.int32) |
202 | | - coords = np.array( |
203 | | - [ |
204 | | - [0.00000, 0.00000, 0.00000], |
205 | | - [1.43355, 0.00000, -0.95296], |
206 | | - [1.43355, 0.00000, 0.95296], |
207 | | - [0.00000, 0.00000, 0.00000], |
208 | | - ] |
209 | | - ) |
210 | | - |
211 | | - @jax.jit |
212 | | - def energy(numbers, coords): |
213 | | - mol = MolePad(numbers, coords, basis=basis, verbose=4) |
214 | | - dm0 = np.zeros((mol.nao, mol.nao)) |
215 | | - for i in range(5): |
216 | | - dm0 = dm0.at[i,i].set(2.) |
217 | | - mf = SCFPad(mol) |
218 | | - ehf = mf.kernel(dm0=dm0) |
219 | | - return ehf |
220 | | - |
221 | | - ehf = energy(numbers, coords) |
222 | | - print(ehf) |
223 | | - print(energy._cache_size()) |
224 | | - |
225 | | - numbers = np.array([7, 1, 1, 1], dtype=np.int32) |
226 | | - coords = np.array( |
227 | | - [ |
228 | | - [-0.80650, -1.00659, 0.02850], |
229 | | - [-0.50540, -0.31299, 0.68220], |
230 | | - [ 0.00620, -1.41579, -0.38500], |
231 | | - [-1.32340, -0.54779, -0.69350], |
232 | | - ] |
233 | | - ) / 0.52917721067121 |
234 | | - |
235 | | - ehf = energy(numbers, coords) |
236 | | - print(ehf) |
237 | | - print(energy._cache_size()) |
| 39 | +SCFPad = SCF |
0 commit comments