Skip to content

Commit 4ea838f

Browse files
[FIX] Match the PyG API for Node Input to the Loader (#3514)
Ensures heterogeneous graphs can be handled correctly in cugraph-pyg. Also cleans up some technical debt and adds some key tests. Resolves #3333 Authors: - Alex Barghi (https://github.com/alexbarghi-nv) - Brad Rees (https://github.com/BradReesWork) Approvers: - Rick Ratzel (https://github.com/rlratzel) URL: #3514
1 parent 649edb6 commit 4ea838f

File tree

6 files changed

+222
-42
lines changed

6 files changed

+222
-42
lines changed

python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def __init__(
271271

272272
self.__infer_offsets(num_nodes_dict, num_edges_dict)
273273
self.__infer_existing_tensors(F)
274-
self.__infer_edge_types(num_edges_dict)
274+
self.__infer_edge_types(num_nodes_dict, num_edges_dict)
275275

276276
self._edge_attr_cls = CuGraphEdgeAttr
277277

@@ -462,6 +462,9 @@ def _is_delayed(self):
462462
return False
463463
return self.__graph.is_multi_gpu()
464464

465+
def _numeric_vertex_type_from_name(self, vertex_type_name: str) -> int:
466+
return np.searchsorted(self.__vertex_type_offsets["type"], vertex_type_name)
467+
465468
def get_vertex_index(self, vtypes) -> TensorType:
466469
if isinstance(vtypes, str):
467470
vtypes = [vtypes]
@@ -559,12 +562,12 @@ def _get_edge_index(self, attr: CuGraphEdgeAttr) -> Tuple[TensorType, TensorType
559562
src_type, _, dst_type = attr.edge_type
560563
src_offset = int(
561564
self.__vertex_type_offsets["start"][
562-
np.searchsorted(self.__vertex_type_offsets["type"], src_type)
565+
self._numeric_vertex_type_from_name(src_type)
563566
]
564567
)
565568
dst_offset = int(
566569
self.__vertex_type_offsets["start"][
567-
np.searchsorted(self.__vertex_type_offsets["type"], dst_type)
570+
self._numeric_vertex_type_from_name(dst_type)
568571
]
569572
)
570573
coli = np.searchsorted(
@@ -693,6 +696,29 @@ def _get_vertex_groups_from_sample(
693696

694697
return noi_index
695698

699+
def _get_sample_from_vertex_groups(
700+
self, vertex_groups: Dict[str, TensorType]
701+
) -> TensorType:
702+
"""
703+
Inverse of _get_vertex_groups_from_sample() (although with de-offsetted ids).
704+
Given a dictionary of node types and de-offsetted node ids, return
705+
the global (non-renumbered) vertex ids.
706+
707+
Example Input: {'horse': [1, 3, 5], 'duck': [1, 2]}
708+
Output: [1, 3, 5, 14, 15]
709+
"""
710+
t = torch.tensor([], dtype=torch.int64, device="cuda")
711+
712+
for group_name, ix in vertex_groups.items():
713+
type_id = self._numeric_vertex_type_from_name(group_name)
714+
if not ix.is_cuda:
715+
ix = ix.cuda()
716+
offset = self.__vertex_type_offsets["start"][type_id]
717+
u = ix + offset
718+
t = torch.concatenate([t, u])
719+
720+
return t
721+
696722
def _get_renumbered_edge_groups_from_sample(
697723
self, sampling_results: cudf.DataFrame, noi_index: dict
698724
) -> Tuple[dict, dict]:
@@ -823,16 +849,21 @@ def create_named_tensor(
823849
)
824850
)
825851

826-
def __infer_edge_types(self, num_edges_dict) -> None:
852+
def __infer_edge_types(
853+
self,
854+
num_nodes_dict: Dict[str, int],
855+
num_edges_dict: Dict[Tuple[str, str, str], int],
856+
) -> None:
827857
self.__edge_types_to_attrs = {}
828858

829859
for pyg_can_edge_type in sorted(num_edges_dict.keys()):
830-
sz = num_edges_dict[pyg_can_edge_type]
860+
sz_src = num_nodes_dict[pyg_can_edge_type[0]]
861+
sz_dst = num_nodes_dict[pyg_can_edge_type[-1]]
831862
self.__edge_types_to_attrs[pyg_can_edge_type] = CuGraphEdgeAttr(
832863
edge_type=pyg_can_edge_type,
833864
layout=EdgeLayout.COO,
834865
is_sorted=False,
835-
size=(sz, sz),
866+
size=(sz_src, sz_dst),
836867
)
837868

838869
def __infer_existing_tensors(self, F) -> None:
@@ -862,22 +893,25 @@ def _get_tensor(self, attr: CuGraphTensorAttr) -> TensorType:
862893
cols = attr.properties
863894

864895
idx = attr.index
865-
if feature_backend == "torch":
866-
if not isinstance(idx, torch.Tensor):
867-
raise TypeError(
868-
f"Type {type(idx)} invalid"
869-
f" for feature store backend {feature_backend}"
870-
)
871-
idx = idx.cpu()
872-
elif feature_backend == "numpy":
873-
# allow indexing through cupy arrays
874-
if isinstance(idx, cupy.ndarray):
875-
idx = idx.get()
876-
elif isinstance(idx, torch.Tensor):
877-
idx = np.asarray(idx.cpu())
896+
if idx is not None:
897+
if feature_backend == "torch":
898+
if not isinstance(idx, torch.Tensor):
899+
raise TypeError(
900+
f"Type {type(idx)} invalid"
901+
f" for feature store backend {feature_backend}"
902+
)
903+
idx = idx.cpu()
904+
elif feature_backend == "numpy":
905+
# allow feature indexing through cupy arrays
906+
if isinstance(idx, cupy.ndarray):
907+
idx = idx.get()
908+
elif isinstance(idx, torch.Tensor):
909+
idx = np.asarray(idx.cpu())
878910

879911
if cols is None:
880912
t = self.__features.get_data(idx, attr.group_name, attr.attr_name)
913+
if idx is None:
914+
t = t[-1]
881915

882916
if isinstance(t, np.ndarray):
883917
t = torch.as_tensor(t, device="cuda")

python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,14 @@
2626
from cugraph_pyg.loader.filter import _filter_cugraph_store
2727
from cugraph_pyg.sampler.cugraph_sampler import _sampler_output_from_sampling_results
2828

29-
from typing import Union, Tuple, Sequence, List
29+
from typing import Union, Tuple, Sequence, List, Dict
3030

3131
torch_geometric = import_optional("torch_geometric")
32+
InputNodes = (
33+
Sequence
34+
if isinstance(torch_geometric, MissingModule)
35+
else torch_geometric.typing.InputNodes
36+
)
3237

3338

3439
class EXPERIMENTAL__BulkSampleLoader:
@@ -39,15 +44,15 @@ def __init__(
3944
self,
4045
feature_store: CuGraphStore,
4146
graph_store: CuGraphStore,
42-
all_indices: Union[Sequence, int],
47+
input_nodes: Union[InputNodes, int] = None,
4348
batch_size: int = 0,
4449
shuffle=False,
4550
edge_types: Sequence[Tuple[str]] = None,
4651
directory=None,
4752
starting_batch_id=0,
4853
batches_per_partition=100,
4954
# Sampler args
50-
num_neighbors: List[int] = [1, 1],
55+
num_neighbors: Union[List[int], Dict[Tuple[str, str, str], List[int]]] = None,
5156
replace: bool = True,
5257
# Other kwargs for the BulkSampler
5358
**kwargs,
@@ -64,9 +69,9 @@ def __init__(
6469
graph_store: CuGraphStore
6570
The graph store containing the graph structure.
6671
67-
all_indices: Union[Tensor, int]
72+
input_nodes: Union[InputNodes, int]
6873
The input nodes associated with this sampler.
69-
If this is an integer N , this loader will load N batches
74+
If this is an integer N, this loader will load N batches
7075
from disk rather than performing sampling in memory.
7176
7277
batch_size: int
@@ -97,6 +102,16 @@ def __init__(
97102
Defaults to 100. Gets passed to the bulk
98103
sampler if there is one; otherwise, this argument
99104
is used to determine which files to read.
105+
106+
num_neighbors: Union[List[int],
107+
Dict[Tuple[str, str, str], List[int]]] (required)
108+
The number of neighbors to sample for each node in each iteration.
109+
If an entry is set to -1, all neighbors will be included.
110+
In heterogeneous graphs, may also take in a dictionary denoting
111+
the number of neighbors to sample for each individual edge type.
112+
113+
Note: in cuGraph, only one value of num_neighbors is currently supported.
114+
Passing in a dictionary will result in an exception.
100115
"""
101116

102117
self.__feature_store = feature_store
@@ -106,18 +121,29 @@ def __init__(
106121
self.__batches_per_partition = batches_per_partition
107122
self.__starting_batch_id = starting_batch_id
108123

109-
if isinstance(all_indices, int):
124+
if isinstance(input_nodes, int):
110125
# Will be loading from disk
111-
self.__num_batches = all_indices
126+
self.__num_batches = input_nodes
112127
self.__directory = directory
113128
iter(os.listdir(self.__directory))
114129
return
115130

131+
input_type, input_nodes = torch_geometric.loader.utils.get_input_nodes(
132+
(feature_store, graph_store), input_nodes
133+
)
134+
if input_type is not None:
135+
input_nodes = graph_store._get_sample_from_vertex_groups(
136+
{input_type: input_nodes}
137+
)
138+
116139
if batch_size is None or batch_size < 1:
117140
raise ValueError("Batch size must be >= 1")
118141

119142
self.__directory = tempfile.TemporaryDirectory(dir=directory)
120143

144+
if isinstance(num_neighbors, dict):
145+
raise ValueError("num_neighbors dict is currently unsupported!")
146+
121147
bulk_sampler = BulkSampler(
122148
batch_size,
123149
self.__directory.name,
@@ -129,21 +155,21 @@ def __init__(
129155
)
130156

131157
# Make sure indices are in cupy
132-
all_indices = cupy.asarray(all_indices)
158+
input_nodes = cupy.asarray(input_nodes)
133159

134160
# Shuffle
135161
if shuffle:
136-
cupy.random.shuffle(all_indices)
162+
cupy.random.shuffle(input_nodes)
137163

138164
# Truncate if we can't evenly divide the input array
139-
stop = (len(all_indices) // batch_size) * batch_size
140-
all_indices = all_indices[:stop]
165+
stop = (len(input_nodes) // batch_size) * batch_size
166+
input_nodes = input_nodes[:stop]
141167

142168
# Split into batches
143-
all_indices = cupy.split(all_indices, len(all_indices) // batch_size)
169+
input_nodes = cupy.split(input_nodes, len(input_nodes) // batch_size)
144170

145171
self.__num_batches = 0
146-
for batch_num, batch_i in enumerate(all_indices):
172+
for batch_num, batch_i in enumerate(input_nodes):
147173
self.__num_batches += 1
148174
bulk_sampler.add_batches(
149175
cudf.DataFrame(
@@ -246,8 +272,8 @@ class EXPERIMENTAL__CuGraphNeighborLoader:
246272
def __init__(
247273
self,
248274
data: Union[CuGraphStore, Tuple[CuGraphStore, CuGraphStore]],
249-
input_nodes: Sequence,
250-
batch_size: int,
275+
input_nodes: Union[InputNodes, int] = None,
276+
batch_size: int = None,
251277
**kwargs,
252278
):
253279
"""
@@ -256,19 +282,23 @@ def __init__(
256282
data: CuGraphStore or (CuGraphStore, CuGraphStore)
257283
The CuGraphStore or stores where the graph/feature data is held.
258284
259-
batch_size: int
285+
batch_size: int (required)
260286
The number of input nodes in each batch.
261287
262-
input_nodes: Tensor
263-
The input nodes for *this* loader. If there are multiple loaders,
264-
the appropriate split should be given for this loader.
288+
input_nodes: Union[InputNodes, int] (required)
289+
The input nodes associated with this sampler.
265290
266291
**kwargs: kwargs
267292
Keyword arguments to pass through for sampling.
268293
i.e. "shuffle", "fanout"
269294
See BulkSampleLoader.
270295
"""
271296

297+
if input_nodes is None:
298+
raise ValueError("input_nodes is required")
299+
if batch_size is None:
300+
raise ValueError("batch_size is required")
301+
272302
# Allow passing in a feature store and graph store as a tuple, as
273303
# in the standard PyG API. If only one is passed, it is assumed
274304
# it is behaving as both a graph store and a feature store.

python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_loader.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,30 @@ def test_cugraph_loader_basic(dask_client, karate_gnn):
4747
if "type1" in sample:
4848
for prop in sample["type1"]["prop0"].tolist():
4949
assert prop % 41 == 0
50+
51+
52+
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
53+
def test_cugraph_loader_hetero(dask_client, karate_gnn):
54+
F, G, N = karate_gnn
55+
cugraph_store = CuGraphStore(F, G, N, multi_gpu=True)
56+
loader = CuGraphNeighborLoader(
57+
(cugraph_store, cugraph_store),
58+
input_nodes=("type1", torch.tensor([0, 1, 2, 5], device="cuda")),
59+
batch_size=2,
60+
num_neighbors=[4, 4],
61+
random_state=62,
62+
replace=False,
63+
)
64+
65+
samples = [s for s in loader]
66+
67+
assert len(samples) == 2
68+
for sample in samples:
69+
print(sample)
70+
if "type0" in sample:
71+
for prop in sample["type0"]["prop0"].tolist():
72+
assert prop % 31 == 0
73+
74+
if "type1" in sample:
75+
for prop in sample["type1"]["prop0"].tolist():
76+
assert prop % 41 == 0

python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333

3434
torch = import_optional("torch")
35+
torch_geometric = import_optional("torch_geometric")
3536

3637

3738
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@@ -152,7 +153,10 @@ def test_edge_types(graph, dask_client):
152153
assert eta.keys() == G.keys()
153154

154155
for attr_name, attr_repr in eta.items():
155-
assert len(G[attr_name][0]) == attr_repr.size[-1]
156+
src_size = N[attr_name[0]]
157+
dst_size = N[attr_name[-1]]
158+
assert src_size == attr_repr.size[0]
159+
assert dst_size == attr_repr.size[-1]
156160
assert attr_name == attr_repr.edge_type
157161

158162

@@ -311,6 +315,17 @@ def test_get_tensor(graph, dask_client):
311315
assert tsr == base_series
312316

313317

318+
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
319+
def test_get_tensor_empty_idx(karate_gnn, dask_client):
320+
F, G, N = karate_gnn
321+
cugraph_store = CuGraphStore(F, G, N, multi_gpu=True)
322+
323+
t = cugraph_store.get_tensor(
324+
CuGraphTensorAttr(group_name="type0", attr_name="prop0", index=None)
325+
)
326+
assert t.tolist() == (torch.arange(17, dtype=torch.float32) * 31).tolist()
327+
328+
314329
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
315330
def test_multi_get_tensor(graph, dask_client):
316331
F, G, N = graph
@@ -397,6 +412,22 @@ def test_get_tensor_size(graph, dask_client):
397412
assert cugraph_store.get_tensor_size(tensor_attr) == torch.Size((sz,))
398413

399414

415+
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
416+
@pytest.mark.skipif(
417+
isinstance(torch_geometric, MissingModule), reason="pyg not available"
418+
)
419+
def test_get_input_nodes(karate_gnn, dask_client):
420+
F, G, N = karate_gnn
421+
cugraph_store = CuGraphStore(F, G, N, multi_gpu=True)
422+
423+
node_type, input_nodes = torch_geometric.loader.utils.get_input_nodes(
424+
(cugraph_store, cugraph_store), "type0"
425+
)
426+
427+
assert node_type == "type0"
428+
assert input_nodes.tolist() == torch.arange(17, dtype=torch.int32).tolist()
429+
430+
400431
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
401432
def test_mg_frame_handle(graph, dask_client):
402433
F, G, N = graph

0 commit comments

Comments
 (0)