Skip to content

Commit 88e3d12

Browse files
KingSkyLiAppointattpoisonooovritser
authored
feat: add document structure into GraphRAG (#2033)
Co-authored-by: Appointat <kuda.czk@antgroup.com> Co-authored-by: tpoisonooo <khj.application@aliyun.com> Co-authored-by: vritser <vritser@163.com>
1 parent 811ce63 commit 88e3d12

File tree

29 files changed

+1910
-936
lines changed

29 files changed

+1910
-936
lines changed

.env.template

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,15 @@ EXECUTE_LOCAL_COMMANDS=False
157157
#*******************************************************************#
158158
VECTOR_STORE_TYPE=Chroma
159159
GRAPH_STORE_TYPE=TuGraph
160-
GRAPH_COMMUNITY_SUMMARY_ENABLED=True
161160
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE=5
162161
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE=0.3
163162
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE=20
164163
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE=0.0
165164

165+
ENABLE_GRAPH_COMMUNITY_SUMMARY=True # enable the graph community summary
166+
ENABLE_TRIPLET_GRAPH=True # enable the graph search for triplets
167+
ENABLE_DOCUMENT_GRAPH=True # enable the graph search for documents and chunks
168+
166169
### Chroma vector db config
167170
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data
168171

dbgpt/_private/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ def __init__(self) -> None:
213213

214214
# Vector Store Configuration
215215
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
216-
self.GRAPH_COMMUNITY_SUMMARY_ENABLED = (
217-
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
216+
self.ENABLE_GRAPH_COMMUNITY_SUMMARY = (
217+
os.getenv("ENABLE_GRAPH_COMMUNITY_SUMMARY", "").lower() == "true"
218218
)
219219
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
220220
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")

dbgpt/app/knowledge/service.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
KnowledgeDocumentEntity,
1313
)
1414
from dbgpt.app.knowledge.request.request import (
15-
ChunkEditRequest,
1615
ChunkQueryRequest,
1716
DocumentQueryRequest,
1817
DocumentRecallTestRequest,
@@ -650,12 +649,17 @@ def query_graph(self, space_name, limit):
650649
{
651650
"id": node.vid,
652651
"communityId": node.get_prop("_community_id"),
653-
"name": node.vid,
654-
"type": "",
652+
"name": node.name,
653+
"type": node.get_prop("type") or "",
655654
}
656655
)
657656
for edge in graph.edges():
658657
res["edges"].append(
659-
{"source": edge.sid, "target": edge.tid, "name": edge.name, "type": ""}
658+
{
659+
"source": edge.sid,
660+
"target": edge.tid,
661+
"name": edge.name,
662+
"type": edge.get_prop("type") or "",
663+
}
660664
)
661665
return res

dbgpt/datasource/conn_tugraph.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""TuGraph Connector."""
22

33
import json
4-
from typing import Dict, Generator, List, cast
4+
from typing import Dict, Generator, List, Tuple, cast
55

66
from .base import BaseConnector
77

@@ -21,8 +21,7 @@ def __init__(self, driver, graph):
2121
self._session = None
2222

2323
def create_graph(self, graph_name: str) -> None:
24-
"""Create a new graph."""
25-
# run the query to get vertex labels
24+
"""Create a new graph in the database if it doesn't already exist."""
2625
try:
2726
with self._driver.session(database="default") as session:
2827
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
@@ -32,10 +31,10 @@ def create_graph(self, graph_name: str) -> None:
3231
f"CALL dbms.graph.createGraph('{graph_name}', '', 2048)"
3332
)
3433
except Exception as e:
35-
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}")
34+
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e
3635

3736
def delete_graph(self, graph_name: str) -> None:
38-
"""Delete a graph."""
37+
"""Delete a graph in the database if it exists."""
3938
with self._driver.session(database="default") as session:
4039
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
4140
exists = any(item["graph_name"] == graph_name for item in graph_list)
@@ -61,17 +60,20 @@ def from_uri_db(
6160
"`pip install neo4j`"
6261
) from err
6362

