Skip to content

Commit b4a4160

Browse files
authored
Replace raw pointers with device_span in induced subgraph (#2348)
Changed the interface of function extract_induced_subgraphs, and all corresponding test codes using this interface. Authors: - Yang Hu (https://github.com/yang-hu-nv) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) - Seunghwa Kang (https://github.com/seunghwak) URL: #2348
1 parent 8587225 commit b4a4160

File tree

7 files changed

+115
-98
lines changed

7 files changed

+115
-98
lines changed

cpp/include/cugraph/graph_functions.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <cugraph/graph_view.hpp>
2020

2121
#include <raft/handle.hpp>
22+
#include <raft/span.hpp>
2223
#include <rmm/device_uvector.hpp>
2324

2425
#include <memory>
@@ -471,8 +472,8 @@ std::tuple<rmm::device_uvector<vertex_t>,
471472
extract_induced_subgraphs(
472473
raft::handle_t const& handle,
473474
graph_view_t<vertex_t, edge_t, weight_t, store_transposed, multi_gpu> const& graph_view,
474-
size_t const* subgraph_offsets /* size == num_subgraphs + 1 */,
475-
vertex_t const* subgraph_vertices /* size == subgraph_offsets[num_subgraphs] */,
475+
raft::device_span<size_t const> subgraph_offsets /* size == num_subgraphs + 1 */,
476+
raft::device_span<vertex_t const> subgraph_vertices /* size == subgraph_offsets[num_subgraphs] */,
476477
size_t num_subgraphs,
477478
bool do_expensive_check = false);
478479

cpp/src/community/legacy/egonet.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,12 @@ extract(raft::handle_t const& handle,
169169

170170
// extract
171171
return cugraph::extract_induced_subgraphs(
172-
handle, csr_view, neighbors_offsets.data().get(), neighbors.data().get(), n_subgraphs);
172+
handle,
173+
csr_view,
174+
raft::device_span<size_t const>(neighbors_offsets.data().get(), neighbors_offsets.size()),
175+
raft::device_span<vertex_t const>(neighbors.data().get(), neighbors.size()),
176+
n_subgraphs,
177+
false);
173178
}
174179

175180
} // namespace

