Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 95 additions & 47 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@
import torch
from diffusers import DiffusionPipeline
from tqdm.auto import tqdm
import copy

from utils import prompt_to_filename, get_noises, TORCH_DTYPE_MAP, get_latent_prep_fn, parse_cli_args, MODEL_NAME_MAP
from utils import (
generate_neighbors,
prompt_to_filename,
get_noises,
TORCH_DTYPE_MAP,
get_latent_prep_fn,
parse_cli_args,
MODEL_NAME_MAP,
)

# Non-configurable constants
TOPK = 1 # Always selecting the top-1 noise for the next round
Expand Down Expand Up @@ -117,72 +124,67 @@ def f(x):
"search_round": search_round,
"num_noises": len(noises),
"best_noise_seed": topk_scores[0]["seed"],
"best_noise": topk_scores[0]["noise"],
"best_score": topk_scores[0][choice_of_metric],
"choice_of_metric": choice_of_metric,
"best_img_path": best_img_path,
}
# Save the best config JSON file alongside the images.
best_json_filename = best_img_path.replace(".png", ".json")
with open(best_json_filename, "w") as f:
json.dump(datapoint, f, indent=4)
datapoint_cp = datapoint.copy()
datapoint_cp.pop("best_noise")
json.dump(datapoint_cp, f, indent=4)
return datapoint


@torch.no_grad()
def main():
"""
Main function:
- Parses CLI arguments.
- Creates an output directory based on verifier and current datetime.
- Loads prompts.
- Loads the image-generation pipeline.
- Loads the verifier model.
- Runs several search rounds where for each prompt a pool of random noises is generated,
candidate images are produced and verified, and the best noise is chosen.
"""
# === Load configuration and CLI arguments ===
args = parse_cli_args()

# Build a config dictionary for parameters that need to be passed around.
with open(args.pipeline_config_path, "r") as f:
config = json.load(f)
config.update(vars(args))

search_rounds = config["search_args"]["search_rounds"]
search_args = config["search_args"]
search_rounds = search_args["search_rounds"]
num_prompts = config["num_prompts"]

