We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 7da1cfb commit 2309c97Copy full SHA for 2309c97
baal/utils/array_utils.py
@@ -15,7 +15,7 @@ def to_prob(probabilities: np.ndarray):
15
"""
16
not_bounded = np.min(probabilities) < 0 or np.max(probabilities) > 1.0
17
multiclass = probabilities.shape[1] > 1
18
- sum_to_one = np.allclose(probabilities.sum(1), 1)
+ sum_to_one = np.allclose(probabilities.sum(1), 1, rtol=1e-4)
19
if not_bounded or (multiclass and not sum_to_one):
20
if multiclass:
21
probabilities = softmax(probabilities, 1)
0 commit comments