Skip to content

Commit de51608

Browse files
authored
Merge pull request #1278 from MouseLand/mps_bfloat16
Patch training bfloat16 on MPS bug
2 parents 5f3ea08 + f2b98f4 commit de51608

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

cellpose/train.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
319319
Train the network with images for segmentation.
320320
321321
Args:
322-
net (object): The network model to train.
322+
net (object): The network model to train. If `net` is a bfloat16 model on MPS, it will be converted to float32 for training. The saved models will be in float32, but the original model will be returned in bfloat16 for consistency. CUDA/CPU will train in bfloat16 if that is the provided net dtype.
323323
train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None.
324324
train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None.
325325
train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None.
@@ -357,6 +357,14 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
357357

358358
device = net.device
359359

360+
original_net_dtype = None
361+
if device.type == 'mps' and net.dtype == torch.bfloat16:
362+
# NOTE: this produces a side effect of returning a network that is not of a guaranteed dtype \
363+
original_net_dtype = torch.bfloat16
364+
train_logger.warning("Training with bfloat16 on MPS is not supported, using float32 network instead")
365+
net.dtype = torch.float32
366+
net.to(torch.float32)
367+
360368
scale_range = 0.5 if scale_range is None else scale_range
361369

362370
if isinstance(normalize, dict):
@@ -533,4 +541,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
533541

534542
net.save_model(filename)
535543

544+
if original_net_dtype is not None:
545+
net.dtype = original_net_dtype
546+
net.to(original_net_dtype)
547+
536548
return filename, train_losses, test_losses

0 commit comments

Comments
 (0)