diff --git a/cellpose/core.py b/cellpose/core.py index 00533e5c..b48d83cc 100644 --- a/cellpose/core.py +++ b/cellpose/core.py @@ -234,7 +234,7 @@ def run_net(net, imgs, batch_size=8, augment=False, tile=True, tile_overlap=0.1, # slices from padding # slc = [slice(0, self.nclasses) for n in range(imgs.ndim)] # changed from imgs.shape[n]+1 for first slice size slc = [slice(0, imgs.shape[n] + 1) for n in range(imgs.ndim)] - slc[-3] = slice(0, 3) + slc[-3] = slice(0, net.nout) slc[-2] = slice(ysub[0], ysub[-1] + 1) slc[-1] = slice(xsub[0], xsub[-1] + 1) slc = tuple(slc) @@ -286,7 +286,8 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0 yf = np.zeros((Lz, nout, imgi.shape[-2], imgi.shape[-1]), np.float32) styles = [] if ny * nx > batch_size: - ziterator = trange(Lz, file=tqdm_out) + ziterator = (trange(Lz, file=tqdm_out, mininterval=30) + if Lz > 1 else range(Lz)) for i in ziterator: yfi, stylei = _run_tiled(net, imgi[i], augment=augment, bsize=bsize, tile_overlap=tile_overlap) @@ -297,7 +298,8 @@ def _run_tiled(net, imgi, batch_size=8, augment=False, bsize=224, tile_overlap=0 ntiles = ny * nx nimgs = max(2, int(np.round(batch_size / ntiles))) niter = int(np.ceil(Lz / nimgs)) - ziterator = trange(niter, file=tqdm_out) + ziterator = (trange(niter, file=tqdm_out, mininterval=30) + if Lz > 1 else range(niter)) for k in ziterator: IMGa = np.zeros((ntiles * nimgs, nchan, ly, lx), np.float32) for i in range(min(Lz - k * nimgs, nimgs)): diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 8deb533f..189774bc 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -210,7 +210,7 @@ def img_norm(imgi): def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsample=0.7, ds_max=7, diams=None, pscale=None, iso=True, sigma0=None, sigma1=None, - ds=None): + ds=None, uniform_blur=False, partial_blur=False): """Adds noise to the input image. Args: @@ -234,30 +234,50 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp """ device = lbl.device imgi = torch.zeros_like(lbl) + Ly, Lx = lbl.shape[-2:] diams = diams if diams is not None else 30. * torch.ones(len(lbl), device=device) #ds0 = 1 if ds is None else ds.item() ds = ds * torch.ones( (len(lbl),), device=device, dtype=torch.long) if ds is not None else ds + # downsample + ii = [] + idownsample = np.random.rand(len(lbl)) < downsample + if (ds is None and idownsample.sum() > 0.) or not iso: + ds = torch.ones(len(lbl), dtype=torch.long, device=device) + ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),), + device=device) + ii = torch.nonzero(ds > 1).flatten() + elif ds is not None and (ds > 1).sum(): + ii = torch.nonzero(ds > 1).flatten() + # add gaussian blur - iblur = np.random.rand(len(lbl)) < blur + iblur = torch.rand(len(lbl), device=device) < blur + iblur[ii] = True if iblur.sum() > 0: if sigma0 is None: - # was 10 - xrand = np.random.exponential(1, size=iblur.sum()) - xrand = np.clip(xrand * 0.5, 0.1, 1.0) - xrand *= gblur - sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to( - device) - #(1 + torch.rand(iblur.sum(), device=device)) - if not iso: - sr = diams[iblur] / 30. * 2 * (1 + - torch.rand(iblur.sum(), device=device)) - sigma1 = (torch.rand(iblur.sum(), device=device) > 0.66) * sr + if uniform_blur and iso: + xr = torch.rand(len(lbl), device=device) + if len(ii) > 0: + xr[ii] = ds[ii].float() / 2. / gblur + sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur]) + sigma1 = sigma0.clone() + elif not iso: + xr = torch.rand(len(lbl), device=device) + if len(ii) > 0: + xr[ii] = (ds[ii].float()) / gblur + xr[ii] = xr[ii] + torch.rand(len(ii), device=device) * 0.7 - 0.35 + xr[ii] = torch.clip(xr[ii], 0.05, 1.5) + sigma0 = diams[iblur] / 30. * gblur * xr[iblur] + sigma1 = sigma0.clone() / 10. else: - sigma1 = sigma0.clone( - ) #+ torch.randint(0, 3, size=(len(sigma0.clone()),), device=device) + xrand = np.random.exponential(1, size=iblur.sum()) + xrand = np.clip(xrand * 0.5, 0.1, 1.0) + xrand *= gblur + sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to( + device) + sigma1 = sigma0.clone() else: sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device) sigma1 = sigma1 * torch.ones((iblur.sum(),), device=device) @@ -278,22 +298,28 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp gfilt = torch.einsum("ck,cl->ckl", gfilt0, gfilt1) gfilt /= gfilt.sum(axis=(1, 2), keepdims=True) - imgi[iblur] = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1), + lbl_blur = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1), padding=gfilt.shape[-1] // 2, groups=gfilt.shape[0]).transpose(1, 0) + if partial_blur: + #yc, xc = np.random.randint(100, Ly-100), np.random.randint(100, Lx-100) + imgi[iblur] = lbl[iblur].clone() + Lxc = int(Lx * 0.85) + ym, xm = torch.meshgrid(torch.zeros(Ly, dtype=torch.float32), + torch.arange(0, Lxc, dtype=torch.float32), + indexing="ij") + mask = torch.exp(-(ym**2 + xm**2) / 2*(0.001**2)) + mask -= mask.min() + mask /= mask.max() + lbl_blur_crop = lbl_blur[:, :, :, :Lxc] + imgi[iblur, :, :, :Lxc] = (lbl_blur_crop * mask + + (1-mask) * imgi[iblur, :, :, :Lxc]) + else: + imgi[iblur] = lbl_blur imgi[~iblur] = lbl[~iblur] - # downsample - ii = [] - idownsample = np.random.rand(len(lbl)) < downsample - if (ds is None and idownsample.sum() > 0.) or not iso: - ds = torch.ones(len(lbl), dtype=torch.long, device=device) - ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),), - device=device) - ii = torch.nonzero(ds > 1) - elif ds is not None and (ds > 1).sum(): - ii = torch.nonzero(ds > 1) + # apply downsample for k in ii: i0 = imgi[k:k + 1, :, ::ds[k], ::ds[k]] if iso else imgi[k:k + 1, :, ::ds[k]] imgi[k] = interpolate(i0, size=lbl[k].shape[-2:], mode="bilinear") @@ -320,8 +346,8 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7, downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30, - ds_max=7, iso=True, rotate=True, - device=None, xy=(224, 224), + ds_max=7, uniform_blur=False, iso=True, rotate=True, + device=torch.device("cuda"), xy=(224, 224), nchan_noise=1, keep_raw=True): """ Applies random rotation, resizing, and noise to the input data. @@ -478,7 +504,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1, augment=False, resample=True, invert=False, flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, - min_size=15, niter=None, interp=True): + min_size=15, niter=None, interp=True, bsize=224): """ Restore array or list of images using the image restoration model, and then segment. @@ -541,9 +567,10 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, img_restore = self.dn.eval(x, batch_size=batch_size, channels=channels, channel_axis=channel_axis, z_axis=z_axis, + do_3D=do_3D, normalize=normalize_params, rescale=rescale, diameter=diameter, tile=tile, - tile_overlap=tile_overlap) + tile_overlap=tile_overlap, bsize=bsize) # turn off special normalization for segmentation normalize_params = normalize_default @@ -559,7 +586,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, invert=invert, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, do_3D=do_3D, anisotropy=anisotropy, stitch_threshold=stitch_threshold, min_size=min_size, niter=niter, - interp=interp) + interp=interp, bsize=bsize) return masks, flows, styles, img_restore @@ -660,7 +687,8 @@ def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None, self.net_type = "cellpose_denoise" def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, - normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1): + normalize=True, rescale=None, diameter=None, tile=True, do_3D=False, + tile_overlap=0.1, bsize=224): """ Restore array or list of images using the image restoration model. @@ -714,11 +742,12 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2)) else channels, channel_axis=channel_axis, z_axis=z_axis, normalize=normalize, + do_3D=do_3D, rescale=rescale[i] if isinstance(rescale, list) or isinstance(rescale, np.ndarray) else rescale, diameter=diameter[i] if isinstance(diameter, list) or isinstance(diameter, np.ndarray) else diameter, tile=tile, - tile_overlap=tile_overlap) + tile_overlap=tile_overlap, bsize=bsize) imgs.append(imgi) if isinstance(x, np.ndarray): imgs = np.array(imgs) @@ -727,7 +756,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, else: # reshape image x = transforms.convert_image(x, channels, channel_axis=channel_axis, - z_axis=z_axis) + z_axis=z_axis, do_3D=do_3D) if x.ndim < 4: squeeze = True x = x[np.newaxis, ...] @@ -767,18 +796,18 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, x[..., c] = self._eval(self.net, x[..., c:c + 1], batch_size=batch_size, normalize=normalize, rescale=rescale0, tile=tile, - tile_overlap=tile_overlap) + tile_overlap=tile_overlap, bsize=bsize)[...,0] else: x[..., c] = self._eval(self.net_chan2, x[..., c:c + 1], batch_size=batch_size, normalize=normalize, rescale=rescale0, tile=tile, - tile_overlap=tile_overlap) + tile_overlap=tile_overlap, bsize=bsize)[...,0] x = x[0] if squeeze else x return x def _eval(self, net, x, batch_size=8, normalize=True, rescale=None, tile=True, - tile_overlap=0.1): + tile_overlap=0.1, bsize=224): """ Run image restoration model on a single channel. @@ -818,40 +847,49 @@ def _eval(self, net, x, batch_size=8, normalize=True, rescale=None, tile=True, do_normalization = True if normalize_params["normalize"] else False - tqdm_out = utils.TqdmToLogger(denoise_logger, level=logging.INFO) - iterator = trange(nimg, file=tqdm_out, - mininterval=30) if nimg > 1 else range(nimg) - imgs = np.zeros((*x.shape[:-1], 1), np.float32) - for i in iterator: - img = np.asarray(x[i]) - if do_normalization: - img = transforms.normalize_img(img, **normalize_params) - if rescale != 1.0: - img = transforms.resize_image(img, rsz=[rescale, rescale]) - if img.ndim == 2: - img = img[:, :, np.newaxis] - yf, style = run_net(net, img, batch_size=batch_size, augment=False, - tile=tile, tile_overlap=tile_overlap) - img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2]) - - if img.ndim == 2: - img = img[:, :, np.newaxis] - imgs[i] = img - del yf, style + img = np.asarray(x) + if do_normalization: + img = transforms.normalize_img(img, **normalize_params) + if rescale != 1.0: + img = transforms.resize_image(img, rsz=rescale) + yf, style = run_net(self.net, img, bsize=bsize, + tile=tile, tile_overlap=tile_overlap) + yf = transforms.resize_image(yf, shape[1], shape[2]) + imgs = yf + del yf, style + + # imgs = np.zeros((*x.shape[:-1], 1), np.float32) + # for i in iterator: + # img = np.asarray(x[i]) + # if do_normalization: + # img = transforms.normalize_img(img, **normalize_params) + # if rescale != 1.0: + # img = transforms.resize_image(img, rsz=[rescale, rescale]) + # if img.ndim == 2: + # img = img[:, :, np.newaxis] + # yf, style = run_net(net, img, batch_size=batch_size, augment=False, + # tile=tile, tile_overlap=tile_overlap, bsize=bsize) + # img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2]) + + # if img.ndim == 2: + # img = img[:, :, np.newaxis] + # imgs[i] = img + # del yf, style net_time = time.time() - tic if nimg > 1: denoise_logger.info("imgs denoised in %2.2fs" % (net_time)) - return imgs.squeeze() + return imgs def train(net, train_data=None, train_labels=None, train_files=None, test_data=None, test_labels=None, test_files=None, train_probs=None, test_probs=None, lam=[1., 1.5, 0.], scale_range=0.5, seg_model_type="cyto2", save_path=None, save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0, - iso=True, downsample=0., learning_rate=0.005, n_epochs=500, momentum=0.9, + iso=True, uniform_blur=False, downsample=0., ds_max=7, + learning_rate=0.005, n_epochs=500, weight_decay=0.00001, batch_size=8, nimg_per_epoch=None, - nimg_test_per_epoch=None): + nimg_test_per_epoch=None, model_name=None): # net properties device = net.device @@ -866,21 +904,23 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N d = datetime.datetime.now() if save_path is not None: - filename = "" - lstrs = ["per", "seg", "rec"] - for k, (l, s) in enumerate(zip(lam, lstrs)): - filename += f"{s}_{l:.2f}_" - if poisson.sum() > 0: - filename += "poisson_" - if blur.sum() > 0: - if iso: + if model_name is None: + filename = "" + lstrs = ["per", "seg", "rec"] + for k, (l, s) in enumerate(zip(lam, lstrs)): + filename += f"{s}_{l:.2f}_" + if not iso: + filename += "aniso_" + if poisson.sum() > 0: + filename += "poisson_" + if blur.sum() > 0: filename += "blur_" - else: - filename += "bluraniso_" - if downsample.sum() > 0: - filename += "downsample_" - filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f") - filename = os.path.join(save_path, filename) + if downsample.sum() > 0: + filename += "downsample_" + filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f") + filename = os.path.join(save_path, filename) + else: + filename = os.path.join(save_path, model_name) print(filename) for i in range(len(poisson)): denoise_logger.info( @@ -939,6 +979,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch nbatch = 0 + train_losses, test_losses = [], [] for iepoch in range(n_epochs): np.random.seed(iepoch) rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,), @@ -948,35 +989,54 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N for param_group in optimizer.param_groups: param_group["lr"] = learning_rate[iepoch] lavg, lavg_per, nsum = 0, 0, 0 - for ibatch in range(0, nimg_per_epoch, batch_size): - inds = rperm[ibatch:ibatch + batch_size] + for ibatch in range(0, nimg_per_epoch, batch_size * nnoise): + inds = rperm[ibatch : ibatch + batch_size * nnoise] if train_data is None: imgs = [np.maximum(0, io.imread(train_files[i])[:nchan]) for i in inds] lbls = [io.imread(train_labels_files[i])[1:] for i in inds] else: imgs = [train_data[i][:nchan] for i in inds] lbls = [train_labels[i][1:] for i in inds] - inoise = nbatch % nnoise - img, lbl, scale = random_rotate_and_resize_noise( - imgs, lbls, diam_train[inds].copy(), poisson=poisson[inoise], - beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso, - downsample=downsample[inoise], diam_mean=diam_mean, device=device) - #print(torch.isnan(img).sum()) - if torch.isnan(img).sum(): - import pdb - pdb.set_trace() - optimizer.zero_grad() - loss, loss_per = train_loss(net, img[:, :nchan], net1=net1, - img=img[:, nchan:], lbl=lbl, lam=lam) - - loss.backward() - optimizer.step() - lavg += loss.item() * img.shape[0] - lavg_per += loss_per.item() * img.shape[0] + #inoise = nbatch % nnoise + rnoise = np.random.permutation(nnoise) + for i, inoise in enumerate(rnoise): + if i * batch_size < len(imgs): + imgi, lbli, scale = random_rotate_and_resize_noise( + imgs[i * batch_size : (i + 1) * batch_size], + lbls[i * batch_size : (i + 1) * batch_size], + diam_train[inds][i * batch_size : (i + 1) * batch_size].copy(), + poisson=poisson[inoise], + beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso, + downsample=downsample[inoise], uniform_blur=uniform_blur, + diam_mean=diam_mean, ds_max=ds_max, + device=device) + if i == 0: + img = imgi + lbl = lbli + else: + img = torch.cat((img, imgi), axis=0) + lbl = torch.cat((lbl, lbli), axis=0) + + if nnoise > 0: + iperm = np.random.permutation(img.shape[0]) + img, lbl = img[iperm], lbl[iperm] + + for i in range(nnoise): + optimizer.zero_grad() + imgi = img[i * batch_size: (i + 1) * batch_size] + lbli = lbl[i * batch_size: (i + 1) * batch_size] + if imgi.shape[0] > 0: + loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1, + img=imgi[:, nchan:], lbl=lbli, lam=lam) + loss.backward() + optimizer.step() + lavg += loss.item() * imgi.shape[0] + lavg_per += loss_per.item() * imgi.shape[0] + nsum += len(img) nbatch += 1 - if iepoch % 10 == 0 or iepoch < 10: + if iepoch % 5 == 0 or iepoch < 10: lavg = lavg / nsum lavg_per = lavg_per / nsum if test_data is not None or test_files is not None: @@ -1000,32 +1060,29 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N img, lbl, scale = random_rotate_and_resize_noise( imgs, lbls, diam_test[inds].copy(), poisson=poisson[inoise], beta=beta[inoise], blur=blur[inoise], gblur=gblur[inoise], - iso=iso, downsample=downsample[inoise], diam_mean=diam_mean, - device=device) + iso=iso, downsample=downsample[inoise], uniform_blur=uniform_blur, + diam_mean=diam_mean, ds_max=ds_max, device=device) loss, loss_per = test_loss(net, img[:, :nchan], net1=net1, img=img[:, nchan:], lbl=lbl, lam=lam) lavgt += loss.item() * img.shape[0] nsum += len(img) + lavgt = lavgt / nsum denoise_logger.info( "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, Loss Test %0.3f, LR %2.4f" - % (iepoch, time.time() - tic, lavg, lavg_per, lavgt / nsum, + % (iepoch, time.time() - tic, lavg, lavg_per, lavgt, learning_rate[iepoch])) + test_losses.append(lavgt) else: denoise_logger.info( "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" % (iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch])) - elif iepoch < 50: - lavg = lavg / nsum - lavg_per = lavg_per / nsum - denoise_logger.info( - "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" % - (iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch])) - + train_losses.append(lavg) + if save_path is not None: - if iepoch == n_epochs - 1 or iepoch % save_every == 1: + if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0): if save_each: #separate files as model progresses - filename0 = filename + "_epoch_" + str(iepoch) + filename0 = str(filename) + f"_epoch_{iepoch:%04d}" else: filename0 = filename denoise_logger.info(f"saving network parameters to {filename0}") @@ -1033,7 +1090,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N else: filename = save_path - return filename + return filename, train_losses, test_losses if __name__ == "__main__": @@ -1070,6 +1127,8 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N help="scale of gaussian blurring stddev") training_args.add_argument("--downsample", default=0., type=float, help="fraction of images to downsample") + training_args.add_argument("--ds_max", default=7, type=int, + help="max downsampling factor") training_args.add_argument("--lam_per", default=1.0, type=float, help="weighting of perceptual loss") training_args.add_argument("--lam_seg", default=1.5, type=float, @@ -1084,6 +1143,9 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N help="learning rate. Default: %(default)s") training_args.add_argument("--n_epochs", default=2000, type=int, help="number of epochs. Default: %(default)s") + training_args.add_argument( + "--save_each", default=False, action="store_true", + help="save each epoch as separate model") training_args.add_argument( "--nimg_per_epoch", default=0, type=int, help="number of images per epoch. Default is length of training images") @@ -1094,33 +1156,59 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N io.logger_setup() args = parser.parse_args() + lams = [args.lam_per, args.lam_seg, args.lam_rec] + print("lam", lams) if len(args.noise_type) > 0: noise_type = args.noise_type + uniform_blur = False + iso = True if noise_type == "poisson": poisson = 0.8 blur = 0. downsample = 0. beta = 0.7 gblur = 1.0 + elif noise_type == "blur_expr": + poisson = 0.8 + blur = 0.8 + downsample = 0. + beta = 0.1 + gblur = 0.5 elif noise_type == "blur": poisson = 0.8 blur = 0.8 downsample = 0. beta = 0.1 + gblur = 10.0 + uniform_blur = True + elif noise_type == "downsample_expr": + poisson = 0.8 + blur = 0.8 + downsample = 0.8 + beta = 0.03 gblur = 1.0 elif noise_type == "downsample": poisson = 0.8 blur = 0.8 downsample = 0.8 - beta = 0.01 - gblur = 0.5 + beta = 0.03 + gblur = 5.0 + uniform_blur = True elif noise_type == "all": poisson = [0.8, 0.8, 0.8] blur = [0., 0.8, 0.8] downsample = [0., 0., 0.8] - beta = [0.7, 0.1, 0.01] - gblur = [0., 1.0, 0.5] + beta = [0.7, 0.1, 0.03] + gblur = [0., 10.0, 5.0] + uniform_blur = True + elif noise_type == "aniso": + poisson = 0.8 + blur = 0.8 + downsample = 0.8 + beta = 0.1 + gblur = args.ds_max * 1.5 + iso = False else: raise ValueError(f"{noise_type} noise_type is not supported") else: @@ -1136,8 +1224,7 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N train_data, labels, train_files, train_probs = None, None, None, None test_data, test_labels, test_files, test_probs = None, None, None, None if len(args.file_list) == 0: - output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0, - 0) + output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0) images, labels, image_names, test_images, test_labels, image_names_test = output train_data = [] for i in range(len(images)): @@ -1185,10 +1272,11 @@ def train(net, train_data=None, train_labels=None, train_files=None, test_data=N model.net, train_data=train_data, train_labels=labels, train_files=train_files, test_data=test_data, test_labels=test_labels, test_files=test_files, train_probs=train_probs, test_probs=test_probs, poisson=poisson, beta=beta, - blur=blur, gblur=gblur, downsample=downsample, iso=True, n_epochs=args.n_epochs, + blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max, + iso=iso, uniform_blur=uniform_blur, n_epochs=args.n_epochs, learning_rate=args.learning_rate, - lam=[args.lam_per, args.lam_seg, args.lam_rec - ], seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch, + lam=lams, + seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch, nimg_test_per_epoch=nimg_test_per_epoch, save_path=save_path) diff --git a/cellpose/gui/gui.py b/cellpose/gui/gui.py index 94296419..dd518b13 100644 --- a/cellpose/gui/gui.py +++ b/cellpose/gui/gui.py @@ -733,32 +733,31 @@ def make_buttons(self): self.l0.addWidget(self.denoiseBox, b, 0, 1, 9) b0 = 0 - self.denoiseBoxG.addWidget(QLabel("mode:"), b0, 0, 1, 3) - + # DENOISING self.DenoiseButtons = [] nett = [ - "filter image (settings below)", "clear restore/filter", + "filter image (settings below)", "denoise (please set cell diameter first)", "deblur (please set cell diameter first)", "upsample to 30. diameter (cyto3) or 17. diameter (nuclei) (please set cell diameter first) (disabled in 3D)", + "one-click model trained to denoise+deblur+upsample (please set cell diameter first)" ] - self.denoise_text = ["filter", "none", "denoise", "deblur", "upsample"] + self.denoise_text = ["none", "filter", "denoise", "deblur", "upsample", "one-click"] self.restore = None self.ratio = 1. - jj = 3 + jj = 0 + w = 3 for j in range(len(self.denoise_text)): self.DenoiseButtons.append( guiparts.DenoiseButton(self, self.denoise_text[j])) - w = 3 self.denoiseBoxG.addWidget(self.DenoiseButtons[-1], b0, jj, 1, w) - jj += w self.DenoiseButtons[-1].setFixedWidth(75) self.DenoiseButtons[-1].setToolTip(nett[j]) self.DenoiseButtons[-1].setFont(self.medfont) - b0 += 1 if j == 1 else 0 - jj = 0 if j == 1 else jj + b0 += 1 if j%2==1 else 0 + jj = 0 if j%2==1 else jj + w # b0+=1 self.save_norm = QCheckBox("save restored/filtered image") @@ -767,22 +766,23 @@ def make_buttons(self): self.save_norm.setChecked(True) # self.denoiseBoxG.addWidget(self.save_norm, b0, 0, 1, 8) - b0 += 1 - label = QLabel("Cellpose3 model type:") + b0 -= 3 + label = QLabel("restore-dataset:") label.setToolTip( - "choose model type and click [denoise], [deblur], or [upsample]") + "choose dataset and click [denoise], [deblur], [upsample], or [one-click]") label.setFont(self.medfont) - self.denoiseBoxG.addWidget(label, b0, 0, 1, 4) + self.denoiseBoxG.addWidget(label, b0, 6, 1, 3) + b0 += 1 self.DenoiseChoose = QComboBox() self.DenoiseChoose.setFont(self.medfont) - self.DenoiseChoose.addItems(["one-click", "nuclei"]) - self.DenoiseChoose.setFixedWidth(100) + self.DenoiseChoose.addItems(["cyto3", "cyto2", "nuclei"]) + self.DenoiseChoose.setFixedWidth(85) tipstr = "choose model type and click [denoise], [deblur], or [upsample]" self.DenoiseChoose.setToolTip(tipstr) - self.denoiseBoxG.addWidget(self.DenoiseChoose, b0, 5, 1, 4) + self.denoiseBoxG.addWidget(self.DenoiseChoose, b0, 6, 1, 3) - b0 += 1 + b0 += 2 # FILTERING self.filtBox = QCollapsible("custom filter settings") self.filtBox._toggle_btn.setFont(self.medfont) @@ -1019,7 +1019,7 @@ def enable_buttons(self): for i in range(len(self.DenoiseButtons)): self.DenoiseButtons[i].setEnabled(True) if self.load_3D: - self.DenoiseButtons[-1].setEnabled(False) + self.DenoiseButtons[-2].setEnabled(False) self.ModelButtonB.setEnabled(True) self.SizeButton.setEnabled(True) self.newmodel.setEnabled(True) @@ -2213,7 +2213,7 @@ def compute_restore(self): self.DenoiseChoose.setCurrentIndex(1) if "upsample" in self.restore: i = self.DenoiseChoose.currentIndex() - diam_up = 30. if i == 0 else 17. + diam_up = 30. if i==0 or i==1 else 17. print(diam_up, self.ratio) self.Diameter.setText(str(diam_up / self.ratio)) self.compute_denoise_model(model_type=model_type) @@ -2264,16 +2264,16 @@ def compute_denoise_model(self, model_type=None): self.progress.setValue(0) try: tic = time.time() - nstr = "cyto3" if self.DenoiseChoose.currentText( - ) == "one-click" else "nuclei" - print(model_type) + nstr = self.DenoiseChoose.currentText() + nstr.replace("-", "") self.clear_restore() model_name = model_type + "_" + nstr + print(model_name) # denoising model self.denoise_model = denoise.DenoiseModel(gpu=self.useGPU.isChecked(), model_type=model_name) self.progress.setValue(10) - diam_up = 30. if "cyto3" in model_name else 17. + diam_up = 30. if "cyto" in model_name else 17. # params channels = self.get_channels() diff --git a/cellpose/gui/gui3d.py b/cellpose/gui/gui3d.py index 8804a7e9..8a3f9eea 100644 --- a/cellpose/gui/gui3d.py +++ b/cellpose/gui/gui3d.py @@ -38,8 +38,8 @@ def avg3d(C): """ Ly, Lx = C.shape # pad T by 2 - T = np.zeros((Ly + 2, Lx + 2), np.float32) - M = np.zeros((Ly, Lx), np.float32) + T = np.zeros((Ly + 2, Lx + 2), "float32") + M = np.zeros((Ly, Lx), "float32") T[1:-1, 1:-1] = C.copy() y, x = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int), indexing="ij") @@ -244,7 +244,7 @@ def add_mask(self, points=None, color=(100, 200, 50), dense=True): vc = stroke[iz, 2] if iz.sum() > 0: # get points inside drawn points - mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8) + mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), "uint8") pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2), axis=-1)[:, np.newaxis, :] mask = cv2.fillPoly(mask, [pts], (255, 0, 0)) @@ -265,7 +265,7 @@ def add_mask(self, points=None, color=(100, 200, 50), dense=True): elif ioverlap.sum() > 0: ar, ac = ar[~ioverlap], ac[~ioverlap] # compute outline of new mask - mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8) + mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), "uint8") mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1 contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) @@ -282,7 +282,7 @@ def add_mask(self, points=None, color=(100, 200, 50), dense=True): pix = np.append(pix, np.vstack((ars, acs)), axis=-1) mall = mall[:, pix[0].min():pix[0].max() + 1, - pix[1].min():pix[1].max() + 1].astype(np.float32) + pix[1].min():pix[1].max() + 1].astype("float32") ymin, xmin = pix[0].min(), pix[1].min() if len(zdraw) > 1: mall, zfill = interpZ(mall, zdraw - zmin) @@ -422,15 +422,15 @@ def update_ortho(self): for j in range(2): if j == 0: if self.view == 0: - image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2) + image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2).copy() else: image = self.stack_filtered[zmin:zmax, :, - x].transpose(1, 0, 2) + x].transpose(1, 0, 2).copy() else: image = self.stack[ zmin:zmax, - y, :] if self.view == 0 else self.stack_filtered[zmin:zmax, - y, :] + y, :].copy() if self.view == 0 else self.stack_filtered[zmin:zmax, + y, :].copy() if self.nchan == 1: # show single channel image = image[..., 0] @@ -458,11 +458,13 @@ def update_ortho(self): self.imgOrtho[j].setLevels( self.saturation[0][self.currentZ]) elif self.color == 4: - image = image.astype(np.float32).mean(axis=-1).astype(np.uint8) + if image.ndim > 2: + image = image.astype("float32").mean(axis=2).astype("uint8") self.imgOrtho[j].setImage(image, autoLevels=False, lut=None) self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ]) elif self.color == 5: - image = image.astype(np.float32).mean(axis=-1).astype(np.uint8) + if image.ndim > 2: + image = image.astype("float32").mean(axis=2).astype("uint8") self.imgOrtho[j].setImage(image, autoLevels=False, lut=self.cmap[0]) self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ]) @@ -470,7 +472,7 @@ def update_ortho(self): self.pOrtho[1].setAspectLocked(lock=True, ratio=1. / self.zaspect) else: - image = np.zeros((10, 10), np.uint8) + image = np.zeros((10, 10), "uint8") self.imgOrtho[0].setImage(image, autoLevels=False, lut=None) self.imgOrtho[0].setLevels([0.0, 255.0]) self.imgOrtho[1].setImage(image, autoLevels=False, lut=None) @@ -478,8 +480,8 @@ def update_ortho(self): zrange = zmax - zmin self.layer_ortho = [ - np.zeros((self.Ly, zrange, 4), np.uint8), - np.zeros((zrange, self.Lx, 4), np.uint8) + np.zeros((self.Ly, zrange, 4), "uint8"), + np.zeros((zrange, self.Lx, 4), "uint8") ] if self.masksOn: for j in range(2): @@ -488,7 +490,7 @@ def update_ortho(self): else: cp = self.cellpix[zmin:zmax, y] self.layer_ortho[j][..., :3] = self.cellcolors[cp, :] - self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype(np.uint8) + self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype("uint8") if self.selected > 0: self.layer_ortho[j][cp == self.selected] = np.array( [255, 255, 255, self.opacity]) @@ -499,7 +501,7 @@ def update_ortho(self): op = self.outpix[zmin:zmax, :, x].T else: op = self.outpix[zmin:zmax, y] - self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype(np.uint8) + self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype("uint8") for j in range(2): self.layerOrtho[j].setImage(self.layer_ortho[j]) diff --git a/cellpose/models.py b/cellpose/models.py index 6198041d..ac8bdafa 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -35,7 +35,7 @@ "lowhigh": None, "percentile": None, "normalize": True, - "norm3D": False, + "norm3D": True, "sharpen_radius": 0, "smooth_radius": 0, "tile_norm_blocksize": 0, @@ -263,7 +263,7 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, if (pretrained_model and not Path(pretrained_model).exists() and np.any([pretrained_model == s for s in all_models])): model_type = pretrained_model - + # check if model_type is builtin or custom user model saved in .cellpose/models if model_type is not None and np.any([model_type == s for s in all_models]): if np.any([model_type == s for s in MODEL_NAMES]): @@ -286,6 +286,10 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, models_logger.warning( "pretrained_model path does not exist, using default model") use_default = True + elif pretrained_model: + if pretrained_model[-13:] == "nucleitorch_0": + builtin = True + self.diam_mean = 17. builtin = True if use_default else builtin self.pretrained_model = model_path( @@ -503,37 +507,18 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non del yf else: tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO) - iterator = trange(nimg, file=tqdm_out, - mininterval=30) if nimg > 1 else range(nimg) - styles = np.zeros((nimg, self.nbase[-1]), np.float32) + img = np.asarray(x) + if do_normalization: + img = transforms.normalize_img(img, **normalize_params) + if rescale != 1.0: + img = transforms.resize_image(img, rsz=rescale) + yf, style = run_net(self.net, img, bsize=bsize, augment=augment, + tile=tile, tile_overlap=tile_overlap) if resample: - dP = np.zeros((2, nimg, shape[1], shape[2]), np.float32) - cellprob = np.zeros((nimg, shape[1], shape[2]), np.float32) - else: - dP = np.zeros( - (2, nimg, int(shape[1] * rescale), int(shape[2] * rescale)), - np.float32) - cellprob = np.zeros( - (nimg, int(shape[1] * rescale), int(shape[2] * rescale)), - np.float32) - for i in iterator: - img = np.asarray(x[i]) - if do_normalization: - img = transforms.normalize_img(img, **normalize_params) - if rescale != 1.0: - img = transforms.resize_image(img, rsz=rescale) - yf, style = run_net(self.net, img, bsize=bsize, augment=augment, - tile=tile, tile_overlap=tile_overlap) - if resample: - yf = transforms.resize_image(yf, shape[1], shape[2]) - - cellprob[i] = yf[:, :, 2] - dP[:, i] = yf[:, :, :2].transpose((2, 0, 1)) - if self.nclasses == 4: - if i == 0: - bd = np.zeros_like(cellprob) - bd[i] = yf[:, :, 3] - styles[i][:len(style)] = style + yf = transforms.resize_image(yf, shape[1], shape[2]) + dP = np.moveaxis(yf[..., :2], source=-1, destination=0).copy() + cellprob = yf[..., 2] + styles = style del yf, style styles = styles.squeeze() diff --git a/cellpose/resnet_torch.py b/cellpose/resnet_torch.py index 56b0a36a..808b26be 100644 --- a/cellpose/resnet_torch.py +++ b/cellpose/resnet_torch.py @@ -199,6 +199,7 @@ class CPnet(nn.Module): def __init__(self, nbase, nout, sz, mkldnn=False, conv_3D=False, max_pool=True, diam_mean=30.): super().__init__() + self.nchan = nbase[0] self.nbase = nbase self.nout = nout self.sz = sz diff --git a/cellpose/train.py b/cellpose/train.py index faa2665b..fd669043 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -331,7 +331,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, test_probs=None, load_files=True, batch_size=8, learning_rate=0.005, n_epochs=2000, weight_decay=1e-5, momentum=0.9, SGD=False, channels=None, channel_axis=None, rgb=False, normalize=True, compute_flows=False, - save_path=None, save_every=100, nimg_per_epoch=None, + save_path=None, save_every=100, save_each=False, nimg_per_epoch=None, nimg_test_per_epoch=None, rescale=True, scale_range=None, bsize=224, min_train_masks=5, model_name=None): """ @@ -362,6 +362,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, compute_flows (bool, optional): Boolean - whether to compute flows during training. Defaults to False. save_path (str, optional): String - where to save the trained model. Defaults to None. save_every (int, optional): Integer - save the network every [save_every] epochs. Defaults to 100. + save_each (bool, optional): Boolean - save the network to a new filename at every [save_each] epoch. Defaults to False. nimg_per_epoch (int, optional): Integer - minimum number of images to train on per epoch. Defaults to None. nimg_test_per_epoch (int, optional): Integer - minimum number of images to test on per epoch. Defaults to None. rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to True. @@ -444,10 +445,10 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, t0 = time.time() model_name = f"cellpose_{t0}" if model_name is None else model_name save_path = Path.cwd() if save_path is None else Path(save_path) - model_path = save_path / "models" / model_name + filename = save_path / "models" / model_name (save_path / "models").mkdir(exist_ok=True) - train_logger.info(f">>> saving model to {model_path}") + train_logger.info(f">>> saving model to {filename}") lavg, nsum = 0, 0 for iepoch in range(n_epochs): @@ -518,15 +519,21 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, lavgt /= len(rperm) lavg /= nsum train_logger.info( - f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.4f}, time {time.time()-t0:.2f}s" + f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s" ) lavg, nsum = 0, 0 - if iepoch > 0 and iepoch % save_every == 0: - net.save_model(model_path) - net.save_model(model_path) + if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0): + if save_each and iepoch != n_epochs - 1: #separate files as model progresses + filename0 = str(filename) + f"_epoch_{iepoch:%04d}" + else: + filename0 = filename + train_logger.info(f"saving network parameters to {filename0}") + net.save_model(filename0) + + net.save_model(filename) - return model_path + return filename def train_size(net, pretrained_model, train_data=None, train_labels=None, diff --git a/cellpose/transforms.py b/cellpose/transforms.py index 4fcacfcd..c9154594 100644 --- a/cellpose/transforms.py +++ b/cellpose/transforms.py @@ -469,9 +469,9 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha if len(to_squeeze) > 0: channel_axis = update_axis( channel_axis, to_squeeze, - x.ndim) if channel_axis is not None else channel_axis + x.ndim) if channel_axis is not None else None z_axis = update_axis(z_axis, to_squeeze, - x.ndim) if z_axis is not None else z_axis + x.ndim) if z_axis is not None else None x = x.squeeze() # put z axis first @@ -480,7 +480,19 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha if channel_axis is not None: channel_axis += 1 z_axis = 0 - + elif z_axis is None and x.ndim > 2 and channels is not None and min(x.shape) > 5 : + # if there are > 5 channels and channels!=None, assume first dimension is z + min_dim = min(x.shape) + if min_dim != channel_axis: + z_axis = (x.shape).index(min_dim) + if z_axis != 0: + x = move_axis(x, m_axis=z_axis, first=True) + if channel_axis is not None: + channel_axis += 1 + transforms_logger.warning(f"z_axis not specified, assuming it is dim {z_axis}") + transforms_logger.warning(f"if this is actually the channel_axis, use 'model.eval(channel_axis={z_axis}, ...)'") + z_axis = 0 + if z_axis is not None: if x.ndim == 3: x = x[..., np.newaxis] @@ -500,7 +512,7 @@ def convert_image(x, channels, channel_axis=None, z_axis=None, do_3D=False, ncha if channel_axis is None: x = move_min_dim(x) - + if x.ndim > 3: transforms_logger.info( "multi-stack tiff read in as having %d planes %d channels" % @@ -723,13 +735,14 @@ def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEA else: imgs = np.zeros((img0.shape[0], Ly, Lx, img0.shape[-1]), np.float32) for i, img in enumerate(img0): - imgs[i] = cv2.resize(img, (Lx, Ly), interpolation=interpolation) + imgi = cv2.resize(img, (Lx, Ly), interpolation=interpolation) + imgs[i] = imgi if imgi.ndim > 2 else imgi[..., np.newaxis] else: imgs = cv2.resize(img0, (Lx, Ly), interpolation=interpolation) return imgs -def pad_image_ND(img0, div=16, extra=1, min_size=None): +def pad_image_ND(img0, div=16, extra=1, min_size=None, zpad=False): """Pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D). Args: @@ -758,7 +771,13 @@ def pad_image_ND(img0, div=16, extra=1, min_size=None): ypad2 = extra * div // 2 + Lpad - Lpad // 2 if img0.ndim > 3: - pads = np.array([[0, 0], [0, 0], [xpad1, xpad2], [ypad1, ypad2]]) + if zpad: + Lpad = int(div * np.ceil(img0.shape[-3] / div) - img0.shape[-3]) + zpad1 = extra * div // 2 + Lpad // 2 + zpad2 = extra * div // 2 + Lpad - Lpad // 2 + else: + zpad1, zpad2 = 0, 0 + pads = np.array([[0, 0], [zpad1, zpad2], [xpad1, xpad2], [ypad1, ypad2]]) else: pads = np.array([[0, 0], [xpad1, xpad2], [ypad1, ypad2]]) @@ -767,8 +786,11 @@ def pad_image_ND(img0, div=16, extra=1, min_size=None): Ly, Lx = img0.shape[-2:] ysub = np.arange(xpad1, xpad1 + Ly) xsub = np.arange(ypad1, ypad1 + Lx) - - return I, ysub, xsub + if zpad: + zsub = np.arange(zpad1, zpad1 + img0.shape[-3]) + return I, ysub, xsub, zsub + else: + return I, ysub, xsub def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=False, @@ -826,7 +848,7 @@ def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=Fal # generate random augmentation parameters flip = np.random.rand() > .5 theta = np.random.rand() * np.pi * 2 if rotate else 0. - scale[n] = (1 - scale_range / 2) + scale_range * np.random.rand() + scale[n] = 2 ** (-2 + 5 * np.random.rand())#(1 - scale_range / 2) + scale_range * np.random.rand() if rescale is not None: scale[n] *= 1. / rescale[n] dxy = np.maximum(0, np.array([Lx * scale[n] - xy[1], diff --git a/docs/restore.rst b/docs/restore.rst index c3864c44..be15e18c 100644 --- a/docs/restore.rst +++ b/docs/restore.rst @@ -5,11 +5,14 @@ Image Restoration The image restoration module ``denoise`` provides functions for restoring degraded images. There are two main classes, ``DenoiseModel`` for image restoration only, and -``CellposeDenoiseModel`` for image restoration and then segmentation. There are three types -of image restoration provided, denoising, deblurring, and upsampling, and for each of these -there are two models, one trained on the full ``cyto3`` training set and one trained on -the ``nuclei`` training set: ``'denoise_cyto3'``, ``'deblur_cyto3'``, ``'upsample_cyto3'``, -``'denoise_nuclei'``, ``'deblur_nuclei'``, ``'upsample_nuclei'``. +``CellposeDenoiseModel`` for image restoration and then segmentation. There are four types +of image restoration provided, denoising, deblurring, upsampling and one-click (trained on +all degradation types), and for each of these +there are three models, one trained on the full ``cyto3`` training set, one trained on the +``cyto2`` training set, and one trained on the ``nuclei`` training set: +``'denoise_cyto3'``, ``'deblur_cyto3'``, ``'upsample_cyto3'``, ``'oneclick_cyto3'``, +``'denoise_cyto2'``, ``'deblur_cyto2'``, ``'upsample_cyto2'``, ``'oneclick_cyto2'``, +``'denoise_nuclei'``, ``'deblur_nuclei'``, ``'upsample_nuclei'``, ``'oneclick_nuclei'``. DenoiseModel -------------- @@ -70,5 +73,36 @@ For more details refer to the API section. Command line usage --------------------- -These models can be used on the command line with input ``--restore_type`` and flag -``--chan2_restore``. +These models can be used on the command line with model_type input using ``--restore_type`` +and add flag ``--chan2_restore`` for restoring the optional nuclear channel, e.g.: + +:: + + python -m cellpose --dir /path/to/images --model_type cyto3 --restore_type denoise_cyto3 --diameter 25 --chan2_restore --chan 2 --chan2 1 + +Training your own models +-------------------------- + +It is also possible to train your own models for image restoration using the +``cellpose.denoise`` module. For example, to train a denoising (Poisson noise) +model with the cyto2 segmentation model with train_data and train_labels +(images and ``_flows.tif``): + +:: + + from cellpose import denoise + model = denoise.DenoiseModel(gpu=True, nchan=1) + + io.logger_setup() + model_path = model.train(train_data, train_labels, test_data=None, test_labels=None, + save_path=save_path, iso=True, + blur=0., downsample=0., poisson=0.8, + n_epochs=2000, learning_rate=0.001, + seg_model_type="/home/carsen/.cellpose/models/cyto2torch_0") + + +This training can also be performed on the command line: + +:: + + python cellpose/denoise.py --dir /path/to/images --noise_type poisson --seg_model_type cyto2 --diam_mean 30. \ No newline at end of file diff --git a/paper/3.0/analysis.py b/paper/3.0/analysis.py index 62e3a400..2a9eb1de 100644 --- a/paper/3.0/analysis.py +++ b/paper/3.0/analysis.py @@ -24,47 +24,7 @@ device = torch.device("cuda") try: - import segmentation_models_pytorch as smp - - class Transformer(nn.Module): - - def __init__(self, pretrained_model=None, encoder="mit_b5", - encoder_weights="imagenet", decoder="FPN"): - super().__init__() - net_fcn = smp.FPN if decoder == "FPN" else smp.MAnet - self.net = net_fcn( - encoder_name=encoder, - encoder_weights=encoder_weights if pretrained_model is None else - None, # use `imagenet` pre-trained weights for encoder initialization - in_channels=3, - classes=3, - activation=None) - self.nout = 3 - self.mkldnn = False - if pretrained_model is not None: - state_dict = torch.load(pretrained_model) - if list(state_dict.keys())[0][:7] == "module.": - from collections import OrderedDict - new_state_dict = OrderedDict() - for k, v in state_dict.items(): - name = k[ - 7:] # remove 'module.' of DataParallel/DistributedDataParallel - new_state_dict[name] = v - self.net.load_state_dict(new_state_dict) - else: - self.load_state_dict(state_dict) - - def forward(self, X): - X = torch.cat( - (X, torch.zeros( - (X.shape[0], 1, X.shape[2], X.shape[3]), device=X.device)), dim=1) - y = self.net(X) - return y, torch.zeros((X.shape[0], 256), device=X.device) - - @property - def device(self): - return next(self.parameters()).device - + from cellpose.segformer import Transformer except Exception as e: print(e) print("need to install segmentation_models_pytorch to run transformer") @@ -75,49 +35,56 @@ def device(self): def seg_eval_cp3(folder, noise_type="poisson"): """ need to download test_poisson.npy, test_blur.npy, test_downsample.npy (for cells and/or nuclei) - - (was computed with old flows, but results similar with new flows) """ + """ ctypes = ["cyto2", "nuclei"] - for ctype in ctypes: + for c, ctype in enumerate(ctypes): + print(ctype) + pretrained_models = [f"/home/carsen/.cellpose/models/{model_names[noise_type]}{istr}_{ctype}" + for istr in ["_rec", "_seg", "_per", ""]] + pretrained_models.extend([f"/home/carsen/.cellpose/models/{model_names[noise_type]}_cyto3", + f"/home/carsen/.cellpose/models/oneclick_{ctype}", + f"/home/carsen/.cellpose/models/oneclick_cyto3"]) + + seg_model = models.CellposeModel(gpu=True, model_type=ctype) + folder_name = ctype - diam_mean = 30 if ctype == "cyto2" else 17 root = Path(folder) / f"images_{folder_name}/" - + model_name = model_names[noise_type] + nimg_test = 68 if ctype=="cyto2" else 111 + diam_mean = 30. if ctype == "cyto2" else 17. ### cellpose enhance dat = np.load(root / "noisy_test" / f"test_{noise_type}.npy", - allow_pickle=True).item() - test_noisy = dat["test_noisy"] - masks_true = dat["masks_true"] - diam_test = dat["diam_test"] if "diam_test" in dat else 30. * np.ones( + allow_pickle=True).item() + test_noisy = dat["test_noisy"][:nimg_test] + masks_true = dat["masks_true"][:nimg_test] + diam_test = dat["diam_test"][:nimg_test] if "diam_test" in dat else diam_mean * np.ones( len(test_noisy)) - istr = ["rec", "seg", "per", "perseg"] - for k in range(len(istr)): - model_name = model_names[noise_type] - if istr[k] != "perseg": - model_name += "_" + istr[k] - model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean, - model_type=f"{model_name}_{ctype}") - imgs2 = model.eval([test_noisy[i][0] for i in range(len(test_noisy))], - diameter=diam_test, channel_axis=0) - print(imgs2[0].shape) - seg_model = models.CellposeModel(gpu=True, model_type=ctype) - masks2, flows2, styles2 = seg_model.eval(imgs2, channels=[1, 0], - diameter=diam_test, channel_axis=0, - normalize=True) - flows = [flow[0] for flow in flows2] + thresholds = np.arange(0.5, 1.05, 0.05) + istrs = ["rec", "seg", "per", "perseg", "noise_spec", "data_spec", "gen"] + + print(pretrained_models) + aps = [] + for istr, pretrained_model in zip(istrs, pretrained_models): + dn_model = denoise.DenoiseModel(gpu=True, nchan=1, + diam_mean = 30 if "cyto" in pretrained_model else 17, + pretrained_model=pretrained_model) + dn_model.pretrained_model = "test" + imgs2 = dn_model.eval([test_noisy[i][0] for i in range(len(test_noisy))], + diameter=diam_test, channel_axis=0) + + masks2, flows, styles2 = seg_model.eval(imgs2, channels=[1, 0], + diameter=diam_test, channel_axis=-1, + normalize=True) + + ap, tp, fp, fn = metrics.average_precision(masks_true, masks2, threshold=thresholds) + print(f"{noise_type} {istr} AP@0.5 \t = {ap[:,0].mean(axis=0):.3f}") - ap, tp, fp, fn = metrics.average_precision(masks_true, masks2) - if ctype == "cyto2": - print(f"{istr[k]} AP@0.5 \t = {ap[:68,0].mean(axis=0):.3f}") - else: - print(f"{istr[k]} AP@0.5 \t = {ap[:,0].mean(axis=0):.3f}") - - dat[f"test_{istr[k]}"] = imgs2 - dat[f"masks_{istr[k]}"] = masks2 - dat[f"flows_{istr[k]}"] = flows - - #np.save(root / "noisy_test" / f"test_{noise_type}_cp3.npy", dat) + dat[f"test_{istr}"] = imgs2 + dat[f"masks_{istr}"] = masks2 + dat[f"flows_{istr}"] = flows + aps.append(ap) + np.save(root / "noisy_test" / f"test_{noise_type}_cp3_all.npy", dat) if noise_type == "poisson": ### cellpose retrained @@ -137,7 +104,7 @@ def seg_eval_cp3(folder, noise_type="poisson"): dat[f"masks_retrain"] = masks2 - #np.save(root / "noisy_test" / f"test_{noise_type}_cp_retrain.npy", dat) + np.save(root / "noisy_test" / f"test_{noise_type}_cp_retrain.npy", dat) def blind_denoising(folder): @@ -350,48 +317,6 @@ def specialist_training(root): noise2void.train_test_specialist(root, n_epochs=100, lr=4e-4, test=True) -def seg_eval_oneclick(folder): - noise_types = ["poisson", "blur", "downsample"] - ctypes = ["cyto2", "nuclei"] - for c, ctype in enumerate(ctypes): - folder_name = ctype - diam_mean = 30. - root = Path(f"/media/carsen/ssd4/datasets_cellpose/images_{folder_name}/") - print(ctype) - for n, noise_type in enumerate(noise_types): - print(noise_type) - ### cellpose enhance - dat = np.load(root / "noisy_test" / f"test_{noise_type}.npy", - allow_pickle=True).item() - test_noisy = dat["test_noisy"] - masks_true = dat["masks_true"] - diam_test = dat["diam_test"] if "diam_test" in dat else 30. * np.ones( - len(test_noisy)) - - model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean, - model_type=model_names[noise_type] + "_cyto3", - device=torch.device("cuda")) - imgs2 = model.eval([test_noisy[i][0] for i in range(len(test_noisy))], - diameter=diam_test, channel_axis=0) - - seg_model = models.CellposeModel(gpu=True, model_type=ctype, - device=torch.device("cuda")) - masks2, flows2, styles2 = seg_model.eval(imgs2, channels=[1, 0], - diameter=diam_test, channel_axis=0, - normalize=True) - istr = "generalist" - ap, tp, fp, fn = metrics.average_precision(masks_true, masks2) - if ctype == "cyto2": - print(f"{istr} AP@0.5 \t = {ap[:68,0].mean(axis=0):.3f}") - else: - print(f"{istr} AP@0.5 \t = {ap[:,0].mean(axis=0):.3f}") - - dat[f"test_{istr}"] = imgs2 - dat[f"masks_{istr}"] = masks2 - - np.save(root / "noisy_test" / f"test_{noise_type}_generalist_cp3.npy", dat) - - def cyto3_comparisons(folder): """ diameters computed from generalist model cyto3 will need segmentation_models_pytorch to run transformer """ @@ -402,21 +327,20 @@ def cyto3_comparisons(folder): ] net_types = ["generalist", "specialist", "transformer"] - for net_type in net_types: + for net_type in net_types[-1:]: if net_type == "generalist": seg_model = models.Cellpose(gpu=True, model_type="cyto3") elif net_type == "transformer": - seg_model = models.CellposeModel(gpu=True, pretrained_model=None) pretrained_model = "/home/carsen/.cellpose/models/transformer_cp3" - seg_model.net = Transformer(pretrained_model=pretrained_model, - decoder="MAnet").to(device) + seg_model = models.CellposeModel(gpu=True, backbone="transformer", + pretrained_model=pretrained_model) for f in folders: if net_type == "specialist": seg_model = models.CellposeModel(gpu=True, model_type=f"{f}_cp3") root = Path(folder) / f"images_{f}" channels = [1, 2] if f == "tissuenet" or f == "cyto2" else [1, 0] - tifs = (root / "test").glob("*.tif") + tifs = natsorted((root / "test").glob("*.tif")) tifs = [tif for tif in tifs] tifs = [ tif for tif in tifs @@ -424,7 +348,7 @@ def cyto3_comparisons(folder): ] if net_type != "generalist": d = np.load( - f"/media/carsen/ssd4/datasets_cellpose/{f}_generalist_masks.npy", + Path(folder) / f"{f}_generalist_masks.npy", allow_pickle=True).item() diams = d["diams"] else: @@ -456,7 +380,7 @@ def cyto3_comparisons(folder): dat["performance"] = [ap, tp, fp, fn] dat["diams"] = diams - #p.save(f"/media/carsen/ssd4/datasets_cellpose/{f}_{net_type}_masks.npy", dat) + #np.save(f"/media/carsen/ssd4/datasets_cellpose/{f}_{net_type}_masks.npy", dat) if __name__ == '__main__': diff --git a/paper/3.0/figures.py b/paper/3.0/figures.py index 6e3cbd7e..02bef712 100644 --- a/paper/3.0/figures.py +++ b/paper/3.0/figures.py @@ -71,7 +71,7 @@ def load_benchmarks(folder, noise_type="poisson", ctype="cyto2", imgs_all.append(test_n2s) masks_all.append(masks_n2s) - dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3.npy", + dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3_all.npy", allow_pickle=True).item() istr = ["rec", "per", "seg", "perseg"] for k in range(len(istr)): @@ -621,13 +621,14 @@ def suppfig_nuclei(folder, save_fig=False): def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_data", + folder3="/media/carsen/ssd4/denoising/ribo_denoise/", save_fig=False): thresholds = np.arange(0.5, 1.05, 0.05) - fig = plt.figure(figsize=(14, 8), dpi=100) + fig = plt.figure(figsize=(14, 12), dpi=100) yratio = 14 / 8 - grid = plt.GridSpec(5, 8, figure=fig, left=0.02, right=0.98, top=0.96, bottom=0.1, - wspace=0.05, hspace=0.25) + grid = plt.GridSpec(7, 8, figure=fig, left=0.02, right=0.97, top=0.98, bottom=0.08, + wspace=0.12, hspace=0.25) transl = mtransforms.ScaledTranslation(-18 / 72, 10 / 72, fig.dpi_scale_trans) il = 0 @@ -639,7 +640,7 @@ def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_d kk=[0, 1, 5], seg=[0, 0, 0], dy=0.015) dat = np.load(f"{folder2}/cp_masks.npy", allow_pickle=True).item() - grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 5, subplot_spec=grid[-2:, :], + grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 5, subplot_spec=grid[-4:-2, :], wspace=0.05, hspace=0.1) iex = 10 @@ -666,7 +667,7 @@ def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_d ax = plt.subplot(grid1[0, k]) pos = ax.get_position().bounds - ax.set_position([pos[0], pos[1] - 0.05, pos[2], pos[3]]) + ax.set_position([pos[0], pos[1] - 0.01, pos[2], pos[3]]) ax.imshow(img, vmin=0, vmax=1, cmap="gray") ax.set_title(titlesd[k], color="k" if k < 2 else cols[-2], fontsize="medium") ax.set_xlim(xlim) @@ -679,7 +680,7 @@ def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_d ax = plt.subplot(grid1[1, k]) pos = ax.get_position().bounds - ax.set_position([pos[0], pos[1] - 0.05, pos[2], pos[3]]) + ax.set_position([pos[0], pos[1] - 0.01, pos[2], pos[3]]) ax.imshow(img, vmin=0, vmax=1, cmap="gray") ax.set_xlim(xlim) ax.set_ylim(ylim) @@ -706,7 +707,7 @@ def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_d transl = mtransforms.ScaledTranslation(-40 / 72, 15 / 72, fig.dpi_scale_trans) ax = plt.subplot(grid1[:, 3]) pos = ax.get_position().bounds - ax.set_position([pos[0] + 0.05, pos[1] - 0.03, pos[2] * 0.7, pos[3]]) + ax.set_position([pos[0] + 0.05, pos[1] - 0.0, pos[2] * 0.7, pos[3]]) nl = 0 titlesd = titles.copy() titlesd[7] = "Cellpose3" @@ -731,7 +732,7 @@ def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_d ax = plt.subplot(grid1[:, 4]) pos = ax.get_position().bounds - ax.set_position([pos[0] + 0.05, pos[1] - 0.03, pos[2] * 0.7, pos[3]]) + ax.set_position([pos[0] + 0.05, pos[1] - 0.0, pos[2] * 0.7, pos[3]]) kk = [1, 2, 3, 7] for k in range(len(aps)): means = np.array([aps[k][nl][:, 0].mean(axis=0) for nl in [0, 2, 1]]) @@ -746,6 +747,117 @@ def fig2(folder, folder2="/media/carsen/ssd4/denoising/Projection_Flywing/test_d ax.set_xticklabels(["2%", "3%", "5%"]) ax.set_xlabel("laser power") + + dat = np.load(f"{folder3}/ribo_denoise_n2v.npy", allow_pickle=True).item() + ap_n2v = dat["ap_n2v"] + dat = np.load(f"{folder3}/ribo_denoise_n2s.npy", allow_pickle=True).item() + ap_n2s = dat["ap_n2s"] + dat = np.load(f"{folder3}/ribo_denoise.npy", allow_pickle=True).item() + navgs = dat["navgs"] + + grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 5, subplot_spec=grid[-2:, :], + wspace=0.05, hspace=0.1) + iex = 3 + nl = 2 + ylim = [350, 500] + xlim = [200, 500] + + transl = mtransforms.ScaledTranslation(-18 / 72, 26 / 72, fig.dpi_scale_trans) + outlines_gt = utils.outlines_list(dat["masks_clean"][iex].copy(), + multiprocessing=False) + titlest = ["clean (300 frames averaged)", "noisy (4 frames averaged)", "denoised (Cellpose3)"] + for k in range(3): + if k == 0: + img = dat["clean"][iex].copy() + elif k == 1: + img = dat["noisy"][nl][iex].copy() + maskk = dat["masks_noisy"][nl][iex].copy() + ap = dat["ap_noisy"][nl][iex, 0] + else: + img = dat["imgs_dn"][nl][iex].copy() + maskk = dat["masks_dn"][nl][iex].copy() + ap = dat["ap_dn"][nl][iex, 0] + img = transforms.normalize99(img) + + ax = plt.subplot(grid1[0, k]) + pos = ax.get_position().bounds + ax.set_position([pos[0], pos[1] - 0.05, pos[2], pos[3]]) + ax.imshow(img, vmin=0., vmax=0.75, cmap="gray") + ax.set_title(titlest[k], color="k" if k < 2 else [0,0.5,0], fontsize="medium") + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.axis("off") + if k == 0: + ax.text(0, 1.25, "Denoising two-photon imaging in mice", fontsize="large", + fontstyle="italic", transform=ax.transAxes) + il = plot_label(ltr, il, ax, transl, fs_title) + + ax = plt.subplot(grid1[1, k]) + pos = ax.get_position().bounds + ax.set_position([pos[0], pos[1] - 0.05, pos[2], pos[3]]) + ax.imshow(img, vmin=0, vmax=1, cmap="gray") + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.axis("off") + #ax.set_title("segmentation") + if k == 0: + for o in outlines_gt: + ax.plot(o[:, 0], o[:, 1], color=[1, 0, 1], lw=1, ls="--") + ax.text(1, -0.15, "ground-truth", ha="right", transform=ax.transAxes) + else: + outlines = utils.outlines_list(maskk, multiprocessing=False) + for o in outlines: + ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1, ls="--") + ax.text(1, -0.15, f"AP@0.5 = {ap:.2f}", ha="right", transform=ax.transAxes) + + grid11 = matplotlib.gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=grid1[:, -2:], + wspace=0.3, hspace=0.1) + + + transl = mtransforms.ScaledTranslation(-35 / 72, 25 / 72, fig.dpi_scale_trans) + ax = plt.subplot(grid11[:, 0]) + pos = ax.get_position().bounds + ax.set_position([pos[0] + 0.05, pos[1] - 0.04, pos[2] * 0.8, pos[3]*0.9]) + aps = [dat["ap_noisy"], ap_n2v, ap_n2s, dat["ap_dn"]] + theight = [-0.9, 3, 2, 4] + kk = [1, 2, 3, 7] + titlesd[1] = "noisy\n(4 frames\naveraged)" + for k in range(len(aps)): + means = aps[k][nl, :12].mean(axis=0) + ax.plot(thresholds, means, color=cols[kk[k]]) + ax.text(1.15, 0.62 + theight[k] * 0.09, titlesd[kk[k]], + color=cols[kk[k]], + transform=ax.transAxes, ha="right") + ax.set_ylim([0, 0.8]) + ax.text(-0.18, 1.13, "Segmentation performance", fontstyle="italic", + transform=ax.transAxes, fontsize="large") + ax.set_ylabel("average precision (AP)") + ax.set_xlabel("IoU threshold") + ax.set_xticks(np.arange(0.5, 1.05, 0.25)) + ax.set_yticks(np.arange(0, 1.1, 0.2)) + ax.set_ylim([0, 0.83]) + ax.set_xlim([0.5, 1.0]) + il = plot_label(ltr, il, ax, transl, fs_title) + + ifrs = [slice(0, 12), slice(12, 20)] + for i, ifr in enumerate(ifrs): + ax = plt.subplot(grid11[:, i + 1]) + pos = ax.get_position().bounds + ax.set_position([pos[0] + 0.04 - i*0.01, pos[1] - 0.04, pos[2] * 0.8, pos[3]*0.9]) + nifr = ifr.stop - ifr.start + for k in range(len(aps)): + means = np.array([aps[k][nl][ifr, 0].mean(axis=0) for nl in range(len(aps[k]))]) + sems = np.array([aps[k][nl][ifr, 0].std(axis=0) / (nifr**0.5) for nl in range(len(aps[k]))]) + ax.errorbar(np.arange(0, len(means)), means, sems, color=cols[kk[k]]) + + ax.set_xticks(np.arange(0, len(navgs), 2)) + ax.set_xticklabels([f"{navg}" for navg in navgs[::2]]) + ax.set_yticks(np.arange(0, 1.1, 0.2)) + ax.set_ylim([0, 0.83]) + ax.set_xlabel("# of frames averaged") + ax.set_title("dense expression" if i==0 else "sparse expression", fontsize="medium") + + if save_fig: os.makedirs("figs/", exist_ok=True) fig.savefig("figs/fig2.pdf", dpi=100) @@ -1021,7 +1133,17 @@ def load_benchmarks_specialist(folder, thresholds=np.arange(0.5, 1.05, 0.05)): imgs_all.append(test_care) masks_all.append(masks_care) - dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3.npy", + dat = np.load(root / "noisy_test" / f"test_{noise_type}_denoiseg_specialist.npy", + allow_pickle=True).item() + test_dns = dat["test_denoiseg"][:nimg_test] + masks_dns = dat["masks_denoiseg"][:nimg_test] + imgs_all.append(test_dns) + masks_all.append(masks_dns) + masks_dns = dat["masks_denoiseg_seg"][:nimg_test] + imgs_all.append(test_dns) + masks_all.append(masks_dns) + + dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3_all.npy", allow_pickle=True).item() istr = ["rec", "per", "seg", "perseg"] for k in range(len(istr)): @@ -1069,17 +1191,22 @@ def suppfig_specialist(folder, save_fig=True): legstr0 = [] for ls in legstr[:-1]: legstr0.append(" ".join(ls.split(" ")[1:])) - legstr0.insert(4, "CARE") + legstr0[-1] = u"\u2013 " + legstr0[-1] + legstr0.insert(4, u"\u2013 CARE") + legstr0.insert(5, u"\u2013 denoiseg") + legstr0.insert(6, "-- denoiseg\n(segmentation)") cols0 = list(cols[:-1].copy()) cols0.insert(4, [1, 0.5, 1]) + cols0.insert(5, 0.4*np.ones(3)) + cols0.insert(6, 0.4*np.ones(3)) print(len(cols0)) - legstr0[-1] = "Cellpose3\n(per. + seg.)" + legstr0[-1] = u"\u2013 Cellpose3\n(per. + seg.)" il = 0 fig = plt.figure(figsize=(9, 5), dpi=100) yratio = 9 / 5 - grid = plt.GridSpec(2, 4, figure=fig, left=0.02, right=0.98, top=0.96, bottom=0.1, + grid = plt.GridSpec(2, 4, figure=fig, left=0.02, right=0.96, top=0.96, bottom=0.1, wspace=0.15, hspace=0.2) titles = ["train - clean", "train - noisy", "test - noisy"] @@ -1094,12 +1221,12 @@ def suppfig_specialist(folder, save_fig=True): imset = imgs_all[1].copy() ax = plt.subplot(grid[0, j]) pos = ax.get_position().bounds - ax.set_position([pos[0] - 0.015 * j, pos[1] - 0.04, pos[2], pos[3]]) + ax.set_position([pos[0] - 0.02 * j, pos[1] - 0.04, pos[2], pos[3]]) ly, lx = 128, 128 dy, dx = 20, 30 ni = 5 img0 = np.ones((ly + (ni - 1) * dy, lx + (ni - 1) * dx)) - ii = np.arange(0, 5)[::-1] if j == 2 else np.arange(0, 20 * ni, 20) + ii = np.arange(0, 5)[::-1] if j == 2 else np.arange(1, 20 * ni, 20)[::-1] if j < 2: x0, y0 = 20, 20 else: @@ -1119,17 +1246,18 @@ def suppfig_specialist(folder, save_fig=True): ax.text(0.02, 1.2, "Specialist dataset", fontsize="large", fontstyle="italic", transform=ax.transAxes) - transl = mtransforms.ScaledTranslation(-50 / 72, 8 / 72, fig.dpi_scale_trans) + transl = mtransforms.ScaledTranslation(-45 / 72, 8 / 72, fig.dpi_scale_trans) ax = plt.subplot(grid[0, -1]) pos = ax.get_position().bounds - ax.set_position([pos[0] + 0.03, pos[1] - 0.03, pos[2] * 0.8, + ax.set_position([pos[0] + 0.01, pos[1] - 0.03, pos[2] * 0.8, pos[3] * 1]) #+pos[3]*0.15-0.03, pos[2], pos[3]*0.7]) il = plot_label(ltr, il, ax, transl, fs_title) - theight = [0, 1, 2, 3, 4, 5, 6, 7, 5.1] - for k in [1, 2, 3, 4, 8]: - ax.plot(thresholds, aps[k, :, :].mean(axis=0), color=cols0[k]) + theight = [0, 0, 4, 3, 6, 5, 1, 5, 7, 8, 7.1] + for k in [1, 2, 3, 4, 5, 6, 10]: + ax.plot(thresholds, aps[k, :, :].mean(axis=0), color=cols0[k], + lw=3 if k==4 else 1, ls="--" if k==6 else "-") #ax.errorbar(thresholds, aps[k,:,:].mean(axis=0), aps[k,:,:].std(axis=0) / 10**0.5, color=cols0[k]) - ax.text(0.59, 0.55 + 0.08 * theight[k], legstr0[k], color=cols0[k], + ax.text(0.7, 0.3 + 0.09 * theight[k], legstr0[k], color=cols0[k], transform=ax.transAxes) ax.set_ylim([0, 0.8]) ax.set_ylabel("average precision (AP)") @@ -1139,11 +1267,11 @@ def suppfig_specialist(folder, save_fig=True): transl = mtransforms.ScaledTranslation(-10 / 72, 20 / 72, fig.dpi_scale_trans) - kk = [2, 3, 4, 8] + kk = [2, 3, 4, 10] iex = 8 ylim = [10, 310] xlim = [100, 500] - legstr0[-1] = "Cellpose3 (per. + seg.)" + legstr0[-1] = u"\u2013 Cellpose3 (per. + seg.)" for j, k in enumerate(kk): ax = plt.subplot(grid[1, j]) pos = ax.get_position().bounds @@ -1156,7 +1284,7 @@ def suppfig_specialist(folder, save_fig=True): ax.axis("off") ax.set_ylim(ylim) ax.set_xlim(xlim) - ax.set_title(legstr0[k], color=cols0[k], fontsize="medium") + ax.set_title(legstr0[k][2:], color=cols0[k], fontsize="medium") ax.text(1, -0.04, f"AP@0.5 = {aps[k,iex,0] : 0.2f}", va="top", ha="right", transform=ax.transAxes) if j == 0: @@ -1170,6 +1298,128 @@ def suppfig_specialist(folder, save_fig=True): os.makedirs("figs/", exist_ok=True) fig.savefig("figs/suppfig_specialist.pdf", dpi=100) +def suppfig_impr(folder, save_fig=True): + aps_all = [[], []] + imgs_all, masks_all = [[], []], [[], []] + inds_all = [[], []] + diams = [[], []] + noise_types = ["poisson", "blur", "downsample"] + for noise_type in noise_types: + for j, ctype in enumerate(["cyto2", "nuclei"]): + nimg_test = 68 if ctype == "cyto2" else 111 + folder_name = ctype + root = Path(f"{folder}/images_{folder_name}/") + + dat = np.load(root / "noisy_test" / f"test_{noise_type}.npy", + allow_pickle=True).item() + test_data = dat["test_data"][:nimg_test] + test_noisy = dat["test_noisy"][:nimg_test] + masks_noisy = dat["masks_noisy"][:nimg_test] + masks_true = dat["masks_true"][:nimg_test] + masks_data = dat["masks_orig"][:nimg_test] + diam_test = dat["diam_test"][:nimg_test] + noise_levels = dat["noise_levels"][:nimg_test] + + dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3_all.npy", + allow_pickle=True).item() + + masks_denoised = dat["masks_perseg"][:nimg_test] + test_denoised = dat["test_perseg"][:nimg_test] + thresholds=np.arange(0.5, 1.05, 0.05) + ap_c, tp_d, fp_d, fn_d = metrics.average_precision(masks_true, masks_data, + threshold=thresholds) + ap_d, tp_d, fp_d, fn_d = metrics.average_precision(masks_true, masks_denoised, + threshold=thresholds) + ap_n, tp_n, fp_n, fn_n = metrics.average_precision(masks_true, masks_noisy, + threshold=thresholds) + + aps_all[j].append([ap_c, ap_n, ap_d]) + igood = np.nonzero(ap_d[:,0] > 0)[0] + impr = (ap_d[igood,0] - ap_n[igood,0]) / ap_n[igood,0] + ii = np.hstack((impr.argsort()[-2:][::-1], impr.argsort()[:2])) + ii = igood[ii] + imgs_all[j].append([np.array([test_data[i].squeeze(), test_noisy[i].squeeze(), test_denoised[i].squeeze()]) + for i in ii]) + masks_all[j].append([np.array([masks_data[i].squeeze(), masks_noisy[i].squeeze(), masks_denoised[i].squeeze()]) + for i in ii]) + diams[j].append(dat["diam_test"][ii]) + inds_all[j].append(ii) + + colors = [["darkblue", "royalblue", [0.46, 1, 0], "cyan", "orange", "maroon"], + ["darkblue", [0.46, 1, 0], "dodgerblue"]] + + titles = [["CellImageLibrary", "Cells : fluorescent", "Cells : nonfluorescent", + "Cell membranes", "Microscopy : other", "Non-microscopy"], + ["DSB 2018 / kaggle", "MoNuSeg (H&E)", "ISBI 2009 (fluorescent)"]] + + cinds = [[np.arange(0, 11), np.arange(11,28,1,int), np.arange(28,33,1,int), + np.arange(33,42,1,int), np.arange(42,55,1,int), + np.arange(55,68,1,int)], + [np.arange(0, 75), np.arange(75, 103), np.arange(103, 111)]] + + ddeg = ["noisy", "blurry", "downsampled"] + dcorr = ["denoised", "deblurred", "upsampled"] + dtitle = ["Denoising", "Deblurring", "Upsampling"] + + fig = plt.figure(figsize=(14,8)) + yratio = 14/10 + grid = plt.GridSpec(2, 5, hspace=0.3, wspace=0.5, + left=0.05, right=0.97, top=0.95, bottom=0.05) + il = 0 + transl = mtransforms.ScaledTranslation(-45 / 72, 5 / 72, fig.dpi_scale_trans) + + for c, ctype in enumerate(["cyto2", "nuclei"]): + for d in range(3): + imgs = imgs_all[c][d] + masks = masks_all[c][d] + inds = inds_all[c][d] + aps = aps_all[c][d] + + ax = plt.subplot(grid[c, d + 2*(d>0)]) + pos = ax.get_position().bounds + ax.set_position([pos[0], pos[1]+(pos[3]-pos[2]*yratio), pos[2], pos[2]*yratio]) + for k in range(len(cinds[c])): + ax.scatter(aps[1][cinds[c][k],0], aps[2][cinds[c][k],0], marker="x", + label=titles[c][k], color=colors[c][k]) + ax.plot([0, 1], [0, 1], color="k", lw=1, ls="--") + ax.set_xlabel(f"{ddeg[d]}, AP@0.5") + ax.set_ylabel(f"{dcorr[d]}, AP@0.5", color=[0, 0.5, 0]) + ax.text(-0.2, 1.05, dtitle[d], fontsize="large", transform=ax.transAxes, + fontstyle="italic") + il = plot_label(ltr, il, ax, transl, fs_title) + if d==0: + ax.legend(loc="lower center", bbox_to_anchor=(0.5, -1.3+c*0.4), fontsize="small") + + dstr = ["clean", "noisy", "denoised"] + diam_mean = 30 if ctype=="cyto2" else 17 + grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(3, 4, subplot_spec=grid[c, 1:3], + wspace=0.15, hspace=0.05) + for j in range(4): + Ly, Lx = imgs[j][0].shape + yinds, xinds = plot.interesting_patch(masks[j][0], + bsize=min(Ly, Lx, int(300 * diams[0][0][j] / diam_mean))) + for k in range(3): + ax = plt.subplot(grid1[k, j]) + pos = ax.get_position().bounds + ax.set_position([pos[0]-0.01, pos[1] - 0.015*k, *pos[2:]]) + ax.imshow(imgs[j][k], vmin=0, vmax=1, cmap="gray") + ax.axis("off") + #outlines = utils.outlines_list(masks[j][k], multiprocessing=False) + #for o in outlines: + # ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, ls="--") + ax.set_ylim([yinds[0], yinds[-1]+1]) + ax.set_xlim([xinds[0], xinds[-1]+1]) + ax.text(1, -0.01, f"AP@0.5 = {aps[k][inds[j],0]:.2f}", ha="right", + va="top", fontsize="small", transform=ax.transAxes) + if j%2==0 and k==0: + istr = ["most improved", "least improved"] + ax.set_title(f"{istr[j//2]}", fontsize="medium", fontstyle="italic") + if j==0: + ax.text(-0.05, 0.5, dstr[k], ha="right", va="center", + rotation=90, transform=ax.transAxes, + color="k" if k<2 else [0., 0.5, 0]) + fig.savefig("figs/suppfig_impr.pdf", dpi=300) + def load_benchmarks_generalist(folder, noise_type="poisson", ctype="cyto2", thresholds=np.arange(0.5, 1.05, 0.05)): @@ -1193,20 +1443,15 @@ def load_benchmarks_generalist(folder, noise_type="poisson", ctype="cyto2", imgs_all.append(test_data) imgs_all.append(test_noisy) - dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3.npy", - allow_pickle=True).item() - test_dn = dat["test_perseg"][:nimg_test] - masks_dn = dat["masks_perseg"][:nimg_test] - imgs_all.append(test_dn) - masks_all.append(masks_dn) - - dat = np.load(root / "noisy_test" / f"test_{noise_type}_generalist_cp3.npy", - allow_pickle=True).item() - test_dn = dat["test_generalist"][:nimg_test] - masks_dn = dat["masks_generalist"][:nimg_test] - imgs_all.append(test_dn) - masks_all.append(masks_dn) - + dat = np.load(root / "noisy_test" / f"test_{noise_type}_cp3_all.npy", + allow_pickle=True).item() + istrs = ["perseg", "noise_spec", "data_spec", "gen"] + for istr in istrs: + test_dn = dat[f"test_{istr}"][:nimg_test] + masks_dn = dat[f"masks_{istr}"][:nimg_test] + imgs_all.append(test_dn) + masks_all.append(masks_dn) + # benchmarking aps = [] tps = [] @@ -1225,7 +1470,7 @@ def load_benchmarks_generalist(folder, noise_type="poisson", ctype="cyto2", diam_test) -def fig5(folder, save_fig=True): +def fig6(folder, save_fig=True): folders = [ "cyto2", "nuclei", "tissuenet", "livecell", "yeast_BF", "yeast_PhC", "bact_phase", "bact_fluor", "deepbacs" @@ -1248,8 +1493,9 @@ def fig5(folder, save_fig=True): diams = [utils.diameters(lbl)[0] for lbl in lbls] + gen_model = "/home/carsen/dm11_string/datasets_cellpose/models/per_1.00_seg_1.50_rec_0.00_poisson_blur_downsample_2024_08_20_11_46_25.557039" model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean, - model_type="denoise_cyto3") + pretrained_model=gen_model) seg_model = models.CellposeModel(gpu=True, model_type="cyto3") pscales = [1.5, 20., 1.5, 1., 5., 40., 3.] denoise.deterministic() @@ -1267,7 +1513,6 @@ def fig5(folder, save_fig=True): imgs[j][i], diameter=diams[i], channels=[0, 0], tile_overlap=0.5, flow_threshold=0.4, augment=True, bsize=224, niter=2000 if folders[i - 2] == "bact_phase" else None)[0]) - api = np.array( [metrics.average_precision(lbls, masks[i])[0][:, 0] for i in range(3)]) @@ -1283,65 +1528,70 @@ def fig5(folder, save_fig=True): print(ctype, noise_type, aps0[1:, :, 0].mean(axis=1)) aps.append(aps0) - fig = plt.figure(figsize=(14, 10), dpi=100) - yratio = 14 / 10 - grid = plt.GridSpec(4, 14, figure=fig, left=0.02, right=0.98, top=0.96, bottom=0.1, - wspace=0.05, hspace=0.3) + fig = plt.figure(figsize=(14, 7), dpi=100) + yratio = 14 / 7 + grid = plt.GridSpec(3, 14, figure=fig, left=0.02, right=0.97, top=0.97, bottom=0.1, + wspace=0.05, hspace=0.2) - grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(1, 6, subplot_spec=grid[0, :], + grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(1, 8, subplot_spec=grid[0, :], wspace=0.4, hspace=0.15) - transl = mtransforms.ScaledTranslation(-40 / 72, 10 / 72, fig.dpi_scale_trans) + transl = mtransforms.ScaledTranslation(-0 / 72, 3 / 72, fig.dpi_scale_trans) il = 0 noise_type = ["poisson", "blur", "downsample"][i % 3] + ax = plt.subplot(grid1[0:2]) + pos = ax.get_position().bounds + im = plt.imread("figs/cellpose3_models.png") + yr = im.shape[0] / im.shape[1] + w = 0.22 + ax.set_position([0.0, pos[1]-0.08, w, w*yratio*yr]) + plt.imshow(im) + ax.axis("off") + ax.text(0.08, 1.02, "General restoration models", transform=ax.transAxes, + fontstyle="italic", fontsize="large") + il = plot_label(ltr, il, ax, transl, fs_title) + + transl = mtransforms.ScaledTranslation(-40 / 72, 20 / 72, fig.dpi_scale_trans) thresholds = np.arange(0.5, 1.05, 0.05) - cols0 = np.array(cols)[[0, 0, 7, 7]].copy() - cols0[-1] = np.array([0, 1, 0]) - cols0 = np.clip(cols0, 0, 1) - lss0 = ["-", "-", "-", "--"] - legstr0 = ["", u"\u2013 noisy image", u"\u2013 dataset-specific", u"-- one-click"] - theight = [0, 1, 3, 2] + cols0 = np.array([[0, 0, 0], [0, 0, 0], [0, 128, 0], [180, 229, 162], + [246, 198, 173], [192, 71, 29], ]) + cols0 = cols0 / 255 + lss0 = ["-", "-", "-","-", "-", "-"] + legstr0 = ["", u"\u2013 noisy image", u"\u2013 original", + u"\u2013 noise-specific", "\u2013 data-specific", u"-- one-click"] + theight = [0, 0,4,3,2,1] for i in range(6): ctype = "cellpose test set" if i < 3 else "nuclei test set" noise_type = ["denoising", "deblurring", "upsampling"][i % 3] - ax = plt.subplot(grid1[i]) + ax = plt.subplot(grid1[i+2]) pos = ax.get_position().bounds ax.set_position([ - pos[0] + (5 - i) * 0.01 - 0.02 + 0.03 * (i > 2), pos[1] - 0.05, + pos[0] + 0.025 * (i>2), pos[1] - 0.05, # (5 - i) * 0.01 - 0.02 + 0.03 * (i > 2) pos[2] * 0.92, pos[3] ]) - for k in range(1, len(aps[0])): - ax.plot(thresholds, aps[i][k].mean(axis=0), color=cols0[k], ls=lss0[k]) - if i == 0 or i == 3: - ax.text(0.43, 0.62 + 0.1 * theight[k], legstr0[k], color=cols0[k], - transform=ax.transAxes) + ax.plot(thresholds, aps[i][k].mean(axis=0), color=cols0[k], ls=lss0[k], lw=1) if i == 0 or i == 3: ax.set_ylabel("average precision (AP)") + ax.set_xlabel("IoU threshold") il = plot_label(ltr, il, ax, transl, fs_title) if i == 1 or i == 4: - ax.text(0.5, 1.3, ctype, transform=ax.transAxes, ha="center", + ax.text(0.5, 1.18, ctype, transform=ax.transAxes, ha="center", fontsize="large") - if i == 0: - ax.text(-0.35, 1.35, "One-click models", transform=ax.transAxes, - fontstyle="italic", fontsize="large") - + ax.set_ylim([0, 0.72]) - ax.set_xlabel("IoU threshold") - ax.set_xticks(np.arange(0.5, 1.05, 0.1)) + ax.set_xticks(np.arange(0.5, 1.05, 0.25)) ax.set_xlim([0.5, 1.0]) ax.set_title(f"{noise_type}", fontsize="medium") - #yr, xr = 200, 240 - - titlesj = ["clean", "noisy", "denoised"] + titlesj = ["clean", "noisy", "denoised (one-click)"] titlesi = [ "Tissuenet", "Livecell", "Yeaz bright-field", "YeaZ phase-contrast", "Omnipose phase-contrast", "Omnipose fluorescent", "DeepBacs" ] - colsj = cols0[[0, 1, 3]] + colsj = cols0[[0, 1, -1]] ly0 = 250 @@ -1355,15 +1605,10 @@ def fig5(folder, save_fig=True): mask_gt = lbls[i].copy() #outlines_gt = utils.outlines_list(mask_gt, multiprocessing=False) - for j in range(3): - #img = np.zeros((*imgs[j][i].shape[1:], 3)) - #img[:,:,1:] = imgs[j][i][[0,1]].transpose(1,2,0) - if imgs[j][i].ndim == 3: - imgs[j][i] = imgs[j][i][0] - - img = np.clip(transforms.normalize99(imgs[j][i].copy()), 0, 1) + for j in range(1, 3): + img = np.clip(transforms.normalize99(imgs[j][i].copy().squeeze()), 0, 1) for k in range(2): - ax = plt.subplot(grid[j + 1, 2 * i + k]) + ax = plt.subplot(grid[j, 2 * i + k]) pos = ax.get_position().bounds ax.set_position([ pos[0] + 0.003 * i - 0.00 * k, pos[1] - (2 - j) * 0.025 - 0.07, @@ -1401,11 +1646,9 @@ def fig5(folder, save_fig=True): if k == 0 and j == 0: ax.text(0.0, 1.05, titlesi[i], transform=ax.transAxes, fontsize="medium") - if save_fig: os.makedirs("figs/", exist_ok=True) - fig.savefig("figs/fig5.pdf", dpi=100) - + fig.savefig("figs/fig6.pdf", dpi=150) def load_seg_generalist(folder): folders = [ @@ -1473,7 +1716,7 @@ def load_seg_generalist(folder): return apcs, api, imgs, masks_true, masks_pred -def suppfig_generalist(folder, save_fig=True): +def fig5(folder, save_fig=True): thresholds = np.arange(0.5, 1.05, 0.05) apcs, api, imgs, masks_true, masks_pred = load_seg_generalist(folder) titlesi = [ @@ -1583,7 +1826,7 @@ def suppfig_generalist(folder, save_fig=True): if save_fig: os.makedirs("figs/", exist_ok=True) - fig.savefig("figs/suppfig_generalist.pdf", dpi=100) + fig.savefig("figs/fig5.pdf", dpi=100) if __name__ == "__main__": @@ -1617,10 +1860,14 @@ def suppfig_generalist(folder, save_fig=True): fig4(folder, save_fig=0, ctype="nuclei") plt.show() + # ex images + suppfig_impr(folder, save_fig=0) + plt.show() + # one-click + supergeneralist fig5(folder, save_fig=0) plt.show() - suppfig_generalist(folder, save_fig=0) + fig6(folder, save_fig=0) plt.show()