66import torch
77from diffusers import DiffusionPipeline
88from tqdm .auto import tqdm
9- import copy
109
11- from utils import prompt_to_filename , get_noises , TORCH_DTYPE_MAP , get_latent_prep_fn , parse_cli_args , MODEL_NAME_MAP
10+ from utils import (
11+ generate_neighbors ,
12+ prompt_to_filename ,
13+ get_noises ,
14+ TORCH_DTYPE_MAP ,
15+ get_latent_prep_fn ,
16+ parse_cli_args ,
17+ serialize_artifacts ,
18+ MODEL_NAME_MAP ,
19+ )
1220
1321# Non-configurable constants
1422TOPK = 1 # Always selecting the top-1 noise for the next round
@@ -28,18 +36,20 @@ def sample(
2836 """
2937 For a given prompt, generate images using all provided noises in batches,
3038 score them with the verifier, and select the top-K noise.
31- The images and JSON artifacts are saved under `root_dir `.
39+ The images and JSON artifacts are serialized via `serialize_artifacts `.
3240 """
3341 use_low_gpu_vram = config .get ("use_low_gpu_vram" , False )
3442 batch_size_for_img_gen = config .get ("batch_size_for_img_gen" , 1 )
3543 verifier_args = config .get ("verifier_args" )
3644 max_new_tokens = verifier_args .get ("max_new_tokens" , None )
3745 choice_of_metric = verifier_args .get ("choice_of_metric" , None )
3846 verifier_to_use = verifier_args .get ("name" , "gemini" )
47+ search_args = config .get ("search_args" , None )
3948
4049 images_for_prompt = []
4150 noises_used = []
4251 seeds_used = []
52+ images_info = [] # Will collect (seed, noise, image, filename) tuples for serialization.
4353 prompt_filename = prompt_to_filename (prompt )
4454
4555 # Convert the noises dictionary into a list of (seed, noise) tuples.
@@ -55,7 +65,7 @@ def sample(
5565
5666 if use_low_gpu_vram and verifier_to_use != "gemini" :
5767 pipe = pipe .to ("cuda:0" )
58- print (f"Generating images for batch with seeds: { [ s for s in seeds_batch ] } ." )
68+ print (f"Generating images for batch with seeds: { list ( seeds_batch ) } ." )
5969
6070 # Create a batched prompt list and stack the latents.
6171 batched_prompts = [prompt ] * len (noises_batch )
@@ -66,12 +76,12 @@ def sample(
6676 if use_low_gpu_vram and verifier_to_use != "gemini" :
6777 pipe = pipe .to ("cpu" )
6878
69- # Iterate over the batch and save the images .
79+ # Collect the images and corresponding info .
7080 for seed , noise , image , filename in zip (seeds_batch , noises_batch , batch_images , filenames_batch ):
7181 images_for_prompt .append (image )
7282 noises_used .append (noise )
7383 seeds_used .append (seed )
74- image . save ( filename )
84+ images_info . append (( seed , noise , image , filename ) )
7585
7686 # Prepare verifier inputs and perform inference.
7787 verifier_inputs = verifier .prepare_inputs (images = images_for_prompt , prompts = [prompt ] * len (images_for_prompt ))
@@ -83,23 +93,18 @@ def sample(
8393 for o in outputs :
8494 assert choice_of_metric in o , o .keys ()
8595
86- assert (
87- len (outputs ) == len (images_for_prompt )
88- ), f"Expected len(outputs) to be same as len(images_for_prompt) but got { len ( outputs ) = } & { len ( images_for_prompt ) = } "
96+ assert len ( outputs ) == len ( images_for_prompt ), (
97+ f"Expected len(outputs) to be same as len(images_for_prompt) but got { len ( outputs ) = } & { len (images_for_prompt )= } "
98+ )
8999
90100 results = []
91101 for json_dict , seed_val , noise in zip (outputs , seeds_used , noises_used ):
92- # Attach the noise tensor so we can select top-K .
102+ # Merge verifier outputs with noise info .
93103 merged = {** json_dict , "noise" : noise , "seed" : seed_val }
94104 results .append (merged )
95105
96- # Sort by the chosen metric descending and pick top-K.
97- for x in results :
98- assert (
99- choice_of_metric in x
100- ), f"Expected all dicts in `results` to contain the `{ choice_of_metric } ` key; got { x .keys ()} ."
101-
102106 def f (x ):
107+ # If the verifier output is a dict, assume it contains a "score" key.
103108 if isinstance (x [choice_of_metric ], dict ):
104109 return x [choice_of_metric ]["score" ]
105110 return x [choice_of_metric ]
@@ -117,72 +122,79 @@ def f(x):
117122 "search_round" : search_round ,
118123 "num_noises" : len (noises ),
119124 "best_noise_seed" : topk_scores [0 ]["seed" ],
125+ "best_noise" : topk_scores [0 ]["noise" ],
120126 "best_score" : topk_scores [0 ][choice_of_metric ],
121127 "choice_of_metric" : choice_of_metric ,
122128 "best_img_path" : best_img_path ,
123129 }
124- # Save the best config JSON file alongside the images.
125- best_json_filename = best_img_path .replace (".png" , ".json" )
126- with open (best_json_filename , "w" ) as f :
127- json .dump (datapoint , f , indent = 4 )
130+
131+ # Check if the neighbors have any improvements (zero-order only).
132+ search_method = search_args .get ("search_method" , "random" ) if search_args else "random"
133+ if search_args and search_method == "zero-order" :
134+ first_score = f (results [0 ])
135+ neighbors_with_better_score = any (f (item ) > first_score for item in results [1 :])
136+ datapoint ["neighbors_improvement" ] = neighbors_with_better_score
137+
138+ # Serialize.
139+ if search_method == "zero-order" :
140+ if datapoint ["neighbors_improvement" ]:
141+ serialize_artifacts (images_info , prompt , search_round , root_dir , datapoint )
142+ else :
143+ print ("Skipping serialization as there was no improvement in this round." )
144+ elif search_method == "random" :
145+ serialize_artifacts (images_info , prompt , search_round , root_dir , datapoint )
146+
128147 return datapoint
129148
130149
131150@torch .no_grad ()
132151def main ():
133- """
134- Main function:
135- - Parses CLI arguments.
136- - Creates an output directory based on verifier and current datetime.
137- - Loads prompts.
138- - Loads the image-generation pipeline.
139- - Loads the verifier model.
140- - Runs several search rounds where for each prompt a pool of random noises is generated,
141- candidate images are produced and verified, and the best noise is chosen.
142- """
152+ # === Load configuration and CLI arguments ===
143153 args = parse_cli_args ()
144-
145- # Build a config dictionary for parameters that need to be passed around.
146154 with open (args .pipeline_config_path , "r" ) as f :
147155 config = json .load (f )
148156 config .update (vars (args ))
149157
150- search_rounds = config ["search_args" ]["search_rounds" ]
158+ search_args = config ["search_args" ]
159+ search_rounds = search_args ["search_rounds" ]
160+ search_method = search_args .get ("search_method" , "random" )
151161 num_prompts = config ["num_prompts" ]
152162
153- # Create a root output directory: output/{verifier_to_use}/{current_datetime}
163+ # === Create output directory ===
154164 current_datetime = datetime .now ().strftime ("%Y%m%d_%H%M%S" )
155165 pipeline_name = config .pop ("pretrained_model_name_or_path" )
156- root_dir = os .path .join (
166+ verifier_name = config ["verifier_args" ]["name" ]
167+ choice_of_metric = config ["verifier_args" ]["choice_of_metric" ]
168+ output_dir = os .path .join (
157169 "output" ,
158170 MODEL_NAME_MAP [pipeline_name ],
159- config [ "verifier_args" ][ "name" ] ,
160- config [ "verifier_args" ][ " choice_of_metric" ] ,
171+ verifier_name ,
172+ choice_of_metric ,
161173 current_datetime ,
162174 )
163- os .makedirs (root_dir , exist_ok = True )
164- print (f"Artifacts will be saved to: { root_dir } " )
165- with open (os .path .join (root_dir , "config.json" ), "w" ) as f :
166- json .dump (config , f )
175+ os .makedirs (output_dir , exist_ok = True )
176+ print (f"Artifacts will be saved to: { output_dir } " )
177+ with open (os .path .join (output_dir , "config.json" ), "w" ) as f :
178+ json .dump (config , f , indent = 4 )
167179
168- # Load prompts from file.
180+ # === Load prompts ===
169181 if args .prompt is None :
170182 with open ("prompts_open_image_pref_v1.txt" , "r" , encoding = "utf-8" ) as f :
171- prompts = [line .strip () for line in f . readlines () if line .strip ()]
183+ prompts = [line .strip () for line in f if line .strip ()]
172184 if num_prompts != "all" :
173185 prompts = prompts [:num_prompts ]
174- print (f"Using { len (prompts )} prompt(s)." )
175186 else :
176187 prompts = [args .prompt ]
188+ print (f"Using { len (prompts )} prompt(s)." )
177189
178- # Set up the image-generation pipeline (on the first GPU if available).
190+ # === Set up the image-generation pipeline ===
179191 torch_dtype = TORCH_DTYPE_MAP [config .pop ("torch_dtype" )]
180192 pipe = DiffusionPipeline .from_pretrained (pipeline_name , torch_dtype = torch_dtype )
181- if not config [ "use_low_gpu_vram" ] :
193+ if not config . get ( "use_low_gpu_vram" , False ) :
182194 pipe = pipe .to ("cuda:0" )
183195 pipe .set_progress_bar_config (disable = True )
184196
185- # Load the verifier model.
197+ # === Load verifier model ===
186198 verifier_args = config ["verifier_args" ]
187199 if verifier_args ["name" ] == "gemini" :
188200 from verifiers import GeminiVerifier
@@ -191,33 +203,88 @@ def main():
191203 else :
192204 from verifiers .qwen_verifier import QwenVerifier
193205
194- verifier = QwenVerifier (use_low_gpu_vram = config ["use_low_gpu_vram" ])
195-
196- # Main loop: For each search round and each prompt, generate images, verify, and save artifacts.
197- for round in range (1 , search_rounds + 1 ):
198- print (f"\n === Round: { round } ===" )
199- num_noises_to_sample = 2 ** round # scale noise pool.
200- for prompt in tqdm (prompts , desc = "Sampling prompts" ):
201- noises = get_noises (
202- max_seed = MAX_SEED ,
203- num_samples = num_noises_to_sample ,
204- height = config ["pipeline_call_args" ]["height" ],
205- width = config ["pipeline_call_args" ]["width" ],
206- dtype = torch_dtype ,
207- fn = get_latent_prep_fn (pipeline_name ),
208- )
209- print (f"Number of noise samples: { len (noises )} " )
210- datapoint_for_current_round = sample (
206+ verifier = QwenVerifier (use_low_gpu_vram = config .get ("use_low_gpu_vram" , False ))
207+
208+ # === Main loop: For each prompt and each search round ===
209+ for prompt in tqdm (prompts , desc = "Processing prompts" ):
210+ search_round = 1
211+
212+ # For zero-order search, we store the best datapoint per round.
213+ best_datapoint_per_round = {}
214+
215+ while search_round <= search_rounds :
216+ # Determine the number of noise samples.
217+ if search_method == "zero-order" :
218+ num_noises_to_sample = 1
219+ else :
220+ num_noises_to_sample = 2 ** search_round
221+
222+ print (f"\n === Prompt: { prompt } | Round: { search_round } ===" )
223+
224+ # --- Generate noise pool ---
225+ should_regenate_noise = True
226+ previous_round = search_round - 1
227+ if previous_round in best_datapoint_per_round :
228+ was_improvement = best_datapoint_per_round [previous_round ]["neighbors_improvement" ]
229+ if was_improvement :
230+ should_regenate_noise = False
231+
232+ # For subsequent rounds in zero-order: use best noise from previous round.
233+ # This happens ONLY if there was an improvement with the neighbors in the
234+ # previous round, otherwise round is progressed with newly sampled noise.
235+ if should_regenate_noise :
236+ # Standard noise sampling.
237+ if search_method == "zero-order" and search_round != 1 :
238+ print ("Regenerating base noise because the previous round was rejected." )
239+ noises = get_noises (
240+ max_seed = MAX_SEED ,
241+ num_samples = num_noises_to_sample ,
242+ height = config ["pipeline_call_args" ]["height" ],
243+ width = config ["pipeline_call_args" ]["width" ],
244+ dtype = torch_dtype ,
245+ fn = get_latent_prep_fn (pipeline_name ),
246+ )
247+ else :
248+ if best_datapoint_per_round [previous_round ]:
249+ if best_datapoint_per_round [previous_round ]["neighbors_improvement" ]:
250+ print ("Using the best noise from the previous round." )
251+ prev_dp = best_datapoint_per_round [previous_round ]
252+ noises = {int (prev_dp ["best_noise_seed" ]): prev_dp ["best_noise" ]}
253+
254+ if search_method == "zero-order" :
255+ # Process the noise to generate neighbors.
256+ base_seed , base_noise = next (iter (noises .items ()))
257+ neighbors = generate_neighbors (
258+ base_noise , threshold = search_args ["threshold" ], num_neighbors = search_args ["num_neighbors" ]
259+ ).squeeze (0 )
260+ # Concatenate the base noise with its neighbors.
261+ neighbors_and_noise = torch .cat ([base_noise , neighbors ], dim = 0 )
262+ new_noises = {}
263+ for i , noise_tensor in enumerate (neighbors_and_noise ):
264+ new_noises [base_seed + i ] = noise_tensor .unsqueeze (0 )
265+ noises = new_noises
266+
267+ print (f"Number of noise samples for prompt '{ prompt } ': { len (noises )} " )
268+
269+ # --- Sampling, verifying, and saving artifacts ---
270+ datapoint = sample (
211271 noises = noises ,
212272 prompt = prompt ,
213- search_round = round ,
273+ search_round = search_round ,
214274 pipe = pipe ,
215275 verifier = verifier ,
216276 topk = TOPK ,
217- root_dir = root_dir ,
277+ root_dir = output_dir ,
218278 config = config ,
219279 )
220280
281+ if search_method == "zero-order" :
282+ # Update the best datapoint for zero-order.
283+ if datapoint ["neighbors_improvement" ]:
284+ best_datapoint_per_round [search_round ] = datapoint
285+
286+ search_round += 1
287+
221288
222289if __name__ == "__main__" :
223290 main ()
0 commit comments