Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions apis/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from llms.chat import Chat
from llms.llama import LlamaChat
from llms.llama_hf import LlamaHFChat
from llms.chatglm import ChatglmChat
import warnings


class ChatLLM:
Expand Down Expand Up @@ -41,13 +43,11 @@ def __init__(
llm = fetch_llm(model_name_or_path, host)
self.chat = llm(model_name_or_path=model_name_or_path, task=task)

def completion(
self,
prompt: str,
system_prompt: str = None,
) -> str:
def completion(self, prompt: str, system_prompt: str = None,) -> str:
if prompt is None:
raise Exception("user prompt must exist")
if self.chat.support_system_prompt() is False and system_prompt is not None:
warnings.warn("system_prompt is not supported by the api")

return self.chat.completion(system_prompt=system_prompt, user_prompt=prompt)

Expand All @@ -57,9 +57,11 @@ def fetch_llm(model_name: str, host: str) -> Chat:

if "llama_2" in model_name or "llama-2" in model_name:
return LlamaChat if host == "local" else LlamaHFChat
if "chatglm2" in model_name and host == "local":
return ChatglmChat

raise UnavailableModelException(
'model unavailable, supporting model family: "llama_2".'
'model unavailable, supporting model family: "llama_2", "local chatglm2"'
)


Expand Down
34 changes: 34 additions & 0 deletions integration_tests/test_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from apis.api import ChatLLM
import unittest

class TestOutput(unittest.TestCase):
def test_ChatLLM(self):
test_cases = [
{
"name": "local model with llama 2",
"model": "/models/llama-2-7b-chat-hf",
"task": "text-generation",
"host": "local",
"prompt": "How many people are there in China",
"system_prompt": None,
"chat": ChatLLM,
},
{
"name": "local model with chatglm2",
"model": "/data/models/chatglm2-6b",
"task": "text-generation",
"host": "local",
"prompt": "How many people are there in China",
"system_prompt": None,
"chat": ChatLLM,
},

]
for test in test_cases:
chat = ChatLLM(
model_name_or_path=test["model"], task=test["task"], host=test["host"],
)
result = chat.completion(prompt=test["prompt"], system_prompt=test["system_prompt"])
self.assertNotEqual(len(result),0)


12 changes: 2 additions & 10 deletions llms/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,11 @@ def support_system_prompt() -> bool:
# TODO: Support history conversation in the future.
@classmethod
@abstractmethod
def prompt(
self,
system_prompt: str = None,
user_prompt: str = None,
) -> str:
def prompt(self, system_prompt: str = None, user_prompt: str = None,) -> str:
pass

@abstractmethod
def completion(
self,
system_prompt: str = None,
user_prompt: str = None,
) -> str:
def completion(self, system_prompt: str = None, user_prompt: str = None,) -> str:
"""
Args:
system_prompt (str): Not all language models support system prompt, e.g. ChatGLM2.
Expand Down
26 changes: 26 additions & 0 deletions llms/chatglm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from transformers import AutoTokenizer, AutoModel
from llms.chat import LocalChat


class ChatglmChat(LocalChat):
def __init__(self, model_name_or_path, task=None):
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True
)
self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda()
self.model = self.model.eval()

def completion(self, system_prompt, user_prompt):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a warning here if we do provide the system_prompt.

history = []
response, history = self.model.chat(
self.tokenizer, user_prompt, history=history
)
return response

@classmethod
def support_system_prompt(self) -> bool:
return False

@classmethod
def prompt(cls, system_prompt: str = None, user_prompt: str = None):
pass
24 changes: 5 additions & 19 deletions llms/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,11 @@ def __init__(
)

@classmethod
def support_system_prompt() -> bool:
def support_system_prompt(self) -> bool:
return True

@classmethod
def prompt(
cls,
system_prompt: str = None,
user_prompt: str = None,
) -> str:
def prompt(cls, system_prompt: str = None, user_prompt: str = None,) -> str:
if system_prompt is not None and user_prompt is not None:
system_content = format_llama_prompt(
role=SYSTEM_PROMPT, content=system_prompt
Expand All @@ -60,11 +56,7 @@ def prompt(
else:
return format_llama_prompt(content=user_prompt)

def completion(
self,
system_prompt: str = None,
user_prompt: str = None,
) -> str:
def completion(self, system_prompt: str = None, user_prompt: str = None,) -> str:
prompt = LlamaChat.prompt(system_prompt=system_prompt, user_prompt=user_prompt)

logging.debug(
Expand Down Expand Up @@ -116,9 +108,7 @@ def format_llama_prompt(


def build_pipeline(
model_name_or_path: str,
task: str,
torch_dtype: torch.dtype,
model_name_or_path: str, task: str, torch_dtype: torch.dtype,
) -> transformers.pipeline:
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True
Expand All @@ -134,9 +124,5 @@ def build_pipeline(
)

return transformers.pipeline(
task=task,
model=model,
tokenizer=tokenizer,
torch_dtype=torch_dtype,
device=0,
task=task, model=model, tokenizer=tokenizer, torch_dtype=torch_dtype, device=0,
)