Skip to content

Commit 8507cbf

Browse files
[FEA] Heterogeneous Distributed Sampling (rapidsai#4795)
Adds support for heterogeneous distributed sampling to the cuGraph distributed sampler. Prerequisite for exposing this functionality to cuGraph-PyG. Has been initially tested with cuGraph-PyG. Updates the distributed sampler to use the new sampling API. Merge after rapidsai#4775, rapidsai#4827, rapidsai#4820 Closes rapidsai#4773 Closes rapidsai#4401 Authors: - Alex Barghi (https://github.com/alexbarghi-nv) - Joseph Nke (https://github.com/jnke2016) - Ralph Liu (https://github.com/nv-rliu) Approvers: - Rick Ratzel (https://github.com/rlratzel) URL: rapidsai#4795
1 parent dd228f9 commit 8507cbf

File tree

10 files changed

+349
-274
lines changed

10 files changed

+349
-274
lines changed

python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -79,9 +79,15 @@ def get_reader(
7979
return DistSampleReader(self._directory, format=self._format, rank=rank)
8080

8181
def __write_minibatches_coo(self, minibatch_dict):
82-
has_edge_ids = minibatch_dict["edge_id"] is not None
83-
has_edge_types = minibatch_dict["edge_type"] is not None
84-
has_weights = minibatch_dict["weight"] is not None
82+
has_edge_ids = (
83+
"edge_id" in minibatch_dict and minibatch_dict["edge_id"] is not None
84+
)
85+
has_edge_types = (
86+
"edge_type" in minibatch_dict and minibatch_dict["edge_type"] is not None
87+
)
88+
has_weights = (
89+
"weight" in minibatch_dict and minibatch_dict["weight"] is not None
90+
)
8591

8692
if minibatch_dict["renumber_map"] is None:
8793
raise ValueError(
@@ -92,22 +98,22 @@ def __write_minibatches_coo(self, minibatch_dict):
9298
if len(minibatch_dict["batch_id"]) == 0:
9399
return
94100

95-
fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len(
96-
minibatch_dict["batch_id"]
97-
)
101+
fanout_length = len(minibatch_dict["fanout"])
102+
total_num_batches = (
103+
len(minibatch_dict["label_hop_offsets"]) - 1
104+
) / fanout_length
98105

99-
for p in range(
100-
0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition))
101-
):
106+
for p in range(0, int(ceil(total_num_batches / self.__batches_per_partition))):
102107
partition_start = p * (self.__batches_per_partition)
103108
partition_end = (p + 1) * (self.__batches_per_partition)
104109

105110
label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][
106111
partition_start * fanout_length : partition_end * fanout_length + 1
107112
]
108113

109-
batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end]
110-
start_batch_id = batch_id_array_p[0]
114+
num_batches_p = len(label_hop_offsets_array_p) - 1
115+
116+
start_batch_id = minibatch_dict["batch_start"]
111117

112118
input_offsets_p = minibatch_dict["input_offsets"][
113119
partition_start : (partition_end + 1)
@@ -171,7 +177,7 @@ def __write_minibatches_coo(self, minibatch_dict):
171177
}
172178
)
173179

174-
end_batch_id = start_batch_id + len(batch_id_array_p) - 1
180+
end_batch_id = start_batch_id + num_batches_p - 1
175181
rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0
176182

177183
full_output_path = os.path.join(
@@ -188,9 +194,15 @@ def __write_minibatches_coo(self, minibatch_dict):
188194
)
189195

190196
def __write_minibatches_csr(self, minibatch_dict):
191-
has_edge_ids = minibatch_dict["edge_id"] is not None
192-
has_edge_types = minibatch_dict["edge_type"] is not None
193-
has_weights = minibatch_dict["weight"] is not None
197+
has_edge_ids = (
198+
"edge_id" in minibatch_dict and minibatch_dict["edge_id"] is not None
199+
)
200+
has_edge_types = (
201+
"edge_type" in minibatch_dict and minibatch_dict["edge_type"] is not None
202+
)
203+
has_weights = (
204+
"weight" in minibatch_dict and minibatch_dict["weight"] is not None
205+
)
194206

195207
if minibatch_dict["renumber_map"] is None:
196208
raise ValueError(
@@ -201,22 +213,22 @@ def __write_minibatches_csr(self, minibatch_dict):
201213
if len(minibatch_dict["batch_id"]) == 0:
202214
return
203215

204-
fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len(
205-
minibatch_dict["batch_id"]
206-
)
216+
fanout_length = len(minibatch_dict["fanout"])
217+
total_num_batches = (
218+
len(minibatch_dict["label_hop_offsets"]) - 1
219+
) / fanout_length
207220

208-
for p in range(
209-
0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition))
210-
):
221+
for p in range(0, int(ceil(total_num_batches / self.__batches_per_partition))):
211222
partition_start = p * (self.__batches_per_partition)
212223
partition_end = (p + 1) * (self.__batches_per_partition)
213224

214225
label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][
215226
partition_start * fanout_length : partition_end * fanout_length + 1
216227
]
217228

218-
batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end]
219-
start_batch_id = batch_id_array_p[0]
229+
num_batches_p = len(label_hop_offsets_array_p) - 1
230+
231+
start_batch_id = minibatch_dict["batch_start"]
220232

221233
input_offsets_p = minibatch_dict["input_offsets"][
222234
partition_start : (partition_end + 1)
@@ -292,7 +304,7 @@ def __write_minibatches_csr(self, minibatch_dict):
292304
}
293305
)
294306

295-
end_batch_id = start_batch_id + len(batch_id_array_p) - 1
307+
end_batch_id = start_batch_id + num_batches_p - 1
296308
rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0
297309

298310
full_output_path = os.path.join(
@@ -309,12 +321,19 @@ def __write_minibatches_csr(self, minibatch_dict):
309321
)
310322

311323
def write_minibatches(self, minibatch_dict):
312-
if (minibatch_dict["majors"] is not None) and (
313-
minibatch_dict["minors"] is not None
314-
):
324+
if "minors" not in minibatch_dict:
325+
raise ValueError("invalid columns")
326+
327+
# PLC API specifies this behavior for empty input
328+
# This needs to be handled here to avoid causing a hang
329+
if len(minibatch_dict["minors"]) == 0:
330+
return
331+
332+
if "majors" in minibatch_dict and minibatch_dict["majors"] is not None:
315333
self.__write_minibatches_coo(minibatch_dict)
316-
elif (minibatch_dict["major_offsets"] is not None) and (
317-
minibatch_dict["minors"] is not None
334+
elif (
335+
"major_offsets" in minibatch_dict
336+
and minibatch_dict["major_offsets"] is not None
318337
):
319338
self.__write_minibatches_csr(minibatch_dict)
320339
else:

0 commit comments

Comments
 (0)