Skip to content

Commit 635bfc9

Browse files
committed
jax_predict.py formating
1 parent a028801 commit 635bfc9

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

equitrain/backends/jax_predict.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,7 @@ def _chunk_iterator():
148148

149149
def _graph_real_counts(graph):
150150
try:
151-
pad_mask = np.asarray(
152-
jraph.get_graph_padding_mask(graph), dtype=bool
153-
)
151+
pad_mask = np.asarray(jraph.get_graph_padding_mask(graph), dtype=bool)
154152
n_node = np.asarray(graph.n_node)
155153

156154
# Loader may append zero-node graphs to satisfy multi-device padding.

0 commit comments

Comments
 (0)