Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions cellpose/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def assign_device(use_torch=True, gpu=False, device=0):
return device, gpu


def _to_device(x, device):
def _to_device(x, device, dtype=torch.float32):
"""
Converts the input tensor or numpy array to the specified device.

Expand All @@ -121,7 +121,7 @@ def _to_device(x, device):
torch.Tensor: The converted tensor on the specified device.
"""
if not isinstance(x, torch.Tensor):
X = torch.from_numpy(x).to(device, dtype=torch.float32)
X = torch.from_numpy(x).to(device, dtype=dtype)
return X
else:
return x
Expand All @@ -137,7 +137,8 @@ def _from_device(X):
Returns:
numpy.ndarray: The converted NumPy array.
"""
x = X.detach().cpu().numpy()
# The cast is so numpy conversion always works
x = X.detach().cpu().to(torch.float32).numpy()
return x


Expand All @@ -151,7 +152,7 @@ def _forward(net, x):
Returns:
Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features.
"""
X = _to_device(x, device=net.device)
X = _to_device(x, device=net.device, dtype=net.dtype)
net.eval()
with torch.no_grad():
y, style = net(X)[:2]
Expand Down
6 changes: 4 additions & 2 deletions cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class CellposeModel():
"""

def __init__(self, gpu=False, pretrained_model="cpsam", model_type=None,
diam_mean=None, device=None, nchan=None):
diam_mean=None, device=None, nchan=None, use_bfloat16=True):
"""
Initialize the CellposeModel.

Expand All @@ -99,6 +99,7 @@ def __init__(self, gpu=False, pretrained_model="cpsam", model_type=None,
model_type (str, optional): Any model that is available in the GUI, use name in GUI e.g. "livecell" (can be user-trained or model zoo).
diam_mean (float, optional): Mean "diameter", 30. is built-in value for "cyto" model; 17. is built-in value for "nuclei" model; if saved in custom model file (cellpose>=2.0) then it will be loaded automatically and overwrite this value.
device (torch device, optional): Device used for model running / training (torch.device("cuda") or torch.device("cpu")), overrides gpu input, recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")).
use_bfloat16 (bool, optional): Use 16bit float precision instead of 32bit for model weights. Default to 16bit (True).
"""
if diam_mean is not None:
models_logger.warning(
Expand Down Expand Up @@ -139,7 +140,8 @@ def __init__(self, gpu=False, pretrained_model="cpsam", model_type=None,
)

self.pretrained_model = pretrained_model
self.net = Transformer().to(self.device)
dtype = torch.bfloat16 if use_bfloat16 else torch.float32
self.net = Transformer(dtype=dtype).to(self.device)

if os.path.exists(self.pretrained_model):
models_logger.info(f">>>> loading model {self.pretrained_model}")
Expand Down
7 changes: 6 additions & 1 deletion cellpose/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _loss_fn_seg(lbl, y, device):
veci = 5. * lbl[:, -2:]
loss = criterion(y[:, -3:-1], veci)
loss /= 2.
loss2 = criterion2(y[:, -1], (lbl[:, -3] > 0.5).float())
loss2 = criterion2(y[:, -1], (lbl[:, -3] > 0.5).to(y.dtype))
loss = loss + loss2
return loss

Expand Down Expand Up @@ -454,6 +454,11 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
# network and loss optimization
X = torch.from_numpy(imgi).to(device)
lbl = torch.from_numpy(lbl).to(device)

if X.dtype != net.dtype:
X = X.to(net.dtype)
lbl = lbl.to(net.dtype)

y = net(X)[0]
loss = _loss_fn_seg(lbl, y, device)
if y.shape[1] > 3:
Expand Down
9 changes: 7 additions & 2 deletions cellpose/vit_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class Transformer(nn.Module):
def __init__(self, backbone="vit_l", ps=8, nout=3, bsize=256, rdrop=0.4,
checkpoint=None):
checkpoint=None, dtype=torch.float32):
super(Transformer, self).__init__()

# instantiate the vit model, default to not loading SAM
Expand Down Expand Up @@ -49,6 +49,8 @@ def __init__(self, backbone="vit_l", ps=8, nout=3, bsize=256, rdrop=0.4,
for blk in self.encoder.blocks:
blk.window_size = 0

self.dtype = dtype

def forward(self, x):
# same progression as SAM until readout
x = self.encoder.patch_embed(x)
Expand All @@ -59,7 +61,7 @@ def forward(self, x):
if self.training and self.rdrop > 0:
nlay = len(self.encoder.blocks)
rdrop = (torch.rand((len(x), nlay), device=x.device) <
torch.linspace(0, self.rdrop, nlay, device=x.device)).float()
torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype)
for i, blk in enumerate(self.encoder.blocks):
mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
x = x * mask + blk(x) * (1-mask)
Expand Down Expand Up @@ -90,6 +92,9 @@ def load_model(self, PATH, device, strict = False):
else:
self.load_state_dict(state_dict, strict = strict)

if self.dtype != torch.float32:
self = self.to(self.dtype)


@property
def device(self):
Expand Down
Loading