Skip to content

Commit 6f21464

Browse files
authored
feat: supported batch image generation to speed things up. (#9)
1 parent 47893e0 commit 6f21464

File tree

3 files changed

+45
-21
lines changed

3 files changed

+45
-21
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ Simple re-implementation of inference-time scaling Flux.1-Dev as introduced in [
99

1010
**Updates**
1111

12+
🔥 16/02/2025: Support for batched image generation has been added [in this PR](https://github.com/sayakpaul/tt-scale-flux/pull/9). It speeds up the total time but consumes more memory.
13+
1214
🔥 15/02/2025: Support for structured generation with Qwen2.5 has been added (using `outlines` and `pydantic`) in [this PR](https://github.com/sayakpaul/tt-scale-flux/pull/6).
1315

1416
🔥 15/02/2025: Support to load other pipelines has been added in [this PR](https://github.com/sayakpaul/tt-scale-flux/pull/5)! [Result section](#more-results) has been updated, too.
@@ -138,6 +140,9 @@ python process_results.py --path=path_to_the_output_dir
138140

139141
This should output a collage of the best images generated in each search round, grouped by the same prompt.
140142

143+
By default, the `--batch_size_for_img_gen` is set to 1. To speed up the process (at the expense of more memory),
144+
this number can be increased.
145+
141146
## Controlling the pipeline checkpoint and `__call__()` args
142147

143148
This is controlled via the `--pipeline_config_path` CLI args. By default, it uses [`configs/flux.1_dev.json`](./configs/flux.1_dev.json). You can either modify this one or create your own JSON file to experiment with different pipelines. We provide some predefined configs for Flux.1-Dev, PixArt-Sigma, SDXL, and SD v1.5 in the [`configs`](./conf) directory.

main.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def sample(
2626
config: dict,
2727
) -> dict:
2828
"""
29-
For a given prompt, generate images using all provided noises,
29+
For a given prompt, generate images using all provided noises in batches,
3030
score them with the verifier, and select the top-K noise.
3131
The images and JSON artifacts are saved under `root_dir`.
3232
"""
@@ -35,30 +35,43 @@ def sample(
3535
choice_of_metric = config_cp.pop("choice_of_metric", None)
3636
verifier_to_use = config_cp.pop("verifier_to_use", "gemini")
3737
use_low_gpu_vram = config_cp.pop("use_low_gpu_vram", False)
38+
batch_size_for_img_gen = config_cp.pop("batch_size_for_img_gen", 1)
3839

3940
images_for_prompt = []
4041
noises_used = []
4142
seeds_used = []
4243
prompt_filename = prompt_to_filename(prompt)
4344

44-
for i, (seed, noise) in enumerate(noises.items()):
45-
# Build the output filename inside the provided root directory.
46-
filename = os.path.join(root_dir, f"{prompt_filename}_i@{search_round}_s@{seed}.png")
45+
# Convert the noises dictionary into a list of (seed, noise) tuples.
46+
noise_items = list(noises.items())
47+
48+
# Process the noises in batches.
49+
for i in range(0, len(noise_items), batch_size_for_img_gen):
50+
batch = noise_items[i : i + batch_size_for_img_gen]
51+
seeds_batch, noises_batch = zip(*batch)
52+
filenames_batch = [
53+
os.path.join(root_dir, f"{prompt_filename}_i@{search_round}_s@{seed}.png") for seed in seeds_batch
54+
]
4755

48-
# If using low GPU VRAM (and not Gemini) move the pipeline to cuda before generating.
4956
if use_low_gpu_vram and verifier_to_use != "gemini":
5057
pipe = pipe.to("cuda:0")
51-
print(f"Generating images.")
52-
image = pipe(prompt=prompt, latents=noise, **config_cp).images[0]
58+
print(f"Generating images for batch with seeds: {[s for s in seeds_batch]}.")
59+
60+
# Create a batched prompt list and stack the latents.
61+
batched_prompts = [prompt] * len(noises_batch)
62+
batched_latents = torch.stack(noises_batch).squeeze(dim=1)
63+
64+
batch_result = pipe(prompt=batched_prompts, latents=batched_latents, **config_cp)
65+
batch_images = batch_result.images
5366
if use_low_gpu_vram and verifier_to_use != "gemini":
5467
pipe = pipe.to("cpu")
5568

56-
images_for_prompt.append(image)
57-
noises_used.append(noise)
58-
seeds_used.append(seed)
59-
60-
# Save the intermediate image to the output folder.
61-
image.save(filename)
69+
# Iterate over the batch and save the images.
70+
for seed, noise, image, filename in zip(seeds_batch, noises_batch, batch_images, filenames_batch):
71+
images_for_prompt.append(image)
72+
noises_used.append(noise)
73+
seeds_used.append(seed)
74+
image.save(filename)
6275

6376
# Prepare verifier inputs and perform inference.
6477
verifier_inputs = verifier.prepare_inputs(images=images_for_prompt, prompts=[prompt] * len(images_for_prompt))
@@ -70,20 +83,20 @@ def sample(
7083
for o in outputs:
7184
assert choice_of_metric in o, o.keys()
7285

73-
assert (
74-
len(outputs) == len(images_for_prompt)
75-
), f"Expected len(outputs) to be same as len(images_for_prompt) but got {len(outputs)=} & {len(images_for_prompt)=}"
86+
assert len(outputs) == len(images_for_prompt), (
87+
f"Expected len(outputs) to be same as len(images_for_prompt) but got {len(outputs)=} & {len(images_for_prompt)=}"
88+
)
7689

7790
results = []
7891
for json_dict, seed_val, noise in zip(outputs, seeds_used, noises_used):
79-
# Attach the noise tensor so we can select top-K
92+
# Attach the noise tensor so we can select top-K.
8093
merged = {**json_dict, "noise": noise, "seed": seed_val}
8194
results.append(merged)
8295

8396
# Sort by the chosen metric descending and pick top-K.
8497
for x in results:
8598
assert choice_of_metric in x, (
86-
f"Expected all dicts in `results` to contain the " f"`{choice_of_metric}` key; got {x.keys()}."
99+
f"Expected all dicts in `results` to contain the `{choice_of_metric}` key; got {x.keys()}."
87100
)
88101

89102
def f(x):
@@ -96,7 +109,7 @@ def f(x):
96109

97110
# Print debug information.
98111
for ts in topk_scores:
99-
print(f"Prompt='{prompt}' | Best seed={ts['seed']} | " f"Score={ts[choice_of_metric]}")
112+
print(f"Prompt='{prompt}' | Best seed={ts['seed']} | Score={ts[choice_of_metric]}")
100113

101114
best_img_path = os.path.join(root_dir, f"{prompt_filename}_i@{search_round}_s@{topk_scores[0]['seed']}.png")
102115
datapoint = {
@@ -135,6 +148,7 @@ def main():
135148
"use_low_gpu_vram": args.use_low_gpu_vram,
136149
"choice_of_metric": args.choice_of_metric,
137150
"verifier_to_use": args.verifier_to_use,
151+
"batch_size_for_img_gen": args.batch_size_for_img_gen,
138152
}
139153
with open(args.pipeline_config_path, "r") as f:
140154
config.update(json.load(f))

utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,15 @@ def parse_cli_args():
4848
parser.add_argument(
4949
"--max_new_tokens",
5050
type=int,
51-
default=600,
51+
default=800,
5252
help="Maximum number of tokens for the verifier. Ignored when using Gemini.",
5353
)
54+
parser.add_argument(
55+
"--batch_size_for_img_gen",
56+
type=int,
57+
default=1,
58+
help="Controls the batch size of noises during image generation. Increasing it reduces the total time at the cost of more memory.",
59+
)
5460
parser.add_argument(
5561
"--use_low_gpu_vram",
5662
action="store_true",
@@ -142,7 +148,6 @@ def get_noises(
142148
fn: callable = prepare_latents_for_flux,
143149
) -> Dict[int, torch.Tensor]:
144150
seeds = torch.randint(0, high=max_seed, size=(num_samples,))
145-
print(f"{seeds=}")
146151

147152
noises = {}
148153
for noise_seed in seeds:

0 commit comments

Comments
 (0)