Skip to content
5 changes: 4 additions & 1 deletion backend/src/nodes/impl/ncnn/auto_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from ncnn import ncnn # type: ignore

use_gpu = False

from api import Progress
from logger import logger

from ...utils.utils import get_h_w_c
Expand All @@ -27,6 +29,7 @@ def ncnn_auto_split(
blob_vkallocator, # noqa: ANN001
staging_vkallocator, # noqa: ANN001
tiler: Tiler,
progress: Progress | None = None,
) -> np.ndarray:
def upscale(img: np.ndarray, _: object):
ex = net.create_extractor()
Expand Down Expand Up @@ -91,4 +94,4 @@ def upscale(img: np.ndarray, _: object):
# Re-raise the exception if not an OOM error
raise

return auto_split(img, upscale, tiler)
return auto_split(img, upscale, tiler, progress=progress)
4 changes: 3 additions & 1 deletion backend/src/nodes/impl/onnx/auto_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import onnxruntime as ort

from api import Progress
from nodes.impl.onnx.model import SizeReq

from ..upscale.auto_split import Tiler, auto_split
Expand Down Expand Up @@ -101,6 +102,7 @@ def onnx_auto_split(
change_shape: bool,
tiler: Tiler,
size_req: SizeReq | None = None,
progress: Progress | None = None,
) -> np.ndarray:
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
Expand Down Expand Up @@ -134,6 +136,6 @@ def upscale(img: np.ndarray, _: object):
raise

try:
return auto_split(img, upscale, tiler)
return auto_split(img, upscale, tiler, progress=progress)
finally:
gc.collect()
17 changes: 9 additions & 8 deletions backend/src/nodes/impl/pytorch/auto_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,20 @@ def pytorch_auto_split(
device: torch.device,
use_fp16: bool,
tiler: Tiler,
progress: Progress,
progress: Progress | None = None,
) -> np.ndarray:
dtype = torch.float16 if use_fp16 else torch.float32
if model.dtype != dtype or model.device != device:
model = model.to(device, dtype)

def upscale(img: np.ndarray, _: object):
progress.check_aborted()
if progress.paused:
# clear resources before pausing
gc.collect()
safe_cuda_cache_empty()
progress.suspend()
if progress is not None:
progress.check_aborted()
if progress.paused:
# clear resources before pausing
gc.collect()
safe_cuda_cache_empty()
progress.suspend()

input_tensor = None
try:
Expand Down Expand Up @@ -165,4 +166,4 @@ def upscale(img: np.ndarray, _: object):
# Re-raise the exception if not an OOM error
raise

return auto_split(img, upscale, tiler)
return auto_split(img, upscale, tiler, progress=progress)
11 changes: 10 additions & 1 deletion backend/src/nodes/impl/tensorrt/auto_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import numpy as np

from api import Progress

from ..upscale.auto_split import Tiler, auto_split
from .inference import get_tensorrt_session
from .model import TensorRTEngine
Expand Down Expand Up @@ -59,6 +61,7 @@ def tensorrt_auto_split(
engine: TensorRTEngine,
tiler: Tiler,
gpu_index: int = 0,
progress: Progress | None = None,
) -> np.ndarray:
"""
Run TensorRT inference with automatic tiling for large images.
Expand All @@ -76,6 +79,12 @@ def tensorrt_auto_split(
is_fp16 = engine.precision == "fp16"

def upscale(img: np.ndarray, _: object):
if progress is not None:
progress.check_aborted()
if progress.paused:
gc.collect()
progress.suspend()

try:
# Convert to appropriate precision
lr_img = img.astype(np.float16) if is_fp16 else img.astype(np.float32)
Expand Down Expand Up @@ -113,6 +122,6 @@ def upscale(img: np.ndarray, _: object):
raise

try:
return auto_split(img, upscale, tiler)
return auto_split(img, upscale, tiler, progress=progress)
finally:
gc.collect()
15 changes: 15 additions & 0 deletions backend/src/nodes/impl/upscale/auto_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np

from api import Progress
from logger import logger

from ...utils.utils import Region, Size, get_h_w_c
Expand All @@ -25,6 +26,7 @@ def auto_split(
upscale: SplitImageOp,
tiler: Tiler,
overlap: int = 16,
progress: Progress | None = None,
) -> np.ndarray:
"""
Splits the image into tiles according to the given tiler.
Expand All @@ -49,6 +51,7 @@ def auto_split(
starting_tile_size=tiler.starting_tile_size(w, h, c),
split_tile_size=tiler.split,
overlap=overlap,
progress=progress,
)


