Skip to content

Commit 7f55aa4

Browse files
authored
feat(model): Support Phi-3 models (#1554)
1 parent 47430f2 commit 7f55aa4

File tree

5 files changed

+74
-0
lines changed

5 files changed

+74
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ At present, we have introduced several key features to showcase our current capa
158158
We offer extensive model support, including dozens of large language models (LLMs) from both open-source and API agents, such as LLaMA/LLaMA2, Baichuan, ChatGLM, Wenxin, Tongyi, Zhipu, and many more.
159159

160160
- News
161+
- 🔥🔥🔥 [Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3)
161162
- 🔥🔥🔥 [Yi-1.5-34B-Chat](https://huggingface.co/01-ai/Yi-1.5-34B-Chat)
162163
- 🔥🔥🔥 [Yi-1.5-9B-Chat](https://huggingface.co/01-ai/Yi-1.5-9B-Chat)
163164
- 🔥🔥🔥 [Yi-1.5-6B-Chat](https://huggingface.co/01-ai/Yi-1.5-6B-Chat)

README.zh.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152
海量模型支持,包括开源、API代理等几十种大语言模型。如LLaMA/LLaMA2、Baichuan、ChatGLM、文心、通义、智谱等。当前已支持如下模型:
153153

154154
- 新增支持模型
155+
- 🔥🔥🔥 [Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3)
155156
- 🔥🔥🔥 [Yi-1.5-34B-Chat](https://huggingface.co/01-ai/Yi-1.5-34B-Chat)
156157
- 🔥🔥🔥 [Yi-1.5-9B-Chat](https://huggingface.co/01-ai/Yi-1.5-9B-Chat)
157158
- 🔥🔥🔥 [Yi-1.5-6B-Chat](https://huggingface.co/01-ai/Yi-1.5-6B-Chat)

dbgpt/configs/model_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,16 @@ def get_device() -> str:
187187
"gemma-2b-it": os.path.join(MODEL_PATH, "gemma-2b-it"),
188188
"starling-lm-7b-beta": os.path.join(MODEL_PATH, "Starling-LM-7B-beta"),
189189
"deepseek-v2-lite-chat": os.path.join(MODEL_PATH, "DeepSeek-V2-Lite-Chat"),
190+
"sailor-14b-chat": os.path.join(MODEL_PATH, "Sailor-14B-Chat"),
191+
# https://huggingface.co/microsoft/Phi-3-medium-128k-instruct
192+
"phi-3-medium-128k-instruct": os.path.join(
193+
MODEL_PATH, "Phi-3-medium-128k-instruct"
194+
),
195+
"phi-3-medium-4k-instruct": os.path.join(MODEL_PATH, "Phi-3-medium-4k-instruct"),
196+
"phi-3-small-128k-instruct": os.path.join(MODEL_PATH, "Phi-3-small-128k-instruct"),
197+
"phi-3-small-8k-instruct": os.path.join(MODEL_PATH, "Phi-3-small-8k-instruct"),
198+
"phi-3-mini-128k-instruct": os.path.join(MODEL_PATH, "Phi-3-mini-128k-instruct"),
199+
"phi-3-mini-4k-instruct": os.path.join(MODEL_PATH, "Phi-3-mini-4k-instruct"),
190200
}
191201

192202
EMBEDDING_MODEL_CONFIG = {

dbgpt/model/adapter/hf_adapter.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,61 @@ def load(self, model_path: str, from_pretrained_kwargs: dict):
396396
return model, tokenizer
397397

398398

399+
class SailorAdapter(QwenAdapter):
400+
"""
401+
https://huggingface.co/sail/Sailor-14B-Chat
402+
"""
403+
404+
def do_match(self, lower_model_name_or_path: Optional[str] = None):
405+
return (
406+
lower_model_name_or_path
407+
and "sailor" in lower_model_name_or_path
408+
and "chat" in lower_model_name_or_path
409+
)
410+
411+
412+
class PhiAdapter(NewHFChatModelAdapter):
413+
"""
414+
https://huggingface.co/microsoft/Phi-3-medium-128k-instruct
415+
"""
416+
417+
support_4bit: bool = True
418+
support_8bit: bool = True
419+
support_system_message: bool = False
420+
421+
def do_match(self, lower_model_name_or_path: Optional[str] = None):
422+
return (
423+
lower_model_name_or_path
424+
and "phi-3" in lower_model_name_or_path
425+
and "instruct" in lower_model_name_or_path
426+
)
427+
428+
def load(self, model_path: str, from_pretrained_kwargs: dict):
429+
if not from_pretrained_kwargs:
430+
from_pretrained_kwargs = {}
431+
if "trust_remote_code" not in from_pretrained_kwargs:
432+
from_pretrained_kwargs["trust_remote_code"] = True
433+
return super().load(model_path, from_pretrained_kwargs)
434+
435+
def get_str_prompt(
436+
self,
437+
params: Dict,
438+
messages: List[ModelMessage],
439+
tokenizer: Any,
440+
prompt_template: str = None,
441+
convert_to_compatible_format: bool = False,
442+
) -> Optional[str]:
443+
str_prompt = super().get_str_prompt(
444+
params,
445+
messages,
446+
tokenizer,
447+
prompt_template,
448+
convert_to_compatible_format,
449+
)
450+
params["custom_stop_words"] = ["<|end|>"]
451+
return str_prompt
452+
453+
399454
# The following code is used to register the model adapter
400455
# The last registered model adapter is matched first
401456
register_model_adapter(YiAdapter)
@@ -408,3 +463,5 @@ def load(self, model_path: str, from_pretrained_kwargs: dict):
408463
register_model_adapter(QwenMoeAdapter)
409464
register_model_adapter(Llama3Adapter)
410465
register_model_adapter(DeepseekV2Adapter)
466+
register_model_adapter(SailorAdapter)
467+
register_model_adapter(PhiAdapter)

dbgpt/model/llm_out/hf_chat_llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def huggingface_chat_generate_stream(
2222
max_new_tokens = int(params.get("max_new_tokens", 2048))
2323
stop_token_ids = params.get("stop_token_ids", [])
2424
do_sample = params.get("do_sample", None)
25+
custom_stop_words = params.get("custom_stop_words", [])
2526

2627
input_ids = tokenizer(prompt).input_ids
2728
# input_ids = input_ids.to(device)
@@ -62,4 +63,8 @@ def huggingface_chat_generate_stream(
6263
out = ""
6364
for new_text in streamer:
6465
out += new_text
66+
if custom_stop_words:
67+
for stop_word in custom_stop_words:
68+
if out.endswith(stop_word):
69+
out = out[: -len(stop_word)]
6570
yield out

0 commit comments

Comments
 (0)