Skip to content

Commit 4c6a5bc

Browse files
authored
Zero order search (#14)
* zero-order search * finish * search tolerance. * updates * updates * changes * changes * changes * get merge ready
1 parent d15309d commit 4c6a5bc

File tree

3 files changed

+234
-75
lines changed

3 files changed

+234
-75
lines changed

README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ But it's been growing now! Check out the rest of the README to know more 🤗
1717
🔥 27/02/2025: [MaximClouser](https://github.com/MaximClouser) implemented a ComfyUI node for inference-time
1818
scaling in [this repo](https://github.com/YRIKKA/ComfyUI-InferenceTimeScaling). Check it out!
1919

20+
🔥 25/02/2025: Support for zero-order search has been added [in this PR](https://github.com/sayakpaul/tt-scale-flux/pull/14). Many thanks to Willis Ma for the reviews. Check out [this section](#controlling-search) for more details.
21+
22+
🔥 21/02/2025: Better configuration management for more flexibility, free the `argparse` madness.
23+
2024
🔥 16/02/2025: Support for batched image generation has been added [in this PR](https://github.com/sayakpaul/tt-scale-flux/pull/9). It speeds up the total time but consumes more memory.
2125

2226
🔥 15/02/2025: Support for structured generation with Qwen2.5 has been added (using `outlines` and `pydantic`) in [this PR](https://github.com/sayakpaul/tt-scale-flux/pull/6).
@@ -271,6 +275,36 @@ The verifier prompt that is used during grading/verification is specified in [th
271275
the paper (Inference-Time Scaling for Diffusion Models beyond Scaling Denoising Steps). You are welcome to
272276
experiment with a different prompt.
273277

278+
### Controlling search
279+
280+
You can configure search related arguments through `search_args` in the configuration file. Currently,
281+
"random search" and "zero-order search" are supported. The default configurations provided
282+
under [`configs/`](./configs/) are all for random search.
283+
284+
Below is a configuration for zero-order search:
285+
286+
```json
287+
"search_args": {
288+
"search_method": "zero-order",
289+
"search_rounds": 4,
290+
"threshold": 0.95,
291+
"num_neighbors": 4
292+
}
293+
```
294+
295+
<details>
296+
<summary>For details about the parameters</summary>
297+
298+
* `threshold`: threshold to use for filtering out neighbor candidates from the base noise
299+
* `num_neighbors`: number of neighbors to generate from the base noise
300+
301+
</details>
302+
303+
> [!NOTE]
304+
> If the neighbors in the current round do not improve the current search round results,
305+
we simply reject the round, starting the next round with a new base nosie. In case of
306+
worse neighbors, we don't serialize the artifacts.
307+
274308
## More results
275309

276310
<details>

main.py

Lines changed: 133 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,17 @@
66
import torch
77
from diffusers import DiffusionPipeline
88
from tqdm.auto import tqdm
9-
import copy
109

11-
from utils import prompt_to_filename, get_noises, TORCH_DTYPE_MAP, get_latent_prep_fn, parse_cli_args, MODEL_NAME_MAP
10+
from utils import (
11+
generate_neighbors,
12+
prompt_to_filename,
13+
get_noises,
14+
TORCH_DTYPE_MAP,
15+
get_latent_prep_fn,
16+
parse_cli_args,
17+
serialize_artifacts,
18+
MODEL_NAME_MAP,
19+
)
1220

1321
# Non-configurable constants
1422
TOPK = 1 # Always selecting the top-1 noise for the next round
@@ -28,18 +36,20 @@ def sample(
2836
"""
2937
For a given prompt, generate images using all provided noises in batches,
3038
score them with the verifier, and select the top-K noise.
31-
The images and JSON artifacts are saved under `root_dir`.
39+
The images and JSON artifacts are serialized via `serialize_artifacts`.
3240
"""
3341
use_low_gpu_vram = config.get("use_low_gpu_vram", False)
3442
batch_size_for_img_gen = config.get("batch_size_for_img_gen", 1)
3543
verifier_args = config.get("verifier_args")
3644
max_new_tokens = verifier_args.get("max_new_tokens", None)
3745
choice_of_metric = verifier_args.get("choice_of_metric", None)
3846
verifier_to_use = verifier_args.get("name", "gemini")
47+
search_args = config.get("search_args", None)
3948

4049
images_for_prompt = []
4150
noises_used = []
4251
seeds_used = []
52+
images_info = [] # Will collect (seed, noise, image, filename) tuples for serialization.
4353
prompt_filename = prompt_to_filename(prompt)
4454

4555
# Convert the noises dictionary into a list of (seed, noise) tuples.
@@ -55,7 +65,7 @@ def sample(
5565

5666
if use_low_gpu_vram and verifier_to_use != "gemini":
5767
pipe = pipe.to("cuda:0")
58-
print(f"Generating images for batch with seeds: {[s for s in seeds_batch]}.")
68+
print(f"Generating images for batch with seeds: {list(seeds_batch)}.")
5969

6070
# Create a batched prompt list and stack the latents.
6171
batched_prompts = [prompt] * len(noises_batch)
@@ -66,12 +76,12 @@ def sample(
6676
if use_low_gpu_vram and verifier_to_use != "gemini":
6777
pipe = pipe.to("cpu")
6878

69-
# Iterate over the batch and save the images.
79+
# Collect the images and corresponding info.
7080
for seed, noise, image, filename in zip(seeds_batch, noises_batch, batch_images, filenames_batch):
7181
images_for_prompt.append(image)
7282
noises_used.append(noise)
7383
seeds_used.append(seed)
74-
image.save(filename)
84+
images_info.append((seed, noise, image, filename))
7585

7686
# Prepare verifier inputs and perform inference.
7787
verifier_inputs = verifier.prepare_inputs(images=images_for_prompt, prompts=[prompt] * len(images_for_prompt))
@@ -83,23 +93,18 @@ def sample(
8393
for o in outputs:
8494
assert choice_of_metric in o, o.keys()
8595

86-
assert (
87-
len(outputs) == len(images_for_prompt)
88-
), f"Expected len(outputs) to be same as len(images_for_prompt) but got {len(outputs)=} & {len(images_for_prompt)=}"
96+
assert len(outputs) == len(images_for_prompt), (
97+
f"Expected len(outputs) to be same as len(images_for_prompt) but got {len(outputs)=} & {len(images_for_prompt)=}"
98+
)
8999

90100
results = []
91101
for json_dict, seed_val, noise in zip(outputs, seeds_used, noises_used):
92-
# Attach the noise tensor so we can select top-K.
102+
# Merge verifier outputs with noise info.
93103
merged = {**json_dict, "noise": noise, "seed": seed_val}
94104
results.append(merged)
95105

96-
# Sort by the chosen metric descending and pick top-K.
97-
for x in results:
98-
assert (
99-
choice_of_metric in x
100-
), f"Expected all dicts in `results` to contain the `{choice_of_metric}` key; got {x.keys()}."
101-
102106
def f(x):
107+
# If the verifier output is a dict, assume it contains a "score" key.
103108
if isinstance(x[choice_of_metric], dict):
104109
return x[choice_of_metric]["score"]
105110
return x[choice_of_metric]
@@ -117,72 +122,79 @@ def f(x):
117122
"search_round": search_round,
118123
"num_noises": len(noises),
119124
"best_noise_seed": topk_scores[0]["seed"],
125+
"best_noise": topk_scores[0]["noise"],
120126
"best_score": topk_scores[0][choice_of_metric],
121127
"choice_of_metric": choice_of_metric,
122128
"best_img_path": best_img_path,
123129
}
124-
# Save the best config JSON file alongside the images.
125-
best_json_filename = best_img_path.replace(".png", ".json")
126-
with open(best_json_filename, "w") as f:
127-
json.dump(datapoint, f, indent=4)
130+
131+
# Check if the neighbors have any improvements (zero-order only).
132+
search_method = search_args.get("search_method", "random") if search_args else "random"
133+
if search_args and search_method == "zero-order":
134+
first_score = f(results[0])
135+
neighbors_with_better_score = any(f(item) > first_score for item in results[1:])
136+
datapoint["neighbors_improvement"] = neighbors_with_better_score
137+
138+
# Serialize.
139+
if search_method == "zero-order":
140+
if datapoint["neighbors_improvement"]:
141+
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint)
142+
else:
143+
print("Skipping serialization as there was no improvement in this round.")
144+
elif search_method == "random":
145+
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint)
146+
128147
return datapoint
129148

130149

131150
@torch.no_grad()
132151
def main():
133-
"""
134-
Main function:
135-
- Parses CLI arguments.
136-
- Creates an output directory based on verifier and current datetime.
137-
- Loads prompts.
138-
- Loads the image-generation pipeline.
139-
- Loads the verifier model.
140-
- Runs several search rounds where for each prompt a pool of random noises is generated,
141-
candidate images are produced and verified, and the best noise is chosen.
142-
"""
152+
# === Load configuration and CLI arguments ===
143153
args = parse_cli_args()
144-
145-
# Build a config dictionary for parameters that need to be passed around.
146154
with open(args.pipeline_config_path, "r") as f:
147155
config = json.load(f)
148156
config.update(vars(args))
149157

150-
search_rounds = config["search_args"]["search_rounds"]
158+
search_args = config["search_args"]
159+
search_rounds = search_args["search_rounds"]
160+
search_method = search_args.get("search_method", "random")
151161
num_prompts = config["num_prompts"]
152162

153-
# Create a root output directory: output/{verifier_to_use}/{current_datetime}
163+
# === Create output directory ===
154164
current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
155165
pipeline_name = config.pop("pretrained_model_name_or_path")
156-
root_dir = os.path.join(
166+
verifier_name = config["verifier_args"]["name"]
167+
choice_of_metric = config["verifier_args"]["choice_of_metric"]
168+
output_dir = os.path.join(
157169
"output",
158170
MODEL_NAME_MAP[pipeline_name],
159-
config["verifier_args"]["name"],
160-
config["verifier_args"]["choice_of_metric"],
171+
verifier_name,
172+
choice_of_metric,
161173
current_datetime,
162174
)
163-
os.makedirs(root_dir, exist_ok=True)
164-
print(f"Artifacts will be saved to: {root_dir}")
165-
with open(os.path.join(root_dir, "config.json"), "w") as f:
166-
json.dump(config, f)
175+
os.makedirs(output_dir, exist_ok=True)
176+
print(f"Artifacts will be saved to: {output_dir}")
177+
with open(os.path.join(output_dir, "config.json"), "w") as f:
178+
json.dump(config, f, indent=4)
167179

168-
# Load prompts from file.
180+
# === Load prompts ===
169181
if args.prompt is None:
170182
with open("prompts_open_image_pref_v1.txt", "r", encoding="utf-8") as f:
171-
prompts = [line.strip() for line in f.readlines() if line.strip()]
183+
prompts = [line.strip() for line in f if line.strip()]
172184
if num_prompts != "all":
173185
prompts = prompts[:num_prompts]
174-
print(f"Using {len(prompts)} prompt(s).")
175186
else:
176187
prompts = [args.prompt]
188+
print(f"Using {len(prompts)} prompt(s).")
177189

178-
# Set up the image-generation pipeline (on the first GPU if available).
190+
# === Set up the image-generation pipeline ===
179191
torch_dtype = TORCH_DTYPE_MAP[config.pop("torch_dtype")]
180192
pipe = DiffusionPipeline.from_pretrained(pipeline_name, torch_dtype=torch_dtype)
181-
if not config["use_low_gpu_vram"]:
193+
if not config.get("use_low_gpu_vram", False):
182194
pipe = pipe.to("cuda:0")
183195
pipe.set_progress_bar_config(disable=True)
184196

185-
# Load the verifier model.
197+
# === Load verifier model ===
186198
verifier_args = config["verifier_args"]
187199
if verifier_args["name"] == "gemini":
188200
from verifiers import GeminiVerifier
@@ -191,33 +203,88 @@ def main():
191203
else:
192204
from verifiers.qwen_verifier import QwenVerifier
193205

194-
verifier = QwenVerifier(use_low_gpu_vram=config["use_low_gpu_vram"])
195-
196-
# Main loop: For each search round and each prompt, generate images, verify, and save artifacts.
197-
for round in range(1, search_rounds + 1):
198-
print(f"\n=== Round: {round} ===")
199-
num_noises_to_sample = 2**round # scale noise pool.
200-
for prompt in tqdm(prompts, desc="Sampling prompts"):
201-
noises = get_noises(
202-
max_seed=MAX_SEED,
203-
num_samples=num_noises_to_sample,
204-
height=config["pipeline_call_args"]["height"],
205-
width=config["pipeline_call_args"]["width"],
206-
dtype=torch_dtype,
207-
fn=get_latent_prep_fn(pipeline_name),
208-
)
209-
print(f"Number of noise samples: {len(noises)}")
210-
datapoint_for_current_round = sample(
206+
verifier = QwenVerifier(use_low_gpu_vram=config.get("use_low_gpu_vram", False))
207+
208+
# === Main loop: For each prompt and each search round ===
209+
for prompt in tqdm(prompts, desc="Processing prompts"):
210+
search_round = 1
211+
212+
# For zero-order search, we store the best datapoint per round.
213+
best_datapoint_per_round = {}
214+
215+
while search_round <= search_rounds:
216+
# Determine the number of noise samples.
217+
if search_method == "zero-order":
218+
num_noises_to_sample = 1
219+
else:
220+
num_noises_to_sample = 2**search_round
221+
222+
print(f"\n=== Prompt: {prompt} | Round: {search_round} ===")
223+
224+
# --- Generate noise pool ---
225+
should_regenate_noise = True
226+
previous_round = search_round - 1
227+
if previous_round in best_datapoint_per_round:
228+
was_improvement = best_datapoint_per_round[previous_round]["neighbors_improvement"]
229+
if was_improvement:
230+
should_regenate_noise = False
231+
232+
# For subsequent rounds in zero-order: use best noise from previous round.
233+
# This happens ONLY if there was an improvement with the neighbors in the
234+
# previous round, otherwise round is progressed with newly sampled noise.
235+
if should_regenate_noise:
236+
# Standard noise sampling.
237+
if search_method == "zero-order" and search_round != 1:
238+
print("Regenerating base noise because the previous round was rejected.")
239+
noises = get_noises(
240+
max_seed=MAX_SEED,
241+
num_samples=num_noises_to_sample,
242+
height=config["pipeline_call_args"]["height"],
243+
width=config["pipeline_call_args"]["width"],
244+
dtype=torch_dtype,
245+
fn=get_latent_prep_fn(pipeline_name),
246+
)
247+
else:
248+
if best_datapoint_per_round[previous_round]:
249+
if best_datapoint_per_round[previous_round]["neighbors_improvement"]:
250+
print("Using the best noise from the previous round.")
251+
prev_dp = best_datapoint_per_round[previous_round]
252+
noises = {int(prev_dp["best_noise_seed"]): prev_dp["best_noise"]}
253+
254+
if search_method == "zero-order":
255+
# Process the noise to generate neighbors.
256+
base_seed, base_noise = next(iter(noises.items()))
257+
neighbors = generate_neighbors(
258+
base_noise, threshold=search_args["threshold"], num_neighbors=search_args["num_neighbors"]
259+
).squeeze(0)
260+
# Concatenate the base noise with its neighbors.
261+
neighbors_and_noise = torch.cat([base_noise, neighbors], dim=0)
262+
new_noises = {}
263+
for i, noise_tensor in enumerate(neighbors_and_noise):
264+
new_noises[base_seed + i] = noise_tensor.unsqueeze(0)
265+
noises = new_noises
266+
267+
print(f"Number of noise samples for prompt '{prompt}': {len(noises)}")
268+
269+
# --- Sampling, verifying, and saving artifacts ---
270+
datapoint = sample(
211271
noises=noises,
212272
prompt=prompt,
213-
search_round=round,
273+
search_round=search_round,
214274
pipe=pipe,
215275
verifier=verifier,
216276
topk=TOPK,
217-
root_dir=root_dir,
277+
root_dir=output_dir,
218278
config=config,
219279
)
220280

281+
if search_method == "zero-order":
282+
# Update the best datapoint for zero-order.
283+
if datapoint["neighbors_improvement"]:
284+
best_datapoint_per_round[search_round] = datapoint
285+
286+
search_round += 1
287+
221288

222289
if __name__ == "__main__":
223290
main()

0 commit comments

Comments
 (0)