Commit b800f0b
authored
Use edge_ids directly in uniform sampling call to prevent cost of edge_id lookup (#2550)
This PR fixes #2520
**Speedup Details**
We see a 2.6x speedup , ranging from 0.8x to 10x.
**Benchmarking Gist:**
Benchmark Link: https://gist.github.com/VibhuJawa/38da2f151141c0582a0532a364458602
**Benchmarking Table:**
| dataset | fanout | seednodes | PR cugraph\_t (ms) | Main cugraph\_t (ms) | Speedup |
| ----------- | ------ | --------- | ------------------ | -------------------- | ------------ |
| livejournal | 5 | 6400 | 9.77469367 | 36.14 | 3.697743722 |
| livejournal | 5 | 12800 | 10.24105188 | 37.04 | 3.617198402 |
| livejournal | 5 | 25600 | 11.25398077 | 39.31 | 3.492790318 |
| livejournal | 5 | 51200 | 19.90233963 | 48.31 | 2.427492542 |
| livejournal | 20 | 6400 | 11.08045933 | 37.40 | 3.375111171 |
| livejournal | 20 | 12800 | 12.41813744 | 39.78 | 3.203001674 |
| livejournal | 20 | 25600 | 20.01964133 | 48.59 | 2.426926934 |
| livejournal | 20 | 51200 | 20.479394 | 51.75 | 2.526783655 |
| livejournal | 40 | 6400 | 18.02444187 | 38.42 | 2.13166189 |
| livejournal | 40 | 12800 | 15.95887286 | 41.13 | 2.577490516 |
| livejournal | 40 | 25600 | 30.42667777 | 49.21 | 1.617178892 |
| livejournal | 40 | 51200 | 31.27987486 | 56.83 | 1.816870032 |
| ogbn-arxiv | 5 | 6400 | 7.269433069 | 6.81 | 0.9363815769 |
| ogbn-arxiv | 5 | 12800 | 3.700939559 | 6.48 | 1.750559107 |
| ogbn-arxiv | 5 | 25600 | 7.43439748 | 6.74 | 0.9070057901 |
| ogbn-arxiv | 5 | 51200 | 8.364707041 | 8.92 | 1.06631151 |
| ogbn-arxiv | 20 | 6400 | 3.526507211 | 6.01 | 1.704996136 |
| ogbn-arxiv | 20 | 12800 | 7.11795785 | 6.35 | 0.8917298112 |
| ogbn-arxiv | 20 | 25600 | 9.83814247 | 8.87 | 0.9015745857 |
| ogbn-arxiv | 20 | 51200 | 19.16898326 | 15.28 | 0.797070347 |
| ogbn-arxiv | 40 | 6400 | 7.47879348 | 6.11 | 0.8169812813 |
| ogbn-arxiv | 40 | 12800 | 8.980390432 | 7.44 | 0.828701598 |
| ogbn-arxiv | 40 | 25600 | 9.939847551 | 9.78 | 0.9838518889 |
| ogbn-arxiv | 40 | 51200 | 21.65015471 | 17.39 | 0.8032186603 |
| reddit | 5 | 6400 | 4.485681872 | 47.60 | 10.61118206 |
| reddit | 5 | 12800 | 8.203881669 | 48.36 | 5.894866842 |
| reddit | 5 | 25600 | 10.19984847 | 51.61 | 5.05981494 |
| reddit | 5 | 51200 | 25.52061113 | 61.15 | 2.39617171 |
| reddit | 20 | 6400 | 9.60336474 | 51.21 | 5.333003796 |
| reddit | 20 | 12800 | 22.43147231 | 60.14 | 2.681092588 |
| reddit | 20 | 25600 | 23.204309 | 70.10 | 3.021163687 |
| reddit | 20 | 51200 | 27.07365799 | 76.18 | 2.813953476 |
| reddit | 40 | 6400 | 24.64297758 | 60.25 | 2.445081387 |
| reddit | 40 | 12800 | 23.05950785 | 68.38 | 2.965428975 |
| reddit | 40 | 25600 | 24.84033842 | 74.12 | 2.983957307 |
| reddit | 40 | 51200 | 30.75342988 | 87.18 | 2.834787134 |
**Bottleneck after the PR**
```python
Timer unit: 1e-06 s
Total time: 0.022579 s
File: /datasets/vjawa/miniconda3/envs/cugraph_dev_aug_10/lib/python3.9/site-packages/cugraph-22.10.0a0+45.g3ff5b53ff.dirty-py3.9-linux-x86_64.egg/cugraph/gnn/graph_store.py
Function: sample_neighbors at line 181
Line # Hits Time Per Hit % Time Line Contents
==============================================================
181 def sample_neighbors(
182 self, nodes, fanout=-1, edge_dir="in", prob=None, replace=False
183 ):
................
216 """
217
218 1 2.0 2.0 0.0 if edge_dir not in ["in", "out"]:
219 raise ValueError(
220 f"edge_dir must be either 'in' or 'out' got {edge_dir} instead"
221 )
222
223 1 1.0 1.0 0.0 if edge_dir == "in":
224 1 1.0 1.0 0.0 sg = self.extracted_reverse_subgraph_without_renumbering
225 else:
226 sg = self.extracted_subgraph_without_renumbering
227
228 1 1.0 1.0 0.0 if not hasattr(self, '_sg_node_dtype'):
229 self._sg_node_dtype = sg.edgelist.edgelist_df['src'].dtype
230
231 # Uniform sampling assumes fails when the dtype
232 # if the seed dtype is not same as the node dtype
233 1 774.0 774.0 3.4 nodes = cudf.from_dlpack(nodes).astype(self._sg_node_dtype)
234
235 2 19303.0 9651.5 85.5 sampled_df = uniform_neighbor_sample(
236 1 1.0 1.0 0.0 sg, start_list=nodes, fanout_vals=[fanout],
237 1 0.0 0.0 0.0 with_replacement=replace,
238 1 1.0 1.0 0.0 is_edge_ids=True # FIXME: Does not seem to do anything
239 )
240
241 # handle empty graph case
242 1 17.0 17.0 0.1 if len(sampled_df) == 0:
243 return None, None, None
244
245 # we reverse directions when directions=='in'
246 1 1.0 1.0 0.0 if edge_dir == "in":
247 2 136.0 68.0 0.6 sampled_df.rename(
248 1 1.0 1.0 0.0 columns={"destinations": src_n, "sources": dst_n}, inplace=True
249 )
250 else:
251 sampled_df.rename(
252 columns={"sources": src_n, "destinations": dst_n}, inplace=True
253 )
254
255 1 2.0 2.0 0.0 return (
256 1 786.0 786.0 3.5 sampled_df[src_n].to_dlpack(),
257 1 776.0 776.0 3.4 sampled_df[dst_n].to_dlpack(),
258 1 776.0 776.0 3.4 sampled_df['indices'].to_dlpack(),
259 )
```
Authors:
- Vibhu Jawa (https://github.com/VibhuJawa)
Approvers:
- Brad Rees (https://github.com/BradReesWork)
- Rick Ratzel (https://github.com/rlratzel)
- Alex Barghi (https://github.com/alexbarghi-nv)
URL: #25501 parent 6632d1e commit b800f0b
File tree
2 files changed
+43
-17
lines changed- python/cugraph/cugraph
- gnn
- tests
2 files changed
+43
-17
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
234 | 234 | | |
235 | 235 | | |
236 | 236 | | |
237 | | - | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
238 | 241 | | |
239 | 242 | | |
240 | | - | |
241 | | - | |
242 | 243 | | |
243 | 244 | | |
244 | 245 | | |
| |||
253 | 254 | | |
254 | 255 | | |
255 | 256 | | |
256 | | - | |
257 | | - | |
258 | | - | |
259 | | - | |
260 | | - | |
261 | 257 | | |
262 | 258 | | |
263 | 259 | | |
264 | | - | |
| 260 | + | |
265 | 261 | | |
266 | 262 | | |
267 | 263 | | |
268 | 264 | | |
269 | 265 | | |
270 | 266 | | |
271 | | - | |
| 267 | + | |
272 | 268 | | |
273 | | - | |
274 | 269 | | |
275 | 270 | | |
276 | 271 | | |
277 | 272 | | |
278 | 273 | | |
279 | | - | |
280 | | - | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
281 | 277 | | |
282 | 278 | | |
283 | 279 | | |
284 | 280 | | |
285 | 281 | | |
286 | 282 | | |
287 | 283 | | |
288 | | - | |
289 | | - | |
| 284 | + | |
| 285 | + | |
290 | 286 | | |
291 | 287 | | |
292 | 288 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
160 | 160 | | |
161 | 161 | | |
162 | 162 | | |
163 | | - | |
164 | 163 | | |
165 | 164 | | |
166 | 165 | | |
| |||
425 | 424 | | |
426 | 425 | | |
427 | 426 | | |
428 | | - | |
429 | 427 | | |
430 | 428 | | |
431 | 429 | | |
| |||
460 | 458 | | |
461 | 459 | | |
462 | 460 | | |
| 461 | + | |
463 | 462 | | |
464 | 463 | | |
465 | 464 | | |
| 465 | + | |
466 | 466 | | |
467 | 467 | | |
468 | 468 | | |
| |||
473 | 473 | | |
474 | 474 | | |
475 | 475 | | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
476 | 481 | | |
477 | 482 | | |
478 | 483 | | |
| |||
498 | 503 | | |
499 | 504 | | |
500 | 505 | | |
| 506 | + | |
501 | 507 | | |
502 | 508 | | |
503 | 509 | | |
| 510 | + | |
504 | 511 | | |
505 | 512 | | |
506 | 513 | | |
| |||
510 | 517 | | |
511 | 518 | | |
512 | 519 | | |
| 520 | + | |
| 521 | + | |
| 522 | + | |
| 523 | + | |
| 524 | + | |
| 525 | + | |
| 526 | + | |
| 527 | + | |
| 528 | + | |
| 529 | + | |
| 530 | + | |
| 531 | + | |
| 532 | + | |
| 533 | + | |
| 534 | + | |
| 535 | + | |
| 536 | + | |
| 537 | + | |
| 538 | + | |
| 539 | + | |
| 540 | + | |
| 541 | + | |
| 542 | + | |
0 commit comments