Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 10 additions & 19 deletions cellpose/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
Train the network with images for segmentation.

Args:
net (object): The network model to train. If `net` is a bfloat16 model on MPS, it will be converted to float32 for training. The saved models will be in float32, but the original model will be returned in bfloat16 for consistency. CUDA/CPU will train in bfloat16 if that is the provided net dtype.
net (object): The network model to train. If `net` is a bfloat16 model it will be converted to float32 for training. The saved models will be in float32, but the original model will be returned as the original dtype for consistency.
train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None.
train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None.
train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None.
Expand Down Expand Up @@ -356,13 +356,11 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,

device = net.device

original_net_dtype = None
if device.type == 'mps' and net.dtype == torch.bfloat16:
original_net_dtype = net.dtype
if net.dtype == torch.bfloat16:
# NOTE: this produces a side effect of returning a network that is not of a guaranteed dtype \
original_net_dtype = torch.bfloat16
train_logger.warning("Training with bfloat16 on MPS is not supported, using float32 network instead")
train_logger.info(">>> converting bfloat16 network to float32 for training")
net.dtype = torch.float32
net.to(torch.float32)

scale_range = 0.5 if scale_range is None else scale_range

Expand Down Expand Up @@ -462,11 +460,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
X = torch.from_numpy(imgi).to(device)
lbl = torch.from_numpy(lbl).to(device)

if X.dtype != net.dtype:
X = X.to(net.dtype)
lbl = lbl.to(net.dtype)

y = net(X)[0]
with torch.autocast(device_type=device.type, dtype=net.dtype):
y = net(X)[0]
loss = _loss_fn_seg(lbl, y, device)
if y.shape[1] > 3:
loss3 = _loss_fn_class(lbl, y, class_weights=class_weights)
Expand Down Expand Up @@ -510,11 +505,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
X = torch.from_numpy(imgi).to(device)
lbl = torch.from_numpy(lbl).to(device)

if X.dtype != net.dtype:
X = X.to(net.dtype)
lbl = lbl.to(net.dtype)

y = net(X)[0]
with torch.autocast(device_type=device.type, dtype=net.dtype):
y = net(X)[0]
loss = _loss_fn_seg(lbl, y, device)
if y.shape[1] > 3:
loss3 = _loss_fn_class(lbl, y, class_weights=class_weights)
Expand All @@ -539,9 +531,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
net.save_model(filename0)

net.save_model(filename)

if original_net_dtype is not None:
if original_net_dtype != torch.float32:
train_logger.info(f">>> converting network back to {original_net_dtype} after training")
net.dtype = original_net_dtype
net.to(original_net_dtype)

return filename, train_losses, test_losses
28 changes: 25 additions & 3 deletions cellpose/vit_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def __init__(self, backbone="vit_l", ps=8, nout=3, bsize=256, rdrop=0.4,
for blk in self.encoder.blocks:
blk.window_size = 0

self.dtype = dtype
if self.dtype != torch.float32:
self = self.to(self.dtype)
self._dtype = dtype
if dtype != torch.float32:
self.dtype = dtype

def forward(self, x):
# same progression as SAM until readout
Expand Down Expand Up @@ -90,6 +90,7 @@ def load_model(self, PATH, device, strict = False):
if w2_data == None:
raise ValueError('This model does not appear to be a CP4 model. CP3 models are not compatible with CP4.')

# models are always saved as float32
if keys[0][:7] == "module.":
from collections import OrderedDict
new_state_dict = OrderedDict()
Expand All @@ -103,6 +104,27 @@ def load_model(self, PATH, device, strict = False):
if self.dtype != torch.float32:
self = self.to(self.dtype)

@property
def dtype(self):
"""
Get the data type of the model.

Returns:
torch.dtype: The data type of the model.
"""
return self._dtype

@dtype.setter
def dtype(self, value):
"""
Set the data type of the model.

Args:
value (torch.dtype): The data type to set for the model.
"""
if self._dtype != value:
self.to(value)
self._dtype = value

@property
def device(self):
Expand Down