Skip to content

Commit df82b74

Browse files
authored
Fix Fanout -1 (#2358)
This PR fixes sampling when for the default value of -1 which is called below: https://github.com/rapidsai/cugraph/blob/92f6ba451b2d9e6c3f60dbccfa05bbf3c480e43a/python/cugraph/cugraph/gnn/graph_store.py#L129-L133 In workflows this is called during inference so we need this to work for inference to work .with CuGraphStorage . Authors: - Vibhu Jawa (https://github.com/VibhuJawa) Approvers: - Rick Ratzel (https://github.com/rlratzel) - Alex Barghi (https://github.com/alexbarghi-nv) URL: #2358
1 parent 9578cfe commit df82b74

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

python/cugraph/cugraph/gnn/graph_store.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ def sample_neighbors(self,
8585
Node IDs to sample neighbors from.
8686
fanout : int
8787
The number of edges to be sampled for each node on each edge type.
88+
If -1 is given all the neighboring edges for each node on
89+
each edge type will be selected.
8890
edge_dir : str {"in" or "out"}
8991
Determines whether to sample inbound or outbound edges.
9092
Can take either in for inbound edges or out for outbound edges.

python/cugraph/cugraph/tests/test_graph_store.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,29 @@ def test_sample_neighbors(graph_file):
164164
assert len(parents_list) > 0
165165

166166

167+
@pytest.mark.parametrize("graph_file", utils.DATASETS)
168+
def test_sample_neighbor_neg_one_fanout(graph_file):
169+
cu_M = utils.read_csv_file(graph_file)
170+
171+
g = cugraph.Graph(directed=True)
172+
g.from_cudf_edgelist(cu_M, source='0', destination='1', renumber=True)
173+
174+
pg = PropertyGraph()
175+
pg.add_edge_data(cu_M,
176+
type_name="edge",
177+
vertex_col_names=("0", "1"),
178+
property_columns=["2"])
179+
180+
gstore = cugraph.gnn.CuGraphStore(graph=pg)
181+
182+
nodes = gstore.get_vertex_ids()
183+
sampled_nodes = nodes[:5]
184+
# -1, default fan_out
185+
parents_list, children_list = gstore.sample_neighbors(sampled_nodes, -1)
186+
187+
assert len(parents_list) > 0
188+
189+
167190
@pytest.mark.parametrize("graph_file", utils.DATASETS)
168191
def test_n_data(graph_file):
169192
cu_M = utils.read_csv_file(graph_file)

python/cugraph/cugraph/utilities/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,14 +479,17 @@ def create_random_bipartite(v1, v2, size, dtype):
479479

480480

481481
def sample_groups(df, by, n_samples):
482-
# Sample n_samples in the df frm by column
482+
# Sample n_samples in the df using the by column
483483

484484
# Step 1
485485
# first, shuffle the dataframe and reset its index,
486486
# so that the ordering of values within each group
487487
# is made random:
488488
df = df.sample(frac=1).reset_index(drop=True)
489489

490+
# If we want to keep all samples we return
491+
if n_samples == -1:
492+
return df
490493
# Step 2
491494
# add an integer-encoded version of the "by" column,
492495
# since the rank aggregation seems not to work for

0 commit comments

Comments
 (0)