64-
def get_table_names(self) -> Dict[str, List[str]]:
63+
def get_table_names(self) -> Tuple[List[str], List[str]]:
6564
"""Get all table names from the TuGraph by Neo4j driver."""
66-
# run the query to get vertex labels
6765
with self._driver.session(database=self._graph) as session:
68-
v_result = session.run("CALL db.vertexLabels()").data()
69-
v_data = [table_name["label"] for table_name in v_result]
66+
# Run the query to get vertex labels
67+
raw_vertex_labels: Dict[str, str] = session.run(
68+
"CALL db.vertexLabels()"
69+
).data()
70+
vertex_labels = [table_name["label"] for table_name in raw_vertex_labels]
71+
72+
# Run the query to get edge labels
73+
raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data()
74+
edge_labels = [table_name["label"] for table_name in raw_edge_labels]
7075

71-
# run the query to get edge labels
72-
e_result = session.run("CALL db.edgeLabels()").data()
73-
e_data = [table_name["label"] for table_name in e_result]
74-
return {"vertex_tables": v_data, "edge_tables": e_data}
76+
return vertex_labels, edge_labels
7577

7678
def get_grants(self):
7779
"""Get grants."""
@@ -100,7 +102,7 @@ def run(self, query: str, fetch: str = "all") -> List:
100102
result = session.run(query)
101103
return list(result)
102104
except Exception as e:
103-
raise Exception(f"Query execution failed: {e}")
105+
raise Exception(f"Query execution failed: {e}\nQuery: {query}") from e
104106

105107
def run_stream(self, query: str) -> Generator:
106108
"""Run GQL."""
@@ -109,11 +111,15 @@ def run_stream(self, query: str) -> Generator:
109111
yield from result
110112

111113
def get_columns(self, table_name: str, table_type: str = "vertex") -> List[Dict]:
112-
"""Get fields about specified graph.
114+
"""Retrieve the column for a specified vertex or edge table in the graph db.
115+
116+
This function queries the schema of a given table (vertex or edge) and returns
117+
detailed information about its columns (properties).
113118
114119
Args:
115120
table_name (str): table name (graph name)
116121
table_type (str): table type (vertex or edge)
122+
117123
Returns:
118124
columns: List[Dict], which contains name: str, type: str,
119125
default_expression: str, is_in_primary_key: bool, comment: str
@@ -146,8 +152,8 @@ def get_indexes(self, table_name: str, table_type: str = "vertex") -> List[Dict]
146152
"""Get table indexes about specified table.
147153
148154
Args:
149-
table_name:(str) table name
150-
table_type:(str'vertex' | 'edge'
155+
table_name (str): table name
156+
table_type (str): 'vertex' | 'edge'
151157
Returns:
152158
List[Dict]:eg:[{'name': 'idx_key', 'column_names': ['id']}]
153159
"""

dbgpt/rag/transformer/graph_extractor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph]
6565
match = re.match(r"\((.*?)#(.*?)\)", line)
6666
if match:
6767
name, summary = [part.strip() for part in match.groups()]
68-
graph.upsert_vertex(Vertex(name, description=summary))
68+
graph.upsert_vertex(
69+
Vertex(name, description=summary, vertex_type="entity")
70+
)
6971
elif current_section == "Relationships":
7072
match = re.match(r"\((.*?)#(.*?)#(.*?)#(.*?)\)", line)
7173
if match:
@@ -74,7 +76,13 @@ def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph]
7476
]
7577
edge_count += 1
7678
graph.append_edge(
77-
Edge(source, target, name, description=summary)
79+
Edge(
80+
source,
81+
target,
82+
name,
83+
description=summary,
84+
edge_type="relation",
85+
)
7886
)
7987

8088
if limit and edge_count >= limit:

