Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 66 additions & 18 deletions doctr/models/preprocessor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import math
from typing import Any
from typing import Any, overload

import numpy as np
import torch
Expand Down Expand Up @@ -43,47 +43,84 @@
# Perform the division by 255 at the same time
self.normalize = T.Normalize(mean, std)

def batch_inputs(self, samples: list[torch.Tensor]) -> list[torch.Tensor]:
@overload
def batch_inputs(
self,
samples: list[torch.Tensor],
) -> list[torch.Tensor]: ...

@overload
def batch_inputs(
self,
samples: list[tuple[torch.Tensor, torch.Tensor]],
) -> tuple[list[torch.Tensor], list[torch.Tensor]]: ...

def batch_inputs(
self, samples: list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]
) -> list[torch.Tensor] | tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Gather samples into batches for inference purposes

Args:
samples: list of samples of shape (C, H, W)
samples: list of samples of shape (C, H, W) or
list of tuples of samples and masks of shape (C, H, W) and (1, H, W) respectively

Returns:
list of batched samples (*, C, H, W)
list of batched samples (*, C, H, W) or tuple of lists of batched samples and masks
"""
num_batches = int(math.ceil(len(samples) / self.batch_size))
batches = [
torch.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], dim=0)
for idx in range(int(num_batches))
]

return batches
if isinstance(samples[0], tuple):
imgs, masks = zip(*samples)

img_batches = [
torch.stack(imgs[idx * self.batch_size : min((idx + 1) * self.batch_size, len(imgs))], dim=0)
for idx in range(num_batches)
]

mask_batches = [
torch.stack(masks[idx * self.batch_size : min((idx + 1) * self.batch_size, len(masks))], dim=0)
for idx in range(num_batches)
]

def sample_transforms(self, x: np.ndarray) -> torch.Tensor:
return img_batches, mask_batches

return [
torch.stack(
samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))],
dim=0,
)
for idx in range(num_batches)
]

def sample_transforms(self, x: np.ndarray) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if x.ndim != 3:
raise AssertionError("expected list of 3D Tensors")
if x.dtype not in (np.uint8, np.float32, np.float16):
raise TypeError("unsupported data type for numpy.ndarray")
tensor = torch.from_numpy(x.copy()).permute(2, 0, 1)
# Resizing
tensor = self.resize(tensor)
if self.resize.return_padding_mask:
tensor, mask = self.resize(tensor)
else:
tensor = self.resize(tensor)
# Data type
if tensor.dtype == torch.uint8:
tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
else:
tensor = tensor.to(dtype=torch.float32)

return tensor
return (tensor, mask) if self.resize.return_padding_mask else tensor

def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[torch.Tensor]:
def __call__(
self, x: np.ndarray | list[np.ndarray]
) -> list[torch.Tensor] | tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Prepare document data for model forwarding

Args:
x: list of images (np.array) or a single image (np.array) of shape (H, W, C)

Returns:
list of page batches (*, C, H, W) ready for model inference
list of page batches (*, C, H, W) or tuple of lists of page batches and padding masks
"""
# Input type check
if isinstance(x, np.ndarray):
Expand All @@ -103,17 +140,28 @@
tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
else:
tensor = tensor.to(dtype=torch.float32)
batches = [tensor]
img_batches = [tensor]

if self.resize.return_padding_mask:
h, w = self.resize.size
mask = torch.zeros((x.shape[0], h, w), dtype=torch.bool)
mask_batches = [mask]

elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x):
# Sample transform (to tensor, resize)
samples = list(multithread_exec(self.sample_transforms, x))
# Batching
batches = self.batch_inputs(samples)
if self.resize.return_padding_mask:
img_batches, mask_batches = self.batch_inputs(samples)
else:
img_batches = self.batch_inputs(samples)
else:
raise TypeError(f"invalid input type: {type(x)}")

