From f1dd0602ed86e556f903f0c81bba8eab2490cbde Mon Sep 17 00:00:00 2001 From: OratHelm Date: Fri, 19 Jul 2024 17:42:26 +0200 Subject: [PATCH 1/2] MPS compatibility to use GPU on Apple Silicon From what I've tested, it works well using the GPU (or with the CPU too, it's your choice), from the GUI or from a Python script. However, I couldn't get the training to work, unfortunately. There are no errors but the generated model doesn't find any cells, and at all epochs train_loss=nan, test_loss=0.0000, except when using the GUI where train_loss at epoch 0 is non-zero. --- cellpose/core.py | 50 ++++++++++++++++++++++++++++---------------- cellpose/denoise.py | 6 ++++-- cellpose/dynamics.py | 21 +++++++++++-------- cellpose/models.py | 5 ++++- cellpose/train.py | 5 ++++- 5 files changed, 56 insertions(+), 31 deletions(-) diff --git a/cellpose/core.py b/cellpose/core.py index aa8399e8..00533e5c 100644 --- a/cellpose/core.py +++ b/cellpose/core.py @@ -45,13 +45,13 @@ def use_gpu(gpu_number=0, use_torch=True): def _use_gpu_torch(gpu_number=0): """ - Checks if CUDA is available and working with PyTorch. + Checks if CUDA or MPS is available and working with PyTorch. Args: gpu_number (int): The GPU device number to use (default is 0). Returns: - bool: True if CUDA is available and working, False otherwise. + bool: True if CUDA or MPS is available and working, False otherwise. """ try: device = torch.device("cuda:" + str(gpu_number)) @@ -59,7 +59,14 @@ def _use_gpu_torch(gpu_number=0): core_logger.info("** TORCH CUDA version installed and working. **") return True except: - core_logger.info("TORCH CUDA version not installed/working.") + pass + try: + device = torch.device('mps:' + str(gpu_number)) + _ = torch.zeros([1, 2, 3]).to(device) + core_logger.info('** TORCH MPS version installed and working. **') + return True + except: + core_logger.info('Neither TORCH CUDA nor MPS version not installed/working.') return False @@ -76,28 +83,35 @@ def assign_device(use_torch=True, gpu=False, device=0): torch.device: The assigned device. bool: True if GPU is used, False otherwise. """ - mac = False - cpu = True + if isinstance(device, str): - if device == "mps": - mac = True - else: + if device != "mps" or not(gpu and torch.backends.mps.is_available()): device = int(device) if gpu and use_gpu(use_torch=True): - device = torch.device(f"cuda:{device}") - gpu = True - cpu = False - core_logger.info(">>>> using GPU") - elif mac: try: - device = torch.device("mps") - gpu = True - cpu = False - core_logger.info(">>>> using GPU") + if torch.cuda.is_available(): + device = torch.device(f'cuda:{device}') + core_logger.info(">>>> using GPU (CUDA)") + gpu = True + cpu = False except: + gpu = False cpu = True + try: + if torch.backends.mps.is_available(): + device = torch.device('mps') + core_logger.info(">>>> using GPU (MPS)") + gpu = True + cpu = False + except: gpu = False - + cpu = True + else: + device = torch.device('cpu') + core_logger.info('>>>> using CPU') + gpu = False + cpu = True + if cpu: device = torch.device("cpu") core_logger.info(">>>> using CPU") diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 1270632d..8deb533f 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -321,7 +321,7 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7, downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30, ds_max=7, iso=True, rotate=True, - device=torch.device("cuda"), xy=(224, 224), + device=None, xy=(224, 224), nchan_noise=1, keep_raw=True): """ Applies random rotation, resizing, and noise to the input data. @@ -349,7 +349,9 @@ def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, b torch.Tensor: The augmented labels. float: The scale factor applied to the image. """ - + if device == None: + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None + diams = 30 if diams is None else diams random_diam = diam_mean * (2**(2 * np.random.rand(len(data)) - 1)) random_rsc = diams / random_diam #/ random_diam diff --git a/cellpose/dynamics.py b/cellpose/dynamics.py index a73130ed..d38e07ba 100644 --- a/cellpose/dynamics.py +++ b/cellpose/dynamics.py @@ -24,7 +24,7 @@ from . import resnet_torch TORCH_ENABLED = True -torch_GPU = torch.device("cuda") +torch_GPU = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None torch_CPU = torch.device("cpu") @@ -54,8 +54,7 @@ def _extend_centers(T, y, x, ymed, xmed, Lx, niter): return T -def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, - device=torch.device("cuda")): +def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, device=None): """Runs diffusion on GPU to generate flows for training images or quality control. Args: @@ -71,9 +70,13 @@ def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, """ if device is None: - device = torch.device("cuda") - - T = torch.zeros(shape, dtype=torch.double, device=device) + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None + + + if device.type == "mps": + T = torch.zeros(shape, dtype=torch.float, device=device) + else: + T = torch.zeros(shape, dtype=torch.double, device=device) for i in range(n_iter): T[tuple(meds.T)] += 1 Tneigh = T[tuple(neighbors)] @@ -148,7 +151,7 @@ def masks_to_flows_gpu(masks, device=None, niter=None): - meds_p (float, 2D or 3D array): cell centers """ if device is None: - device = torch.device("cuda") + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None Ly0, Lx0 = masks.shape Ly, Lx = Ly0 + 2, Lx0 + 2 @@ -205,7 +208,7 @@ def masks_to_flows_gpu_3d(masks, device=None): - mu_c (float, 2D or 3D array): zeros """ if device is None: - device = torch.device("cuda") + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None Lz0, Ly0, Lx0 = masks.shape Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2 @@ -468,7 +471,7 @@ def steps2D_interp(p, dP, niter, device=None): """ shape = dP.shape[1:] - if device is not None and device.type == "cuda": + if device is not None and (device.type == "cuda" or device.type == "mps"): shape = np.array(shape)[[ 1, 0 ]].astype("float") - 1 # Y and X dimensions (dP is 2.Ly.Lx), flipped X-1, Y-1 diff --git a/cellpose/models.py b/cellpose/models.py index 5329d7f5..d1efc6bd 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -297,7 +297,10 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, sdevice, gpu = assign_device(use_torch=True, gpu=gpu) self.device = device if device is not None else sdevice if device is not None: - device_gpu = self.device.type == "cuda" + if torch.cuda.is_available(): + device_gpu = self.device.type == "cuda" + elif torch.backends.mps.is_available(): + device_gpu = self.device.type == "mps" self.gpu = gpu if device is None else device_gpu if not self.gpu: self.mkldnn = check_mkl(True) diff --git a/cellpose/train.py b/cellpose/train.py index 01d5fc64..faa2665b 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -156,7 +156,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, test_probs=None, load_files=True, min_train_masks=5, compute_flows=False, channels=None, channel_axis=None, rgb=False, normalize_params={"normalize": False - }, device=torch.device("cuda")): + }, device=None): """ Process train and test data. @@ -183,6 +183,9 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, Returns: tuple: A tuple containing the processed train and test data and sampling probabilities and diameters. """ + if device == None: + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None + if train_data is not None and train_labels is not None: # if data is loaded nimg = len(train_data) From 6b7fbc73a4d26588fbba5194673432e7e63f390d Mon Sep 17 00:00:00 2001 From: OratHelm Date: Thu, 25 Jul 2024 15:56:36 +0200 Subject: [PATCH 2/2] Adding an omission --- cellpose/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cellpose/models.py b/cellpose/models.py index d1efc6bd..6198041d 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -301,6 +301,8 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, device_gpu = self.device.type == "cuda" elif torch.backends.mps.is_available(): device_gpu = self.device.type == "mps" + else: + device_gpu = False self.gpu = gpu if device is None else device_gpu if not self.gpu: self.mkldnn = check_mkl(True)