Skip to content

Commit 3c301a4

Browse files
committed
add options to drop self-loops & multi_edges in test graph generation
1 parent 61950dd commit 3c301a4

File tree

3 files changed

+152
-4
lines changed

3 files changed

+152
-4
lines changed

cpp/tests/utilities/test_graphs.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,13 +537,19 @@ std::tuple<cugraph::graph_t<vertex_t, edge_t, weight_t, store_transposed, multi_
537537
construct_graph(raft::handle_t const& handle,
538538
input_usecase_t const& input_usecase,
539539
bool test_weighted,
540-
bool renumber = true)
540+
bool renumber = true,
541+
bool drop_self_loops = false,
542+
bool drop_multi_edges = false)
541543
{
542544
auto [d_src_v, d_dst_v, d_weights_v, d_vertices_v, num_vertices, is_symmetric] =
543545
input_usecase
544546
.template construct_edgelist<vertex_t, edge_t, weight_t, store_transposed, multi_gpu>(
545547
handle, test_weighted);
546548

549+
if (drop_self_loops) { remove_self_loops(handle, d_src_v, d_dst_v, d_weights_v); }
550+
551+
if (drop_multi_edges) { sort_and_remove_multi_edges(handle, d_src_v, d_dst_v, d_weights_v); }
552+
547553
return cugraph::
548554
create_graph_from_edgelist<vertex_t, edge_t, weight_t, store_transposed, multi_gpu>(
549555
handle,

cpp/tests/utilities/thrust_wrapper.cu

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
#include <rmm/exec_policy.hpp>
2222

2323
#include <thrust/copy.h>
24+
#include <thrust/remove.h>
2425
#include <thrust/shuffle.h>
2526
#include <thrust/sort.h>
27+
#include <thrust/unique.h>
2628

2729
namespace cugraph {
2830
namespace test {
@@ -164,5 +166,132 @@ template rmm::device_uvector<int32_t> randomly_select(raft::handle_t const& hand
164166
template rmm::device_uvector<int64_t> randomly_select(raft::handle_t const& handle,
165167
rmm::device_uvector<int64_t> const& input,
166168
size_t count);
169+
170+
template <typename vertex_t, typename weight_t>
171+
void remove_self_loops(raft::handle_t const& handle,
172+
rmm::device_uvector<vertex_t>& d_src_v /* [INOUT] */,
173+
rmm::device_uvector<vertex_t>& d_dst_v /* [INOUT] */,
174+
std::optional<rmm::device_uvector<weight_t>>& d_weight_v /* [INOUT] */)
175+
{
176+
if (d_weight_v) {
177+
auto edge_first = thrust::make_zip_iterator(
178+
thrust::make_tuple(d_src_v.begin(), d_dst_v.begin(), (*d_weight_v).begin()));
179+
d_src_v.resize(
180+
thrust::distance(edge_first,
181+
thrust::remove_if(
182+
handle.get_thrust_policy(),
183+
edge_first,
184+
edge_first + d_src_v.size(),
185+
[] __device__(auto e) { return thrust::get<0>(e) == thrust::get<1>(e); })),
186+
handle.get_stream());
187+
d_dst_v.resize(d_src_v.size(), handle.get_stream());
188+
(*d_weight_v).resize(d_src_v.size(), handle.get_stream());
189+
} else {
190+
auto edge_first =
191+
thrust::make_zip_iterator(thrust::make_tuple(d_src_v.begin(), d_dst_v.begin()));
192+
d_src_v.resize(
193+
thrust::distance(edge_first,
194+
thrust::remove_if(
195+
handle.get_thrust_policy(),
196+
edge_first,
197+
edge_first + d_src_v.size(),
198+
[] __device__(auto e) { return thrust::get<0>(e) == thrust::get<1>(e); })),
199+
handle.get_stream());
200+
d_dst_v.resize(d_src_v.size(), handle.get_stream());
201+
}
202+
203+
d_src_v.shrink_to_fit(handle.get_stream());
204+
d_dst_v.shrink_to_fit(handle.get_stream());
205+
if (d_weight_v) { (*d_weight_v).shrink_to_fit(handle.get_stream()); }
206+
}
207+
208+
template void remove_self_loops(
209+
raft::handle_t const& handle,
210+
rmm::device_uvector<int32_t>& d_src_v /* [INOUT] */,
211+
rmm::device_uvector<int32_t>& d_dst_v /* [INOUT] */,
212+
std::optional<rmm::device_uvector<float>>& d_weight_v /* [INOUT] */);
213+
214+
template void remove_self_loops(
215+
raft::handle_t const& handle,
216+
rmm::device_uvector<int32_t>& d_src_v /* [INOUT] */,
217+
rmm::device_uvector<int32_t>& d_dst_v /* [INOUT] */,
218+
std::optional<rmm::device_uvector<double>>& d_weight_v /* [INOUT] */);
219+
220+
template void remove_self_loops(
221+
raft::handle_t const& handle,
222+
rmm::device_uvector<int64_t>& d_src_v /* [INOUT] */,
223+
rmm::device_uvector<int64_t>& d_dst_v /* [INOUT] */,
224+
std::optional<rmm::device_uvector<float>>& d_weight_v /* [INOUT] */);
225+
226+
template void remove_self_loops(
227+
raft::handle_t const& handle,
228+
rmm::device_uvector<int64_t>& d_src_v /* [INOUT] */,
229+
rmm::device_uvector<int64_t>& d_dst_v /* [INOUT] */,
230+
std::optional<rmm::device_uvector<double>>& d_weight_v /* [INOUT] */);
231+
232+
template <typename vertex_t, typename weight_t>
233+
void sort_and_remove_multi_edges(
234+
raft::handle_t const& handle,
235+
rmm::device_uvector<vertex_t>& d_src_v /* [INOUT] */,
236+
rmm::device_uvector<vertex_t>& d_dst_v /* [INOUT] */,
237+
std::optional<rmm::device_uvector<weight_t>>& d_weight_v /* [INOUT] */)
238+
{
239+
if (d_weight_v) {
240+
auto edge_first = thrust::make_zip_iterator(
241+
thrust::make_tuple(d_src_v.begin(), d_dst_v.begin(), (*d_weight_v).begin()));
242+
thrust::sort(handle.get_thrust_policy(), edge_first, edge_first + d_src_v.size());
243+
d_src_v.resize(
244+
thrust::distance(edge_first,
245+
thrust::unique(handle.get_thrust_policy(),
246+
edge_first,
247+
edge_first + d_src_v.size(),
248+
[] __device__(auto lhs, auto rhs) {
249+
return (thrust::get<0>(lhs) == thrust::get<0>(rhs)) &&
250+
(thrust::get<1>(lhs) == thrust::get<1>(rhs));
251+
})),
252+
handle.get_stream());
253+
d_dst_v.resize(d_src_v.size(), handle.get_stream());
254+
(*d_weight_v).resize(d_src_v.size(), handle.get_stream());
255+
} else {
256+
auto edge_first =
257+
thrust::make_zip_iterator(thrust::make_tuple(d_src_v.begin(), d_dst_v.begin()));
258+
thrust::sort(handle.get_thrust_policy(), edge_first, edge_first + d_src_v.size());
259+
d_src_v.resize(
260+
thrust::distance(
261+
edge_first,
262+
thrust::unique(handle.get_thrust_policy(), edge_first, edge_first + d_src_v.size())),
263+
handle.get_stream());
264+
d_dst_v.resize(d_src_v.size(), handle.get_stream());
265+
}
266+
267+
d_src_v.shrink_to_fit(handle.get_stream());
268+
d_dst_v.shrink_to_fit(handle.get_stream());
269+
if (d_weight_v) { (*d_weight_v).shrink_to_fit(handle.get_stream()); }
270+
}
271+
272+
template void sort_and_remove_multi_edges(
273+
raft::handle_t const& handle,
274+
rmm::device_uvector<int32_t>& d_src_v /* [INOUT] */,
275+
rmm::device_uvector<int32_t>& d_dst_v /* [INOUT] */,
276+
std::optional<rmm::device_uvector<float>>& d_weight_v /* [INOUT] */);
277+
278+
template void sort_and_remove_multi_edges(
279+
raft::handle_t const& handle,
280+
rmm::device_uvector<int32_t>& d_src_v /* [INOUT] */,
281+
rmm::device_uvector<int32_t>& d_dst_v /* [INOUT] */,
282+
std::optional<rmm::device_uvector<double>>& d_weight_v /* [INOUT] */);
283+
284+
template void sort_and_remove_multi_edges(
285+
raft::handle_t const& handle,
286+
rmm::device_uvector<int64_t>& d_src_v /* [INOUT] */,
287+
rmm::device_uvector<int64_t>& d_dst_v /* [INOUT] */,
288+
std::optional<rmm::device_uvector<float>>& d_weight_v /* [INOUT] */);
289+
290+
template void sort_and_remove_multi_edges(
291+
raft::handle_t const& handle,
292+
rmm::device_uvector<int64_t>& d_src_v /* [INOUT] */,
293+
rmm::device_uvector<int64_t>& d_dst_v /* [INOUT] */,
294+
std::optional<rmm::device_uvector<double>>& d_weight_v /* [INOUT] */);
295+
167296
} // namespace test
168297
} // namespace cugraph

cpp/tests/utilities/thrust_wrapper.hpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,32 @@ std::tuple<key_buffer_type, value_buffer_type> sort_by_key(raft::handle_t const&
2929

3030
template <typename vertex_t>
3131
void translate_vertex_ids(raft::handle_t const& handle,
32-
rmm::device_uvector<vertex_t>& d_src_v,
33-
rmm::device_uvector<vertex_t>& d_dst_v,
32+
rmm::device_uvector<vertex_t>& d_src_v /* [INOUT] */,
33+
rmm::device_uvector<vertex_t>& d_dst_v /* [INOUT] */,
3434
vertex_t vertex_id_offset);
3535

3636
template <typename vertex_t>
3737
void populate_vertex_ids(raft::handle_t const& handle,
38-
rmm::device_uvector<vertex_t>& d_vertices_v,
38+
rmm::device_uvector<vertex_t>& d_vertices_v /* [INOUT] */,
3939
vertex_t vertex_id_offset);
4040

4141
template <typename T>
4242
rmm::device_uvector<T> randomly_select(raft::handle_t const& handle,
4343
rmm::device_uvector<T> const& input,
4444
size_t count);
4545

46+
template <typename vertex_t, typename weight_t>
47+
void remove_self_loops(raft::handle_t const& handle,
48+
rmm::device_uvector<vertex_t>& d_src_v /* [INOUT] */,
49+
rmm::device_uvector<vertex_t>& d_dst_v /* [INOUT] */,
50+
std::optional<rmm::device_uvector<weight_t>>& d_weight_v /* [INOUT] */);
51+
52+
template <typename vertex_t, typename weight_t>
53+
void sort_and_remove_multi_edges(
54+
raft::handle_t const& handle,
55+
rmm::device_uvector<vertex_t>& d_src_v /* [INOUT] */,
56+
rmm::device_uvector<vertex_t>& d_dst_v /* [INOUT] */,
57+
std::optional<rmm::device_uvector<weight_t>>& d_weight_v /* [INOUT] */);
58+
4659
} // namespace test
4760
} // namespace cugraph

0 commit comments

Comments
 (0)