Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions backend/src/nodes/properties/inputs/file_inputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
from pathlib import Path
from typing import Literal, Union

Expand All @@ -8,6 +9,7 @@

# pylint: disable=relative-beyond-top-level
from ...impl.image_formats import get_available_image_formats
from .generic_inputs import TextInput
from .label import LabelStyle

FileInputKind = Union[
Expand Down Expand Up @@ -112,15 +114,14 @@ def PthFileInput(primary_input: bool = False) -> FileInput:
)


class DirectoryInput(BaseInput):
class DirectoryInput(BaseInput[Path]):
"""Input for submitting a local directory"""

def __init__(
self,
label: str = "Directory",
has_handle: bool = True,
must_exist: bool = True,
create: bool = False,
label_style: LabelStyle = "default",
):
super().__init__("Directory", label, kind="directory", has_handle=has_handle)
Expand All @@ -133,7 +134,6 @@ def __init__(
"""

self.must_exist: bool = must_exist
self.create: bool = create
self.label_style: LabelStyle = label_style

self.associated_type = Path
Expand All @@ -144,14 +144,12 @@ def to_dict(self):
"labelStyle": self.label_style,
}

def enforce(self, value: object):
def enforce(self, value: object) -> Path:
if isinstance(value, str):
value = Path(value)
assert isinstance(value, Path)

if self.create:
value.mkdir(parents=True, exist_ok=True)
elif self.must_exist:
if self.must_exist:
assert value.exists(), f"Directory {value} does not exist"

return value
Expand Down Expand Up @@ -185,3 +183,47 @@ def OnnxFileInput(primary_input: bool = False) -> FileInput:
filetypes=[".onnx"],
primary_input=primary_input,
)


_INVALID_PATH_CHARS = re.compile(r'[<>:"|?*\x00-\x1F]')


def _is_abs_path(path: str) -> bool:
return path.startswith(("/", "\\")) or Path(path).is_absolute()


class RelativePathInput(TextInput):
def __init__(
self,
label: str,
has_handle: bool = True,
placeholder: str | None = None,
allow_numbers: bool = True,
default: str | None = None,
label_style: LabelStyle = "default",
):
super().__init__(
label,
has_handle=has_handle,
min_length=1,
max_length=None,
placeholder=placeholder,
multiline=False,
allow_numbers=allow_numbers,
default=default,
label_style=label_style,
allow_empty_string=False,
invalid_pattern=_INVALID_PATH_CHARS.pattern,
)

def enforce(self, value: object) -> str:
value = super().enforce(value)

if _is_abs_path(value):
raise ValueError(f"Absolute paths are not allowed for input {self.label}.")

invalid = _INVALID_PATH_CHARS.search(value)
if invalid is not None:
raise ValueError(f"Invalid character '{invalid.group()}' in {self.label}.")

return value
3 changes: 3 additions & 0 deletions backend/src/nodes/properties/inputs/generic_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def __init__(
default: str | None = None,
label_style: LabelStyle = "default",
allow_empty_string: bool = False,
invalid_pattern: str | None = None,
):
super().__init__(
input_type="string" if min_length == 0 else 'invStrSet("")',
Expand All @@ -300,6 +301,7 @@ def __init__(
self.multiline = multiline
self.label_style: LabelStyle = label_style
self.allow_empty_string = allow_empty_string
self.invalid_pattern = invalid_pattern

if default is not None:
assert default != ""
Expand Down Expand Up @@ -339,6 +341,7 @@ def to_dict(self):
"def": self.default,
"labelStyle": self.label_style,
"allowEmptyString": self.allow_empty_string,
"invalidPattern": self.invalid_pattern,
}


Expand Down
14 changes: 7 additions & 7 deletions backend/src/packages/chaiNNer_ncnn/ncnn/io/save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sanic.log import logger

from nodes.impl.ncnn.model import NcnnModelWrapper
from nodes.properties.inputs import DirectoryInput, NcnnModelInput, TextInput
from nodes.properties.inputs import DirectoryInput, NcnnModelInput, RelativePathInput

from .. import io_group

Expand All @@ -15,18 +15,18 @@
icon="MdSave",
inputs=[
NcnnModelInput(),
DirectoryInput(create=True),
TextInput("Param/Bin Name"),
DirectoryInput(must_exist=False),
RelativePathInput("Param/Bin Name"),
],
outputs=[],
side_effects=True,
)
def save_model_node(model: NcnnModelWrapper, directory: Path, name: str) -> None:
full_bin = f"{name}.bin"
full_param = f"{name}.param"
full_bin_path = directory / full_bin
full_param_path = directory / full_param
full_bin_path = (directory / f"{name}.bin").resolve()
full_param_path = (directory / f"{name}.param").resolve()

logger.debug(f"Writing NCNN model to paths: {full_bin_path} {full_param_path}")

full_bin_path.parent.mkdir(parents=True, exist_ok=True)
model.model.write_bin(full_bin_path)
model.model.write_param(full_param_path)
9 changes: 5 additions & 4 deletions backend/src/packages/chaiNNer_onnx/onnx/io/save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sanic.log import logger

from nodes.impl.onnx.model import OnnxModel
from nodes.properties.inputs import DirectoryInput, OnnxModelInput, TextInput
from nodes.properties.inputs import DirectoryInput, OnnxModelInput, RelativePathInput

from .. import io_group

Expand All @@ -17,14 +17,15 @@
icon="MdSave",
inputs=[
OnnxModelInput(),
DirectoryInput(create=True),
TextInput("Model Name"),
DirectoryInput(must_exist=False),
RelativePathInput("Model Name"),
],
outputs=[],
side_effects=True,
)
def save_model_node(model: OnnxModel, directory: Path, model_name: str) -> None:
full_path = f"{directory / model_name}.onnx"
full_path = (directory / f"{model_name}.onnx").resolve()
logger.debug(f"Writing file to path: {full_path}")
full_path.parent.mkdir(parents=True, exist_ok=True)
with open(full_path, "wb") as f:
f.write(model.bytes)
15 changes: 10 additions & 5 deletions backend/src/packages/chaiNNer_pytorch/pytorch/io/save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from sanic.log import logger
from spandrel import ModelDescriptor

from nodes.properties.inputs import DirectoryInput, EnumInput, ModelInput, TextInput
from nodes.properties.inputs import (
DirectoryInput,
EnumInput,
ModelInput,
RelativePathInput,
)

from .. import io_group

Expand All @@ -28,8 +33,8 @@ class WeightFormat(Enum):
icon="MdSave",
inputs=[
ModelInput(),
DirectoryInput(create=True),
TextInput("Model Name"),
DirectoryInput(must_exist=False),
RelativePathInput("Model Name"),
EnumInput(
WeightFormat,
"Weight Format",
Expand All @@ -46,9 +51,9 @@ class WeightFormat(Enum):
def save_model_node(
model: ModelDescriptor, directory: Path, name: str, weight_format: WeightFormat
) -> None:
full_file = f"{name}.{weight_format.value}"
full_path = directory / full_file
full_path = (directory / f"{name}.{weight_format.value}").resolve()
logger.debug(f"Writing model to path: {full_path}")
full_path.parent.mkdir(parents=True, exist_ok=True)
if weight_format == WeightFormat.PTH:
torch.save(model.model.state_dict(), full_path)
elif weight_format == WeightFormat.ST:
Expand Down
8 changes: 4 additions & 4 deletions backend/src/packages/chaiNNer_standard/image/io/save_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
DropDownInput,
EnumInput,
ImageInput,
RelativePathInput,
SliderInput,
TextInput,
)
from nodes.utils.utils import get_h_w_c

Expand Down Expand Up @@ -161,13 +161,13 @@ def DdsMipMapsDropdown() -> DropDownInput:
inputs=[
ImageInput(),
DirectoryInput(must_exist=False),
TextInput("Subdirectory Path")
RelativePathInput("Subdirectory Path")
.make_optional()
.with_docs(
"An optional subdirectory path. Use this to save the image to a subdirectory of the specified directory. If the subdirectory does not exist, it will be created. Multiple subdirectories can be specified by separating them with a forward slash (`/`).",
"Example: `foo/bar`",
),
TextInput("Image Name").with_docs(
RelativePathInput("Image Name").with_docs(
"The name of the image file **without** the file extension. If the file already exists, it will be overwritten.",
"Example: `my-image`",
),
Expand Down Expand Up @@ -385,4 +385,4 @@ def get_full_path(
if relative_path and relative_path != ".":
base_directory = base_directory / relative_path
full_path = base_directory / file
return full_path
return full_path.resolve()
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from subprocess import Popen
from typing import Any, Literal

Expand All @@ -19,6 +20,7 @@
DirectoryInput,
EnumInput,
ImageInput,
RelativePathInput,
SliderInput,
TextInput,
)
Expand Down Expand Up @@ -219,8 +221,8 @@ def close(self):
icon="MdVideoCameraBack",
inputs=[
ImageInput("Image Sequence", channels=3),
DirectoryInput(create=True),
TextInput("Video Name"),
DirectoryInput(must_exist=False),
RelativePathInput("Video Name"),
EnumInput(
VideoFormat,
label="Video Format",
Expand Down Expand Up @@ -324,7 +326,7 @@ def close(self):
def save_video_node(
node_context: NodeContext,
_: None,
save_dir: str,
save_dir: Path,
video_name: str,
container: VideoFormat,
encoder: VideoEncoder,
Expand All @@ -336,11 +338,12 @@ def save_video_node(
audio: Any,
audio_settings: AudioSettings,
) -> Collector[np.ndarray, None]:
save_path = os.path.join(save_dir, f"{video_name}.{container.ext}")
save_path = (save_dir / f"{video_name}.{container.ext}").resolve()
save_path.parent.mkdir(parents=True, exist_ok=True)

# Common output settings
output_params = {
"filename": save_path,
"filename": str(save_path),
"pix_fmt": "yuv420p",
"r": fps,
"movflags": "faststart",
Expand Down Expand Up @@ -388,7 +391,7 @@ def save_video_node(
fps=fps,
audio=audio,
audio_settings=audio_settings,
save_path=save_path,
save_path=str(save_path),
output_params=output_params,
global_params=global_params,
ffmpeg_env=FFMpegEnv.get_integrated(node_context.storage_dir),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,13 @@
from __future__ import annotations

import re
from pathlib import Path

from nodes.groups import optional_list_group
from nodes.properties.inputs import DirectoryInput, TextInput
from nodes.properties.inputs import DirectoryInput, RelativePathInput
from nodes.properties.outputs import DirectoryOutput

from .. import directory_group

INVALID_CHARS = re.compile(r"[<>:\"|?*\x00-\x1F]")


def is_abs(path: str) -> bool:
return path.startswith(("/", "\\")) or Path(path).is_absolute()


def go_into(dir: Path, folder: str) -> Path:
if is_abs(folder):
raise ValueError("Absolute paths are not allowed as folders.")

invalid = INVALID_CHARS.search(folder)
if invalid is not None:
raise ValueError(f"Invalid character '{invalid.group()}' in folder name.")

return (dir / folder).resolve()


@directory_group.register(
schema_id="chainner:utility:into_directory",
Expand All @@ -34,9 +16,9 @@ def go_into(dir: Path, folder: str) -> Path:
icon="BsFolder",
inputs=[
DirectoryInput(must_exist=False, label_style="hidden"),
TextInput("Folder"),
RelativePathInput("Folder"),
optional_list_group(
*[TextInput(f"Folder {i}").make_optional() for i in range(2, 11)],
*[RelativePathInput(f"Folder {i}").make_optional() for i in range(2, 11)],
),
],
outputs=[
Expand Down Expand Up @@ -78,5 +60,5 @@ def into(dir: Directory | Error, folder: string | null): Directory | Error {
def directory_go_into_node(directory: Path, *folders: str | None) -> Path:
for folder in folders:
if folder is not None:
directory = go_into(directory, folder)
directory = (directory / folder).resolve()
return directory
1 change: 1 addition & 0 deletions src/common/common-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ export interface TextInput extends InputBase {
readonly placeholder?: string | null;
readonly def?: string | null;
readonly allowEmptyString?: boolean;
readonly invalidPattern?: string | null;
readonly labelStyle: LabelStyle | undefined;
}
export interface NumberInput extends InputBase {
Expand Down
Loading