Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions cellpose/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,20 @@ def boundary_scores(masks_true, masks_pred, scales):
return precision, recall, fscore


def _label_overlap(masks_true, masks_pred):
return csr_matrix((np.ones((masks_true.size,), "int"),
(masks_true.flatten(), masks_pred.flatten())),
shape=(masks_true.max() + 1, masks_pred.max() + 1))


def aggregated_jaccard_index(masks_true, masks_pred):
"""
AJI = intersection of all matched masks / union of all masks
"""
AJI = intersection of all matched masks / union of all masks

Args:
masks_true (list of np.ndarrays (int) or np.ndarray (int)):
masks_true (list of np.ndarrays (int) or np.ndarray (int)):
where 0=NO masks; 1,2... are mask labels
masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
np.ndarray (int) where 0=NO masks; 1,2... are mask labels

Returns:
Expand All @@ -80,26 +86,26 @@ def aggregated_jaccard_index(masks_true, masks_pred):


def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]):
"""
"""
Average precision estimation: AP = TP / (TP + FP + FN)

This function is based heavily on the *fast* stardist matching functions
(https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py)

Args:
masks_true (list of np.ndarrays (int) or np.ndarray (int)):
masks_true (list of np.ndarrays (int) or np.ndarray (int)):
where 0=NO masks; 1,2... are mask labels
masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
np.ndarray (int) where 0=NO masks; 1,2... are mask labels

Returns:
ap (array [len(masks_true) x len(threshold)]):
ap (array [len(masks_true) x len(threshold)]):
average precision at thresholds
tp (array [len(masks_true) x len(threshold)]):
tp (array [len(masks_true) x len(threshold)]):
number of true positives at thresholds
fp (array [len(masks_true) x len(threshold)]):
fp (array [len(masks_true) x len(threshold)]):
number of false positives at thresholds
fn (array [len(masks_true) x len(threshold)]):
fn (array [len(masks_true) x len(threshold)]):
number of false negatives at thresholds
"""
not_list = False
Expand Down Expand Up @@ -149,7 +155,7 @@ def _intersection_over_union(masks_true, masks_pred):
How it works:
The overlap matrix is a lookup table of the area of intersection
between each set of labels (true and predicted). The true labels
are taken to be along axis 0, and the predicted labels are taken
are taken to be along axis 0, and the predicted labels are taken
to be along axis 1. The sum of the overlaps along axis 0 is thus
an array giving the total overlap of the true labels with each of
the predicted labels, and likewise the sum over axis 1 is the
Expand All @@ -159,13 +165,11 @@ def _intersection_over_union(masks_true, masks_pred):
column vectors gives a 2D array with the areas of every label pair
added together. This is equivalent to the union of the label areas
except for the duplicated overlap area, so the overlap matrix is
subtracted to find the union matrix.
subtracted to find the union matrix.
"""
if masks_true.size != masks_pred.size:
raise ValueError("masks_true.size != masks_pred.size")
overlap = csr_matrix((np.ones((masks_true.size,), "int"),
(masks_true.flatten(), masks_pred.flatten())),
shape=(masks_true.max()+1, masks_pred.max()+1))
overlap = _label_overlap(masks_true, masks_pred)
overlap = overlap.toarray()
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
Expand Down
Loading