diff --git a/doctr/models/preprocessor/pytorch.py b/doctr/models/preprocessor/pytorch.py index 4a2013c9c1..4f7c4949bd 100644 --- a/doctr/models/preprocessor/pytorch.py +++ b/doctr/models/preprocessor/pytorch.py @@ -4,7 +4,7 @@ # See LICENSE or go to for full license details. import math -from typing import Any +from typing import Any, overload import numpy as np import torch @@ -43,47 +43,84 @@ def __init__( # Perform the division by 255 at the same time self.normalize = T.Normalize(mean, std) - def batch_inputs(self, samples: list[torch.Tensor]) -> list[torch.Tensor]: + @overload + def batch_inputs( + self, + samples: list[torch.Tensor], + ) -> list[torch.Tensor]: ... + + @overload + def batch_inputs( + self, + samples: list[tuple[torch.Tensor, torch.Tensor]], + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: ... + + def batch_inputs( + self, samples: list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]] + ) -> list[torch.Tensor] | tuple[list[torch.Tensor], list[torch.Tensor]]: """Gather samples into batches for inference purposes Args: - samples: list of samples of shape (C, H, W) + samples: list of samples of shape (C, H, W) or + list of tuples of samples and masks of shape (C, H, W) and (1, H, W) respectively Returns: - list of batched samples (*, C, H, W) + list of batched samples (*, C, H, W) or tuple of lists of batched samples and masks """ num_batches = int(math.ceil(len(samples) / self.batch_size)) - batches = [ - torch.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], dim=0) - for idx in range(int(num_batches)) - ] - return batches + if isinstance(samples[0], tuple): + imgs, masks = zip(*samples) + + img_batches = [ + torch.stack(imgs[idx * self.batch_size : min((idx + 1) * self.batch_size, len(imgs))], dim=0) + for idx in range(num_batches) + ] + + mask_batches = [ + torch.stack(masks[idx * self.batch_size : min((idx + 1) * self.batch_size, len(masks))], dim=0) + for idx in range(num_batches) + ] - def sample_transforms(self, x: np.ndarray) -> torch.Tensor: + return img_batches, mask_batches + + return [ + torch.stack( + samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], + dim=0, + ) + for idx in range(num_batches) + ] + + def sample_transforms(self, x: np.ndarray) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if x.ndim != 3: raise AssertionError("expected list of 3D Tensors") if x.dtype not in (np.uint8, np.float32, np.float16): raise TypeError("unsupported data type for numpy.ndarray") tensor = torch.from_numpy(x.copy()).permute(2, 0, 1) # Resizing - tensor = self.resize(tensor) + if self.resize.return_padding_mask: + tensor, mask = self.resize(tensor) + else: + tensor = self.resize(tensor) # Data type if tensor.dtype == torch.uint8: tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1) else: tensor = tensor.to(dtype=torch.float32) - return tensor + return (tensor, mask) if self.resize.return_padding_mask else tensor - def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[torch.Tensor]: + def __call__( + self, x: np.ndarray | list[np.ndarray] + ) -> list[torch.Tensor] | tuple[list[torch.Tensor], list[torch.Tensor]]: """Prepare document data for model forwarding Args: x: list of images (np.array) or a single image (np.array) of shape (H, W, C) Returns: - list of page batches (*, C, H, W) ready for model inference + list of page batches (*, C, H, W) or tuple of lists of page batches and padding masks """ # Input type check if isinstance(x, np.ndarray): @@ -103,17 +140,28 @@ def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[torch.Tensor]: tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1) else: tensor = tensor.to(dtype=torch.float32) - batches = [tensor] + img_batches = [tensor] + + if self.resize.return_padding_mask: + h, w = self.resize.size + mask = torch.zeros((x.shape[0], h, w), dtype=torch.bool) + mask_batches = [mask] elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x): # Sample transform (to tensor, resize) samples = list(multithread_exec(self.sample_transforms, x)) # Batching - batches = self.batch_inputs(samples) + if self.resize.return_padding_mask: + img_batches, mask_batches = self.batch_inputs(samples) + else: + img_batches = self.batch_inputs(samples) else: raise TypeError(f"invalid input type: {type(x)}") # Batch transforms (normalize) - batches = list(multithread_exec(self.normalize, batches)) + if self.resize.return_padding_mask: + img_batches = list(multithread_exec(self.normalize, img_batches)) + return img_batches, mask_batches - return batches + img_batches = list(multithread_exec(self.normalize, img_batches)) + return img_batches diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py index 6a52cfd5f5..760cb88c9f 100644 --- a/doctr/transforms/modules/pytorch.py +++ b/doctr/transforms/modules/pytorch.py @@ -41,6 +41,7 @@ class Resize(T.Resize): if True, the image will be resized to fit within the target size while maintaining its aspect ratio symmetric_pad: whether to symmetrically pad the image to the target size, if True, the image will be padded equally on both sides to fit the target size + return_padding_mask: whether to return a padding mask indicating the padded areas of the image """ def __init__( @@ -49,25 +50,43 @@ def __init__( interpolation=F.InterpolationMode.BILINEAR, preserve_aspect_ratio: bool = False, symmetric_pad: bool = False, + return_padding_mask: bool = False, ) -> None: super().__init__(size if isinstance(size, (list, tuple)) else (size, size), interpolation, antialias=True) self.preserve_aspect_ratio = preserve_aspect_ratio self.symmetric_pad = symmetric_pad + self.return_padding_mask = return_padding_mask def forward( self, img: torch.Tensor, target: np.ndarray | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]: + ) -> ( + torch.Tensor + | tuple[torch.Tensor, np.ndarray] + | tuple[torch.Tensor, np.ndarray, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor] + ): target_ratio = self.size[0] / self.size[1] actual_ratio = img.shape[-2] / img.shape[-1] if not self.preserve_aspect_ratio or (target_ratio == actual_ratio): # If we don't preserve the aspect ratio or the wanted aspect ratio is the same than the original one # We can use with the regular resize + img = super().forward(img) + + if self.return_padding_mask: + padding_mask = torch.zeros(self.size, dtype=torch.bool, device=img.device) + if target is not None: - return super().forward(img), target - return super().forward(img) + if self.return_padding_mask: + return img, target, padding_mask + return img, target + + if self.return_padding_mask: + return img, padding_mask + + return img else: # Resize if actual_ratio > target_ratio: @@ -87,6 +106,13 @@ def forward( # Pad image img = pad(img, _pad) + if self.return_padding_mask: + h, w = self.size + padding_mask = torch.zeros((h, w), dtype=torch.bool, device=img.device) + + left, right, top, bottom = _pad + padding_mask[top : h - bottom, left : w - right] = True + # In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio) if target is not None: if self.symmetric_pad: @@ -111,7 +137,14 @@ def forward( else: raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)") - return img, np.clip(target, 0, 1) + target = np.clip(target, 0, 1) + + if self.return_padding_mask: + return img, target, padding_mask + return img, target + + if self.return_padding_mask: + return img, padding_mask return img diff --git a/tests/pytorch/test_models_preprocessor_pt.py b/tests/pytorch/test_models_preprocessor_pt.py index e3b5ff8e02..9b4d9a668d 100644 --- a/tests/pytorch/test_models_preprocessor_pt.py +++ b/tests/pytorch/test_models_preprocessor_pt.py @@ -40,4 +40,22 @@ def test_preprocessor(batch_size, output_size, input_tensor, expected_batches, e assert all(b.dtype == torch.float32 for b in out) assert all(b.shape[-2:] == output_size for b in out) assert all(torch.all(b == expected_value) for b in out) + + # Tests with padding mask + processor_mask = PreProcessor(output_size, batch_size, return_padding_mask=True) + with torch.no_grad(): + out_masked = processor_mask(input_tensor) + imgs, masks = out_masked + + assert isinstance(imgs, list) and len(imgs) == expected_batches + assert isinstance(masks, list) and len(masks) == expected_batches + assert all(isinstance(b, torch.Tensor) for b in imgs) + assert all(isinstance(m, torch.Tensor) for m in masks) + assert all(b.dtype == torch.float32 for b in imgs) + assert all(m.dtype == torch.bool for m in masks) + assert all(b.shape[-2:] == output_size for b in imgs) + assert all(m.shape[-2:] == output_size for m in masks) + + # mask sanity: should contain both 0 and 1 somewhere (padding depends on Resize logic) + assert all(torch.is_tensor(m) for m in masks) assert len(repr(processor).split("\n")) == 4 diff --git a/tests/pytorch/test_transforms_pt.py b/tests/pytorch/test_transforms_pt.py index b29f88e119..3ed262608b 100644 --- a/tests/pytorch/test_transforms_pt.py +++ b/tests/pytorch/test_transforms_pt.py @@ -29,6 +29,14 @@ def test_resize(): assert out.shape[-2:] == output_size assert repr(transfo) == "Resize(output_size=(32, 32), interpolation='bilinear')" + # Test return_padding_mask without aspect ratio + transfo = Resize(output_size, return_padding_mask=True) + out, mask = transfo(input_t) + assert out.shape[-2:] == output_size + assert mask.shape == output_size + assert mask.dtype == torch.bool + assert torch.all(mask == 0) + # Test with preserve_aspect_ratio output_size = (32, 32) input_t = torch.ones((3, 32, 64), dtype=torch.float32) @@ -40,12 +48,30 @@ def test_resize(): assert not torch.all(out == 1) assert torch.all(out[:, -1] == 0) and torch.all(out[:, 0] == 1) + # Asymmetric padding mask + transfo = Resize(output_size, preserve_aspect_ratio=True, return_padding_mask=True) + out, mask = transfo(input_t) + assert mask.shape == output_size + assert mask.dtype == torch.bool + assert mask.any() + assert torch.any(mask[:, -5:]) + assert torch.any(mask[:, 5:]) + # Symmetric padding transfo = Resize(32, preserve_aspect_ratio=True, symmetric_pad=True) out = transfo(input_t) assert out.shape[-2:] == output_size assert torch.all(out[:, 0] == 0) and torch.all(out[:, -1] == 0) + # Symmetric padding mask + transfo = Resize(32, preserve_aspect_ratio=True, symmetric_pad=True, return_padding_mask=True) + out, mask = transfo(input_t) + assert mask.shape == output_size + assert mask.dtype == torch.bool + assert mask.any() + assert torch.any(mask[:, :5]) + assert torch.any(mask[:, -5:]) + expected = "Resize(output_size=(32, 32), interpolation='bilinear', preserve_aspect_ratio=True, symmetric_pad=True)" assert repr(transfo) == expected @@ -70,26 +96,30 @@ def test_resize(): for symmetric_pad in padding: # Test with target boxes target_boxes = np.array([[0.1, 0.1, 0.3, 0.4], [0.2, 0.2, 0.8, 0.8]]) - transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=symmetric_pad) + transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=symmetric_pad, return_padding_mask=True) input_t = torch.ones((3, 32, 64), dtype=torch.float32) - out, new_target = transfo(input_t, target_boxes) + out, new_target, mask = transfo(input_t, target_boxes) assert out.shape[-2:] == (64, 64) assert new_target.shape == target_boxes.shape assert np.all((0 <= new_target) & (new_target <= 1)) + assert mask.shape == (64, 64) + assert mask.dtype == torch.bool # Test with target polygons target_boxes = np.array([ [[0.1, 0.1], [0.9, 0.1], [0.9, 0.9], [0.1, 0.9]], [[0.2, 0.2], [0.8, 0.2], [0.8, 0.8], [0.2, 0.8]], ]) - transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=symmetric_pad) + transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=symmetric_pad, return_padding_mask=True) input_t = torch.ones((3, 32, 64), dtype=torch.float32) - out, new_target = transfo(input_t, target_boxes) + out, new_target, mask = transfo(input_t, target_boxes) assert out.shape[-2:] == (64, 64) assert new_target.shape == target_boxes.shape assert np.all((0 <= new_target) & (new_target <= 1)) + assert mask.shape == (64, 64) + assert mask.dtype == torch.bool # Test with invalid target shape input_t = torch.ones((3, 32, 64), dtype=torch.float32)