diff --git a/cellpose/core.py b/cellpose/core.py index 0505f464..f016bcde 100644 --- a/cellpose/core.py +++ b/cellpose/core.py @@ -77,7 +77,7 @@ def assign_device(use_torch=True, gpu=False, device=0): if isinstance(device, str): if device != "mps" or not(gpu and torch.backends.mps.is_available()): device = int(device) - if gpu and use_gpu(use_torch=True): + if gpu and use_gpu(gpu_number=device, use_torch=use_torch): try: if torch.cuda.is_available(): device = torch.device(f'cuda:{device}')