Skip to content

Commit 93d024e

Browse files
authored
Merge pull request #1 from MouseLand/ring_artifacts-refactor
## Summary of Changes - Modularization of ring artifact detection code. - Improved logic for artifact removal and filtering. - Enhanced documentation and inline comments for clarity. - Compatibility updates to work with upstream changes from MouseLand. - Bug fixes identified during code review and testing. - Refined parameter settings and default behaviors for artifact processing.
2 parents 2eccc33 + e4cf3f8 commit 93d024e

File tree

8 files changed

+139
-55
lines changed

8 files changed

+139
-55
lines changed

.github/workflows/test_and_deploy.yml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
name: tests
55

6-
on:
6+
on:
77
push:
88
branches:
99
- main
@@ -14,7 +14,7 @@ on:
1414
- main
1515
workflow_dispatch:
1616

17-
jobs:
17+
jobs:
1818
test:
1919
name: ${{ matrix.platform }} py${{ matrix.python-version }}
2020
runs-on: ${{ matrix.platform }}
@@ -26,10 +26,12 @@ jobs:
2626

2727
steps:
2828
- uses: actions/checkout@v4
29+
2930
- name: Set up Python ${{ matrix.python-version }}
3031
uses: actions/setup-python@v5
3132
with:
3233
python-version: ${{ matrix.python-version }}
34+
3335
- name: Install dependencies
3436
run: |
3537
python -m pip install --upgrade pip
@@ -45,21 +47,25 @@ jobs:
4547

4648
deploy:
4749
# this will run when you have tagged a commit, starting with "v*"
48-
# and requires that you have put your twine API key in your
50+
# and requires that you have put your twine API key in your
4951
# github secrets (see readme for details)
5052
needs: [test]
5153
runs-on: ubuntu-latest
5254
if: contains(github.ref, 'tags')
55+
5356
steps:
5457
- uses: actions/checkout@v4
58+
5559
- name: Set up Python
56-
uses: actions/setup-python@v4
60+
uses: actions/setup-python@v5
5761
with:
5862
python-version: "3.x"
63+
5964
- name: Install dependencies
6065
run: |
6166
python -m pip install --upgrade pip
6267
pip install -U setuptools setuptools_scm wheel twine
68+
6369
- name: Build and publish
6470
env:
6571
TWINE_USERNAME: __token__

cellpose/gui/io.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def _load_image(parent, filename=None, load_seg=True, load_3D=False):
139139
except Exception as e:
140140
print("ERROR: images not compatible")
141141
print(f"ERROR: {e}")
142+
return
142143

143144
if parent.loaded:
144145
parent.reset()

cellpose/io.py

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,33 @@
4747
io_logger = logging.getLogger(__name__)
4848

