@@ -105,6 +105,7 @@ def __init__(self,
105105 assert img_detail in ['high' , 'low' ]
106106 self .img_detail = img_detail
107107 self .timeout = timeout
108+ self .o1_model = 'o1' in model or 'o3' in model
108109
109110 super ().__init__ (wait = wait , retry = retry , system_prompt = system_prompt , verbose = verbose , ** kwargs )
110111
@@ -185,17 +186,6 @@ def generate_inner(self, inputs, **kwargs) -> str:
185186 temperature = kwargs .pop ('temperature' , self .temperature )
186187 max_tokens = kwargs .pop ('max_tokens' , self .max_tokens )
187188
188- # context_window = GPT_context_window(self.model)
189- # new_max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
190- # if 0 < new_max_tokens <= 100 and new_max_tokens < max_tokens:
191- # self.logger.warning(
192- # 'Less than 100 tokens left, '
193- # 'may exceed the context window with some additional meta symbols. '
194- # )
195- # if new_max_tokens <= 0:
196- # return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
197- # max_tokens = new_max_tokens
198-
199189 # Will send request if use Azure, dk how to use openai client for it
200190 if self .use_azure :
201191 headers = {'Content-Type' : 'application/json' , 'api-key' : self .key }
@@ -206,10 +196,16 @@ def generate_inner(self, inputs, **kwargs) -> str:
206196 payload = dict (
207197 model = self .model ,
208198 messages = input_msgs ,
209- max_tokens = max_tokens ,
199+ # max_tokens=max_tokens,
210200 n = 1 ,
211201 temperature = temperature ,
212202 ** kwargs )
203+ if self .o1_model :
204+ payload ['max_completion_tokens' ] = max_tokens
205+ payload .pop ('temperature' )
206+ else :
207+ payload ['max_tokens' ] = max_tokens
208+
213209 response = requests .post (
214210 self .api_base ,
215211 headers = headers , data = json .dumps (payload ), timeout = self .timeout * 1.1 )
0 commit comments