diff --git a/py/AILab_RMBG.py b/py/AILab_RMBG.py index 06c5a63..11ee1d6 100644 --- a/py/AILab_RMBG.py +++ b/py/AILab_RMBG.py @@ -87,6 +87,11 @@ def handle_model_error(message): print(f"[RMBG ERROR] {message}") raise RuntimeError(message) +def normalize_process_res(process_res, step=128, minimum=256, maximum=2048): + process_res = max(minimum, min(maximum, int(process_res))) + normalized = max(minimum, (process_res // step) * step) + return min(maximum, normalized) + class BaseModelLoader: def __init__(self): self.model = None @@ -243,9 +248,30 @@ def process_image(self, images, model_name, params): try: self.load_model(model_name) + normalized_process_res = normalize_process_res(params["process_res"]) + + def extract_results(outputs): + if isinstance(outputs, list) and len(outputs) > 0: + return outputs[-1].sigmoid().cpu() + if isinstance(outputs, dict) and 'logits' in outputs: + return outputs['logits'].sigmoid().cpu() + if isinstance(outputs, torch.Tensor): + return outputs.sigmoid().cpu() + + try: + if hasattr(outputs, 'last_hidden_state'): + return outputs.last_hidden_state.sigmoid().cpu() + for _, value in outputs.items(): + if isinstance(value, torch.Tensor): + return value.sigmoid().cpu() + except Exception: + pass + + handle_model_error("Unable to recognize model output format") + # Prepare batch processing transform_image = transforms.Compose([ - transforms.Resize((params["process_res"], params["process_res"])), + transforms.Resize((normalized_process_res, normalized_process_res)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) @@ -262,40 +288,43 @@ def process_image(self, images, model_name, params): input_batch = torch.cat(input_tensors, dim=0).to(device) with torch.no_grad(): - outputs = self.model(input_batch) - - if isinstance(outputs, list) and len(outputs) > 0: - results = outputs[-1].sigmoid().cpu() - elif isinstance(outputs, dict) and 'logits' in outputs: - results = outputs['logits'].sigmoid().cpu() - elif isinstance(outputs, torch.Tensor): - results = outputs.sigmoid().cpu() - else: - try: - if hasattr(outputs, 'last_hidden_state'): - results = outputs.last_hidden_state.sigmoid().cpu() - else: - for k, v in outputs.items(): - if isinstance(v, torch.Tensor): - results = v.sigmoid().cpu() - break - except: - handle_model_error("Unable to recognize model output format") - - masks = [] - - for i, (result, (orig_w, orig_h)) in enumerate(zip(results, original_sizes)): - result = result.squeeze() - result = result * (1 + (1 - params["sensitivity"])) - result = torch.clamp(result, 0, 1) - - result = F.interpolate(result.unsqueeze(0).unsqueeze(0), - size=(orig_h, orig_w), - mode='bilinear').squeeze() - - masks.append(tensor2pil(result)) + try: + results = extract_results(self.model(input_batch)) + + masks = [] + + for result, (orig_w, orig_h) in zip(results, original_sizes): + result = result.squeeze() + result = result * (1 + (1 - params["sensitivity"])) + result = torch.clamp(result, 0, 1) + + result = F.interpolate(result.unsqueeze(0).unsqueeze(0), + size=(orig_h, orig_w), + mode='bilinear').squeeze() + + masks.append(tensor2pil(result)) + + return masks + except Exception as batch_error: + if len(images) == 1 or "Sizes of tensors must match" not in str(batch_error): + raise + + print("[RMBG INFO] Batch inference failed due to tensor size mismatch; retrying images one by one.") + masks = [] + for img, (orig_w, orig_h) in zip(images, original_sizes): + single_input = transform_image(tensor2pil(img)).unsqueeze(0).to(device) + single_results = extract_results(self.model(single_input)) + result = single_results[0].squeeze() + result = result * (1 + (1 - params["sensitivity"])) + result = torch.clamp(result, 0, 1) + + result = F.interpolate(result.unsqueeze(0).unsqueeze(0), + size=(orig_h, orig_w), + mode='bilinear').squeeze() + + masks.append(tensor2pil(result)) - return masks + return masks except Exception as e: handle_model_error(f"Error in batch processing: {str(e)}") @@ -541,7 +570,7 @@ def INPUT_TYPES(s): }, "optional": { "sensitivity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": tooltips["sensitivity"]}), - "process_res": ("INT", {"default": 1024, "min": 256, "max": 2048, "step": 8, "tooltip": tooltips["process_res"]}), + "process_res": ("INT", {"default": 1024, "min": 256, "max": 2048, "step": 128, "tooltip": tooltips["process_res"]}), "mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1, "tooltip": tooltips["mask_blur"]}), "mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1, "tooltip": tooltips["mask_offset"]}), "invert_output": ("BOOLEAN", {"default": False, "tooltip": tooltips["invert_output"]}),