Expand All @@ -62,6 +65,7 @@ def _exact_split(
starting_tile_size: Size,
split_tile_size: Callable[[Size], Size],
overlap: int,
progress: Progress | None = None,
) -> np.ndarray:
h, w, c = get_h_w_c(img)
logger.debug(
Expand Down Expand Up @@ -89,6 +93,7 @@ def no_split_upscale(i: np.ndarray, r: Region) -> np.ndarray:
exact_size=starting_tile_size,
upscale=no_split_upscale,
overlap=min(max_overlap, overlap),
progress=progress,
)
except _SplitEx:
starting_tile_size = split_tile_size(starting_tile_size)
Expand All @@ -102,6 +107,7 @@ def _max_split(
starting_tile_size: Size,
split_tile_size: Callable[[Size], Size],
overlap: int,
progress: Progress | None = None,
) -> np.ndarray:
"""
Splits the image into tiles with at most the given tile size.
Expand All @@ -126,6 +132,8 @@ def _max_split(
# the image might be small enough so that we don't have to split at all
upscale_result = upscale(img, img_region)
if not isinstance(upscale_result, Split):
if progress is not None:
progress.set_progress(1.0)
return upscale_result

# the image was too large
Expand Down Expand Up @@ -161,6 +169,7 @@ def _max_split(
tile_count_y = math.ceil(h / max_tile_size[1])
tile_size_x = math.ceil(w / tile_count_x)
tile_size_y = math.ceil(h / tile_count_y)
total_tiles = tile_count_x * tile_count_y

logger.debug(
"Currently %dx%d tiles each %dx%dpx.",
Expand All @@ -171,6 +180,7 @@ def _max_split(
)

prev_row_result: TileBlender | None = None
tiles_processed = 0

for y in range(tile_count_y):
if y < start_y:
Expand Down Expand Up @@ -238,6 +248,11 @@ def _max_split(
upscale_result, TileOverlap(pad.left * scale, pad.right * scale)
)

# Report progress after each tile
tiles_processed += 1
if progress is not None:
progress.set_progress(tiles_processed / total_tiles)

if restart:
break

Expand Down
24 changes: 16 additions & 8 deletions backend/src/nodes/impl/upscale/basic_upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import numpy as np

from nodes.impl.image_op import ImageOp
from api import Progress
from nodes.impl.image_utils import BorderType, create_border
from nodes.impl.resize import ResizeFilter, resize
from nodes.utils.utils import Padding, get_h_w_c

from .convenient_upscale import convenient_upscale
from .convenient_upscale import ProgressImageOp, convenient_upscale


@dataclass
Expand Down Expand Up @@ -49,20 +49,25 @@ def to_border_type(self) -> BorderType:

def _custom_scale_upscale(
img: np.ndarray,
upscale: ImageOp,
upscale: ProgressImageOp,
natural_scale: int,
custom_scale: int,
separate_alpha: bool,
progress: Progress | None = None,
) -> np.ndarray:
if custom_scale == natural_scale:
return upscale(img)
return upscale(img, progress)

# number of iterations we need to do to reach the desired scale
# e.g. if the model is 2x and the desired scale is 13x, we need to do 4 iterations
iterations = max(1, math.ceil(math.log(custom_scale, natural_scale)))
org_h, org_w, _ = get_h_w_c(img)
for _ in range(iterations):
img = upscale(img)
for i in range(iterations):
# Split progress evenly across iterations
iter_progress = (
progress.sub_progress(i / iterations, 1 / iterations) if progress else None
)
img = upscale(img, iter_progress)

# resize, if necessary
target_size = (
Expand All @@ -83,21 +88,23 @@ def _custom_scale_upscale(

def basic_upscale(
img: np.ndarray,
upscale: ImageOp,
upscale: ProgressImageOp,
upscale_info: UpscaleInfo,
scale: int,
separate_alpha: bool,
padding: PaddingType = PaddingType.NONE,
clip: bool = True,
progress: Progress | None = None,
):
def inner_upscale(img: np.ndarray) -> np.ndarray:
def inner_upscale(img: np.ndarray, p: Progress | None) -> np.ndarray:
return convenient_upscale(
img,
upscale_info.in_nc,
upscale_info.out_nc,
upscale,
separate_alpha,
clip=clip,
progress=p,
)

if not upscale_info.supports_custom_scale and scale != upscale_info.scale:
Expand All @@ -114,6 +121,7 @@ def inner_upscale(img: np.ndarray) -> np.ndarray:
natural_scale=upscale_info.scale,
custom_scale=scale,
separate_alpha=separate_alpha,
progress=progress,
)

if padding != PaddingType.NONE:
Expand Down
51 changes: 37 additions & 14 deletions backend/src/nodes/impl/upscale/convenient_upscale.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import annotations

from collections.abc import Callable

import numpy as np

from api import Progress

from ...utils.utils import get_h_w_c
from ..image_op import ImageOp, clipped
from ..image_op import ImageOp, clipped, to_op
from ..image_utils import as_target_channels

ProgressImageOp = Callable[[np.ndarray, Progress | None], np.ndarray]
"""
An image processing operation that takes an image and progress, and produces a new image.
"""


def with_black_and_white_backgrounds(img: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
c = get_h_w_c(img)[2]
Expand All @@ -32,9 +41,10 @@ def convenient_upscale(
img: np.ndarray,
model_in_nc: int,
model_out_nc: int,
upscale: ImageOp,
upscale: ProgressImageOp,
separate_alpha: bool = False,
clip: bool = True,
progress: Progress | None = None,
) -> np.ndarray:
"""
Upscales the given image in an intuitive/convenient way.
Expand All @@ -45,45 +55,60 @@ def convenient_upscale(
Additionally, guarantees that the number of channels of the output image will match
that of the input image in cases where `model_in_nc` == `model_out_nc`, and match
`model_out_nc` otherwise.

When multiple upscale operations are needed (e.g., for RGBA images), progress is
automatically divided between them using sub-progress ranges.
"""
in_img_c = get_h_w_c(img)[2]

if clip:
upscale = clipped(upscale)
def make_op(p: Progress | None) -> ImageOp:
"""Create an ImageOp with the given progress bound."""
op = to_op(upscale)(p)
return clipped(op) if clip else op

if model_in_nc != model_out_nc:
return upscale(as_target_channels(img, model_in_nc, True))
return make_op(progress)(as_target_channels(img, model_in_nc, True))

if in_img_c == model_in_nc:
return upscale(img)
return make_op(progress)(img)

if in_img_c == 4:
# Ignore alpha if single-color or not being replaced
unique = np.unique(img[:, :, 3])
if len(unique) == 1:
op = make_op(progress)
rgb = as_target_channels(
upscale(as_target_channels(img[:, :, :3], model_in_nc, True)), 3, True
op(as_target_channels(img[:, :, :3], model_in_nc, True)), 3, True
)
unique_alpha = np.full(rgb.shape[:-1], unique[0], np.float32)
return np.dstack((rgb, unique_alpha))

if separate_alpha:
# Upscale the RGB channels and alpha channel separately
# Split progress: 50% for RGB, 50% for alpha
rgb_op = make_op(progress.sub_progress(0, 0.5) if progress else None)
alpha_op = make_op(progress.sub_progress(0.5, 0.5) if progress else None)

rgb = as_target_channels(
upscale(as_target_channels(img[:, :, :3], model_in_nc, True)), 3, True
rgb_op(as_target_channels(img[:, :, :3], model_in_nc, True)), 3, True
)
alpha = denoise_and_flatten_alpha(
upscale(as_target_channels(img[:, :, 3], model_in_nc, True))
alpha_op(as_target_channels(img[:, :, 3], model_in_nc, True))
)
return np.dstack((rgb, alpha))
else:
# Transparency hack (white/black background difference alpha)
# Split progress: 50% for black background, 50% for white background
black, white = with_black_and_white_backgrounds(img)

black_op = make_op(progress.sub_progress(0, 0.5) if progress else None)
white_op = make_op(progress.sub_progress(0.5, 0.5) if progress else None)

black_up = as_target_channels(
upscale(as_target_channels(black, model_in_nc, True)), 3, True
black_op(as_target_channels(black, model_in_nc, True)), 3, True
)
white_up = as_target_channels(
upscale(as_target_channels(white, model_in_nc, True)), 3, True
white_op(as_target_channels(white, model_in_nc, True)), 3, True
)

# Interpolate between the alpha values to get a more defined alpha
Expand All @@ -92,6 +117,4 @@ def convenient_upscale(

return np.dstack((black_up, alpha))

return as_target_channels(
upscale(as_target_channels(img, model_in_nc, True)), in_img_c, True
)
return make_op(progress)(as_target_channels(img, model_in_nc, True))
Loading