Skip to content

Commit 24c2f89

Browse files
Output normalization (#1717)
* Image output normalization * Apply to nodes * Added 2D-ize back * Didn't mean to remove that
1 parent bffc015 commit 24c2f89

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+129
-77
lines changed

backend/src/nodes/properties/inputs/numpy_inputs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from ...impl.image_utils import get_h_w_c, normalize
7+
from ...impl.image_utils import get_h_w_c
88
from ...utils.format import format_image_with_channels
99
from .. import expression
1010
from .base_input import BaseInput
@@ -47,10 +47,12 @@ def enforce(self, value):
4747
f"The input {self.label} only supports {expected} but was given {actual}."
4848
)
4949

50+
assert value.dtype == np.float32, "Expected the input image to be normalized."
51+
5052
if c == 1 and value.ndim == 3:
5153
value = value[:, :, 0]
5254

53-
return normalize(value)
55+
return value
5456

5557

5658
class VideoInput(BaseInput):

backend/src/nodes/properties/outputs/base_output.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Literal, Union
3+
from typing import Literal
44

55
from base_types import OutputId
66

@@ -20,7 +20,7 @@ def __init__(
2020
self.output_type: expression.ExpressionJson = output_type
2121
self.label: str = label
2222
self.id: OutputId = OutputId(-1)
23-
self.never_reason: Union[str, None] = None
23+
self.never_reason: str | None = None
2424
self.kind: OutputKind = kind
2525
self.has_handle: bool = has_handle
2626

@@ -34,7 +34,7 @@ def toDict(self):
3434
"hasHandle": self.has_handle,
3535
}
3636

37-
def with_id(self, output_id: Union[OutputId, int]):
37+
def with_id(self, output_id: OutputId | int):
3838
self.id = OutputId(output_id)
3939
return self
4040

@@ -54,5 +54,6 @@ def get_broadcast_data(self, _value):
5454
def get_broadcast_type(self, _value) -> expression.ExpressionJson | None:
5555
return None
5656

57-
def validate(self, value) -> None:
57+
def enforce(self, value: object) -> object:
5858
assert value is not None
59+
return value

backend/src/nodes/properties/outputs/file_outputs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ def __init__(self, file_type: expression.ExpressionJson, label: str):
1313
def get_broadcast_data(self, value: str):
1414
return value
1515

16-
def validate(self, value) -> None:
16+
def enforce(self, value) -> str:
1717
assert isinstance(value, str)
18+
return value
1819

1920

2021
class DirectoryOutput(BaseOutput):
@@ -32,5 +33,6 @@ def __init__(self, label: str = "Directory", of_input: int | None = None):
3233
def get_broadcast_type(self, value: str):
3334
return expression.named("Directory", {"path": expression.literal(value)})
3435

35-
def validate(self, value) -> None:
36+
def enforce(self, value) -> str:
3637
assert isinstance(value, str)
38+
return value

