@@ -30,8 +30,8 @@ use crate::error::Result;
3030use crate :: physical_optimizer:: PhysicalOptimizerRule ;
3131use crate :: physical_plan:: joins:: utils:: { ColumnIndex , JoinFilter } ;
3232use crate :: physical_plan:: joins:: {
33- CrossJoinExec , HashJoinExec , PartitionMode , StreamJoinPartitionMode ,
34- SymmetricHashJoinExec ,
33+ CrossJoinExec , HashJoinExec , NestedLoopJoinExec , PartitionMode ,
34+ StreamJoinPartitionMode , SymmetricHashJoinExec ,
3535} ;
3636use crate :: physical_plan:: projection:: ProjectionExec ;
3737use crate :: physical_plan:: { ExecutionPlan , ExecutionPlanProperties } ;
@@ -199,6 +199,38 @@ fn swap_hash_join(
199199 }
200200}
201201
202+ /// Swaps inputs of `NestedLoopJoinExec` and wraps it into `ProjectionExec` is required
203+ fn swap_nl_join ( join : & NestedLoopJoinExec ) -> Result < Arc < dyn ExecutionPlan > > {
204+ let new_filter = swap_join_filter ( join. filter ( ) ) ;
205+ let new_join_type = & swap_join_type ( * join. join_type ( ) ) ;
206+
207+ let new_join = NestedLoopJoinExec :: try_new (
208+ Arc :: clone ( join. right ( ) ) ,
209+ Arc :: clone ( join. left ( ) ) ,
210+ new_filter,
211+ new_join_type,
212+ ) ?;
213+
214+ // For Semi/Anti joins, swap result will produce same output schema,
215+ // no need to wrap them into additional projection
216+ let plan: Arc < dyn ExecutionPlan > = if matches ! (
217+ join. join_type( ) ,
218+ JoinType :: LeftSemi
219+ | JoinType :: RightSemi
220+ | JoinType :: LeftAnti
221+ | JoinType :: RightAnti
222+ ) {
223+ Arc :: new ( new_join)
224+ } else {
225+ let projection =
226+ swap_reverting_projection ( & join. left ( ) . schema ( ) , & join. right ( ) . schema ( ) ) ;
227+
228+ Arc :: new ( ProjectionExec :: try_new ( projection, Arc :: new ( new_join) ) ?)
229+ } ;
230+
231+ Ok ( plan)
232+ }
233+
202234/// When the order of the join is changed by the optimizer, the columns in
203235/// the output should not be impacted. This function creates the expressions
204236/// that will allow to swap back the values from the original left as the first
@@ -461,6 +493,14 @@ fn statistical_join_selection_subrule(
461493 } else {
462494 None
463495 }
496+ } else if let Some ( nl_join) = plan. as_any ( ) . downcast_ref :: < NestedLoopJoinExec > ( ) {
497+ let left = nl_join. left ( ) ;
498+ let right = nl_join. right ( ) ;
499+ if should_swap_join_order ( & * * left, & * * right) ? {
500+ swap_nl_join ( nl_join) . map ( Some ) ?
501+ } else {
502+ None
503+ }
464504 } else {
465505 None
466506 } ;
@@ -697,9 +737,12 @@ mod tests_statistical {
697737
698738 use arrow:: datatypes:: { DataType , Field , Schema } ;
699739 use datafusion_common:: { stats:: Precision , JoinType , ScalarValue } ;
700- use datafusion_physical_expr:: expressions:: Column ;
740+ use datafusion_expr:: Operator ;
741+ use datafusion_physical_expr:: expressions:: { BinaryExpr , Column } ;
701742 use datafusion_physical_expr:: { PhysicalExpr , PhysicalExprRef } ;
702743
744+ use rstest:: rstest;
745+
703746 /// Return statistcs for empty table
704747 fn empty_statistics ( ) -> Statistics {
705748 Statistics {
@@ -785,6 +828,35 @@ mod tests_statistical {
785828 } ]
786829 }
787830
831+ /// Create join filter for NLJoinExec with expression `big_col > small_col`
832+ /// where both columns are 0-indexed and come from left and right inputs respectively
833+ fn nl_join_filter ( ) -> Option < JoinFilter > {
834+ let column_indices = vec ! [
835+ ColumnIndex {
836+ index: 0 ,
837+ side: JoinSide :: Left ,
838+ } ,
839+ ColumnIndex {
840+ index: 0 ,
841+ side: JoinSide :: Right ,
842+ } ,
843+ ] ;
844+ let intermediate_schema = Schema :: new ( vec ! [
845+ Field :: new( "big_col" , DataType :: Int32 , false ) ,
846+ Field :: new( "small_col" , DataType :: Int32 , false ) ,
847+ ] ) ;
848+ let expression = Arc :: new ( BinaryExpr :: new (
849+ Arc :: new ( Column :: new_with_schema ( "big_col" , & intermediate_schema) . unwrap ( ) ) ,
850+ Operator :: Gt ,
851+ Arc :: new ( Column :: new_with_schema ( "big_col" , & intermediate_schema) . unwrap ( ) ) ,
852+ ) ) as _ ;
853+ Some ( JoinFilter :: new (
854+ expression,
855+ column_indices,
856+ intermediate_schema,
857+ ) )
858+ }
859+
788860 /// Returns three plans with statistics of (min, max, distinct_count)
789861 /// * big 100K rows @ (0, 50k, 50k)
790862 /// * medium 10K rows @ (1k, 5k, 1k)
@@ -1151,6 +1223,137 @@ mod tests_statistical {
11511223 crosscheck_plans ( join) . unwrap ( ) ;
11521224 }
11531225
1226+ #[ rstest(
1227+ join_type,
1228+ case:: inner( JoinType :: Inner ) ,
1229+ case:: left( JoinType :: Left ) ,
1230+ case:: right( JoinType :: Right ) ,
1231+ case:: full( JoinType :: Full )
1232+ ) ]
1233+ #[ tokio:: test]
1234+ async fn test_nl_join_with_swap ( join_type : JoinType ) {
1235+ let ( big, small) = create_big_and_small ( ) ;
1236+
1237+ let join = Arc :: new (
1238+ NestedLoopJoinExec :: try_new (
1239+ Arc :: clone ( & big) ,
1240+ Arc :: clone ( & small) ,
1241+ nl_join_filter ( ) ,
1242+ & join_type,
1243+ )
1244+ . unwrap ( ) ,
1245+ ) ;
1246+
1247+ let optimized_join = JoinSelection :: new ( )
1248+ . optimize ( join. clone ( ) , & ConfigOptions :: new ( ) )
1249+ . unwrap ( ) ;
1250+
1251+ let swapping_projection = optimized_join
1252+ . as_any ( )
1253+ . downcast_ref :: < ProjectionExec > ( )
1254+ . expect ( "A proj is required to swap columns back to their original order" ) ;
1255+
1256+ assert_eq ! ( swapping_projection. expr( ) . len( ) , 2 ) ;
1257+ let ( col, name) = & swapping_projection. expr ( ) [ 0 ] ;
1258+ assert_eq ! ( name, "big_col" ) ;
1259+ assert_col_expr ( col, "big_col" , 1 ) ;
1260+ let ( col, name) = & swapping_projection. expr ( ) [ 1 ] ;
1261+ assert_eq ! ( name, "small_col" ) ;
1262+ assert_col_expr ( col, "small_col" , 0 ) ;
1263+
1264+ let swapped_join = swapping_projection
1265+ . input ( )
1266+ . as_any ( )
1267+ . downcast_ref :: < NestedLoopJoinExec > ( )
1268+ . expect ( "The type of the plan should not be changed" ) ;
1269+
1270+ // Assert join side of big_col swapped in filter expression
1271+ let swapped_filter = swapped_join. filter ( ) . unwrap ( ) ;
1272+ let swapped_big_col_idx = swapped_filter. schema ( ) . index_of ( "big_col" ) . unwrap ( ) ;
1273+ let swapped_big_col_side = swapped_filter
1274+ . column_indices ( )
1275+ . get ( swapped_big_col_idx)
1276+ . unwrap ( )
1277+ . side ;
1278+ assert_eq ! (
1279+ swapped_big_col_side,
1280+ JoinSide :: Right ,
1281+ "Filter column side should be swapped"
1282+ ) ;
1283+
1284+ assert_eq ! (
1285+ swapped_join. left( ) . statistics( ) . unwrap( ) . total_byte_size,
1286+ Precision :: Inexact ( 8192 )
1287+ ) ;
1288+ assert_eq ! (
1289+ swapped_join. right( ) . statistics( ) . unwrap( ) . total_byte_size,
1290+ Precision :: Inexact ( 2097152 )
1291+ ) ;
1292+ crosscheck_plans ( join. clone ( ) ) . unwrap ( ) ;
1293+ }
1294+
1295+ #[ rstest(
1296+ join_type,
1297+ case:: left_semi( JoinType :: LeftSemi ) ,
1298+ case:: left_anti( JoinType :: LeftAnti ) ,
1299+ case:: right_semi( JoinType :: RightSemi ) ,
1300+ case:: right_anti( JoinType :: RightAnti )
1301+ ) ]
1302+ #[ tokio:: test]
1303+ async fn test_nl_join_with_swap_no_proj ( join_type : JoinType ) {
1304+ let ( big, small) = create_big_and_small ( ) ;
1305+
1306+ let join = Arc :: new (
1307+ NestedLoopJoinExec :: try_new (
1308+ Arc :: clone ( & big) ,
1309+ Arc :: clone ( & small) ,
1310+ nl_join_filter ( ) ,
1311+ & join_type,
1312+ )
1313+ . unwrap ( ) ,
1314+ ) ;
1315+
1316+ let optimized_join = JoinSelection :: new ( )
1317+ . optimize ( join. clone ( ) , & ConfigOptions :: new ( ) )
1318+ . unwrap ( ) ;
1319+
1320+ let swapped_join = optimized_join
1321+ . as_any ( )
1322+ . downcast_ref :: < NestedLoopJoinExec > ( )
1323+ . expect ( "The type of the plan should not be changed" ) ;
1324+
1325+ // Assert before/after schemas are equal
1326+ assert_eq ! (
1327+ join. schema( ) ,
1328+ swapped_join. schema( ) ,
1329+ "Join schema should not be modified while optimization"
1330+ ) ;
1331+
1332+ // Assert join side of big_col swapped in filter expression
1333+ let swapped_filter = swapped_join. filter ( ) . unwrap ( ) ;
1334+ let swapped_big_col_idx = swapped_filter. schema ( ) . index_of ( "big_col" ) . unwrap ( ) ;
1335+ let swapped_big_col_side = swapped_filter
1336+ . column_indices ( )
1337+ . get ( swapped_big_col_idx)
1338+ . unwrap ( )
1339+ . side ;
1340+ assert_eq ! (
1341+ swapped_big_col_side,
1342+ JoinSide :: Right ,
1343+ "Filter column side should be swapped"
1344+ ) ;
1345+
1346+ assert_eq ! (
1347+ swapped_join. left( ) . statistics( ) . unwrap( ) . total_byte_size,
1348+ Precision :: Inexact ( 8192 )
1349+ ) ;
1350+ assert_eq ! (
1351+ swapped_join. right( ) . statistics( ) . unwrap( ) . total_byte_size,
1352+ Precision :: Inexact ( 2097152 )
1353+ ) ;
1354+ crosscheck_plans ( join. clone ( ) ) . unwrap ( ) ;
1355+ }
1356+
11541357 #[ tokio:: test]
11551358 async fn test_swap_reverting_projection ( ) {
11561359 let left_schema = Schema :: new ( vec ! [
0 commit comments