Skip to content

Commit 4f2f7f4

Browse files
committed
Fix jax mean for 0 counts.
1 parent 2833835 commit 4f2f7f4

File tree

3 files changed

+3
-2
lines changed

3 files changed

+3
-2
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ You can find a [table](training/results/README.md) of common benchmark datasets
295295

296296
Some known issues to be aware of, if using and making new models or layers with `kgcnn`.
297297
* Jagged or nested Tensors loading into models for PyTorch backend is not working.
298-
* ForceModel does not support all backends.
299298
* BatchNormalization layer dos not support padding yet.
300299
* Keras AUC metric does not seem to work for torch cuda.
301300

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ causing clashes with built-in functions. We catch defaults to be at least as bac
1616
* Added simple ragged support for ``train_force.py``
1717
* Implemented random equivariant initialize for PAiNN
1818
* Implemented charge and dipole output for HDNNP2nd
19+
* Implemented jax backend for force models.
1920

2021

2122
v4.0.0

kgcnn/backend/_jax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def scatter_reduce_mean(indices, values, shape):
5353
zeros = jnp.zeros(shape, values.dtype)
5454
counts = jnp.zeros(shape, values.dtype)
5555
counts = counts.at[indices].add(jnp.ones_like(values))
56-
return zeros.at[indices].add(values)/counts
56+
inverse_counts = jnp.nan_to_num(jnp.reciprocal(counts))
57+
return zeros.at[indices].add(values)*inverse_counts
5758

5859

5960
def scatter_reduce_softmax(indices, values, shape, normalize: bool = False):

0 commit comments

Comments
 (0)