diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 365d8e9fd9a42..9f77e5b290773 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -565,12 +565,17 @@ message CsvScanExecNode { repeated string filename = 8; } +enum PartitionMode { + COLLECT_LEFT = 0; + PARTITIONED = 1; +} + message HashJoinExecNode { PhysicalPlanNode left = 1; PhysicalPlanNode right = 2; repeated JoinOn on = 3; JoinType join_type = 4; - + PartitionMode partition_mode = 6; } message PhysicalColumn { diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 4b87be4105be0..cf4e824ab9e8c 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -356,12 +356,24 @@ impl TryInto> for &protobuf::PhysicalPlanNode { protobuf::JoinType::Semi => JoinType::Semi, protobuf::JoinType::Anti => JoinType::Anti, }; + let partition_mode = + protobuf::PartitionMode::from_i32(hashjoin.partition_mode) + .ok_or_else(|| { + proto_error(format!( + "Received a HashJoinNode message with unknown PartitionMode {}", + hashjoin.partition_mode + )) + })?; + let partition_mode = match partition_mode { + protobuf::PartitionMode::CollectLeft => PartitionMode::CollectLeft, + protobuf::PartitionMode::Partitioned => PartitionMode::Partitioned, + }; Ok(Arc::new(HashJoinExec::try_new( left, right, on, &join_type, - PartitionMode::CollectLeft, + partition_mode, )?)) } PhysicalPlanType::ShuffleReader(shuffle_reader) => { diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index c0fe81f0ffb91..a393d7fdab1f7 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -88,13 +88,29 @@ mod roundtrip_tests { Column::new("col", schema_right.index_of("col")?), )]; - roundtrip_test(Arc::new(HashJoinExec::try_new( - Arc::new(EmptyExec::new(false, Arc::new(schema_left))), - Arc::new(EmptyExec::new(false, Arc::new(schema_right))), - on, - &JoinType::Inner, - PartitionMode::CollectLeft, - )?)) + let schema_left = Arc::new(schema_left); + let schema_right = Arc::new(schema_right); + for join_type in &[ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::Anti, + JoinType::Semi, + ] { + for partition_mode in + &[PartitionMode::Partitioned, PartitionMode::CollectLeft] + { + roundtrip_test(Arc::new(HashJoinExec::try_new( + Arc::new(EmptyExec::new(false, schema_left.clone())), + Arc::new(EmptyExec::new(false, schema_right.clone())), + on.clone(), + &join_type, + *partition_mode, + )?))?; + } + } + Ok(()) } #[test] diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index cf5401b650193..314155b0e21b0 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -34,7 +34,7 @@ use datafusion::physical_plan::expressions::{ use datafusion::physical_plan::expressions::{CastExpr, TryCastExpr}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::AggregateMode; -use datafusion::physical_plan::hash_join::HashJoinExec; +use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::hash_utils::JoinType; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::parquet::ParquetExec; @@ -143,6 +143,10 @@ impl TryInto for Arc { JoinType::Semi => protobuf::JoinType::Semi, JoinType::Anti => protobuf::JoinType::Anti, }; + let partition_mode = match exec.partition_mode() { + PartitionMode::CollectLeft => protobuf::PartitionMode::CollectLeft, + PartitionMode::Partitioned => protobuf::PartitionMode::Partitioned, + }; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( protobuf::HashJoinExecNode { @@ -150,6 +154,7 @@ impl TryInto for Arc { right: Some(Box::new(right)), on, join_type: join_type.into(), + partition_mode: partition_mode.into(), }, ))), }) diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index ad356079387a0..581743c698dff 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -177,6 +177,11 @@ impl HashJoinExec { &self.join_type } + /// The partitioning mode of this hash join + pub fn partition_mode(&self) -> &PartitionMode { + &self.mode + } + /// Calculates column indices and left/right placement on input / output schemas and jointype fn column_indices_from_schema(&self) -> ArrowResult> { let (primary_is_left, primary_schema, secondary_schema) = match self.join_type {