diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 776b7a50d1f8c..c986fc489807c 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -32,7 +32,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, - DescribeTable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, + DescribeTable, Partitioning, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; @@ -917,11 +917,16 @@ impl SessionContext { /// Creates a [`DataFrame`] for a [`TableProvider`] such as a /// [`ListingTable`] or a custom user defined provider. pub fn read_table(&self, provider: Arc) -> Result { - Ok(DataFrame::new( - self.state(), - LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)? - .build()?, - )) + let state = self.state(); + let mut builder = + LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)?; + let target_partitions = state.config.target_partitions(); + if target_partitions > 1 { + // Keep the data in the target number of partitions + builder = + builder.repartition(Partitioning::RoundRobinBatch(target_partitions))?; + } + Ok(DataFrame::new(state, builder.build()?)) } /// Creates a [`DataFrame`] for reading a [`RecordBatch`] diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index aecec35f2e16d..eddf6e3a7b460 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -685,7 +685,8 @@ async fn test_grouping_sets() -> Result<()> { #[tokio::test] async fn test_grouping_sets_count() -> Result<()> { - let ctx = SessionContext::new(); + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::with_config(config); let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ vec![col("c1")], @@ -725,7 +726,8 @@ async fn test_grouping_sets_count() -> Result<()> { #[tokio::test] async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { - let ctx = SessionContext::new(); + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::with_config(config); let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ vec![col("c1")], @@ -795,6 +797,18 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_read_partitioned() -> Result<()> { + let config = SessionConfig::new().with_target_partitions(4); + let ctx = SessionContext::with_config(config); + + let df = aggregates_table(&ctx).await?; + let plan = df.create_physical_plan().await?; + + assert_eq!(plan.output_partitioning().partition_count(), 4); + Ok(()) +} + #[tokio::test] async fn join_with_alias_filter() -> Result<()> { let join_ctx = create_join_context()?;