diff --git a/cellpose/core.py b/cellpose/core.py index aa8399e8..00533e5c 100644 --- a/cellpose/core.py +++ b/cellpose/core.py @@ -45,13 +45,13 @@ def use_gpu(gpu_number=0, use_torch=True): def _use_gpu_torch(gpu_number=0): """ - Checks if CUDA is available and working with PyTorch. + Checks if CUDA or MPS is available and working with PyTorch. Args: gpu_number (int): The GPU device number to use (default is 0). Returns: - bool: True if CUDA is available and working, False otherwise. + bool: True if CUDA or MPS is available and working, False otherwise. """ try: device = torch.device("cuda:" + str(gpu_number)) @@ -59,7 +59,14 @@ def _use_gpu_torch(gpu_number=0): core_logger.info("** TORCH CUDA version installed and working. **") return True except: - core_logger.info("TORCH CUDA version not installed/working.") + pass + try: + device = torch.device('mps:' + str(gpu_number)) + _ = torch.zeros([1, 2, 3]).to(device) + core_logger.info('** TORCH MPS version installed and working. **') + return True + except: + core_logger.info('Neither TORCH CUDA nor MPS version not installed/working.') return False @@ -76,28 +83,35 @@ def assign_device(use_torch=True, gpu=False, device=0): torch.device: The assigned device. bool: True if GPU is used, False otherwise. """ - mac = False - cpu = True + if isinstance(device, str): - if device == "mps": - mac = True - else: + if device != "mps" or not(gpu and torch.backends.mps.is_available()): device = int(device) if gpu and use_gpu(use_torch=True): - device = torch.device(f"cuda:{device}") - gpu = True - cpu = False - core_logger.info(">>>> using GPU") - elif mac: try: - device = torch.device("mps") - gpu = True - cpu = False - core_logger.info(">>>> using GPU") + if torch.cuda.is_available(): + device = torch.device(f'cuda:{device}') + core_logger.info(">>>> using GPU (CUDA)") + gpu = True + cpu = False except: + gpu = False cpu = True + try: + if torch.backends.mps.is_available(): + device = torch.device('mps') + core_logger.info(">>>> using GPU (MPS)") + gpu = True + cpu = False + except: gpu = False - + cpu = True + else: + device = torch.device('cpu') + core_logger.info('>>>> using CPU') + gpu = False + cpu = True + if cpu: device = torch.device("cpu") core_logger.info(">>>> using CPU") diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 1270632d..8deb533f 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -321,7 +321,7 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7, downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30, ds_max=7, iso=True, rotate=True, - device=torch.device("cuda"), xy=(224, 224), + device=None, xy=(224, 224), nchan_noise=1, keep_raw=True): """ Applies random rotation, resizing, and noise to the input data. @@ -349,7 +349,9 @@ def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, b torch.Tensor: The augmented labels. float: The scale factor applied to the image. """ - + if device == None: + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None + diams = 30 if diams is None else diams random_diam = diam_mean * (2**(2 * np.random.rand(len(data)) - 1)) random_rsc = diams / random_diam #/ random_diam diff --git a/cellpose/dynamics.py b/cellpose/dynamics.py index a73130ed..d38e07ba 100644 --- a/cellpose/dynamics.py +++ b/cellpose/dynamics.py @@ -24,7 +24,7 @@ from . import resnet_torch TORCH_ENABLED = True -torch_GPU = torch.device("cuda") +torch_GPU = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None torch_CPU = torch.device("cpu") @@ -54,8 +54,7 @@ def _extend_centers(T, y, x, ymed, xmed, Lx, niter): return T -def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, - device=torch.device("cuda")): +def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, device=None): """Runs diffusion on GPU to generate flows for training images or quality control. Args: @@ -71,9 +70,13 @@ def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, """ if device is None: - device = torch.device("cuda") - - T = torch.zeros(shape, dtype=torch.double, device=device) + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None + + + if device.type == "mps": + T = torch.zeros(shape, dtype=torch.float, device=device) + else: + T = torch.zeros(shape, dtype=torch.double, device=device) for i in range(n_iter): T[tuple(meds.T)] += 1 Tneigh = T[tuple(neighbors)] @@ -148,7 +151,7 @@ def masks_to_flows_gpu(masks, device=None, niter=None): - meds_p (float, 2D or 3D array): cell centers """ if device is None: - device = torch.device("cuda") + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None Ly0, Lx0 = masks.shape Ly, Lx = Ly0 + 2, Lx0 + 2 @@ -205,7 +208,7 @@ def masks_to_flows_gpu_3d(masks, device=None): - mu_c (float, 2D or 3D array): zeros """ if device is None: - device = torch.device("cuda") + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None Lz0, Ly0, Lx0 = masks.shape Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2 @@ -468,7 +471,7 @@ def steps2D_interp(p, dP, niter, device=None): """ shape = dP.shape[1:] - if device is not None and device.type == "cuda": + if device is not None and (device.type == "cuda" or device.type == "mps"): shape = np.array(shape)[[ 1, 0 ]].astype("float") - 1 # Y and X dimensions (dP is 2.Ly.Lx), flipped X-1, Y-1 diff --git a/cellpose/models.py b/cellpose/models.py index 5329d7f5..6198041d 100644 --- a/cellpose/models.py +++ b/cellpose/models.py @@ -297,7 +297,12 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, sdevice, gpu = assign_device(use_torch=True, gpu=gpu) self.device = device if device is not None else sdevice if device is not None: - device_gpu = self.device.type == "cuda" + if torch.cuda.is_available(): + device_gpu = self.device.type == "cuda" + elif torch.backends.mps.is_available(): + device_gpu = self.device.type == "mps" + else: + device_gpu = False self.gpu = gpu if device is None else device_gpu if not self.gpu: self.mkldnn = check_mkl(True) diff --git a/cellpose/train.py b/cellpose/train.py index 01d5fc64..faa2665b 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -156,7 +156,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, test_probs=None, load_files=True, min_train_masks=5, compute_flows=False, channels=None, channel_axis=None, rgb=False, normalize_params={"normalize": False - }, device=torch.device("cuda")): + }, device=None): """ Process train and test data. @@ -183,6 +183,9 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, Returns: tuple: A tuple containing the processed train and test data and sampling probabilities and diameters. """ + if device == None: + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None + if train_data is not None and train_labels is not None: # if data is loaded nimg = len(train_data)