# Create a root output directory: output/{verifier_to_use}/{current_datetime}
# === Create output directory ===
current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
pipeline_name = config.pop("pretrained_model_name_or_path")
root_dir = os.path.join(
verifier_name = config["verifier_args"]["name"]
choice_of_metric = config["verifier_args"]["choice_of_metric"]
output_dir = os.path.join(
"output",
MODEL_NAME_MAP[pipeline_name],
config["verifier_args"]["name"],
config["verifier_args"]["choice_of_metric"],
verifier_name,
choice_of_metric,
current_datetime,
)
os.makedirs(root_dir, exist_ok=True)
print(f"Artifacts will be saved to: {root_dir}")
with open(os.path.join(root_dir, "config.json"), "w") as f:
json.dump(config, f)
os.makedirs(output_dir, exist_ok=True)
print(f"Artifacts will be saved to: {output_dir}")
with open(os.path.join(output_dir, "config.json"), "w") as f:
json.dump(config, f, indent=4)

# Load prompts from file.
# === Load prompts ===
if args.prompt is None:
with open("prompts_open_image_pref_v1.txt", "r", encoding="utf-8") as f:
prompts = [line.strip() for line in f.readlines() if line.strip()]
prompts = [line.strip() for line in f if line.strip()]
if num_prompts != "all":
prompts = prompts[:num_prompts]
print(f"Using {len(prompts)} prompt(s).")
else:
prompts = [args.prompt]
print(f"Using {len(prompts)} prompt(s).")

# Set up the image-generation pipeline (on the first GPU if available).
# === Set up the image-generation pipeline ===
torch_dtype = TORCH_DTYPE_MAP[config.pop("torch_dtype")]
pipe = DiffusionPipeline.from_pretrained(pipeline_name, torch_dtype=torch_dtype)
if not config["use_low_gpu_vram"]:
if not config.get("use_low_gpu_vram", False):
pipe = pipe.to("cuda:0")
pipe.set_progress_bar_config(disable=True)

# Load the verifier model.
# === Load verifier model ===
verifier_args = config["verifier_args"]
if verifier_args["name"] == "gemini":
from verifiers import GeminiVerifier
Expand All @@ -191,33 +193,79 @@ def main():
else:
from verifiers.qwen_verifier import QwenVerifier

verifier = QwenVerifier(use_low_gpu_vram=config["use_low_gpu_vram"])
verifier = QwenVerifier(use_low_gpu_vram=config.get("use_low_gpu_vram", False))

# For zero-order search, we store the best datapoint per prompt.
best_datapoint_for_prompt = {}

# === Main loop: For each search round and each prompt ===
for search_round in range(1, search_rounds + 1):
print(f"\n=== Round: {search_round} ===")
# For non-zero-order, the noise pool scales with the round.
num_noises_to_sample = 2**search_round if search_args["search_method"] != "zero-order" else 1

# Main loop: For each search round and each prompt, generate images, verify, and save artifacts.
for round in range(1, search_rounds + 1):
print(f"\n=== Round: {round} ===")
num_noises_to_sample = 2**round # scale noise pool.
for prompt in tqdm(prompts, desc="Sampling prompts"):
noises = get_noises(
max_seed=MAX_SEED,
num_samples=num_noises_to_sample,
height=config["pipeline_call_args"]["height"],
width=config["pipeline_call_args"]["width"],
dtype=torch_dtype,
fn=get_latent_prep_fn(pipeline_name),
)
print(f"Number of noise samples: {len(noises)}")
datapoint_for_current_round = sample(
# --- Generate noise pool ---
if search_args["search_method"] != "zero-order":
# Standard noise sampling.
noises = get_noises(
max_seed=MAX_SEED,
num_samples=num_noises_to_sample,
height=config["pipeline_call_args"]["height"],
width=config["pipeline_call_args"]["width"],
dtype=torch_dtype,
fn=get_latent_prep_fn(pipeline_name),
)
elif search_args["search_method"] == "zero-order":
if search_round == 1:
# First round: sample initial noise(s)
noises = get_noises(
max_seed=MAX_SEED,
num_samples=num_noises_to_sample,
height=config["pipeline_call_args"]["height"],
width=config["pipeline_call_args"]["width"],
dtype=torch_dtype,
fn=get_latent_prep_fn(pipeline_name),
)
else:
# Subsequent rounds: use the best noise from the current round.
prev_dp = best_datapoint_for_prompt[prompt]
# Note: assuming the key is "best_noise_seed" (fixing the typo "seeed").
noises = {int(prev_dp["best_noise_seed"]): prev_dp["best_noise"]}

# --- Process the single noise to generate neighbors ---
# Extract the base noise and its seed.
base_seed, base_noise = next(iter(noises.items()))
# Generate neighbors from the base noise (after squeezing if needed).
neighbors = generate_neighbors(
base_noise, threshold=search_args["threshold"], num_neighbors=search_args["num_neighbors"]
).squeeze(0)
# Concatenate the base noise with its neighbors.
neighbors_and_noise = torch.cat([base_noise, neighbors], dim=0)
# Build a new dictionary mapping updated seeds to each noise.
new_noises = {}
for i, noise_tensor in enumerate(neighbors_and_noise):
new_noises[base_seed + i] = noise_tensor.unsqueeze(0)
noises = new_noises

print(f"Number of noise samples for prompt '{prompt}': {len(noises)}")

# --- Sampling, verifying, and saving artifacts ---
datapoint = sample(
noises=noises,
prompt=prompt,
search_round=round,
search_round=search_round,
pipe=pipe,
verifier=verifier,
topk=TOPK,
root_dir=root_dir,
root_dir=output_dir,
config=config,
)

# Update the best noise for zero-order search.
if search_args["search_method"] == "zero-order":
best_datapoint_for_prompt[prompt] = datapoint


if __name__ == "__main__":
main()
34 changes: 34 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import requests
import argparse
import io
import numpy as np
import torch.nn.functional as F


TORCH_DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
Expand Down Expand Up @@ -77,6 +79,11 @@ def validate_args(args):
element in config_keys for element in MANDATORY_CONFIG_KEYS
), f"Expected the following keys to be present: {MANDATORY_CONFIG_KEYS} but got: {config_keys}."

_validate_verifier_args(config)
_validate_search_args(config)


def _validate_verifier_args(config):
from verifiers import SUPPORTED_VERIFIERS, SUPPORTED_METRICS

verifier_args = config["verifier_args"]
Expand All @@ -93,6 +100,16 @@ def validate_args(args):
), f"Unsupported metric provided: {choice_of_metric}, supported ones are: {supported_metrics}."


def _validate_search_args(config):
search_args = config["search_args"]
search_method = search_args["search_method"]
supported_search_methods = ["random", "zero-order"]

assert (
search_method in supported_search_methods
), f"Unsupported search method provided: {search_method}, supported ones are: {supported_search_methods}."


# Adapted from Diffusers.
def prepare_latents_for_flux(
batch_size: int,
Expand Down Expand Up @@ -166,6 +183,23 @@ def get_noises(
return noises


def generate_neighbors(x, threshold=0.95, num_neighbors=4):
"""Courtesy: Willis Ma"""
rng = np.random.Generator(np.random.PCG64())
x_f = x.flatten(1)
x_norm = torch.linalg.norm(x_f, dim=-1, keepdim=True, dtype=torch.float64).unsqueeze(-2)
u = x_f.unsqueeze(-2) / x_norm.clamp_min(1e-12)
v = torch.from_numpy(rng.standard_normal(size=(u.shape[0], num_neighbors, u.shape[-1]), dtype=np.float64)).to(
u.device
)
w = F.normalize(v - (v @ u.transpose(-2, -1)) * u, dim=-1)
return (
(x_norm * (threshold * u + np.sqrt(1 - threshold**2) * w))
.reshape(x.shape[0], num_neighbors, *x.shape[1:])
.to(x.dtype)
)


def load_verifier_prompt(path: str) -> str:
with open(path, "r") as f:
verifier_prompt = f.read().replace('"""', "")
Expand Down