Skip to content

Commit ba9420c

Browse files
committed
implemented cross entropy for multiple classes
1 parent 8073842 commit ba9420c

1 file changed

Lines changed: 35 additions & 1 deletion

File tree

little-book-of-deep-learning/cross-entropy.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ def binary_cross_entropy(y_true, y_pred):
1414
float: Binary cross-entropy loss.
1515
"""
1616
epsilon = 1e-15 # to avoid log(0)
17-
y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
17+
y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
18+
## we use np.clip to avoid log(0)
1819
loss = -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))
1920
return loss
2021

@@ -23,3 +24,36 @@ def binary_cross_entropy(y_true, y_pred):
2324
y_pred = np.array([0.9, 0.1, 0.8, 0.7, 0.3])
2425
loss = binary_cross_entropy(y_true, y_pred)
2526
print(f'Binary Cross-Entropy Loss: {loss}')
27+
28+
29+
30+
def multi_cross_entropy(y_true, y_pred, classes):
31+
"""
32+
Compute the binary cross-entropy loss.
33+
34+
Parameters:
35+
y_true (numpy.ndarray): True labels (0 or 1).
36+
y_pred (numpy.ndarray): Predicted probabilities.
37+
38+
Returns:
39+
float: Binary cross-entropy loss.
40+
"""
41+
epsilon = 1e-15 # to avoid log(0)
42+
y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
43+
## we use np.clip to avoid log(0)
44+
loss = -np.mean(np.sum(y_true * np.log(y_pred), axis=1))
45+
return loss
46+
47+
48+
y_true = np.array([
49+
[1, 0, 0],
50+
[0, 1, 0],
51+
[0, 0, 1]
52+
])
53+
y_pred = np.array([
54+
[0.9, 0.05, 0.05],
55+
[0.1, 0.8, 0.1],
56+
[0.2, 0.2, 0.6]
57+
])
58+
loss = multi_cross_entropy(y_true, y_pred)
59+
print(f'Categorical Cross-Entropy Loss: {loss}')

0 commit comments

Comments
 (0)