Skip to content

Commit 4b41413

Browse files
committed
Update hashjoin to identify right probe side (ref: apache#17518)
1 parent 408e1e4 commit 4b41413

File tree

2 files changed

+254
-27
lines changed

2 files changed

+254
-27
lines changed

datafusion/physical-plan/src/joins/hash_join/exec.rs

Lines changed: 182 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -463,12 +463,25 @@ impl HashJoinExec {
463463
})
464464
}
465465

466-
fn create_dynamic_filter(on: &JoinOn) -> Arc<DynamicFilterPhysicalExpr> {
467-
// Extract the right-side keys (probe side keys) from the `on` clauses
468-
// Dynamic filter will be created from build side values (left side) and applied to probe side (right side)
469-
let right_keys: Vec<_> = on.iter().map(|(_, r)| Arc::clone(r)).collect();
466+
fn join_exprs_for_side(on: &JoinOn, pushdown_side: JoinSide) -> Vec<PhysicalExprRef> {
467+
match pushdown_side {
468+
JoinSide::Left => on.iter().map(|(l, _)| Arc::clone(l)).collect(),
469+
JoinSide::Right => on.iter().map(|(_, r)| Arc::clone(r)).collect(),
470+
JoinSide::None => return vec![],
471+
}
472+
}
473+
474+
fn create_dynamic_filter(
475+
on: &JoinOn,
476+
pushdown_side: JoinSide,
477+
) -> Result<Arc<DynamicFilterPhysicalExpr>> {
478+
if pushdown_side == JoinSide::None {
479+
return internal_err!("dynamic filter side must be specified");
480+
}
481+
// Extract the join key expressions from the side that will receive the dynamic filter
482+
let keys = Self::join_exprs_for_side(on, pushdown_side);
470483
// Initialize with a placeholder expression (true) that will be updated when the hash table is built
471-
Arc::new(DynamicFilterPhysicalExpr::new(right_keys, lit(true)))
484+
Ok(Arc::new(DynamicFilterPhysicalExpr::new(keys, lit(true))))
472485
}
473486

474487
/// left (build) side which gets hashed
@@ -780,6 +793,21 @@ impl DisplayAs for HashJoinExec {
780793
}
781794
}
782795

