diff --git a/cellpose/metrics.py b/cellpose/metrics.py index c89ea565..b27604f8 100644 --- a/cellpose/metrics.py +++ b/cellpose/metrics.py @@ -55,14 +55,20 @@ def boundary_scores(masks_true, masks_pred, scales): return precision, recall, fscore +def _label_overlap(masks_true, masks_pred): + return csr_matrix((np.ones((masks_true.size,), "int"), + (masks_true.flatten(), masks_pred.flatten())), + shape=(masks_true.max() + 1, masks_pred.max() + 1)) + + def aggregated_jaccard_index(masks_true, masks_pred): - """ - AJI = intersection of all matched masks / union of all masks - + """ + AJI = intersection of all matched masks / union of all masks + Args: - masks_true (list of np.ndarrays (int) or np.ndarray (int)): + masks_true (list of np.ndarrays (int) or np.ndarray (int)): where 0=NO masks; 1,2... are mask labels - masks_pred (list of np.ndarrays (int) or np.ndarray (int)): + masks_pred (list of np.ndarrays (int) or np.ndarray (int)): np.ndarray (int) where 0=NO masks; 1,2... are mask labels Returns: @@ -80,26 +86,26 @@ def aggregated_jaccard_index(masks_true, masks_pred): def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]): - """ + """ Average precision estimation: AP = TP / (TP + FP + FN) This function is based heavily on the *fast* stardist matching functions (https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py) Args: - masks_true (list of np.ndarrays (int) or np.ndarray (int)): + masks_true (list of np.ndarrays (int) or np.ndarray (int)): where 0=NO masks; 1,2... are mask labels - masks_pred (list of np.ndarrays (int) or np.ndarray (int)): + masks_pred (list of np.ndarrays (int) or np.ndarray (int)): np.ndarray (int) where 0=NO masks; 1,2... are mask labels Returns: - ap (array [len(masks_true) x len(threshold)]): + ap (array [len(masks_true) x len(threshold)]): average precision at thresholds - tp (array [len(masks_true) x len(threshold)]): + tp (array [len(masks_true) x len(threshold)]): number of true positives at thresholds - fp (array [len(masks_true) x len(threshold)]): + fp (array [len(masks_true) x len(threshold)]): number of false positives at thresholds - fn (array [len(masks_true) x len(threshold)]): + fn (array [len(masks_true) x len(threshold)]): number of false negatives at thresholds """ not_list = False @@ -149,7 +155,7 @@ def _intersection_over_union(masks_true, masks_pred): How it works: The overlap matrix is a lookup table of the area of intersection between each set of labels (true and predicted). The true labels - are taken to be along axis 0, and the predicted labels are taken + are taken to be along axis 0, and the predicted labels are taken to be along axis 1. The sum of the overlaps along axis 0 is thus an array giving the total overlap of the true labels with each of the predicted labels, and likewise the sum over axis 1 is the @@ -159,13 +165,11 @@ def _intersection_over_union(masks_true, masks_pred): column vectors gives a 2D array with the areas of every label pair added together. This is equivalent to the union of the label areas except for the duplicated overlap area, so the overlap matrix is - subtracted to find the union matrix. + subtracted to find the union matrix. """ if masks_true.size != masks_pred.size: raise ValueError("masks_true.size != masks_pred.size") - overlap = csr_matrix((np.ones((masks_true.size,), "int"), - (masks_true.flatten(), masks_pred.flatten())), - shape=(masks_true.max()+1, masks_pred.max()+1)) + overlap = _label_overlap(masks_true, masks_pred) overlap = overlap.toarray() n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) n_pixels_true = np.sum(overlap, axis=1, keepdims=True)