1616
1717#include " egonet_validate.hpp"
1818
19+ #include < structure/detail/structure_utils.cuh>
20+
1921#include < utilities/base_fixture.hpp>
2022#include < utilities/device_comm_wrapper.hpp>
2123#include < utilities/high_res_clock.h>
3436#include < thrust/execution_policy.h>
3537#include < thrust/iterator/counting_iterator.h>
3638#include < thrust/sequence.h>
39+ #include < thrust/sort.h>
3740
3841#include < gtest/gtest.h>
3942
@@ -126,7 +129,7 @@ class Tests_MGEgonet
126129 cugraph::extract_ego (
127130 *handle_,
128131 mg_graph_view,
129- raft::device_span<vertex_t const >{d_ego_sources.data (), egonet_usecase. ego_sources_ .size ()},
132+ raft::device_span<vertex_t const >{d_ego_sources.data (), d_ego_sources .size ()},
130133 static_cast <vertex_t >(egonet_usecase.radius_ ));
131134
132135 if (cugraph::test::g_perf) {
@@ -138,11 +141,6 @@ class Tests_MGEgonet
138141 }
139142
140143 if (egonet_usecase.check_correctness_ ) {
141- *d_renumber_map_labels = cugraph::test::device_gatherv (
142- *handle_,
143- raft::device_span<vertex_t const >(d_renumber_map_labels->data (),
144- d_renumber_map_labels->size ()));
145-
146144 d_ego_edgelist_src = cugraph::test::device_gatherv (
147145 *handle_,
148146 raft::device_span<vertex_t const >(d_ego_edgelist_src.data (), d_ego_edgelist_src.size ()));
@@ -157,54 +155,48 @@ class Tests_MGEgonet
157155 d_ego_edgelist_wgt->data (), d_ego_edgelist_wgt->size ()));
158156 }
159157
160- d_ego_edgelist_offsets = cugraph::test::device_gatherv (
161- *handle_,
158+ size_t offsets_size = d_ego_edgelist_offsets.size ();
159+
160+ auto graph_ids_v = cugraph::detail::expand_sparse_offsets (
162161 raft::device_span<size_t const >(d_ego_edgelist_offsets.data (),
163- d_ego_edgelist_offsets.size ()));
162+ d_ego_edgelist_offsets.size ()),
163+ vertex_t {0 },
164+ handle_->get_stream ());
164165
165- auto [sg_graph, sg_number_map] =
166- cugraph::test::mg_graph_to_sg_graph ( *handle_, mg_graph_view, d_renumber_map_labels, false );
166+ graph_ids_v = cugraph::test::device_gatherv (
167+ *handle_, raft::device_span< vertex_t const >(graph_ids_v. data (), graph_ids_v. size ()) );
167168
168- if (my_rank == 0 ) {
169- cugraph::unrenumber_int_vertices<vertex_t , false >(
170- *handle_,
171- d_ego_edgelist_src.data (),
172- d_ego_edgelist_src.size (),
173- d_renumber_map_labels->data (),
174- std::vector<vertex_t >{mg_graph_view.number_of_vertices ()});
175-
176- cugraph::unrenumber_int_vertices<vertex_t , false >(
177- *handle_,
178- d_ego_edgelist_dst.data (),
179- d_ego_edgelist_dst.size (),
180- d_renumber_map_labels->data (),
181- std::vector<vertex_t >{mg_graph_view.number_of_vertices ()});
182-
183- rmm::device_uvector<vertex_t > d_sg_ego_sources (egonet_usecase.ego_sources_ .size (),
184- handle_->get_stream ());
185-
186- if constexpr (std::is_same<int32_t , vertex_t >::value) {
187- raft::update_device (d_sg_ego_sources.data (),
188- egonet_usecase.ego_sources_ .data (),
189- egonet_usecase.ego_sources_ .size (),
190- handle_->get_stream ());
191- } else {
192- std::vector<vertex_t > h_ego_sources (d_sg_ego_sources.size ());
193- std::transform (egonet_usecase.ego_sources_ .begin (),
194- egonet_usecase.ego_sources_ .end (),
195- h_ego_sources.begin (),
196- [](auto v) { return static_cast <vertex_t >(v); });
197- raft::update_device (d_sg_ego_sources.data (),
198- h_ego_sources.data (),
199- h_ego_sources.size (),
200- handle_->get_stream ());
201- }
169+ if (d_ego_edgelist_wgt) {
170+ thrust::sort_by_key (
171+ handle_->get_thrust_policy (),
172+ thrust::make_zip_iterator (
173+ graph_ids_v.begin (), d_ego_edgelist_src.begin (), d_ego_edgelist_dst.begin ()),
174+ thrust::make_zip_iterator (
175+ graph_ids_v.end (), d_ego_edgelist_src.end (), d_ego_edgelist_dst.end ()),
176+ d_ego_edgelist_wgt->begin ());
177+ } else {
178+ thrust::sort (handle_->get_thrust_policy (),
179+ thrust::make_zip_iterator (
180+ graph_ids_v.begin (), d_ego_edgelist_src.begin (), d_ego_edgelist_dst.begin ()),
181+ thrust::make_zip_iterator (
182+ graph_ids_v.end (), d_ego_edgelist_src.end (), d_ego_edgelist_dst.end ()));
183+ }
184+
185+ d_ego_edgelist_offsets = cugraph::detail::compute_sparse_offsets<size_t >(
186+ graph_ids_v.begin (), graph_ids_v.end (), size_t {0 }, offsets_size - 1 , handle_->get_stream ());
202187
188+ auto [sg_graph, sg_number_map] = cugraph::test::mg_graph_to_sg_graph (
189+ *handle_, mg_graph_view, std::optional<rmm::device_uvector<vertex_t >>{std::nullopt }, false );
190+
191+ d_ego_sources = cugraph::test::device_gatherv (
192+ *handle_, raft::device_span<vertex_t const >(d_ego_sources.data (), d_ego_sources.size ()));
193+
194+ if (my_rank == 0 ) {
203195 auto [d_reference_src, d_reference_dst, d_reference_wgt, d_reference_offsets] =
204196 cugraph::extract_ego (
205197 *handle_,
206198 sg_graph.view (),
207- raft::device_span<vertex_t const >{d_sg_ego_sources .data (), d_sg_ego_sources .size ()},
199+ raft::device_span<vertex_t const >{d_ego_sources .data (), d_ego_sources .size ()},
208200 static_cast <vertex_t >(egonet_usecase.radius_ ));
209201
210202 cugraph::test::egonet_validate (*handle_,
@@ -289,7 +281,7 @@ INSTANTIATE_TEST_SUITE_P(
289281 Tests_MGEgonet_File,
290282 ::testing::Combine (
291283 // disable correctness checks for large graphs
292- ::testing::Values (Egonet_Usecase{std::vector<int32_t >{5 , 9 , 3 , 10 , 12 , 13 }, 2 , true , true }),
284+ ::testing::Values (Egonet_Usecase{std::vector<int32_t >{5 , 9 , 3 , 10 , 12 , 13 }, 2 , true , false }),
293285 ::testing::Values(cugraph::test::File_Usecase(" test/datasets/karate.mtx" ))));
294286
295287INSTANTIATE_TEST_SUITE_P (
@@ -301,7 +293,7 @@ INSTANTIATE_TEST_SUITE_P(
301293 Tests_MGEgonet_File64,
302294 ::testing::Combine (
303295 // disable correctness checks for large graphs
304- ::testing::Values (Egonet_Usecase{std::vector<int32_t >{5 , 9 , 3 , 10 , 12 , 13 }, 2 , true , true }),
296+ ::testing::Values (Egonet_Usecase{std::vector<int32_t >{5 , 9 , 3 , 10 , 12 , 13 }, 2 , true , false }),
305297 ::testing::Values(cugraph::test::File_Usecase(" test/datasets/karate.mtx" ))));
306298
307299INSTANTIATE_TEST_SUITE_P (
@@ -313,12 +305,7 @@ INSTANTIATE_TEST_SUITE_P(
313305 Tests_MGEgonet_Rmat,
314306 ::testing::Combine (
315307 // disable correctness checks for large graphs
316- ::testing::Values (Egonet_Usecase{std::vector<int32_t >{0 }, 1 , false , true },
317- Egonet_Usecase{std::vector<int32_t >{0 }, 2 , false , true },
318- Egonet_Usecase{std::vector<int32_t >{0 }, 3 , false , true },
319- Egonet_Usecase{std::vector<int32_t >{10 , 0 , 5 }, 2 , false , true },
320- Egonet_Usecase{std::vector<int32_t >{9 , 3 , 10 }, 2 , false , true },
321- Egonet_Usecase{std::vector<int32_t >{5 , 9 , 3 , 10 , 12 , 13 }, 2 , true , true }),
308+ ::testing::Values (Egonet_Usecase{std::vector<int32_t >{5 , 9 , 3 , 10 , 12 , 13 }, 2 , true , false }),
322309 ::testing::Values(
323310 cugraph::test::Rmat_Usecase (20 , 32 , 0.57 , 0.19 , 0.19 , 0 , true , false , 0 , true ))));
324311
@@ -331,7 +318,7 @@ INSTANTIATE_TEST_SUITE_P(
331318 Tests_MGEgonet_Rmat64,
332319 ::testing::Combine (
333320 // disable correctness checks for large graphs
334- ::testing::Values (Egonet_Usecase{std::vector<int32_t >{5 , 9 , 3 , 10 , 12 , 13 }, 2 , true , true }),
321+ ::testing::Values (Egonet_Usecase{std::vector<int32_t >{5 , 9 , 3 , 10 , 12 , 13 }, 2 , true , false }),
335322 ::testing::Values(
336323 cugraph::test::Rmat_Usecase (20 , 32 , 0.57 , 0.19 , 0.19 , 0 , true , false , 0 , true ))));
337324
0 commit comments