Skip to content

Commit d5d7f8e

Browse files
committed
Update demo code.
1 parent 9c23f63 commit d5d7f8e

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

chat_demo.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from transformers import AutoModelForCausalLM, AutoTokenizer,GenerationConfig
77

88
tokenizer, model = None, None
9+
system_tpl = None
10+
exec_tpl = None
911
css = """
1012
.message.user{
1113
border-color: #BFB0FA !important;
@@ -25,15 +27,21 @@ def init_model(args):
2527
model.generation_config = GenerationConfig.from_pretrained(args.model_path)
2628
model = model.eval()
2729

28-
def chat(message, history, request: gr.Request):
30+
def chat(message, history, system_text, exec_text, request: gr.Request):
2931
global tokenizer, model
3032
history = history or []
31-
history.append({"role": "user", "content": message})
33+
if system_text is not None and system_text != "" and len(history) == 0:
34+
history.append({"role": "system", "content": system_text})
35+
if exec_text is not None and exec_text != "":
36+
history.append({"role": "exec", "content": exec_text})
37+
else:
38+
history.append({"role": "user", "content": message})
3239

3340
# init
3441
history.append({"role": "assistant", "content": ""})
3542
utter_history = []
36-
for i in range(0, len(history), 2):
43+
start_index = 1 if len(history) % 2 == 1 else 0
44+
for i in range(start_index, len(history), 2):
3745
utter_history.append([history[i]["content"], history[i+1]["content"]])
3846

3947
# chat with stream
@@ -54,7 +62,7 @@ def get_args():
5462
help="server port")
5563
parser.add_argument("--title", type=str, default="XVERSE-13B-Chat",
5664
help="server title")
57-
parser.add_argument("--model_path", type=str, default="./VERSE-13B-Chat",
65+
parser.add_argument("--model_path", type=str, default="./XVERSE-13B-Chat",
5866
help="model path")
5967
parser.add_argument("--tokenizer_path", type=str, default="./XVERSE-13B-Chat",
6068
help="Path to the tokenizer.")
@@ -71,6 +79,13 @@ def get_args():
7179
gr.Markdown("# <center>{}</center>".format(args.title))
7280
chatbot = gr.Chatbot(label="Chat history", height=650)
7381
state = gr.State([])
82+
advanced_desc = f"<b>system</b><br>输入对话背景知识。<br><b>exec</b><br>agent调用时,输入api执行后的结果。<br>"
83+
with gr.Accordion("Advanced settings", open=False):
84+
gr.HTML(advanced_desc)
85+
with gr.Row():
86+
system_text = gr.Textbox(label="system", show_label=True, value=system_tpl)
87+
with gr.Row():
88+
exec_text = gr.Textbox(label="exec", show_label=True, value=exec_tpl)
7489

7590
with gr.Row():
7691
text_box = gr.Textbox(label="Message", show_label=False, placeholder="Enter message and press enter")
@@ -80,18 +95,18 @@ def get_args():
8095
reset_btn = gr.Button(value="Reset")
8196

8297
text_box.submit(fn=chat,
83-
inputs=[text_box, state],
98+
inputs=[text_box, state, system_text, exec_text],
8499
outputs=[chatbot, state],
85100
api_name="chat")
86101
submit_btn.click(fn=chat,
87-
inputs=[text_box, state],
102+
inputs=[text_box, state, system_text, exec_text],
88103
outputs=[chatbot, state])
89104

90105
# 用于清空text_box
91106
def clear_textbox():
92-
return gr.update(value="")
93-
text_box.submit(fn=clear_textbox, inputs=None, outputs=[text_box])
94-
submit_btn.click(fn=clear_textbox, inputs=None, outputs=[text_box])
107+
return gr.update(value=""), gr.update(value="")
108+
text_box.submit(fn=clear_textbox, inputs=None, outputs=[text_box, exec_text])
109+
submit_btn.click(fn=clear_textbox, inputs=None, outputs=[text_box, exec_text])
95110

96111
# 用于清空页面和重置state
97112
def reset():

0 commit comments

Comments
 (0)