Skip to content

Commit da59f05

Browse files
committed
fix pbcintor
1 parent 3ed1a22 commit da59f05

File tree

14 files changed

+585
-317
lines changed

14 files changed

+585
-317
lines changed

pyscfad/gto/mole_lite.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,9 @@
4747
)
4848

4949
from pyscfad import numpy as np
50-
#from pyscfad import pytree
5150
from pyscfad import ops
52-
#from pyscfad.ops import jit
5351
from pyscfad.gto.mole import energy_nuc
52+
from pyscfad.gto import moleintor_lite
5453

5554
Array = Any
5655

@@ -92,12 +91,13 @@ class Mole(MoleBase):
9291
Parameters
9392
----------
9493
symbols : tuple of str
95-
Atomic symbols.
94+
Atomic symbols (mutually exclusive with ``numbers``).
9695
coords : array
9796
Atomic coordinates (in Bohr).
9897
basis : dict or str
99-
Atom-centered contracted Gaussian basis set parameters
100-
(including exponents and contraction coefficients).
98+
A string indicating the Gaussian basis set to use, or
99+
a dictionary including the basis set parameters
100+
(exponents and contraction coefficients).
101101
numbers : tuple of ints
102102
Atomic numbers (mutually exclusive with ``symbols``).
103103
charge : int
@@ -106,10 +106,18 @@ class Mole(MoleBase):
106106
2S (number of alpha electrons minus number of beta electrons).
107107
cart : bool
108108
Whether to use Cartesian Gaussian basis.
109+
verbose : int
110+
Printing level.
109111
trace_coords : bool
110112
Whether to trace atomic coordinates for gradient calculations.
111113
trace_basis : bool
112114
Whether to trace basis set parameters for gradient calculations.
115+
116+
Notes
117+
-----
118+
The molecular composition (i.e., ``symbols`` or ``numbers``)
119+
must be static as input. For dynamic molecular composition,
120+
refer to :class:`pyscfad.ml.gto.MolePad`.
113121
"""
114122
def __init__(
115123
self,
@@ -127,7 +135,6 @@ def __init__(
127135
if numbers is not None:
128136
if symbols is not None:
129137
raise KeyError("Only one of 'symbols' and 'numbers' can be specified.")
130-
#numbers = numpy.asarray(numbers, dtype=int)
131138
self.symbols = tuple(_symbol(i) for i in numbers)
132139
else:
133140
self.symbols = _format_symbols(symbols)
@@ -187,7 +194,7 @@ def copy(
187194
return newmol
188195

189196
def build(self, *args, **kwargs):
190-
pass
197+
return self
191198

192199
def intor(
193200
self,
@@ -199,7 +206,8 @@ def intor(
199206
shls_slice: tuple[int, ...] | None = None,
200207
grids: Array | None = None,
201208
) -> Array:
202-
from pyscfad.gto import moleintor_lite
209+
del out, grids
210+
203211
intor_name = self._add_suffix(intor_name)
204212
if "ECP" in intor_name:
205213
raise NotImplementedError
@@ -215,7 +223,6 @@ def intor(
215223
comp=comp,
216224
hermi=hermi,
217225
aosym=aosym,
218-
out=out,
219226
trace_coords=self.trace_coords,
220227
trace_basis=self.trace_basis,
221228
)
@@ -309,14 +316,14 @@ def from_pyscf(
309316
charge=mol.charge,
310317
spin=mol.spin,
311318
cart=mol.cart,
319+
verbose=mol.verbose,
312320
trace_coords=trace_coords,
313321
trace_basis=trace_basis,
314322
)
315323
return dmol
316324

317325
def to_pyscf(
318326
self,
319-
verbose: int | None = None,
320327
output: str | None = None,
321328
max_memory: int | None = None,
322329
) -> MoleBase:
@@ -336,7 +343,7 @@ def to_pyscf(
336343
spin=self.spin,
337344
cart=self.cart,
338345
unit="AU",
339-
verbose=verbose,
346+
verbose=self.verbose,
340347
output=output,
341348
max_memory=max_memory,
342349
dump_input=False,

pyscfad/gto/moleintor_lite.py

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# limitations under the License.
1414

1515
from typing import Any
16-
import warnings
1716
from functools import partial
1817
import numpy
1918

19+
from jax.custom_derivatives import SymbolicZero
2020
from pyscf.gto.mole import (
2121
ATOM_OF,
2222
PTR_COORD,
@@ -161,8 +161,6 @@ def _get_shape(
161161
"hermi",
162162
"aosym",
163163
"ao_loc",
164-
"cintopt",
165-
"out",
166164
"trace_coords",
167165
"trace_basis",
168166
"aoslices",
@@ -178,15 +176,11 @@ def getints(
178176
hermi: int = 0,
179177
aosym: str = "s1",
180178
ao_loc: Array | None = None,
181-
cintopt: Any | None = None,
182-
out: Any | None = None,
183179
trace_coords: bool = False,
184180
trace_basis: bool = False,
185181
aoslices: Array | None = None, # for padding
186182
):
187183
from pyscfad.gto._pyscf_moleintor import getints as callback
188-
if out is not None:
189-
warnings.warn("Pre-allocated 'out' is not used")
190184

191185
shape = _get_shape(
192186
intor_name,
@@ -211,8 +205,6 @@ def getints(
211205
comp=comp,
212206
hermi=hermi,
213207
ao_loc=ao_loc,
214-
cintopt=cintopt,
215-
out=None,
216208
)
217209
return out
218210

@@ -226,8 +218,6 @@ def getints_jvp(
226218
hermi,
227219
aosym,
228220
ao_loc,
229-
cintopt,
230-
out,
231221
trace_coords,
232222
trace_basis,
233223
aoslices,
@@ -250,8 +240,6 @@ def getints_jvp(
250240
hermi=hermi,
251241
aosym=aosym,
252242
ao_loc=ao_loc,
253-
cintopt=cintopt,
254-
out=out,
255243
trace_coords=trace_coords,
256244
trace_basis=trace_basis,
257245
aoslices=aoslices,
@@ -262,34 +250,35 @@ def getints_jvp(
262250
intor_ip_bra = intor_ip_ket = None
263251
intor_ip_bra, intor_ip_ket = int1e_dr1_name(intor_name)
264252

265-
if trace_coords and (intor_ip_bra or intor_ip_ket):
266-
if intor_name.startswith("int1e_rinv"):
267-
rc_deriv = PTR_RINV_ORIG
268-
elif intor_name.startswith("int1e_r"):
269-
rc_deriv = PTR_COMMON_ORIG
270-
else:
271-
rc_deriv = None
272-
273-
tangent_out += _gen_int1e_jvp_r0(
274-
intor_ip_bra,
275-
intor_ip_ket,
276-
atm,
277-
bas,
278-
env,
279-
env_dot,
280-
shls_slice,
281-
comp,
282-
hermi,
283-
aosym,
284-
ao_loc,
285-
trace_coords,
286-
trace_basis,
287-
aoslices,
288-
rc_deriv,
289-
).reshape(tangent_out.shape)
290-
291-
if trace_basis:
292-
raise NotImplementedError("basis set parameter derivative not supported")
253+
if not isinstance(env_dot, SymbolicZero):
254+
if trace_coords and (intor_ip_bra or intor_ip_ket):
255+
if intor_name.startswith("int1e_rinv"):
256+
rc_deriv = PTR_RINV_ORIG
257+
elif intor_name.startswith("int1e_r"):
258+
rc_deriv = PTR_COMMON_ORIG
259+
else:
260+
rc_deriv = None
261+
262+
tangent_out += _gen_int1e_jvp_r0(
263+
intor_ip_bra,
264+
intor_ip_ket,
265+
atm,
266+
bas,
267+
env,
268+
env_dot,
269+
shls_slice,
270+
comp,
271+
hermi,
272+
aosym,
273+
ao_loc,
274+
trace_coords,
275+
trace_basis,
276+
aoslices,
277+
rc_deriv,
278+
).reshape(tangent_out.shape)
279+
280+
if trace_basis:
281+
raise NotImplementedError("basis set parameter derivative not supported")
293282
return primal_out, tangent_out
294283

295284

@@ -323,8 +312,6 @@ def _gen_int1e_jvp_r0(
323312
hermi=0,
324313
aosym=aosym,
325314
ao_loc=ao_loc,
326-
cintopt=None,
327-
out=None,
328315
trace_coords=trace_coords,
329316
trace_basis=trace_basis,
330317
aoslices=aoslices,
@@ -364,8 +351,6 @@ def _gen_int1e_jvp_r0(
364351
hermi=0,
365352
aosym=aosym,
366353
ao_loc=ao_loc,
367-
cintopt=None,
368-
out=None,
369354
trace_coords=trace_coords,
370355
trace_basis=trace_basis,
371356
aoslices=aoslices,

pyscfad/ml/gto/mole_pad.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
from pyscfad import numpy as np
3434
from pyscfad import ops
35+
from pyscfad.gto import moleintor_lite
3536
from pyscfad.gto.mole_lite import MoleLite
3637
from pyscfad.ml.gto.basis_array import BasisArray
3738

@@ -160,7 +161,8 @@ def intor(
160161
shls_slice: tuple[int, ...] | None = None,
161162
grids: Array | None = None,
162163
) -> Array:
163-
from pyscfad.gto import moleintor_lite
164+
del out, grids
165+
164166
intor_name = self._add_suffix(intor_name)
165167
if "ECP" in intor_name:
166168
raise NotImplementedError
@@ -180,7 +182,6 @@ def intor(
180182
hermi=hermi,
181183
aosym=aosym,
182184
ao_loc=ao_loc,
183-
out=out,
184185
trace_coords=self.trace_coords,
185186
trace_basis=self.trace_basis,
186187
aoslices=aoslices,

pyscfad/pbc/gto/_pbcintor.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2021-2025 Xing Zhang
1+
# Copyright 2021-2025 The PySCFAD Authors
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -24,7 +24,79 @@
2424
@partial(custom_jvp, nondiff_argnums=tuple(range(1,7)))
2525
def _pbc_intor(mol, intor, comp=None, hermi=0, kpts=None, kpt=None,
2626
shls_slice=None):
27-
return Cell.pbc_intor(mol.view(Cell), intor, comp, hermi, kpts, kpt, shls_slice)
27+
import ctypes
28+
from pyscf.gto import moleintor
29+
from pyscf import lib
30+
from pyscf.gto.mole import conc_env
31+
from pyscf.pbc.gto import _pbcintor
32+
libpbc = _pbcintor.libpbc
33+
34+
cell1 = cell2 = mol
35+
intor, comp = moleintor._get_intor_and_comp(cell1._add_suffix(intor), comp)
36+
37+
if kpts is None:
38+
if kpt is not None:
39+
kpts_lst = numpy.reshape(kpt, (1,3))
40+
else:
41+
kpts_lst = numpy.zeros((1,3))
42+
else:
43+
kpts_lst = numpy.reshape(kpts, (-1,3))
44+
nkpts = len(kpts_lst)
45+
46+
pcell = cell1.copy(deep=False)
47+
pcell.precision = min(cell1.precision, cell2.precision)
48+
pcell._atm, pcell._bas, pcell._env = \
49+
atm, bas, env = conc_env(cell1._atm, cell1._bas, cell1._env,
50+
cell2._atm, cell2._bas, cell2._env)
51+
52+
if shls_slice is None:
53+
shls_slice = (0, cell1.nbas, 0, cell2.nbas)
54+
i0, i1, j0, j1 = shls_slice[:4]
55+
j0 += cell1.nbas
56+
j1 += cell1.nbas
57+
ao_loc = moleintor.make_loc(bas, intor)
58+
ni = ao_loc[i1] - ao_loc[i0]
59+
nj = ao_loc[j1] - ao_loc[j0]
60+
out = numpy.empty((nkpts,comp,ni,nj), dtype=numpy.complex128)
61+
62+
if hermi == 0:
63+
aosym = 's1'
64+
else:
65+
aosym = 's2'
66+
fill = getattr(libpbc, 'PBCnr2c_fill_k'+aosym)
67+
fintor = getattr(moleintor.libcgto, intor)
68+
cintopt = lib.c_null_ptr()
69+
70+
rcut = max(cell1.rcut, cell2.rcut)
71+
Ls = numpy.asarray(cell1.get_lattice_Ls(rcut=rcut), order='C')
72+
expkL = numpy.asarray(numpy.exp(1j*numpy.dot(kpts_lst, Ls.T)), order='C')
73+
drv = libpbc.PBCnr2c_drv
74+
75+
drv(fintor, fill, out.ctypes.data_as(ctypes.c_void_p),
76+
ctypes.c_int(nkpts), ctypes.c_int(comp), ctypes.c_int(len(Ls)),
77+
Ls.ctypes.data_as(ctypes.c_void_p),
78+
expkL.ctypes.data_as(ctypes.c_void_p),
79+
(ctypes.c_int*4)(i0, i1, j0, j1),
80+
ao_loc.ctypes.data_as(ctypes.c_void_p), cintopt,
81+
atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(pcell.natm),
82+
bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(pcell.nbas),
83+
env.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(env.size))
84+
85+
mat = []
86+
for k, kpt in enumerate(kpts_lst):
87+
v = out[k]
88+
if hermi != 0:
89+
for ic in range(comp):
90+
lib.hermi_triu(v[ic], hermi=hermi, inplace=True)
91+
if comp == 1:
92+
v = v[0]
93+
if abs(kpt).sum() < 1e-9: # gamma_point
94+
v = v.real
95+
mat.append(v)
96+
97+
if kpts is None or numpy.shape(kpts) == (3,): # A single k-point
98+
mat = mat[0]
99+
return mat
28100

29101
@_pbc_intor.defjvp
30102
def _pbc_intor_jvp(intor, comp, hermi, kpts, kpt, shls_slice,
@@ -74,6 +146,7 @@ def _int1e_jvp_r0(mol, mol_t, intor, hermi, kpts, kpt, shls_slice):
74146
tangent_out = tangent_out[0]
75147
return tangent_out
76148

149+
# FIXME use pyscfad's Ls
77150
@partial(custom_vjp, nondiff_argnums=tuple(range(1,7)))
78151
def _pbc_intor_rev(mol, intor, comp=None, hermi=0, kpts=None, kpt=None,
79152
shls_slice=None):

0 commit comments

Comments
 (0)