Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Image output normalization
  • Loading branch information
RunDevelopment committed Apr 7, 2023
commit 10ae13898c043c49ee16c1aedb56a41b048f3619
8 changes: 4 additions & 4 deletions backend/src/nodes/properties/inputs/numpy_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from ...impl.image_utils import get_h_w_c, normalize
from ...impl.image_utils import get_h_w_c
from ...utils.format import format_image_with_channels
from .. import expression
from .base_input import BaseInput
Expand Down Expand Up @@ -47,10 +47,10 @@ def enforce(self, value):
f"The input {self.label} only supports {expected} but was given {actual}."
)

if c == 1 and value.ndim == 3:
value = value[:, :, 0]
assert value.dtype == np.float32, "Expected the input image to be normalized."
assert c != 1 or value.ndim == 2, "Expected single-channel images to be 2D."

return normalize(value)
return value


class VideoInput(BaseInput):
Expand Down
9 changes: 5 additions & 4 deletions backend/src/nodes/properties/outputs/base_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Literal, Union
from typing import Literal

from base_types import OutputId

Expand All @@ -20,7 +20,7 @@ def __init__(
self.output_type: expression.ExpressionJson = output_type
self.label: str = label
self.id: OutputId = OutputId(-1)
self.never_reason: Union[str, None] = None
self.never_reason: str | None = None
self.kind: OutputKind = kind
self.has_handle: bool = has_handle

Expand All @@ -34,7 +34,7 @@ def toDict(self):
"hasHandle": self.has_handle,
}

def with_id(self, output_id: Union[OutputId, int]):
def with_id(self, output_id: OutputId | int):
self.id = OutputId(output_id)
return self

Expand All @@ -54,5 +54,6 @@ def get_broadcast_data(self, _value):
def get_broadcast_type(self, _value) -> expression.ExpressionJson | None:
return None

def validate(self, value) -> None:
def enforce(self, value: object) -> object:
assert value is not None
return value
9 changes: 4 additions & 5 deletions backend/src/nodes/properties/outputs/file_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ class FileOutput(BaseOutput):
def __init__(self, file_type: expression.ExpressionJson, label: str):
super().__init__(file_type, label)

def get_broadcast_data(self, value: str):
return value

def validate(self, value) -> None:
def enforce(self, value) -> str:
assert isinstance(value, str)
return value


class DirectoryOutput(BaseOutput):
Expand All @@ -32,5 +30,6 @@ def __init__(self, label: str = "Directory", of_input: int | None = None):
def get_broadcast_type(self, value: str):
return expression.named("Directory", {"path": expression.literal(value)})

def validate(self, value) -> None:
def enforce(self, value) -> str:
assert isinstance(value, str)
return value
9 changes: 6 additions & 3 deletions backend/src/nodes/properties/outputs/generic_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ def __init__(
def get_broadcast_type(self, value: int | float):
return expression.literal(value)

def validate(self, value) -> None:
def enforce(self, value) -> int | float:
assert isinstance(value, (int, float))
return value


class TextOutput(BaseOutput):
Expand All @@ -31,8 +32,9 @@ def __init__(
def get_broadcast_type(self, value: str):
return expression.literal(value)

def validate(self, value) -> None:
def enforce(self, value) -> str:
assert isinstance(value, str)
return value


def FileNameOutput(label: str = "Name", of_input: int | None = None):
Expand All @@ -49,5 +51,6 @@ class SeedOutput(BaseOutput):
def __init__(self, label: str = "Seed"):
super().__init__(output_type="Seed", label=label, kind="generic")

def validate(self, value) -> None:
def enforce(self, value) -> Seed:
assert isinstance(value, Seed)
return value
25 changes: 21 additions & 4 deletions backend/src/nodes/properties/outputs/numpy_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import cv2
import numpy as np

from ...impl.image_utils import to_uint8
from ...impl.image_utils import normalize, to_uint8
from ...impl.pil_utils import InterpolationMethod, resize
from ...utils.format import format_image_with_channels
from ...utils.utils import get_h_w_c
Expand All @@ -24,8 +24,9 @@ def __init__(
):
super().__init__(output_type, label, kind=kind, has_handle=has_handle)

def validate(self, value) -> None:
def enforce(self, value) -> np.ndarray:
assert isinstance(value, np.ndarray)
return value


def AudioOutput():
Expand All @@ -41,6 +42,7 @@ def __init__(
kind: OutputKind = "image",
has_handle: bool = True,
channels: Optional[int] = None,
assume_normalized: bool = False,
):
super().__init__(
expression.intersect(image_type, expression.Image(channels=channels)),
Expand All @@ -50,6 +52,7 @@ def __init__(
)

self.channels: Optional[int] = channels
self.assume_normalized: bool = assume_normalized

def get_broadcast_data(self, value: np.ndarray):
h, w, c = get_h_w_c(value)
Expand All @@ -63,9 +66,8 @@ def get_broadcast_type(self, value: np.ndarray):
h, w, c = get_h_w_c(value)
return expression.Image(width=w, height=h, channels=c)

def validate(self, value) -> None:
def enforce(self, value) -> np.ndarray:
assert isinstance(value, np.ndarray)
assert value.dtype == np.float32

_, _, c = get_h_w_c(value)

Expand All @@ -78,6 +80,21 @@ def validate(self, value) -> None:
f" Please report this bug."
)

# flatting 3D single-channel images to 2D
if c == 1 and value.ndim == 3:
value = value[:, :, 0]

if self.assume_normalized:
assert value.dtype == np.float32, (
f"The output {self.label} did not return a normalized image."
f" This is a bug in the implementation of the node."
f" Please report this bug."
f"\n\nTo the author of this node: Either use `normalize` or remove `assume_normalized=True` from this output."
)
return value

return normalize(value)


def preview_encode(
img: np.ndarray,
Expand Down
2 changes: 1 addition & 1 deletion backend/src/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def to_output(raw_output: Any, node: NodeData) -> Output:

# output-specific validations
for i, o in enumerate(node.outputs):
o.validate(output[i])
output[i] = o.enforce(output[i])

return output

Expand Down