Skip to content

Commit b4550cb

Browse files
committed
updates
1 parent 1a0a941 commit b4550cb

File tree

2 files changed

+86
-63
lines changed

2 files changed

+86
-63
lines changed

main.py

Lines changed: 62 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
TORCH_DTYPE_MAP,
1515
get_latent_prep_fn,
1616
parse_cli_args,
17+
serialize_artifacts,
1718
MODEL_NAME_MAP,
1819
)
1920

@@ -35,7 +36,7 @@ def sample(
3536
"""
3637
For a given prompt, generate images using all provided noises in batches,
3738
score them with the verifier, and select the top-K noise.
38-
The images and JSON artifacts are saved under `root_dir`.
39+
The images and JSON artifacts are serialized via `serialize_artifacts`.
3940
"""
4041
use_low_gpu_vram = config.get("use_low_gpu_vram", False)
4142
batch_size_for_img_gen = config.get("batch_size_for_img_gen", 1)
@@ -48,6 +49,7 @@ def sample(
4849
images_for_prompt = []
4950
noises_used = []
5051
seeds_used = []
52+
images_info = [] # Will collect (seed, noise, image, filename) tuples for serialization.
5153
prompt_filename = prompt_to_filename(prompt)
5254

5355
# Convert the noises dictionary into a list of (seed, noise) tuples.
@@ -63,7 +65,7 @@ def sample(
6365

6466
if use_low_gpu_vram and verifier_to_use != "gemini":
6567
pipe = pipe.to("cuda:0")
66-
print(f"Generating images for batch with seeds: {[s for s in seeds_batch]}.")
68+
print(f"Generating images for batch with seeds: {list(seeds_batch)}.")
6769

6870
# Create a batched prompt list and stack the latents.
6971
batched_prompts = [prompt] * len(noises_batch)
@@ -74,12 +76,12 @@ def sample(
7476
if use_low_gpu_vram and verifier_to_use != "gemini":
7577
pipe = pipe.to("cpu")
7678

77-
# Iterate over the batch and save the images.
79+
# Collect the images and corresponding info.
7880
for seed, noise, image, filename in zip(seeds_batch, noises_batch, batch_images, filenames_batch):
7981
images_for_prompt.append(image)
8082
noises_used.append(noise)
8183
seeds_used.append(seed)
82-
image.save(filename)
84+
images_info.append((seed, noise, image, filename))
8385

8486
# Prepare verifier inputs and perform inference.
8587
verifier_inputs = verifier.prepare_inputs(images=images_for_prompt, prompts=[prompt] * len(images_for_prompt))
@@ -97,17 +99,12 @@ def sample(
9799

98100
results = []
99101
for json_dict, seed_val, noise in zip(outputs, seeds_used, noises_used):
100-
# Attach the noise tensor so we can select top-K.
102+
# Merge verifier outputs with noise info.
101103
merged = {**json_dict, "noise": noise, "seed": seed_val}
102104
results.append(merged)
103105

104-
# Sort by the chosen metric descending and pick top-K.
105-
for x in results:
106-
assert choice_of_metric in x, (
107-
f"Expected all dicts in `results` to contain the `{choice_of_metric}` key; got {x.keys()}."
108-
)
109-
110106
def f(x):
107+
# If the verifier output is a dict, assume it contains a "score" key.
111108
if isinstance(x[choice_of_metric], dict):
112109
return x[choice_of_metric]["score"]
113110
return x[choice_of_metric]
@@ -131,22 +128,21 @@ def f(x):
131128
"best_img_path": best_img_path,
132129
}
133130

134-
# Check if the neighbors have any improvements.
135-
if search_args and search_args.get("search_method") == "zero-order":
136-
# `first_score` corresponds to the base noise.
131+
# Check if the neighbors have any improvements (zero-order only).
132+
search_method = search_args.get("search_method", "random") if search_args else "random"
133+
if search_args and search_method == "zero-order":
137134
first_score = f(results[0])
138135
neighbors_with_better_score = any(f(item) > first_score for item in results[1:])
139-
if not neighbors_with_better_score:
140-
datapoint["neighbors_no_improvement"] = True
141-
else:
142-
datapoint["neighbors_no_improvement"] = False
136+
datapoint["neighbors_improvement"] = neighbors_with_better_score
143137

144-
# Save the best config JSON file alongside the images.
145-
best_json_filename = best_img_path.replace(".png", ".json")
146-
with open(best_json_filename, "w") as f:
147-
datapoint_cp = datapoint.copy()
148-
datapoint_cp.pop("best_noise")
149-
json.dump(datapoint_cp, f, indent=4)
138+
# Serialize.
139+
if search_method == "zero-order":
140+
if datapoint["neighbors_improvement"]:
141+
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint)
142+
else:
143+
print("Skipping serialization as there was no improvement in this round.")
144+
elif search_method == "random":
145+
serialize_artifacts(images_info, prompt, search_round, root_dir, datapoint)
150146

151147
return datapoint
152148

@@ -161,6 +157,7 @@ def main():
161157

162158
search_args = config["search_args"]
163159
search_rounds = search_args["search_rounds"]
160+
search_method = search_args.get("search_method", "random")
164161
num_prompts = config["num_prompts"]
165162

166163
# === Create output directory ===
@@ -208,30 +205,37 @@ def main():
208205

209206
verifier = QwenVerifier(use_low_gpu_vram=config.get("use_low_gpu_vram", False))
210207

211-
# For zero-order search, we store the best datapoint per prompt.
212-
best_datapoint_for_prompt = {}
213-
214208
# === Main loop: For each search round and each prompt ===
215-
search_round = 1
216-
search_method = search_args.get("search_method", "random")
217-
tolerance_count = 0 # Only used for zero-order
209+
for prompt in tqdm(prompts, desc="Processing prompts"):
210+
search_round = 1
218211

219-
while search_round <= search_rounds:
220-
# Determine the number of noise samples.
221-
if search_method == "zero-order":
222-
num_noises_to_sample = 1
223-
else:
224-
num_noises_to_sample = 2**search_round
212+
# For zero-order search, we store the best datapoint per round.
213+
best_datapoint_per_round = {}
225214

226-
print(f"\n=== Round: {search_round} (tolerance_count: {tolerance_count}) ===")
215+
while search_round <= search_rounds:
216+
# Determine the number of noise samples.
217+
if search_method == "zero-order":
218+
num_noises_to_sample = 1
219+
else:
220+
num_noises_to_sample = 2**search_round
227221

228-
# Track if any prompt improved in this round.
229-
round_improved = False
222+
print(f"\n=== Prompt: {prompt} | Round: {search_round} ===")
230223

231-
for prompt in tqdm(prompts, desc="Sampling prompts"):
232224
# --- Generate noise pool ---
233-
if search_method != "zero-order" or search_round == 1:
234-
# Standard noise sampling
225+
should_regenate_noise = True
226+
previous_round = search_round - 1
227+
if previous_round in best_datapoint_per_round:
228+
was_improvement = best_datapoint_per_round[previous_round]["neighbors_improvement"]
229+
if was_improvement:
230+
should_regenate_noise = False
231+
232+
# For subsequent rounds in zero-order: use best noise from previous round.
233+
# This happens ONLY if there was an improvement with the neighbors, otherwise
234+
# round is progressed.
235+
if should_regenate_noise:
236+
# Standard noise sampling.
237+
if search_method == "zero-order" and search_round != 1:
238+
print("Regenerating base noise because the previous round was rejected.")
235239
noises = get_noises(
236240
max_seed=MAX_SEED,
237241
num_samples=num_noises_to_sample,
@@ -241,9 +245,18 @@ def main():
241245
fn=get_latent_prep_fn(pipeline_name),
242246
)
243247
else:
244-
# For subsequent rounds in zero-order: use best noise from previous round.
245-
prev_dp = best_datapoint_for_prompt[prompt]
246-
noises = {int(prev_dp["best_noise_seed"]): prev_dp["best_noise"]}
248+
if best_datapoint_per_round[previous_round]:
249+
if best_datapoint_per_round[previous_round]["neighbors_improvement"]:
250+
print("Using the best noise from the previous round.")
251+
prev_dp = best_datapoint_per_round[previous_round]
252+
noises = {int(prev_dp["best_noise_seed"]): prev_dp["best_noise"]}
253+
else:
254+
print(
255+
f"No improvement in neighbors found for prompt '{prompt}' at round {search_round}. "
256+
"Rejecting this round and progressing to the next."
257+
)
258+
search_round += 1
259+
continue
247260

248261
if search_method == "zero-order":
249262
# Process the noise to generate neighbors.
@@ -273,24 +286,10 @@ def main():
273286
)
274287

275288
if search_method == "zero-order":
276-
# Update the best noise for zero-order.
277-
best_datapoint_for_prompt[prompt] = datapoint
278-
279-
# If there was an improvement, flag this round as improved.
280-
if not datapoint.get("neighbors_no_improvement", False):
281-
round_improved = True
282-
283-
# --- Decide on round incrementation ---
284-
if search_method == "zero-order":
285-
if round_improved:
286-
tolerance_count = 0
287-
search_round += 1
288-
else:
289-
tolerance_count += 1
290-
if tolerance_count >= search_args["search_round_tolerance"]:
291-
tolerance_count = 0
292-
search_round += 1
293-
else:
289+
# Update the best datapoint for zero-order.
290+
if datapoint["neighbors_improvement"]:
291+
best_datapoint_per_round[search_round] = datapoint
292+
294293
search_round += 1
295294

296295

utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,27 @@ def recover_json_from_output(output: str):
249249
end = output.rfind("}") + 1
250250
json_part = output[start:end]
251251
return json.loads(json_part)
252+
253+
254+
def serialize_artifacts(
255+
images_info: list[tuple[int, torch.Tensor, Image.Image, str]],
256+
prompt: str,
257+
search_round: int,
258+
root_dir: str,
259+
datapoint: dict,
260+
) -> None:
261+
"""
262+
Serialize generated images and the best datapoint JSON configuration.
263+
"""
264+
# Save each image.
265+
for seed, noise, image, filename in images_info:
266+
image.save(filename)
267+
268+
# Save the best datapoint config as a JSON file.
269+
best_json_filename = datapoint["best_img_path"].replace(".png", ".json")
270+
with open(best_json_filename, "w") as f:
271+
# Remove the noise tensor (or any non-serializable object) from the JSON.
272+
datapoint_copy = datapoint.copy()
273+
datapoint_copy.pop("best_noise", None)
274+
json.dump(datapoint_copy, f, indent=4)
275+
print(f"Serialized JSON configuration and images to {root_dir}.")

0 commit comments

Comments
 (0)