Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix lint error
  • Loading branch information
fishjojo committed Jun 15, 2025
commit 6a5ee81cc368bd6ecd2e73955bc896fb99f41aea
10 changes: 5 additions & 5 deletions pyscfad/cc/ccsd_t.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ctypes
import numpy
from jax import custom_vjp
from jax.tree_util import tree_flatten_with_path, tree_unflatten

Check warning on line 4 in pyscfad/cc/ccsd_t.py

View check run for this annotation

Codecov / codecov/patch

pyscfad/cc/ccsd_t.py#L4

Added line #L4 was not covered by tests
from pyscf.lib import (
prange_tril,
num_threads,
Expand Down Expand Up @@ -35,31 +35,31 @@
eris, t1, t2 = res

# TODO clean up tree unflatten
path_vals, treedef = tree_flatten_with_path(eris)
keys = [item[0][0].name for item in path_vals]
shapes = [item[1].shape for item in path_vals]
path_vals = None

Check warning on line 41 in pyscfad/cc/ccsd_t.py

View check run for this annotation

Codecov / codecov/patch

pyscfad/cc/ccsd_t.py#L38-L41

Added lines #L38 - L41 were not covered by tests

t1_bar, t2_bar, fock_bar, mo_energy_bar,\
ovoo_bar, ovov_bar, ovvv_bar = _ccsd_t_energy_vjp(eris, t1, t2, ybar, max_memory)

leaves = [None] * len(keys)
key_to_bar = {

Check warning on line 47 in pyscfad/cc/ccsd_t.py

View check run for this annotation

Codecov / codecov/patch

pyscfad/cc/ccsd_t.py#L46-L47

Added lines #L46 - L47 were not covered by tests
"fock": fock_bar,
"mo_energy": mo_energy_bar,
"ovoo": ovoo_bar,
"ovov": ovov_bar,
"ovvv": ovvv_bar,
'fock': fock_bar,
'mo_energy': mo_energy_bar,
'ovoo': ovoo_bar,
'ovov': ovov_bar,
'ovvv': ovvv_bar,
}

for k, val in key_to_bar.items():
leaves[keys.index(k)] = val

Check warning on line 56 in pyscfad/cc/ccsd_t.py

View check run for this annotation

Codecov / codecov/patch

pyscfad/cc/ccsd_t.py#L55-L56

Added lines #L55 - L56 were not covered by tests

for i, leaf in enumerate(leaves):
if leaf is None:
leaves[i] = numpy.zeros(shapes[i])

Check warning on line 60 in pyscfad/cc/ccsd_t.py

View check run for this annotation

Codecov / codecov/patch

pyscfad/cc/ccsd_t.py#L58-L60

Added lines #L58 - L60 were not covered by tests

eris_bar = tree_unflatten(treedef, leaves)

Check warning on line 62 in pyscfad/cc/ccsd_t.py

View check run for this annotation

Codecov / codecov/patch

pyscfad/cc/ccsd_t.py#L62

Added line #L62 was not covered by tests
return eris_bar, t1_bar, t2_bar

_ccsd_t_kernel.defvjp(_ccsd_t_kernel_fwd, _ccsd_t_kernel_bwd)
Expand Down