diff --git a/cellpose/train.py b/cellpose/train.py index 95ec4951..e45903bb 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -319,7 +319,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, Train the network with images for segmentation. Args: - net (object): The network model to train. + net (object): The network model to train. If `net` is a bfloat16 model on MPS, it will be converted to float32 for training. The saved models will be in float32, but the original model will be returned in bfloat16 for consistency. CUDA/CPU will train in bfloat16 if that is the provided net dtype. train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None. train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None. train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None. @@ -357,6 +357,14 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, device = net.device + original_net_dtype = None + if device.type == 'mps' and net.dtype == torch.bfloat16: + # NOTE: this produces a side effect of returning a network that is not of a guaranteed dtype \ + original_net_dtype = torch.bfloat16 + train_logger.warning("Training with bfloat16 on MPS is not supported, using float32 network instead") + net.dtype = torch.float32 + net.to(torch.float32) + scale_range = 0.5 if scale_range is None else scale_range if isinstance(normalize, dict): @@ -533,4 +541,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, net.save_model(filename) + if original_net_dtype is not None: + net.dtype = original_net_dtype + net.to(original_net_dtype) + return filename, train_losses, test_losses