Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions cellpose/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,28 @@ def use_gpu(gpu_number=0, use_torch=True):

def _use_gpu_torch(gpu_number=0):
"""
Checks if CUDA is available and working with PyTorch.
Checks if CUDA or MPS is available and working with PyTorch.

Args:
gpu_number (int): The GPU device number to use (default is 0).

Returns:
bool: True if CUDA is available and working, False otherwise.
bool: True if CUDA or MPS is available and working, False otherwise.
"""
try:
device = torch.device("cuda:" + str(gpu_number))
_ = torch.zeros([1, 2, 3]).to(device)
core_logger.info("** TORCH CUDA version installed and working. **")
return True
except:
core_logger.info("TORCH CUDA version not installed/working.")
pass
try:
device = torch.device('mps:' + str(gpu_number))
_ = torch.zeros([1, 2, 3]).to(device)
core_logger.info('** TORCH MPS version installed and working. **')
return True
except:
core_logger.info('Neither TORCH CUDA nor MPS version not installed/working.')
return False


Expand All @@ -76,28 +83,35 @@ def assign_device(use_torch=True, gpu=False, device=0):
torch.device: The assigned device.
bool: True if GPU is used, False otherwise.
"""
mac = False
cpu = True

if isinstance(device, str):
if device == "mps":
mac = True
else:
if device != "mps" or not(gpu and torch.backends.mps.is_available()):
device = int(device)
if gpu and use_gpu(use_torch=True):
device = torch.device(f"cuda:{device}")
gpu = True
cpu = False
core_logger.info(">>>> using GPU")
elif mac:
try:
device = torch.device("mps")
gpu = True
cpu = False
core_logger.info(">>>> using GPU")
if torch.cuda.is_available():
device = torch.device(f'cuda:{device}')
core_logger.info(">>>> using GPU (CUDA)")
gpu = True
cpu = False
except:
gpu = False
cpu = True
try:
if torch.backends.mps.is_available():
device = torch.device('mps')
core_logger.info(">>>> using GPU (MPS)")
gpu = True
cpu = False
except:
gpu = False

cpu = True
else:
device = torch.device('cpu')
core_logger.info('>>>> using CPU')
gpu = False
cpu = True

if cpu:
device = torch.device("cpu")
core_logger.info(">>>> using CPU")
Expand Down
6 changes: 4 additions & 2 deletions cellpose/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsamp
def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7,
downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30,
ds_max=7, iso=True, rotate=True,
device=torch.device("cuda"), xy=(224, 224),
device=None, xy=(224, 224),
nchan_noise=1, keep_raw=True):
"""
Applies random rotation, resizing, and noise to the input data.
Expand Down Expand Up @@ -349,7 +349,9 @@ def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, b
torch.Tensor: The augmented labels.
float: The scale factor applied to the image.
"""

if device == None:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None

diams = 30 if diams is None else diams
random_diam = diam_mean * (2**(2 * np.random.rand(len(data)) - 1))
random_rsc = diams / random_diam #/ random_diam
Expand Down
21 changes: 12 additions & 9 deletions cellpose/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from . import resnet_torch

TORCH_ENABLED = True
torch_GPU = torch.device("cuda")
torch_GPU = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
torch_CPU = torch.device("cpu")


Expand Down Expand Up @@ -54,8 +54,7 @@ def _extend_centers(T, y, x, ymed, xmed, Lx, niter):
return T


def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200,
device=torch.device("cuda")):
def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, device=None):
"""Runs diffusion on GPU to generate flows for training images or quality control.

Args:
Expand All @@ -71,9 +70,13 @@ def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200,

"""
if device is None:
device = torch.device("cuda")

T = torch.zeros(shape, dtype=torch.double, device=device)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None


if device.type == "mps":
T = torch.zeros(shape, dtype=torch.float, device=device)
else:
T = torch.zeros(shape, dtype=torch.double, device=device)
for i in range(n_iter):
T[tuple(meds.T)] += 1
Tneigh = T[tuple(neighbors)]
Expand Down Expand Up @@ -148,7 +151,7 @@ def masks_to_flows_gpu(masks, device=None, niter=None):
- meds_p (float, 2D or 3D array): cell centers
"""
if device is None:
device = torch.device("cuda")
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None

Ly0, Lx0 = masks.shape
Ly, Lx = Ly0 + 2, Lx0 + 2
Expand Down Expand Up @@ -205,7 +208,7 @@ def masks_to_flows_gpu_3d(masks, device=None):
- mu_c (float, 2D or 3D array): zeros
"""
if device is None:
device = torch.device("cuda")
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None

Lz0, Ly0, Lx0 = masks.shape
Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2
Expand Down Expand Up @@ -468,7 +471,7 @@ def steps2D_interp(p, dP, niter, device=None):
"""

shape = dP.shape[1:]
if device is not None and device.type == "cuda":
if device is not None and (device.type == "cuda" or device.type == "mps"):
shape = np.array(shape)[[
1, 0
]].astype("float") - 1 # Y and X dimensions (dP is 2.Ly.Lx), flipped X-1, Y-1
Expand Down
7 changes: 6 additions & 1 deletion cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,12 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
sdevice, gpu = assign_device(use_torch=True, gpu=gpu)
self.device = device if device is not None else sdevice
if device is not None:
device_gpu = self.device.type == "cuda"
if torch.cuda.is_available():
device_gpu = self.device.type == "cuda"
elif torch.backends.mps.is_available():
device_gpu = self.device.type == "mps"
else:
device_gpu = False
self.gpu = gpu if device is None else device_gpu
if not self.gpu:
self.mkldnn = check_mkl(True)
Expand Down
5 changes: 4 additions & 1 deletion cellpose/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None,
test_probs=None, load_files=True, min_train_masks=5,
compute_flows=False, channels=None, channel_axis=None,
rgb=False, normalize_params={"normalize": False
}, device=torch.device("cuda")):
}, device=None):
"""
Process train and test data.

Expand All @@ -183,6 +183,9 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None,
Returns:
tuple: A tuple containing the processed train and test data and sampling probabilities and diameters.
"""
if device == None:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None

if train_data is not None and train_labels is not None:
# if data is loaded
nimg = len(train_data)
Expand Down