Skip to content

Commit e3e0518

Browse files
openai verifier (#16)
* openai verifier Co-authored-by: zhuole1025 <zhuole1025@gmail.com> * fixes * fixes * updates * contributor note --------- Co-authored-by: zhuole1025 <zhuole1025@gmail.com>
1 parent 53cecf8 commit e3e0518

File tree

7 files changed

+163
-8
lines changed

7 files changed

+163
-8
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ But it's been growing now! Check out the rest of the README to know more 🤗
1414

1515
**Updates**
1616

17+
🔥 01/03/2025: OpenAIVerifier was added in [this PR](https://github.com/sayakpaul/tt-scale-flux/pull/16). Specify "openai" in the `name` under `verifier_args`. Thanks to [zhuole1025](https://github.com/zhuole1025) for contributing this!
18+
1719
🔥 27/02/2025: [MaximClouser](https://github.com/MaximClouser) implemented a ComfyUI node for inference-time
1820
scaling in [this repo](https://github.com/YRIKKA/ComfyUI-InferenceTimeScaling). Check it out!
1921

@@ -268,6 +270,9 @@ parameter under `verifier_args`:
268270
}
269271
```
270272

273+
`model_name` is supported for the other non-local verifiers. For example, for the `GeminiVerifier`, you can
274+
pass any model supported by the Gemini API through `model_name`.
275+
271276
You can also bring in your own verifier by implementing a so-called `Verifier` class following the structure of either of `GeminiVerifier` or `QwenVerifier`. You will then have to make adjustments to the following places:
272277

273278
* [Scoring](https://github.com/sayakpaul/tt-scale-flux/blob/c654bc066171aee9c765fa42a322f65415529a77/main.py#L135)

utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from diffusers.utils.torch_utils import randn_tensor
33
from diffusers import FluxPipeline
4+
import base64
45
import re
56
import hashlib
67
from typing import Dict
@@ -236,12 +237,16 @@ def load_image(path_or_url: Union[str, Image.Image]) -> Image.Image:
236237
return Image.open(path_or_url)
237238

238239

239-
def convert_to_bytes(path_or_url: Union[str, Image.Image]) -> bytes:
240+
def convert_to_bytes(path_or_url: Union[str, Image.Image], b64_encode: bool = False) -> bytes:
240241
"""Load an image from a path or URL and convert it to bytes."""
241242
image = load_image(path_or_url).convert("RGB")
242243
image_bytes_io = io.BytesIO()
243244
image.save(image_bytes_io, format="PNG")
244-
return image_bytes_io.getvalue()
245+
image_bytes = image_bytes_io.getvalue()
246+
if not b64_encode:
247+
return image_bytes
248+
else:
249+
return base64.b64encode(image_bytes).decode("utf-8")
245250

246251

247252
def recover_json_from_output(output: str):

verifiers/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,21 @@
33
except Exception as e:
44
GeminiVerifier = None
55

6+
try:
7+
from .openai_verifier import OpenAIVerifier
8+
except Exception as e:
9+
print(f"{e=}")
10+
OpenAIVerifier = None
11+
612
try:
713
from .qwen_verifier import QwenVerifier
8-
except:
14+
except Exception as e:
915
QwenVerifier = None
1016

1117
SUPPORTED_VERIFIERS = {
1218
"qwen": QwenVerifier if QwenVerifier else None,
1319
"gemini": GeminiVerifier if GeminiVerifier else None,
20+
"openai": OpenAIVerifier if OpenAIVerifier else None,
1421
}
1522

1623
SUPPORTED_METRICS = {k: getattr(v, "SUPPORTED_METRIC_CHOICES", None) for k, v in SUPPORTED_VERIFIERS.items()}

verifiers/gemini_verifier.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@
33
import typing_extensions as typing
44
import json
55
import os
6+
import sys
67
from typing import Union
78
from PIL import Image
89
from concurrent.futures import ThreadPoolExecutor, as_completed
9-
import sys
10-
from .base_verifier import BaseVerifier
1110

12-
sys.path.append("..")
11+
current_dir = os.path.dirname(os.path.abspath(__file__))
12+
root_dir = os.path.abspath(os.path.join(current_dir, ".."))
13+
14+
sys.path.insert(0, current_dir)
15+
sys.path.insert(0, root_dir)
16+
1317

18+
from base_verifier import BaseVerifier
1419
from utils import convert_to_bytes
1520

1621

verifiers/openai_verifier.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from openai import OpenAI
2+
from pydantic import BaseModel
3+
import os
4+
from typing import Union
5+
from PIL import Image
6+
import json
7+
from concurrent.futures import ThreadPoolExecutor, as_completed
8+
import sys
9+
10+
current_dir = os.path.dirname(os.path.abspath(__file__))
11+
root_dir = os.path.abspath(os.path.join(current_dir, ".."))
12+
13+
sys.path.insert(0, current_dir)
14+
sys.path.insert(0, root_dir)
15+
16+
from base_verifier import BaseVerifier
17+
from utils import convert_to_bytes
18+
19+
20+
class Score(BaseModel):
21+
score: int
22+
explanation: str
23+
24+
25+
class Grading(BaseModel):
26+
accuracy_to_prompt: Score
27+
creativity_and_originality: Score
28+
visual_quality_and_realism: Score
29+
consistency_and_cohesion: Score
30+
emotional_or_thematic_resonance: Score
31+
overall_score: Score
32+
33+
34+
class OpenAIVerifier(BaseVerifier):
35+
SUPPORTED_METRIC_CHOICES = [
36+
"accuracy_to_prompt",
37+
"creativity_and_originality",
38+
"visual_quality_and_realism",
39+
"consistency_and_cohesion",
40+
"emotional_or_thematic_resonance",
41+
"overall_score",
42+
]
43+
44+
def __init__(self, seed=1994, model_name="gpt-4o-2024-11-20", **kwargs):
45+
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
46+
super().__init__(seed=seed, prompt_path=kwargs.pop("prompt_path", None))
47+
self.model_name = model_name
48+
self.seed = seed
49+
50+
def prepare_inputs(self, images: Union[list[Image.Image], Image.Image], prompts: Union[list[str], str], **kwargs):
51+
"""Prepare inputs for the API from a given prompt and image."""
52+
inputs = []
53+
images = images if isinstance(images, list) else [images]
54+
prompts = prompts if isinstance(prompts, list) else [prompts]
55+
56+
for prompt, image in zip(prompts, images):
57+
# Convert image to base64
58+
base64_image = convert_to_bytes(image, b64_encode=True)
59+
60+
message = {
61+
"role": "user",
62+
"content": [
63+
{"type": "text", "text": prompt},
64+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
65+
],
66+
}
67+
inputs.append(message)
68+
69+
return inputs
70+
71+
def score(self, inputs, **kwargs) -> list[dict[str, float]]:
72+
system_message = {"role": "system", "content": self.verifier_prompt}
73+
74+
def call_generate_content(parts):
75+
conversation = [system_message, parts]
76+
response = self.client.beta.chat.completions.parse(
77+
model=self.model_name, messages=conversation, temperature=1, response_format=Grading
78+
)
79+
return response.choices[0].message.parsed.model_dump()
80+
81+
results = []
82+
max_workers = min(len(inputs), 4)
83+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
84+
futures = [executor.submit(call_generate_content, group) for group in inputs]
85+
for future in as_completed(futures):
86+
try:
87+
results.append(future.result())
88+
except Exception as e:
89+
# Handle exceptions as appropriate.
90+
print(f"An error occurred during API call: {e}")
91+
return results
92+
93+
94+
# Define inputs
95+
if __name__ == "__main__":
96+
verifier = OpenAIVerifier()
97+
image_urls = [
98+
(
99+
"realistic photo a shiny black SUV car with a mountain in the background.",
100+
"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/flux-edit-artifacts/assets/car.jpg",
101+
),
102+
(
103+
"photo a green and funny creature standing in front a lightweight forest.",
104+
"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/flux-edit-artifacts/assets/green_creature.jpg",
105+
),
106+
]
107+
108+
prompts = []
109+
images = []
110+
for text, path_or_url in image_urls:
111+
prompts.append(text)
112+
images.append(path_or_url)
113+
114+
# # Single image
115+
# response = client.models.generate_content(
116+
# model='gemini-2.0-flash',
117+
# contents=[
118+
# "realistic photo a shiny black SUV car with a mountain in the background.",
119+
# load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/flux-edit-artifacts/assets/car.jpg")
120+
# ],
121+
# config=generation_config
122+
# )
123+
inputs = verifier.prepare_inputs(images=images, prompts=prompts)
124+
response = verifier.score(inputs)
125+
126+
with open("results.json", "w") as f:
127+
json.dump(response, f)
128+
129+
print(json.dumps(response, indent=4))

verifiers/qwen_verifier.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@
66
import torch
77
from PIL import Image
88
from typing import Union
9-
from .base_verifier import BaseVerifier
9+
import os
10+
import sys
1011

12+
current_dir = os.path.dirname(os.path.abspath(__file__))
13+
sys.path.insert(0, current_dir)
14+
15+
from base_verifier import BaseVerifier
1116

1217
DEFAULT_QWEN_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
1318
# Optional device map that one can use to let `transformers` share a single GPU and CPU.

verifiers/results.json

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)