Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ But it's been growing now! Check out the rest of the README to know more 🤗

**Updates**

🔥 01/03/2025: OpenAIVerifier was added in [this PR](https://github.com/sayakpaul/tt-scale-flux/pull/16). Specify "openai" in the `name` under `verifier_args`. Thanks to [zhuole1025](https://github.com/zhuole1025) for contributing this!

🔥 27/02/2025: [MaximClouser](https://github.com/MaximClouser) implemented a ComfyUI node for inference-time
scaling in [this repo](https://github.com/YRIKKA/ComfyUI-InferenceTimeScaling). Check it out!

Expand Down Expand Up @@ -268,6 +270,9 @@ parameter under `verifier_args`:
}
```

`model_name` is supported for the other non-local verifiers. For example, for the `GeminiVerifier`, you can
pass any model supported by the Gemini API through `model_name`.

You can also bring in your own verifier by implementing a so-called `Verifier` class following the structure of either of `GeminiVerifier` or `QwenVerifier`. You will then have to make adjustments to the following places:

* [Scoring](https://github.com/sayakpaul/tt-scale-flux/blob/c654bc066171aee9c765fa42a322f65415529a77/main.py#L135)
Expand Down
9 changes: 7 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from diffusers.utils.torch_utils import randn_tensor
from diffusers import FluxPipeline
import base64
import re
import hashlib
from typing import Dict
Expand Down Expand Up @@ -236,12 +237,16 @@ def load_image(path_or_url: Union[str, Image.Image]) -> Image.Image:
return Image.open(path_or_url)


def convert_to_bytes(path_or_url: Union[str, Image.Image]) -> bytes:
def convert_to_bytes(path_or_url: Union[str, Image.Image], b64_encode: bool = False) -> bytes:
"""Load an image from a path or URL and convert it to bytes."""
image = load_image(path_or_url).convert("RGB")
image_bytes_io = io.BytesIO()
image.save(image_bytes_io, format="PNG")
return image_bytes_io.getvalue()
image_bytes = image_bytes_io.getvalue()
if not b64_encode:
return image_bytes
else:
return base64.b64encode(image_bytes).decode("utf-8")


def recover_json_from_output(output: str):
Expand Down
9 changes: 8 additions & 1 deletion verifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@
except Exception as e:
GeminiVerifier = None

try:
from .openai_verifier import OpenAIVerifier
except Exception as e:
print(f"{e=}")
OpenAIVerifier = None

try:
from .qwen_verifier import QwenVerifier
except:
except Exception as e:
QwenVerifier = None

SUPPORTED_VERIFIERS = {
"qwen": QwenVerifier if QwenVerifier else None,
"gemini": GeminiVerifier if GeminiVerifier else None,
"openai": OpenAIVerifier if OpenAIVerifier else None,
}

SUPPORTED_METRICS = {k: getattr(v, "SUPPORTED_METRIC_CHOICES", None) for k, v in SUPPORTED_VERIFIERS.items()}
11 changes: 8 additions & 3 deletions verifiers/gemini_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
import typing_extensions as typing
import json
import os
import sys
from typing import Union
from PIL import Image
from concurrent.futures import ThreadPoolExecutor, as_completed
import sys
from .base_verifier import BaseVerifier

sys.path.append("..")
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, ".."))

sys.path.insert(0, current_dir)
sys.path.insert(0, root_dir)


from base_verifier import BaseVerifier
from utils import convert_to_bytes


Expand Down
129 changes: 129 additions & 0 deletions verifiers/openai_verifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from openai import OpenAI
from pydantic import BaseModel
import os
from typing import Union
from PIL import Image
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
import sys

current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, ".."))

sys.path.insert(0, current_dir)
sys.path.insert(0, root_dir)

from base_verifier import BaseVerifier
from utils import convert_to_bytes


class Score(BaseModel):
score: int
explanation: str


class Grading(BaseModel):
accuracy_to_prompt: Score
creativity_and_originality: Score
visual_quality_and_realism: Score
consistency_and_cohesion: Score
emotional_or_thematic_resonance: Score
overall_score: Score


class OpenAIVerifier(BaseVerifier):
SUPPORTED_METRIC_CHOICES = [
"accuracy_to_prompt",
"creativity_and_originality",
"visual_quality_and_realism",
"consistency_and_cohesion",
"emotional_or_thematic_resonance",
"overall_score",
]

def __init__(self, seed=1994, model_name="gpt-4o-2024-11-20", **kwargs):
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
super().__init__(seed=seed, prompt_path=kwargs.pop("prompt_path", None))
self.model_name = model_name
self.seed = seed

def prepare_inputs(self, images: Union[list[Image.Image], Image.Image], prompts: Union[list[str], str], **kwargs):
"""Prepare inputs for the API from a given prompt and image."""
inputs = []
images = images if isinstance(images, list) else [images]
prompts = prompts if isinstance(prompts, list) else [prompts]

for prompt, image in zip(prompts, images):
# Convert image to base64
base64_image = convert_to_bytes(image, b64_encode=True)

message = {
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
],
}
inputs.append(message)

return inputs

def score(self, inputs, **kwargs) -> list[dict[str, float]]:
system_message = {"role": "system", "content": self.verifier_prompt}

def call_generate_content(parts):
conversation = [system_message, parts]
response = self.client.beta.chat.completions.parse(
model=self.model_name, messages=conversation, temperature=1, response_format=Grading
)
return response.choices[0].message.parsed.model_dump()

results = []
max_workers = min(len(inputs), 4)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(call_generate_content, group) for group in inputs]
for future in as_completed(futures):
try:
results.append(future.result())
except Exception as e:
# Handle exceptions as appropriate.
print(f"An error occurred during API call: {e}")
return results


# Define inputs
if __name__ == "__main__":
verifier = OpenAIVerifier()
image_urls = [
(
"realistic photo a shiny black SUV car with a mountain in the background.",
"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/flux-edit-artifacts/assets/car.jpg",
),
(
"photo a green and funny creature standing in front a lightweight forest.",
"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/flux-edit-artifacts/assets/green_creature.jpg",
),
]

prompts = []
images = []
for text, path_or_url in image_urls:
prompts.append(text)
images.append(path_or_url)

# # Single image
# response = client.models.generate_content(
# model='gemini-2.0-flash',
# contents=[
# "realistic photo a shiny black SUV car with a mountain in the background.",
# load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/flux-edit-artifacts/assets/car.jpg")
# ],
# config=generation_config
# )
inputs = verifier.prepare_inputs(images=images, prompts=prompts)
response = verifier.score(inputs)

with open("results.json", "w") as f:
json.dump(response, f)

print(json.dumps(response, indent=4))
7 changes: 6 additions & 1 deletion verifiers/qwen_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
import torch
from PIL import Image
from typing import Union
from .base_verifier import BaseVerifier
import os
import sys

current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)

from base_verifier import BaseVerifier

DEFAULT_QWEN_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
# Optional device map that one can use to let `transformers` share a single GPU and CPU.
Expand Down
1 change: 0 additions & 1 deletion verifiers/results.json

This file was deleted.