Skip to content

Commit 320228a

Browse files
fix: fix unit test and some code typos (#2717)
1 parent fe98239 commit 320228a

File tree

7 files changed

+29
-20
lines changed

7 files changed

+29
-20
lines changed

packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@
140140
4.优先使用数据分析的方式回答,如果用户问题不涉及数据分析内容,你可以按你的理解进行回答
141141
5.输出内容中sql部分转换为:
142142
<api-call><name>[数据显示方式]</name><args><sql>\
143-
[正确的duckdb数据分析sql]</sql></args></api- call> \
143+
[正确的duckdb数据分析sql]</sql></args></api-call> \
144144
这样的格式,参考返回格式要求
145145
146146
请一步一步思考,给出回答,并确保你的回答内容格式如下:

packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_learning/chat.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,11 @@ def __init__(
3838

3939
@trace()
4040
async def generate_input_values(self) -> Dict:
41-
# colunms, datas = self.excel_reader.get_sample_data()
42-
colunms, datas = await blocking_func_to_async(
41+
columns, datas = await blocking_func_to_async(
4342
self._executor, self.excel_reader.get_sample_data, self._curr_table
4443
)
45-
self.prompt_template.output_parser.update(colunms)
46-
datas.insert(0, colunms)
44+
self.prompt_template.output_parser.update(columns)
45+
datas.insert(0, columns)
4746

4847
table_schema = await blocking_func_to_async(
4948
self._executor, self.excel_reader.get_create_table_sql, self._curr_table

packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/prompt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
表格格式:|xxx|xxx|xxx|, 代码格式:```xxx```.
2626
2.如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" \
2727
禁止胡乱编造.
28-
3.回答的时候最好按照1.2.3.点进行总结, 并以markdwon格式显示.
28+
3.回答的时候最好按照1.2.3.点进行总结, 并以markdown格式显示.
2929
已知内容:
3030
{context}
3131
问题:
@@ -43,7 +43,7 @@
4343
question, state clearly: "The content provided in the knowledge base is not enough \
4444
to answer this question," and avoid making up answers.
4545
3.When responding, it is best to summarize the points in the order of 1, 2, 3, And \
46-
displayed in markdwon format.
46+
displayed in markdown format.
4747
known information:
4848
{context}
4949
question:

packages/dbgpt-core/src/dbgpt/datasource/rdbms/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ def __sql_parse(self, sql):
744744
sql = sql.strip()
745745
parsed = sqlparse.parse(sql)[0]
746746
sql_type = parsed.get_type()
747-
if sql_type == "CREATE":
747+
if sql_type == "CREATE" or sql_type == "DROP":
748748
table_name = self._extract_table_name_from_ddl(parsed)
749749
else:
750750
table_name = parsed.get_name()

packages/dbgpt-ext/src/dbgpt_ext/datasource/rdbms/conn_oracle.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def from_uri_db(
110110

111111
return cls.from_uri(db_url, engine_args=engine_args, **kwargs)
112112

113+
def get_simple_fields(self, table_name):
114+
"""Get column fields about specified table."""
115+
return self.get_fields(table_name)
116+
113117
def get_fields(self, table_name: str, db_name=None) -> List[Tuple]:
114118
with self.session_scope() as session:
115119
query = f"""

tests/intetration_tests/datasource/test_conn_mysql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import pytest
2222

23-
from dbgpt.datasource.rdbms.conn_mysql import MySQLConnector
23+
from dbgpt_ext.datasource.rdbms.conn_mysql import MySQLConnector
2424

2525
_create_table_sql = """
2626
CREATE TABLE IF NOT EXISTS `test` (

tests/intetration_tests/datasource/test_conn_oracle.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""
2-
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
3-
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
4-
mysql -h 127.0.0.1 -uroot -p -P3307
5-
Enter password:
6-
Welcome to the MySQL monitor. Commands end with ; or \g.
7-
Your MySQL connection id is 2
8-
Server version: 5.7.41 MySQL Community Server (GPL)
2+
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_oracle.py
3+
4+
Install Oracle With Docker
5+
docker run -d -p 1521:1521 -p 5502:5500 -e ORACLE_SID=ORCLCDB -e ORACLE_PDB=ORCLPDB
6+
-e ORACLE_PWD=oracle -e ORACLE_EDITION=standard -e ORACLE_CHARACTERSET=AL32UTF8
7+
--name oracle_19c registry.cn-hangzhou.aliyuncs.com/zhuyijun/oracle:19c
98
109
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
1110
@@ -50,14 +49,21 @@ def db():
5049

5150
def test_get_usable_table_names(db):
5251
db.run(_create_table_sql)
53-
db.run("COMMIT")
52+
# db.run("COMMIT") DML not need commit
53+
db._inspector.clear_cache()
54+
print(db._sync_tables_from_db())
5455
table_names = db.get_usable_table_names()
5556
assert "TEST" in map(str.upper, table_names)
5657

5758

5859
def test_get_table_info(db):
5960
db.run(_create_table_sql)
60-
db.run("COMMIT")
61+
# clean cache
62+
db._inspector.clear_cache()
63+
db._metadata.clear()
64+
db._metadata.reflect(bind=db._engine)
65+
# refresh table
66+
print(db._sync_tables_from_db())
6167
table_info = db.get_table_info()
6268
assert "CREATE TABLE TEST" in table_info.upper()
6369

@@ -70,7 +76,7 @@ def test_run_no_throw(db):
7076

7177
def test_get_index_empty(db):
7278
db.run(_create_table_sql)
73-
db.run("COMMIT")
79+
# db.run("COMMIT")
7480
indexes = db.get_indexes("TEST")
7581
assert indexes == []
7682

@@ -110,4 +116,4 @@ def test_get_database_lists(db):
110116
pdb_names = [name[0] for name in databases[1:]]
111117
else:
112118
pdb_names = ["ORCL"]
113-
assert any(name in ("ORCLPDB1", "ORCL") for name in pdb_names)
119+
assert any(name in ("ORCLPDB", "ORCL") for name in pdb_names)

0 commit comments

Comments
 (0)