Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,32 @@ The verifier prompt that is used during grading/verification is specified in [th
the paper (Inference-Time Scaling for Diffusion Models beyond Scaling Denoising Steps). You are welcome to
experiment with a different prompt.

### Controlling search

You can configure search related arguments through `search_args` in the configuration file. Currently, "random search" and "zero-order search" are supported. The default configurations provided under [`configs/`](./configs/)
are all for random search.

Below is a configuration for zero-order search:

```json
"search_args": {
"search_method": "zero-order",
"search_rounds": 4,
"threshold": 0.95,
"num_neighbors": 4
}
```

<details>
<summary>For details about the parameters</summary>

* `threshold`: threshold to use for filtering out neighbor candidates from the base noise
* `num_neighbors`: number of neighbors to generate from the base noise

</details>

If the neighbors do not improve the current search round results, we simply reject the round.

## More results

<details>
Expand Down
206 changes: 140 additions & 66 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@
import torch
from diffusers import DiffusionPipeline
from tqdm.auto import tqdm
import copy

from utils import prompt_to_filename, get_noises, TORCH_DTYPE_MAP, get_latent_prep_fn, parse_cli_args, MODEL_NAME_MAP
from utils import (
generate_neighbors,
prompt_to_filename,
get_noises,
TORCH_DTYPE_MAP,
get_latent_prep_fn,
parse_cli_args,
serialize_artifacts,
MODEL_NAME_MAP,
)

# Non-configurable constants
TOPK = 1 # Always selecting the top-1 noise for the next round
Expand All @@ -28,18 +36,20 @@ def sample(
"""
For a given prompt, generate images using all provided noises in batches,
score them with the verifier, and select the top-K noise.
The images and JSON artifacts are saved under `root_dir`.
The images and JSON artifacts are serialized via `serialize_artifacts`.
"""
use_low_gpu_vram = config.get("use_low_gpu_vram", False)
batch_size_for_img_gen = config.get("batch_size_for_img_gen", 1)
verifier_args = config.get("verifier_args")
max_new_tokens = verifier_args.get("max_new_tokens", None)
choice_of_metric = verifier_args.get("choice_of_metric", None)
verifier_to_use = verifier_args.get("name", "gemini")
search_args = config.get("search_args", None)

images_for_prompt = []
noises_used = []
seeds_used = []
images_info = [] # Will collect (seed, noise, image, filename) tuples for serialization.
prompt_filename = prompt_to_filename(prompt)

# Convert the noises dictionary into a list of (seed, noise) tuples.
Expand All @@ -55,7 +65,7 @@ def sample(

if use_low_gpu_vram and verifier_to_use != "gemini":
pipe = pipe.to("cuda:0")
print(f"Generating images for batch with seeds: {[s for s in seeds_batch]}.")
print(f"Generating images for batch with seeds: {list(seeds_batch)}.")

# Create a batched prompt list and stack the latents.
batched_prompts = [prompt] * len(noises_batch)
Expand All @@ -66,12 +76,12 @@ def sample(
if use_low_gpu_vram and verifier_to_use != "gemini":
pipe = pipe.to("cpu")

# Iterate over the batch and save the images.
# Collect the images and corresponding info.
for seed, noise, image, filename in zip(seeds_batch, noises_batch, batch_images, filenames_batch):
images_for_prompt.append(image)
noises_used.append(noise)
seeds_used.append(seed)
image.save(filename)
images_info.append((seed, noise, image, filename))

# Prepare verifier inputs and perform inference.
verifier_inputs = verifier.prepare_inputs(images=images_for_prompt, prompts=[prompt] * len(images_for_prompt))
Expand All @@ -83,23 +93,18 @@ def sample(
for o in outputs:
assert choice_of_metric in o, o.keys()

assert (
len(outputs) == len(images_for_prompt)
), f"Expected len(outputs) to be same as len(images_for_prompt) but got {len(outputs)=} & {len(images_for_prompt)=}"
assert len(outputs) == len(images_for_prompt), (
f"Expected len(outputs) to be same as len(images_for_prompt) but got {len(outputs)=} & {len(images_for_prompt)=}"
)

results = []
for json_dict, seed_val, noise in zip(outputs, seeds_used, noises_used):
# Attach the noise tensor so we can select top-K.
# Merge verifier outputs with noise info.
merged = {**json_dict, "noise": noise, "seed": seed_val}
results.append(merged)

# Sort by the chosen metric descending and pick top-K.
for x in results:
assert (
choice_of_metric in x
), f"Expected all dicts in `results` to contain the `{choice_of_metric}` key; got {x.keys()}."

def f(x):
# If the verifier output is a dict, assume it contains a "score" key.
if isinstance(x[choice_of_metric], dict):
return x[choice_of_metric]["score"]
return x[choice_of_metric]
Expand All @@ -117,72 +122,79 @@ def f(x):
"search_round": search_round,
"num_noises": len(noises),
"best_noise_seed": topk_scores[0]["seed"],
"best_noise": topk_scores[0]["noise"],
"best_score": topk_scores[0][choice_of_metric],
"choice_of_metric": choice_of_metric,
"best_img_path": best_img_path,
}
# Save the best config JSON file alongside the images.
best_json_filename = best_img_path.replace(".png", ".json")
with open(best_json_filename, "w") as f:
json.dump(datapoint, f, indent=4)

# Check if the neighbors have any improvements (zero-order only).
search_method = search_args.get("search_method", "random") if search_args else "random"
if search_args and search_method == "zero-order":
first_score = f(results[0])
neighbors_with_better_score = any(f(item) > first_score for item in results[1:])
datapoint["neighbors_improvement"] = neighbors_with_better_score

# Serialize.
if search_method == "zero-order":
if datapoint["neighbors_improvement"]:
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint)
else:
print("Skipping serialization as there was no improvement in this round.")
elif search_method == "random":
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint)

return datapoint


@torch.no_grad()
def main():
"""
Main function:
- Parses CLI arguments.
- Creates an output directory based on verifier and current datetime.
- Loads prompts.
- Loads the image-generation pipeline.
- Loads the verifier model.
- Runs several search rounds where for each prompt a pool of random noises is generated,
candidate images are produced and verified, and the best noise is chosen.
"""
# === Load configuration and CLI arguments ===
args = parse_cli_args()

# Build a config dictionary for parameters that need to be passed around.
with open(args.pipeline_config_path, "r") as f:
config = json.load(f)
config.update(vars(args))

search_rounds = config["search_args"]["search_rounds"]
search_args = config["search_args"]
search_rounds = search_args["search_rounds"]
search_method = search_args.get("search_method", "random")
num_prompts = config["num_prompts"]

# Create a root output directory: output/{verifier_to_use}/{current_datetime}
# === Create output directory ===
current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
pipeline_name = config.pop("pretrained_model_name_or_path")
root_dir = os.path.join(
verifier_name = config["verifier_args"]["name"]
choice_of_metric = config["verifier_args"]["choice_of_metric"]
output_dir = os.path.join(
"output",
MODEL_NAME_MAP[pipeline_name],
config["verifier_args"]["name"],
config["verifier_args"]["choice_of_metric"],
verifier_name,
choice_of_metric,
current_datetime,
)
os.makedirs(root_dir, exist_ok=True)
print(f"Artifacts will be saved to: {root_dir}")
with open(os.path.join(root_dir, "config.json"), "w") as f:
json.dump(config, f)
os.makedirs(output_dir, exist_ok=True)
print(f"Artifacts will be saved to: {output_dir}")
with open(os.path.join(output_dir, "config.json"), "w") as f:
json.dump(config, f, indent=4)

# Load prompts from file.
# === Load prompts ===
if args.prompt is None:
with open("prompts_open_image_pref_v1.txt", "r", encoding="utf-8") as f:
prompts = [line.strip() for line in f.readlines() if line.strip()]
prompts = [line.strip() for line in f if line.strip()]
if num_prompts != "all":
prompts = prompts[:num_prompts]
print(f"Using {len(prompts)} prompt(s).")
else:
prompts = [args.prompt]
print(f"Using {len(prompts)} prompt(s).")

# Set up the image-generation pipeline (on the first GPU if available).
# === Set up the image-generation pipeline ===
torch_dtype = TORCH_DTYPE_MAP[config.pop("torch_dtype")]
pipe = DiffusionPipeline.from_pretrained(pipeline_name, torch_dtype=torch_dtype)
if not config["use_low_gpu_vram"]:
if not config.get("use_low_gpu_vram", False):
pipe = pipe.to("cuda:0")
pipe.set_progress_bar_config(disable=True)

# Load the verifier model.
# === Load verifier model ===
verifier_args = config["verifier_args"]
if verifier_args["name"] == "gemini":
from verifiers import GeminiVerifier
Expand All @@ -191,33 +203,95 @@ def main():
else:
from verifiers.qwen_verifier import QwenVerifier

verifier = QwenVerifier(use_low_gpu_vram=config["use_low_gpu_vram"])

# Main loop: For each search round and each prompt, generate images, verify, and save artifacts.
for round in range(1, search_rounds + 1):
print(f"\n=== Round: {round} ===")
num_noises_to_sample = 2**round # scale noise pool.
for prompt in tqdm(prompts, desc="Sampling prompts"):
noises = get_noises(
max_seed=MAX_SEED,
num_samples=num_noises_to_sample,
height=config["pipeline_call_args"]["height"],
width=config["pipeline_call_args"]["width"],
dtype=torch_dtype,
fn=get_latent_prep_fn(pipeline_name),
)
print(f"Number of noise samples: {len(noises)}")
datapoint_for_current_round = sample(
verifier = QwenVerifier(use_low_gpu_vram=config.get("use_low_gpu_vram", False))

# === Main loop: For each search round and each prompt ===
for prompt in tqdm(prompts, desc="Processing prompts"):
search_round = 1

# For zero-order search, we store the best datapoint per round.
best_datapoint_per_round = {}

while search_round <= search_rounds:
# Determine the number of noise samples.
if search_method == "zero-order":
num_noises_to_sample = 1
else:
num_noises_to_sample = 2**search_round

print(f"\n=== Prompt: {prompt} | Round: {search_round} ===")

# --- Generate noise pool ---
should_regenate_noise = True
previous_round = search_round - 1
if previous_round in best_datapoint_per_round:
was_improvement = best_datapoint_per_round[previous_round]["neighbors_improvement"]
if was_improvement:
should_regenate_noise = False

# For subsequent rounds in zero-order: use best noise from previous round.
# This happens ONLY if there was an improvement with the neighbors, otherwise
# round is progressed.
if should_regenate_noise:
# Standard noise sampling.
if search_method == "zero-order" and search_round != 1:
print("Regenerating base noise because the previous round was rejected.")
noises = get_noises(
max_seed=MAX_SEED,
num_samples=num_noises_to_sample,
height=config["pipeline_call_args"]["height"],
width=config["pipeline_call_args"]["width"],
dtype=torch_dtype,
fn=get_latent_prep_fn(pipeline_name),
)
else:
if best_datapoint_per_round[previous_round]:
if best_datapoint_per_round[previous_round]["neighbors_improvement"]:
print("Using the best noise from the previous round.")
prev_dp = best_datapoint_per_round[previous_round]
noises = {int(prev_dp["best_noise_seed"]): prev_dp["best_noise"]}
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused about here & the should_generate_noise flag; what's the difference between the two?

print(
f"No improvement in neighbors found for prompt '{prompt}' at round {search_round}. "
"Rejecting this round and progressing to the next."
)
search_round += 1
continue

if search_method == "zero-order":
# Process the noise to generate neighbors.
base_seed, base_noise = next(iter(noises.items()))
neighbors = generate_neighbors(
base_noise, threshold=search_args["threshold"], num_neighbors=search_args["num_neighbors"]
).squeeze(0)
# Concatenate the base noise with its neighbors.
neighbors_and_noise = torch.cat([base_noise, neighbors], dim=0)
new_noises = {}
for i, noise_tensor in enumerate(neighbors_and_noise):
new_noises[base_seed + i] = noise_tensor.unsqueeze(0)
noises = new_noises

print(f"Number of noise samples for prompt '{prompt}': {len(noises)}")

# --- Sampling, verifying, and saving artifacts ---
datapoint = sample(
noises=noises,
prompt=prompt,
search_round=round,
search_round=search_round,
pipe=pipe,
verifier=verifier,
topk=TOPK,
root_dir=root_dir,
root_dir=output_dir,
config=config,
)

if search_method == "zero-order":
# Update the best datapoint for zero-order.
if datapoint["neighbors_improvement"]:
best_datapoint_per_round[search_round] = datapoint

search_round += 1


if __name__ == "__main__":
main()
Loading