From 7e545ec8d02b3d3106953d7023d1c4c211cf03cb Mon Sep 17 00:00:00 2001 From: Michael Rariden Date: Mon, 21 Jul 2025 15:10:44 -0400 Subject: [PATCH 1/2] warn when training with bfloat16 on MPS; fallback to float32 --- cellpose/train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cellpose/train.py b/cellpose/train.py index 95ec4951..528b871a 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -357,6 +357,12 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, device = net.device + if device.type == 'mps' and net.dtype == torch.bfloat16: + # NOTE: this produces a side effect of returning a network that is not of a guaranteed dtype + train_logger.warning("Training with bfloat16 on MPS is not supported, using float32 network instead") + net.dtype = torch.float32 + net.to(torch.float32) + scale_range = 0.5 if scale_range is None else scale_range if isinstance(normalize, dict): From f2b98f4f80f1f85a2607778625772fbc049c6ea5 Mon Sep 17 00:00:00 2001 From: Michael Rariden Date: Mon, 21 Jul 2025 15:31:51 -0400 Subject: [PATCH 2/2] revert back to originial dtype after training --- cellpose/train.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/cellpose/train.py b/cellpose/train.py index 528b871a..e45903bb 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -319,7 +319,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, Train the network with images for segmentation. Args: - net (object): The network model to train. + net (object): The network model to train. If `net` is a bfloat16 model on MPS, it will be converted to float32 for training. The saved models will be in float32, but the original model will be returned in bfloat16 for consistency. CUDA/CPU will train in bfloat16 if that is the provided net dtype. train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None. train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None. train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None. @@ -357,8 +357,10 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, device = net.device + original_net_dtype = None if device.type == 'mps' and net.dtype == torch.bfloat16: - # NOTE: this produces a side effect of returning a network that is not of a guaranteed dtype + # NOTE: this produces a side effect of returning a network that is not of a guaranteed dtype \ + original_net_dtype = torch.bfloat16 train_logger.warning("Training with bfloat16 on MPS is not supported, using float32 network instead") net.dtype = torch.float32 net.to(torch.float32) @@ -539,4 +541,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, net.save_model(filename) + if original_net_dtype is not None: + net.dtype = original_net_dtype + net.to(original_net_dtype) + return filename, train_losses, test_losses