diff --git a/cellpose/train.py b/cellpose/train.py index 53a97a33..95ec4951 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -502,6 +502,11 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, xy=(bsize, bsize))[:2] X = torch.from_numpy(imgi).to(device) lbl = torch.from_numpy(lbl).to(device) + + if X.dtype != net.dtype: + X = X.to(net.dtype) + lbl = lbl.to(net.dtype) + y = net(X)[0] loss = _loss_fn_seg(lbl, y, device) if y.shape[1] > 3: