@@ -499,16 +499,27 @@ def from_internal_vertex_id(
499499 @staticmethod
500500 def renumber_and_segment (
501501 df , src_col_names , dst_col_names , preserve_order = False ,
502- store_transposed = False
502+ store_transposed = False , legacy_renum_only = False
503503 ):
504+ # FIXME: Drop the renumber_type 'experimental' once all the
505+ # algos follow the C/Pylibcugraph path
506+
507+ # The renumber_type 'legacy' runs both the python and the
508+ # C++ renumbering.
504509 if isinstance (src_col_names , list ):
505510 renumber_type = 'legacy'
506511 elif not (df [src_col_names ].dtype == np .int32 or
507512 df [src_col_names ].dtype == np .int64 ):
508513 renumber_type = 'legacy'
509514 else :
515+ # The renumber_type 'experimental' only runs the C++
516+ # renumbering
510517 renumber_type = 'experimental'
511518
519+ if legacy_renum_only and renumber_type == 'experimental' :
520+ # The original dataframe will be returned.
521+ renumber_type = 'skip_renumbering'
522+
512523 renumber_map = NumberMap ()
513524 if not isinstance (src_col_names , list ):
514525 src_col_names = [src_col_names ]
@@ -547,6 +558,12 @@ def renumber_and_segment(
547558 df , renumber_map .renumbered_dst_col_name , dst_col_names ,
548559 drop = True , preserve_order = preserve_order
549560 )
561+ elif renumber_type == 'skip_renumbering' :
562+ # Update the renumbered source and destination column name
563+ # with the original input's source and destination name
564+ renumber_map .renumbered_src_col_name = src_col_names [0 ]
565+ renumber_map .renumbered_dst_col_name = dst_col_names [0 ]
566+
550567 else :
551568 df = df .rename (
552569 columns = {src_col_names [0 ]:
@@ -562,69 +579,77 @@ def renumber_and_segment(
562579 is_mnmg = False
563580
564581 if is_mnmg :
565- client = default_client ()
566- data = get_distributed_data (df )
567- result = [(client .submit (call_renumber ,
568- Comms .get_session_id (),
569- wf [1 ],
570- renumber_map .renumbered_src_col_name ,
571- renumber_map .renumbered_dst_col_name ,
572- num_edges ,
573- is_mnmg ,
574- store_transposed ,
575- workers = [wf [0 ]]), wf [0 ])
576- for idx , wf in enumerate (data .worker_to_parts .items ())]
577- wait (result )
578-
579- def get_renumber_map (id_type , data ):
580- return data [0 ].astype (id_type )
581-
582- def get_segment_offsets (data ):
583- return data [1 ]
584-
585- def get_renumbered_df (id_type , data ):
586- data [2 ][renumber_map .renumbered_src_col_name ] = \
587- data [2 ][renumber_map .renumbered_src_col_name ]\
588- .astype (id_type )
589- data [2 ][renumber_map .renumbered_dst_col_name ] = \
590- data [2 ][renumber_map .renumbered_dst_col_name ]\
591- .astype (id_type )
592- return data [2 ]
593-
594- renumbering_map = dask_cudf .from_delayed (
595- [client .submit (get_renumber_map ,
596- id_type ,
597- data ,
598- workers = [wf ])
599- for (data , wf ) in result ])
600-
601- list_of_segment_offsets = client .gather (
602- [client .submit (get_segment_offsets ,
603- data ,
604- workers = [wf ])
605- for (data , wf ) in result ])
606- aggregate_segment_offsets = []
607- for segment_offsets in list_of_segment_offsets :
608- aggregate_segment_offsets .extend (segment_offsets )
609-
610- renumbered_df = dask_cudf .from_delayed (
611- [client .submit (get_renumbered_df ,
612- id_type ,
613- data ,
614- workers = [wf ])
615- for (data , wf ) in result ])
616- if renumber_type == 'legacy' :
617- renumber_map .implementation .ddf = indirection_map .merge (
618- renumbering_map ,
619- right_on = 'original_ids' , left_on = 'global_id' ,
620- how = 'right' ).\
621- drop (columns = ['global_id' , 'original_ids' ])\
622- .rename (columns = {'new_ids' : 'global_id' })
582+ # Do not renumber the algos following the C/Pylibcugraph path
583+ if renumber_type in ['legacy' , 'experimental' ]:
584+ client = default_client ()
585+ data = get_distributed_data (df )
586+ result = [(client .submit (call_renumber ,
587+ Comms .get_session_id (),
588+ wf [1 ],
589+ renumber_map .renumbered_src_col_name ,
590+ renumber_map .renumbered_dst_col_name ,
591+ num_edges ,
592+ is_mnmg ,
593+ store_transposed ,
594+ workers = [wf [0 ]]), wf [0 ])
595+ for idx , wf in enumerate (
596+ data .worker_to_parts .items ())]
597+ wait (result )
598+
599+ def get_renumber_map (id_type , data ):
600+ return data [0 ].astype (id_type )
601+
602+ def get_segment_offsets (data ):
603+ return data [1 ]
604+
605+ def get_renumbered_df (id_type , data ):
606+ data [2 ][renumber_map .renumbered_src_col_name ] = \
607+ data [2 ][renumber_map .renumbered_src_col_name ]\
608+ .astype (id_type )
609+ data [2 ][renumber_map .renumbered_dst_col_name ] = \
610+ data [2 ][renumber_map .renumbered_dst_col_name ]\
611+ .astype (id_type )
612+ return data [2 ]
613+
614+ renumbering_map = dask_cudf .from_delayed (
615+ [client .submit (get_renumber_map ,
616+ id_type ,
617+ data ,
618+ workers = [wf ])
619+ for (data , wf ) in result ])
620+
621+ list_of_segment_offsets = client .gather (
622+ [client .submit (get_segment_offsets ,
623+ data ,
624+ workers = [wf ])
625+ for (data , wf ) in result ])
626+ aggregate_segment_offsets = []
627+ for segment_offsets in list_of_segment_offsets :
628+ aggregate_segment_offsets .extend (segment_offsets )
629+
630+ renumbered_df = dask_cudf .from_delayed (
631+ [client .submit (get_renumbered_df ,
632+ id_type ,
633+ data ,
634+ workers = [wf ])
635+ for (data , wf ) in result ])
636+ if renumber_type == 'legacy' :
637+ renumber_map .implementation .ddf = indirection_map .merge (
638+ renumbering_map ,
639+ right_on = 'original_ids' , left_on = 'global_id' ,
640+ how = 'right' ).\
641+ drop (columns = ['global_id' , 'original_ids' ])\
642+ .rename (columns = {'new_ids' : 'global_id' })
643+ else :
644+ renumber_map .implementation .ddf = renumbering_map .rename (
645+ columns = {'original_ids' : '0' , 'new_ids' : 'global_id' })
646+ renumber_map .implementation .numbered = True
647+ return renumbered_df , renumber_map , aggregate_segment_offsets
648+
623649 else :
624- renumber_map .implementation .ddf = renumbering_map .rename (
625- columns = {'original_ids' : '0' , 'new_ids' : 'global_id' })
626- renumber_map .implementation .numbered = True
627- return renumbered_df , renumber_map , aggregate_segment_offsets
650+ # There is no aggregate_segment_offsets since the
651+ # C++ renumbering is skipped
652+ return df , renumber_map , None
628653
629654 else :
630655 renumbering_map , segment_offsets , renumbered_df = \
0 commit comments