1414 TORCH_DTYPE_MAP ,
1515 get_latent_prep_fn ,
1616 parse_cli_args ,
17+ serialize_artifacts ,
1718 MODEL_NAME_MAP ,
1819)
1920
@@ -35,7 +36,7 @@ def sample(
3536 """
3637 For a given prompt, generate images using all provided noises in batches,
3738 score them with the verifier, and select the top-K noise.
38- The images and JSON artifacts are saved under `root_dir `.
39+ The images and JSON artifacts are serialized via `serialize_artifacts `.
3940 """
4041 use_low_gpu_vram = config .get ("use_low_gpu_vram" , False )
4142 batch_size_for_img_gen = config .get ("batch_size_for_img_gen" , 1 )
@@ -48,6 +49,7 @@ def sample(
4849 images_for_prompt = []
4950 noises_used = []
5051 seeds_used = []
52+ images_info = [] # Will collect (seed, noise, image, filename) tuples for serialization.
5153 prompt_filename = prompt_to_filename (prompt )
5254
5355 # Convert the noises dictionary into a list of (seed, noise) tuples.
@@ -63,7 +65,7 @@ def sample(
6365
6466 if use_low_gpu_vram and verifier_to_use != "gemini" :
6567 pipe = pipe .to ("cuda:0" )
66- print (f"Generating images for batch with seeds: { [ s for s in seeds_batch ] } ." )
68+ print (f"Generating images for batch with seeds: { list ( seeds_batch ) } ." )
6769
6870 # Create a batched prompt list and stack the latents.
6971 batched_prompts = [prompt ] * len (noises_batch )
@@ -74,12 +76,12 @@ def sample(
7476 if use_low_gpu_vram and verifier_to_use != "gemini" :
7577 pipe = pipe .to ("cpu" )
7678
77- # Iterate over the batch and save the images .
79+ # Collect the images and corresponding info .
7880 for seed , noise , image , filename in zip (seeds_batch , noises_batch , batch_images , filenames_batch ):
7981 images_for_prompt .append (image )
8082 noises_used .append (noise )
8183 seeds_used .append (seed )
82- image . save ( filename )
84+ images_info . append (( seed , noise , image , filename ) )
8385
8486 # Prepare verifier inputs and perform inference.
8587 verifier_inputs = verifier .prepare_inputs (images = images_for_prompt , prompts = [prompt ] * len (images_for_prompt ))
@@ -97,17 +99,12 @@ def sample(
9799
98100 results = []
99101 for json_dict , seed_val , noise in zip (outputs , seeds_used , noises_used ):
100- # Attach the noise tensor so we can select top-K .
102+ # Merge verifier outputs with noise info .
101103 merged = {** json_dict , "noise" : noise , "seed" : seed_val }
102104 results .append (merged )
103105
104- # Sort by the chosen metric descending and pick top-K.
105- for x in results :
106- assert choice_of_metric in x , (
107- f"Expected all dicts in `results` to contain the `{ choice_of_metric } ` key; got { x .keys ()} ."
108- )
109-
110106 def f (x ):
107+ # If the verifier output is a dict, assume it contains a "score" key.
111108 if isinstance (x [choice_of_metric ], dict ):
112109 return x [choice_of_metric ]["score" ]
113110 return x [choice_of_metric ]
@@ -131,22 +128,21 @@ def f(x):
131128 "best_img_path" : best_img_path ,
132129 }
133130
134- # Check if the neighbors have any improvements.
135- if search_args and search_args .get ("search_method" ) == "zero-order" :
136- # `first_score` corresponds to the base noise.
131+ # Check if the neighbors have any improvements (zero-order only) .
132+ search_method = search_args .get ("search_method" , "random" ) if search_args else "random"
133+ if search_args and search_method == "zero-order" :
137134 first_score = f (results [0 ])
138135 neighbors_with_better_score = any (f (item ) > first_score for item in results [1 :])
139- if not neighbors_with_better_score :
140- datapoint ["neighbors_no_improvement" ] = True
141- else :
142- datapoint ["neighbors_no_improvement" ] = False
136+ datapoint ["neighbors_improvement" ] = neighbors_with_better_score
143137
144- # Save the best config JSON file alongside the images.
145- best_json_filename = best_img_path .replace (".png" , ".json" )
146- with open (best_json_filename , "w" ) as f :
147- datapoint_cp = datapoint .copy ()
148- datapoint_cp .pop ("best_noise" )
149- json .dump (datapoint_cp , f , indent = 4 )
138+ # Serialize.
139+ if search_method == "zero-order" :
140+ if datapoint ["neighbors_improvement" ]:
141+ serialize_artifacts (images_info , prompt , search_round , root_dir , datapoint )
142+ else :
143+ print ("Skipping serialization as there was no improvement in this round." )
144+ elif search_method == "random" :
145+ serialize_artifacts (images_info , prompt , search_round , root_dir , datapoint )
150146
151147 return datapoint
152148
@@ -161,6 +157,7 @@ def main():
161157
162158 search_args = config ["search_args" ]
163159 search_rounds = search_args ["search_rounds" ]
160+ search_method = search_args .get ("search_method" , "random" )
164161 num_prompts = config ["num_prompts" ]
165162
166163 # === Create output directory ===
@@ -208,30 +205,37 @@ def main():
208205
209206 verifier = QwenVerifier (use_low_gpu_vram = config .get ("use_low_gpu_vram" , False ))
210207
211- # For zero-order search, we store the best datapoint per prompt.
212- best_datapoint_for_prompt = {}
213-
214208 # === Main loop: For each search round and each prompt ===
215- search_round = 1
216- search_method = search_args .get ("search_method" , "random" )
217- tolerance_count = 0 # Only used for zero-order
209+ for prompt in tqdm (prompts , desc = "Processing prompts" ):
210+ search_round = 1
218211
219- while search_round <= search_rounds :
220- # Determine the number of noise samples.
221- if search_method == "zero-order" :
222- num_noises_to_sample = 1
223- else :
224- num_noises_to_sample = 2 ** search_round
212+ # For zero-order search, we store the best datapoint per round.
213+ best_datapoint_per_round = {}
225214
226- print (f"\n === Round: { search_round } (tolerance_count: { tolerance_count } ) ===" )
215+ while search_round <= search_rounds :
216+ # Determine the number of noise samples.
217+ if search_method == "zero-order" :
218+ num_noises_to_sample = 1
219+ else :
220+ num_noises_to_sample = 2 ** search_round
227221
228- # Track if any prompt improved in this round.
229- round_improved = False
222+ print (f"\n === Prompt: { prompt } | Round: { search_round } ===" )
230223
231- for prompt in tqdm (prompts , desc = "Sampling prompts" ):
232224 # --- Generate noise pool ---
233- if search_method != "zero-order" or search_round == 1 :
234- # Standard noise sampling
225+ should_regenate_noise = True
226+ previous_round = search_round - 1
227+ if previous_round in best_datapoint_per_round :
228+ was_improvement = best_datapoint_per_round [previous_round ]["neighbors_improvement" ]
229+ if was_improvement :
230+ should_regenate_noise = False
231+
232+ # For subsequent rounds in zero-order: use best noise from previous round.
233+ # This happens ONLY if there was an improvement with the neighbors, otherwise
234+ # round is progressed.
235+ if should_regenate_noise :
236+ # Standard noise sampling.
237+ if search_method == "zero-order" and search_round != 1 :
238+ print ("Regenerating base noise because the previous round was rejected." )
235239 noises = get_noises (
236240 max_seed = MAX_SEED ,
237241 num_samples = num_noises_to_sample ,
@@ -241,9 +245,18 @@ def main():
241245 fn = get_latent_prep_fn (pipeline_name ),
242246 )
243247 else :
244- # For subsequent rounds in zero-order: use best noise from previous round.
245- prev_dp = best_datapoint_for_prompt [prompt ]
246- noises = {int (prev_dp ["best_noise_seed" ]): prev_dp ["best_noise" ]}
248+ if best_datapoint_per_round [previous_round ]:
249+ if best_datapoint_per_round [previous_round ]["neighbors_improvement" ]:
250+ print ("Using the best noise from the previous round." )
251+ prev_dp = best_datapoint_per_round [previous_round ]
252+ noises = {int (prev_dp ["best_noise_seed" ]): prev_dp ["best_noise" ]}
253+ else :
254+ print (
255+ f"No improvement in neighbors found for prompt '{ prompt } ' at round { search_round } . "
256+ "Rejecting this round and progressing to the next."
257+ )
258+ search_round += 1
259+ continue
247260
248261 if search_method == "zero-order" :
249262 # Process the noise to generate neighbors.
@@ -273,24 +286,10 @@ def main():
273286 )
274287
275288 if search_method == "zero-order" :
276- # Update the best noise for zero-order.
277- best_datapoint_for_prompt [prompt ] = datapoint
278-
279- # If there was an improvement, flag this round as improved.
280- if not datapoint .get ("neighbors_no_improvement" , False ):
281- round_improved = True
282-
283- # --- Decide on round incrementation ---
284- if search_method == "zero-order" :
285- if round_improved :
286- tolerance_count = 0
287- search_round += 1
288- else :
289- tolerance_count += 1
290- if tolerance_count >= search_args ["search_round_tolerance" ]:
291- tolerance_count = 0
292- search_round += 1
293- else :
289+ # Update the best datapoint for zero-order.
290+ if datapoint ["neighbors_improvement" ]:
291+ best_datapoint_per_round [search_round ] = datapoint
292+
294293 search_round += 1
295294
296295
0 commit comments