From 5dab5e81809c172cde9d60b7b1b957eae167a214 Mon Sep 17 00:00:00 2001 From: Jerry-Kon <1156807819@qq.com> Date: Fri, 8 Sep 2023 20:38:58 +0800 Subject: [PATCH] updata pr --- apis/api.py | 14 +++++++------ integration_tests/test_output.py | 34 ++++++++++++++++++++++++++++++++ llms/chat.py | 12 ++--------- llms/chatglm.py | 26 ++++++++++++++++++++++++ llms/llama.py | 24 +++++----------------- 5 files changed, 75 insertions(+), 35 deletions(-) create mode 100644 integration_tests/test_output.py create mode 100644 llms/chatglm.py diff --git a/apis/api.py b/apis/api.py index f2d3bca..6c9a6ab 100644 --- a/apis/api.py +++ b/apis/api.py @@ -1,6 +1,8 @@ from llms.chat import Chat from llms.llama import LlamaChat from llms.llama_hf import LlamaHFChat +from llms.chatglm import ChatglmChat +import warnings class ChatLLM: @@ -41,13 +43,11 @@ def __init__( llm = fetch_llm(model_name_or_path, host) self.chat = llm(model_name_or_path=model_name_or_path, task=task) - def completion( - self, - prompt: str, - system_prompt: str = None, - ) -> str: + def completion(self, prompt: str, system_prompt: str = None,) -> str: if prompt is None: raise Exception("user prompt must exist") + if self.chat.support_system_prompt() is False and system_prompt is not None: + warnings.warn("system_prompt is not supported by the api") return self.chat.completion(system_prompt=system_prompt, user_prompt=prompt) @@ -57,9 +57,11 @@ def fetch_llm(model_name: str, host: str) -> Chat: if "llama_2" in model_name or "llama-2" in model_name: return LlamaChat if host == "local" else LlamaHFChat + if "chatglm2" in model_name and host == "local": + return ChatglmChat raise UnavailableModelException( - 'model unavailable, supporting model family: "llama_2".' + 'model unavailable, supporting model family: "llama_2", "local chatglm2"' ) diff --git a/integration_tests/test_output.py b/integration_tests/test_output.py new file mode 100644 index 0000000..f615f9f --- /dev/null +++ b/integration_tests/test_output.py @@ -0,0 +1,34 @@ +from apis.api import ChatLLM +import unittest + +class TestOutput(unittest.TestCase): + def test_ChatLLM(self): + test_cases = [ + { + "name": "local model with llama 2", + "model": "/models/llama-2-7b-chat-hf", + "task": "text-generation", + "host": "local", + "prompt": "How many people are there in China", + "system_prompt": None, + "chat": ChatLLM, + }, + { + "name": "local model with chatglm2", + "model": "/data/models/chatglm2-6b", + "task": "text-generation", + "host": "local", + "prompt": "How many people are there in China", + "system_prompt": None, + "chat": ChatLLM, + }, + + ] + for test in test_cases: + chat = ChatLLM( + model_name_or_path=test["model"], task=test["task"], host=test["host"], + ) + result = chat.completion(prompt=test["prompt"], system_prompt=test["system_prompt"]) + self.assertNotEqual(len(result),0) + + diff --git a/llms/chat.py b/llms/chat.py index 34e4f35..f691d65 100644 --- a/llms/chat.py +++ b/llms/chat.py @@ -28,19 +28,11 @@ def support_system_prompt() -> bool: # TODO: Support history conversation in the future. @classmethod @abstractmethod - def prompt( - self, - system_prompt: str = None, - user_prompt: str = None, - ) -> str: + def prompt(self, system_prompt: str = None, user_prompt: str = None,) -> str: pass @abstractmethod - def completion( - self, - system_prompt: str = None, - user_prompt: str = None, - ) -> str: + def completion(self, system_prompt: str = None, user_prompt: str = None,) -> str: """ Args: system_prompt (str): Not all language models support system prompt, e.g. ChatGLM2. diff --git a/llms/chatglm.py b/llms/chatglm.py new file mode 100644 index 0000000..8a815e5 --- /dev/null +++ b/llms/chatglm.py @@ -0,0 +1,26 @@ +from transformers import AutoTokenizer, AutoModel +from llms.chat import LocalChat + + +class ChatglmChat(LocalChat): + def __init__(self, model_name_or_path, task=None): + self.tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, trust_remote_code=True + ) + self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda() + self.model = self.model.eval() + + def completion(self, system_prompt, user_prompt): + history = [] + response, history = self.model.chat( + self.tokenizer, user_prompt, history=history + ) + return response + + @classmethod + def support_system_prompt(self) -> bool: + return False + + @classmethod + def prompt(cls, system_prompt: str = None, user_prompt: str = None): + pass diff --git a/llms/llama.py b/llms/llama.py index 78ba31d..c409146 100644 --- a/llms/llama.py +++ b/llms/llama.py @@ -40,15 +40,11 @@ def __init__( ) @classmethod - def support_system_prompt() -> bool: + def support_system_prompt(self) -> bool: return True @classmethod - def prompt( - cls, - system_prompt: str = None, - user_prompt: str = None, - ) -> str: + def prompt(cls, system_prompt: str = None, user_prompt: str = None,) -> str: if system_prompt is not None and user_prompt is not None: system_content = format_llama_prompt( role=SYSTEM_PROMPT, content=system_prompt @@ -60,11 +56,7 @@ def prompt( else: return format_llama_prompt(content=user_prompt) - def completion( - self, - system_prompt: str = None, - user_prompt: str = None, - ) -> str: + def completion(self, system_prompt: str = None, user_prompt: str = None,) -> str: prompt = LlamaChat.prompt(system_prompt=system_prompt, user_prompt=user_prompt) logging.debug( @@ -116,9 +108,7 @@ def format_llama_prompt( def build_pipeline( - model_name_or_path: str, - task: str, - torch_dtype: torch.dtype, + model_name_or_path: str, task: str, torch_dtype: torch.dtype, ) -> transformers.pipeline: tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, trust_remote_code=True @@ -134,9 +124,5 @@ def build_pipeline( ) return transformers.pipeline( - task=task, - model=model, - tokenizer=tokenizer, - torch_dtype=torch_dtype, - device=0, + task=task, model=model, tokenizer=tokenizer, torch_dtype=torch_dtype, device=0, )