cpp/src/structure/induced_subgraph_impl.cuh

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
* limitations under the License.
1515
*/
1616
#pragma once
17-
1817
#include <cugraph/edge_partition_device_view.cuh>
1918
#include <cugraph/graph_functions.hpp>
2019
#include <cugraph/graph_view.hpp>
@@ -48,8 +47,8 @@ std::tuple<rmm::device_uvector<vertex_t>,
4847
extract_induced_subgraphs(
4948
raft::handle_t const& handle,
5049
graph_view_t<vertex_t, edge_t, weight_t, store_transposed, multi_gpu> const& graph_view,
51-
size_t const* subgraph_offsets /* size == num_subgraphs + 1 */,
52-
vertex_t const* subgraph_vertices /* size == subgraph_offsets[num_subgraphs] */,
50+
raft::device_span<size_t const> subgraph_offsets /* size == num_subgraphs + 1 */,
51+
raft::device_span<vertex_t const> subgraph_vertices /* size == subgraph_offsets[num_subgraphs] */,
5352
size_t num_subgraphs,
5453
bool do_expensive_check)
5554
{
@@ -69,24 +68,25 @@ extract_induced_subgraphs(
6968
if (do_expensive_check) {
7069
size_t should_be_zero{std::numeric_limits<size_t>::max()};
7170
size_t num_aggregate_subgraph_vertices{};
72-
raft::update_host(&should_be_zero, subgraph_offsets, 1, handle.get_stream());
73-
raft::update_host(
74-
&num_aggregate_subgraph_vertices, subgraph_offsets + num_subgraphs, 1, handle.get_stream());
71+
raft::update_host(&should_be_zero, subgraph_offsets.data(), 1, handle.get_stream());
72+
raft::update_host(&num_aggregate_subgraph_vertices,
73+
subgraph_offsets.data() + num_subgraphs,
74+
1,
75+
handle.get_stream());
7576
handle.sync_stream();
7677
CUGRAPH_EXPECTS(should_be_zero == 0,
7778
"Invalid input argument: subgraph_offsets[0] should be 0.");
7879

79-
CUGRAPH_EXPECTS(
80-
thrust::is_sorted(
81-
handle.get_thrust_policy(), subgraph_offsets, subgraph_offsets + (num_subgraphs + 1)),
82-
"Invalid input argument: subgraph_offsets is not sorted.");
80+
CUGRAPH_EXPECTS(thrust::is_sorted(
81+
handle.get_thrust_policy(), subgraph_offsets.begin(), subgraph_offsets.end()),
82+
"Invalid input argument: subgraph_offsets is not sorted.");
8383
auto vertex_partition =
8484
vertex_partition_device_view_t<vertex_t, multi_gpu>(graph_view.local_vertex_partition_view());
8585

8686
CUGRAPH_EXPECTS(
8787
thrust::count_if(handle.get_thrust_policy(),
88-
subgraph_vertices,
89-
subgraph_vertices + num_aggregate_subgraph_vertices,
88+
subgraph_vertices.begin(),
89+
subgraph_vertices.end(),
9090
[vertex_partition] __device__(auto v) {
9191
return !vertex_partition.is_valid_vertex(v) ||
9292
!vertex_partition.in_local_vertex_partition_range_nocheck(v);
@@ -101,8 +101,8 @@ extract_induced_subgraphs(
101101
[subgraph_offsets, subgraph_vertices] __device__(auto i) {
102102
// vertices are sorted and unique
103103
return !thrust::is_sorted(thrust::seq,
104-
subgraph_vertices + subgraph_offsets[i],
105-
subgraph_vertices + subgraph_offsets[i + 1]) ||
104+
subgraph_vertices.begin() + subgraph_offsets[i],
105+
subgraph_vertices.begin() + subgraph_offsets[i + 1]) ||
106106
(thrust::count_if(
107107
thrust::seq,
108108
thrust::make_counting_iterator(subgraph_offsets[i]),
@@ -127,8 +127,10 @@ extract_induced_subgraphs(
127127
// 2-1. Phase 1: calculate memory requirements
128128

129129
size_t num_aggregate_subgraph_vertices{};
130-
raft::update_host(
131-
&num_aggregate_subgraph_vertices, subgraph_offsets + num_subgraphs, 1, handle.get_stream());
130+
raft::update_host(&num_aggregate_subgraph_vertices,
131+
subgraph_offsets.data() + num_subgraphs,
132+
1,
133+
handle.get_stream());
132134
handle.sync_stream();
133135

134136
rmm::device_uvector<size_t> subgraph_vertex_output_offsets(
@@ -145,9 +147,10 @@ extract_induced_subgraphs(
145147
thrust::make_counting_iterator(num_aggregate_subgraph_vertices),
146148
subgraph_vertex_output_offsets.begin(),
147149
[subgraph_offsets, subgraph_vertices, num_subgraphs, edge_partition] __device__(auto i) {
148-
auto subgraph_idx = thrust::distance(
149-
subgraph_offsets + 1,
150-
thrust::upper_bound(thrust::seq, subgraph_offsets, subgraph_offsets + num_subgraphs, i));
150+
auto subgraph_idx =
151+
thrust::distance(subgraph_offsets.begin() + 1,
152+
thrust::upper_bound(
153+
thrust::seq, subgraph_offsets.begin(), subgraph_offsets.end() - 1, i));
151154
vertex_t const* indices{nullptr};
152155
thrust::optional<weight_t const*> weights{thrust::nullopt};
153156
edge_t local_degree{};
@@ -158,9 +161,9 @@ extract_induced_subgraphs(
158161
thrust::seq,
159162
indices,
160163
indices + local_degree,
161-
[vertex_first = subgraph_vertices + subgraph_offsets[subgraph_idx],
164+
[vertex_first = subgraph_vertices.begin() + subgraph_offsets[subgraph_idx],
162165
vertex_last =
163-
subgraph_vertices + subgraph_offsets[subgraph_idx + 1]] __device__(auto nbr) {
166+
subgraph_vertices.begin() + subgraph_offsets[subgraph_idx + 1]] __device__(auto nbr) {
164167
return thrust::binary_search(thrust::seq, vertex_first, vertex_last, nbr);
165168
});
166169
});
@@ -201,9 +204,9 @@ extract_induced_subgraphs(
201204
edge_weights = edge_weights ? thrust::optional<weight_t*>{(*edge_weights).data()}
202205
: thrust::nullopt] __device__(auto i) {
203206
auto subgraph_idx = thrust::distance(
204-
subgraph_offsets + 1,
207+
subgraph_offsets.begin() + 1,
205208
thrust::upper_bound(
206-
thrust::seq, subgraph_offsets, subgraph_offsets + num_subgraphs, size_t{i}));
209+
thrust::seq, subgraph_offsets.begin(), subgraph_offsets.end() - 1, size_t{i}));
207210
vertex_t const* indices{nullptr};
208211
thrust::optional<weight_t const*> weights{thrust::nullopt};
209212
edge_t local_degree{};
@@ -219,34 +222,35 @@ extract_induced_subgraphs(
219222
triplet_first + local_degree,
220223
thrust::make_zip_iterator(thrust::make_tuple(edge_majors, edge_minors, *edge_weights)) +
221224
subgraph_vertex_output_offsets[i],
222-
[vertex_first = subgraph_vertices + subgraph_offsets[subgraph_idx],
225+
[vertex_first = subgraph_vertices.begin() + subgraph_offsets[subgraph_idx],
223226
vertex_last =
224-
subgraph_vertices + subgraph_offsets[subgraph_idx + 1]] __device__(auto t) {
227+
subgraph_vertices.begin() + subgraph_offsets[subgraph_idx + 1]] __device__(auto t) {
225228
return thrust::binary_search(
226229
thrust::seq, vertex_first, vertex_last, thrust::get<1>(t));
227230
});
228231
} else {
229232
auto pair_first = thrust::make_zip_iterator(
230233
thrust::make_tuple(thrust::make_constant_iterator(subgraph_vertices[i]), indices));
231234
// FIXME: this is inefficient for high local degree vertices
232-
thrust::copy_if(thrust::seq,
233-
pair_first,
234-
pair_first + local_degree,
235-
thrust::make_zip_iterator(thrust::make_tuple(edge_majors, edge_minors)) +
236-
subgraph_vertex_output_offsets[i],
237-
[vertex_first = subgraph_vertices + subgraph_offsets[subgraph_idx],
238-
vertex_last = subgraph_vertices +
239-
subgraph_offsets[subgraph_idx + 1]] __device__(auto t) {
240-
return thrust::binary_search(
241-
thrust::seq, vertex_first, vertex_last, thrust::get<1>(t));
242-
});
235+
thrust::copy_if(
236+
thrust::seq,
237+
pair_first,
238+
pair_first + local_degree,
239+
thrust::make_zip_iterator(thrust::make_tuple(edge_majors, edge_minors)) +
240+
subgraph_vertex_output_offsets[i],
241+
[vertex_first = subgraph_vertices.begin() + subgraph_offsets[subgraph_idx],
242+
vertex_last =
243+
subgraph_vertices.begin() + subgraph_offsets[subgraph_idx + 1]] __device__(auto t) {
244+
return thrust::binary_search(
245+
thrust::seq, vertex_first, vertex_last, thrust::get<1>(t));
246+
});
243247
}
244248
});
245249

246250
rmm::device_uvector<size_t> subgraph_edge_offsets(num_subgraphs + 1, handle.get_stream());
247251
thrust::gather(handle.get_thrust_policy(),
248-
subgraph_offsets,
249-
subgraph_offsets + (num_subgraphs + 1),
252+
subgraph_offsets.begin(),
253+
subgraph_offsets.end(),
250254
subgraph_vertex_output_offsets.begin(),
251255
subgraph_edge_offsets.begin());
252256
#ifdef TIMING

cpp/src/structure/induced_subgraph_mg.cu

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021, NVIDIA CORPORATION.
2+
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -25,8 +25,8 @@ template std::tuple<rmm::device_uvector<int32_t>,
2525
rmm::device_uvector<size_t>>
2626
extract_induced_subgraphs(raft::handle_t const& handle,
2727
graph_view_t<int32_t, int32_t, float, true, true> const& graph_view,
28-
size_t const* subgraph_offsets,
29-
int32_t const* subgraph_vertices,
28+
raft::device_span<size_t const> subgraph_offsets,
29+
raft::device_span<int32_t const> subgraph_vertices,
3030
size_t num_subgraphs,
3131
bool do_expensive_check);
3232

@@ -36,8 +36,8 @@ template std::tuple<rmm::device_uvector<int32_t>,
3636
rmm::device_uvector<size_t>>
3737
extract_induced_subgraphs(raft::handle_t const& handle,
3838
graph_view_t<int32_t, int32_t, float, false, true> const& graph_view,
39-
size_t const* subgraph_offsets,
40-
int32_t const* subgraph_vertices,
39+
raft::device_span<size_t const> subgraph_offsets,
40+
raft::device_span<int32_t const> subgraph_vertices,
4141
size_t num_subgraphs,
4242
bool do_expensive_check);
4343

@@ -47,8 +47,8 @@ template std::tuple<rmm::device_uvector<int32_t>,
4747
rmm::device_uvector<size_t>>
4848
extract_induced_subgraphs(raft::handle_t const& handle,
4949
graph_view_t<int32_t, int32_t, double, true, true> const& graph_view,
50-
size_t const* subgraph_offsets,
51-
int32_t const* subgraph_vertices,
50+
raft::device_span<size_t const> subgraph_offsets,
51+
raft::device_span<int32_t const> subgraph_vertices,
5252
size_t num_subgraphs,
5353
bool do_expensive_check);
5454

@@ -58,8 +58,8 @@ template std::tuple<rmm::device_uvector<int32_t>,
5858
rmm::device_uvector<size_t>>
5959
extract_induced_subgraphs(raft::handle_t const& handle,
6060
graph_view_t<int32_t, int32_t, double, false, true> const& graph_view,
61-
size_t const* subgraph_offsets,
62-
int32_t const* subgraph_vertices,
61+
raft::device_span<size_t const> subgraph_offsets,
62+
raft::device_span<int32_t const> subgraph_vertices,
6363
size_t num_subgraphs,
6464
bool do_expensive_check);
6565

@@ -69,8 +69,8 @@ template std::tuple<rmm::device_uvector<int32_t>,
6969
rmm::device_uvector<size_t>>
7070
extract_induced_subgraphs(raft::handle_t const& handle,
7171
graph_view_t<int32_t, int64_t, float, true, true> const& graph_view,
72-
size_t const* subgraph_offsets,
73-
int32_t const* subgraph_vertices,
72+
raft::device_span<size_t const> subgraph_offsets,
73+
raft::device_span<int32_t const> subgraph_vertices,
7474
size_t num_subgraphs,
7575
bool do_expensive_check);
7676

@@ -80,8 +80,8 @@ template std::tuple<rmm::device_uvector<int32_t>,
8080
rmm::device_uvector<size_t>>
8181
extract_induced_subgraphs(raft::handle_t const& handle,
8282
graph_view_t<int32_t, int64_t, float, false, true> const& graph_view,
83-
size_t const* subgraph_offsets,
84-
int32_t const* subgraph_vertices,
83+
raft::device_span<size_t const> subgraph_offsets,
84+
raft::device_span<int32_t const> subgraph_vertices,
8585
size_t num_subgraphs,
8686
bool do_expensive_check);
8787

@@ -91,8 +91,8 @@ template std::tuple<rmm::device_uvector<int32_t>,
9191
rmm::device_uvector<size_t>>
9292
extract_induced_subgraphs(raft::handle_t const& handle,
9393
graph_view_t<int32_t, int64_t, double, true, true> const& graph_view,
94-
size_t const* subgraph_offsets,
95-
int32_t const* subgraph_vertices,
94+
raft::device_span<size_t const> subgraph_offsets,
95+
raft::device_span<int32_t const> subgraph_vertices,
9696
size_t num_subgraphs,
9797
bool do_expensive_check);
9898

@@ -102,8 +102,8 @@ template std::tuple<rmm::device_uvector<int32_t>,
102102
rmm::device_uvector<size_t>>
103103
extract_induced_subgraphs(raft::handle_t const& handle,
104104
graph_view_t<int32_t, int64_t, double, false, true> const& graph_view,
105-
size_t const* subgraph_offsets,
106-
int32_t const* subgraph_vertices,
105+
raft::device_span<size_t const> subgraph_offsets,
106+
raft::device_span<int32_t const> subgraph_vertices,
107107
size_t num_subgraphs,
108108
bool do_expensive_check);
109109

@@ -113,8 +113,8 @@ template std::tuple<rmm::device_uvector<int64_t>,
113113
rmm::device_uvector<size_t>>
114114
extract_induced_subgraphs(raft::handle_t const& handle,
115115
graph_view_t<int64_t, int64_t, float, true, true> const& graph_view,
116-
size_t const* subgraph_offsets,
117-
int64_t const* subgraph_vertices,
116+
raft::device_span<size_t const> subgraph_offsets,
117+
raft::device_span<int64_t const> subgraph_vertices,
118118
size_t num_subgraphs,
119119
bool do_expensive_check);
120120

@@ -124,8 +124,8 @@ template std::tuple<rmm::device_uvector<int64_t>,
124124
rmm::device_uvector<size_t>>
125125
extract_induced_subgraphs(raft::handle_t const& handle,
126126
graph_view_t<int64_t, int64_t, float, false, true> const& graph_view,
127-
size_t const* subgraph_offsets,
128-
int64_t const* subgraph_vertices,
127+
raft::device_span<size_t const> subgraph_offsets,
128+
raft::device_span<int64_t const> subgraph_vertices,
129129
size_t num_subgraphs,
130130
bool do_expensive_check);
131131

@@ -135,8 +135,8 @@ template std::tuple<rmm::device_uvector<int64_t>,
135135
rmm::device_uvector<size_t>>
136136
extract_induced_subgraphs(raft::handle_t const& handle,
137137
graph_view_t<int64_t, int64_t, double, true, true> const& graph_view,
138-
size_t const* subgraph_offsets,
139-
int64_t const* subgraph_vertices,
138+
raft::device_span<size_t const> subgraph_offsets,
139+
raft::device_span<int64_t const> subgraph_vertices,
140140
size_t num_subgraphs,
141141
bool do_expensive_check);
142142

@@ -146,8 +146,8 @@ template std::tuple<rmm::device_uvector<int64_t>,
146146
rmm::device_uvector<size_t>>
147147
extract_induced_subgraphs(raft::handle_t const& handle,
148148
graph_view_t<int64_t, int64_t, double, false, true> const& graph_view,
149-
size_t const* subgraph_offsets,
150-
int64_t const* subgraph_vertices,
149+
raft::device_span<size_t const> subgraph_offsets,
150+
raft::device_span<int64_t const> subgraph_vertices,
151151
size_t num_subgraphs,
152152
bool do_expensive_check);
153153

0 commit comments

Comments
 (0)