diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 03387c316b8e1..14adc08a7c056 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -2125,6 +2125,7 @@ mod tests { Ok((left_schema, right_schema, on)) } + use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::joins::hash_join::stream::lookup_join_hashmap; use crate::test::{TestMemoryExec, assert_join_metrics}; @@ -2147,7 +2148,9 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::{BinaryExpr, Literal, col}; use hashbrown::HashTable; use insta::{allow_duplicates, assert_snapshot}; use rstest::*; @@ -2500,6 +2503,134 @@ mod tests { Ok((columns, batches, metrics)) } + fn aggregate_join_group_key(i: usize) -> u32 { + (i as u32) * 1000 + } + + async fn final_aggregate_build_side(num_groups: usize) -> Result> { + let raw_schema = Arc::new(Schema::new(vec![ + Field::new("group_key", DataType::UInt32, false), + Field::new("value", DataType::UInt64, false), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&raw_schema), + vec![ + Arc::new(UInt32Array::from_iter_values( + (0..num_groups).map(aggregate_join_group_key), + )), + Arc::new(UInt64Array::from(vec![1; num_groups])), + ], + )?; + let input = + TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&raw_schema), None)?; + + let group_by = PhysicalGroupBy::new_single(vec![( + col("group_key", &raw_schema)?, + "group_key".to_string(), + )]); + let aggregates = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value", &raw_schema)?]) + .schema(Arc::clone(&raw_schema)) + .alias("count_value") + .build()?, + )]; + let partial_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggregates.clone(), + vec![None], + input, + Arc::clone(&raw_schema), + )?); + let partial_schema = partial_aggregate.schema(); + let partial_batches = common::collect( + partial_aggregate.execute(0, Arc::new(TaskContext::default()))?, + ) + .await?; + let partial_input = TestMemoryExec::try_new_exec( + &[partial_batches], + Arc::clone(&partial_schema), + None, + )?; + + Ok(Arc::new(AggregateExec::try_new( + AggregateMode::Final, + group_by.as_final(), + aggregates.clone(), + vec![None; aggregates.len()], + partial_input, + Arc::clone(&raw_schema), + )?)) + } + + #[tokio::test] + async fn build_side_final_aggregate_respects_grouped_memory_limit() -> Result<()> { + const BATCH_SIZE: usize = 8192; + const NUM_GROUPS: usize = BATCH_SIZE * 32 + 1; + const EXPECTED_JOIN_ROWS: usize = 3; + + let aggregate = final_aggregate_build_side(NUM_GROUPS).await?; + let aggregate_batches = + common::collect(aggregate.execute(0, Arc::new(TaskContext::default()))?) + .await?; + assert!(aggregate_batches.len() > 1); + assert_eq!( + aggregate_batches + .iter() + .map(RecordBatch::num_rows) + .sum::(), + NUM_GROUPS + ); + let aggregate_batch = concat_batches(&aggregate.schema(), &aggregate_batches)?; + let memory_limit = get_record_batch_memory_size(&aggregate_batch) * 4; + + let probe_schema = Arc::new(Schema::new(vec![Field::new( + "probe_key", + DataType::UInt32, + false, + )])); + let probe_batch = RecordBatch::try_new( + Arc::clone(&probe_schema), + vec![Arc::new(UInt32Array::from(vec![ + aggregate_join_group_key(0), + aggregate_join_group_key(NUM_GROUPS / 2), + aggregate_join_group_key(NUM_GROUPS - 1), + ]))], + )?; + let probe: Arc = TestMemoryExec::try_new_exec( + &[vec![probe_batch]], + Arc::clone(&probe_schema), + None, + )?; + + let aggregate: Arc = aggregate; + let join = HashJoinExec::try_new( + Arc::clone(&aggregate), + probe, + vec![( + Arc::new(Column::new_with_schema("group_key", &aggregate.schema())?) as _, + Arc::new(Column::new_with_schema("probe_key", &probe_schema)?) as _, + )], + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?; + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .build_arc()?; + let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime)); + let batches = common::collect(join.execute(0, task_ctx)?).await?; + assert_eq!( + batches.iter().map(RecordBatch::num_rows).sum::(), + EXPECTED_JOIN_ROWS + ); + + Ok(()) + } + #[apply(hash_join_exec_configs)] #[tokio::test] async fn join_inner_one(