From e8cb696bdd7ceba1bb706b7a8e56a920b6b487df Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 2 Jul 2024 09:58:32 -0700 Subject: [PATCH 01/11] Support SortMerge spilling --- datafusion/core/tests/memory_limit/mod.rs | 1 + .../src/joins/sort_merge_join.rs | 168 +++++++++++++++--- datafusion/physical-plan/src/sorts/sort.rs | 7 +- 3 files changed, 147 insertions(+), 29 deletions(-) diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index f4f4f8cd89cb1..3f7cc466e83df 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -164,6 +164,7 @@ async fn cross_join() { } #[tokio::test] +#[ignore] async fn merge_join() { // Planner chooses MergeJoin only if number of partitions > 1 let config = SessionConfig::new() diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index e9124a72970ae..f325b589f276d 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -35,11 +35,12 @@ use crate::joins::utils::{ build_join_schema, check_join_is_valid, estimate_join_statistics, symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, }; -use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use crate::{ - execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, - ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, Statistics, + execution_mode_from_children, metrics, spill_record_batches, DisplayAs, + DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, + PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream, + Statistics, }; use arrow::array::*; @@ -49,13 +50,16 @@ use arrow::error::ArrowError; use arrow_array::types::UInt64Type; use datafusion_common::{ - internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, + exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, + Result, }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::runtime_env::RuntimeEnv; use futures::{Stream, StreamExt}; use hashbrown::HashSet; @@ -375,6 +379,7 @@ impl ExecutionPlan for SortMergeJoinExec { batch_size, SortMergeJoinMetrics::new(partition, &self.metrics), reservation, + context.runtime_env(), )?)) } @@ -412,6 +417,12 @@ struct SortMergeJoinMetrics { /// Peak memory used for buffered data. /// Calculated as sum of peak memory values across partitions peak_mem_used: metrics::Gauge, + /// count of spills during the execution of the operator + spill_count: Count, + /// total spilled bytes during the execution of the operator + spilled_bytes: Count, + /// total spilled rows during the execution of the operator + spilled_rows: Count, } impl SortMergeJoinMetrics { @@ -425,6 +436,9 @@ impl SortMergeJoinMetrics { MetricBuilder::new(metrics).counter("output_batches", partition); let output_rows = MetricBuilder::new(metrics).output_rows(partition); let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition); + let spill_count = MetricBuilder::new(metrics).spill_count(partition); + let spilled_bytes = MetricBuilder::new(metrics).spilled_bytes(partition); + let spilled_rows = MetricBuilder::new(metrics).spilled_rows(partition); Self { join_time, @@ -433,6 +447,9 @@ impl SortMergeJoinMetrics { output_batches, output_rows, peak_mem_used, + spill_count, + spilled_bytes, + spilled_rows, } } } @@ -577,6 +594,8 @@ struct BufferedBatch { /// The indices of buffered batch that failed the join filter. /// When dequeuing the buffered batch, we need to produce null joined rows for these indices. pub join_filter_failed_idxs: HashSet, + pub num_rows: usize, + pub spill_file: Option, } impl BufferedBatch { @@ -602,6 +621,7 @@ impl BufferedBatch { + mem::size_of::>() + mem::size_of::(); + let num_rows = batch.num_rows(); BufferedBatch { batch, range, @@ -609,6 +629,35 @@ impl BufferedBatch { null_joined: vec![], size_estimation, join_filter_failed_idxs: HashSet::new(), + num_rows, + spill_file: None, + } + } + + fn spill_to_disk( + &mut self, + path: RefCountedTempFile, + buffered_schema: SchemaRef, + ) -> Result<()> { + let batch = std::mem::replace( + &mut self.batch, + RecordBatch::new_empty(buffered_schema.clone()), + ); + let _ = spill_record_batches(vec![batch], path.path().into(), buffered_schema)?; + self.spill_file = Some(path); + + Ok(()) + } + + fn read_spilled_from_disk( + &self, + schema: SchemaRef, + ) -> Result { + if let Some(f) = &self.spill_file { + todo!() + //read_spill_as_stream(*f, schema, 2) + } else { + exec_err!("Cannot read data batch from disk. Use `spill_to_disk` to spill") } } } @@ -634,7 +683,7 @@ struct SMJStream { pub buffered: SendableRecordBatchStream, /// Current processing record batch of streamed pub streamed_batch: StreamedBatch, - /// Currrent buffered data + /// Current buffered data pub buffered_data: BufferedData, /// (used in outer join) Is current streamed row joined at least once? pub streamed_joined: bool, @@ -666,6 +715,8 @@ struct SMJStream { pub join_metrics: SortMergeJoinMetrics, /// Memory reservation pub reservation: MemoryReservation, + /// Runtime env + pub runtime_env: Arc, } impl RecordBatchStream for SMJStream { @@ -785,6 +836,7 @@ impl SMJStream { batch_size: usize, join_metrics: SortMergeJoinMetrics, reservation: MemoryReservation, + runtime_env: Arc, ) -> Result { let streamed_schema = streamed.schema(); let buffered_schema = buffered.schema(); @@ -813,6 +865,7 @@ impl SMJStream { join_type, join_metrics, reservation, + runtime_env, }) } @@ -838,6 +891,7 @@ impl SMJStream { self.streamed_state = StreamedState::Exhausted; } Poll::Ready(Some(batch)) => { + println!("\nstreamed rows {}", batch.num_rows()); if batch.num_rows() > 0 { self.freeze_streamed()?; self.join_metrics.input_batches.add(1); @@ -872,6 +926,7 @@ impl SMJStream { if let Some(buffered_batch) = self.buffered_data.batches.pop_front() { + println!("shrink\n"); self.reservation.shrink(buffered_batch.size_estimation); } } else { @@ -900,13 +955,43 @@ impl SMJStream { Poll::Ready(Some(batch)) => { self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); + println!( + "\nbatch rows {} mem {}", + batch.num_rows(), + self.reservation.size() + ); if batch.num_rows() > 0 { - let buffered_batch = + let mut buffered_batch = BufferedBatch::new(batch, 0..1, &self.on_buffered); - self.reservation.try_grow(buffered_batch.size_estimation)?; - self.join_metrics - .peak_mem_used - .set_max(self.reservation.size()); + + if self + .reservation + .try_grow(buffered_batch.size_estimation) + .is_err() + { + // spill batch to disk + let spill_file = self + .runtime_env + .disk_manager + .create_tmp_file("SortMergeJoin")?; + buffered_batch.spill_to_disk( + spill_file, + self.buffered_schema.clone(), + )?; + + // update metrics to display spill + self.join_metrics.spill_count.add(1); + self.join_metrics + .spilled_bytes + .add(buffered_batch.size_estimation); + self.join_metrics + .spilled_rows + .add(buffered_batch.num_rows); + } else { + self.join_metrics + .peak_mem_used + .set_max(self.reservation.size()); + } self.buffered_data.batches.push_back(buffered_batch); self.buffered_state = BufferedState::PollingRest; @@ -914,6 +999,7 @@ impl SMJStream { } }, BufferedState::PollingRest => { + println!("Polling Rest"); if self.buffered_data.tail_batch().range.end < self.buffered_data.tail_batch().batch.num_rows() { @@ -941,6 +1027,7 @@ impl SMJStream { self.buffered_state = BufferedState::Ready; } Poll::Ready(Some(batch)) => { + // This code is unreachable! Think about dropping it self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); if batch.num_rows() > 0 { @@ -992,6 +1079,7 @@ impl SMJStream { /// Produce join and fill output buffer until reaching target batch size /// or the join is finished fn join_partial(&mut self) -> Result<()> { + println!("join_partial"); // Whether to join streamed rows let mut join_streamed = false; // Whether to join buffered rows @@ -1060,10 +1148,13 @@ impl SMJStream { } if join_buffered { + //println!("join_partial: join_buffered"); + // joining streamed/nulls and buffered while !self.buffered_data.scanning_finished() && self.output_size < self.batch_size { + //println!("join_partial: while join_buffered"); let scanning_idx = self.buffered_data.scanning_idx(); if join_streamed { // Join streamed row and buffered row @@ -1208,6 +1299,8 @@ impl SMJStream { .collect::>() }; + dbg!(&buffered_columns); + let streamed_columns_length = streamed_columns.len(); let buffered_columns_length = buffered_columns.len(); @@ -1473,13 +1566,9 @@ fn produce_buffered_null_batch( } // Take buffered (right) columns - let buffered_columns = buffered_batch - .batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>() - .map_err(Into::::into)?; + let buffered_columns = + get_buffered_columns_from_batch(buffered_batch, buffered_indices) + .map_err(Into::::into)?; // Create null streamed (left) columns let mut streamed_columns = streamed_schema @@ -1503,12 +1592,42 @@ fn get_buffered_columns( buffered_batch_idx: usize, buffered_indices: &UInt64Array, ) -> Result, ArrowError> { - buffered_data.batches[buffered_batch_idx] - .batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>() + get_buffered_columns_from_batch( + &buffered_data.batches[buffered_batch_idx], + buffered_indices, + ) +} + +#[inline(always)] +fn get_buffered_columns_from_batch( + buffered_batch: &BufferedBatch, + buffered_indices: &UInt64Array, +) -> Result, ArrowError> { + if buffered_batch.spill_file.is_none() { + buffered_batch + .batch + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::, ArrowError>>() + } else { + // if spilled read as a stream + let mut buffered_cols: Vec = Vec::with_capacity(buffered_indices.len()); + let mut stream = + buffered_batch.read_spilled_from_disk(buffered_batch.batch.schema())?; + let _ = futures::stream::once(async { + while let Some(batch) = stream.next().await { + let batch = batch?; + batch.columns().iter().for_each(|column| { + buffered_cols.extend(take(column, &buffered_indices, None)) + }); + } + + Ok::<(), DataFusionError>(()) + }); + + Ok(buffered_cols) + } } /// Calculate join filter bit mask considering join type specifics @@ -2749,6 +2868,7 @@ mod tests { } #[tokio::test] + #[ignore] async fn overallocation_single_batch() -> Result<()> { let left = build_table( ("a1", &vec![0, 1, 2, 3, 4, 5]), diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 5b99f8bc71617..54ac77719ecea 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -45,7 +45,7 @@ use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; use arrow_array::{Array, RecordBatchOptions, UInt32Array}; use arrow_schema::DataType; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; @@ -333,10 +333,7 @@ impl ExternalSorter { for spill in self.spills.drain(..) { if !spill.path().exists() { - return Err(DataFusionError::Internal(format!( - "Spill file {:?} does not exist", - spill.path() - ))); + return internal_err!("Spill file {:?} does not exist", spill.path()); } let stream = read_spill_as_stream(spill, Arc::clone(&self.schema), 2)?; streams.push(stream); From d5586e0d91570807c7bdd45779d9a3f23ac8d778 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 3 Jul 2024 17:31:24 -0700 Subject: [PATCH 02/11] Support SortMerge spilling --- datafusion/core/tests/memory_limit/mod.rs | 8 +- .../src/joins/sort_merge_join.rs | 135 +++++++++++------- 2 files changed, 87 insertions(+), 56 deletions(-) diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 3f7cc466e83df..b7dc591f2df96 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -164,8 +164,7 @@ async fn cross_join() { } #[tokio::test] -#[ignore] -async fn merge_join() { +async fn sort_merge_join_no_spill() { // Planner chooses MergeJoin only if number of partitions > 1 let config = SessionConfig::new() .with_target_partitions(2) @@ -185,6 +184,11 @@ async fn merge_join() { .await } +#[tokio::test] +async fn sort_merge_join_spill() { + todo!() +} + #[tokio::test] async fn symmetric_hash_join() { TestCase::new() diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index f325b589f276d..d4c584f2b7385 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -37,8 +37,8 @@ use crate::joins::utils::{ }; use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use crate::{ - execution_mode_from_children, metrics, spill_record_batches, DisplayAs, - DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, + execution_mode_from_children, metrics, read_spill_as_stream, spill_record_batches, + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; @@ -650,12 +650,11 @@ impl BufferedBatch { } fn read_spilled_from_disk( - &self, + &mut self, schema: SchemaRef, ) -> Result { - if let Some(f) = &self.spill_file { - todo!() - //read_spill_as_stream(*f, schema, 2) + if let Some(f) = mem::take(&mut self.spill_file) { + read_spill_as_stream(f, schema, 2) } else { exec_err!("Cannot read data batch from disk. Use `spill_to_disk` to spill") } @@ -921,7 +920,7 @@ impl SMJStream { while !self.buffered_data.batches.is_empty() { let head_batch = self.buffered_data.head_batch(); // If the head batch is fully processed, dequeue it and produce output of it. - if head_batch.range.end == head_batch.batch.num_rows() { + if head_batch.range.end == head_batch.num_rows { self.freeze_dequeuing_buffered()?; if let Some(buffered_batch) = self.buffered_data.batches.pop_front() @@ -964,34 +963,46 @@ impl SMJStream { let mut buffered_batch = BufferedBatch::new(batch, 0..1, &self.on_buffered); - if self + match self .reservation .try_grow(buffered_batch.size_estimation) - .is_err() { - // spill batch to disk - let spill_file = self - .runtime_env - .disk_manager - .create_tmp_file("SortMergeJoin")?; - buffered_batch.spill_to_disk( - spill_file, - self.buffered_schema.clone(), - )?; - - // update metrics to display spill - self.join_metrics.spill_count.add(1); - self.join_metrics - .spilled_bytes - .add(buffered_batch.size_estimation); - self.join_metrics - .spilled_rows - .add(buffered_batch.num_rows); - } else { - self.join_metrics - .peak_mem_used - .set_max(self.reservation.size()); - } + Ok(_) => { + self.join_metrics + .peak_mem_used + .set_max(self.reservation.size()); + Ok(()) + } + Err(_) + if self + .runtime_env + .disk_manager + .tmp_files_enabled() => + { + // spill buffered batch to disk + let spill_file = self + .runtime_env + .disk_manager + .create_tmp_file("SortMergeJoinBuffered")?; + + buffered_batch.spill_to_disk( + spill_file, + self.buffered_schema.clone(), + )?; + + // update metrics to display spill + self.join_metrics.spill_count.add(1); + self.join_metrics + .spilled_bytes + .add(buffered_batch.size_estimation); + self.join_metrics + .spilled_rows + .add(buffered_batch.num_rows); + + Ok(()) + } + Err(e) => Err(e), + }?; self.buffered_data.batches.push_back(buffered_batch); self.buffered_state = BufferedState::PollingRest; @@ -1001,10 +1012,10 @@ impl SMJStream { BufferedState::PollingRest => { println!("Polling Rest"); if self.buffered_data.tail_batch().range.end - < self.buffered_data.tail_batch().batch.num_rows() + < self.buffered_data.tail_batch().num_rows { while self.buffered_data.tail_batch().range.end - < self.buffered_data.tail_batch().batch.num_rows() + < self.buffered_data.tail_batch().num_rows { if is_join_arrays_equal( &self.buffered_data.head_batch().join_arrays, @@ -1285,7 +1296,7 @@ impl SMJStream { vec![] } else if let Some(buffered_idx) = chunk.buffered_batch_idx { get_buffered_columns( - &self.buffered_data, + &mut self.buffered_data, buffered_idx, &buffered_indices, )? @@ -1315,7 +1326,7 @@ impl SMJStream { ) { // unwrap is safe here as we check is_some on top of if statement let buffered_columns = get_buffered_columns( - &self.buffered_data, + &mut self.buffered_data, chunk.buffered_batch_idx.unwrap(), &buffered_indices, )?; @@ -1559,7 +1570,7 @@ fn produce_buffered_null_batch( schema: &SchemaRef, streamed_schema: &SchemaRef, buffered_indices: &PrimitiveArray, - buffered_batch: &BufferedBatch, + buffered_batch: &mut BufferedBatch, ) -> Result> { if buffered_indices.is_empty() { return Ok(None); @@ -1588,19 +1599,19 @@ fn produce_buffered_null_batch( /// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` #[inline(always)] fn get_buffered_columns( - buffered_data: &BufferedData, + buffered_data: &mut BufferedData, buffered_batch_idx: usize, buffered_indices: &UInt64Array, ) -> Result, ArrowError> { get_buffered_columns_from_batch( - &buffered_data.batches[buffered_batch_idx], + &mut buffered_data.batches[buffered_batch_idx], buffered_indices, ) } #[inline(always)] fn get_buffered_columns_from_batch( - buffered_batch: &BufferedBatch, + buffered_batch: &mut BufferedBatch, buffered_indices: &UInt64Array, ) -> Result, ArrowError> { if buffered_batch.spill_file.is_none() { @@ -1973,6 +1984,7 @@ mod tests { assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; use datafusion_execution::config::SessionConfig; + use datafusion_execution::disk_manager::DiskManagerConfig; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_execution::TaskContext; @@ -2868,8 +2880,7 @@ mod tests { } #[tokio::test] - #[ignore] - async fn overallocation_single_batch() -> Result<()> { + async fn overallocation_single_batch_no_spill() -> Result<()> { let left = build_table( ("a1", &vec![0, 1, 2, 3, 4, 5]), ("b1", &vec![1, 2, 3, 4, 5, 6]), @@ -2895,14 +2906,17 @@ mod tests { JoinType::LeftAnti, ]; - for join_type in join_types { - let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); - let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_config = SessionConfig::default().with_batch_size(50); + // Disable DiskManager to prevent spilling + let runtime_config = RuntimeConfig::new() + .with_memory_limit(100, 1.0) + .with_disk_manager(DiskManagerConfig::Disabled); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let session_config = SessionConfig::default().with_batch_size(50); + for join_type in join_types { let task_ctx = TaskContext::default() - .with_session_config(session_config) - .with_runtime(runtime); + .with_session_config(session_config.clone()) + .with_runtime(runtime.clone()); let task_ctx = Arc::new(task_ctx); let join = join_with_options( @@ -2928,7 +2942,7 @@ mod tests { } #[tokio::test] - async fn overallocation_multi_batch() -> Result<()> { + async fn overallocation_multi_batch_no_spill() -> Result<()> { let left_batch_1 = build_table_i32( ("a1", &vec![0, 1]), ("b1", &vec![1, 1]), @@ -2975,13 +2989,17 @@ mod tests { JoinType::LeftAnti, ]; + // Disable DiskManager to prevent spilling + let runtime_config = RuntimeConfig::new() + .with_memory_limit(100, 1.0) + .with_disk_manager(DiskManagerConfig::Disabled); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let session_config = SessionConfig::default().with_batch_size(50); + for join_type in join_types { - let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); - let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_config = SessionConfig::default().with_batch_size(50); let task_ctx = TaskContext::default() - .with_session_config(session_config) - .with_runtime(runtime); + .with_session_config(session_config.clone()) + .with_runtime(runtime.clone()); let task_ctx = Arc::new(task_ctx); let join = join_with_options( Arc::clone(&left), @@ -3005,6 +3023,15 @@ mod tests { Ok(()) } + #[tokio::test] + async fn overallocation_single_batch_spill() -> Result<()> { + todo!() + } + #[tokio::test] + async fn overallocation_multi_batch_spill() -> Result<()> { + todo!() + } + #[tokio::test] async fn left_semi_join_filtered_mask() -> Result<()> { assert_eq!( From 957cae54d9051e021bcfe6ef837714d35b685504 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 8 Jul 2024 16:39:25 -0700 Subject: [PATCH 03/11] Support SortMerge spilling --- datafusion/core/tests/memory_limit/mod.rs | 15 +- datafusion/execution/src/memory_pool/mod.rs | 17 +- .../src/joins/sort_merge_join.rs | 250 ++++++++++++++---- 3 files changed, 226 insertions(+), 56 deletions(-) diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index b7dc591f2df96..81e7129f30c5b 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -186,7 +186,20 @@ async fn sort_merge_join_no_spill() { #[tokio::test] async fn sort_merge_join_spill() { - todo!() + // Planner chooses MergeJoin only if number of partitions > 1 + let config = SessionConfig::new() + .with_target_partitions(2) + .set_bool("datafusion.optimizer.prefer_hash_join", false); + + TestCase::new() + .with_query( + "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", + ) + .with_memory_limit(1_000) + .with_config(config) + .with_disk_manager_config(DiskManagerConfig::NewOs) + .run() + .await } #[tokio::test] diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 3f66a304dc18c..e565b9a0177b9 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -18,7 +18,7 @@ //! [`MemoryPool`] for memory management during query execution, [`proxy]` for //! help with allocation accounting. -use datafusion_common::Result; +use datafusion_common::{internal_err, Result}; use std::{cmp::Ordering, sync::Arc}; mod pool; @@ -220,6 +220,21 @@ impl MemoryReservation { self.size = new_size } + /// Tries to free `capacity` bytes from this reservation + /// if `capacity` does not exceed [`Self::size`] + pub fn try_shrink(&mut self, capacity: usize) -> Result<()> { + if let Some(new_size) = self.size.checked_sub(capacity) { + self.registration.pool.shrink(self, capacity); + self.size = new_size; + Ok(()) + } else { + internal_err!( + "Cannot free the capacity {capacity} out of allocated size {}", + self.size + ) + } + } + /// Sets the size of this reservation to `capacity` pub fn resize(&mut self, capacity: usize) { match capacity.cmp(&self.size) { diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index d4c584f2b7385..1afd0bd1b60c4 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -24,44 +24,45 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::VecDeque; use std::fmt::Formatter; +use std::fs::File; +use std::io::BufReader; use std::mem; use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::expressions::PhysicalSortExpr; -use crate::joins::utils::{ - build_join_schema, check_join_is_valid, estimate_join_statistics, - symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, -}; -use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; -use crate::{ - execution_mode_from_children, metrics, read_spill_as_stream, spill_record_batches, - DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, - PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream, - Statistics, -}; - use arrow::array::*; use arrow::compute::{self, concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; +use arrow::ipc::reader::FileReader; use arrow_array::types::UInt64Type; +use futures::{Stream, StreamExt}; +use hashbrown::HashSet; use datafusion_common::{ - exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, - Result, + internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, }; +use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; -use datafusion_execution::disk_manager::RefCountedTempFile; -use datafusion_execution::runtime_env::RuntimeEnv; -use futures::{Stream, StreamExt}; -use hashbrown::HashSet; +use crate::expressions::PhysicalSortExpr; +use crate::joins::utils::{ + build_join_schema, check_join_is_valid, estimate_join_statistics, + symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, +}; +use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::{ + execution_mode_from_children, metrics, spill_record_batch_by_size, DisplayAs, + DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, + PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream, + Statistics, +}; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. @@ -638,27 +639,22 @@ impl BufferedBatch { &mut self, path: RefCountedTempFile, buffered_schema: SchemaRef, + batch_size: usize, ) -> Result<()> { let batch = std::mem::replace( &mut self.batch, RecordBatch::new_empty(buffered_schema.clone()), ); - let _ = spill_record_batches(vec![batch], path.path().into(), buffered_schema)?; + let _ = spill_record_batch_by_size( + batch, + path.path().into(), + buffered_schema, + batch_size, + ); self.spill_file = Some(path); - + dbg!(&self.spill_file); Ok(()) } - - fn read_spilled_from_disk( - &mut self, - schema: SchemaRef, - ) -> Result { - if let Some(f) = mem::take(&mut self.spill_file) { - read_spill_as_stream(f, schema, 2) - } else { - exec_err!("Cannot read data batch from disk. Use `spill_to_disk` to spill") - } - } } /// Sort-merge join stream that consumes streamed and buffered data stream @@ -925,8 +921,11 @@ impl SMJStream { if let Some(buffered_batch) = self.buffered_data.batches.pop_front() { - println!("shrink\n"); - self.reservation.shrink(buffered_batch.size_estimation); + // Noop on shrink complaints, this might happen + // on spilled batches + self.reservation + .try_shrink(buffered_batch.size_estimation) + .unwrap_or(()); } } else { // If the head batch is not fully processed, break the loop. @@ -988,6 +987,7 @@ impl SMJStream { buffered_batch.spill_to_disk( spill_file, self.buffered_schema.clone(), + self.batch_size, )?; // update metrics to display spill @@ -1614,30 +1614,46 @@ fn get_buffered_columns_from_batch( buffered_batch: &mut BufferedBatch, buffered_indices: &UInt64Array, ) -> Result, ArrowError> { - if buffered_batch.spill_file.is_none() { + if let Some(spill_file) = mem::take(&mut buffered_batch.spill_file) { + // if spilled read as a stream + let mut buffered_cols: Vec = Vec::with_capacity(buffered_indices.len()); + // let mut stream = + // read_spill_as_stream(spill_file, buffered_batch.batch.schema(), 2)?; + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = FileReader::try_new(file, None)?; + + for batch in reader { + let batch = batch?; + batch.columns().iter().for_each(|column| { + buffered_cols.extend(take(column, &buffered_indices, None)) + }); + } + + // let _ = futures::stream::once(async { + // dbg!("in"); + // while let Some(batch) = stream.next().await { + // dbg!("stream spilled batch"); + // + // let batch = batch?; + // batch.columns().iter().for_each(|column| { + // buffered_cols.extend(take(column, &buffered_indices, None)) + // }); + // } + // + // Ok::<(), ArrowError>(()) + // }); + + dbg!(&buffered_cols); + + Ok(buffered_cols) + } else { buffered_batch .batch .columns() .iter() .map(|column| take(column, &buffered_indices, None)) .collect::, ArrowError>>() - } else { - // if spilled read as a stream - let mut buffered_cols: Vec = Vec::with_capacity(buffered_indices.len()); - let mut stream = - buffered_batch.read_spilled_from_disk(buffered_batch.batch.schema())?; - let _ = futures::stream::once(async { - while let Some(batch) = stream.next().await { - let batch = batch?; - batch.columns().iter().for_each(|column| { - buffered_cols.extend(take(column, &buffered_indices, None)) - }); - } - - Ok::<(), DataFusionError>(()) - }); - - Ok(buffered_cols) } } @@ -1674,6 +1690,7 @@ fn get_filtered_join_mask( // we don't need to check any others for the same index JoinType::LeftSemi => { // have we seen a filter match for a streaming index before + // have we seen a filter match for are streaming index before for i in 0..streamed_indices_length { // LeftSemi respects only first true values for specific streaming index, // others true values for the same index must be false @@ -3025,11 +3042,136 @@ mod tests { #[tokio::test] async fn overallocation_single_batch_spill() -> Result<()> { - todo!() + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4, 5]), + ("b1", &vec![1, 2, 3, 4, 5, 6]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![1, 3, 4, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + //JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + + // Enable DiskManager to allow spilling + let runtime_config = RuntimeConfig::new() + .with_memory_limit(100, 1.0) + .with_disk_manager(DiskManagerConfig::NewOs); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let session_config = SessionConfig::default().with_batch_size(50); + + for join_type in join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(runtime.clone()); + let task_ctx = Arc::new(task_ctx); + + println!("{join_type}"); + + let join = join_with_options( + left.clone(), + right.clone(), + on.clone(), + join_type, + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let _ = common::collect(stream).await.unwrap(); + } + + Ok(()) } + #[tokio::test] async fn overallocation_multi_batch_spill() -> Result<()> { - todo!() + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![2, 3]), + ("b1", &vec![1, 1]), + ("c1", &vec![6, 7]), + ); + let left_batch_3 = build_table_i32( + ("a1", &vec![4, 5]), + ("b1", &vec![1, 1]), + ("c1", &vec![8, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10]), + ("b2", &vec![1, 1]), + ("c2", &vec![50, 60]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![20, 30]), + ("b2", &vec![1, 1]), + ("c2", &vec![70, 80]), + ); + let right_batch_3 = + build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90])); + let left = + build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); + let right = + build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + //JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + + // Enable DiskManager to allow spilling + let runtime_config = RuntimeConfig::new() + .with_memory_limit(100, 1.0) + .with_disk_manager(DiskManagerConfig::NewOs); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let session_config = SessionConfig::default().with_batch_size(50); + + for join_type in join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(runtime.clone()); + let task_ctx = Arc::new(task_ctx); + let join = join_with_options( + left.clone(), + right.clone(), + on.clone(), + join_type, + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let _ = common::collect(stream).await.unwrap(); + } + + Ok(()) } #[tokio::test] From 9d27bb57cccf8d5ce0842fa5be1bf86f2bae11f3 Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 9 Jul 2024 08:49:42 -0700 Subject: [PATCH 04/11] Support SortMerge spilling --- .../src/joins/sort_merge_join.rs | 180 +++++++----------- 1 file changed, 68 insertions(+), 112 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 1afd0bd1b60c4..e7a064f13c201 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -643,7 +643,7 @@ impl BufferedBatch { ) -> Result<()> { let batch = std::mem::replace( &mut self.batch, - RecordBatch::new_empty(buffered_schema.clone()), + RecordBatch::new_empty(Arc::clone(&buffered_schema)), ); let _ = spill_record_batch_by_size( batch, @@ -652,7 +652,6 @@ impl BufferedBatch { batch_size, ); self.spill_file = Some(path); - dbg!(&self.spill_file); Ok(()) } } @@ -886,7 +885,6 @@ impl SMJStream { self.streamed_state = StreamedState::Exhausted; } Poll::Ready(Some(batch)) => { - println!("\nstreamed rows {}", batch.num_rows()); if batch.num_rows() > 0 { self.freeze_streamed()?; self.join_metrics.input_batches.add(1); @@ -907,6 +905,43 @@ impl SMJStream { } } + fn mem_allocate_batch(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> { + match self.reservation.try_grow(buffered_batch.size_estimation) { + Ok(_) => { + self.join_metrics + .peak_mem_used + .set_max(self.reservation.size()); + Ok(()) + } + Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => { + // spill buffered batch to disk + let spill_file = self + .runtime_env + .disk_manager + .create_tmp_file("SortMergeJoinBuffered")?; + + buffered_batch.spill_to_disk( + spill_file, + Arc::clone(&self.buffered_schema), + self.batch_size, + )?; + + // update metrics to display spill + self.join_metrics.spill_count.add(1); + self.join_metrics + .spilled_bytes + .add(buffered_batch.size_estimation); + self.join_metrics.spilled_rows.add(buffered_batch.num_rows); + + Ok(()) + } + Err(e) => Err(e), + }?; + + self.buffered_data.batches.push_back(buffered_batch); + Ok(()) + } + /// Poll next buffered batches fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll>> { loop { @@ -921,11 +956,12 @@ impl SMJStream { if let Some(buffered_batch) = self.buffered_data.batches.pop_front() { - // Noop on shrink complaints, this might happen - // on spilled batches - self.reservation - .try_shrink(buffered_batch.size_estimation) - .unwrap_or(()); + // Shrink mem usage for non spilled batches only + if buffered_batch.spill_file.is_none() { + self.reservation + .try_shrink(buffered_batch.size_estimation) + .unwrap_or(()); + } } } else { // If the head batch is not fully processed, break the loop. @@ -953,64 +989,16 @@ impl SMJStream { Poll::Ready(Some(batch)) => { self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); - println!( - "\nbatch rows {} mem {}", - batch.num_rows(), - self.reservation.size() - ); if batch.num_rows() > 0 { - let mut buffered_batch = + let buffered_batch = BufferedBatch::new(batch, 0..1, &self.on_buffered); - match self - .reservation - .try_grow(buffered_batch.size_estimation) - { - Ok(_) => { - self.join_metrics - .peak_mem_used - .set_max(self.reservation.size()); - Ok(()) - } - Err(_) - if self - .runtime_env - .disk_manager - .tmp_files_enabled() => - { - // spill buffered batch to disk - let spill_file = self - .runtime_env - .disk_manager - .create_tmp_file("SortMergeJoinBuffered")?; - - buffered_batch.spill_to_disk( - spill_file, - self.buffered_schema.clone(), - self.batch_size, - )?; - - // update metrics to display spill - self.join_metrics.spill_count.add(1); - self.join_metrics - .spilled_bytes - .add(buffered_batch.size_estimation); - self.join_metrics - .spilled_rows - .add(buffered_batch.num_rows); - - Ok(()) - } - Err(e) => Err(e), - }?; - - self.buffered_data.batches.push_back(buffered_batch); + self.mem_allocate_batch(buffered_batch)?; self.buffered_state = BufferedState::PollingRest; } } }, BufferedState::PollingRest => { - println!("Polling Rest"); if self.buffered_data.tail_batch().range.end < self.buffered_data.tail_batch().num_rows { @@ -1038,7 +1026,7 @@ impl SMJStream { self.buffered_state = BufferedState::Ready; } Poll::Ready(Some(batch)) => { - // This code is unreachable! Think about dropping it + // Multi batch self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); if batch.num_rows() > 0 { @@ -1047,12 +1035,7 @@ impl SMJStream { 0..0, &self.on_buffered, ); - self.reservation - .try_grow(buffered_batch.size_estimation)?; - self.join_metrics - .peak_mem_used - .set_max(self.reservation.size()); - self.buffered_data.batches.push_back(buffered_batch); + self.mem_allocate_batch(buffered_batch)?; } } } @@ -1090,7 +1073,6 @@ impl SMJStream { /// Produce join and fill output buffer until reaching target batch size /// or the join is finished fn join_partial(&mut self) -> Result<()> { - println!("join_partial"); // Whether to join streamed rows let mut join_streamed = false; // Whether to join buffered rows @@ -1159,13 +1141,10 @@ impl SMJStream { } if join_buffered { - //println!("join_partial: join_buffered"); - // joining streamed/nulls and buffered while !self.buffered_data.scanning_finished() && self.output_size < self.batch_size { - //println!("join_partial: while join_buffered"); let scanning_idx = self.buffered_data.scanning_idx(); if join_streamed { // Join streamed row and buffered row @@ -1296,7 +1275,7 @@ impl SMJStream { vec![] } else if let Some(buffered_idx) = chunk.buffered_batch_idx { get_buffered_columns( - &mut self.buffered_data, + &self.buffered_data, buffered_idx, &buffered_indices, )? @@ -1310,8 +1289,6 @@ impl SMJStream { .collect::>() }; - dbg!(&buffered_columns); - let streamed_columns_length = streamed_columns.len(); let buffered_columns_length = buffered_columns.len(); @@ -1326,7 +1303,7 @@ impl SMJStream { ) { // unwrap is safe here as we check is_some on top of if statement let buffered_columns = get_buffered_columns( - &mut self.buffered_data, + &self.buffered_data, chunk.buffered_batch_idx.unwrap(), &buffered_indices, )?; @@ -1570,7 +1547,7 @@ fn produce_buffered_null_batch( schema: &SchemaRef, streamed_schema: &SchemaRef, buffered_indices: &PrimitiveArray, - buffered_batch: &mut BufferedBatch, + buffered_batch: &BufferedBatch, ) -> Result> { if buffered_indices.is_empty() { return Ok(None); @@ -1599,53 +1576,34 @@ fn produce_buffered_null_batch( /// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` #[inline(always)] fn get_buffered_columns( - buffered_data: &mut BufferedData, + buffered_data: &BufferedData, buffered_batch_idx: usize, buffered_indices: &UInt64Array, ) -> Result, ArrowError> { get_buffered_columns_from_batch( - &mut buffered_data.batches[buffered_batch_idx], + &buffered_data.batches[buffered_batch_idx], buffered_indices, ) } #[inline(always)] fn get_buffered_columns_from_batch( - buffered_batch: &mut BufferedBatch, + buffered_batch: &BufferedBatch, buffered_indices: &UInt64Array, ) -> Result, ArrowError> { - if let Some(spill_file) = mem::take(&mut buffered_batch.spill_file) { - // if spilled read as a stream + if let Some(spill_file) = &buffered_batch.spill_file { + // if spilled read from disk in smaller sub batches let mut buffered_cols: Vec = Vec::with_capacity(buffered_indices.len()); - // let mut stream = - // read_spill_as_stream(spill_file, buffered_batch.batch.schema(), 2)?; let file = BufReader::new(File::open(spill_file.path())?); let reader = FileReader::try_new(file, None)?; for batch in reader { - let batch = batch?; - batch.columns().iter().for_each(|column| { + batch?.columns().iter().for_each(|column| { buffered_cols.extend(take(column, &buffered_indices, None)) }); } - // let _ = futures::stream::once(async { - // dbg!("in"); - // while let Some(batch) = stream.next().await { - // dbg!("stream spilled batch"); - // - // let batch = batch?; - // batch.columns().iter().for_each(|column| { - // buffered_cols.extend(take(column, &buffered_indices, None)) - // }); - // } - // - // Ok::<(), ArrowError>(()) - // }); - - dbg!(&buffered_cols); - Ok(buffered_cols) } else { buffered_batch @@ -2933,7 +2891,7 @@ mod tests { for join_type in join_types { let task_ctx = TaskContext::default() .with_session_config(session_config.clone()) - .with_runtime(runtime.clone()); + .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); let join = join_with_options( @@ -3016,7 +2974,7 @@ mod tests { for join_type in join_types { let task_ctx = TaskContext::default() .with_session_config(session_config.clone()) - .with_runtime(runtime.clone()); + .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); let join = join_with_options( Arc::clone(&left), @@ -3062,7 +3020,7 @@ mod tests { JoinType::Inner, JoinType::Left, JoinType::Right, - //JoinType::Full, + JoinType::Full, JoinType::LeftSemi, JoinType::LeftAnti, ]; @@ -3077,14 +3035,12 @@ mod tests { for join_type in join_types { let task_ctx = TaskContext::default() .with_session_config(session_config.clone()) - .with_runtime(runtime.clone()); + .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); - println!("{join_type}"); - let join = join_with_options( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), join_type, sort_options.clone(), @@ -3141,14 +3097,14 @@ mod tests { JoinType::Inner, JoinType::Left, JoinType::Right, - //JoinType::Full, + JoinType::Full, JoinType::LeftSemi, JoinType::LeftAnti, ]; // Enable DiskManager to allow spilling let runtime_config = RuntimeConfig::new() - .with_memory_limit(100, 1.0) + .with_memory_limit(500, 1.0) .with_disk_manager(DiskManagerConfig::NewOs); let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); let session_config = SessionConfig::default().with_batch_size(50); @@ -3156,11 +3112,11 @@ mod tests { for join_type in join_types { let task_ctx = TaskContext::default() .with_session_config(session_config.clone()) - .with_runtime(runtime.clone()); + .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); let join = join_with_options( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), join_type, sort_options.clone(), From c55019182755285a1ac7d38131e5e1731cdafbb9 Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 12 Jul 2024 13:38:47 -0700 Subject: [PATCH 05/11] address comments --- .../src/joins/sort_merge_join.rs | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index e7a064f13c201..cb7b6f68890f2 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -239,11 +239,6 @@ impl SortMergeJoinExec { impl DisplayAs for SortMergeJoinExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - let display_filter = self.filter.as_ref().map_or_else( - || "".to_string(), - |f| format!(", filter={}", f.expression()), - ); - match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let on = self @@ -255,7 +250,12 @@ impl DisplayAs for SortMergeJoinExec { write!( f, "SortMergeJoin: join_type={:?}, on=[{}]{}", - self.join_type, on, display_filter + self.join_type, + on, + self.filter + .as_ref() + .map(|f| format!(", filter={}", f.expression())) + .unwrap_or("".to_string()) ) } } @@ -959,8 +959,7 @@ impl SMJStream { // Shrink mem usage for non spilled batches only if buffered_batch.spill_file.is_none() { self.reservation - .try_shrink(buffered_batch.size_estimation) - .unwrap_or(()); + .shrink(buffered_batch.size_estimation); } } } else { @@ -2911,6 +2910,10 @@ mod tests { "Resources exhausted: Failed to allocate additional" ); assert_contains!(err.to_string(), "SMJStream[0]"); + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); } Ok(()) @@ -2993,6 +2996,10 @@ mod tests { "Resources exhausted: Failed to allocate additional" ); assert_contains!(err.to_string(), "SMJStream[0]"); + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); } Ok(()) @@ -3049,6 +3056,11 @@ mod tests { let stream = join.execute(0, task_ctx)?; let _ = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); } Ok(()) @@ -3125,6 +3137,10 @@ mod tests { let stream = join.execute(0, task_ctx)?; let _ = common::collect(stream).await.unwrap(); + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); } Ok(()) From a62f808c8760639f6081038f5003568dc19f31ed Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 12 Jul 2024 15:54:13 -0700 Subject: [PATCH 06/11] address comments --- datafusion/core/tests/memory_limit/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 81e7129f30c5b..2494108d92588 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -180,6 +180,7 @@ async fn sort_merge_join_no_spill() { ]) .with_memory_limit(1_000) .with_config(config) + .with_scenario(Scenario::AccessLogStreaming) .run() .await } @@ -198,6 +199,7 @@ async fn sort_merge_join_spill() { .with_memory_limit(1_000) .with_config(config) .with_disk_manager_config(DiskManagerConfig::NewOs) + .with_scenario(Scenario::AccessLogStreaming) .run() .await } @@ -471,7 +473,7 @@ impl TestCase { let table = scenario.table(); let rt_config = RuntimeConfig::new() - // do not allow spilling + // disk manager setting controls the spilling .with_disk_manager(disk_manager_config) .with_memory_limit(memory_limit, MEMORY_FRACTION); From 8380758f4dd899b2c7cc736d6bb2257876fa7264 Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 16 Jul 2024 08:01:57 -0700 Subject: [PATCH 07/11] address comments --- datafusion/execution/src/memory_pool/mod.rs | 6 +- .../src/joins/sort_merge_join.rs | 176 +++++++++++------- 2 files changed, 115 insertions(+), 67 deletions(-) diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index e565b9a0177b9..92ed1b2918de0 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -222,11 +222,13 @@ impl MemoryReservation { /// Tries to free `capacity` bytes from this reservation /// if `capacity` does not exceed [`Self::size`] - pub fn try_shrink(&mut self, capacity: usize) -> Result<()> { + /// Returns new reservation size + /// or error if shrinking capacity is more than allocated size + pub fn try_shrink(&mut self, capacity: usize) -> Result { if let Some(new_size) = self.size.checked_sub(capacity) { self.registration.pool.shrink(self, capacity); self.size = new_size; - Ok(()) + Ok(new_size) } else { internal_err!( "Cannot free the capacity {capacity} out of allocated size {}", diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index cb7b6f68890f2..23faa924bd544 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -583,7 +583,7 @@ impl StreamedBatch { #[derive(Debug)] struct BufferedBatch { /// The buffered record batch - pub batch: RecordBatch, + pub batch: Option, /// The range in which the rows share the same join key pub range: Range, /// Array refs of the join key @@ -595,7 +595,11 @@ struct BufferedBatch { /// The indices of buffered batch that failed the join filter. /// When dequeuing the buffered batch, we need to produce null joined rows for these indices. pub join_filter_failed_idxs: HashSet, + /// Current buffered batch number of rows. Equal to batch.num_rows() + /// but if batch is spilled to disk this property is preferable + /// and less expensive pub num_rows: usize, + /// A temp spill file name on the disk if the batch spilled pub spill_file: Option, } @@ -624,7 +628,7 @@ impl BufferedBatch { let num_rows = batch.num_rows(); BufferedBatch { - batch, + batch: Some(batch), range, join_arrays, null_joined: vec![], @@ -634,26 +638,6 @@ impl BufferedBatch { spill_file: None, } } - - fn spill_to_disk( - &mut self, - path: RefCountedTempFile, - buffered_schema: SchemaRef, - batch_size: usize, - ) -> Result<()> { - let batch = std::mem::replace( - &mut self.batch, - RecordBatch::new_empty(Arc::clone(&buffered_schema)), - ); - let _ = spill_record_batch_by_size( - batch, - path.path().into(), - buffered_schema, - batch_size, - ); - self.spill_file = Some(path); - Ok(()) - } } /// Sort-merge join stream that consumes streamed and buffered data stream @@ -905,7 +889,17 @@ impl SMJStream { } } - fn mem_allocate_batch(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> { + fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> { + // Shrink memory usage for in memory batches only + if buffered_batch.spill_file.is_none() && buffered_batch.batch.is_some() { + self.reservation + .try_shrink(buffered_batch.size_estimation)?; + } + + Ok(()) + } + + fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> { match self.reservation.try_grow(buffered_batch.size_estimation) { Ok(_) => { self.join_metrics @@ -920,18 +914,23 @@ impl SMJStream { .disk_manager .create_tmp_file("SortMergeJoinBuffered")?; - buffered_batch.spill_to_disk( - spill_file, - Arc::clone(&self.buffered_schema), - self.batch_size, - )?; - - // update metrics to display spill - self.join_metrics.spill_count.add(1); - self.join_metrics - .spilled_bytes - .add(buffered_batch.size_estimation); - self.join_metrics.spilled_rows.add(buffered_batch.num_rows); + if let Some(batch) = &buffered_batch.batch { + spill_record_batch_by_size( + batch, + spill_file.path().into(), + Arc::clone(&self.buffered_schema), + self.batch_size, + )?; + buffered_batch.spill_file = Some(spill_file); + buffered_batch.batch = None; + + // update metrics to register spill + self.join_metrics.spill_count.add(1); + self.join_metrics + .spilled_bytes + .add(buffered_batch.size_estimation); + self.join_metrics.spilled_rows.add(buffered_batch.num_rows); + } Ok(()) } @@ -956,11 +955,7 @@ impl SMJStream { if let Some(buffered_batch) = self.buffered_data.batches.pop_front() { - // Shrink mem usage for non spilled batches only - if buffered_batch.spill_file.is_none() { - self.reservation - .shrink(buffered_batch.size_estimation); - } + self.free_reservation(buffered_batch)?; } } else { // If the head batch is not fully processed, break the loop. @@ -988,11 +983,12 @@ impl SMJStream { Poll::Ready(Some(batch)) => { self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); + if batch.num_rows() > 0 { let buffered_batch = BufferedBatch::new(batch, 0..1, &self.on_buffered); - self.mem_allocate_batch(buffered_batch)?; + self.allocate_reservation(buffered_batch)?; self.buffered_state = BufferedState::PollingRest; } } @@ -1034,7 +1030,7 @@ impl SMJStream { 0..0, &self.on_buffered, ); - self.mem_allocate_batch(buffered_batch)?; + self.allocate_reservation(buffered_batch)?; } } } @@ -1554,8 +1550,7 @@ fn produce_buffered_null_batch( // Take buffered (right) columns let buffered_columns = - get_buffered_columns_from_batch(buffered_batch, buffered_indices) - .map_err(Into::::into)?; + get_buffered_columns_from_batch(buffered_batch, buffered_indices)?; // Create null streamed (left) columns let mut streamed_columns = streamed_schema @@ -1578,7 +1573,7 @@ fn get_buffered_columns( buffered_data: &BufferedData, buffered_batch_idx: usize, buffered_indices: &UInt64Array, -) -> Result, ArrowError> { +) -> Result> { get_buffered_columns_from_batch( &buffered_data.batches[buffered_batch_idx], buffered_indices, @@ -1589,28 +1584,33 @@ fn get_buffered_columns( fn get_buffered_columns_from_batch( buffered_batch: &BufferedBatch, buffered_indices: &UInt64Array, -) -> Result, ArrowError> { - if let Some(spill_file) = &buffered_batch.spill_file { - // if spilled read from disk in smaller sub batches - let mut buffered_cols: Vec = Vec::with_capacity(buffered_indices.len()); - - let file = BufReader::new(File::open(spill_file.path())?); - let reader = FileReader::try_new(file, None)?; - - for batch in reader { - batch?.columns().iter().for_each(|column| { - buffered_cols.extend(take(column, &buffered_indices, None)) - }); - } - - Ok(buffered_cols) - } else { - buffered_batch - .batch +) -> Result> { + match (&buffered_batch.spill_file, &buffered_batch.batch) { + // In memory batch + (None, Some(batch)) => Ok(batch .columns() .iter() .map(|column| take(column, &buffered_indices, None)) .collect::, ArrowError>>() + .map_err(Into::::into)?), + // If the batch was spilled to disk, less likely + (Some(spill_file), None) => { + let mut buffered_cols: Vec = + Vec::with_capacity(buffered_indices.len()); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = FileReader::try_new(file, None)?; + + for batch in reader { + batch?.columns().iter().for_each(|column| { + buffered_cols.extend(take(column, &buffered_indices, None)) + }); + } + + Ok(buffered_cols) + } + // Invalid combination + _ => internal_err!("Buffered batch spill status is in the inconsistent state."), } } @@ -3055,12 +3055,35 @@ mod tests { )?; let stream = join.execute(0, task_ctx)?; - let _ = common::collect(stream).await.unwrap(); + let spilled_join_result = common::collect(stream).await.unwrap(); assert!(join.metrics().is_some()); assert!(join.metrics().unwrap().spill_count().unwrap() > 0); assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Run the test with no spill configuration as + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + join_type, + sort_options.clone(), + false, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + // Compare spilled and non spilled data to check spill logic doesn't corrupt the data + assert_eq!(spilled_join_result, no_spilled_join_result); } Ok(()) @@ -3136,11 +3159,34 @@ mod tests { )?; let stream = join.execute(0, task_ctx)?; - let _ = common::collect(stream).await.unwrap(); + let spilled_join_result = common::collect(stream).await.unwrap(); assert!(join.metrics().is_some()); assert!(join.metrics().unwrap().spill_count().unwrap() > 0); assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Run the test with no spill configuration as + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + join_type, + sort_options.clone(), + false, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + // Compare spilled and non spilled data to check spill logic doesn't corrupt the data + assert_eq!(spilled_join_result, no_spilled_join_result); } Ok(()) From 87810f2946bba1ef00f82c6f0f32792773274da3 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 17 Jul 2024 08:41:51 -0700 Subject: [PATCH 08/11] spill entire file --- .../src/joins/sort_merge_join.rs | 203 +++++++++--------- 1 file changed, 101 insertions(+), 102 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 23faa924bd544..486a9b022f13d 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -57,12 +57,7 @@ use crate::joins::utils::{ symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, }; use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; -use crate::{ - execution_mode_from_children, metrics, spill_record_batch_by_size, DisplayAs, - DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, - PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream, - Statistics, -}; +use crate::{execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, spill_record_batches}; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. @@ -914,12 +909,11 @@ impl SMJStream { .disk_manager .create_tmp_file("SortMergeJoinBuffered")?; - if let Some(batch) = &buffered_batch.batch { - spill_record_batch_by_size( - batch, + if let Some(batch) = buffered_batch.batch { + spill_record_batches( + vec![batch], spill_file.path().into(), - Arc::clone(&self.buffered_schema), - self.batch_size, + Arc::clone(&self.buffered_schema) )?; buffered_batch.spill_file = Some(spill_file); buffered_batch.batch = None; @@ -3037,53 +3031,56 @@ mod tests { .with_memory_limit(100, 1.0) .with_disk_manager(DiskManagerConfig::NewOs); let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_config = SessionConfig::default().with_batch_size(50); - - for join_type in join_types { - let task_ctx = TaskContext::default() - .with_session_config(session_config.clone()) - .with_runtime(Arc::clone(&runtime)); - let task_ctx = Arc::new(task_ctx); - - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - join_type, - sort_options.clone(), - false, - )?; - let stream = join.execute(0, task_ctx)?; - let spilled_join_result = common::collect(stream).await.unwrap(); - - assert!(join.metrics().is_some()); - assert!(join.metrics().unwrap().spill_count().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); - - // Run the test with no spill configuration as - let task_ctx_no_spill = - TaskContext::default().with_session_config(session_config.clone()); - let task_ctx_no_spill = Arc::new(task_ctx_no_spill); - - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - join_type, - sort_options.clone(), - false, - )?; - let stream = join.execute(0, task_ctx_no_spill)?; - let no_spilled_join_result = common::collect(stream).await.unwrap(); - - assert!(join.metrics().is_some()); - assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); - // Compare spilled and non spilled data to check spill logic doesn't corrupt the data - assert_eq!(spilled_join_result, no_spilled_join_result); + for batch_size in vec![1,50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + join_type.clone(), + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Run the test with no spill configuration as + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + join_type.clone(), + sort_options.clone(), + false, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + // Compare spilled and non spilled data to check spill logic doesn't corrupt the data + assert_eq!(spilled_join_result, no_spilled_join_result); + } } Ok(()) @@ -3142,51 +3139,53 @@ mod tests { .with_memory_limit(500, 1.0) .with_disk_manager(DiskManagerConfig::NewOs); let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_config = SessionConfig::default().with_batch_size(50); - - for join_type in join_types { - let task_ctx = TaskContext::default() - .with_session_config(session_config.clone()) - .with_runtime(Arc::clone(&runtime)); - let task_ctx = Arc::new(task_ctx); - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - join_type, - sort_options.clone(), - false, - )?; - - let stream = join.execute(0, task_ctx)?; - let spilled_join_result = common::collect(stream).await.unwrap(); - assert!(join.metrics().is_some()); - assert!(join.metrics().unwrap().spill_count().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); - - // Run the test with no spill configuration as - let task_ctx_no_spill = - TaskContext::default().with_session_config(session_config.clone()); - let task_ctx_no_spill = Arc::new(task_ctx_no_spill); - - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - join_type, - sort_options.clone(), - false, - )?; - let stream = join.execute(0, task_ctx_no_spill)?; - let no_spilled_join_result = common::collect(stream).await.unwrap(); - - assert!(join.metrics().is_some()); - assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); - // Compare spilled and non spilled data to check spill logic doesn't corrupt the data - assert_eq!(spilled_join_result, no_spilled_join_result); + for batch_size in vec![1,50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + join_type.clone(), + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let spilled_join_result = common::collect(stream).await.unwrap(); + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Run the test with no spill configuration as + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + join_type.clone(), + sort_options.clone(), + false, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + // Compare spilled and non spilled data to check spill logic doesn't corrupt the data + assert_eq!(spilled_join_result, no_spilled_join_result); + } } Ok(()) From 186b2ffb66bd4acd3967302bd265a5d99078c53d Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 17 Jul 2024 08:44:15 -0700 Subject: [PATCH 09/11] fmt --- .../src/joins/sort_merge_join.rs | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 486a9b022f13d..00486d1e54f01 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -57,7 +57,12 @@ use crate::joins::utils::{ symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, }; use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; -use crate::{execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, spill_record_batches}; +use crate::{ + execution_mode_from_children, metrics, spill_record_batches, DisplayAs, + DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, + PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream, + Statistics, +}; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. @@ -913,7 +918,7 @@ impl SMJStream { spill_record_batches( vec![batch], spill_file.path().into(), - Arc::clone(&self.buffered_schema) + Arc::clone(&self.buffered_schema), )?; buffered_batch.spill_file = Some(spill_file); buffered_batch.batch = None; @@ -3017,7 +3022,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![ + let join_types = [ JoinType::Inner, JoinType::Left, JoinType::Right, @@ -3032,7 +3037,7 @@ mod tests { .with_disk_manager(DiskManagerConfig::NewOs); let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - for batch_size in vec![1,50] { + for batch_size in [1, 50] { let session_config = SessionConfig::default().with_batch_size(batch_size); for join_type in &join_types { @@ -3045,7 +3050,7 @@ mod tests { Arc::clone(&left), Arc::clone(&right), on.clone(), - join_type.clone(), + *join_type, sort_options.clone(), false, )?; @@ -3067,7 +3072,7 @@ mod tests { Arc::clone(&left), Arc::clone(&right), on.clone(), - join_type.clone(), + *join_type, sort_options.clone(), false, )?; @@ -3125,7 +3130,7 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![ + let join_types = [ JoinType::Inner, JoinType::Left, JoinType::Right, @@ -3139,7 +3144,7 @@ mod tests { .with_memory_limit(500, 1.0) .with_disk_manager(DiskManagerConfig::NewOs); let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - for batch_size in vec![1,50] { + for batch_size in [1, 50] { let session_config = SessionConfig::default().with_batch_size(batch_size); for join_type in &join_types { @@ -3151,7 +3156,7 @@ mod tests { Arc::clone(&left), Arc::clone(&right), on.clone(), - join_type.clone(), + *join_type, sort_options.clone(), false, )?; @@ -3172,7 +3177,7 @@ mod tests { Arc::clone(&left), Arc::clone(&right), on.clone(), - join_type.clone(), + *join_type, sort_options.clone(), false, )?; From 1d9c7d41c0cf2f31ce447ce985757eae6f080b90 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 17 Jul 2024 12:36:21 -0700 Subject: [PATCH 10/11] merge --- .../src/joins/sort_merge_join.rs | 8 +- datafusion/physical-plan/src/spill.rs | 101 ++++++++++++++++++ 2 files changed, 105 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 00486d1e54f01..555df2613ff7b 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -57,11 +57,11 @@ use crate::joins::utils::{ symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, }; use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::spill::spill_record_batches; use crate::{ - execution_mode_from_children, metrics, spill_record_batches, DisplayAs, - DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, - PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream, - Statistics, + execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, + ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, + RecordBatchStream, SendableRecordBatchStream, Statistics, }; /// join execution plan executes partitions in parallel and combines them into a set of diff --git a/datafusion/physical-plan/src/spill.rs b/datafusion/physical-plan/src/spill.rs index 0018a27bd22bb..75797eb25624e 100644 --- a/datafusion/physical-plan/src/spill.rs +++ b/datafusion/physical-plan/src/spill.rs @@ -85,3 +85,104 @@ fn read_spill(sender: Sender>, path: &Path) -> Result<()> { } Ok(()) } + +/// Spill the `RecordBatch` to disk as smaller batches +/// split by `batch_size_rows` +/// Return `total_rows` what is spilled +pub fn spill_record_batch_by_size( + batch: &RecordBatch, + path: PathBuf, + schema: SchemaRef, + batch_size_rows: usize, +) -> Result { + let mut offset = 0; + let total_rows = batch.num_rows(); + let mut writer = IPCWriter::new(&path, schema.as_ref())?; + + while offset < total_rows { + let length = std::cmp::min(total_rows - offset, batch_size_rows); + let batch = batch.slice(offset, length); + offset += batch.num_rows(); + writer.write(&batch)?; + } + writer.finish()?; + + Ok(total_rows) +} + +#[cfg(test)] +mod tests { + use crate::spill::{spill_record_batch_by_size, spill_record_batches}; + use crate::test::build_table_i32; + use datafusion_common::Result; + use datafusion_execution::disk_manager::DiskManagerConfig; + use datafusion_execution::DiskManager; + use std::fs::File; + use std::io::BufReader; + use std::sync::Arc; + + #[test] + fn test_batch_spill_and_read() -> Result<()> { + let batch1 = build_table_i32( + ("a2", &vec![0, 1, 2]), + ("b2", &vec![3, 4, 5]), + ("c2", &vec![4, 5, 6]), + ); + + let batch2 = build_table_i32( + ("a2", &vec![10, 11, 12]), + ("b2", &vec![13, 14, 15]), + ("c2", &vec![14, 15, 16]), + ); + + let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?; + + let spill_file = disk_manager.create_tmp_file("Test Spill")?; + let schema = batch1.schema(); + let num_rows = batch1.num_rows() + batch2.num_rows(); + let cnt = spill_record_batches( + vec![batch1, batch2], + spill_file.path().into(), + Arc::clone(&schema), + ); + assert_eq!(cnt.unwrap(), num_rows); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = arrow::ipc::reader::FileReader::try_new(file, None)?; + + assert_eq!(reader.num_batches(), 2); + assert_eq!(reader.schema(), schema); + + Ok(()) + } + + #[test] + fn test_batch_spill_by_size() -> Result<()> { + let batch1 = build_table_i32( + ("a2", &vec![0, 1, 2, 3]), + ("b2", &vec![3, 4, 5, 6]), + ("c2", &vec![4, 5, 6, 7]), + ); + + let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?; + + let spill_file = disk_manager.create_tmp_file("Test Spill")?; + let schema = batch1.schema(); + let num_rows = batch1.num_rows(); + let cnt = spill_record_batch_by_size( + &batch1, + spill_file.path().into(), + Arc::clone(&schema), + 1, + ); + assert_eq!(cnt.unwrap(), num_rows); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = arrow::ipc::reader::FileReader::try_new(file, None)?; + + assert_eq!(reader.num_batches(), 4); + assert_eq!(reader.schema(), schema); + + Ok(()) + } +} From 93643b61e8de78ede59318da9e99a8eda0e8a9e0 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 18 Jul 2024 18:33:48 -0700 Subject: [PATCH 11/11] Comments --- datafusion/core/tests/memory_limit/mod.rs | 3 +- .../src/joins/sort_merge_join.rs | 44 +++++++++---------- datafusion/physical-plan/src/spill.rs | 14 +++--- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 2494108d92588..bc2c3315da592 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -175,8 +175,9 @@ async fn sort_merge_join_no_spill() { "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", ) .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", + "Failed to allocate additional", "SMJStream", + "Disk spilling disabled", ]) .with_memory_limit(1_000) .with_config(config) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 555df2613ff7b..5fde028c7f488 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -42,7 +42,8 @@ use futures::{Stream, StreamExt}; use hashbrown::HashSet; use datafusion_common::{ - internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, + exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, + Result, }; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -252,10 +253,10 @@ impl DisplayAs for SortMergeJoinExec { "SortMergeJoin: join_type={:?}, on=[{}]{}", self.join_type, on, - self.filter - .as_ref() - .map(|f| format!(", filter={}", f.expression())) - .unwrap_or("".to_string()) + self.filter.as_ref().map_or("".to_string(), |f| format!( + ", filter={}", + f.expression() + )) ) } } @@ -583,6 +584,7 @@ impl StreamedBatch { #[derive(Debug)] struct BufferedBatch { /// The buffered record batch + /// None if the batch spilled to disk th pub batch: Option, /// The range in which the rows share the same join key pub range: Range, @@ -599,7 +601,9 @@ struct BufferedBatch { /// but if batch is spilled to disk this property is preferable /// and less expensive pub num_rows: usize, - /// A temp spill file name on the disk if the batch spilled + /// An optional temp spill file name on the disk if the batch spilled + /// None by default + /// Some(fileName) if the batch spilled to the disk pub spill_file: Option, } @@ -890,7 +894,7 @@ impl SMJStream { } fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> { - // Shrink memory usage for in memory batches only + // Shrink memory usage for in-memory batches only if buffered_batch.spill_file.is_none() && buffered_batch.batch.is_some() { self.reservation .try_shrink(buffered_batch.size_estimation)?; @@ -912,7 +916,7 @@ impl SMJStream { let spill_file = self .runtime_env .disk_manager - .create_tmp_file("SortMergeJoinBuffered")?; + .create_tmp_file("sort_merge_join_buffered_spill")?; if let Some(batch) = buffered_batch.batch { spill_record_batches( @@ -929,11 +933,12 @@ impl SMJStream { .spilled_bytes .add(buffered_batch.size_estimation); self.join_metrics.spilled_rows.add(buffered_batch.num_rows); + Ok(()) + } else { + internal_err!("Buffered batch has empty body") } - - Ok(()) } - Err(e) => Err(e), + Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()), }?; self.buffered_data.batches.push_back(buffered_batch); @@ -1020,7 +1025,7 @@ impl SMJStream { self.buffered_state = BufferedState::Ready; } Poll::Ready(Some(batch)) => { - // Multi batch + // Polling batches coming concurrently as multiple partitions self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); if batch.num_rows() > 0 { @@ -1609,7 +1614,7 @@ fn get_buffered_columns_from_batch( Ok(buffered_cols) } // Invalid combination - _ => internal_err!("Buffered batch spill status is in the inconsistent state."), + (spill, batch) => internal_err!("Unexpected buffered batch spill status. Spill exists: {}. In-memory exists: {}", spill.is_some(), batch.is_some()), } } @@ -1646,7 +1651,6 @@ fn get_filtered_join_mask( // we don't need to check any others for the same index JoinType::LeftSemi => { // have we seen a filter match for a streaming index before - // have we seen a filter match for are streaming index before for i in 0..streamed_indices_length { // LeftSemi respects only first true values for specific streaming index, // others true values for the same index must be false @@ -2904,11 +2908,9 @@ mod tests { let stream = join.execute(0, task_ctx)?; let err = common::collect(stream).await.unwrap_err(); - assert_contains!( - err.to_string(), - "Resources exhausted: Failed to allocate additional" - ); + assert_contains!(err.to_string(), "Failed to allocate additional"); assert_contains!(err.to_string(), "SMJStream[0]"); + assert_contains!(err.to_string(), "Disk spilling disabled"); assert!(join.metrics().is_some()); assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); @@ -2990,11 +2992,9 @@ mod tests { let stream = join.execute(0, task_ctx)?; let err = common::collect(stream).await.unwrap_err(); - assert_contains!( - err.to_string(), - "Resources exhausted: Failed to allocate additional" - ); + assert_contains!(err.to_string(), "Failed to allocate additional"); assert_contains!(err.to_string(), "SMJStream[0]"); + assert_contains!(err.to_string(), "Disk spilling disabled"); assert!(join.metrics().is_some()); assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); diff --git a/datafusion/physical-plan/src/spill.rs b/datafusion/physical-plan/src/spill.rs index 75797eb25624e..21ca58fa0a9fa 100644 --- a/datafusion/physical-plan/src/spill.rs +++ b/datafusion/physical-plan/src/spill.rs @@ -40,7 +40,7 @@ use crate::stream::RecordBatchReceiverStream; /// `path` - temp file /// `schema` - batches schema, should be the same across batches /// `buffer` - internal buffer of capacity batches -pub fn read_spill_as_stream( +pub(crate) fn read_spill_as_stream( path: RefCountedTempFile, schema: SchemaRef, buffer: usize, @@ -56,7 +56,7 @@ pub fn read_spill_as_stream( /// Spills in-memory `batches` to disk. /// /// Returns total number of the rows spilled to disk. -pub fn spill_record_batches( +pub(crate) fn spill_record_batches( batches: Vec, path: PathBuf, schema: SchemaRef, @@ -94,7 +94,7 @@ pub fn spill_record_batch_by_size( path: PathBuf, schema: SchemaRef, batch_size_rows: usize, -) -> Result { +) -> Result<()> { let mut offset = 0; let total_rows = batch.num_rows(); let mut writer = IPCWriter::new(&path, schema.as_ref())?; @@ -107,7 +107,7 @@ pub fn spill_record_batch_by_size( } writer.finish()?; - Ok(total_rows) + Ok(()) } #[cfg(test)] @@ -168,14 +168,12 @@ mod tests { let spill_file = disk_manager.create_tmp_file("Test Spill")?; let schema = batch1.schema(); - let num_rows = batch1.num_rows(); - let cnt = spill_record_batch_by_size( + spill_record_batch_by_size( &batch1, spill_file.path().into(), Arc::clone(&schema), 1, - ); - assert_eq!(cnt.unwrap(), num_rows); + )?; let file = BufReader::new(File::open(spill_file.path())?); let reader = arrow::ipc::reader::FileReader::try_new(file, None)?;