Skip to content

Commit b800f0b

Browse files
authored
Use edge_ids directly in uniform sampling call to prevent cost of edge_id lookup (#2550)
This PR fixes #2520 **Speedup Details** We see a 2.6x speedup , ranging from 0.8x to 10x. **Benchmarking Gist:** Benchmark Link: https://gist.github.com/VibhuJawa/38da2f151141c0582a0532a364458602 **Benchmarking Table:** | dataset | fanout | seednodes | PR cugraph\_t (ms) | Main cugraph\_t (ms) | Speedup | | ----------- | ------ | --------- | ------------------ | -------------------- | ------------ | | livejournal | 5 | 6400 | 9.77469367 | 36.14 | 3.697743722 | | livejournal | 5 | 12800 | 10.24105188 | 37.04 | 3.617198402 | | livejournal | 5 | 25600 | 11.25398077 | 39.31 | 3.492790318 | | livejournal | 5 | 51200 | 19.90233963 | 48.31 | 2.427492542 | | livejournal | 20 | 6400 | 11.08045933 | 37.40 | 3.375111171 | | livejournal | 20 | 12800 | 12.41813744 | 39.78 | 3.203001674 | | livejournal | 20 | 25600 | 20.01964133 | 48.59 | 2.426926934 | | livejournal | 20 | 51200 | 20.479394 | 51.75 | 2.526783655 | | livejournal | 40 | 6400 | 18.02444187 | 38.42 | 2.13166189 | | livejournal | 40 | 12800 | 15.95887286 | 41.13 | 2.577490516 | | livejournal | 40 | 25600 | 30.42667777 | 49.21 | 1.617178892 | | livejournal | 40 | 51200 | 31.27987486 | 56.83 | 1.816870032 | | ogbn-arxiv | 5 | 6400 | 7.269433069 | 6.81 | 0.9363815769 | | ogbn-arxiv | 5 | 12800 | 3.700939559 | 6.48 | 1.750559107 | | ogbn-arxiv | 5 | 25600 | 7.43439748 | 6.74 | 0.9070057901 | | ogbn-arxiv | 5 | 51200 | 8.364707041 | 8.92 | 1.06631151 | | ogbn-arxiv | 20 | 6400 | 3.526507211 | 6.01 | 1.704996136 | | ogbn-arxiv | 20 | 12800 | 7.11795785 | 6.35 | 0.8917298112 | | ogbn-arxiv | 20 | 25600 | 9.83814247 | 8.87 | 0.9015745857 | | ogbn-arxiv | 20 | 51200 | 19.16898326 | 15.28 | 0.797070347 | | ogbn-arxiv | 40 | 6400 | 7.47879348 | 6.11 | 0.8169812813 | | ogbn-arxiv | 40 | 12800 | 8.980390432 | 7.44 | 0.828701598 | | ogbn-arxiv | 40 | 25600 | 9.939847551 | 9.78 | 0.9838518889 | | ogbn-arxiv | 40 | 51200 | 21.65015471 | 17.39 | 0.8032186603 | | reddit | 5 | 6400 | 4.485681872 | 47.60 | 10.61118206 | | reddit | 5 | 12800 | 8.203881669 | 48.36 | 5.894866842 | | reddit | 5 | 25600 | 10.19984847 | 51.61 | 5.05981494 | | reddit | 5 | 51200 | 25.52061113 | 61.15 | 2.39617171 | | reddit | 20 | 6400 | 9.60336474 | 51.21 | 5.333003796 | | reddit | 20 | 12800 | 22.43147231 | 60.14 | 2.681092588 | | reddit | 20 | 25600 | 23.204309 | 70.10 | 3.021163687 | | reddit | 20 | 51200 | 27.07365799 | 76.18 | 2.813953476 | | reddit | 40 | 6400 | 24.64297758 | 60.25 | 2.445081387 | | reddit | 40 | 12800 | 23.05950785 | 68.38 | 2.965428975 | | reddit | 40 | 25600 | 24.84033842 | 74.12 | 2.983957307 | | reddit | 40 | 51200 | 30.75342988 | 87.18 | 2.834787134 | **Bottleneck after the PR** ```python Timer unit: 1e-06 s Total time: 0.022579 s File: /datasets/vjawa/miniconda3/envs/cugraph_dev_aug_10/lib/python3.9/site-packages/cugraph-22.10.0a0+45.g3ff5b53ff.dirty-py3.9-linux-x86_64.egg/cugraph/gnn/graph_store.py Function: sample_neighbors at line 181 Line # Hits Time Per Hit % Time Line Contents ============================================================== 181 def sample_neighbors( 182 self, nodes, fanout=-1, edge_dir="in", prob=None, replace=False 183 ): ................ 216 """ 217 218 1 2.0 2.0 0.0 if edge_dir not in ["in", "out"]: 219 raise ValueError( 220 f"edge_dir must be either 'in' or 'out' got {edge_dir} instead" 221 ) 222 223 1 1.0 1.0 0.0 if edge_dir == "in": 224 1 1.0 1.0 0.0 sg = self.extracted_reverse_subgraph_without_renumbering 225 else: 226 sg = self.extracted_subgraph_without_renumbering 227 228 1 1.0 1.0 0.0 if not hasattr(self, '_sg_node_dtype'): 229 self._sg_node_dtype = sg.edgelist.edgelist_df['src'].dtype 230 231 # Uniform sampling assumes fails when the dtype 232 # if the seed dtype is not same as the node dtype 233 1 774.0 774.0 3.4 nodes = cudf.from_dlpack(nodes).astype(self._sg_node_dtype) 234 235 2 19303.0 9651.5 85.5 sampled_df = uniform_neighbor_sample( 236 1 1.0 1.0 0.0 sg, start_list=nodes, fanout_vals=[fanout], 237 1 0.0 0.0 0.0 with_replacement=replace, 238 1 1.0 1.0 0.0 is_edge_ids=True # FIXME: Does not seem to do anything 239 ) 240 241 # handle empty graph case 242 1 17.0 17.0 0.1 if len(sampled_df) == 0: 243 return None, None, None 244 245 # we reverse directions when directions=='in' 246 1 1.0 1.0 0.0 if edge_dir == "in": 247 2 136.0 68.0 0.6 sampled_df.rename( 248 1 1.0 1.0 0.0 columns={"destinations": src_n, "sources": dst_n}, inplace=True 249 ) 250 else: 251 sampled_df.rename( 252 columns={"sources": src_n, "destinations": dst_n}, inplace=True 253 ) 254 255 1 2.0 2.0 0.0 return ( 256 1 786.0 786.0 3.5 sampled_df[src_n].to_dlpack(), 257 1 776.0 776.0 3.4 sampled_df[dst_n].to_dlpack(), 258 1 776.0 776.0 3.4 sampled_df['indices'].to_dlpack(), 259 ) ``` Authors: - Vibhu Jawa (https://github.com/VibhuJawa) Approvers: - Brad Rees (https://github.com/BradReesWork) - Rick Ratzel (https://github.com/rlratzel) - Alex Barghi (https://github.com/alexbarghi-nv) URL: #2550
1 parent 6632d1e commit b800f0b

File tree

2 files changed

+43
-17
lines changed

2 files changed

+43
-17
lines changed

python/cugraph/cugraph/gnn/graph_store.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,12 @@ def sample_neighbors(
234234

235235
sampled_df = uniform_neighbor_sample(
236236
sg, start_list=nodes, fanout_vals=[fanout],
237-
with_replacement=replace
237+
with_replacement=replace,
238+
is_edge_ids=True
239+
# FIXME: is_edge_ids=True does not seem to do anything
240+
# issue https://github.com/rapidsai/cugraph/issues/2562
238241
)
239242

240-
sampled_df.drop(columns=["indices"], inplace=True)
241-
242243
# handle empty graph case
243244
if len(sampled_df) == 0:
244245
return None, None, None
@@ -253,40 +254,35 @@ def sample_neighbors(
253254
columns={"sources": src_n, "destinations": dst_n}, inplace=True
254255
)
255256

256-
# FIXME: Remove once below lands
257-
# https://github.com/rapidsai/cugraph/issues/2444
258-
edge_df = self.gdata._edge_prop_dataframe[[src_n, dst_n, eid_n]]
259-
sampled_df = edge_df.merge(sampled_df)
260-
261257
return (
262258
sampled_df[src_n].to_dlpack(),
263259
sampled_df[dst_n].to_dlpack(),
264-
sampled_df[eid_n].to_dlpack(),
260+
sampled_df['indices'].to_dlpack(),
265261
)
266262

267263
@cached_property
268264
def extracted_reverse_subgraph_without_renumbering(self):
269265
# TODO: Switch to extract_subgraph based on response on
270266
# https://github.com/rapidsai/cugraph/issues/2458
271-
subset_df = self.gdata._edge_prop_dataframe[[src_n, dst_n]]
267+
subset_df = self.gdata._edge_prop_dataframe[[src_n, dst_n, eid_n]]
272268
subset_df.rename(columns={src_n: dst_n, dst_n: src_n}, inplace=True)
273-
subset_df["weight"] = cp.float32(1.0)
274269
subgraph = cugraph.Graph(directed=True)
275270
subgraph.from_cudf_edgelist(
276271
subset_df,
277272
source=src_n,
278273
destination=dst_n,
279-
edge_attr="weight",
280-
legacy_renum_only=True,
274+
edge_attr=eid_n,
275+
renumber=False,
276+
legacy_renum_only=False,
281277
)
282278
return subgraph
283279

284280
@cached_property
285281
def extracted_subgraph_without_renumbering(self):
286282
gr_template = cugraph.Graph(directed=True)
287283
subgraph = self.gdata.extract_subgraph(create_using=gr_template,
288-
default_edge_weight=1.0,
289-
renumber_graph=True)
284+
edge_weight_property=eid_n,
285+
renumber_graph=False)
290286
return subgraph
291287

292288
def find_edges(self, edge_ids_cap, etype):

python/cugraph/cugraph/tests/test_graph_store.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ def test_sample_neighbors(graph_file):
160160
assert len(parents_list) > 0
161161

162162

163-
@pytest.mark.skip(reason="Neg one fanout fails see cugraph/issues/2446")
164163
@pytest.mark.parametrize("graph_file", utils.DATASETS)
165164
def test_sample_neighbor_neg_one_fanout(graph_file):
166165
cu_M = utils.read_csv_file(graph_file)
@@ -425,7 +424,6 @@ def test_sampling_gs(dataset1_CuGraphStore):
425424
assert len(src_ser) != 0
426425

427426

428-
@pytest.mark.skip(reason="Neg one fanout fails see cugraph/issues/2446")
429427
def test_sampling_dataset_gs_neg_one_fanout(dataset1_CuGraphStore):
430428
node_pack = cp.asarray([4]).toDlpack()
431429
gs = dataset1_CuGraphStore
@@ -460,9 +458,11 @@ def test_sampling_gs_out_dir():
460458
if sample_src is None:
461459
sample_src = cudf.Series([]).astype(np.int64)
462460
sample_dst = cudf.Series([]).astype(np.int64)
461+
sample_eid = cudf.Series([]).astype(np.int64)
463462
else:
464463
sample_src = cudf.from_dlpack(sample_src)
465464
sample_dst = cudf.from_dlpack(sample_dst)
465+
sample_eid = cudf.from_dlpack(sample_eid)
466466

467467
output_df = cudf.DataFrame({"src": sample_src, "dst": sample_dst})
468468
output_df = output_df.sort_values(by=["src", "dst"])
@@ -473,6 +473,11 @@ def test_sampling_gs_out_dir():
473473
).astype(np.int64)
474474
cudf.testing.assert_frame_equal(output_df, expected_df)
475475

