We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a028801 commit 635bfc9Copy full SHA for 635bfc9
equitrain/backends/jax_predict.py
@@ -148,9 +148,7 @@ def _chunk_iterator():
148
149
def _graph_real_counts(graph):
150
try:
151
- pad_mask = np.asarray(
152
- jraph.get_graph_padding_mask(graph), dtype=bool
153
- )
+ pad_mask = np.asarray(jraph.get_graph_padding_mask(graph), dtype=bool)
154
n_node = np.asarray(graph.n_node)
155
156
# Loader may append zero-node graphs to satisfy multi-device padding.
0 commit comments