796+
fn find_filter_pushdown_sides(join_type: JoinType) -> JoinSide {
797+
match join_type {
798+
JoinType::Inner => JoinSide::Right,
799+
JoinType::Left => JoinSide::Right,
800+
JoinType::Right => JoinSide::Left,
801+
JoinType::Full => JoinSide::None,
802+
JoinType::LeftSemi => JoinSide::Right,
803+
JoinType::RightSemi => JoinSide::Left,
804+
JoinType::LeftAnti => JoinSide::Right,
805+
JoinType::RightAnti => JoinSide::Left,
806+
JoinType::LeftMark => JoinSide::Right,
807+
JoinType::RightMark => JoinSide::Left,
808+
}
809+
}
810+
783811
impl ExecutionPlan for HashJoinExec {
784812
fn name(&self) -> &'static str {
785813
"HashJoinExec"
@@ -929,8 +957,10 @@ impl ExecutionPlan for HashJoinExec {
929957
}
930958

931959
let enable_dynamic_filter_pushdown = self.dynamic_filter.is_some();
960+
// let filter_pushdown_side = find_filter_pushdown_sides(self.join_type);
932961

933962
let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
963+
let probe_side = find_filter_pushdown_sides(self.join_type);
934964
let left_fut = match self.mode {
935965
PartitionMode::CollectLeft => self.left_fut.try_once(|| {
936966
let left_stream = self.left.execute(0, Arc::clone(&context))?;
@@ -946,7 +976,7 @@ impl ExecutionPlan for HashJoinExec {
946976
reservation,
947977
need_produce_result_in_final(self.join_type),
948978
self.right().output_partitioning().partition_count(),
949-
enable_dynamic_filter_pushdown,
979+
enable_dynamic_filter_pushdown && probe_side == JoinSide::Right,
950980
))
951981
})?,
952982
PartitionMode::Partitioned => {
@@ -964,7 +994,7 @@ impl ExecutionPlan for HashJoinExec {
964994
reservation,
965995
need_produce_result_in_final(self.join_type),
966996
1,
967-
enable_dynamic_filter_pushdown,
997+
enable_dynamic_filter_pushdown && probe_side == JoinSide::Right,
968998
))
969999
}
9701000
PartitionMode::Auto => {
@@ -982,18 +1012,22 @@ impl ExecutionPlan for HashJoinExec {
9821012
.then(|| {
9831013
self.dynamic_filter.as_ref().map(|df| {
9841014
let filter = Arc::clone(&df.filter);
985-
let on_right = self
986-
.on
987-
.iter()
988-
.map(|(_, right_expr)| Arc::clone(right_expr))
989-
.collect::<Vec<_>>();
1015+
// Determine which side will receive the dynamic filter
1016+
let probe_side = find_filter_pushdown_sides(self.join_type);
1017+
// Bounds should be collected from the build side (opposite of probe side)
1018+
// let build_side = match probe_side {
1019+
// JoinSide::Left => JoinSide::Right,
1020+
// JoinSide::Right => JoinSide::Left,
1021+
// JoinSide::None => JoinSide::None,
1022+
// };
1023+
let on_expressions = Self::join_exprs_for_side(&self.on, probe_side);
9901024
Some(Arc::clone(df.bounds_accumulator.get_or_init(|| {
9911025
Arc::new(SharedBoundsAccumulator::new_from_partition_mode(
9921026
self.mode,
9931027
self.left.as_ref(),
9941028
self.right.as_ref(),
9951029
filter,
996-
on_right,
1030+
on_expressions,
9971031
))
9981032
})))
9991033
})
@@ -1126,7 +1160,7 @@ impl ExecutionPlan for HashJoinExec {
11261160
}
11271161

11281162
// Get basic filter descriptions for both children
1129-
let left_child = crate::filter_pushdown::ChildFilterDescription::from_child(
1163+
let mut left_child = crate::filter_pushdown::ChildFilterDescription::from_child(
11301164
&parent_filters,
11311165
self.left(),
11321166
)?;
@@ -1139,9 +1173,24 @@ impl ExecutionPlan for HashJoinExec {
11391173
if matches!(phase, FilterPushdownPhase::Post)
11401174
&& config.optimizer.enable_join_dynamic_filter_pushdown
11411175
{
1142-
// Add actual dynamic filter to right side (probe side)
1143-
let dynamic_filter = Self::create_dynamic_filter(&self.on);
1144-
right_child = right_child.with_self_filter(dynamic_filter);
1176+
let pushdown_side = find_filter_pushdown_sides(self.join_type);
1177+
let dynamic_filter = Self::create_dynamic_filter(&self.on, pushdown_side)?;
1178+
match pushdown_side {
1179+
JoinSide::None => {
1180+
// A join type that preserves both sides (e.g. FULL) cannot
1181+
// leverage dynamic filters. Return early before attempting to
1182+
// create one.
1183+
return Ok(FilterDescription::new()
1184+
.with_child(left_child)
1185+
.with_child(right_child));
1186+
}
1187+
JoinSide::Left => {
1188+
left_child = left_child.with_self_filter(dynamic_filter);
1189+
}
1190+
JoinSide::Right => {
1191+
right_child = right_child.with_self_filter(dynamic_filter);
1192+
}
1193+
}
11451194
}
11461195

11471196
Ok(FilterDescription::new()
@@ -1159,7 +1208,8 @@ impl ExecutionPlan for HashJoinExec {
11591208
// non-inner joins in `gather_filters_for_pushdown`.
11601209
// However it's a cheap check and serves to inform future devs touching this function that they need to be really
11611210
// careful pushing down filters through non-inner joins.
1162-
if self.join_type != JoinType::Inner {
1211+
let pushdown_side = find_filter_pushdown_sides(self.join_type);
1212+
if pushdown_side == JoinSide::None {
11631213
// Other types of joins can support *some* filters, but restrictions are complex and error prone.
11641214
// For now we don't support them.
11651215
// See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs
@@ -1170,9 +1220,13 @@ impl ExecutionPlan for HashJoinExec {
11701220

11711221
let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone());
11721222
assert_eq!(child_pushdown_result.self_filters.len(), 2); // Should always be 2, we have 2 children
1173-
let right_child_self_filters = &child_pushdown_result.self_filters[1]; // We only push down filters to the right child
1223+
let self_filters = match pushdown_side {
1224+
JoinSide::Left => &child_pushdown_result.self_filters[0],
1225+
JoinSide::Right => &child_pushdown_result.self_filters[1],
1226+
JoinSide::None => unreachable!(),
1227+
};
11741228
// We expect 0 or 1 self filters
1175-
if let Some(filter) = right_child_self_filters.first() {
1229+
if let Some(filter) = self_filters.first() {
11761230
// Note that we don't check PushdDownPredicate::discrimnant because even if nothing said
11771231
// "yes, I can fully evaluate this filter" things might still use it for statistics -> it's worth updating
11781232
let predicate = Arc::clone(&filter.predicate);
@@ -4518,4 +4572,112 @@ mod tests {
45184572
fn columns(schema: &Schema) -> Vec<String> {
45194573
schema.fields().iter().map(|f| f.name().clone()).collect()
45204574
}
4575+
4576+
#[test]
4577+
fn create_dynamic_filter_none_side_returns_error() {
4578+
let on: JoinOn = vec![];
4579+
let err = HashJoinExec::create_dynamic_filter(&on, JoinSide::None).unwrap_err();
4580+
assert_contains!(err.to_string(), "dynamic filter side must be specified");
4581+
}
4582+
4583+
#[test]
4584+
fn full_join_skips_dynamic_filter_creation() -> Result<()> {
4585+
use arrow::array::Int32Array;
4586+
use arrow::datatypes::{DataType, Field, Schema};
4587+
use datafusion_physical_expr::expressions::col;
4588+
4589+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
4590+
let batch = RecordBatch::try_new(
4591+
Arc::clone(&schema),
4592+
vec![Arc::new(Int32Array::from(vec![1]))],
4593+
)?;
4594+
let left =
4595+
TestMemoryExec::try_new(&[vec![batch.clone()]], Arc::clone(&schema), None)?;
4596+
let right = TestMemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None)?;
4597+
4598+
let on = vec![(col("a", &left.schema())?, col("a", &right.schema())?)];
4599+
let join = HashJoinExec::try_new(
4600+
Arc::new(left),
4601+
Arc::new(right),
4602+
on,
4603+
None,
4604+
&JoinType::Full,
4605+
None,
4606+
PartitionMode::CollectLeft,
4607+
NullEquality::NullEqualsNull,
4608+
)?;
4609+
4610+
let mut config = ConfigOptions::default();
4611+
config.optimizer.enable_dynamic_filter_pushdown = true;
4612+
4613+
let desc =
4614+
join.gather_filters_for_pushdown(FilterPushdownPhase::Post, vec![], &config)?;
4615+
assert!(desc.self_filters().iter().all(|f| f.is_empty()));
4616+
Ok(())
4617+
}
4618+
4619+
// This test verifies that when a HashJoinExec is created with a dynamic filter
4620+
// targeting the left side, the join build phase collects min/max bounds from
4621+
// the build-side input and reports them back into the dynamic filter for the
4622+
// other side. Concretely:
4623+
// - Left input has values [1, 3, 5]
4624+
// - Right (build) input has values [2, 4, 6]
4625+
// - JoinType::Right is used so that the right side acts as the build side
4626+
// and the dynamic filter is attached to the left side expression.
4627+
// - After fully executing the join, the dynamic filter should be updated
4628+
// with the observed bounds `a@0 >= 2 AND a@0 <= 6` (min=2, max=6).
4629+
// The test asserts that HashJoinExec correctly accumulates and reports these
4630+
// bounds so downstream consumers can use the dynamic predicate for pruning.
4631+
#[tokio::test]
4632+
async fn reports_bounds_when_dynamic_filter_side_left() -> Result<()> {
4633+
use datafusion_physical_expr::expressions::col;
4634+
4635+
let task_ctx = Arc::new(TaskContext::default());
4636+
4637+
let left_schema =
4638+
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
4639+
let left_batch = RecordBatch::try_new(
4640+
Arc::clone(&left_schema),
4641+
vec![Arc::new(Int32Array::from(vec![1, 3, 5]))],
4642+
)?;
4643+
let left = TestMemoryExec::try_new(&[vec![left_batch]], left_schema, None)?;
4644+
4645+
let right_schema =
4646+
Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, false)]));
4647+
let right_batch = RecordBatch::try_new(
4648+
Arc::clone(&right_schema),
4649+
vec![Arc::new(Int32Array::from(vec![2, 4, 6]))],
4650+
)?;
4651+
let right = TestMemoryExec::try_new(&[vec![right_batch]], right_schema, None)?;
4652+
4653+
let on = vec![(col("a", &left.schema())?, col("b", &right.schema())?)];
4654+
4655+
let mut join = HashJoinExec::try_new(
4656+
Arc::new(left),
4657+
Arc::new(right),
4658+
on,
4659+
None,
4660+
&JoinType::Right,
4661+
None,
4662+
PartitionMode::CollectLeft,
4663+
NullEquality::NullEqualsNull,
4664+
)?;
4665+
4666+
let dynamic_filter: Arc<DynamicFilterPhysicalExpr> =
4667+
HashJoinExec::create_dynamic_filter(&join.on, JoinSide::Left)?;
4668+
join.dynamic_filter = Some(HashJoinExecDynamicFilter {
4669+
filter: Arc::clone(&dynamic_filter),
4670+
bounds_accumulator: OnceLock::new(),
4671+
});
4672+
4673+
let stream = join.execute(0, task_ctx)?;
4674+
let _batches: Vec<RecordBatch> = stream.try_collect().await?;
4675+
4676+
assert_eq!(
4677+
format!("{}", dynamic_filter.current().unwrap()),
4678+
"a@0 >= 2 AND a@0 <= 6"
4679+
);
4680+
4681+
Ok(())
4682+
}
45214683
}

0 commit comments

Comments
 (0)