backend/src/nodes/properties/outputs/generic_outputs.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ def __init__(
1616
def get_broadcast_type(self, value: int | float):
1717
return expression.literal(value)
1818

19-
def validate(self, value) -> None:
19+
def enforce(self, value) -> int | float:
2020
assert isinstance(value, (int, float))
21+
return value
2122

2223

2324
class TextOutput(BaseOutput):
@@ -31,8 +32,9 @@ def __init__(
3132
def get_broadcast_type(self, value: str):
3233
return expression.literal(value)
3334

34-
def validate(self, value) -> None:
35+
def enforce(self, value) -> str:
3536
assert isinstance(value, str)
37+
return value
3638

3739

3840
def FileNameOutput(label: str = "Name", of_input: int | None = None):
@@ -49,5 +51,6 @@ class SeedOutput(BaseOutput):
4951
def __init__(self, label: str = "Seed"):
5052
super().__init__(output_type="Seed", label=label, kind="generic")
5153

52-
def validate(self, value) -> None:
54+
def enforce(self, value) -> Seed:
5355
assert isinstance(value, Seed)
56+
return value

backend/src/nodes/properties/outputs/numpy_outputs.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import cv2
55
import numpy as np
66

7-
from ...impl.image_utils import to_uint8
7+
from ...impl.image_utils import normalize, to_uint8
88
from ...impl.pil_utils import InterpolationMethod, resize
99
from ...utils.format import format_image_with_channels
1010
from ...utils.utils import get_h_w_c
@@ -24,8 +24,9 @@ def __init__(
2424
):
2525
super().__init__(output_type, label, kind=kind, has_handle=has_handle)
2626

27-
def validate(self, value) -> None:
27+
def enforce(self, value) -> np.ndarray:
2828
assert isinstance(value, np.ndarray)
29+
return value
2930

3031

3132
def AudioOutput():
@@ -41,6 +42,7 @@ def __init__(
4142
kind: OutputKind = "image",
4243
has_handle: bool = True,
4344
channels: Optional[int] = None,
45+
assume_normalized: bool = False,
4446
):
4547
super().__init__(
4648
expression.intersect(image_type, expression.Image(channels=channels)),
@@ -50,6 +52,7 @@ def __init__(
5052
)
5153

5254
self.channels: Optional[int] = channels
55+
self.assume_normalized: bool = assume_normalized
5356

5457
def get_broadcast_data(self, value: np.ndarray):
5558
h, w, c = get_h_w_c(value)
@@ -63,9 +66,8 @@ def get_broadcast_type(self, value: np.ndarray):
6366
h, w, c = get_h_w_c(value)
6467
return expression.Image(width=w, height=h, channels=c)
6568

66-
def validate(self, value) -> None:
69+
def enforce(self, value) -> np.ndarray:
6770
assert isinstance(value, np.ndarray)
68-
assert value.dtype == np.float32
6971

7072
_, _, c = get_h_w_c(value)
7173

@@ -78,6 +80,21 @@ def validate(self, value) -> None:
7880
f" Please report this bug."
7981
)
8082

83+
# flatting 3D single-channel images to 2D
84+
if c == 1 and value.ndim == 3:
85+
value = value[:, :, 0]
86+
87+
if self.assume_normalized:
88+
assert value.dtype == np.float32, (
89+
f"The output {self.label} did not return a normalized image."
90+
f" This is a bug in the implementation of the node."
91+
f" Please report this bug."
92+
f"\n\nTo the author of this node: Either use `normalize` or remove `assume_normalized=True` from this output."
93+
)
94+
return value
95+
96+
return normalize(value)
97+
8198

8299
def preview_encode(
83100
img: np.ndarray,

backend/src/packages/chaiNNer_standard/image/batch_processing/video_frame_iterator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy as np
1111
from sanic.log import logger
1212

13-
from nodes.impl.image_utils import normalize, to_uint8
13+
from nodes.impl.image_utils import to_uint8
1414
from nodes.properties.inputs import (
1515
BoolInput,
1616
DirectoryInput,
@@ -73,7 +73,7 @@ class Writer:
7373
def VideoFrameIteratorFrameLoaderNode(
7474
img: np.ndarray, idx: int, video_dir: str, video_name: str
7575
) -> Tuple[np.ndarray, int, str, str]:
76-
return normalize(img), idx, video_dir, video_name
76+
return img, idx, video_dir, video_name
7777

7878

7979
@batch_processing_group.register(

backend/src/packages/chaiNNer_standard/image/create_images/create_color_gray.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
width="Input0",
3232
height="Input1",
3333
channels="1",
34-
)
34+
),
35+
assume_normalized=True,
3536
)
3637
],
3738
)

backend/src/packages/chaiNNer_standard/image/create_images/create_color_rgb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@
4545
width="Input0",
4646
height="Input1",
4747
channels="3",
48-
)
48+
),
49+
assume_normalized=True,
4950
)
5051
],
5152
)

backend/src/packages/chaiNNer_standard/image/create_images/create_color_rgba.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@
5252
width="Input0",
5353
height="Input1",
5454
channels="4",
55-
)
55+
),
56+
assume_normalized=True,
5657
)
5758
],
5859
)

backend/src/packages/chaiNNer_standard/image/create_images/create_noise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,4 @@ def create_noise_node(
200200
kwargs["seed"] = (kwargs["seed"] + 1) % (2**32)
201201
img /= total_brightness
202202

203-
return np.clip(img, 0, 1)
203+
return img

0 commit comments

Comments
 (0)