Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion backend/src/api/node_context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Literal
from typing import Callable, Literal

from .settings import SettingsParser

Expand Down Expand Up @@ -146,3 +146,11 @@ def storage_dir(self) -> Path:

This directory persists between node executions, and its contents are shared between different nodes.
"""

@abstractmethod
def add_cleanup(self, fn: Callable) -> None:
"""
Registers a function that will be called when the chain execution is finished.

The function will be called with no arguments.
"""
5 changes: 1 addition & 4 deletions backend/src/nodes/impl/pytorch/auto_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,4 @@ def upscale(img: np.ndarray, _: object):
# Re-raise the exception if not an OOM error
raise

try:
return auto_split(img, upscale, tiler)
finally:
safe_cuda_cache_empty()
return auto_split(img, upscale, tiler)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from api import KeyInfo, NodeContext, Progress
from nodes.groups import Condition, if_enum_group, if_group
from nodes.impl.pytorch.auto_split import pytorch_auto_split
from nodes.impl.pytorch.utils import safe_cuda_cache_empty
from nodes.impl.upscale.auto_split_tiles import (
CUSTOM,
NO_TILING,
Expand Down Expand Up @@ -262,6 +263,8 @@ def upscale_image_node(
) -> np.ndarray:
exec_options = get_settings(context)

context.add_cleanup(safe_cuda_cache_empty)

in_nc = model.input_channels
out_nc = model.output_channels
scale = model.scale
Expand Down
15 changes: 14 additions & 1 deletion backend/src/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, NewType, Sequence, Union
from typing import Any, Callable, Iterable, List, NewType, Sequence, Union

from sanic.log import logger

Expand Down Expand Up @@ -342,6 +342,8 @@ def __init__(
self.__settings = settings
self._storage_dir = storage_dir

self.cleanup_fns: set[Callable] = set()

@property
def aborted(self) -> bool:
return self.progress.aborted
Expand All @@ -367,6 +369,9 @@ def settings(self) -> SettingsParser:
def storage_dir(self) -> Path:
return self._storage_dir

def add_cleanup(self, fn: Callable[..., Any]) -> None:
self.cleanup_fns.add(fn)


class Executor:
"""
Expand Down Expand Up @@ -805,6 +810,14 @@ async def __process_nodes(self):
# clear cache after the chain is done
self.cache.clear()

# Run cleanup functions
for context in self.__context_cache.values():
for fn in context.cleanup_fns:
try:
fn()
except Exception as e:
logger.error(f"Error running cleanup function: {e}")

# await all broadcasts
tasks = self.__broadcast_tasks
self.__broadcast_tasks = []
Expand Down