@@ -55,6 +55,7 @@ use datafusion::{
5555 } ,
5656} ;
5757use datafusion_physical_expr:: EquivalenceProperties ;
58+ use either:: { Either , Left , Right } ;
5859use futures:: { lock:: Mutex , Stream , StreamExt , TryFutureExt , TryStreamExt } ;
5960use itertools:: Itertools ;
6061use simd_adler32:: Adler32 ;
@@ -233,10 +234,8 @@ impl PartitionBuffer {
233234 }
234235
235236 /// Initializes active builders if necessary.
236- async fn init_active_if_necessary (
237- & mut self ,
238- repartitioner : & ShuffleRepartitioner ,
239- ) -> Result < isize > {
237+ /// Returns error if memory reservation fails.
238+ fn init_active_if_necessary ( & mut self ) -> Result < isize > {
240239 let mut mem_diff = 0 ;
241240
242241 if self . active . is_empty ( ) {
@@ -250,15 +249,7 @@ impl PartitionBuffer {
250249 . sum :: < usize > ( ) ;
251250 }
252251
253- if self
254- . reservation
255- . try_grow ( self . active_slots_mem_size )
256- . is_err ( )
257- {
258- repartitioner. spill ( ) . await ?;
259- self . reservation . free ( ) ;
260- self . reservation . try_grow ( self . active_slots_mem_size ) ?;
261- }
252+ self . reservation . try_grow ( self . active_slots_mem_size ) ?;
262253
263254 self . active = new_array_builders ( & self . schema , self . batch_size ) ;
264255
@@ -267,32 +258,23 @@ impl PartitionBuffer {
267258 Ok ( mem_diff)
268259 }
269260
270- /// Appends all rows of given batch into active array builders.
271- async fn append_batch (
272- & mut self ,
273- batch : & RecordBatch ,
274- time_metric : & Time ,
275- repartitioner : & ShuffleRepartitioner ,
276- ) -> Result < isize > {
277- let columns = batch. columns ( ) ;
278- let indices = ( 0 ..batch. num_rows ( ) ) . collect :: < Vec < usize > > ( ) ;
279- self . append_rows ( columns, & indices, time_metric, repartitioner)
280- . await
281- }
282-
283261 /// Appends rows of specified indices from columns into active array builders.
284- async fn append_rows (
262+ fn append_rows (
285263 & mut self ,
286264 columns : & [ ArrayRef ] ,
287265 indices : & [ usize ] ,
266+ start_index : usize ,
288267 time_metric : & Time ,
289- repartitioner : & ShuffleRepartitioner ,
290- ) -> Result < isize > {
268+ ) -> Either < Result < isize > , usize > {
291269 let mut mem_diff = 0 ;
292- let mut start = 0 ;
270+ let mut start = start_index ;
293271
294272 // lazy init because some partition may be empty
295- mem_diff += self . init_active_if_necessary ( repartitioner) . await ?;
273+ let init = self . init_active_if_necessary ( ) ;
274+ if init. is_err ( ) {
275+ return Right ( start) ;
276+ }
277+ mem_diff += init. unwrap ( ) ;
296278
297279 while start < indices. len ( ) {
298280 let end = ( start + self . batch_size ) . min ( indices. len ( ) ) ;
@@ -305,14 +287,22 @@ impl PartitionBuffer {
305287 self . num_active_rows += end - start;
306288 if self . num_active_rows >= self . batch_size {
307289 let mut timer = time_metric. timer ( ) ;
308- mem_diff += self . flush ( ) ?;
290+ let flush = self . flush ( ) ;
291+ if let Err ( e) = flush {
292+ return Left ( Err ( e) ) ;
293+ }
294+ mem_diff += flush. unwrap ( ) ;
309295 timer. stop ( ) ;
310296
311- mem_diff += self . init_active_if_necessary ( repartitioner) . await ?;
297+ let init = self . init_active_if_necessary ( ) ;
298+ if init. is_err ( ) {
299+ return Right ( end) ;
300+ }
301+ mem_diff += init. unwrap ( ) ;
312302 }
313303 start = end;
314304 }
315- Ok ( mem_diff)
305+ Left ( Ok ( mem_diff) )
316306 }
317307
318308 /// flush active data into frozen bytes
@@ -326,7 +316,7 @@ impl PartitionBuffer {
326316 let active = std:: mem:: take ( & mut self . active ) ;
327317 let num_rows = self . num_active_rows ;
328318 self . num_active_rows = 0 ;
329- mem_diff -= self . active_slots_mem_size as isize ;
319+ self . reservation . try_shrink ( self . active_slots_mem_size ) ? ;
330320
331321 let frozen_batch = make_batch ( Arc :: clone ( & self . schema ) , active, num_rows) ?;
332322
@@ -610,7 +600,7 @@ struct ShuffleRepartitioner {
610600 output_data_file : String ,
611601 output_index_file : String ,
612602 schema : SchemaRef ,
613- buffered_partitions : Mutex < Vec < PartitionBuffer > > ,
603+ buffered_partitions : Vec < PartitionBuffer > ,
614604 spills : Mutex < Vec < SpillInfo > > ,
615605 /// Sort expressions
616606 /// Partitioning scheme to use
@@ -683,18 +673,11 @@ impl ShuffleRepartitioner {
683673 output_data_file,
684674 output_index_file,
685675 schema : Arc :: clone ( & schema) ,
686- buffered_partitions : Mutex :: new (
687- ( 0 ..num_output_partitions)
688- . map ( |partition_id| {
689- PartitionBuffer :: new (
690- Arc :: clone ( & schema) ,
691- batch_size,
692- partition_id,
693- & runtime,
694- )
695- } )
696- . collect :: < Vec < _ > > ( ) ,
697- ) ,
676+ buffered_partitions : ( 0 ..num_output_partitions)
677+ . map ( |partition_id| {
678+ PartitionBuffer :: new ( Arc :: clone ( & schema) , batch_size, partition_id, & runtime)
679+ } )
680+ . collect :: < Vec < _ > > ( ) ,
698681 spills : Mutex :: new ( vec ! [ ] ) ,
699682 partitioning,
700683 num_output_partitions,
@@ -741,8 +724,6 @@ impl ShuffleRepartitioner {
741724 // Update data size metric
742725 self . metrics . data_size . add ( input. get_array_memory_size ( ) ) ;
743726
744- let time_metric = self . metrics . baseline . elapsed_compute ( ) ;
745-
746727 // NOTE: in shuffle writer exec, the output_rows metrics represents the
747728 // number of rows those are written to output data file.
748729 self . metrics . baseline . record_output ( input. num_rows ( ) ) ;
@@ -807,17 +788,11 @@ impl ShuffleRepartitioner {
807788 . enumerate ( )
808789 . filter ( |( _, ( start, end) ) | start < end)
809790 {
810- let mut buffered_partitions = self . buffered_partitions . lock ( ) . await ;
811- let output = & mut buffered_partitions[ partition_id] ;
812-
813- // If the range of indices is not big enough, just appending the rows into
814- // active array builders instead of directly adding them as a record batch.
815- mem_diff += output
816- . append_rows (
791+ mem_diff += self
792+ . append_rows_to_partition (
817793 input. columns ( ) ,
818794 & shuffled_partition_ids[ start..end] ,
819- time_metric,
820- self ,
795+ partition_id,
821796 )
822797 . await ?;
823798
@@ -842,16 +817,18 @@ impl ShuffleRepartitioner {
842817 }
843818 }
844819 Partitioning :: UnknownPartitioning ( n) if * n == 1 => {
845- let mut buffered_partitions = self . buffered_partitions . lock ( ) . await ;
820+ let buffered_partitions = & mut self . buffered_partitions ;
846821
847822 assert ! (
848823 buffered_partitions. len( ) == 1 ,
849824 "Expected 1 partition but got {}" ,
850825 buffered_partitions. len( )
851826 ) ;
852827
853- let output = & mut buffered_partitions[ 0 ] ;
854- output. append_batch ( & input, time_metric, self ) . await ?;
828+ let indices = ( 0 ..input. num_rows ( ) ) . collect :: < Vec < usize > > ( ) ;
829+
830+ self . append_rows_to_partition ( input. columns ( ) , & indices, 0 )
831+ . await ?;
855832 }
856833 other => {
857834 // this should be unreachable as long as the validation logic
@@ -868,7 +845,7 @@ impl ShuffleRepartitioner {
868845 /// Writes buffered shuffled record batches into Arrow IPC bytes.
869846 async fn shuffle_write ( & mut self ) -> Result < SendableRecordBatchStream > {
870847 let num_output_partitions = self . num_output_partitions ;
871- let mut buffered_partitions = self . buffered_partitions . lock ( ) . await ;
848+ let buffered_partitions = & mut self . buffered_partitions ;
872849 let mut output_batches: Vec < Vec < u8 > > = vec ! [ vec![ ] ; num_output_partitions] ;
873850
874851 for i in 0 ..num_output_partitions {
@@ -966,16 +943,15 @@ impl ShuffleRepartitioner {
966943 self . metrics . data_size . value ( )
967944 }
968945
969- async fn spill ( & self ) -> Result < usize > {
946+ async fn spill ( & mut self ) -> Result < usize > {
970947 log:: debug!(
971948 "ShuffleRepartitioner spilling shuffle data of {} to disk while inserting ({} time(s) so far)" ,
972949 self . used( ) ,
973950 self . spill_count( )
974951 ) ;
975952
976- let mut buffered_partitions = self . buffered_partitions . lock ( ) . await ;
977953 // we could always get a chance to free some memory as long as we are holding some
978- if buffered_partitions. len ( ) == 0 {
954+ if self . buffered_partitions . is_empty ( ) {
979955 return Ok ( 0 ) ;
980956 }
981957
@@ -986,7 +962,7 @@ impl ShuffleRepartitioner {
986962 . disk_manager
987963 . create_tmp_file ( "shuffle writer spill" ) ?;
988964 let offsets = spill_into (
989- & mut buffered_partitions,
965+ & mut self . buffered_partitions ,
990966 spillfile. path ( ) ,
991967 self . num_output_partitions ,
992968 )
@@ -1004,6 +980,60 @@ impl ShuffleRepartitioner {
1004980 } ) ;
1005981 Ok ( used)
1006982 }
983+
984+ /// Appends rows of specified indices from columns into active array builders in the specified partition.
985+ async fn append_rows_to_partition (
986+ & mut self ,
987+ columns : & [ ArrayRef ] ,
988+ indices : & [ usize ] ,
989+ partition_id : usize ,
990+ ) -> Result < isize > {
991+ let mut mem_diff = 0 ;
992+
993+ let output = & mut self . buffered_partitions [ partition_id] ;
994+
995+ let time_metric = self . metrics . baseline . elapsed_compute ( ) ;
996+
997+ // If the range of indices is not big enough, just appending the rows into
998+ // active array builders instead of directly adding them as a record batch.
999+ let mut start_index: usize = 0 ;
1000+ let mut output_ret = output. append_rows ( columns, indices, start_index, time_metric) ;
1001+
1002+ loop {
1003+ match output_ret {
1004+ Left ( l) => {
1005+ mem_diff += l?;
1006+ break ;
1007+ }
1008+ Right ( new_start) => {
1009+ // Cannot allocate enough memory for the array builders in the partition,
1010+ // spill partitions and retry.
1011+ self . spill ( ) . await ?;
1012+
1013+ let output = & mut self . buffered_partitions [ partition_id] ;
1014+ output. reservation . free ( ) ;
1015+
1016+ let time_metric = self . metrics . baseline . elapsed_compute ( ) ;
1017+
1018+ start_index = new_start;
1019+ output_ret = output. append_rows ( columns, indices, start_index, time_metric) ;
1020+
1021+ if let Right ( new_start) = output_ret {
1022+ if new_start == start_index {
1023+ // If the start index is not updated, it means that the partition
1024+ // is still not able to allocate enough memory for the array builders.
1025+ return Err ( DataFusionError :: Internal (
1026+ "Partition is still not able to allocate enough memory for the array builders after spilling."
1027+ . to_string ( ) ,
1028+ ) ) ;
1029+ }
1030+ }
1031+ }
1032+ }
1033+ }
1034+
1035+ Ok ( mem_diff)
1036+ }
10071037}
10081038
10091039/// consume the `buffered_partitions` and do spill into a single temp shuffle output file
@@ -1520,6 +1550,8 @@ mod test {
15201550 use datafusion:: physical_plan:: common:: collect;
15211551 use datafusion:: physical_plan:: memory:: MemoryExec ;
15221552 use datafusion:: prelude:: SessionContext ;
1553+ use datafusion_execution:: config:: SessionConfig ;
1554+ use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
15231555 use datafusion_physical_expr:: expressions:: Column ;
15241556 use tokio:: runtime:: Runtime ;
15251557
@@ -1554,25 +1586,63 @@ mod test {
15541586 #[ test]
15551587 #[ cfg_attr( miri, ignore) ] // miri can't call foreign function `ZSTD_createCCtx`
15561588 fn test_insert_larger_batch ( ) {
1589+ shuffle_write_test ( 10000 , 1 , 16 , None ) ;
1590+ }
1591+
1592+ #[ test]
1593+ #[ cfg_attr( miri, ignore) ] // miri can't call foreign function `ZSTD_createCCtx`
1594+ fn test_insert_smaller_batch ( ) {
1595+ shuffle_write_test ( 1000 , 1 , 16 , None ) ;
1596+ shuffle_write_test ( 1000 , 10 , 16 , None ) ;
1597+ }
1598+
1599+ #[ test]
1600+ #[ cfg_attr( miri, ignore) ] // miri can't call foreign function `ZSTD_createCCtx`
1601+ fn test_large_number_of_partitions ( ) {
1602+ shuffle_write_test ( 10000 , 10 , 200 , Some ( 10 * 1024 * 1024 ) ) ;
1603+ shuffle_write_test ( 10000 , 10 , 2000 , Some ( 10 * 1024 * 1024 ) ) ;
1604+ }
1605+
1606+ #[ test]
1607+ #[ cfg_attr( miri, ignore) ] // miri can't call foreign function `ZSTD_createCCtx`
1608+ fn test_large_number_of_partitions_spilling ( ) {
1609+ shuffle_write_test ( 10000 , 100 , 200 , Some ( 10 * 1024 * 1024 ) ) ;
1610+ }
1611+
1612+ fn shuffle_write_test (
1613+ batch_size : usize ,
1614+ num_batches : usize ,
1615+ num_partitions : usize ,
1616+ memory_limit : Option < usize > ,
1617+ ) {
15571618 let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Utf8 , true ) ] ) ) ;
15581619 let mut b = StringBuilder :: new ( ) ;
1559- for i in 0 ..10000 {
1620+ for i in 0 ..batch_size {
15601621 b. append_value ( format ! ( "{i}" ) ) ;
15611622 }
15621623 let array = b. finish ( ) ;
15631624 let batch = RecordBatch :: try_new ( Arc :: clone ( & schema) , vec ! [ Arc :: new( array) ] ) . unwrap ( ) ;
15641625
1565- let batches = vec ! [ batch. clone( ) ] ;
1626+ let batches = ( 0 ..num_batches ) . map ( |_| batch. clone ( ) ) . collect :: < Vec < _ > > ( ) ;
15661627
15671628 let partitions = & [ batches] ;
15681629 let exec = ShuffleWriterExec :: try_new (
15691630 Arc :: new ( MemoryExec :: try_new ( partitions, batch. schema ( ) , None ) . unwrap ( ) ) ,
1570- Partitioning :: Hash ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , 16 ) ,
1631+ Partitioning :: Hash ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , num_partitions ) ,
15711632 "/tmp/data.out" . to_string ( ) ,
15721633 "/tmp/index.out" . to_string ( ) ,
15731634 )
15741635 . unwrap ( ) ;
1575- let ctx = SessionContext :: new ( ) ;
1636+
1637+ // 10MB memory should be enough for running this test
1638+ let config = SessionConfig :: new ( ) ;
1639+ let mut runtime_env_builder = RuntimeEnvBuilder :: new ( ) ;
1640+ runtime_env_builder = match memory_limit {
1641+ Some ( limit) => runtime_env_builder. with_memory_limit ( limit, 1.0 ) ,
1642+ None => runtime_env_builder,
1643+ } ;
1644+ let runtime_env = Arc :: new ( runtime_env_builder. build ( ) . unwrap ( ) ) ;
1645+ let ctx = SessionContext :: new_with_config_rt ( config, runtime_env) ;
15761646 let task_ctx = ctx. task_ctx ( ) ;
15771647 let stream = exec. execute ( 0 , task_ctx) . unwrap ( ) ;
15781648 let rt = Runtime :: new ( ) . unwrap ( ) ;
0 commit comments