Skip to content

Commit d3d9c45

Browse files
authored
Add ltx support for video inference-time scaling (#18)
* quick video support in hackiest way * serialization * config * updates * moviepy. * fixes * support wan. * fixes * get vae in fp32 when using wan. * examples. * date * fixes
1 parent bd9314c commit d3d9c45

File tree

7 files changed

+379
-58
lines changed

7 files changed

+379
-58
lines changed

README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ But it's been growing now! Check out the rest of the README to know more 🤗
1414

1515
**Updates**
1616

17+
🔥 04/03/2025: Support for LTX-Video and Wan in [this PR](https://github.com/sayakpaul/tt-scale-flux/pull/18) 🎬 Check out [this section](#videos) for results and more info.
18+
1719
🔥 01/03/2025: `OpenAIVerifier` was added in [this PR](https://github.com/sayakpaul/tt-scale-flux/pull/16). Specify "openai" in the `name` under `verifier_args`. Thanks to [zhuole1025](https://github.com/zhuole1025) for contributing this!
1820

1921
🔥 27/02/2025: [MaximClouser](https://github.com/MaximClouser) implemented a ComfyUI node for inference-time
@@ -442,6 +444,50 @@ between the outputs of different metrics -- "overall_score" vs. "emotional_or_th
442444

443445
</details>&nbsp;&nbsp;
444446

447+
## Videos
448+
449+
We currently support [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video) and [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan). Only LAION aeshtetic scoring is
450+
supported for these. Checkout the LTX and Wan configs [here](./configs/ltx_video.json) and [here](./configs/wan.json).
451+
452+
<details>
453+
<summary>Expand for results</summary>
454+
455+
<table>
456+
<tr>
457+
<th>Wan</th>
458+
</tr>
459+
<tr>
460+
<td>
461+
<video width="320" height="240" controls>
462+
<source src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/tt-scale-flux/videos/collage_Two_anthropomorphic_cats_in_comfy_boxing_gear_and_bright_gloves_fight_intensely_on_a__i%401-4.mp4" type="video/mp4">
463+
</video>
464+
<br>
465+
<i>Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.<i>
466+
</td>
467+
</tr>
468+
</table>
469+
<sup>Check the video manually <a href=https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/tt-scale-flux/videos/collage_Two_anthropomorphic_cats_in_comfy_boxing_gear_and_bright_gloves_fight_intensely_on_a__i%401-4.mp4>here</a> if it doesn't show up.
470+
<br>
471+
472+
<table>
473+
<tr>
474+
<th>LTX-Video</th>
475+
</tr>
476+
<tr>
477+
<td>
478+
<video width="320" height="240" controls>
479+
<source src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/tt-scale-flux/videos/collage_The_camera_pans_over_a_snow_covered_mountain_range_revealing_a_vast_expanse_of_snow_c_i%401-4.mp4" type="video/mp4">
480+
</video>
481+
<br>
482+
<br>
483+
<i>The camera pans over a snow-covered mountain range, revealing a vast expanse of snow-capped peaks and valleys.The mountains are covered in a thick layer of snow, with some areas appearing almost white while others have a slightly darker, almost grayish hue. The peaks are jagged and irregular, with some rising sharply into the sky while others are more rounded. The valleys are deep and narrow, with steep slopes that are also covered in snow. The trees in the foreground are mostly bare, with only a few leaves remaining on their branches. The sky is overcast, with thick clouds obscuring the sun. The overall impression is one of peace and tranquility, with the snow-covered mountains standing as a testament to the power and beauty of nature.<i>
484+
</td>
485+
</tr>
486+
</table>
487+
<sup>Check the video manually <a href=https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/tt-scale-flux/videos/collage_The_camera_pans_over_a_snow_covered_mountain_range_revealing_a_vast_expanse_of_snow_c_i%401-4.mp4>here</a> if it doesn't show up.
488+
489+
</details>
490+
445491
## Acknowledgements
446492

447493
* Thanks to [Willis Ma](https://twitter.com/ma_nanye) for all the guidance and pair-coding.

configs/ltx_video.json

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"pretrained_model_name_or_path": "a-r-r-o-w/LTX-Video-0.9.1-diffusers",
3+
"torch_dtype": "bf16",
4+
"pipeline_call_args": {
5+
"height": 480,
6+
"width": 704,
7+
"num_frames": 161,
8+
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
9+
"num_inference_steps": 50
10+
},
11+
"verifier_args": {
12+
"name": "laion_aesthetic",
13+
"choice_of_metric": "laion_aesthetic_score"
14+
},
15+
"search_args": {
16+
"search_method": "random",
17+
"search_rounds": 4
18+
},
19+
"export_args": {"fps": 24}
20+
}

configs/wan.json

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
3+
"torch_dtype": "bf16",
4+
"pipeline_call_args": {
5+
"height": 480,
6+
"width": 832,
7+
"num_frames": 81,
8+
"guidance_scale": 5.0,
9+
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
10+
"num_inference_steps": 50
11+
},
12+
"verifier_args": {
13+
"name": "laion_aesthetic",
14+
"choice_of_metric": "laion_aesthetic_score"
15+
},
16+
"search_args": {
17+
"search_method": "random",
18+
"search_rounds": 4
19+
},
20+
"export_args": {"fps": 15}
21+
}

main.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
22
import json
33
from datetime import datetime
4-
4+
from PIL import Image
55
import numpy as np
66
import torch
77
from diffusers import DiffusionPipeline
88
from tqdm.auto import tqdm
9+
import tempfile
10+
from diffusers.utils import export_to_video
911

1012
from utils import (
1113
generate_neighbors,
@@ -16,6 +18,7 @@
1618
parse_cli_args,
1719
serialize_artifacts,
1820
MODEL_NAME_MAP,
21+
prepare_video_frames,
1922
)
2023
from verifiers import SUPPORTED_VERIFIERS
2124

@@ -42,7 +45,6 @@ def sample(
4245
use_low_gpu_vram = config.get("use_low_gpu_vram", False)
4346
batch_size_for_img_gen = config.get("batch_size_for_img_gen", 1)
4447
verifier_args = config.get("verifier_args")
45-
max_new_tokens = verifier_args.get("max_new_tokens", None)
4648
choice_of_metric = verifier_args.get("choice_of_metric", None)
4749
verifier_to_use = verifier_args.get("name", "gemini")
4850
search_args = config.get("search_args", None)
@@ -57,11 +59,18 @@ def sample(
5759
noise_items = list(noises.items())
5860

5961
# Process the noises in batches.
62+
# TODO: find better way
63+
extension_to_use = "png"
64+
if "LTX" in pipe.__class__.__name__:
65+
extension_to_use = "mp4"
66+
elif "Wan" in pipe.__class__.__name__:
67+
extension_to_use = "mp4"
6068
for i in range(0, len(noise_items), batch_size_for_img_gen):
6169
batch = noise_items[i : i + batch_size_for_img_gen]
6270
seeds_batch, noises_batch = zip(*batch)
6371
filenames_batch = [
64-
os.path.join(root_dir, f"{prompt_filename}_i@{search_round}_s@{seed}.png") for seed in seeds_batch
72+
os.path.join(root_dir, f"{prompt_filename}_i@{search_round}_s@{seed}.{extension_to_use}")
73+
for seed in seeds_batch
6574
]
6675

6776
if use_low_gpu_vram and verifier_to_use != "gemini":
@@ -73,7 +82,11 @@ def sample(
7382
batched_latents = torch.stack(noises_batch).squeeze(dim=1)
7483

7584
batch_result = pipe(prompt=batched_prompts, latents=batched_latents, **config["pipeline_call_args"])
76-
batch_images = batch_result.images
85+
if hasattr(batch_result, "images"):
86+
batch_images = batch_result.images
87+
elif hasattr(batch_result, "frames"):
88+
batch_images = [vid for vid in batch_result.frames]
89+
7790
if use_low_gpu_vram and verifier_to_use != "gemini":
7891
pipe = pipe.to("cpu")
7992

@@ -85,15 +98,34 @@ def sample(
8598
images_info.append((seed, noise, image, filename))
8699

87100
# Prepare verifier inputs and perform inference.
88-
verifier_inputs = verifier.prepare_inputs(images=images_for_prompt, prompts=[prompt] * len(images_for_prompt))
101+
if isinstance(images_for_prompt[0], Image.Image):
102+
verifier_inputs = verifier.prepare_inputs(images=images_for_prompt, prompts=[prompt] * len(images_for_prompt))
103+
else:
104+
export_args = config.get("export_args", None) or {}
105+
if export_args:
106+
fps = export_args.get("fps", 24)
107+
else:
108+
fps = 24
109+
temp_vid_paths = []
110+
with tempfile.TemporaryDirectory() as tmpdir:
111+
for idx, vid in enumerate(images_for_prompt):
112+
vid_path = os.path.join(tmpdir, f"{idx}.mp4")
113+
export_to_video(vid, vid_path, fps=fps)
114+
temp_vid_paths.append(vid_path)
115+
116+
verifier_inputs = []
117+
for vid_path in temp_vid_paths:
118+
frames = prepare_video_frames(vid_path)
119+
verifier_inputs.append(verifier.prepare_inputs(images=frames, prompts=[prompt] * len(frames)))
120+
89121
print("Scoring with the verifier.")
90122
outputs = verifier.score(inputs=verifier_inputs)
91123
for o in outputs:
92124
assert choice_of_metric in o, o.keys()
93125

94-
assert len(outputs) == len(images_for_prompt), (
95-
f"Expected len(outputs) to be same as len(images_for_prompt) but got {len(outputs)=} & {len(images_for_prompt)=}"
96-
)
126+
assert (
127+
len(outputs) == len(images_for_prompt)
128+
), f"Expected len(outputs) to be same as len(images_for_prompt) but got {len(outputs)=} & {len(images_for_prompt)=}"
97129

98130
results = []
99131
for json_dict, seed_val, noise in zip(outputs, seeds_used, noises_used):
@@ -114,7 +146,9 @@ def f(x):
114146
for ts in topk_scores:
115147
print(f"Prompt='{prompt}' | Best seed={ts['seed']} | Score={ts[choice_of_metric]}")
116148

117-
best_img_path = os.path.join(root_dir, f"{prompt_filename}_i@{search_round}_s@{topk_scores[0]['seed']}.png")
149+
best_img_path = os.path.join(
150+
root_dir, f"{prompt_filename}_i@{search_round}_s@{topk_scores[0]['seed']}.{extension_to_use}"
151+
)
118152
datapoint = {
119153
"prompt": prompt,
120154
"search_round": search_round,
@@ -136,11 +170,11 @@ def f(x):
136170
# Serialize.
137171
if search_method == "zero-order":
138172
if datapoint["neighbors_improvement"]:
139-
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint)
173+
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint, **export_args)
140174
else:
141175
print("Skipping serialization as there was no improvement in this round.")
142176
elif search_method == "random":
143-
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint)
177+
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint, **export_args)
144178

145179
return datapoint
146180

@@ -187,7 +221,14 @@ def main():
187221

188222
# === Set up the image-generation pipeline ===
189223
torch_dtype = TORCH_DTYPE_MAP[config.pop("torch_dtype")]
190-
pipe = DiffusionPipeline.from_pretrained(pipeline_name, torch_dtype=torch_dtype)
224+
fp_kwargs = {"pretrained_model_name_or_path": pipeline_name, "torch_dtype": torch_dtype}
225+
if "Wan" in pipeline_name:
226+
# As per recommendations from https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan.
227+
from diffusers import AutoencoderKLWan
228+
229+
vae = AutoencoderKLWan.from_pretrained(pipeline_name, subfolder="vae", torch_dtype=torch.float32)
230+
fp_kwargs.update({"vae": vae})
231+
pipe = DiffusionPipeline.from_pretrained(**fp_kwargs)
191232
if not config.get("use_low_gpu_vram", False):
192233
pipe = pipe.to("cuda:0")
193234
pipe.set_progress_bar_config(disable=True)
@@ -201,6 +242,7 @@ def main():
201242
verifier = verifier_cls(**verifier_args)
202243

203244
# === Main loop: For each prompt and each search round ===
245+
pipeline_call_args = config["pipeline_call_args"].copy()
204246
for prompt in tqdm(prompts, desc="Processing prompts"):
205247
search_round = 1
206248

@@ -234,10 +276,9 @@ def main():
234276
noises = get_noises(
235277
max_seed=MAX_SEED,
236278
num_samples=num_noises_to_sample,
237-
height=config["pipeline_call_args"]["height"],
238-
width=config["pipeline_call_args"]["width"],
239279
dtype=torch_dtype,
240280
fn=get_latent_prep_fn(pipeline_name),
281+
**pipeline_call_args,
241282
)
242283
else:
243284
if best_datapoint_per_round[previous_round]:

0 commit comments

Comments
 (0)