dbgpt/rag/transformer/keyword_extractor.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""KeywordExtractor class."""
2+
23
import logging
34
from typing import List, Optional
45

@@ -39,12 +40,15 @@ def __init__(self, llm_client: LLMClient, model_name: str):
3940
def _parse_response(self, text: str, limit: Optional[int] = None) -> List[str]:
4041
keywords = set()
4142

42-
for part in text.split(";"):
43-
for s in part.strip().split(","):
44-
keyword = s.strip()
45-
if keyword:
46-
keywords.add(keyword)
47-
if limit and len(keywords) >= limit:
48-
return list(keywords)
43+
lines = text.replace(":", "\n").split("\n")
44+
45+
for line in lines:
46+
for part in line.split(";"):
47+
for s in part.strip().split(","):
48+
keyword = s.strip()
49+
if keyword:
50+
keywords.add(keyword)
51+
if limit and len(keywords) >= limit:
52+
return list(keywords)
4953

5054
return list(keywords)

dbgpt/serve/rag/connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def __init__(
128128

129129
def __rewrite_index_store_type(self, index_store_type):
130130
# Rewrite Knowledge Graph Type
131-
if CFG.GRAPH_COMMUNITY_SUMMARY_ENABLED:
131+
if CFG.ENABLE_GRAPH_COMMUNITY_SUMMARY:
132132
if index_store_type == "KnowledgeGraph":
133133
return "CommunitySummaryKnowledgeGraph"
134134
return index_store_type

dbgpt/storage/graph_store/base.py

Lines changed: 21 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Graph store base class."""
2+
23
import logging
34
from abc import ABC, abstractmethod
4-
from typing import Generator, List, Optional, Tuple
5+
from typing import Optional
56

67
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
78
from dbgpt.core import Embeddings
8-
from dbgpt.storage.graph_store.graph import Direction, Graph
99

1010
logger = logging.getLogger(__name__)
1111

@@ -23,78 +23,36 @@ class GraphStoreConfig(BaseModel):
2323
default=None,
2424
description="The embedding function of graph store, optional.",
2525
)
26-
summary_enabled: bool = Field(
26+
enable_summary: bool = Field(
2727
default=False,
2828
description="Enable graph community summary or not.",
2929
)
30+
enable_document_graph: bool = Field(
31+
default=True,
32+
description="Enable document graph search or not.",
33+
)
34+
enable_triplet_graph: bool = Field(
35+
default=True,
36+
description="Enable knowledge graph search or not.",
37+
)
3038

3139

3240
class GraphStoreBase(ABC):
3341
"""Graph store base class."""
3442

43+
def __init__(self, config: GraphStoreConfig):
44+
"""Initialize graph store."""
45+
self._config = config
46+
self._conn = None
47+
3548
@abstractmethod
3649
def get_config(self) -> GraphStoreConfig:
3750
"""Get the graph store config."""
3851

3952
@abstractmethod
40-
def get_vertex_type(self) -> str:
41-
"""Get the vertex type."""
42-
43-
@abstractmethod
44-
def get_edge_type(self) -> str:
45-
"""Get the edge type."""
46-
47-
@abstractmethod
48-
def insert_triplet(self, sub: str, rel: str, obj: str):
49-
"""Add triplet."""
50-
51-
@abstractmethod
52-
def insert_graph(self, graph: Graph):
53-
"""Add graph."""
54-
55-
@abstractmethod
56-
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
57-
"""Get triplets."""
58-
59-
@abstractmethod
60-
def delete_triplet(self, sub: str, rel: str, obj: str):
61-
"""Delete triplet."""
53+
def _escape_quotes(self, text: str) -> str:
54+
"""Escape single and double quotes in a string for queries."""
6255

63-
@abstractmethod
64-
def truncate(self):
65-
"""Truncate Graph."""
66-
67-
@abstractmethod
68-
def drop(self):
69-
"""Drop graph."""
70-
71-
@abstractmethod
72-
def get_schema(self, refresh: bool = False) -> str:
73-
"""Get schema."""
74-
75-
@abstractmethod
76-
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
77-
"""Get full graph."""
78-
79-
@abstractmethod
80-
def explore(
81-
self,
82-
subs: List[str],
83-
direct: Direction = Direction.BOTH,
84-
depth: Optional[int] = None,
85-
fan: Optional[int] = None,
86-
limit: Optional[int] = None,
87-
) -> Graph:
88-
"""Explore on graph."""
89-
90-
@abstractmethod
91-
def query(self, query: str, **args) -> Graph:
92-
"""Execute a query."""
93-
94-
def aquery(self, query: str, **args) -> Graph:
95-
"""Async execute a query."""
96-
return self.query(query, **args)
97-
98-
@abstractmethod
99-
def stream_query(self, query: str) -> Generator[Graph, None, None]:
100-
"""Execute stream query."""
56+
# @abstractmethod
57+
# def _paser(self, entities: List[Vertex]) -> str:
58+
# """Parse entities to string."""

dbgpt/storage/graph_store/factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Graph store factory."""
2+
23
import logging
34
from typing import Tuple, Type
45

0 commit comments

Comments
 (0)