From cbfc6d835e493e20ec34be264f451f61ad30ccfe Mon Sep 17 00:00:00 2001 From: swg Date: Mon, 24 Apr 2023 21:21:16 -0400 Subject: [PATCH 1/4] Support defaulting to infinity or -1 for chat completions --- examples/high_level_api/fastapi_server.py | 4 ++-- llama_cpp/llama.py | 16 +++++++++------- llama_cpp/server/__main__.py | 4 ++-- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/examples/high_level_api/fastapi_server.py b/examples/high_level_api/fastapi_server.py index 3ed0eaca3..bc4b0f305 100644 --- a/examples/high_level_api/fastapi_server.py +++ b/examples/high_level_api/fastapi_server.py @@ -62,7 +62,7 @@ class Settings(BaseSettings): class CreateCompletionRequest(BaseModel): prompt: str suffix: Optional[str] = Field(None) - max_tokens: int = 16 + max_tokens: Optional[Union[int, None]] = 16 temperature: float = 0.8 top_p: float = 0.95 echo: bool = False @@ -156,7 +156,7 @@ class CreateChatCompletionRequest(BaseModel): top_p: float = 0.95 stream: bool = False stop: List[str] = [] - max_tokens: int = 128 + max_tokens: Optional[Union[int, None]] = -1 # ignored or currently unsupported model: Optional[str] = Field(None) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 70dcea992..5405b1629 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -377,7 +377,7 @@ def _create_completion( self, prompt: str, suffix: Optional[str] = None, - max_tokens: int = 16, + max_tokens: Optional[Union[int, None]] = 16, temperature: float = 0.8, top_p: float = 0.95, logprobs: Optional[int] = None, @@ -402,7 +402,9 @@ def _create_completion( if self.verbose: llama_cpp.llama_reset_timings(self.ctx) - if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)): + if max_tokens not in [-1, None] and len(prompt_tokens) + max_tokens > int( + llama_cpp.llama_n_ctx(self.ctx) + ): raise ValueError( f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" ) @@ -487,7 +489,7 @@ def _create_completion( ], } - if len(completion_tokens) >= max_tokens: + if max_tokens not in [-1, None] and len(completion_tokens) >= max_tokens: text = self.detokenize(completion_tokens) finish_reason = "length" break @@ -591,7 +593,7 @@ def create_completion( self, prompt: str, suffix: Optional[str] = None, - max_tokens: int = 128, + max_tokens: Optional[Union[int, None]] = 128, temperature: float = 0.8, top_p: float = 0.95, logprobs: Optional[int] = None, @@ -626,7 +628,7 @@ def create_completion( completion_or_chunks = self._create_completion( prompt=prompt, suffix=suffix, - max_tokens=max_tokens, + max_tokens=-1 if max_tokens is None else max_tokens, temperature=temperature, top_p=top_p, logprobs=logprobs, @@ -646,7 +648,7 @@ def __call__( self, prompt: str, suffix: Optional[str] = None, - max_tokens: int = 128, + max_tokens: Optional[Union[int, None]] = 128, temperature: float = 0.8, top_p: float = 0.95, logprobs: Optional[int] = None, @@ -758,7 +760,7 @@ def create_chat_completion( top_k: int = 40, stream: bool = False, stop: Optional[List[str]] = [], - max_tokens: int = 256, + max_tokens: Optional[Union[int, None]] = -1, repeat_penalty: float = 1.1, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: """Generate a chat completion from a list of messages. diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index b2ec4def6..b76cae2d4 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -77,7 +77,7 @@ def get_llama(): class CreateCompletionRequest(BaseModel): prompt: Union[str, List[str]] suffix: Optional[str] = Field(None) - max_tokens: int = 16 + max_tokens: Optional[Union[int, None]] = 16 temperature: float = 0.8 top_p: float = 0.95 echo: bool = False @@ -179,7 +179,7 @@ class CreateChatCompletionRequest(BaseModel): top_p: float = 0.95 stream: bool = False stop: Optional[List[str]] = [] - max_tokens: int = 128 + max_tokens: Optional[Union[int, None]] = -1 # ignored or currently unsupported model: Optional[str] = Field(None) From 99ff17556284d684693821d169cd35059a82b6a9 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 22 Dec 2023 13:41:06 -0500 Subject: [PATCH 2/4] Check if completion_tokens is none in error handler. --- llama_cpp/server/errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_cpp/server/errors.py b/llama_cpp/server/errors.py index febe3e39d..9d3d35598 100644 --- a/llama_cpp/server/errors.py +++ b/llama_cpp/server/errors.py @@ -72,7 +72,7 @@ def context_length_exceeded( return 400, ErrorResponse( message=message.format( context_window, - completion_tokens + prompt_tokens, + (completion_tokens or 0) + prompt_tokens, prompt_tokens, completion_tokens, ), # type: ignore From 06808782405e9a49550fb008f507974d0353ce2a Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 22 Dec 2023 13:47:04 -0500 Subject: [PATCH 3/4] fix: max_tokens in create completion should match openai spec --- llama_cpp/server/types.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llama_cpp/server/types.py b/llama_cpp/server/types.py index f0867bc4e..f0827d762 100644 --- a/llama_cpp/server/types.py +++ b/llama_cpp/server/types.py @@ -110,7 +110,9 @@ class CreateCompletionRequest(BaseModel): default=None, description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.", ) - max_tokens: int = max_tokens_field + max_tokens: Optional[int] = Field( + default=16, ge=0, description="The maximum number of tokens to generate." + ) temperature: float = temperature_field top_p: float = top_p_field min_p: float = min_p_field From 27856f980643a860bde0fc73c967d727fe6be63d Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 22 Dec 2023 13:59:14 -0500 Subject: [PATCH 4/4] Fix __call__ --- llama_cpp/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 5c1311b82..788732bd3 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1951,7 +1951,7 @@ def __call__( self, prompt: str, suffix: Optional[str] = None, - max_tokens: Optional[Union[int, None]] = 128, + max_tokens: Optional[int] = 16, temperature: float = 0.8, top_p: float = 0.95, min_p: float = 0.05,