Skip to content

Commit e4e0263

Browse files
[VLM] Support o1
1 parent f22451a commit e4e0263

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

vlmeval/api/gpt.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(self,
105105
assert img_detail in ['high', 'low']
106106
self.img_detail = img_detail
107107
self.timeout = timeout
108+
self.o1_model = 'o1' in model or 'o3' in model
108109

109110
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
110111

@@ -185,17 +186,6 @@ def generate_inner(self, inputs, **kwargs) -> str:
185186
temperature = kwargs.pop('temperature', self.temperature)
186187
max_tokens = kwargs.pop('max_tokens', self.max_tokens)
187188

188-
# context_window = GPT_context_window(self.model)
189-
# new_max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
190-
# if 0 < new_max_tokens <= 100 and new_max_tokens < max_tokens:
191-
# self.logger.warning(
192-
# 'Less than 100 tokens left, '
193-
# 'may exceed the context window with some additional meta symbols. '
194-
# )
195-
# if new_max_tokens <= 0:
196-
# return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
197-
# max_tokens = new_max_tokens
198-
199189
# Will send request if use Azure, dk how to use openai client for it
200190
if self.use_azure:
201191
headers = {'Content-Type': 'application/json', 'api-key': self.key}
@@ -206,10 +196,16 @@ def generate_inner(self, inputs, **kwargs) -> str:
206196
payload = dict(
207197
model=self.model,
208198
messages=input_msgs,
209-
max_tokens=max_tokens,
199+
# max_tokens=max_tokens,
210200
n=1,
211201
temperature=temperature,
212202
**kwargs)
203+
if self.o1_model:
204+
payload['max_completion_tokens'] = max_tokens
205+
payload.pop('temperature')
206+
else:
207+
payload['max_tokens'] = max_tokens
208+
213209
response = requests.post(
214210
self.api_base,
215211
headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)

vlmeval/config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,20 @@
8080
"Falcon2-VLM-11B": partial(Falcon2VLM, model_path="tiiuae/falcon-11B-vlm"),
8181
}
8282

83+
o1_key = 'XXX' # noqa: E501
84+
o1_apis = {
85+
'o1': partial(
86+
GPT4V,
87+
model="o1-2024-12-17",
88+
key=o1_key,
89+
api_base='OFFICIAL',
90+
temperature=0,
91+
img_detail='high',
92+
retry=10,
93+
verbose=False,
94+
),
95+
}
96+
8397
api_models = {
8498
# GPT
8599
"GPT4V": partial(
@@ -1086,6 +1100,7 @@
10861100

10871101
model_groups = [
10881102
ungrouped,
1103+
o1_apis,
10891104
api_models,
10901105
xtuner_series,
10911106
qwen_series,

0 commit comments

Comments
 (0)