Skip to content

Commit f79f81c

Browse files
authored
feat(flow): Add Chat Data template (#2716)
1 parent 7e7581e commit f79f81c

File tree

7 files changed

+3356
-37
lines changed

7 files changed

+3356
-37
lines changed

packages/dbgpt-app/src/dbgpt_app/component_configs.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,18 +137,21 @@ def _initialize_openapi(system_app: SystemApp):
137137

138138

139139
def _initialize_operators():
140-
from dbgpt_app.operators.code import CodeMapOperator # noqa: F401
141-
from dbgpt_app.operators.converter import StringToInteger # noqa: F401
142-
from dbgpt_app.operators.datasource import ( # noqa: F401
143-
HODatasourceExecutorOperator,
144-
HODatasourceRetrieverOperator,
145-
)
146-
from dbgpt_app.operators.llm import ( # noqa: F401
147-
HOLLMOperator,
148-
HOStreamingLLMOperator,
149-
)
150-
from dbgpt_app.operators.rag import HOKnowledgeOperator # noqa: F401
151-
from dbgpt_serve.agent.resource.datasource import DatasourceResource # noqa: F401
140+
from dbgpt.core.awel import BaseOperator
141+
from dbgpt.util.module_utils import ModelScanner, ScannerConfig
142+
143+
modules = ["dbgpt_app.operators", "dbgpt_serve.agent.resource"]
144+
145+
scanner = ModelScanner[BaseOperator]()
146+
registered_items = {}
147+
for module in modules:
148+
config = ScannerConfig(
149+
module_path=module,
150+
base_class=BaseOperator,
151+
)
152+
items = scanner.scan_and_register(config)
153+
registered_items[module] = items
154+
return scanner.get_registered_items()
152155

153156

154157
def _initialize_code_server(system_app: SystemApp):

packages/dbgpt-app/src/dbgpt_app/operators/datasource.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
type=DBResource,
106106
description=_("The datasource to retrieve the context"),
107107
)
108+
108109
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
109110
_("Prompt Template"),
110111
"prompt_template",
@@ -172,7 +173,7 @@
172173
_("SQL result"),
173174
"sql_result",
174175
str,
175-
description=_("The result of the SQL execution"),
176+
description=_("The result of the SQL execution(GPT-Vis format)"),
176177
)
177178

