@@ -40,7 +40,8 @@ def __init__(self, graph, backend_lib="torch"):
4040 if isinstance (graph , (PropertyGraph , MGPropertyGraph )):
4141 self .__G = graph
4242 else :
43- raise ValueError ("graph must be a PropertyGraph or MGPropertyGraph" )
43+ raise ValueError ("graph must be a PropertyGraph or"
44+ " MGPropertyGraph" )
4445 # dict to map column names corresponding to edge features
4546 # of each type
4647 self .edata_feat_col_d = defaultdict (list )
@@ -58,7 +59,8 @@ def add_node_data(self, df, node_col_name, feat_name, ntype=None):
5859 self .ndata_feat_col_d [feat_name ] += col_names
5960
6061 # Ensure that we only keep unique column names lying around
61- self .ndata_feat_col_d [feat_name ] = list (set (self .ndata_feat_col_d [feat_name ]))
62+ unique_names = list (set (self .ndata_feat_col_d [feat_name ]))
63+ self .ndata_feat_col_d [feat_name ] = unique_names
6264
6365 def add_edge_data (self , df , vertex_col_names , feat_name , etype = None ):
6466 self .gdata .add_edge_data (
@@ -69,7 +71,8 @@ def add_edge_data(self, df, vertex_col_names, feat_name, etype=None):
6971 ]
7072 self .edata_feat_col_d [feat_name ] += col_names
7173 # Ensure that we only keep unique column names lying around
72- self .edata_feat_col_d [feat_name ] = list (set (self .edata_feat_col_d [feat_name ]))
74+ unique_names = list (set (self .edata_feat_col_d [feat_name ]))
75+ self .edata_feat_col_d [feat_name ] = unique_names
7376
7477 def get_node_storage (self , feat_name , ntype = None ):
7578
@@ -83,6 +86,10 @@ def get_node_storage(self, feat_name, ntype=None):
8386 )
8487 )
8588 ntype = ntypes [0 ]
89+ if feat_name not in self .ndata_feat_col_d :
90+ raise ValueError (f"feat_name { feat_name } not found in CuGraphStore"
91+ " node features" ,
92+ f" { list (self .ndata_feat_col_d .keys ())} " )
8693
8794 col_names = self .ndata_feat_col_d [feat_name ]
8895
@@ -99,18 +106,22 @@ def get_edge_storage(self, feat_name, etype=None):
99106 if len (self .etypes ) > 1 :
100107 raise ValueError (
101108 (
102- "Edge type name must be specified if there"
109+ "Edge type name must be specified if there "
103110 "are more than one edge types."
104111 )
105112 )
106113
107114 etype = etypes [0 ]
115+ if feat_name not in self .edata_feat_col_d :
116+ raise ValueError (f"feat_name { feat_name } not found in CuGraphStore"
117+ " edge features" ,
118+ f" { list (self .edata_feat_col_d .keys ())} " )
108119 col_names = self .edata_feat_col_d [feat_name ]
109120
110121 return CuFeatureStorage (
111122 pg = self .gdata ,
112123 col_names = col_names ,
113- storage_type = "node " ,
124+ storage_type = "edge " ,
114125 backend_lib = self .backend_lib ,
115126 )
116127
@@ -126,11 +137,11 @@ def has_multiple_etypes(self):
126137
127138 @property
128139 def ntypes (self ):
129- return self .gdata .vertex_types
140+ return list ( self .gdata .vertex_types )
130141
131142 @property
132143 def etypes (self ):
133- return self .gdata .edge_types
144+ return list ( self .gdata .edge_types )
134145
135146 @property
136147 def is_mg (self ):
@@ -359,8 +370,6 @@ def node_subgraph(
359370 self ,
360371 nodes = None ,
361372 create_using = cugraph .MultiGraph ,
362- directed = False ,
363- multigraph = True ,
364373 ):
365374 """
366375 Return a subgraph induced on the given nodes.
@@ -379,10 +388,7 @@ def node_subgraph(
379388 The sampled subgraph with the same node ID space with the original
380389 graph.
381390 """
382- _g = self .gdata .extract_subgraph (
383- create_using = create_using ,
384- allow_multi_edges = cugraph .MultiGraph ,
385- )
391+ _g = self .gdata .extract_subgraph (create_using = create_using )
386392
387393 if nodes is None :
388394 return _g
@@ -455,7 +461,6 @@ def fetch(self, indices, device, pin_memory=False, **kwargs):
455461
456462 if isinstance (subset_df , dask_cudf .DataFrame ):
457463 subset_df = subset_df .compute ()
458-
459464 tensor = self .from_dlpack (subset_df .to_dlpack ())
460465
461466 if isinstance (tensor , cp .ndarray ):
0 commit comments