Skip to content

Commit 402288e

Browse files
updating docstrings to google format
1 parent 89ab381 commit 402288e

File tree

8 files changed

+60
-65
lines changed

8 files changed

+60
-65
lines changed

cellpose/core.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ def assign_device(use_torch=True, gpu=False, device=0):
8080
device (int or str, optional): The device index or name to be used. Defaults to 0.
8181
8282
Returns:
83-
torch.device: The assigned device.
84-
bool: True if GPU is used, False otherwise.
83+
torch.device, bool (True if GPU is used, False otherwise)
8584
"""
8685

8786
if isinstance(device, str):
@@ -212,9 +211,9 @@ def run_net(net, imgi, batch_size=8, augment=False, tile_overlap=0.1, bsize=224,
212211
bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
213212
214213
Returns:
215-
y (np.ndarray): output of network, if tiled it is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
216-
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
217-
style (np.ndarray): 1D array of size 256 summarizing the style of the image, if tiled it is averaged over tiles.
214+
Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
215+
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
216+
style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
218217
"""
219218
# run network
220219
nout = net.nout
@@ -300,9 +299,9 @@ def run_3D(net, imgs, batch_size=8, augment=False,
300299
progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
301300
302301
Returns:
303-
y (np.ndarray): output of network, if tiled it is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
304-
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
305-
style (np.ndarray): 1D array of size 256 summarizing the style of the image, if tiled it is averaged over tiles.
302+
Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
303+
y[...,0] is Z flow; y[...,1] is Y flow; y[...,2] is X flow; y[...,3] is cell probability.
304+
style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
306305
"""
307306
sstr = ["YX", "ZY", "ZX"]
308307
pm = [(0, 1, 2, 3), (1, 0, 2, 3), (2, 0, 1, 3)]

cellpose/denoise.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -554,10 +554,10 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
554554
interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True.
555555
556556
Returns:
557-
masks (list, np.ndarray): labelled image(s), where 0=no masks; 1,2,...=mask labels
558-
flows (list): list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration
559-
styles (list, np.ndarray): style vector summarizing each image of size 256.
560-
imgs (list of 2D/3D arrays): Restored images
557+
A tuple containing (masks, flows, styles, imgs); masks: labelled image(s), where 0=no masks; 1,2,...=mask labels;
558+
flows: list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration;
559+
styles: style vector summarizing each image of size 256;
560+
imgs: Restored images.
561561
"""
562562

563563
if isinstance(normalize, dict):
@@ -729,9 +729,9 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
729729
diameter (float, optional): diameter for each image,
730730
if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
731731
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
732-
732+
733733
Returns:
734-
imgs (list of 2D/3D arrays): Restored images
734+
list: A list of 2D/3D arrays of restored images
735735
736736
"""
737737
if isinstance(x, list) or x.squeeze().ndim == 5:
@@ -836,7 +836,7 @@ def _eval(self, net, x, batch_size=8, normalize=True, rescale=None,
836836
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
837837
838838
Returns:
839-
imgs (list of 2D/3D arrays): Restored images
839+
list: A list of 2D/3D arrays of restored images
840840
841841
"""
842842
if isinstance(normalize, dict):

cellpose/dynamics.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,7 @@ def get_centers(masks, slices):
105105
slices (ndarray): The slices of the masks.
106106
107107
Returns:
108-
tuple containing
109-
- centers (ndarray): The centers of the masks.
110-
- ext (ndarray): The extents of the masks.
108+
A tuple containing the centers of the masks and the extents of the masks.
111109
"""
112110
centers = np.zeros((len(slices), 2), "int32")
113111
ext = np.zeros((len(slices),), "int32")
@@ -131,16 +129,20 @@ def get_centers(masks, slices):
131129
def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
132130
"""Convert masks to flows using diffusion from center pixel.
133131
134-
Center of masks where diffusion starts is defined using COM.
132+
Center of masks where diffusion starts is defined by pixel closest to median within the mask.
135133
136134
Args:
137135
masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
136+
device (torch.device, optional): The device to run the computation on. Defaults to torch.device("cpu").
137+
niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
138138
139139
Returns:
140-
tuple containing
141-
- mu (float, 3D or 4D array): Flows in Y = mu[-2], flows in X = mu[-1].
142-
If masks are 3D, flows in Z = mu[0].
143-
- meds_p (float, 2D or 3D array): cell centers
140+
np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
141+
142+
143+
Returns:
144+
A tuple containing (mu, meds_p). mu is float 3D or 4D array of flows in (Z)XY.
145+
meds_p are cell centers.
144146
"""
145147
if device is None:
146148
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
@@ -200,11 +202,12 @@ def masks_to_flows_gpu_3d(masks, device=None, niter=None):
200202
201203
Args:
202204
masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
205+
device (torch.device, optional): The device to run the computation on. Defaults to None.
206+
niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
203207
204208
Returns:
205-
tuple containing
206-
- mu (float, 3D or 4D array): Flows in Y = mu[-2], flows in X = mu[-1]. If masks are 3D, flows in Z = mu[0].
207-
- mu_c (float, 2D or 3D array): zeros
209+
np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
210+
208211
"""
209212
if device is None:
210213
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
@@ -264,24 +267,22 @@ def masks_to_flows_gpu_3d(masks, device=None, niter=None):
264267
# put into original image
265268
mu0 = np.zeros((3, Lz0, Ly0, Lx0))
266269
mu0[:, z.cpu().numpy() - 1, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
267-
mu_c = np.zeros_like(mu0)
268-
return mu0, mu_c
270+
return mu0
269271

270272

271-
def masks_to_flows_cpu(masks, device=None, niter=None):
273+
def masks_to_flows_cpu(masks, niter=None, device=None):
272274
"""Convert masks to flows using diffusion from center pixel.
273275
274276
Center of masks where diffusion starts is defined to be the closest pixel to the mean of all pixels that is inside the mask.
275277
Result of diffusion is converted into flows by computing the gradients of the diffusion density map.
276278
277279
Args:
278280
masks (int, 2D or 3D array): Labelled masks 0=NO masks; 1,2,...=mask labels
279-
281+
niter (int, optional): Number of iterations for computing flows. Defaults to None.
282+
280283
Returns:
281-
tuple containing
282-
- mu (float, 3D or 4D array): Flows in Y = mu[-2], flows in X = mu[-1].
283-
If masks are 3D, flows in Z = mu[0].
284-
- meds (float, 2D or 3D array): cell centers
284+
A tuple containing (mu, meds_p). mu is float 3D or 4D array of flows in (Z)XY.
285+
meds_p are cell centers.
285286
"""
286287
Ly, Lx = masks.shape
287288
mu = np.zeros((2, Ly, Lx), np.float64)
@@ -327,8 +328,7 @@ def masks_to_flows(masks, device=torch.device("cpu"), niter=None):
327328
masks (int, 2D or 3D array): Labelled masks 0=NO masks; 1,2,...=mask labels
328329
329330
Returns:
330-
mu (float, 3D or 4D array): Flows in Y = mu[-2], flows in X = mu[-1].
331-
If masks are 3D, flows in Z = mu[0].
331+
np.ndarray: mu is float 3D or 4D array of flows in (Z)XY.
332332
"""
333333
if masks.max() == 0:
334334
dynamics_logger.warning("empty masks!")
@@ -583,9 +583,8 @@ def follow_flows(dP, inds, niter=200, interp=True, device=torch.device("cpu")):
583583
device (torch.device, optional): Device to use for computation. Default is None.
584584
585585
Returns:
586-
tuple containing:
587-
- p (np.ndarray): Final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
588-
- inds (np.ndarray): Indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
586+
A tuple containing (p, inds): p (np.ndarray): Final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx];
587+
inds (np.ndarray): Indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
589588
"""
590589
shape = np.array(dP.shape[1:]).astype(np.int32)
591590
ndim = len(inds)

cellpose/io.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -457,12 +457,8 @@ def load_train_test_data(train_dir, test_dir=None, image_filter=None,
457457
look_one_level_down (bool, optional): Whether to look for data in subdirectories of train_dir and test_dir. Defaults to False.
458458
459459
Returns:
460-
images (list): A list of training images.
461-
labels (list): A list of labels corresponding to the training images.
462-
image_names (list): A list of names of the training images.
463-
test_images (list, optional): A list of testing images. None if test_dir is not provided.
464-
test_labels (list, optional): A list of labels corresponding to the testing images. None if test_dir is not provided.
465-
test_image_names (list, optional): A list of names of the testing images. None if test_dir is not provided.
460+
images, labels, image_names, test_images, test_labels, test_image_names
461+
466462
"""
467463
images, labels, image_names = load_images_labels(train_dir, mask_filter,
468464
image_filter, look_one_level_down)

cellpose/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def flow_error(maski, dP_net, device=None):
245245
dP_net (np.ndarray, float): ND flows where dP_net.shape[1:] = maski.shape.
246246
247247
Returns:
248-
flow_errors (np.ndarray, float): Mean squared error between predicted flows and flows from masks.
248+
A tuple containing (flow_errors, dP_masks): flow_errors (np.ndarray, float): Mean squared error between predicted flows and flows from masks;
249249
dP_masks (np.ndarray, float): ND flows produced from the predicted masks.
250250
"""
251251
if dP_net.shape[1:] != maski.shape:

cellpose/models.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,12 @@ def eval(self, x, batch_size=8, channels=[0, 0], channel_axis=None, invert=False
162162
do_3D (bool, optional): Set to True to run 3D segmentation on 4D image input. Defaults to False.
163163
164164
Returns:
165-
tuple containing
166-
- masks (list of 2D arrays or single 3D array): Labelled image, where 0=no masks; 1,2,...=mask labels.
167-
- flows (list of lists 2D arrays or list of 3D arrays):
168-
- flows[k][0] = XY flow in HSV 0-255
169-
- flows[k][1] = XY flows at each pixel
170-
- flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics)
171-
- flows[k][3] = final pixel locations after Euler integration
172-
- styles (list of 1D arrays of length 256 or single 1D array): Style vector summarizing each image, also used to estimate size of objects in image.
173-
- diams (list of diameters or float): List of diameters or float (if do_3D=True).
165+
A tuple containing (masks, flows, styles, diams): masks (list of 2D arrays or single 3D array): Labelled image, where 0=no masks; 1,2,...=mask labels;
166+
flows (list of lists 2D arrays or list of 3D arrays): flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY flows at each pixel;
167+
flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics);
168+
flows[k][3] = final pixel locations after Euler integration;
169+
styles (list of 1D arrays of length 256 or single 1D array): Style vector summarizing each image, also used to estimate size of objects in image;
170+
diams (list of diameters or float): List of diameters or float (if do_3D=True).
174171
175172
"""
176173

@@ -435,10 +432,11 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
435432
progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
436433
437434
Returns:
438-
A tuple containing:
439-
- masks (list, np.ndarray): labelled image(s), where 0=no masks; 1,2,...=mask labels
440-
- flows (list): list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration
441-
- styles (list, np.ndarray): style vector summarizing each image of size 256.
435+
masks (list of 2D arrays or single 3D array): Labelled image, where 0=no masks; 1,2,...=mask labels;
436+
flows (list of lists 2D arrays or list of 3D arrays): flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY flows at each pixel;
437+
flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics);
438+
flows[k][3] = final pixel locations after Euler integration;
439+
styles (list of 1D arrays of length 256 or single 1D array): Style vector summarizing each image, also used to estimate size of objects in image.
442440
443441
"""
444442
if isinstance(x, list) or x.squeeze().ndim == 5:
@@ -733,9 +731,9 @@ def eval(self, x, channels=None, channel_axis=None, normalize=True, invert=False
733731
734732
735733
Returns:
736-
A tuple containing:
737-
- diam (np.ndarray): Final estimated diameters from images x or styles style after running both steps.
738-
- diam_style (np.ndarray): Estimated diameters from style alone.
734+
A tuple containing (diam, diam_style):
735+
diam (np.ndarray): Final estimated diameters from images x or styles style after running both steps;
736+
diam_style (np.ndarray): Estimated diameters from style alone.
739737
"""
740738
if isinstance(x, list):
741739
self.timing = []

cellpose/train.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
371371
model_name (str, optional): String - name of the network. Defaults to None.
372372
373373
Returns:
374-
Path: path to saved model weights
375-
np.ndarray: training losses
376-
np.ndarray: test losses
374+
tuple: A tuple containing the path to the saved model weights, training losses, and test losses.
375+
377376
"""
378377
device = net.device
379378

docs/train.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
Training
22
---------------------------
33

4+
.. warning::
5+
MPS support for pytorch is incomplete, and so training on Macs with MPS may give NaN's,
6+
if so please use the CPU instead
7+
48
At the beginning of training, cellpose computes the flow field representation for each
59
mask image (``dynamics.labels_to_flows``).
610

0 commit comments

Comments
 (0)