Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args, PerformanceFeature
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import threading
import torch
import sys
import platform
Expand Down Expand Up @@ -650,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
soft_empty_cache()
return unloaded_models

def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state

Expand Down Expand Up @@ -746,6 +747,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
current_loaded_models.insert(0, loaded_model)
return

def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
with torch.inference_mode():
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
soft_empty_cache()

def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
#thread local. So exploit that to escape context
if enables_dynamic_vram():
t = threading.Thread(
target=load_models_gpu_thread,
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
)
t.start()
t.join()
else:
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)

def load_model_gpu(model):
return load_models_gpu([model])

Expand Down Expand Up @@ -1112,11 +1133,11 @@ def get_cast_buffer(offload_stream, device, size, ref):
return None
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
torch.cuda.synchronize()
synchronize()
del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
torch.cuda.empty_cache()
soft_empty_cache()
with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
Expand All @@ -1132,9 +1153,7 @@ def reset_cast_buffers():
for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize()
STREAM_CAST_BUFFERS.clear()
if comfy.memory_management.aimdo_allocator is None:
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
torch.cuda.empty_cache()
soft_empty_cache()

def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
Expand Down Expand Up @@ -1284,7 +1303,7 @@ def discard_cuda_async_error():
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
torch.cuda.synchronize()
synchronize()
except torch.AcceleratorError:
#Dump it! We already know about it from the synchronous return
pass
Expand Down Expand Up @@ -1688,6 +1707,12 @@ def lora_compute_dtype(device):
LORA_COMPUTE_DTYPES[device] = dtype
return dtype

def synchronize():
if is_intel_xpu():
torch.xpu.synchronize()
elif torch.cuda.is_available():
torch.cuda.synchronize()

def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
Expand All @@ -1713,9 +1738,6 @@ def debug_memory_summary():
return torch.cuda.memory.memory_summary()
return ""

#TODO: might be cleaner to put this somewhere else
import threading

class InterruptProcessingException(Exception):
pass

Expand Down
2 changes: 1 addition & 1 deletion comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,7 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):

if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None)
self.partially_unload(None, 1e32)

def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
Expand Down
51 changes: 51 additions & 0 deletions comfy_api_nodes/apis/hitpaw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import TypedDict

from pydantic import BaseModel, Field


class InputVideoModel(TypedDict):
model: str
resolution: str


class ImageEnhanceTaskCreateRequest(BaseModel):
model_name: str = Field(...)
img_url: str = Field(...)
extension: str = Field(".png")
exif: bool = Field(False)
DPI: int | None = Field(None)


class VideoEnhanceTaskCreateRequest(BaseModel):
video_url: str = Field(...)
extension: str = Field(".mp4")
model_name: str | None = Field(...)
resolution: list[int] = Field(..., description="Target resolution [width, height]")
original_resolution: list[int] = Field(..., description="Original video resolution [width, height]")


class TaskCreateDataResponse(BaseModel):
job_id: str = Field(...)
consume_coins: int | None = Field(None)


class TaskStatusPollRequest(BaseModel):
job_id: str = Field(...)


class TaskCreateResponse(BaseModel):
code: int = Field(...)
message: str = Field(...)
data: TaskCreateDataResponse | None = Field(None)


class TaskStatusDataResponse(BaseModel):
job_id: str = Field(...)
status: str = Field(...)
res_url: str = Field("")


class TaskStatusResponse(BaseModel):
code: int = Field(...)
message: str = Field(...)
data: TaskStatusDataResponse = Field(...)
Loading
Loading