diff --git a/dask_planner/Cargo.lock b/dask_planner/Cargo.lock index c48a9faf8..43fef7473 100644 --- a/dask_planner/Cargo.lock +++ b/dask_planner/Cargo.lock @@ -2,17 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "ahash" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" -dependencies = [ - "getrandom", - "once_cell", - "version_check", -] - [[package]] name = "ahash" version = "0.8.1" @@ -58,64 +47,95 @@ checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" [[package]] name = "arrow" -version = "26.0.0" +version = "28.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e24e2bcd431a4aa0ff003fdd2dc21c78cfb42f31459c89d2312c2746fe17a5ac" +checksum = "aed9849f86164fad5cb66ce4732782b15f1bc97f8febab04e782c20cce9d4b6c" dependencies = [ - "ahash 0.8.1", + "ahash", "arrow-array", "arrow-buffer", + "arrow-cast", + "arrow-csv", "arrow-data", + "arrow-ipc", + "arrow-json", "arrow-schema", "arrow-select", - "bitflags", "chrono", "comfy-table", - "csv", - "flatbuffers", "half", - "hashbrown", - "indexmap", - "lazy_static", - "lexical-core", + "hashbrown 0.13.1", "multiversion", "num", "regex", "regex-syntax", - "serde_json", ] [[package]] name = "arrow-array" -version = "26.0.0" +version = "28.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9044300874385f19e77cbf90911e239bd23630d8f23bb0f948f9067998a13b7" +checksum = "6b8504cf0a6797e908eecf221a865e7d339892720587f87c8b90262863015b08" dependencies = [ - "ahash 0.8.1", + "ahash", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", "half", - "hashbrown", + "hashbrown 0.13.1", "num", ] [[package]] name = "arrow-buffer" -version = "26.0.0" +version = "28.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78476cbe9e3f808dcecab86afe42d573863c63e149c62e6e379ed2522743e626" +checksum = "d6de64a27cea684b24784647d9608314bc80f7c4d55acb44a425e05fab39d916" dependencies = [ "half", "num", ] +[[package]] +name = "arrow-cast" +version = "28.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bec4a54502eefe05923c385c90a005d69474fa06ca7aa2a2b123c9f9532f6178" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "chrono", + "lexical-core", + "num", +] + +[[package]] +name = "arrow-csv" +version = "28.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7902bbf8127eac48554fe902775303377047ad49a9fd473c2b8cb399d092080" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "chrono", + "csv", + "lazy_static", + "lexical-core", + "regex", +] + [[package]] name = "arrow-data" -version = "26.0.0" +version = "28.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d916feee158c485dad4f701cba31bc9a90a8db87d9df8e2aa8adc0c20a2bbb9" +checksum = "7e4882efe617002449d5c6b5de9ddb632339074b36df8a96ea7147072f1faa8a" dependencies = [ "arrow-buffer", "arrow-schema", @@ -123,17 +143,49 @@ dependencies = [ "num", ] +[[package]] +name = "arrow-ipc" +version = "28.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa0703a6de2785828561b03a4d7793ecd333233e1b166316b4bfc7cfce55a4a7" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "flatbuffers", +] + +[[package]] +name = "arrow-json" +version = "28.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bd23fc8c6d251f96cd63b96fece56bbb9710ce5874a627cb786e2600673595a" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "indexmap", + "num", + "serde_json", +] + [[package]] name = "arrow-schema" -version = "26.0.0" +version = "28.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f9406eb7834ca6bd8350d1baa515d18b9fcec487eddacfb62f5e19511f7bd37" +checksum = "da9f143882a80be168538a60e298546314f50f11f2a288c8d73e11108da39d26" [[package]] name = "arrow-select" -version = "26.0.0" +version = "28.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6593a01586751c74498495d2f5a01fcd438102b52965c11dd98abf4ebcacef37" +checksum = "520406331d4ad60075359524947ebd804e479816439af82bcb17f8d280d9b38c" dependencies = [ "arrow-array", "arrow-buffer", @@ -229,9 +281,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.22" +version = "0.4.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfd4d1b31faaa3a89d7934dbded3111da0d2ef28e3ebccdb4f0179f5929d1ef1" +checksum = "16b0a3d9ed01224b22057780a37bb8c5dbfe1be8ba48678e7bf57ec4b385411f" dependencies = [ "iana-time-zone", "num-integer", @@ -407,23 +459,22 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "14.0.0" +version = "15.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15f1ffcbc1f040c9ab99f41db1c743d95aff267bb2e7286aaa010738b7402251" +checksum = "7b17262b899f79afdf502846d1138a8b48441afe24dc6e07c922105289248137" dependencies = [ "arrow", "chrono", - "ordered-float", "sqlparser", ] [[package]] name = "datafusion-expr" -version = "14.0.0" +version = "15.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1883d9590d303ef38fa295567e7fdb9f8f5f511fcc167412d232844678cd295c" +checksum = "533d2226b4636a1306d1f6f4ac02e436947c5d6e8bfc85f6d8f91a425c10a407" dependencies = [ - "ahash 0.8.1", + "ahash", "arrow", "datafusion-common", "log", @@ -432,9 +483,9 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "14.0.0" +version = "15.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2127d46d566ab3463d70da9675fc07b9d634be8d17e80d0e1ce79600709fe651" +checksum = "ce7ba274267b6baf1714a67727249aa56d648c8814b0f4c43387fbe6d147e619" dependencies = [ "arrow", "async-trait", @@ -442,17 +493,17 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "hashbrown", + "hashbrown 0.13.1", "log", ] [[package]] name = "datafusion-physical-expr" -version = "14.0.0" +version = "15.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d108b6fe8eeb317ecad1d74619e8758de49cccc8c771b56c97962fd52eaae23" +checksum = "f35cb53e6c2f9c40accdf45aef2be7fde030ea3051b1145a059d96109e65b0bf" dependencies = [ - "ahash 0.8.1", + "ahash", "arrow", "arrow-buffer", "arrow-schema", @@ -463,12 +514,11 @@ dependencies = [ "datafusion-expr", "datafusion-row", "half", - "hashbrown", + "hashbrown 0.13.1", "itertools", "lazy_static", "md-5", "num-traits", - "ordered-float", "paste", "rand", "regex", @@ -479,9 +529,9 @@ dependencies = [ [[package]] name = "datafusion-row" -version = "14.0.0" +version = "15.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43537b6377d506e4788bf21e9ed943340e076b48ca4d077e6ea4405ca5e54a1c" +checksum = "27c77b1229ae5cf6a6e0e2ba43ed4e98131dbf1cc4a97fad17c94230b32e0812" dependencies = [ "arrow", "datafusion-common", @@ -491,11 +541,11 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "14.0.0" +version = "15.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "244d08d4710e1088d9c0949c9b5b8d68d9cf2cde7203134a4cc389e870fe2354" +checksum = "569423fa8a50db39717080949e3b4f8763582b87baf393cc3fcf27cc21467ba7" dependencies = [ - "arrow", + "arrow-schema", "datafusion-common", "datafusion-expr", "sqlparser", @@ -600,8 +650,14 @@ name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hashbrown" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ff8ae62cd3a9102e5637afc8452c55acf3844001bd5374e0b0bd7b6616c038" dependencies = [ - "ahash 0.7.6", + "ahash", ] [[package]] @@ -665,7 +721,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", ] [[package]] @@ -998,15 +1054,6 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860" -[[package]] -name = "ordered-float" -version = "3.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d84eb1409416d254e4a9c8fa56cc24701755025b458f0fcd8e59e1f5f40c23bf" -dependencies = [ - "num-traits", -] - [[package]] name = "parking_lot" version = "0.12.1" @@ -1268,9 +1315,9 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" [[package]] name = "sqlparser" -version = "0.26.0" +version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86be66ea0b2b22749cfa157d16e2e84bf793e626a3375f4d378dc289fa03affb" +checksum = "aba319938d4bfe250a769ac88278b629701024fe16f34257f9563bc628081970" dependencies = [ "log", ] diff --git a/dask_planner/Cargo.toml b/dask_planner/Cargo.toml index 4562294ed..4366649a5 100644 --- a/dask_planner/Cargo.toml +++ b/dask_planner/Cargo.toml @@ -9,12 +9,12 @@ edition = "2021" rust-version = "1.62" [dependencies] -arrow = { version = "26.0.0", features = ["prettyprint"] } +arrow = { version = "28.0.0", features = ["prettyprint"] } async-trait = "0.1.59" -datafusion-common = "14.0.0" -datafusion-expr = "14.0.0" -datafusion-optimizer = "14.0.0" -datafusion-sql = "14.0.0" +datafusion-common = "15.0.0" +datafusion-expr = "15.0.0" +datafusion-optimizer = "15.0.0" +datafusion-sql = "15.0.0" env_logger = "0.10" log = "^0.4" mimalloc = { version = "*", default-features = false } diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index f69f120ed..9fe06f839 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -475,7 +475,10 @@ impl PyExpr { ScalarValue::LargeBinary(_value) => "LargeBinary", ScalarValue::Date32(_value) => "Date32", ScalarValue::Date64(_value) => "Date64", - ScalarValue::Time64(_value) => "Time64", + ScalarValue::Time32Second(_value) => "Time32", + ScalarValue::Time32Millisecond(_value) => "Time32", + ScalarValue::Time64Microsecond(_value) => "Time64", + ScalarValue::Time64Nanosecond(_value) => "Time64", ScalarValue::Null => "Null", ScalarValue::TimestampSecond(..) => "TimestampSecond", ScalarValue::TimestampMillisecond(..) => "TimestampMillisecond", @@ -591,7 +594,7 @@ impl PyExpr { } #[pyo3(name = "getDecimal128Value")] - pub fn decimal_128_value(&mut self) -> PyResult<(Option, u8, u8)> { + pub fn decimal_128_value(&mut self) -> PyResult<(Option, u8, i8)> { match self.get_scalar_value()? { ScalarValue::Decimal128(value, precision, scale) => Ok((*value, *precision, *scale)), other => Err(unexpected_literal_value(other)), @@ -681,7 +684,7 @@ impl PyExpr { #[pyo3(name = "getTime64Value")] pub fn time_64_value(&self) -> PyResult> { match self.get_scalar_value()? { - ScalarValue::Time64(value) => Ok(*value), + ScalarValue::Time64Nanosecond(value) => Ok(*value), other => Err(unexpected_literal_value(other)), } } diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 4b9c13438..585537779 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -12,7 +12,7 @@ pub mod types; use std::{collections::HashMap, sync::Arc}; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use datafusion_common::{DFSchema, DataFusionError}; +use datafusion_common::{DFSchema, DataFusionError, ScalarValue}; use datafusion_expr::{ logical_plan::Extension, AccumulatorFunctionImplementation, @@ -370,6 +370,10 @@ impl ContextProvider for DaskSQLContext { fn get_variable_type(&self, _: &[String]) -> Option { unimplemented!("RUST: get_variable_type is not yet implemented for DaskSQLContext") } + + fn get_config_option(&self, _option: &str) -> Option { + None + } } #[pymethods] diff --git a/dask_planner/src/sql/optimizer.rs b/dask_planner/src/sql/optimizer.rs index 2f2843763..0168a1c28 100644 --- a/dask_planner/src/sql/optimizer.rs +++ b/dask_planner/src/sql/optimizer.rs @@ -6,17 +6,17 @@ use datafusion_optimizer::{ common_subexpr_eliminate::CommonSubexprEliminate, decorrelate_where_exists::DecorrelateWhereExists, decorrelate_where_in::DecorrelateWhereIn, + eliminate_cross_join::EliminateCrossJoin, // TODO: need to handle EmptyRelation for GPU cases // eliminate_filter::EliminateFilter, eliminate_limit::EliminateLimit, + eliminate_outer_join::EliminateOuterJoin, filter_null_join_keys::FilterNullJoinKeys, - filter_push_down::FilterPushDown, inline_table_scan::InlineTableScan, limit_push_down::LimitPushDown, optimizer::{Optimizer, OptimizerRule}, projection_push_down::ProjectionPushDown, - reduce_cross_join::ReduceCrossJoin, - reduce_outer_join::ReduceOuterJoin, + push_down_filter::PushDownFilter, rewrite_disjunctive_predicate::RewriteDisjunctivePredicate, scalar_subquery_to_join::ScalarSubqueryToJoin, simplify_expressions::SimplifyExpressions, @@ -30,6 +30,9 @@ use log::trace; mod eliminate_agg_distinct; use eliminate_agg_distinct::EliminateAggDistinct; +mod join_reorder; +use join_reorder::JoinReorder; + /// Houses the optimization logic for Dask-SQL. This optimization controls the optimizations /// and their ordering in regards to their impact on the underlying `LogicalPlan` instance pub struct DaskSqlOptimizer { @@ -56,13 +59,13 @@ impl DaskSqlOptimizer { Arc::new(SimplifyExpressions::new()), // TODO: need to handle EmptyRelation for GPU cases // Arc::new(EliminateFilter::new()), - Arc::new(ReduceCrossJoin::new()), + Arc::new(EliminateCrossJoin::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(FilterNullJoinKeys::default()), - Arc::new(ReduceOuterJoin::new()), - Arc::new(FilterPushDown::new()), + Arc::new(EliminateOuterJoin::new()), + Arc::new(PushDownFilter::new()), Arc::new(LimitPushDown::new()), // Dask-SQL specific optimizations Arc::new(EliminateAggDistinct::new()), @@ -72,6 +75,7 @@ impl DaskSqlOptimizer { Arc::new(UnwrapCastInComparison::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(ProjectionPushDown::new()), + Arc::new(JoinReorder::default()), ]; Self { @@ -102,7 +106,7 @@ mod tests { use std::{any::Any, collections::HashMap, sync::Arc}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::{DataFusionError, Result}; + use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource}; use datafusion_sql::{ planner::{ContextProvider, SqlToRel}, @@ -122,14 +126,15 @@ mod tests { AND (cast('2002-05-08' as date) + interval '5 days')\ )"; let plan = test_sql(sql)?; - let expected = - "Projection: test.col_int32\n Filter: CAST(test.col_int32 AS Float64) > __sq_1.__value\ - \n CrossJoin:\ - \n TableScan: test projection=[col_int32]\ - \n Projection: AVG(test.col_int32) AS __value, alias=__sq_1\ - \n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\ - \n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\ - \n TableScan: test projection=[col_int32, col_utf8]"; + let expected = r#"Projection: test.col_int32 + Filter: CAST(test.col_int32 AS Float64) > __sq_1.__value + CrossJoin: + TableScan: test projection=[col_int32] + SubqueryAlias: __sq_1 + Projection: AVG(test.col_int32) AS __value + Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]] + Filter: test.col_utf8 >= Utf8("2002-05-08") AND test.col_utf8 <= Utf8("2002-05-13") + TableScan: test projection=[col_int32, col_utf8]"#; assert_eq!(expected, format!("{:?}", plan)); Ok(()) } @@ -189,6 +194,10 @@ mod tests { fn get_variable_type(&self, _variable_names: &[String]) -> Option { None } + + fn get_config_option(&self, _option: &str) -> Option { + None + } } struct MyTableSource { diff --git a/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs b/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs index 411e0a25a..cd0539b73 100644 --- a/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs +++ b/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs @@ -282,7 +282,6 @@ fn create_plan( LogicalPlan::Projection(Projection::try_new( projected_cols, Arc::new(second_aggregate), - None, )?) }; @@ -349,7 +348,6 @@ fn create_plan( LogicalPlan::Projection(Projection::try_new( projected_cols, Arc::new(second_aggregate), - None, )?) }; diff --git a/dask_planner/src/sql/optimizer/join_reorder.rs b/dask_planner/src/sql/optimizer/join_reorder.rs new file mode 100644 index 000000000..23888c6ef --- /dev/null +++ b/dask_planner/src/sql/optimizer/join_reorder.rs @@ -0,0 +1,625 @@ +//! Join reordering based on the paper "Improving Join Reordering for Large Scale Distributed Computing" +//! https://ieeexplore.ieee.org/document/9378281 + +use std::collections::HashSet; + +use datafusion_common::{Column, Result}; +use datafusion_expr::{Expr, Join, JoinType, LogicalPlan, LogicalPlanBuilder}; +use datafusion_optimizer::{utils, utils::split_conjunction, OptimizerConfig, OptimizerRule}; + +use crate::sql::table::DaskTableSource; + +pub struct JoinReorder { + /// Maximum number of fact tables to allow in a join + max_fact_tables: usize, + /// Ratio of the size of the dimension tables to fact tables + fact_dimension_ratio: f64, + /// Whether to preserve user-defined order of unfiltered dimensions + preserve_user_order: bool, + /// Constant to use when determining the number of rows produced by a + /// filtered relation + filter_selectivity: f64, +} + +impl Default for JoinReorder { + fn default() -> Self { + Self { + max_fact_tables: 2, + fact_dimension_ratio: 0.3, + preserve_user_order: true, + filter_selectivity: 1.0, + } + } +} + +impl OptimizerRule for JoinReorder { + fn name(&self) -> &str { + "join_reorder" + } + + fn try_optimize( + &self, + plan: &LogicalPlan, + _config: &mut OptimizerConfig, + ) -> Result> { + //TODO is this transformUp or transformDown? + // TODO too many clones - use Box/Rc/Arc to reduce + match plan { + LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { + if !is_supported_join(join) { + println!("Not a supported join"); + return Ok(None); + } + println!( + "JoinReorder attempting to optimize join: {}", + plan.display_indent() + ); + + // extract the relations and join conditions + let (rels, mut conds) = extract_inner_joins(plan); + + // split rels into facts and dims + let rels: Vec = rels.into_iter().map(|rel| Relation::new(rel)).collect(); + let largest_rel = rels.iter().map(|rel| rel.size).max().unwrap() as f64; + let mut facts = vec![]; + let mut dims = vec![]; + for rel in &rels { + if rel.size as f64 / largest_rel > self.fact_dimension_ratio { + facts.push(rel.clone()); + } else { + dims.push(rel.clone()); + } + } + println!("There are {} facts and {} dims", facts.len(), dims.len()); + if facts.is_empty() { + return Ok(None); + } + if facts.len() > self.max_fact_tables { + println!("Too many fact tables"); + return Ok(None); + } + + let mut unfiltered_dimensions = get_unfiltered_dimensions(&dims); + if !self.preserve_user_order { + unfiltered_dimensions.sort_by(|a, b| a.size.cmp(&b.size)); + } + + let filtered_dimensions = get_filtered_dimensions(&dims); + let mut filtered_dimensions: Vec = filtered_dimensions + .iter() + .map(|rel| Relation { + plan: rel.plan.clone(), + size: (rel.size as f64 * self.filter_selectivity) as usize, + }) + .collect(); + + filtered_dimensions.sort_by(|a, b| a.size.cmp(&b.size)); + for dim in &unfiltered_dimensions { + println!("UNFILTERED: {} {}", dim.size, dim.plan.display_indent()); + } + + for dim in &filtered_dimensions { + println!("FILTERED: {} {}", dim.size, dim.plan.display_indent()); + } + + // Merge both the lists of dimensions by giving user order + // the preference for tables without a selective predicate, + // whereas for tables with selective predicates giving preference + // to smaller tables. When comparing the top of both + // the lists, if size of the top table in the selective predicate + // list is smaller than top of the other list, choose it otherwise + // vice-versa. + // This algorithm is a greedy approach where smaller + // joins with filtered dimension table are preferred for execution + // earlier than other Joins to improve Join performance. We try to keep + // the user order intact when unsure about reordering to make sure + // regressions are minimized. + let mut result = vec![]; + while filtered_dimensions.len() > 0 || unfiltered_dimensions.len() > 0 { + if filtered_dimensions.len() > 0 && unfiltered_dimensions.len() > 0 { + if filtered_dimensions[0].size < unfiltered_dimensions[0].size { + result.push(filtered_dimensions.remove(0)); + } else { + result.push(unfiltered_dimensions.remove(0)); + } + } else if filtered_dimensions.len() > 0 { + result.push(filtered_dimensions.remove(0)); + } else { + result.push(unfiltered_dimensions.remove(0)); + } + } + assert!(filtered_dimensions.is_empty()); + assert!(unfiltered_dimensions.is_empty()); + + let dim_plans: Vec = + result.iter().map(|rel| rel.plan.clone()).collect(); + let optimized = if facts.len() == 1 { + build_join_tree(&facts[0].plan, &dim_plans, &mut conds)? + } else { + // build one join tree for each fact table + let fact_dim_joins = facts + .iter() + .map(|f| build_join_tree(&f.plan, &dim_plans, &mut conds)) + .collect::>>()?; + // join the trees together + build_join_tree(&fact_dim_joins[0], &fact_dim_joins[1..], &mut conds)? + }; + + if conds.is_empty() { + println!("Optimized: {}", optimized.display_indent()); + return Ok(Some(optimized)); + } else { + println!("Did not use all join conditions"); + return Ok(None); + } + } + _ => { + println!("not a join"); + // TODO do we need to manually recurse here + Ok(Some(utils::optimize_children(self, plan, _config)?)) + } + } + } + + fn optimize(&self, _plan: &LogicalPlan, _config: &mut OptimizerConfig) -> Result { + // this method is not needed because we implement try_optimize instead + unimplemented!() + } +} + +/// Represents a Fact or Dimension table, possibly nested in a filter. +#[derive(Clone, Debug)] +struct Relation { + /// Plan containing the table scan for the fact or dimension table. May also contain + /// Filter and SubqueryAlias. + plan: LogicalPlan, + /// Estimated size of the underlying table before any filtering is applied + size: usize, +} + +impl Relation { + fn new(plan: LogicalPlan) -> Self { + let size = get_table_size(&plan).unwrap_or(100); + Self { plan, size } + } + + /// Determine if this plan contains any filters + fn has_filter(&self) -> bool { + has_filter(&self.plan) + } +} + +fn has_filter(plan: &LogicalPlan) -> bool { + /// We want to ignore "IsNotNull" filters that are added for join keys since they exist + /// for most dimension tables + fn is_real_filter(predicate: &Expr) -> bool { + let exprs = split_conjunction(predicate); + let x = exprs + .iter() + .filter(|e| match e { + Expr::IsNotNull(_) => false, + _ => true, + }) + .count(); + x > 0 + } + + match plan { + LogicalPlan::Filter(filter) => is_real_filter(filter.predicate()), + LogicalPlan::TableScan(scan) => scan.filters.iter().any(is_real_filter), + _ => plan.inputs().iter().any(|child| has_filter(child)), + } +} + +/// Extracts items of consecutive inner joins and join conditions. +/// This method works for bushy trees and left/right deep trees. +fn extract_inner_joins(plan: &LogicalPlan) -> (Vec, HashSet<(Column, Column)>) { + fn _extract_inner_joins( + plan: &LogicalPlan, + rels: &mut Vec, + conds: &mut HashSet<(Column, Column)>, + ) { + match plan { + LogicalPlan::Join(join) + if join.join_type == JoinType::Inner && join.filter.is_none() => + { + _extract_inner_joins(&join.left, rels, conds); + _extract_inner_joins(&join.right, rels, conds); + // TODO could also handle join conditions here? + + for (l, r) in &join.on { + conds.insert((l.clone(), r.clone())); + } + } + _ => { + if find_join(plan).is_some() { + for x in plan.inputs() { + _extract_inner_joins(x, rels, conds); + } + } else { + // leaf node + rels.push(plan.clone()) + } + } + } + } + + let mut rels = vec![]; + let mut conds = HashSet::new(); + _extract_inner_joins(plan, &mut rels, &mut conds); + (rels, conds) +} + +/// Simple Join Constraint: Only INNER Joins are consid- +/// ered which can be composed of other Joins too. But apart +/// from the Joins, none of the operator in both the left and +/// right side of the join should be non-deterministic, or have +/// output greater than the input to the operator. For instance, +/// Filter would be allowed operator as it reduces the output +/// over input, but a project adding extra column will not +/// be allowed. It is difficult to reason about operators that +/// add extra to output when dealing with just table sizes, so +/// instead we only allowed operators from selected set of +/// operators +fn is_supported_join(join: &Join) -> bool { + //TODO check for deterministic filter expressions + + fn is_supported_rel(plan: &LogicalPlan) -> bool { + // println!("is_simple_rel? {}", plan.display_indent()); + match plan { + LogicalPlan::Join(join) => { + join.join_type == JoinType::Inner + && join.filter.is_none() + && is_supported_rel(&join.left) + && is_supported_rel(&join.right) + } + LogicalPlan::Filter(filter) => is_supported_rel(filter.input()), + LogicalPlan::SubqueryAlias(sq) => is_supported_rel(&sq.input), + LogicalPlan::TableScan(_) => true, + _ => { + println!("not a simple join: {}", plan.display_indent()); + false + } + } + } + + is_supported_rel(&LogicalPlan::Join(join.clone())) +} + +/// find first (top-level) join in plan +fn find_join(plan: &LogicalPlan) -> Option { + match plan { + LogicalPlan::Join(join) => Some(join.clone()), + other => { + if other.inputs().len() == 0 { + None + } else { + for input in &other.inputs() { + if let Some(join) = find_join(*input) { + return Some(join); + } + } + None + } + } + } +} + +fn get_unfiltered_dimensions(dims: &[Relation]) -> Vec { + dims.iter().filter(|t| !t.has_filter()).cloned().collect() +} + +fn get_filtered_dimensions(dims: &[Relation]) -> Vec { + dims.iter().filter(|t| t.has_filter()).cloned().collect() +} + +fn build_join_tree( + fact: &LogicalPlan, + dims: &[LogicalPlan], + conds: &mut HashSet<(Column, Column)>, +) -> Result { + let mut b = LogicalPlanBuilder::from(fact.clone()); + for dim in dims { + // find join keys between the fact and this dim + let mut join_keys = vec![]; + for (l, r) in conds.iter() { + if (fact.schema().index_of_column(l).is_ok() && dim.schema().index_of_column(r).is_ok()) + || fact.schema().index_of_column(r).is_ok() + && dim.schema().index_of_column(l).is_ok() + { + join_keys.push((l.clone(), r.clone())); + } + } + if !join_keys.is_empty() { + let left_keys: Vec = join_keys.iter().map(|(l, _r)| l.clone()).collect(); + let right_keys: Vec = join_keys.iter().map(|(_l, r)| r.clone()).collect(); + + for key in join_keys { + conds.remove(&key); + } + + println!("Joining fact to dim on {:?} = {:?}", left_keys, right_keys); + b = b.join(&dim, JoinType::Inner, (left_keys, right_keys), None)?; + } + } + b.build() +} + +fn get_table_size(plan: &LogicalPlan) -> Option { + match plan { + LogicalPlan::TableScan(scan) => { + let source = scan + .source + .as_any() + .downcast_ref::() + .expect("should be a DaskTableSource"); + if let Some(stats) = source.statistics() { + stats.num_rows + } else { + // TODO hard-coded stats for manual testing until stats are available + // these numbers based on sf100 + let n = match scan.table_name.as_str() { + "call_center" => 30, + "catalog_page" => 20400, + "catalog_returns" => 14404374, + "catalog_sales" => 143997065, + "customer_address" => 1000000, + "customer_demographics" => 1920800, + "customer" => 2000000, + "date_dim" => 73049, + "household_demographics" => 7200, + "income_band" => 20, + "inventory" => 399330000, + "item" => 204000, + "promotion" => 1000, + "reason" => 55, + "ship_mode" => 20, + "store" => 402, + "store_returns" => 28795080, + "store_sales" => 287997024, + "time_dim" => 86400, + "warehouse" => 15, + "web_page" => 2040, + "web_returns" => 7197670, + "web_sales" => 72001237, + "web_site" => 24, + other => { + println!("No row count available for table '{}'", other); + 100 + } + }; + + Some(n) + } + } + _ => get_table_size(&plan.inputs()[0]), + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{Result, Statistics}; + use datafusion_expr::{col, lit, JoinType, LogicalPlan, LogicalPlanBuilder, SubqueryAlias}; + + use super::*; + use crate::sql::table::DaskTableSource; + + #[test] + fn inner_join_supported() -> Result<()> { + let a = test_table_scan("t1", 100); + let b = test_table_scan("t2", 100); + let join = LogicalPlanBuilder::from(a) + .join(&b, JoinType::Inner, (vec!["t1_a"], vec!["t2_b"]), None)? + .build()?; + if let LogicalPlan::Join(join) = join { + assert!(is_supported_join(&join)); + } else { + panic!(); + } + Ok(()) + } + + #[test] + fn outer_join_not_supported() -> Result<()> { + let a = test_table_scan("t1", 100); + let b = test_table_scan("t2", 100); + let join = LogicalPlanBuilder::from(a) + .join(&b, JoinType::Left, (vec!["t1_a"], vec!["t2_b"]), None)? + .build()?; + if let LogicalPlan::Join(join) = join { + assert!(!is_supported_join(&join)); + } else { + panic!(); + } + Ok(()) + } + + #[test] + fn test_extract_inner_joins() -> Result<()> { + let join = create_test_plan()?; + let (rels, conds) = extract_inner_joins(&join); + assert_eq!(4, rels.len()); + assert_eq!(3, conds.len()); + assert_eq!("TableScan: fact", &format!("{:?}", rels[0])); + assert_eq!("TableScan: dim1", &format!("{:?}", rels[1])); + assert_eq!("TableScan: dim2", &format!("{:?}", rels[2])); + assert_eq!( + "Filter: dim3.dim3_b <= Int32(100) + TableScan: dim3", + &format!("{:?}", rels[3]) + ); + Ok(()) + } + + #[test] + fn optimize_joins() -> Result<()> { + let plan = create_test_plan()?; + let formatted_plan = format!("{}", plan.display_indent()); + let expected_plan = r#"Inner Join: fact.fact_d = dim3.dim3_a + Inner Join: fact.fact_c = dim2.dim2_a + Inner Join: fact.fact_b = dim1.dim1_a + TableScan: fact + TableScan: dim1 + TableScan: dim2 + Filter: dim3.dim3_b <= Int32(100) + TableScan: dim3"#; + assert_eq!(expected_plan, formatted_plan); + let rule = JoinReorder::default(); + let mut config = OptimizerConfig::default(); + let optimized_plan = rule.try_optimize(&plan, &mut config)?.unwrap(); + let formatted_plan = format!("{}", optimized_plan.display_indent()); + let expected_plan = r#"Inner Join: fact.fact_c = dim2.dim2_a + Inner Join: fact.fact_b = dim1.dim1_a + Inner Join: fact.fact_d = dim3.dim3_a + TableScan: fact + Filter: dim3.dim3_b <= Int32(100) + TableScan: dim3 + TableScan: dim1 + TableScan: dim2"#; + assert_eq!(expected_plan, formatted_plan); + Ok(()) + } + + #[test] + fn optimize_joins_aliases() -> Result<()> { + let plan = create_test_plan_with_aliases()?; + let formatted_plan = format!("{}", plan.display_indent()); + let expected_plan = r#"Inner Join: fact.fact_d = dim3.date_dim_a + Inner Join: fact.fact_c = dim2.date_dim_a + Inner Join: fact.fact_b = dim1.date_dim_a + TableScan: fact + SubqueryAlias: dim1 + TableScan: date_dim + SubqueryAlias: dim2 + TableScan: date_dim + Filter: dim3.date_dim_b <= Int32(100) + SubqueryAlias: dim3 + TableScan: date_dim"#; + assert_eq!(expected_plan, formatted_plan); + let rule = JoinReorder::default(); + let mut config = OptimizerConfig::default(); + let optimized_plan = rule.try_optimize(&plan, &mut config)?.unwrap(); + let formatted_plan = format!("{}", optimized_plan.display_indent()); + let expected_plan = r#"Inner Join: fact.fact_c = dim2.date_dim_a + Inner Join: fact.fact_b = dim1.date_dim_a + Inner Join: fact.fact_d = dim3.date_dim_a + TableScan: fact + Filter: dim3.date_dim_b <= Int32(100) + SubqueryAlias: dim3 + TableScan: date_dim + SubqueryAlias: dim1 + TableScan: date_dim + SubqueryAlias: dim2 + TableScan: date_dim"#; + assert_eq!(expected_plan, formatted_plan); + Ok(()) + } + + fn create_test_plan() -> Result { + let dim1 = test_table_scan("dim1", 100); + let dim2 = test_table_scan("dim2", 200); + let dim3 = test_table_scan("dim3", 50); + let fact = test_table_scan("fact", 10000); + + // add a filter to one dimension + let dim3 = LogicalPlanBuilder::from(dim3) + .filter(col("dim3_b").lt_eq(lit(100)))? + .build()?; + + LogicalPlanBuilder::from(fact) + .join( + &dim1, + JoinType::Inner, + (vec!["fact_b"], vec!["dim1_a"]), + None, + )? + .join( + &dim2, + JoinType::Inner, + (vec!["fact_c"], vec!["dim2_a"]), + None, + )? + .join( + &dim3, + JoinType::Inner, + (vec!["fact_d"], vec!["dim3_a"]), + None, + )? + .build() + } + + fn create_test_plan_with_aliases() -> Result { + let dim1 = aliased_plan(test_table_scan("date_dim", 100), "dim1"); + let dim2 = aliased_plan(test_table_scan("date_dim", 200), "dim2"); + let dim3 = aliased_plan(test_table_scan("date_dim", 50), "dim3"); + let fact = test_table_scan("fact", 10000); + + // add a filter to one dimension + let dim3 = LogicalPlanBuilder::from(dim3) + .filter(col("dim3.date_dim_b").lt_eq(lit(100)))? + .build()?; + + LogicalPlanBuilder::from(fact) + .join( + &dim1, + JoinType::Inner, + (vec!["fact_b"], vec!["dim1.date_dim_a"]), + None, + )? + .join( + &dim2, + JoinType::Inner, + (vec!["fact_c"], vec!["dim2.date_dim_a"]), + None, + )? + .join( + &dim3, + JoinType::Inner, + (vec!["fact_d"], vec!["dim3.date_dim_a"]), + None, + )? + .build() + } + + fn aliased_plan(plan: LogicalPlan, alias: &str) -> LogicalPlan { + let schema = plan.schema().as_ref().clone(); + let schema = schema.replace_qualifier(alias); + LogicalPlan::SubqueryAlias(SubqueryAlias { + input: Arc::new(plan), + alias: alias.to_string(), + schema: Arc::new(schema), + }) + } + + fn test_table_scan(table_name: &str, size: usize) -> LogicalPlan { + let schema = Schema::new(vec![ + Field::new(&format!("{}_a", table_name), DataType::UInt32, false), + Field::new(&format!("{}_b", table_name), DataType::UInt32, false), + Field::new(&format!("{}_c", table_name), DataType::UInt32, false), + Field::new(&format!("{}_d", table_name), DataType::UInt32, false), + ]); + table_scan(Some(table_name), &schema, None, size) + .expect("creating scan") + .build() + .expect("building plan") + } + + fn table_scan( + name: Option<&str>, + table_schema: &Schema, + projection: Option>, + table_size: usize, + ) -> Result { + let tbl_schema = Arc::new(table_schema.clone()); + let mut statistics = Statistics::default(); + statistics.num_rows = Some(table_size); + let table_source = Arc::new(DaskTableSource::new_with_statistics( + tbl_schema, + Some(statistics), + )); + LogicalPlanBuilder::scan(name.unwrap_or("test"), table_source, projection) + } +} diff --git a/dask_planner/src/sql/table.rs b/dask_planner/src/sql/table.rs index 679559319..e8755fb0d 100644 --- a/dask_planner/src/sql/table.rs +++ b/dask_planner/src/sql/table.rs @@ -2,7 +2,7 @@ use std::{any::Any, sync::Arc}; use arrow::datatypes::{DataType, Field, SchemaRef}; use async_trait::async_trait; -use datafusion_common::DFField; +use datafusion_common::{DFField, Statistics}; use datafusion_expr::{Expr, LogicalPlan, TableProviderFilterPushDown, TableSource}; use datafusion_optimizer::utils::split_conjunction; use datafusion_sql::TableReference; @@ -25,12 +25,26 @@ use crate::{ /// DaskTable wrapper that is compatible with DataFusion logical query plans pub struct DaskTableSource { schema: SchemaRef, + statistics: Option, } impl DaskTableSource { /// Initialize a new `EmptyTable` from a schema. pub fn new(schema: SchemaRef) -> Self { - Self { schema } + Self { + schema, + statistics: None, + } + } + + /// Initialize a new `EmptyTable` from a schema. + pub fn new_with_statistics(schema: SchemaRef, statistics: Option) -> Self { + Self { schema, statistics } + } + + /// Access optional statistics associated with this table source + pub fn statistics(&self) -> Option<&Statistics> { + self.statistics.as_ref() } }