Skip to content

Commit 32bc50b

Browse files
committed
Release XVERSE-13B-Chat
1 parent 93d6399 commit 32bc50b

File tree

6 files changed

+1133
-134
lines changed

6 files changed

+1133
-134
lines changed

README.md

Lines changed: 344 additions & 44 deletions
Large diffs are not rendered by default.

README_EN.md

Lines changed: 344 additions & 41 deletions
Large diffs are not rendered by default.

README_JA.md

Lines changed: 350 additions & 48 deletions
Large diffs are not rendered by default.

chat_demo.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import argparse
2+
import torch
3+
import gradio as gr
4+
import json
5+
from datetime import datetime
6+
from transformers import AutoModelForCausalLM, AutoTokenizer,GenerationConfig
7+
8+
tokenizer, model = None, None
9+
10+
def init_model(args):
11+
global tokenizer, model
12+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, truncation_side="left", padding_side="left")
13+
model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True, torch_dtype=torch.bfloat16,
14+
low_cpu_mem_usage=True, device_map='auto')
15+
model.generation_config = GenerationConfig.from_pretrained(args.model_path)
16+
model = model.eval()
17+
18+
def chat(message, history, request: gr.Request):
19+
global tokenizer, model
20+
history = history or []
21+
history.append({"role": "user", "content": message})
22+
23+
# init
24+
history.append({"role": "assistant", "content": ""})
25+
utter_history = []
26+
for i in range(0, len(history), 2):
27+
utter_history.append([history[i]["content"], history[i+1]["content"]])
28+
29+
# chat with stream
30+
for next_text in model.chat(tokenizer, history[:-1], stream=True):
31+
utter_history[-1][1] += next_text
32+
history[-1]["content"] += next_text
33+
if torch.backends.mps.is_available():
34+
torch.mps.empty_cache()
35+
yield utter_history, history
36+
37+
# log
38+
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
39+
print(f'{current_time} request_ip:{request.client.host}\nquery: {message}\nhistory: {json.dumps(history, ensure_ascii=False)}\nanswer: {json.dumps(utter_history[-1][1], ensure_ascii=False)}')
40+
41+
def get_args():
42+
parser = argparse.ArgumentParser()
43+
parser.add_argument("--port", type=int, default=36000,
44+
help="server port")
45+
parser.add_argument("--title", type=str, default="XVERSE-13B-Chat",
46+
help="server title")
47+
parser.add_argument("--model_path", type=str, default="./VERSE-13B-Chat",
48+
help="model path")
49+
parser.add_argument("--tokenizer_path", type=str, default="./XVERSE-13B-Chat",
50+
help="Path to the tokenizer.")
51+
args = parser.parse_args()
52+
return args
53+
54+
if __name__ == "__main__":
55+
args = get_args()
56+
# 初始化模型
57+
init_model(args)
58+
59+
# 构建demo应用
60+
with gr.Blocks() as demo:
61+
gr.Markdown("# <center>{}</center>".format(args.title))
62+
chatbot = gr.Chatbot(label="Chat history", height=650).style(color_map=("green", "pink"))
63+
state = gr.State([])
64+
65+
with gr.Row():
66+
text_box = gr.Textbox(label="Message", show_label=False, placeholder="Enter message and press enter").style(container=False)
67+
68+
with gr.Row():
69+
submit_btn = gr.Button(value="Send", variant="secondary")
70+
reset_btn = gr.Button(value="Reset")
71+
72+
text_box.submit(fn=chat,
73+
inputs=[text_box, state],
74+
outputs=[chatbot, state],
75+
api_name="chat")
76+
submit_btn.click(fn=chat,
77+
inputs=[text_box, state],
78+
outputs=[chatbot, state])
79+
80+
# 用于清空text_box
81+
def clear_textbox():
82+
return gr.update(value="")
83+
text_box.submit(fn=clear_textbox, inputs=None, outputs=[text_box])
84+
submit_btn.click(fn=clear_textbox, inputs=None, outputs=[text_box])
85+
86+
# 用于清空页面和重置state
87+
def reset():
88+
return None, []
89+
reset_btn.click(fn=reset, inputs=None, outputs=[chatbot, state])
90+
91+
demo.queue(concurrency_count=4)
92+
demo.launch(server_name="0.0.0.0", server_port=args.port)

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
transformers>=4.29.1
22
torch>=2.0
33
gradio>=3.39.0
4+
accelerate>=0.21.0
5+
cpm_kernels>=1.0.11

text_generation_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
def init_model(args):
1212
global tokenizer, model
1313
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, truncation_side="left", padding_side="left")
14-
model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map='auto')
14+
model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map='auto')
1515
model = model.eval()
1616

1717

0 commit comments

Comments
 (0)