@@ -268,17 +268,9 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
268268 x = x [np .newaxis , ...]
269269 nimg = x .shape [0 ]
270270
271- image_scaling = None
272- Ly_0 = x .shape [1 ]
273- Lx_0 = x .shape [2 ]
274- Lz_0 = None
275- if do_3D or stitch_threshold > 0 :
276- Lz_0 = x .shape [0 ]
277- if diameter is not None :
271+ image_scaling = 1.0
272+ if diameter is not None and diameter > 0 :
278273 image_scaling = 30. / diameter
279- x = transforms .resize_image (x ,
280- Ly = int (x .shape [1 ] * image_scaling ),
281- Lx = int (x .shape [2 ] * image_scaling ))
282274
283275
284276 # normalize image
@@ -306,40 +298,17 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
306298 if do_normalization :
307299 x = transforms .normalize_img (x , ** normalize_params )
308300
309- # ajust the anisotropy when diameter is specified and images are resized:
310- if isinstance (anisotropy , (float , int )) and image_scaling :
311- anisotropy = image_scaling * anisotropy
312-
313301 dP , cellprob , styles = self ._run_net (
314- x ,
302+ x ,
303+ resample = resample ,
304+ rescale = image_scaling ,
315305 augment = augment ,
316306 batch_size = batch_size ,
317307 tile_overlap = tile_overlap ,
318308 bsize = bsize ,
319309 do_3D = do_3D ,
320310 anisotropy = anisotropy )
321311
322- if do_3D :
323- torch .cuda .empty_cache ()
324- gc .collect ()
325-
326- if resample :
327- # resize XY then YZ and then put channels first
328- dP = transforms .resize_image (dP .transpose (1 , 2 , 3 , 0 ), Ly = Ly_0 , Lx = Lx_0 , no_channels = False )
329- dP = transforms .resize_image (dP .transpose (1 , 0 , 2 , 3 ), Lx = Lx_0 , Ly = Lz_0 , no_channels = False )
330- dP = dP .transpose (3 , 1 , 0 , 2 )
331-
332- # resize cellprob:
333- cellprob = transforms .resize_image (cellprob , Ly = Ly_0 , Lx = Lx_0 , no_channels = True )
334- cellprob = transforms .resize_image (cellprob .transpose (1 , 0 , 2 ), Lx = Lx_0 , Ly = Lz_0 , no_channels = True )
335- cellprob = cellprob .transpose (1 , 0 , 2 )
336-
337- # 2d case:
338- if resample and not do_3D :
339- # 2D images have N = 1 in batch dimension:
340- dP = transforms .resize_image (dP .transpose (1 , 2 , 3 , 0 ), Ly = Ly_0 , Lx = Lx_0 , no_channels = False ).transpose (3 , 0 , 1 , 2 )
341- cellprob = transforms .resize_image (cellprob , Ly = Ly_0 , Lx = Lx_0 , no_channels = True )
342-
343312 if do_3D and flow3D_smooth :
344313 if isinstance (flow3D_smooth , (int , float )):
345314 flow3D_smooth = [flow3D_smooth ]* 3
@@ -350,15 +319,21 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
350319 dP = gaussian_filter (dP , [0 , * flow3D_smooth ])
351320 else :
352321 models_logger .warning (f"Could not do flow smoothing with { flow3D_smooth } either because its len was not 3 or no items were > 0, skipping flow3D_smoothing" )
322+ torch .cuda .empty_cache ()
323+ gc .collect ()
353324
354325 if compute_masks :
355326 # use user niter if specified, otherwise scale niter (200) with diameter
356327 niter_scale = 1 if image_scaling is None else image_scaling
357328 niter = int (200 / niter_scale ) if niter is None or niter == 0 else niter
358- masks = self ._compute_masks ((Lz_0 or nimg , Ly_0 , Lx_0 ), dP , cellprob , flow_threshold = flow_threshold ,
359- cellprob_threshold = cellprob_threshold , min_size = min_size ,
360- max_size_fraction = max_size_fraction , niter = niter ,
361- stitch_threshold = stitch_threshold , do_3D = do_3D )
329+ masks = self ._compute_masks (x .shape , dP , cellprob ,
330+ flow_threshold = flow_threshold ,
331+ cellprob_threshold = cellprob_threshold ,
332+ min_size = min_size ,
333+ max_size_fraction = max_size_fraction ,
334+ niter = niter ,
335+ stitch_threshold = stitch_threshold ,
336+ do_3D = do_3D )
362337 else :
363338 masks = np .zeros (0 ) #pass back zeros if not compute_masks
364339
@@ -368,38 +343,54 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
368343
369344
370345 def _run_net (self , x ,
371- augment = False ,
372- batch_size = 8 , tile_overlap = 0.1 ,
373- bsize = 256 , anisotropy = 1.0 , do_3D = False ):
346+ rescale = 1.0 ,
347+ resample = True ,
348+ augment = False ,
349+ batch_size = 8 ,
350+ tile_overlap = 0.1 ,
351+ bsize = 256 ,
352+ anisotropy = 1.0 ,
353+ do_3D = False ):
374354 """ run network on image x """
375355 tic = time .time ()
376356 shape = x .shape
377357 nimg = shape [0 ]
378358
379-
380359 if do_3D :
381360 Lz , Ly , Lx = shape [:- 1 ]
382- if anisotropy is not None and anisotropy != 1.0 :
361+ if rescale != 1.0 or ( anisotropy is not None and anisotropy != 1.0 ) :
383362 models_logger .info (f"resizing 3D image with anisotropy={ anisotropy } " )
363+ anisotropy = 1.0 if anisotropy is None else anisotropy
364+ if rescale != 1.0 :
365+ x = transforms .resize_image (x , Ly = int (Ly * rescale ),
366+ Lx = int (Lx * rescale ))
384367 x = transforms .resize_image (x .transpose (1 ,0 ,2 ,3 ),
385- Ly = int (Lz * anisotropy ),
386- Lx = int (Lx )).transpose (1 ,0 ,2 ,3 )
368+ Ly = int (Lz * anisotropy * rescale ),
369+ Lx = int (Lx * rescale )).transpose (1 ,0 ,2 ,3 )
387370 yf , styles = run_3D (self .net , x ,
388371 batch_size = batch_size , augment = augment ,
389372 tile_overlap = tile_overlap ,
390373 bsize = bsize
391374 )
375+ if resample :
376+ if rescale != 1.0 or Lz != yf .shape [0 ]:
377+ models_logger .info ("resizing 3D flows and cellprobl to original image size" )
378+ if rescale != 1.0 :
379+ yf = transforms .resize_image (yf , Ly = Ly , Lx = Lx )
380+ if Lz != yf .shape [0 ]:
381+ yf = transforms .resize_image (yf .transpose (1 , 0 , 2 , 3 ), Ly = Lz , Lx = Lx ).transpose (1 , 0 , 2 , 3 )
392382 cellprob = yf [..., - 1 ]
393383 dP = yf [..., :- 1 ].transpose ((3 , 0 , 1 , 2 ))
394384 else :
395385 yf , styles = run_net (self .net , x , bsize = bsize , augment = augment ,
396386 batch_size = batch_size ,
397387 tile_overlap = tile_overlap ,
398- )
388+ rsz = rescale if rescale != 1.0 else None )
389+ if resample :
390+ if rescale != 1.0 :
391+ yf = transforms .resize_image (yf , shape [1 ], shape [2 ])
399392 cellprob = yf [..., - 1 ]
400393 dP = yf [..., - 3 :- 1 ].transpose ((3 , 0 , 1 , 2 ))
401- if yf .shape [- 1 ] > 3 :
402- styles = yf [..., :- 3 ]
403394
404395 styles = styles .squeeze ()
405396
0 commit comments