Skip to content

Commit 908faaf

Browse files
authored
Merge pull request #1408 from yuriyzubov/ring_artifacts
add support for anisotropic values of the flow3D_smooth
2 parents 4149f51 + ef481bf commit 908faaf

File tree

4 files changed

+30
-20
lines changed

4 files changed

+30
-20
lines changed

cellpose/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def get_arg_parser():
109109
"--min_size", required=False, default=15, type=int,
110110
help="minimum number of pixels per mask, can turn off with -1")
111111
algorithm_args.add_argument(
112-
"--flow3D_smooth", required=False, default=0, type=float,
113-
help="stddev of gaussian for smoothing of dP for dynamics in 3D, default of 0 means no smoothing")
112+
"--flow3D_smooth", required=False, default=0, type=float, nargs='+',
113+
help="stddev of gaussian for smoothing of dP for dynamics in 3D, default of 0 means no smoothing. Pass a list of values to allow smoothing of the ZYX axes independently")
114114
algorithm_args.add_argument(
115115
"--flow_threshold", default=0.4, type=float, help=
116116
"flow error threshold, 0 turns off this optional QC step. Default: %(default)s")

cellpose/models.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
187187
flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
188188
cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
189189
do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False.
190-
flow3D_smooth (int, optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0.
190+
flow3D_smooth (int or float or list of (int or float), optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. List smooths the ZYX axes independently and must be length 3. Defaults to 0.
191191
anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
192192
stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0.
193193
min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15.
@@ -319,18 +319,11 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
319319
do_3D=do_3D,
320320
anisotropy=anisotropy)
321321

322-
if do_3D:
323-
if flow3D_smooth > 0:
324-
models_logger.info(f"smoothing flows with sigma={flow3D_smooth}")
325-
dP = gaussian_filter(dP, (0, flow3D_smooth, flow3D_smooth, flow3D_smooth))
322+
if do_3D:
326323
torch.cuda.empty_cache()
327324
gc.collect()
328325

329326
if resample:
330-
# upsample flows flows before computing them:
331-
# dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0)
332-
# cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
333-
334327
# resize XY then YZ and then put channels first
335328
dP = transforms.resize_image(dP.transpose(1, 2, 3, 0), Ly=Ly_0, Lx=Lx_0, no_channels=False)
336329
dP = transforms.resize_image(dP.transpose(1, 0, 2, 3), Lx=Lx_0, Ly=Lz_0, no_channels=False)
@@ -341,16 +334,22 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
341334
cellprob = transforms.resize_image(cellprob.transpose(1, 0, 2), Lx=Lx_0, Ly=Lz_0, no_channels=True)
342335
cellprob = cellprob.transpose(1, 0, 2)
343336

344-
345337
# 2d case:
346338
if resample and not do_3D:
347-
# upsample flows before computing them:
348-
# dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0)
349-
# cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
350-
351339
# 2D images have N = 1 in batch dimension:
352340
dP = transforms.resize_image(dP.transpose(1, 2, 3, 0), Ly=Ly_0, Lx=Lx_0, no_channels=False).transpose(3, 0, 1, 2)
353341
cellprob = transforms.resize_image(cellprob, Ly=Ly_0, Lx=Lx_0, no_channels=True)
342+
343+
if do_3D and flow3D_smooth:
344+
if isinstance(flow3D_smooth, (int, float)):
345+
flow3D_smooth = [flow3D_smooth]*3
346+
if isinstance(flow3D_smooth, list) and len(flow3D_smooth) == 1:
347+
flow3D_smooth = flow3D_smooth*3
348+
if len(flow3D_smooth) == 3 and any(v > 0 for v in flow3D_smooth):
349+
models_logger.info(f"smoothing flows with ZYX sigma={flow3D_smooth}")
350+
dP = gaussian_filter(dP, [0, *flow3D_smooth])
351+
else:
352+
models_logger.warning(f"Could not do flow smoothing with {flow3D_smooth} either because its len was not 3 or no items were > 0, skipping flow3D_smoothing")
354353

355354
if compute_masks:
356355
# use user niter if specified, otherwise scale niter (200) with diameter

docs/do3d.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ If you see many cells that are fragmented, you can smooth the flows before the d
5757
are run in 3D using the ``flow3D_smooth`` parameter, which specifies the standard deviation of
5858
a Gaussian for smoothing the flows. The default is 0.0, which means no smoothing. Alternatively/additionally,
5959
you may want to train a model on 2D slices from your 3D data to improve the segmentation (see below).
60+
*If there are ring-like artifacts in your masks*, increasing ``flow3D_smooth`` can help remove them.
61+
You can specify the ZYX flow smoothing independently for each axis by passing a list of values to the ``flow3D_smooth``
62+
argument. For example: ``flow3D_smooth = [2, 0, 0]``
6063

6164
The network can rescale images using the user diameter and the model ``diam_mean`` (30),
6265
so for example if you input a diameter of 90,

tests/test_output.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def clear_output(data_dir, image_names):
4646
(False, True, None),
4747
(False, False, None),
4848
(True, False, None),
49-
(True, False, 40)
49+
(True, False, 40),
5050
]
5151
)
5252
def test_class_2D_one_img(data_dir, image_names, cellposemodel_fixture_24layer, compute_masks, resample, diameter):
@@ -172,13 +172,21 @@ def test_cli_3D_diam_anisotropy_shape(data_dir, image_names_3d, diam, aniso):
172172
compare_mask_shapes(data_dir, image_names_3d[0], "3D")
173173
clear_output(data_dir, image_names_3d)
174174

175-
175+
@pytest.mark.parametrize('flow3D_smooth',
176+
[None, 2, [1., 0., 0.]])
176177
@pytest.mark.slow
177-
def test_cli_3D_one_img(data_dir, image_names_3d):
178+
def test_cli_3D_one_img(data_dir, image_names_3d, flow3D_smooth):
178179
clear_output(data_dir, image_names_3d)
179180
use_gpu = torch.cuda.is_available() or torch.backends.mps.is_available()
180181
gpu_string = "--use_gpu" if use_gpu else ""
181-
cmd = f"python -m cellpose --image_path {str(data_dir / '3D' / image_names_3d[0])} --do_3D --save_tif {gpu_string} --verbose"
182+
183+
flow_string = ''
184+
if isinstance(flow3D_smooth, (float, int)):
185+
flow_string = f" --flow3D_smooth {flow3D_smooth}"
186+
elif isinstance(flow3D_smooth, list):
187+
flow_string = f" --flow3D_smooth {' '.join([str(f) for f in flow3D_smooth])}"
188+
189+
cmd = f"python -m cellpose --image_path {str(data_dir / '3D' / image_names_3d[0])} --do_3D --save_tif {gpu_string} --verbose{flow_string}"
182190
print(cmd)
183191
try:
184192
cmd_stdout = check_output(cmd, stderr=STDOUT, shell=True).decode()

0 commit comments

Comments
 (0)