476+
sample_edge_id_df = cudf.DataFrame({"src": sample_src,
477+
"dst": sample_dst,
478+
'edge_id': sample_eid})
479+
assert_correct_eids(df, sample_edge_id_df)
480+
476481

477482
def test_sampling_gs_in_dir():
478483
src_ser = cudf.Series([1, 1, 1, 1, 1, 2, 2, 3])
@@ -498,9 +503,11 @@ def test_sampling_gs_in_dir():
498503
if sample_src is None:
499504
sample_src = cudf.Series([]).astype(np.int64)
500505
sample_dst = cudf.Series([]).astype(np.int64)
506+
sample_eid = cudf.Series([]).astype(np.int64)
501507
else:
502508
sample_src = cudf.from_dlpack(sample_src)
503509
sample_dst = cudf.from_dlpack(sample_dst)
510+
sample_eid = cudf.from_dlpack(sample_eid)
504511

505512
output_df = cudf.DataFrame({"src": sample_src, "dst": sample_dst})
506513
output_df = output_df.sort_values(by=["src", "dst"])
@@ -510,3 +517,26 @@ def test_sampling_gs_in_dir():
510517
{"src": expected_in[seed][0], "dst": expected_in[seed][1]}
511518
).astype(np.int64)
512519
cudf.testing.assert_frame_equal(output_df, expected_df)
520+
521+
sample_edge_id_df = cudf.DataFrame({"src": sample_src,
522+
"dst": sample_dst,
523+
'edge_id': sample_eid})
524+
525+
assert_correct_eids(df, sample_edge_id_df)
526+
527+
528+
def assert_correct_eids(edge_df, sample_edge_id_df):
529+
# We test that all src, dst correspond to the correct
530+
# eids in the sample_edge_id_df
531+
# we do this by ensuring that the inner merge to edge_df
532+
# remains the same as sample_edge_id_df
533+
# if they don't correspond correctly
534+
# the inner merge would fail
535+
536+
sample_edge_id_df = sample_edge_id_df.sort_values(by='edge_id')
537+
sample_edge_id_df = sample_edge_id_df.reset_index(drop=True)
538+
539+
sample_merged_df = sample_edge_id_df.merge(edge_df, how='inner')
540+
sample_merged_df = sample_merged_df.sort_values(by='edge_id')
541+
sample_merged_df = sample_merged_df.reset_index(drop=True)
542+
assert sample_merged_df.equals(sample_edge_id_df)

0 commit comments

Comments
 (0)