4949
def logger_setup(cp_path=".cellpose", logfile_name="run.log", stdout_file_replacement=None):
50+
"""Set up logging to a file and stdout (or a file replacement).
51+
52+
Creates the log directory if it doesn't exist, removes any existing log
53+
file, and configures the root logger to write INFO-level and above messages
54+
to both a log file and stdout (or a replacement file).
55+
56+
Parameters
57+
----------
58+
cp_path : str, optional
59+
Directory name under the user's home directory for log output.
60+
Default is ".cellpose".
61+
logfile_name : str, optional
62+
Name of the log file created inside cp_path. Default is "run.log".
63+
stdout_file_replacement : str or None, optional
64+
If provided, log output is written to this file path instead of stdout.
65+
66+
Returns
67+
-------
68+
logger : logging.Logger
69+
Configured logger for this module. Only INFO and above messages are
70+
emitted by default. To enable debug output, call
71+
``logger.setLevel(logging.DEBUG)`` on the returned logger.
72+
73+
Notes
74+
-----
75+
The log file is deleted and recreated on each call.
76+
"""
5077
cp_dir = pathlib.Path.home().joinpath(cp_path)
5178
cp_dir.mkdir(exist_ok=True)
5279
log_file = cp_dir.joinpath(logfile_name)
@@ -189,6 +216,28 @@ def imread(filename):
189216
if not ND2:
190217
io_logger.critical("ERROR: need to 'pip install nd2' to load in .nd2 file")
191218
return None
219+
else:
220+
with nd2.ND2File(filename) as nd2_file:
221+
img = nd2_file.asarray()
222+
sizes = nd2_file.sizes
223+
224+
kept_axes = [nd2.AXIS.Y, nd2.AXIS.X, nd2.AXIS.CHANNEL, nd2.AXIS.Z]
225+
# For multi-dimensional data (T, P, etc.), take first frame/position
226+
# Work backwards through axes to avoid index shifting
227+
for i, (ax_name, size) in reversed(list(enumerate(sizes.items()))):
228+
# Keep Y, X, C, Z; remove or reduce everything else
229+
if ax_name not in kept_axes:
230+
if size > 1:
231+
io_logger.warning(
232+
f"ND2 file has {size} {ax_name} - using first only"
233+
)
234+
# Take first element (works for both size=1 and size>1)
235+
img = np.take(img, 0, axis=i)
236+
237+
# Result should now be YX, CYX, ZYX, or CZYX depending on original axes
238+
# nd2 preserves axis order from sizes dict (usually C, Z, Y, X)
239+
return img
240+
192241
elif ext == ".nrrd":
193242
if not NRRD:
194243
io_logger.critical(
@@ -230,40 +279,47 @@ def imread_2D(img_file):
230279
img_out (numpy.ndarray): The 3-channel image data as a NumPy array.
231280
"""
232281
img = imread(img_file)
282+
if img is None:
283+
raise ValueError(f"could not read image file {img_file}")
233284
return transforms.convert_image(img, do_3D=False)
234285

235286

236287
def imread_3D(img_file):
237288
"""
238289
Read in a 3D image file and convert it to have a channel axis last automatically. Attempts to do this for multi-channel and grayscale images.
239290
240-
If multichannel image, the channel axis is assumed to be the smallest dimension, and the z axis is the next smallest dimension.
241-
Use `cellpose.io.imread()` to load the full image without selecting the z and channel axes.
242-
291+
For grayscale images (3D array), axis 0 is assumed to be the Z axis (e.g., Z x Y x X).
292+
For multichannel images (4D array), the channel axis is assumed to be the smallest dimension,
293+
and the Z axis is assumed to be the first remaining axis after the channel axis is removed.
294+
295+
Use ``cellpose.io.imread()`` to load the full image without automatic axis selection,
296+
then specify ``z_axis`` and ``channel_axis`` manually when calling ``model.eval``.
297+
243298
Args:
244299
img_file (str): The path to the image file.
245300
246301
Returns:
247-
img_out (numpy.ndarray): The image data as a NumPy array.
302+
img_out (numpy.ndarray): The image data as a NumPy array with channels last, or None if loading fails.
248303
"""
249304
img = imread(img_file)
305+
if img is None:
306+
raise ValueError(f"could not read image file {img_file}")
250307

251308
dimension_lengths = list(img.shape)
252309

253310
# grayscale images:
254311
if img.ndim == 3:
255312
channel_axis = None
256313
# guess at z axis:
257-
z_axis = np.argmin(dimension_lengths)
314+
z_axis = 0
258315

259316
elif img.ndim == 4:
260317
# guess at channel axis:
261318
channel_axis = np.argmin(dimension_lengths)
262-
263-
# guess at z axis:
264-
# set channel axis to max so argmin works:
265-
dimension_lengths[channel_axis] = max(dimension_lengths)
266-
z_axis = np.argmin(dimension_lengths)
319+
dimensions = list(range(img.ndim))
320+
dimensions.pop(channel_axis)
321+
# guess at z axis as the first remaining dimension:
322+
z_axis = dimensions[0]
267323

268324
else:
269325
raise ValueError(f'image shape error, 3D image must 3 or 4 dimensional. Number of dimensions: {img.ndim}')

cellpose/models.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
187187
flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
188188
cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
189189
do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False.
190-
flow3D_smooth (int or list[int], optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0.
190+
flow3D_smooth (int or float or list of (int or float), optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. List smooths the ZYX axes independently and must be length 3. Defaults to 0.
191191
anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
192192
stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0.
193193
min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15.
@@ -320,20 +320,10 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
320320
anisotropy=anisotropy)
321321

322322
if do_3D:
323-
324-
if isinstance(flow3D_smooth, int):
325-
flow3D_smooth = [flow3D_smooth]*3
326-
if any(v > 0 for v in flow3D_smooth):
327-
models_logger.info(f"smoothing flows with sigma={flow3D_smooth}")
328-
dP = gaussian_filter(dP, [0, *flow3D_smooth])
329323
torch.cuda.empty_cache()
330324
gc.collect()
331325

332326
if resample:
333-
# upsample flows flows before computing them:
334-
# dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0)
335-
# cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
336-
337327
# resize XY then YZ and then put channels first
338328
dP = transforms.resize_image(dP.transpose(1, 2, 3, 0), Ly=Ly_0, Lx=Lx_0, no_channels=False)
339329
dP = transforms.resize_image(dP.transpose(1, 0, 2, 3), Lx=Lx_0, Ly=Lz_0, no_channels=False)
@@ -344,16 +334,20 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
344334
cellprob = transforms.resize_image(cellprob.transpose(1, 0, 2), Lx=Lx_0, Ly=Lz_0, no_channels=True)
345335
cellprob = cellprob.transpose(1, 0, 2)
346336

347-
348337
# 2d case:
349338
if resample and not do_3D:
350-
# upsample flows before computing them:
351-
# dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0)
352-
# cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
353-
354339
# 2D images have N = 1 in batch dimension:
355340
dP = transforms.resize_image(dP.transpose(1, 2, 3, 0), Ly=Ly_0, Lx=Lx_0, no_channels=False).transpose(3, 0, 1, 2)
356341
cellprob = transforms.resize_image(cellprob, Ly=Ly_0, Lx=Lx_0, no_channels=True)
342+
343+
if do_3D and flow3D_smooth > 0:
344+
if isinstance(flow3D_smooth, int) or isinstance(flow3D_smooth, float):
345+
flow3D_smooth = [flow3D_smooth]*3
346+
if len(flow3D_smooth) == 3 and any(v > 0 for v in flow3D_smooth):
347+
models_logger.info(f"smoothing flows with ZYX sigma={flow3D_smooth}")
348+
dP = gaussian_filter(dP, [0, *flow3D_smooth])
349+
else:
350+
models_logger.warning(f"Could not do flow smoothing with {flow3D_smooth} either because its len was not 3 or no items were > 0, skipping flow3D_smoothing")
357351

358352
if compute_masks:
359353
# use user niter if specified, otherwise scale niter (200) with diameter

cellpose/transforms.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
612612
x_out[..., 0] = x
613613
x = x_out
614614
del x_out
615-
transforms_logger.info(f'processing grayscale image with {x.shape[0], x.shape[1]} HW')
615+
transforms_logger.debug(f'processing grayscale image with {x.shape[0], x.shape[1]} HW')
616616
elif ndim == 3:
617617
# assume 2d with channels
618618
# find dim with smaller size between first and last dims
@@ -629,7 +629,7 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
629629
x_out[..., :num_channels] = x[..., :num_channels]
630630
x = x_out
631631
del x_out
632-
transforms_logger.info(f'processing image with {x.shape[0], x.shape[1]} HW, and {x.shape[2]} channels')
632+
transforms_logger.debug(f'processing image with {x.shape[0], x.shape[1]} HW, and {x.shape[2]} channels')
633633
elif ndim == 4:
634634
# assume batch of 2d with channels
635635

@@ -642,7 +642,7 @@ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
642642
x_out[..., :num_channels] = x[..., :num_channels]
643643
x = x_out
644644
del x_out
645-
transforms_logger.info(f'processing image batch with {x.shape[0]} images, {x.shape[1], x.shape[2]} HW, and {x.shape[3]} channels')
645+
transforms_logger.debug(f'processing image batch with {x.shape[0]} images, {x.shape[1], x.shape[2]} HW, and {x.shape[3]} channels')
646646
else:
647647
# something is wrong: yell
648648
expected_shapes = "2D (H, W), 3D (H, W, C), or 4D (Z, H, W, C)"
@@ -776,10 +776,6 @@ def normalize_img(img, normalize=True, norm3D=True, invert=False, lowhigh=None,
776776
transforms_logger.critical(error_message)
777777
raise ValueError(error_message)
778778

779-
# Move channel axis back to the original position
780-
if axis != -1 and axis != img_norm.ndim - 1:
781-
img_norm = np.moveaxis(img_norm, -1, axis)
782-
783779
# The transformer can get confused if a channel is all 1's instead of all 0's:
784780
for i, chan_did_normalize in enumerate(cgood):
785781
if not chan_did_normalize:
@@ -788,6 +784,10 @@ def normalize_img(img, normalize=True, norm3D=True, invert=False, lowhigh=None,
788784
if img_norm.ndim == 4:
789785
img_norm[:, :, :, i] = 0
790786

787+
# Move channel axis back to the original position
788+
if axis != -1 and axis != img_norm.ndim - 1:
789+
img_norm = np.moveaxis(img_norm, -1, axis)
790+
791791
return img_norm
792792

793793
def resize_safe(img, Ly, Lx, interpolation=cv2.INTER_LINEAR):

docs/do3d.rst

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,21 @@ then the GUI will automatically run 3D segmentation and display it in the GUI. W
2525
the command line for progress. It is recommended to use a GPU to speed up processing.
2626

2727
In the CLI/notebook, you need to specify the ``z_axis`` and the ``channel_axis``
28-
parameters to specify the axis (0-based) of the image which corresponds to the image channels and to the z axis.
29-
For example an image with 2 channels of shape (1024,1024,2,105,1) can be
30-
specified with ``channel_axis=2`` and ``z_axis=3``. These parameters can be specified using the command line
31-
with ``--channel_axis`` or ``--z_axis`` or as inputs to ``model.eval`` for
28+
parameters to specify the axis (0-based) of the image which corresponds to the image channels and to the z axis.
29+
For example an image with 2 channels of shape (1024,1024,2,105,1) can be
30+
specified with ``channel_axis=2`` and ``z_axis=3``. These parameters can be specified using the command line
31+
with ``--channel_axis`` or ``--z_axis`` or as inputs to ``model.eval`` for
3232
the ``CellposeModel`` model.
3333

34+
As a convenience, :func:`cellpose.io.imread_3D` will attempt to load a 3D image and
35+
automatically guess the axes. For grayscale images (3D array), axis 0 is assumed
36+
to be the Z axis (e.g., Z x Y x X). For multichannel images (4D array), the
37+
channel axis is assumed to be the smallest dimension, and the Z axis is assumed to
38+
be the first remaining axis after the channel axis is identified (e.g., for a
39+
Z x C x Y x X image, channel axis = 1 and z axis = 0). If your image does not
40+
follow these conventions, use ``cellpose.io.imread`` and specify ``z_axis`` and
41+
``channel_axis`` manually.
42+
3443
Volumetric stacks do not always have the same sampling in XY as they do in Z.
3544
Therefore you can set an ``anisotropy`` parameter in CLI/notebook to allow for differences in
3645
sampling, e.g. set to 2.0 if Z is sampled half as dense as X or Y, and then in the algorithm
@@ -48,6 +57,9 @@ If you see many cells that are fragmented, you can smooth the flows before the d
4857
are run in 3D using the ``flow3D_smooth`` parameter, which specifies the standard deviation of
4958
a Gaussian for smoothing the flows. The default is 0.0, which means no smoothing. Alternatively/additionally,
5059
you may want to train a model on 2D slices from your 3D data to improve the segmentation (see below).
60+
*If there are ring-like artifacts in your masks*, increasing ``flow3D_smooth`` can help remove them.
61+
You can specify the ZYX flow smoothing independently for each axis by passing a list of values to the ``flow3D_smooth``
62+
argument. For example: ``flow3D_smooth = [2, 0, 0]``
5163

5264
The network can rescale images using the user diameter and the model ``diam_mean`` (30),
5365
so for example if you input a diameter of 90,

tests/test_output.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,28 @@ def clear_output(data_dir, image_names):
3939
os.remove(npy_output)
4040

4141

42-
@pytest.mark.parametrize('compute_masks, resample, diameter',
42+
@pytest.mark.parametrize('compute_masks, resample, diameter, flow3D_smooth',
4343
[
44-
(True, True, 40),
45-
(True, True, None),
46-
(False, True, None),
47-
(False, False, None),
48-
(True, False, None),
49-
(True, False, 40)
44+
(True, True, 40, None),
45+
(True, True, None, None),
46+
(False, True, None, None),
47+
(False, False, None, None),
48+
(True, False, None, None),
49+
(True, False, 40, None),
50+
(False, True, None, 2),
51+
(False, True, None, [2, 0, 0]),
52+
(False, False, None, 2),
5053
]
5154
)
52-
def test_class_2D_one_img(data_dir, image_names, cellposemodel_fixture_24layer, compute_masks, resample, diameter):
55+
def test_class_2D_one_img(data_dir, image_names, cellposemodel_fixture_24layer, compute_masks, resample, diameter, flow3D_smooth):
5356
clear_output(data_dir, image_names)
5457

5558
img_file = data_dir / '2D' / image_names[0]
5659

5760
img = io.imread_2D(img_file)
5861
# flowps = io.imread(img_file.parent / (img_file.stem + "_cp4_gt_flowps.tif"))
5962

60-
masks_pred, _, _ = cellposemodel_fixture_24layer.eval(img, normalize=True, compute_masks=compute_masks, resample=resample, diameter=diameter)
63+
masks_pred, _, _ = cellposemodel_fixture_24layer.eval(img, normalize=True, compute_masks=compute_masks, resample=resample, diameter=diameter, flow3D_smooth=flow3D_smooth)
6164

6265
if not compute_masks:
6366
# not compute_masks won't return masks so can't check

0 commit comments

Comments
 (0)