Skip to content

Commit 849ca11

Browse files
Merge pull request #1417 from MouseLand/refactor_eval
Refactor eval
2 parents b475fc7 + b4c7608 commit 849ca11

File tree

1 file changed

+41
-50
lines changed

1 file changed

+41
-50
lines changed

cellpose/models.py

Lines changed: 41 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -268,17 +268,9 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
268268
x = x[np.newaxis, ...]
269269
nimg = x.shape[0]
270270

271-
image_scaling = None
272-
Ly_0 = x.shape[1]
273-
Lx_0 = x.shape[2]
274-
Lz_0 = None
275-
if do_3D or stitch_threshold > 0:
276-
Lz_0 = x.shape[0]
277-
if diameter is not None:
271+
image_scaling = 1.0
272+
if diameter is not None and diameter > 0:
278273
image_scaling = 30. / diameter
279-
x = transforms.resize_image(x,
280-
Ly=int(x.shape[1] * image_scaling),
281-
Lx=int(x.shape[2] * image_scaling))
282274

283275

284276
# normalize image
@@ -306,40 +298,17 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
306298
if do_normalization:
307299
x = transforms.normalize_img(x, **normalize_params)
308300

309-
# ajust the anisotropy when diameter is specified and images are resized:
310-
if isinstance(anisotropy, (float, int)) and image_scaling:
311-
anisotropy = image_scaling * anisotropy
312-
313301
dP, cellprob, styles = self._run_net(
314-
x,
302+
x,
303+
resample=resample,
304+
rescale=image_scaling,
315305
augment=augment,
316306
batch_size=batch_size,
317307
tile_overlap=tile_overlap,
318308
bsize=bsize,
319309
do_3D=do_3D,
320310
anisotropy=anisotropy)
321311

322-
if do_3D:
323-
torch.cuda.empty_cache()
324-
gc.collect()
325-
326-
if resample:
327-
# resize XY then YZ and then put channels first
328-
dP = transforms.resize_image(dP.transpose(1, 2, 3, 0), Ly=Ly_0, Lx=Lx_0, no_channels=False)
329-
dP = transforms.resize_image(dP.transpose(1, 0, 2, 3), Lx=Lx_0, Ly=Lz_0, no_channels=False)
330-
dP = dP.transpose(3, 1, 0, 2)
331-
332-
# resize cellprob:
333-
cellprob = transforms.resize_image(cellprob, Ly=Ly_0, Lx=Lx_0, no_channels=True)
334-
cellprob = transforms.resize_image(cellprob.transpose(1, 0, 2), Lx=Lx_0, Ly=Lz_0, no_channels=True)
335-
cellprob = cellprob.transpose(1, 0, 2)
336-
337-
# 2d case:
338-
if resample and not do_3D:
339-
# 2D images have N = 1 in batch dimension:
340-
dP = transforms.resize_image(dP.transpose(1, 2, 3, 0), Ly=Ly_0, Lx=Lx_0, no_channels=False).transpose(3, 0, 1, 2)
341-
cellprob = transforms.resize_image(cellprob, Ly=Ly_0, Lx=Lx_0, no_channels=True)
342-
343312
if do_3D and flow3D_smooth:
344313
if isinstance(flow3D_smooth, (int, float)):
345314
flow3D_smooth = [flow3D_smooth]*3
@@ -350,15 +319,21 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
350319
dP = gaussian_filter(dP, [0, *flow3D_smooth])
351320
else:
352321
models_logger.warning(f"Could not do flow smoothing with {flow3D_smooth} either because its len was not 3 or no items were > 0, skipping flow3D_smoothing")
322+
torch.cuda.empty_cache()
323+
gc.collect()
353324

354325
if compute_masks:
355326
# use user niter if specified, otherwise scale niter (200) with diameter
356327
niter_scale = 1 if image_scaling is None else image_scaling
357328
niter = int(200/niter_scale) if niter is None or niter == 0 else niter
358-
masks = self._compute_masks((Lz_0 or nimg, Ly_0, Lx_0), dP, cellprob, flow_threshold=flow_threshold,
359-
cellprob_threshold=cellprob_threshold, min_size=min_size,
360-
max_size_fraction=max_size_fraction, niter=niter,
361-
stitch_threshold=stitch_threshold, do_3D=do_3D)
329+
masks = self._compute_masks(x.shape, dP, cellprob,
330+
flow_threshold=flow_threshold,
331+
cellprob_threshold=cellprob_threshold,
332+
min_size=min_size,
333+
max_size_fraction=max_size_fraction,
334+
niter=niter,
335+
stitch_threshold=stitch_threshold,
336+
do_3D=do_3D)
362337
else:
363338
masks = np.zeros(0) #pass back zeros if not compute_masks
364339

@@ -368,38 +343,54 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
368343

369344

370345
def _run_net(self, x,
371-
augment=False,
372-
batch_size=8, tile_overlap=0.1,
373-
bsize=256, anisotropy=1.0, do_3D=False):
346+
rescale=1.0,
347+
resample=True,
348+
augment=False,
349+
batch_size=8,
350+
tile_overlap=0.1,
351+
bsize=256,
352+
anisotropy=1.0,
353+
do_3D=False):
374354
""" run network on image x """
375355
tic = time.time()
376356
shape = x.shape
377357
nimg = shape[0]
378358

379-
380359
if do_3D:
381360
Lz, Ly, Lx = shape[:-1]
382-
if anisotropy is not None and anisotropy != 1.0:
361+
if rescale != 1.0 or (anisotropy is not None and anisotropy != 1.0):
383362
models_logger.info(f"resizing 3D image with anisotropy={anisotropy}")
363+
anisotropy = 1.0 if anisotropy is None else anisotropy
364+
if rescale != 1.0:
365+
x = transforms.resize_image(x, Ly=int(Ly*rescale),
366+
Lx=int(Lx*rescale))
384367
x = transforms.resize_image(x.transpose(1,0,2,3),
385-
Ly=int(Lz*anisotropy),
386-
Lx=int(Lx)).transpose(1,0,2,3)
368+
Ly=int(Lz*anisotropy*rescale),
369+
Lx=int(Lx*rescale)).transpose(1,0,2,3)
387370
yf, styles = run_3D(self.net, x,
388371
batch_size=batch_size, augment=augment,
389372
tile_overlap=tile_overlap,
390373
bsize=bsize
391374
)
375+
if resample:
376+
if rescale != 1.0 or Lz != yf.shape[0]:
377+
models_logger.info("resizing 3D flows and cellprobl to original image size")
378+
if rescale != 1.0:
379+
yf = transforms.resize_image(yf, Ly=Ly, Lx=Lx)
380+
if Lz != yf.shape[0]:
381+
yf = transforms.resize_image(yf.transpose(1, 0, 2, 3), Ly=Lz, Lx=Lx).transpose(1, 0, 2, 3)
392382
cellprob = yf[..., -1]
393383
dP = yf[..., :-1].transpose((3, 0, 1, 2))
394384
else:
395385
yf, styles = run_net(self.net, x, bsize=bsize, augment=augment,
396386
batch_size=batch_size,
397387
tile_overlap=tile_overlap,
398-
)
388+
rsz=rescale if rescale !=1.0 else None)
389+
if resample:
390+
if rescale != 1.0:
391+
yf = transforms.resize_image(yf, shape[1], shape[2])
399392
cellprob = yf[..., -1]
400393
dP = yf[..., -3:-1].transpose((3, 0, 1, 2))
401-
if yf.shape[-1] > 3:
402-
styles = yf[..., :-3]
403394

404395
styles = styles.squeeze()
405396

0 commit comments

Comments
 (0)