Skip to content

Commit 8ef0b4d

Browse files
Finish individual node progress feature (#3252)
* Add way for individual nodes to send progress events * Individual node progress for PyTorch tiling * Implement for onnx and ncnn as well * Add support for tensorrt progress * (wip) support multiple upscale passes * Better split progress wrapper * reuse existing split progress infra * Lint + format * Pyright fix * more pyright fixes
1 parent deff8b5 commit 8ef0b4d

File tree

15 files changed

+226
-68
lines changed

15 files changed

+226
-68
lines changed

backend/src/nodes/impl/ncnn/auto_split.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from ncnn import ncnn # type: ignore
1313

1414
use_gpu = False
15+
16+
from api import Progress
1517
from logger import logger
1618

1719
from ...utils.utils import get_h_w_c
@@ -27,6 +29,7 @@ def ncnn_auto_split(
2729
blob_vkallocator, # noqa: ANN001
2830
staging_vkallocator, # noqa: ANN001
2931
tiler: Tiler,
32+
progress: Progress | None = None,
3033
) -> np.ndarray:
3134
def upscale(img: np.ndarray, _: object):
3235
ex = net.create_extractor()
@@ -91,4 +94,4 @@ def upscale(img: np.ndarray, _: object):
9194
# Re-raise the exception if not an OOM error
9295
raise
9396

94-
return auto_split(img, upscale, tiler)
97+
return auto_split(img, upscale, tiler, progress=progress)

backend/src/nodes/impl/onnx/auto_split.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import onnxruntime as ort
88

9+
from api import Progress
910
from nodes.impl.onnx.model import SizeReq
1011

1112
from ..upscale.auto_split import Tiler, auto_split
@@ -101,6 +102,7 @@ def onnx_auto_split(
101102
change_shape: bool,
102103
tiler: Tiler,
103104
size_req: SizeReq | None = None,
105+
progress: Progress | None = None,
104106
) -> np.ndarray:
105107
input_name = session.get_inputs()[0].name
106108
output_name = session.get_outputs()[0].name
@@ -134,6 +136,6 @@ def upscale(img: np.ndarray, _: object):
134136
raise
135137

136138
try:
137-
return auto_split(img, upscale, tiler)
139+
return auto_split(img, upscale, tiler, progress=progress)
138140
finally:
139141
gc.collect()

backend/src/nodes/impl/pytorch/auto_split.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,19 +112,20 @@ def pytorch_auto_split(
112112
device: torch.device,
113113
use_fp16: bool,
114114
tiler: Tiler,
115-
progress: Progress,
115+
progress: Progress | None = None,
116116
) -> np.ndarray:
117117
dtype = torch.float16 if use_fp16 else torch.float32
118118
if model.dtype != dtype or model.device != device:
119119
model = model.to(device, dtype)
120120

121121
def upscale(img: np.ndarray, _: object):
122-
progress.check_aborted()
123-
if progress.paused:
124-
# clear resources before pausing
125-
gc.collect()
126-
safe_cuda_cache_empty()
127-
progress.suspend()
122+
if progress is not None:
123+
progress.check_aborted()
124+
if progress.paused:
125+
# clear resources before pausing
126+
gc.collect()
127+
safe_cuda_cache_empty()
128+
progress.suspend()
128129

129130
input_tensor = None
130131
try:
@@ -165,4 +166,4 @@ def upscale(img: np.ndarray, _: object):
165166
# Re-raise the exception if not an OOM error
166167
raise
167168

168-
return auto_split(img, upscale, tiler)
169+
return auto_split(img, upscale, tiler, progress=progress)

backend/src/nodes/impl/tensorrt/auto_split.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import numpy as np
88

9+
from api import Progress
10+
911
from ..upscale.auto_split import Tiler, auto_split
1012
from .inference import get_tensorrt_session
1113
from .model import TensorRTEngine
@@ -59,6 +61,7 @@ def tensorrt_auto_split(
5961
engine: TensorRTEngine,
6062
tiler: Tiler,
6163
gpu_index: int = 0,
64+
progress: Progress | None = None,
6265
) -> np.ndarray:
6366
"""
6467
Run TensorRT inference with automatic tiling for large images.
@@ -76,6 +79,12 @@ def tensorrt_auto_split(
7679
is_fp16 = engine.precision == "fp16"
7780

7881
def upscale(img: np.ndarray, _: object):
82+
if progress is not None:
83+
progress.check_aborted()
84+
if progress.paused:
85+
gc.collect()
86+
progress.suspend()
87+
7988
try:
8089
# Convert to appropriate precision
8190
lr_img = img.astype(np.float16) if is_fp16 else img.astype(np.float32)
@@ -113,6 +122,6 @@ def upscale(img: np.ndarray, _: object):
113122
raise
114123

115124
try:
116-
return auto_split(img, upscale, tiler)
125+
return auto_split(img, upscale, tiler, progress=progress)
117126
finally:
118127
gc.collect()

backend/src/nodes/impl/upscale/auto_split.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77

8+
from api import Progress
89
from logger import logger
910

1011
from ...utils.utils import Region, Size, get_h_w_c
@@ -25,6 +26,7 @@ def auto_split(
2526
upscale: SplitImageOp,
2627
tiler: Tiler,
2728
overlap: int = 16,
29+
progress: Progress | None = None,
2830
) -> np.ndarray:
2931
"""
3032
Splits the image into tiles according to the given tiler.
@@ -49,6 +51,7 @@ def auto_split(
4951
starting_tile_size=tiler.starting_tile_size(w, h, c),
5052
split_tile_size=tiler.split,
5153
overlap=overlap,
54+
progress=progress,
5255
)
5356

5457

@@ -62,6 +65,7 @@ def _exact_split(
6265
starting_tile_size: Size,
6366
split_tile_size: Callable[[Size], Size],
6467
overlap: int,
68+
progress: Progress | None = None,
6569
) -> np.ndarray:
6670
h, w, c = get_h_w_c(img)
6771
logger.debug(
@@ -89,6 +93,7 @@ def no_split_upscale(i: np.ndarray, r: Region) -> np.ndarray:
8993
exact_size=starting_tile_size,
9094
upscale=no_split_upscale,
9195
overlap=min(max_overlap, overlap),
96+
progress=progress,
9297
)
9398
except _SplitEx:
9499
starting_tile_size = split_tile_size(starting_tile_size)
@@ -102,6 +107,7 @@ def _max_split(
102107
starting_tile_size: Size,
103108
split_tile_size: Callable[[Size], Size],
104109
overlap: int,
110+
progress: Progress | None = None,
105111
) -> np.ndarray:
106112
"""
107113
Splits the image into tiles with at most the given tile size.
@@ -126,6 +132,8 @@ def _max_split(
126132
# the image might be small enough so that we don't have to split at all
127133
upscale_result = upscale(img, img_region)
128134
if not isinstance(upscale_result, Split):
135+
if progress is not None:
136+
progress.set_progress(1.0)
129137
return upscale_result
130138

131139
# the image was too large
@@ -161,6 +169,7 @@ def _max_split(
161169
tile_count_y = math.ceil(h / max_tile_size[1])
162170
tile_size_x = math.ceil(w / tile_count_x)
163171
tile_size_y = math.ceil(h / tile_count_y)
172+
total_tiles = tile_count_x * tile_count_y
164173

165174
logger.debug(
166175
"Currently %dx%d tiles each %dx%dpx.",
@@ -171,6 +180,7 @@ def _max_split(
171180
)
172181

173182
prev_row_result: TileBlender | None = None
183+
tiles_processed = 0
174184

175185
for y in range(tile_count_y):
176186
if y < start_y:
@@ -238,6 +248,11 @@ def _max_split(
238248
upscale_result, TileOverlap(pad.left * scale, pad.right * scale)
239249
)
240250

251+
# Report progress after each tile
252+
tiles_processed += 1
253+
if progress is not None:
254+
progress.set_progress(tiles_processed / total_tiles)
255+
241256
if restart:
242257
break
243258

backend/src/nodes/impl/upscale/basic_upscale.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
import numpy as np
66

7-
from nodes.impl.image_op import ImageOp
7+
from api import Progress
88
from nodes.impl.image_utils import BorderType, create_border
99
from nodes.impl.resize import ResizeFilter, resize
1010
from nodes.utils.utils import Padding, get_h_w_c
1111

12-
from .convenient_upscale import convenient_upscale
12+
from .convenient_upscale import ProgressImageOp, convenient_upscale
1313

1414

1515
@dataclass
@@ -49,20 +49,25 @@ def to_border_type(self) -> BorderType:
4949

5050
def _custom_scale_upscale(
5151
img: np.ndarray,
52-
upscale: ImageOp,
52+
upscale: ProgressImageOp,
5353
natural_scale: int,
5454
custom_scale: int,
5555
separate_alpha: bool,
56+
progress: Progress | None = None,
5657
) -> np.ndarray:
5758
if custom_scale == natural_scale:
58-
return upscale(img)
59+
return upscale(img, progress)
5960

6061
# number of iterations we need to do to reach the desired scale
6162
# e.g. if the model is 2x and the desired scale is 13x, we need to do 4 iterations
6263
iterations = max(1, math.ceil(math.log(custom_scale, natural_scale)))
6364
org_h, org_w, _ = get_h_w_c(img)
64-
for _ in range(iterations):
65-
img = upscale(img)
65+
for i in range(iterations):
66+
# Split progress evenly across iterations
67+
iter_progress = (
68+
progress.sub_progress(i / iterations, 1 / iterations) if progress else None
69+
)
70+
img = upscale(img, iter_progress)
6671

6772
# resize, if necessary
6873
target_size = (
@@ -83,21 +88,23 @@ def _custom_scale_upscale(
8388

8489
def basic_upscale(
8590
img: np.ndarray,
86-
upscale: ImageOp,
91+
upscale: ProgressImageOp,
8792
upscale_info: UpscaleInfo,
8893
scale: int,
8994
separate_alpha: bool,
9095
padding: PaddingType = PaddingType.NONE,
9196
clip: bool = True,
97+
progress: Progress | None = None,
9298
):
93-
def inner_upscale(img: np.ndarray) -> np.ndarray:
99+
def inner_upscale(img: np.ndarray, p: Progress | None) -> np.ndarray:
94100
return convenient_upscale(
95101
img,
96102
upscale_info.in_nc,
97103
upscale_info.out_nc,
98104
upscale,
99105
separate_alpha,
100106
clip=clip,
107+
progress=p,
101108
)
102109

103110
if not upscale_info.supports_custom_scale and scale != upscale_info.scale:
@@ -114,6 +121,7 @@ def inner_upscale(img: np.ndarray) -> np.ndarray:
114121
natural_scale=upscale_info.scale,
115122
custom_scale=scale,
116123
separate_alpha=separate_alpha,
124+
progress=progress,
117125
)
118126

119127
if padding != PaddingType.NONE:

backend/src/nodes/impl/upscale/convenient_upscale.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
from __future__ import annotations
22

3+
from collections.abc import Callable
4+
35
import numpy as np
46

7+
from api import Progress
8+
59
from ...utils.utils import get_h_w_c
6-
from ..image_op import ImageOp, clipped
10+
from ..image_op import ImageOp, clipped, to_op
711
from ..image_utils import as_target_channels
812

13+
ProgressImageOp = Callable[[np.ndarray, Progress | None], np.ndarray]
14+
"""
15+
An image processing operation that takes an image and progress, and produces a new image.
16+
"""
17+
918

1019
def with_black_and_white_backgrounds(img: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
1120
c = get_h_w_c(img)[2]
@@ -32,9 +41,10 @@ def convenient_upscale(
3241
img: np.ndarray,
3342
model_in_nc: int,
3443
model_out_nc: int,
35-
upscale: ImageOp,
44+
upscale: ProgressImageOp,
3645
separate_alpha: bool = False,
3746
clip: bool = True,
47+
progress: Progress | None = None,
3848
) -> np.ndarray:
3949
"""
4050
Upscales the given image in an intuitive/convenient way.
@@ -45,45 +55,60 @@ def convenient_upscale(
4555
Additionally, guarantees that the number of channels of the output image will match
4656
that of the input image in cases where `model_in_nc` == `model_out_nc`, and match
4757
`model_out_nc` otherwise.
58+
59+
When multiple upscale operations are needed (e.g., for RGBA images), progress is
60+
automatically divided between them using sub-progress ranges.
4861
"""
4962
in_img_c = get_h_w_c(img)[2]
5063

51-
if clip:
52-
upscale = clipped(upscale)
64+
def make_op(p: Progress | None) -> ImageOp:
65+
"""Create an ImageOp with the given progress bound."""
66+
op = to_op(upscale)(p)
67+
return clipped(op) if clip else op
5368

5469
if model_in_nc != model_out_nc:
55-
return upscale(as_target_channels(img, model_in_nc, True))
70+
return make_op(progress)(as_target_channels(img, model_in_nc, True))
5671

5772
if in_img_c == model_in_nc:
58-
return upscale(img)
73+
return make_op(progress)(img)
5974

6075
if in_img_c == 4:
6176
# Ignore alpha if single-color or not being replaced
6277
unique = np.unique(img[:, :, 3])
6378
if len(unique) == 1:
79+
op = make_op(progress)
6480
rgb = as_target_channels(
65-
upscale(as_target_channels(img[:, :, :3], model_in_nc, True)), 3, True
81+
op(as_target_channels(img[:, :, :3], model_in_nc, True)), 3, True
6682
)
6783
unique_alpha = np.full(rgb.shape[:-1], unique[0], np.float32)
6884
return np.dstack((rgb, unique_alpha))
6985

7086
if separate_alpha:
7187
# Upscale the RGB channels and alpha channel separately
88+
# Split progress: 50% for RGB, 50% for alpha
89+
rgb_op = make_op(progress.sub_progress(0, 0.5) if progress else None)
90+
alpha_op = make_op(progress.sub_progress(0.5, 0.5) if progress else None)
91+
7292
rgb = as_target_channels(
73-
upscale(as_target_channels(img[:, :, :3], model_in_nc, True)), 3, True
93+
rgb_op(as_target_channels(img[:, :, :3], model_in_nc, True)), 3, True
7494
)
7595
alpha = denoise_and_flatten_alpha(
76-
upscale(as_target_channels(img[:, :, 3], model_in_nc, True))
96+
alpha_op(as_target_channels(img[:, :, 3], model_in_nc, True))
7797
)
7898
return np.dstack((rgb, alpha))
7999
else:
80100
# Transparency hack (white/black background difference alpha)
101+
# Split progress: 50% for black background, 50% for white background
81102
black, white = with_black_and_white_backgrounds(img)
103+
104+
black_op = make_op(progress.sub_progress(0, 0.5) if progress else None)
105+
white_op = make_op(progress.sub_progress(0.5, 0.5) if progress else None)
106+
82107
black_up = as_target_channels(
83-
upscale(as_target_channels(black, model_in_nc, True)), 3, True
108+
black_op(as_target_channels(black, model_in_nc, True)), 3, True
84109
)
85110
white_up = as_target_channels(
86-
upscale(as_target_channels(white, model_in_nc, True)), 3, True
111+
white_op(as_target_channels(white, model_in_nc, True)), 3, True
87112
)
88113

89114
# Interpolate between the alpha values to get a more defined alpha
@@ -92,6 +117,4 @@ def convenient_upscale(
92117

93118
return np.dstack((black_up, alpha))
94119

95-
return as_target_channels(
96-
upscale(as_target_channels(img, model_in_nc, True)), in_img_c, True
97-
)
120+
return make_op(progress)(as_target_channels(img, model_in_nc, True))

0 commit comments

Comments
 (0)