178179
_INPUTS_SQL_DICT_LIST = IOField.build_from(
@@ -189,7 +190,9 @@
189190
class GPTVisMixin:
190191
async def save_view_message(self, dag_ctx: DAGContext, view: str):
191192
"""Save the view message."""
192-
await dag_ctx.save_to_share_data(BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view)
193+
await dag_ctx.save_to_share_data(
194+
BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view, overwrite=True
195+
)
193196

194197

195198
class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
@@ -286,14 +289,36 @@ async def map(self, question: str) -> HOContextBody:
286289
class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
287290
"""Execute the context from the datasource."""
288291

292+
_share_data_key = "__datasource_executor_result__"
293+
294+
class MarkdownMapper(MapOperator[str, str]):
295+
async def map(self, context: str) -> str:
296+
"""Convert the result to markdown."""
297+
298+
from dbgpt.util.pd_utils import df_to_markdown
299+
300+
df = await self.current_dag_context.get_from_share_data(
301+
HODatasourceExecutorOperator._share_data_key
302+
)
303+
return df_to_markdown(df)
304+
289305
metadata = ViewMetadata(
290306
label=_("Datasource Executor Operator"),
291307
name="higher_order_datasource_executor_operator",
292308
description=_("Execute the context from the datasource."),
293309
category=OperatorCategory.DATABASE,
294310
parameters=[_PARAMETER_DATASOURCE.new()],
295311
inputs=[_INPUTS_SQL_DICT.new()],
296-
outputs=[_OUTPUTS_SQL_RESULT.new()],
312+
outputs=[
313+
_OUTPUTS_SQL_RESULT.new(),
314+
IOField.build_from(
315+
_("Markdown result"),
316+
"markdown_result",
317+
str,
318+
description=_("The markdown result of the SQL execution"),
319+
mappers=[MarkdownMapper],
320+
),
321+
],
297322
tags={"order": TAGS_ORDER_HIGH},
298323
)
299324

@@ -314,8 +339,16 @@ async def map(self, sql_dict: dict) -> str:
314339
sql = sql_dict.get("sql")
315340
if not sql:
316341
return sql_dict.get("thoughts", "No SQL found in the input dictionary.")
342+
343+
thoughts = sql_dict.get("thoughts", "")
344+
317345
data_df = await self._datasource.query_to_df(sql)
346+
# Save the result to share data, for markdown mapper
347+
await self.current_dag_context.save_to_share_data(
348+
HODatasourceExecutorOperator._share_data_key, data_df
349+
)
318350
view = await vis.display(chart=sql_dict, data_df=data_df)
351+
view = thoughts + "\n\n" + view
319352
await self.save_view_message(self.current_dag_context, view)
320353
return view
321354

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
from functools import cache
2+
from typing import Optional
3+
4+
from dbgpt.core import (
5+
ChatPromptTemplate,
6+
HumanPromptTemplate,
7+
LLMClient,
8+
ModelMessage,
9+
SystemPromptTemplate,
10+
)
11+
from dbgpt.core.awel import JoinOperator
12+
from dbgpt.core.awel.flow.base import (
13+
TAGS_ORDER_HIGH,
14+
IOField,
15+
OperatorCategory,
16+
Parameter,
17+
ViewMetadata,
18+
)
19+
from dbgpt.core.interface.llm import ModelRequest
20+
from dbgpt.model.operators import MixinLLMOperator
21+
from dbgpt.util.i18n_utils import _
22+
from dbgpt_app.operators.datasource import GPTVisMixin
23+
24+
_DEFAULT_PROMPT_EN = """You are a helpful AI assistant.
25+
26+
Please carefully read the data in the Markdown table format below, the data is a
27+
database query result based on the user question. Please analyze and summarize the
28+
data carefully, and provide a summary report in markdown format.
29+
30+
<data-report>
31+
{data_report}
32+
</data-report>
33+
34+
user question:
35+
{user_input}
36+
37+
Please answer in the same language as the user's question.
38+
"""
39+
40+
_DEFAULT_PROMPT_ZH = """你是一个有用的AI助手。
41+
42+
请你仔细阅读下面的 Markdown 表格格式的数据,这是一份根据用户问题查询到的数据库的数据,\
43+
你需要根据数据仔细分析和总结,给出一份总结报告,使用 markdown 格式输出。
44+
45+
<data-report>
46+
{data_report}
47+
</data-report>
48+
49+
用户的问题:
50+
{user_input}
51+
52+
请用用户提问的语言回答。
53+
"""
54+
55+
_DEFAULT_USER_PROMPT = """\
56+
{user_input}
57+
"""
58+
59+
60+
@cache
61+
def _get_default_prompt(language: str) -> ChatPromptTemplate:
62+
if language == "zh":
63+
sys_prompt = _DEFAULT_PROMPT_ZH
64+
user_prompt = _DEFAULT_USER_PROMPT
65+
else:
66+
sys_prompt = _DEFAULT_PROMPT_EN
67+
user_prompt = _DEFAULT_USER_PROMPT
68+
69+
return ChatPromptTemplate(
70+
messages=[
71+
SystemPromptTemplate.from_template(sys_prompt),
72+
HumanPromptTemplate.from_template(user_prompt),
73+
]
74+
)
75+
76+
77+
class ReportAnalystOperator(MixinLLMOperator, JoinOperator[str]):
78+
metadata = ViewMetadata(
79+
label=_("Report Analyst"),
80+
name="report_analyst",
81+
description=_("Report Analyst"),
82+
category=OperatorCategory.DATABASE,
83+
tags={"order": TAGS_ORDER_HIGH},
84+
parameters=[
85+
Parameter.build_from(
86+
_("Prompt Template"),
87+
"prompt_template",
88+
ChatPromptTemplate,
89+
description=_("The prompt template for the conversation."),
90+
optional=True,
91+
default=None,
92+
),
93+
Parameter.build_from(
94+
_("Model Name"),
95+
"model",
96+
str,
97+
optional=True,
98+
default=None,
99+
description=_("The model name."),
100+
),
101+
Parameter.build_from(
102+
_("LLM Client"),
103+
"llm_client",
104+
LLMClient,
105+
optional=True,
106+
default=None,
107+
description=_(
108+
"The LLM Client, how to connect to the LLM model, if not provided,"
109+
" it will use the default client deployed by DB-GPT."
110+
),
111+
),
112+
],
113+
inputs=[
114+
IOField.build_from(
115+
_("User question"),
116+
"question",
117+
str,
118+
description=_("The question of user"),
119+
),
120+
IOField.build_from(
121+
_("The data report"),
122+
"data_report",
123+
str,
124+
_("The data report in markdown format."),
125+
dynamic=True,
126+
),
127+
],
128+
outputs=[
129+
IOField.build_from(
130+
_("Report Analyst Result"),
131+
"report_analyst_result",
132+
str,
133+
description=_("The report analyst result."),
134+
)
135+
],
136+
)
137+
138+
def __init__(
139+
self,
140+
prompt_template: Optional[ChatPromptTemplate] = None,
141+
model: Optional[str] = None,
142+
llm_client: Optional[LLMClient] = None,
143+
**kwargs,
144+
):
145+
JoinOperator.__init__(self, combine_function=self._join_func, **kwargs)
146+
MixinLLMOperator.__init__(self, llm_client=llm_client, **kwargs)
147+
148+
# User must select a history merge mode
149+
self._prompt_template = prompt_template
150+
self._model = model
151+
152+
@property
153+
def prompt_template(self) -> ChatPromptTemplate:
154+
"""Get the prompt template."""
155+
language = "en"
156+
if self.system_app:
157+
language = self.system_app.config.get_current_lang()
158+
if self._prompt_template is None:
159+
return _get_default_prompt(language)
160+
return self._prompt_template
161+
162+
async def _join_func(self, question: str, data_report: str, *args):
163+
dynamic_inputs = [data_report]
164+
for arg in args:
165+
if isinstance(arg, str):
166+
dynamic_inputs.append(arg)
167+
data_report = "\n".join(dynamic_inputs)
168+
messages = self.prompt_template.format_messages(
169+
user_input=question,
170+
data_report=data_report,
171+
)
172+
model_messages = ModelMessage.from_base_messages(messages)
173+
models = await self.llm_client.models()
174+
if not models:
175+
raise Exception("No models available.")
176+
model = self._model or models[0].model
177+
178+
model_request = ModelRequest.build_request(model, messages=model_messages)
179+
model_output = await self.llm_client.generate(model_request)
180+
text = model_output.gen_text_with_thinking()
181+
182+
return text
183+
184+
185+
class StringJoinOperator(GPTVisMixin, JoinOperator[str]):
186+
"""Join operator for strings.
187+
This operator joins the input strings with a specified separator.
188+
"""
189+
190+
metadata = ViewMetadata(
191+
label=_("String Join Operator"),
192+
name="string_join_operator",
193+
description=_("Merge multiple inputs into a single string."),
194+
category=OperatorCategory.COMMON,
195+
parameters=[
196+
Parameter.build_from(
197+
_("Separator"),
198+
"separator",
199+
str,
200+
optional=True,
201+
default="\n\n",
202+
description=_("The separator to join the strings."),
203+
),
204+
],
205+
inputs=[
206+
IOField.build_from(
207+
_("Input Strings"),
208+
"input_strings",
209+
str,
210+
description=_("The input strings to join."),
211+
dynamic=True,
212+
),
213+
],
214+
outputs=[
215+
IOField.build_from(
216+
_("Joined String"),
217+
"joined_string",
218+
str,
219+
description=_("The joined string."),
220+
)
221+
],
222+
tags={"order": TAGS_ORDER_HIGH},
223+
)
224+
225+
def __init__(self, separator: str = "\n\n", **kwargs):
226+
super().__init__(combine_function=self._join_func, **kwargs)
227+
self.separator = separator
228+
229+
async def _join_func(self, *args) -> str:
230+
"""Join the strings with the separator."""
231+
view = self.separator.join(args)
232+
await self.save_view_message(self.current_dag_context, view)
233+
return view

packages/dbgpt-core/src/dbgpt/util/pd_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,25 @@ def csv_colunm_foramt(val):
2020
return val
2121
except ValueError:
2222
return val
23+
24+
25+
def df_to_markdown(df: pd.DataFrame, index=False) -> str:
26+
"""Convert a pandas DataFrame to a Markdown table."""
27+
columns = df.columns
28+
header = "| " + " | ".join(columns) + " |"
29+
separator = "| " + " | ".join(["---"] * len(columns)) + " |"
30+
31+
rows = []
32+
for _, row in df.iterrows():
33+
row_str = "| " + " | ".join(map(str, row.values)) + " |"
34+
rows.append(row_str)
35+
36+
if index:
37+
header = "| index | " + " | ".join(columns) + " |"
38+
separator = "| --- | " + " | ".join(["---"] * len(columns)) + " |"
39+
rows = []
40+
for idx, row in df.iterrows():
41+
row_str = f"| {idx} | " + " | ".join(map(str, row.values)) + " |"
42+
rows.append(row_str)
43+
44+
return "\n".join([header, separator] + rows)

0 commit comments

Comments
 (0)