diff --git a/doctr/models/classification/predictor/pytorch.py b/doctr/models/classification/predictor/pytorch.py index af59b4a8c7..280743af4d 100644 --- a/doctr/models/classification/predictor/pytorch.py +++ b/doctr/models/classification/predictor/pytorch.py @@ -37,6 +37,9 @@ def forward( self, inputs: list[np.ndarray], ) -> list[list[int] | list[float]]: + if len(inputs) == 0: + return [[], [], []] + # Dimension check if any(input.ndim != 3 for input in inputs): raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")