Skip to content

Commit bd5a5cf

Browse files
committed
Working tests apart from hetrograph_sampling and eids
1 parent 7a5e993 commit bd5a5cf

File tree

2 files changed

+113
-71
lines changed

2 files changed

+113
-71
lines changed

python/cugraph/cugraph/gnn/graph_store.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def __init__(self, graph, backend_lib="torch"):
4040
if isinstance(graph, (PropertyGraph, MGPropertyGraph)):
4141
self.__G = graph
4242
else:
43-
raise ValueError("graph must be a PropertyGraph or MGPropertyGraph")
43+
raise ValueError("graph must be a PropertyGraph or"
44+
" MGPropertyGraph")
4445
# dict to map column names corresponding to edge features
4546
# of each type
4647
self.edata_feat_col_d = defaultdict(list)
@@ -58,7 +59,8 @@ def add_node_data(self, df, node_col_name, feat_name, ntype=None):
5859
self.ndata_feat_col_d[feat_name] += col_names
5960

6061
# Ensure that we only keep unique column names lying around
61-
self.ndata_feat_col_d[feat_name] = list(set(self.ndata_feat_col_d[feat_name]))
62+
unique_names = list(set(self.ndata_feat_col_d[feat_name]))
63+
self.ndata_feat_col_d[feat_name] = unique_names
6264

6365
def add_edge_data(self, df, vertex_col_names, feat_name, etype=None):
6466
self.gdata.add_edge_data(
@@ -69,7 +71,8 @@ def add_edge_data(self, df, vertex_col_names, feat_name, etype=None):
6971
]
7072
self.edata_feat_col_d[feat_name] += col_names
7173
# Ensure that we only keep unique column names lying around
72-
self.edata_feat_col_d[feat_name] = list(set(self.edata_feat_col_d[feat_name]))
74+
unique_names = list(set(self.edata_feat_col_d[feat_name]))
75+
self.edata_feat_col_d[feat_name] = unique_names
7376

7477
def get_node_storage(self, feat_name, ntype=None):
7578

@@ -83,6 +86,10 @@ def get_node_storage(self, feat_name, ntype=None):
8386
)
8487
)
8588
ntype = ntypes[0]
89+
if feat_name not in self.ndata_feat_col_d:
90+
raise ValueError(f"feat_name {feat_name} not found in CuGraphStore"
91+
" node features",
92+
f" {list(self.ndata_feat_col_d.keys())}")
8693

8794
col_names = self.ndata_feat_col_d[feat_name]
8895

@@ -99,18 +106,22 @@ def get_edge_storage(self, feat_name, etype=None):
99106
if len(self.etypes) > 1:
100107
raise ValueError(
101108
(
102-
"Edge type name must be specified if there"
109+
"Edge type name must be specified if there "
103110
"are more than one edge types."
104111
)
105112
)
106113

107114
etype = etypes[0]
115+
if feat_name not in self.edata_feat_col_d:
116+
raise ValueError(f"feat_name {feat_name} not found in CuGraphStore"
117+
" edge features",
118+
f" {list(self.edata_feat_col_d.keys())}")
108119
col_names = self.edata_feat_col_d[feat_name]
109120

110121
return CuFeatureStorage(
111122
pg=self.gdata,
112123
col_names=col_names,
113-
storage_type="node",
124+
storage_type="edge",
114125
backend_lib=self.backend_lib,
115126
)
116127

@@ -126,11 +137,11 @@ def has_multiple_etypes(self):
126137

127138
@property
128139
def ntypes(self):
129-
return self.gdata.vertex_types
140+
return list(self.gdata.vertex_types)
130141

131142
@property
132143
def etypes(self):
133-
return self.gdata.edge_types
144+
return list(self.gdata.edge_types)
134145

135146
@property
136147
def is_mg(self):
@@ -359,8 +370,6 @@ def node_subgraph(
359370
self,
360371
nodes=None,
361372
create_using=cugraph.MultiGraph,
362-
directed=False,
363-
multigraph=True,
364373
):
365374
"""
366375
Return a subgraph induced on the given nodes.
@@ -379,10 +388,7 @@ def node_subgraph(
379388
The sampled subgraph with the same node ID space with the original
380389
graph.
381390
"""
382-
_g = self.gdata.extract_subgraph(
383-
create_using=create_using,
384-
allow_multi_edges=cugraph.MultiGraph,
385-
)
391+
_g = self.gdata.extract_subgraph(create_using=create_using)
386392

387393
if nodes is None:
388394
return _g
@@ -455,7 +461,6 @@ def fetch(self, indices, device, pin_memory=False, **kwargs):
455461

456462
if isinstance(subset_df, dask_cudf.DataFrame):
457463
subset_df = subset_df.compute()
458-
459464
tensor = self.from_dlpack(subset_df.to_dlpack())
460465

461466
if isinstance(tensor, cp.ndarray):

0 commit comments

Comments
 (0)