# Batch transforms (normalize)
batches = list(multithread_exec(self.normalize, batches))
if self.resize.return_padding_mask:
img_batches = list(multithread_exec(self.normalize, img_batches))
return img_batches, mask_batches

Check warning on line 164 in doctr/models/preprocessor/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/preprocessor/pytorch.py#L164

Using variable 'mask_batches' before assignment

return batches
img_batches = list(multithread_exec(self.normalize, img_batches))
return img_batches
41 changes: 37 additions & 4 deletions doctr/transforms/modules/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Resize(T.Resize):
if True, the image will be resized to fit within the target size while maintaining its aspect ratio
symmetric_pad: whether to symmetrically pad the image to the target size,
if True, the image will be padded equally on both sides to fit the target size
return_padding_mask: whether to return a padding mask indicating the padded areas of the image
"""

def __init__(
Expand All @@ -49,25 +50,43 @@ def __init__(
interpolation=F.InterpolationMode.BILINEAR,
preserve_aspect_ratio: bool = False,
symmetric_pad: bool = False,
return_padding_mask: bool = False,
) -> None:
super().__init__(size if isinstance(size, (list, tuple)) else (size, size), interpolation, antialias=True)
self.preserve_aspect_ratio = preserve_aspect_ratio
self.symmetric_pad = symmetric_pad
self.return_padding_mask = return_padding_mask

def forward(
self,
img: torch.Tensor,
target: np.ndarray | None = None,
) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
) -> (
torch.Tensor
| tuple[torch.Tensor, np.ndarray]
| tuple[torch.Tensor, np.ndarray, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor]
):
target_ratio = self.size[0] / self.size[1]
actual_ratio = img.shape[-2] / img.shape[-1]

if not self.preserve_aspect_ratio or (target_ratio == actual_ratio):
# If we don't preserve the aspect ratio or the wanted aspect ratio is the same than the original one
# We can use with the regular resize
img = super().forward(img)

if self.return_padding_mask:
padding_mask = torch.zeros(self.size, dtype=torch.bool, device=img.device)

if target is not None:
return super().forward(img), target
return super().forward(img)
if self.return_padding_mask:
return img, target, padding_mask
return img, target

if self.return_padding_mask:
return img, padding_mask

return img
else:
# Resize
if actual_ratio > target_ratio:
Expand All @@ -87,6 +106,13 @@ def forward(
# Pad image
img = pad(img, _pad)

if self.return_padding_mask:
h, w = self.size
padding_mask = torch.zeros((h, w), dtype=torch.bool, device=img.device)

left, right, top, bottom = _pad
padding_mask[top : h - bottom, left : w - right] = True

# In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio)
if target is not None:
if self.symmetric_pad:
Expand All @@ -111,7 +137,14 @@ def forward(
else:
raise AssertionError("Boxes should be in the format (n_boxes, 4, 2) or (n_boxes, 4)")

return img, np.clip(target, 0, 1)
target = np.clip(target, 0, 1)

if self.return_padding_mask:
return img, target, padding_mask
return img, target

if self.return_padding_mask:
return img, padding_mask

return img

Expand Down
18 changes: 18 additions & 0 deletions tests/pytorch/test_models_preprocessor_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,22 @@ def test_preprocessor(batch_size, output_size, input_tensor, expected_batches, e
assert all(b.dtype == torch.float32 for b in out)
assert all(b.shape[-2:] == output_size for b in out)
assert all(torch.all(b == expected_value) for b in out)

# Tests with padding mask
processor_mask = PreProcessor(output_size, batch_size, return_padding_mask=True)
with torch.no_grad():
out_masked = processor_mask(input_tensor)
imgs, masks = out_masked

assert isinstance(imgs, list) and len(imgs) == expected_batches
assert isinstance(masks, list) and len(masks) == expected_batches
assert all(isinstance(b, torch.Tensor) for b in imgs)
assert all(isinstance(m, torch.Tensor) for m in masks)
assert all(b.dtype == torch.float32 for b in imgs)
assert all(m.dtype == torch.bool for m in masks)
assert all(b.shape[-2:] == output_size for b in imgs)
assert all(m.shape[-2:] == output_size for m in masks)

# mask sanity: should contain both 0 and 1 somewhere (padding depends on Resize logic)
assert all(torch.is_tensor(m) for m in masks)
assert len(repr(processor).split("\n")) == 4
38 changes: 34 additions & 4 deletions tests/pytorch/test_transforms_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def test_resize():
assert out.shape[-2:] == output_size
assert repr(transfo) == "Resize(output_size=(32, 32), interpolation='bilinear')"

# Test return_padding_mask without aspect ratio
transfo = Resize(output_size, return_padding_mask=True)
out, mask = transfo(input_t)
assert out.shape[-2:] == output_size
assert mask.shape == output_size
assert mask.dtype == torch.bool
assert torch.all(mask == 0)

# Test with preserve_aspect_ratio
output_size = (32, 32)
input_t = torch.ones((3, 32, 64), dtype=torch.float32)
Expand All @@ -40,12 +48,30 @@ def test_resize():
assert not torch.all(out == 1)
assert torch.all(out[:, -1] == 0) and torch.all(out[:, 0] == 1)

# Asymmetric padding mask
transfo = Resize(output_size, preserve_aspect_ratio=True, return_padding_mask=True)
out, mask = transfo(input_t)
assert mask.shape == output_size
assert mask.dtype == torch.bool
assert mask.any()
assert torch.any(mask[:, -5:])
assert torch.any(mask[:, 5:])

# Symmetric padding
transfo = Resize(32, preserve_aspect_ratio=True, symmetric_pad=True)
out = transfo(input_t)
assert out.shape[-2:] == output_size
assert torch.all(out[:, 0] == 0) and torch.all(out[:, -1] == 0)

# Symmetric padding mask
transfo = Resize(32, preserve_aspect_ratio=True, symmetric_pad=True, return_padding_mask=True)
out, mask = transfo(input_t)
assert mask.shape == output_size
assert mask.dtype == torch.bool
assert mask.any()
assert torch.any(mask[:, :5])
assert torch.any(mask[:, -5:])

expected = "Resize(output_size=(32, 32), interpolation='bilinear', preserve_aspect_ratio=True, symmetric_pad=True)"
assert repr(transfo) == expected

Expand All @@ -70,26 +96,30 @@ def test_resize():
for symmetric_pad in padding:
# Test with target boxes
target_boxes = np.array([[0.1, 0.1, 0.3, 0.4], [0.2, 0.2, 0.8, 0.8]])
transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=symmetric_pad)
transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=symmetric_pad, return_padding_mask=True)
input_t = torch.ones((3, 32, 64), dtype=torch.float32)
out, new_target = transfo(input_t, target_boxes)
out, new_target, mask = transfo(input_t, target_boxes)

assert out.shape[-2:] == (64, 64)
assert new_target.shape == target_boxes.shape
assert np.all((0 <= new_target) & (new_target <= 1))
assert mask.shape == (64, 64)
assert mask.dtype == torch.bool

# Test with target polygons
target_boxes = np.array([
[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9], [0.1, 0.9]],
[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8], [0.2, 0.8]],
])
transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=symmetric_pad)
transfo = Resize((64, 64), preserve_aspect_ratio=True, symmetric_pad=symmetric_pad, return_padding_mask=True)
input_t = torch.ones((3, 32, 64), dtype=torch.float32)
out, new_target = transfo(input_t, target_boxes)
out, new_target, mask = transfo(input_t, target_boxes)

assert out.shape[-2:] == (64, 64)
assert new_target.shape == target_boxes.shape
assert np.all((0 <= new_target) & (new_target <= 1))
assert mask.shape == (64, 64)
assert mask.dtype == torch.bool

# Test with invalid target shape
input_t = torch.ones((3, 32, 64), dtype=torch.float32)
Expand Down
Loading