|
| 1 | +import argparse |
| 2 | +import torch |
| 3 | +import gradio as gr |
| 4 | +import json |
| 5 | +from datetime import datetime |
| 6 | +from transformers import AutoModelForCausalLM, AutoTokenizer,GenerationConfig |
| 7 | + |
| 8 | +tokenizer, model = None, None |
| 9 | + |
| 10 | +def init_model(args): |
| 11 | + global tokenizer, model |
| 12 | + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, truncation_side="left", padding_side="left") |
| 13 | + model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, |
| 14 | + low_cpu_mem_usage=True, device_map='auto') |
| 15 | + model.generation_config = GenerationConfig.from_pretrained(args.model_path) |
| 16 | + model = model.eval() |
| 17 | + |
| 18 | +def chat(message, history, request: gr.Request): |
| 19 | + global tokenizer, model |
| 20 | + history = history or [] |
| 21 | + history.append({"role": "user", "content": message}) |
| 22 | + |
| 23 | + # init |
| 24 | + history.append({"role": "assistant", "content": ""}) |
| 25 | + utter_history = [] |
| 26 | + for i in range(0, len(history), 2): |
| 27 | + utter_history.append([history[i]["content"], history[i+1]["content"]]) |
| 28 | + |
| 29 | + # chat with stream |
| 30 | + for next_text in model.chat(tokenizer, history[:-1], stream=True): |
| 31 | + utter_history[-1][1] += next_text |
| 32 | + history[-1]["content"] += next_text |
| 33 | + if torch.backends.mps.is_available(): |
| 34 | + torch.mps.empty_cache() |
| 35 | + yield utter_history, history |
| 36 | + |
| 37 | + # log |
| 38 | + current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
| 39 | + print(f'{current_time} request_ip:{request.client.host}\nquery: {message}\nhistory: {json.dumps(history, ensure_ascii=False)}\nanswer: {json.dumps(utter_history[-1][1], ensure_ascii=False)}') |
| 40 | + |
| 41 | +def get_args(): |
| 42 | + parser = argparse.ArgumentParser() |
| 43 | + parser.add_argument("--port", type=int, default=36000, |
| 44 | + help="server port") |
| 45 | + parser.add_argument("--title", type=str, default="XVERSE-13B-Chat", |
| 46 | + help="server title") |
| 47 | + parser.add_argument("--model_path", type=str, default="./VERSE-13B-Chat", |
| 48 | + help="model path") |
| 49 | + parser.add_argument("--tokenizer_path", type=str, default="./XVERSE-13B-Chat", |
| 50 | + help="Path to the tokenizer.") |
| 51 | + args = parser.parse_args() |
| 52 | + return args |
| 53 | + |
| 54 | +if __name__ == "__main__": |
| 55 | + args = get_args() |
| 56 | + # 初始化模型 |
| 57 | + init_model(args) |
| 58 | + |
| 59 | + # 构建demo应用 |
| 60 | + with gr.Blocks() as demo: |
| 61 | + gr.Markdown("# <center>{}</center>".format(args.title)) |
| 62 | + chatbot = gr.Chatbot(label="Chat history", height=650).style(color_map=("green", "pink")) |
| 63 | + state = gr.State([]) |
| 64 | + |
| 65 | + with gr.Row(): |
| 66 | + text_box = gr.Textbox(label="Message", show_label=False, placeholder="Enter message and press enter").style(container=False) |
| 67 | + |
| 68 | + with gr.Row(): |
| 69 | + submit_btn = gr.Button(value="Send", variant="secondary") |
| 70 | + reset_btn = gr.Button(value="Reset") |
| 71 | + |
| 72 | + text_box.submit(fn=chat, |
| 73 | + inputs=[text_box, state], |
| 74 | + outputs=[chatbot, state], |
| 75 | + api_name="chat") |
| 76 | + submit_btn.click(fn=chat, |
| 77 | + inputs=[text_box, state], |
| 78 | + outputs=[chatbot, state]) |
| 79 | + |
| 80 | + # 用于清空text_box |
| 81 | + def clear_textbox(): |
| 82 | + return gr.update(value="") |
| 83 | + text_box.submit(fn=clear_textbox, inputs=None, outputs=[text_box]) |
| 84 | + submit_btn.click(fn=clear_textbox, inputs=None, outputs=[text_box]) |
| 85 | + |
| 86 | + # 用于清空页面和重置state |
| 87 | + def reset(): |
| 88 | + return None, [] |
| 89 | + reset_btn.click(fn=reset, inputs=None, outputs=[chatbot, state]) |
| 90 | + |
| 91 | + demo.queue(concurrency_count=4) |
| 92 | + demo.launch(server_name="0.0.0.0", server_port=args.port) |
0 commit comments