Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ But it's been growing now! Check out the rest of the README to know more 🤗

**Updates**

🔥 04/03/2025: Support for LTX-Video and Wan in [this PR](https://github.com/sayakpaul/tt-scale-flux/pull/18) 🎬 Check out [this section](#videos) for results and more info.

🔥 01/03/2025: `OpenAIVerifier` was added in [this PR](https://github.com/sayakpaul/tt-scale-flux/pull/16). Specify "openai" in the `name` under `verifier_args`. Thanks to [zhuole1025](https://github.com/zhuole1025) for contributing this!

🔥 27/02/2025: [MaximClouser](https://github.com/MaximClouser) implemented a ComfyUI node for inference-time
Expand Down Expand Up @@ -442,6 +444,50 @@ between the outputs of different metrics -- "overall_score" vs. "emotional_or_th

</details>&nbsp;&nbsp;

## Videos

We currently support [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video) and [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan). Only LAION aeshtetic scoring is
supported for these. Checkout the LTX and Wan configs [here](./configs/ltx_video.json) and [here](./configs/wan.json).

<details>
<summary>Expand for results</summary>

<table>
<tr>
<th>Wan</th>
</tr>
<tr>
<td>
<video width="320" height="240" controls>
<source src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/tt-scale-flux/videos/collage_Two_anthropomorphic_cats_in_comfy_boxing_gear_and_bright_gloves_fight_intensely_on_a__i%401-4.mp4" type="video/mp4">
</video>
<br>
<i>Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.<i>
</td>
</tr>
</table>
<sup>Check the video manually <a href=https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/tt-scale-flux/videos/collage_Two_anthropomorphic_cats_in_comfy_boxing_gear_and_bright_gloves_fight_intensely_on_a__i%401-4.mp4>here</a> if it doesn't show up.
<br>

<table>
<tr>
<th>LTX-Video</th>
</tr>
<tr>
<td>
<video width="320" height="240" controls>
<source src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/tt-scale-flux/videos/collage_The_camera_pans_over_a_snow_covered_mountain_range_revealing_a_vast_expanse_of_snow_c_i%401-4.mp4" type="video/mp4">
</video>
<br>
<br>
<i>The camera pans over a snow-covered mountain range, revealing a vast expanse of snow-capped peaks and valleys.The mountains are covered in a thick layer of snow, with some areas appearing almost white while others have a slightly darker, almost grayish hue. The peaks are jagged and irregular, with some rising sharply into the sky while others are more rounded. The valleys are deep and narrow, with steep slopes that are also covered in snow. The trees in the foreground are mostly bare, with only a few leaves remaining on their branches. The sky is overcast, with thick clouds obscuring the sun. The overall impression is one of peace and tranquility, with the snow-covered mountains standing as a testament to the power and beauty of nature.<i>
</td>
</tr>
</table>
<sup>Check the video manually <a href=https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/tt-scale-flux/videos/collage_The_camera_pans_over_a_snow_covered_mountain_range_revealing_a_vast_expanse_of_snow_c_i%401-4.mp4>here</a> if it doesn't show up.

</details>

## Acknowledgements

* Thanks to [Willis Ma](https://twitter.com/ma_nanye) for all the guidance and pair-coding.
Expand Down
20 changes: 20 additions & 0 deletions configs/ltx_video.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"pretrained_model_name_or_path": "a-r-r-o-w/LTX-Video-0.9.1-diffusers",
"torch_dtype": "bf16",
"pipeline_call_args": {
"height": 480,
"width": 704,
"num_frames": 161,
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
"num_inference_steps": 50
},
"verifier_args": {
"name": "laion_aesthetic",
"choice_of_metric": "laion_aesthetic_score"
},
"search_args": {
"search_method": "random",
"search_rounds": 4
},
"export_args": {"fps": 24}
}
21 changes: 21 additions & 0 deletions configs/wan.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"torch_dtype": "bf16",
"pipeline_call_args": {
"height": 480,
"width": 832,
"num_frames": 81,
"guidance_scale": 5.0,
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
"num_inference_steps": 50
},
"verifier_args": {
"name": "laion_aesthetic",
"choice_of_metric": "laion_aesthetic_score"
},
"search_args": {
"search_method": "random",
"search_rounds": 4
},
"export_args": {"fps": 15}
}
69 changes: 55 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import json
from datetime import datetime

from PIL import Image
import numpy as np
import torch
from diffusers import DiffusionPipeline
from tqdm.auto import tqdm
import tempfile
from diffusers.utils import export_to_video

from utils import (
generate_neighbors,
Expand All @@ -16,6 +18,7 @@
parse_cli_args,
serialize_artifacts,
MODEL_NAME_MAP,
prepare_video_frames,
)
from verifiers import SUPPORTED_VERIFIERS

Expand All @@ -42,7 +45,6 @@ def sample(
use_low_gpu_vram = config.get("use_low_gpu_vram", False)
batch_size_for_img_gen = config.get("batch_size_for_img_gen", 1)
verifier_args = config.get("verifier_args")
max_new_tokens = verifier_args.get("max_new_tokens", None)
choice_of_metric = verifier_args.get("choice_of_metric", None)
verifier_to_use = verifier_args.get("name", "gemini")
search_args = config.get("search_args", None)
Expand All @@ -57,11 +59,18 @@ def sample(
noise_items = list(noises.items())

# Process the noises in batches.
# TODO: find better way
extension_to_use = "png"
if "LTX" in pipe.__class__.__name__:
extension_to_use = "mp4"
elif "Wan" in pipe.__class__.__name__:
extension_to_use = "mp4"
for i in range(0, len(noise_items), batch_size_for_img_gen):
batch = noise_items[i : i + batch_size_for_img_gen]
seeds_batch, noises_batch = zip(*batch)
filenames_batch = [
os.path.join(root_dir, f"{prompt_filename}_i@{search_round}_s@{seed}.png") for seed in seeds_batch
os.path.join(root_dir, f"{prompt_filename}_i@{search_round}_s@{seed}.{extension_to_use}")
for seed in seeds_batch
]

if use_low_gpu_vram and verifier_to_use != "gemini":
Expand All @@ -73,7 +82,11 @@ def sample(
batched_latents = torch.stack(noises_batch).squeeze(dim=1)

batch_result = pipe(prompt=batched_prompts, latents=batched_latents, **config["pipeline_call_args"])
batch_images = batch_result.images
if hasattr(batch_result, "images"):
batch_images = batch_result.images
elif hasattr(batch_result, "frames"):
batch_images = [vid for vid in batch_result.frames]

if use_low_gpu_vram and verifier_to_use != "gemini":
pipe = pipe.to("cpu")

Expand All @@ -85,15 +98,34 @@ def sample(
images_info.append((seed, noise, image, filename))

# Prepare verifier inputs and perform inference.
verifier_inputs = verifier.prepare_inputs(images=images_for_prompt, prompts=[prompt] * len(images_for_prompt))
if isinstance(images_for_prompt[0], Image.Image):
verifier_inputs = verifier.prepare_inputs(images=images_for_prompt, prompts=[prompt] * len(images_for_prompt))
else:
export_args = config.get("export_args", None) or {}
if export_args:
fps = export_args.get("fps", 24)
else:
fps = 24
temp_vid_paths = []
with tempfile.TemporaryDirectory() as tmpdir:
for idx, vid in enumerate(images_for_prompt):
vid_path = os.path.join(tmpdir, f"{idx}.mp4")
export_to_video(vid, vid_path, fps=fps)
temp_vid_paths.append(vid_path)

verifier_inputs = []
for vid_path in temp_vid_paths:
frames = prepare_video_frames(vid_path)
verifier_inputs.append(verifier.prepare_inputs(images=frames, prompts=[prompt] * len(frames)))

print("Scoring with the verifier.")
outputs = verifier.score(inputs=verifier_inputs)
for o in outputs:
assert choice_of_metric in o, o.keys()

assert len(outputs) == len(images_for_prompt), (
f"Expected len(outputs) to be same as len(images_for_prompt) but got {len(outputs)=} & {len(images_for_prompt)=}"
)
assert (
len(outputs) == len(images_for_prompt)
), f"Expected len(outputs) to be same as len(images_for_prompt) but got {len(outputs)=} & {len(images_for_prompt)=}"

results = []
for json_dict, seed_val, noise in zip(outputs, seeds_used, noises_used):
Expand All @@ -114,7 +146,9 @@ def f(x):
for ts in topk_scores:
print(f"Prompt='{prompt}' | Best seed={ts['seed']} | Score={ts[choice_of_metric]}")

best_img_path = os.path.join(root_dir, f"{prompt_filename}_i@{search_round}_s@{topk_scores[0]['seed']}.png")
best_img_path = os.path.join(
root_dir, f"{prompt_filename}_i@{search_round}_s@{topk_scores[0]['seed']}.{extension_to_use}"
)
datapoint = {
"prompt": prompt,
"search_round": search_round,
Expand All @@ -136,11 +170,11 @@ def f(x):
# Serialize.
if search_method == "zero-order":
if datapoint["neighbors_improvement"]:
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint)
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint, **export_args)
else:
print("Skipping serialization as there was no improvement in this round.")
elif search_method == "random":
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint)
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint, **export_args)

return datapoint

Expand Down Expand Up @@ -187,7 +221,14 @@ def main():

# === Set up the image-generation pipeline ===
torch_dtype = TORCH_DTYPE_MAP[config.pop("torch_dtype")]
pipe = DiffusionPipeline.from_pretrained(pipeline_name, torch_dtype=torch_dtype)
fp_kwargs = {"pretrained_model_name_or_path": pipeline_name, "torch_dtype": torch_dtype}
if "Wan" in pipeline_name:
# As per recommendations from https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan.
from diffusers import AutoencoderKLWan

vae = AutoencoderKLWan.from_pretrained(pipeline_name, subfolder="vae", torch_dtype=torch.float32)
fp_kwargs.update({"vae": vae})
pipe = DiffusionPipeline.from_pretrained(**fp_kwargs)
if not config.get("use_low_gpu_vram", False):
pipe = pipe.to("cuda:0")
pipe.set_progress_bar_config(disable=True)
Expand All @@ -201,6 +242,7 @@ def main():
verifier = verifier_cls(**verifier_args)

# === Main loop: For each prompt and each search round ===
pipeline_call_args = config["pipeline_call_args"].copy()
for prompt in tqdm(prompts, desc="Processing prompts"):
search_round = 1

Expand Down Expand Up @@ -234,10 +276,9 @@ def main():
noises = get_noises(
max_seed=MAX_SEED,
num_samples=num_noises_to_sample,
height=config["pipeline_call_args"]["height"],
width=config["pipeline_call_args"]["width"],
dtype=torch_dtype,
fn=get_latent_prep_fn(pipeline_name),
**pipeline_call_args,
)
else:
if best_datapoint_per_round[previous_round]:
Expand Down
Loading