Skip to content

Commit ddc541f

Browse files
authored
feat(api-nodes): add WaveSpeed nodes (Comfy-Org#11945)
1 parent 8ccc0c9 commit ddc541f

File tree

2 files changed

+213
-0
lines changed

2 files changed

+213
-0
lines changed

comfy_api_nodes/apis/wavespeed.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from pydantic import BaseModel, Field
2+
3+
4+
class SeedVR2ImageRequest(BaseModel):
5+
image: str = Field(...)
6+
target_resolution: str = Field(...)
7+
output_format: str = Field("png")
8+
enable_sync_mode: bool = Field(False)
9+
10+
11+
class FlashVSRRequest(BaseModel):
12+
target_resolution: str = Field(...)
13+
video: str = Field(...)
14+
duration: float = Field(...)
15+
16+
17+
class TaskCreatedDataResponse(BaseModel):
18+
id: str = Field(...)
19+
20+
21+
class TaskCreatedResponse(BaseModel):
22+
code: int = Field(...)
23+
message: str = Field(...)
24+
data: TaskCreatedDataResponse | None = Field(None)
25+
26+
27+
class TaskResultDataResponse(BaseModel):
28+
status: str = Field(...)
29+
outputs: list[str] = Field([])
30+
31+
32+
class TaskResultResponse(BaseModel):
33+
code: int = Field(...)
34+
message: str = Field(...)
35+
data: TaskResultDataResponse | None = Field(None)

comfy_api_nodes/nodes_wavespeed.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
from typing_extensions import override
2+
3+
from comfy_api.latest import IO, ComfyExtension, Input
4+
from comfy_api_nodes.apis.wavespeed import (
5+
FlashVSRRequest,
6+
TaskCreatedResponse,
7+
TaskResultResponse,
8+
SeedVR2ImageRequest,
9+
)
10+
from comfy_api_nodes.util import (
11+
ApiEndpoint,
12+
download_url_to_video_output,
13+
poll_op,
14+
sync_op,
15+
upload_video_to_comfyapi,
16+
validate_container_format_is_mp4,
17+
validate_video_duration,
18+
upload_images_to_comfyapi,
19+
get_number_of_images,
20+
download_url_to_image_tensor,
21+
)
22+
23+
24+
class WavespeedFlashVSRNode(IO.ComfyNode):
25+
@classmethod
26+
def define_schema(cls):
27+
return IO.Schema(
28+
node_id="WavespeedFlashVSRNode",
29+
display_name="FlashVSR Video Upscale",
30+
category="api node/video/WaveSpeed",
31+
description="Fast, high-quality video upscaler that "
32+
"boosts resolution and restores clarity for low-resolution or blurry footage.",
33+
inputs=[
34+
IO.Video.Input("video"),
35+
IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]),
36+
],
37+
outputs=[
38+
IO.Video.Output(),
39+
],
40+
hidden=[
41+
IO.Hidden.auth_token_comfy_org,
42+
IO.Hidden.api_key_comfy_org,
43+
IO.Hidden.unique_id,
44+
],
45+
is_api_node=True,
46+
price_badge=IO.PriceBadge(
47+
depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]),
48+
expr="""
49+
(
50+
$price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032};
51+
{
52+
"type":"usd",
53+
"usd": $lookup($price_for_1sec, widgets.target_resolution),
54+
"format":{"suffix": "/second", "approximate": true}
55+
}
56+
)
57+
""",
58+
),
59+
)
60+
61+
@classmethod
62+
async def execute(
63+
cls,
64+
video: Input.Video,
65+
target_resolution: str,
66+
) -> IO.NodeOutput:
67+
validate_container_format_is_mp4(video)
68+
validate_video_duration(video, min_duration=5, max_duration=60 * 10)
69+
initial_res = await sync_op(
70+
cls,
71+
ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"),
72+
response_model=TaskCreatedResponse,
73+
data=FlashVSRRequest(
74+
target_resolution=target_resolution.lower(),
75+
video=await upload_video_to_comfyapi(cls, video),
76+
duration=video.get_duration(),
77+
),
78+
)
79+
if initial_res.code != 200:
80+
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
81+
final_response = await poll_op(
82+
cls,
83+
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
84+
response_model=TaskResultResponse,
85+
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
86+
poll_interval=10.0,
87+
max_poll_attempts=480,
88+
)
89+
if final_response.code != 200:
90+
raise ValueError(
91+
f"Task processing failed with code={final_response.code} and message={final_response.message}"
92+
)
93+
return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0]))
94+
95+
96+
class WavespeedImageUpscaleNode(IO.ComfyNode):
97+
@classmethod
98+
def define_schema(cls):
99+
return IO.Schema(
100+
node_id="WavespeedImageUpscaleNode",
101+
display_name="WaveSpeed Image Upscale",
102+
category="api node/image/WaveSpeed",
103+
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
104+
inputs=[
105+
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),
106+
IO.Image.Input("image"),
107+
IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]),
108+
],
109+
outputs=[
110+
IO.Image.Output(),
111+
],
112+
hidden=[
113+
IO.Hidden.auth_token_comfy_org,
114+
IO.Hidden.api_key_comfy_org,
115+
IO.Hidden.unique_id,
116+
],
117+
is_api_node=True,
118+
price_badge=IO.PriceBadge(
119+
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
120+
expr="""
121+
(
122+
$prices := {"seedvr2": 0.01, "ultimate": 0.06};
123+
{"type":"usd", "usd": $lookup($prices, widgets.model)}
124+
)
125+
""",
126+
),
127+
)
128+
129+
@classmethod
130+
async def execute(
131+
cls,
132+
model: str,
133+
image: Input.Image,
134+
target_resolution: str,
135+
) -> IO.NodeOutput:
136+
if get_number_of_images(image) != 1:
137+
raise ValueError("Exactly one input image is required.")
138+
if model == "SeedVR2":
139+
model_path = "seedvr2/image"
140+
else:
141+
model_path = "ultimate-image-upscaler"
142+
initial_res = await sync_op(
143+
cls,
144+
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"),
145+
response_model=TaskCreatedResponse,
146+
data=SeedVR2ImageRequest(
147+
target_resolution=target_resolution.lower(),
148+
image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0],
149+
),
150+
)
151+
if initial_res.code != 200:
152+
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
153+
final_response = await poll_op(
154+
cls,
155+
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
156+
response_model=TaskResultResponse,
157+
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
158+
poll_interval=10.0,
159+
max_poll_attempts=480,
160+
)
161+
if final_response.code != 200:
162+
raise ValueError(
163+
f"Task processing failed with code={final_response.code} and message={final_response.message}"
164+
)
165+
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0]))
166+
167+
168+
class WavespeedExtension(ComfyExtension):
169+
@override
170+
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
171+
return [
172+
WavespeedFlashVSRNode,
173+
WavespeedImageUpscaleNode,
174+
]
175+
176+
177+
async def comfy_entrypoint() -> WavespeedExtension:
178+
return WavespeedExtension()

0 commit comments

Comments
 (0)