11from __future__ import annotations
22
3+ from collections .abc import Callable
4+
35import numpy as np
46
7+ from api import Progress
8+
59from ...utils .utils import get_h_w_c
6- from ..image_op import ImageOp , clipped
10+ from ..image_op import ImageOp , clipped , to_op
711from ..image_utils import as_target_channels
812
13+ ProgressImageOp = Callable [[np .ndarray , Progress | None ], np .ndarray ]
14+ """
15+ An image processing operation that takes an image and progress, and produces a new image.
16+ """
17+
918
1019def with_black_and_white_backgrounds (img : np .ndarray ) -> tuple [np .ndarray , np .ndarray ]:
1120 c = get_h_w_c (img )[2 ]
@@ -32,9 +41,10 @@ def convenient_upscale(
3241 img : np .ndarray ,
3342 model_in_nc : int ,
3443 model_out_nc : int ,
35- upscale : ImageOp ,
44+ upscale : ProgressImageOp ,
3645 separate_alpha : bool = False ,
3746 clip : bool = True ,
47+ progress : Progress | None = None ,
3848) -> np .ndarray :
3949 """
4050 Upscales the given image in an intuitive/convenient way.
@@ -45,45 +55,60 @@ def convenient_upscale(
4555 Additionally, guarantees that the number of channels of the output image will match
4656 that of the input image in cases where `model_in_nc` == `model_out_nc`, and match
4757 `model_out_nc` otherwise.
58+
59+ When multiple upscale operations are needed (e.g., for RGBA images), progress is
60+ automatically divided between them using sub-progress ranges.
4861 """
4962 in_img_c = get_h_w_c (img )[2 ]
5063
51- if clip :
52- upscale = clipped (upscale )
64+ def make_op (p : Progress | None ) -> ImageOp :
65+ """Create an ImageOp with the given progress bound."""
66+ op = to_op (upscale )(p )
67+ return clipped (op ) if clip else op
5368
5469 if model_in_nc != model_out_nc :
55- return upscale (as_target_channels (img , model_in_nc , True ))
70+ return make_op ( progress ) (as_target_channels (img , model_in_nc , True ))
5671
5772 if in_img_c == model_in_nc :
58- return upscale (img )
73+ return make_op ( progress ) (img )
5974
6075 if in_img_c == 4 :
6176 # Ignore alpha if single-color or not being replaced
6277 unique = np .unique (img [:, :, 3 ])
6378 if len (unique ) == 1 :
79+ op = make_op (progress )
6480 rgb = as_target_channels (
65- upscale (as_target_channels (img [:, :, :3 ], model_in_nc , True )), 3 , True
81+ op (as_target_channels (img [:, :, :3 ], model_in_nc , True )), 3 , True
6682 )
6783 unique_alpha = np .full (rgb .shape [:- 1 ], unique [0 ], np .float32 )
6884 return np .dstack ((rgb , unique_alpha ))
6985
7086 if separate_alpha :
7187 # Upscale the RGB channels and alpha channel separately
88+ # Split progress: 50% for RGB, 50% for alpha
89+ rgb_op = make_op (progress .sub_progress (0 , 0.5 ) if progress else None )
90+ alpha_op = make_op (progress .sub_progress (0.5 , 0.5 ) if progress else None )
91+
7292 rgb = as_target_channels (
73- upscale (as_target_channels (img [:, :, :3 ], model_in_nc , True )), 3 , True
93+ rgb_op (as_target_channels (img [:, :, :3 ], model_in_nc , True )), 3 , True
7494 )
7595 alpha = denoise_and_flatten_alpha (
76- upscale (as_target_channels (img [:, :, 3 ], model_in_nc , True ))
96+ alpha_op (as_target_channels (img [:, :, 3 ], model_in_nc , True ))
7797 )
7898 return np .dstack ((rgb , alpha ))
7999 else :
80100 # Transparency hack (white/black background difference alpha)
101+ # Split progress: 50% for black background, 50% for white background
81102 black , white = with_black_and_white_backgrounds (img )
103+
104+ black_op = make_op (progress .sub_progress (0 , 0.5 ) if progress else None )
105+ white_op = make_op (progress .sub_progress (0.5 , 0.5 ) if progress else None )
106+
82107 black_up = as_target_channels (
83- upscale (as_target_channels (black , model_in_nc , True )), 3 , True
108+ black_op (as_target_channels (black , model_in_nc , True )), 3 , True
84109 )
85110 white_up = as_target_channels (
86- upscale (as_target_channels (white , model_in_nc , True )), 3 , True
111+ white_op (as_target_channels (white , model_in_nc , True )), 3 , True
87112 )
88113
89114 # Interpolate between the alpha values to get a more defined alpha
@@ -92,6 +117,4 @@ def convenient_upscale(
92117
93118 return np .dstack ((black_up , alpha ))
94119
95- return as_target_channels (
96- upscale (as_target_channels (img , model_in_nc , True )), in_img_c , True
97- )
120+ return make_op (progress )(as_target_channels (img , model_in_nc , True ))
0 commit comments