From 9d254084277eabe8ad20d9974b0f5199b5a70ddc Mon Sep 17 00:00:00 2001 From: Michael Rariden Date: Thu, 8 Jan 2026 10:57:21 -0500 Subject: [PATCH 1/4] update train_seg function to always train in float32 --- cellpose/train.py | 20 +++++++++++--------- cellpose/vit_sam.py | 1 + 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/cellpose/train.py b/cellpose/train.py index 401c0efc..6094bfc9 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -319,7 +319,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, Train the network with images for segmentation. Args: - net (object): The network model to train. If `net` is a bfloat16 model on MPS, it will be converted to float32 for training. The saved models will be in float32, but the original model will be returned in bfloat16 for consistency. CUDA/CPU will train in bfloat16 if that is the provided net dtype. + net (object): The network model to train. If `net` is a bfloat16 model it will be converted to float32 for training. The saved models will be in float32, but the original model will be returned as the original dtype for consistency. train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None. train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None. train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None. @@ -356,11 +356,10 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, device = net.device - original_net_dtype = None - if device.type == 'mps' and net.dtype == torch.bfloat16: + original_net_dtype = net.dtype + if net.dtype == torch.bfloat16: # NOTE: this produces a side effect of returning a network that is not of a guaranteed dtype \ - original_net_dtype = torch.bfloat16 - train_logger.warning("Training with bfloat16 on MPS is not supported, using float32 network instead") + train_logger.info(">>> converting bfloat16 network to float32 for training") net.dtype = torch.float32 net.to(torch.float32) @@ -539,9 +538,12 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, net.save_model(filename0) net.save_model(filename) - - if original_net_dtype is not None: - net.dtype = original_net_dtype - net.to(original_net_dtype) + if original_net_dtype == torch.bfloat16: + train_logger.info(">>> converting network back to bfloat16") + net.dtype = torch.bfloat16 + net.to(torch.bfloat16) + # if original_net_dtype is not None: + # net.dtype = original_net_dtype + # net.to(original_net_dtype) return filename, train_losses, test_losses diff --git a/cellpose/vit_sam.py b/cellpose/vit_sam.py index 7d93378a..52cb6b06 100644 --- a/cellpose/vit_sam.py +++ b/cellpose/vit_sam.py @@ -90,6 +90,7 @@ def load_model(self, PATH, device, strict = False): if w2_data == None: raise ValueError('This model does not appear to be a CP4 model. CP3 models are not compatible with CP4.') + # models are always saved as float32 if keys[0][:7] == "module.": from collections import OrderedDict new_state_dict = OrderedDict() From 3915921282f69197506a23c5abb2116d33658ee0 Mon Sep 17 00:00:00 2001 From: Michael Rariden Date: Thu, 8 Jan 2026 11:04:30 -0500 Subject: [PATCH 2/4] refactor train_seg and Transformer class to improve dtype handling --- cellpose/train.py | 6 ++---- cellpose/vit_sam.py | 23 ++++++++++++++++++++++- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/cellpose/train.py b/cellpose/train.py index 6094bfc9..3df474a4 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -360,8 +360,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, if net.dtype == torch.bfloat16: # NOTE: this produces a side effect of returning a network that is not of a guaranteed dtype \ train_logger.info(">>> converting bfloat16 network to float32 for training") - net.dtype = torch.float32 - net.to(torch.float32) + net.dtype(torch.float32) scale_range = 0.5 if scale_range is None else scale_range @@ -540,8 +539,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, net.save_model(filename) if original_net_dtype == torch.bfloat16: train_logger.info(">>> converting network back to bfloat16") - net.dtype = torch.bfloat16 - net.to(torch.bfloat16) + net.dtype(torch.bfloat16) # if original_net_dtype is not None: # net.dtype = original_net_dtype # net.to(original_net_dtype) diff --git a/cellpose/vit_sam.py b/cellpose/vit_sam.py index 52cb6b06..cdac9724 100644 --- a/cellpose/vit_sam.py +++ b/cellpose/vit_sam.py @@ -49,7 +49,7 @@ def __init__(self, backbone="vit_l", ps=8, nout=3, bsize=256, rdrop=0.4, for blk in self.encoder.blocks: blk.window_size = 0 - self.dtype = dtype + self._dtype = dtype if self.dtype != torch.float32: self = self.to(self.dtype) @@ -104,6 +104,27 @@ def load_model(self, PATH, device, strict = False): if self.dtype != torch.float32: self = self.to(self.dtype) + @property + def dtype(self): + """ + Get the data type of the model. + + Returns: + torch.dtype: The data type of the model. + """ + return self._dtype + + @dtype.setter + def dtype(self, value): + """ + Set the data type of the model. + + Args: + value (torch.dtype): The data type to set for the model. + """ + if self._dtype != value: + self.to(value) + self._dtype = value @property def device(self): From 7bb28552838deb9ea097eca05d7143ac3f4501bd Mon Sep 17 00:00:00 2001 From: Michael Rariden Date: Thu, 8 Jan 2026 11:29:44 -0500 Subject: [PATCH 3/4] use properties and setter correctly in Transformer dtype --- cellpose/train.py | 4 ++-- cellpose/vit_sam.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cellpose/train.py b/cellpose/train.py index 3df474a4..fab75fb8 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -360,7 +360,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, if net.dtype == torch.bfloat16: # NOTE: this produces a side effect of returning a network that is not of a guaranteed dtype \ train_logger.info(">>> converting bfloat16 network to float32 for training") - net.dtype(torch.float32) + net.dtype = torch.float32 scale_range = 0.5 if scale_range is None else scale_range @@ -539,7 +539,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, net.save_model(filename) if original_net_dtype == torch.bfloat16: train_logger.info(">>> converting network back to bfloat16") - net.dtype(torch.bfloat16) + net.dtype = torch.bfloat16 # if original_net_dtype is not None: # net.dtype = original_net_dtype # net.to(original_net_dtype) diff --git a/cellpose/vit_sam.py b/cellpose/vit_sam.py index cdac9724..70332eda 100644 --- a/cellpose/vit_sam.py +++ b/cellpose/vit_sam.py @@ -50,8 +50,8 @@ def __init__(self, backbone="vit_l", ps=8, nout=3, bsize=256, rdrop=0.4, blk.window_size = 0 self._dtype = dtype - if self.dtype != torch.float32: - self = self.to(self.dtype) + if dtype != torch.float32: + self.dtype = dtype def forward(self, x): # same progression as SAM until readout From 1783aea431ec036274be4ff7db12cb8df705c523 Mon Sep 17 00:00:00 2001 From: Michael Rariden Date: Thu, 8 Jan 2026 11:32:39 -0500 Subject: [PATCH 4/4] refactor train_seg to use autocast for dtype handling and improve logging --- cellpose/train.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/cellpose/train.py b/cellpose/train.py index fab75fb8..aa312c74 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -460,11 +460,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, X = torch.from_numpy(imgi).to(device) lbl = torch.from_numpy(lbl).to(device) - if X.dtype != net.dtype: - X = X.to(net.dtype) - lbl = lbl.to(net.dtype) - - y = net(X)[0] + with torch.autocast(device_type=device.type, dtype=net.dtype): + y = net(X)[0] loss = _loss_fn_seg(lbl, y, device) if y.shape[1] > 3: loss3 = _loss_fn_class(lbl, y, class_weights=class_weights) @@ -508,11 +505,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, X = torch.from_numpy(imgi).to(device) lbl = torch.from_numpy(lbl).to(device) - if X.dtype != net.dtype: - X = X.to(net.dtype) - lbl = lbl.to(net.dtype) - - y = net(X)[0] + with torch.autocast(device_type=device.type, dtype=net.dtype): + y = net(X)[0] loss = _loss_fn_seg(lbl, y, device) if y.shape[1] > 3: loss3 = _loss_fn_class(lbl, y, class_weights=class_weights) @@ -537,11 +531,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, net.save_model(filename0) net.save_model(filename) - if original_net_dtype == torch.bfloat16: - train_logger.info(">>> converting network back to bfloat16") - net.dtype = torch.bfloat16 - # if original_net_dtype is not None: - # net.dtype = original_net_dtype - # net.to(original_net_dtype) + if original_net_dtype != torch.float32: + train_logger.info(f">>> converting network back to {original_net_dtype} after training") + net.dtype = original_net_dtype return filename, train_losses, test_losses