@@ -79,7 +79,7 @@ async def get_community(self, community_id: str) -> Community:
7979 @property
8080 def graph_store (self ) -> TuGraphStore :
8181 """Get the graph store."""
82- return self ._graph_store
82+ return self ._graph_store # type: ignore[return-value]
8383
8484 def get_graph_config (self ):
8585 """Get the graph store config."""
@@ -176,29 +176,23 @@ def upsert_edge(
176176 [{ self ._convert_dict_to_str (edge_list )} ])"""
177177 self .graph_store .conn .run (query = relation_query )
178178
179- def upsert_chunks (
180- self , chunks : Union [Iterator [Vertex ], Iterator [ParagraphChunk ]]
181- ) -> None :
179+ def upsert_chunks (self , chunks : Iterator [Union [Vertex , ParagraphChunk ]]) -> None :
182180 """Upsert chunks."""
183- chunks_list = list (chunks )
184- if chunks_list and isinstance (chunks_list [0 ], ParagraphChunk ):
185- chunk_list = [
186- {
187- "id" : self ._escape_quotes (chunk .chunk_id ),
188- "name" : self ._escape_quotes (chunk .chunk_name ),
189- "content" : self ._escape_quotes (chunk .content ),
190- }
191- for chunk in chunks_list
192- ]
193- else :
194- chunk_list = [
195- {
196- "id" : self ._escape_quotes (chunk .vid ),
197- "name" : self ._escape_quotes (chunk .name ),
198- "content" : self ._escape_quotes (chunk .get_prop ("content" )),
199- }
200- for chunk in chunks_list
201- ]
181+ chunk_list = [
182+ {
183+ "id" : self ._escape_quotes (chunk .chunk_id ),
184+ "name" : self ._escape_quotes (chunk .chunk_name ),
185+ "content" : self ._escape_quotes (chunk .content ),
186+ }
187+ if isinstance (chunk , ParagraphChunk )
188+ else {
189+ "id" : self ._escape_quotes (chunk .vid ),
190+ "name" : self ._escape_quotes (chunk .name ),
191+ "content" : self ._escape_quotes (chunk .get_prop ("content" )),
192+ }
193+ for chunk in chunks
194+ ]
195+
202196 chunk_query = (
203197 f"CALL db.upsertVertex("
204198 f'"{ GraphElemType .CHUNK .value } ", '
@@ -207,28 +201,24 @@ def upsert_chunks(
207201 self .graph_store .conn .run (query = chunk_query )
208202
209203 def upsert_documents (
210- self , documents : Union [ Iterator [Vertex ], Iterator [ ParagraphChunk ]]
204+ self , documents : Iterator [Union [ Vertex , ParagraphChunk ]]
211205 ) -> None :
212206 """Upsert documents."""
213- documents_list = list (documents )
214- if documents_list and isinstance (documents_list [0 ], ParagraphChunk ):
215- document_list = [
216- {
217- "id" : self ._escape_quotes (document .chunk_id ),
218- "name" : self ._escape_quotes (document .chunk_name ),
219- "content" : "" ,
220- }
221- for document in documents_list
222- ]
223- else :
224- document_list = [
225- {
226- "id" : self ._escape_quotes (document .vid ),
227- "name" : self ._escape_quotes (document .name ),
228- "content" : self ._escape_quotes (document .get_prop ("content" )) or "" ,
229- }
230- for document in documents_list
231- ]
207+ document_list = [
208+ {
209+ "id" : self ._escape_quotes (document .chunk_id ),
210+ "name" : self ._escape_quotes (document .chunk_name ),
211+ "content" : "" ,
212+ }
213+ if isinstance (document , ParagraphChunk )
214+ else {
215+ "id" : self ._escape_quotes (document .vid ),
216+ "name" : self ._escape_quotes (document .name ),
217+ "content" : "" ,
218+ }
219+ for document in documents
220+ ]
221+
232222 document_query = (
233223 "CALL db.upsertVertex("
234224 f'"{ GraphElemType .DOCUMENT .value } ", '
@@ -258,7 +248,7 @@ def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
258248 self .graph_store .conn .run (query = vertex_query )
259249 self .graph_store .conn .run (query = edge_query )
260250
261- def upsert_graph (self , graph : MemoryGraph ) -> None :
251+ def upsert_graph (self , graph : Graph ) -> None :
262252 """Add graph to the graph store.
263253
264254 Args:
@@ -362,7 +352,8 @@ def drop(self):
362352
363353 def create_graph (self , graph_name : str ):
364354 """Create a graph."""
365- self .graph_store .conn .create_graph (graph_name = graph_name )
355+ if not self .graph_store .conn .create_graph (graph_name = graph_name ):
356+ return
366357
367358 # Create the graph schema
368359 def _format_graph_propertity_schema (
@@ -474,12 +465,14 @@ def create_graph_label(
474465 (vertices) and edges in the graph.
475466 """
476467 if graph_elem_type .is_vertex (): # vertex
477- data = json .dumps ({
478- "label" : graph_elem_type .value ,
479- "type" : "VERTEX" ,
480- "primary" : "id" ,
481- "properties" : graph_properties ,
482- })
468+ data = json .dumps (
469+ {
470+ "label" : graph_elem_type .value ,
471+ "type" : "VERTEX" ,
472+ "primary" : "id" ,
473+ "properties" : graph_properties ,
474+ }
475+ )
483476 gql = f"""CALL db.createVertexLabelByJson('{ data } ')"""
484477 else : # edge
485478
@@ -505,12 +498,14 @@ def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]:
505498 else :
506499 raise ValueError ("Invalid graph element type." )
507500
508- data = json .dumps ({
509- "label" : graph_elem_type .value ,
510- "type" : "EDGE" ,
511- "constraints" : edge_direction (graph_elem_type ),
512- "properties" : graph_properties ,
513- })
501+ data = json .dumps (
502+ {
503+ "label" : graph_elem_type .value ,
504+ "type" : "EDGE" ,
505+ "constraints" : edge_direction (graph_elem_type ),
506+ "properties" : graph_properties ,
507+ }
508+ )
514509 gql = f"""CALL db.createEdgeLabelByJson('{ data } ')"""
515510
516511 self .graph_store .conn .run (gql )
@@ -530,18 +525,16 @@ def check_label(self, graph_elem_type: GraphElemType) -> bool:
530525 True if the label exists in the specified graph element type, otherwise
531526 False.
532527 """
533- vertex_tables , edge_tables = self .graph_store .conn .get_table_names ()
528+ tables = self .graph_store .conn .get_table_names ()
534529
535- if graph_elem_type .is_vertex ():
536- return graph_elem_type in vertex_tables
537- else :
538- return graph_elem_type in edge_tables
530+ return graph_elem_type .value in tables
539531
540532 def explore (
541533 self ,
542534 subs : List [str ],
543535 direct : Direction = Direction .BOTH ,
544536 depth : int = 3 ,
537+ fan : Optional [int ] = None ,
545538 limit : Optional [int ] = None ,
546539 search_scope : Optional [
547540 Literal ["knowledge_graph" , "document_graph" ]
@@ -621,11 +614,17 @@ def query(self, query: str, **kwargs) -> MemoryGraph:
621614 mg .append_edge (edge )
622615 return mg
623616
624- async def stream_query (self , query : str , ** kwargs ) -> AsyncGenerator [Graph , None ]:
617+ # type: ignore[override]
618+ # mypy: ignore-errors
619+ async def stream_query ( # type: ignore[override]
620+ self ,
621+ query : str ,
622+ ** kwargs ,
623+ ) -> AsyncGenerator [Graph , None ]:
625624 """Execute a stream query."""
626625 from neo4j import graph
627626
628- async for record in self .graph_store .conn .run_stream (query ):
627+ async for record in self .graph_store .conn .run_stream (query ): # type: ignore
629628 mg = MemoryGraph ()
630629 for key in record .keys ():
631630 value = record [key ]
@@ -650,15 +649,19 @@ async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None
650649 rels = list (record ["p" ].relationships )
651650 formatted_path = []
652651 for i in range (len (nodes )):
653- formatted_path .append ({
654- "id" : nodes [i ]._properties ["id" ],
655- "description" : nodes [i ]._properties ["description" ],
656- })
652+ formatted_path .append (
653+ {
654+ "id" : nodes [i ]._properties ["id" ],
655+ "description" : nodes [i ]._properties ["description" ],
656+ }
657+ )
657658 if i < len (rels ):
658- formatted_path .append ({
659- "id" : rels [i ]._properties ["id" ],
660- "description" : rels [i ]._properties ["description" ],
661- })
659+ formatted_path .append (
660+ {
661+ "id" : rels [i ]._properties ["id" ],
662+ "description" : rels [i ]._properties ["description" ],
663+ }
664+ )
662665 for i in range (0 , len (formatted_path ), 2 ):
663666 mg .upsert_vertex (
664667 Vertex (
0 commit comments