From 06e64ae88b1ca17f1eb82829634cef2000887ed1 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Thu, 4 Jun 2026 12:30:50 +0530 Subject: [PATCH 1/4] test: initial repro test by codex --- .../physical-plan/src/joins/hash_join/exec.rs | 263 +++++++++++++++++- 1 file changed, 262 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 03387c316b8e1..f47fe1aadc08d 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -2125,6 +2125,7 @@ mod tests { Ok((left_schema, right_schema, on)) } + use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::joins::hash_join::stream::lookup_join_hashmap; use crate::test::{TestMemoryExec, assert_join_metrics}; @@ -2145,9 +2146,14 @@ mod tests { exec_err, internal_err, }; use datafusion_execution::config::SessionConfig; + use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::{ + AggregateExprBuilder, AggregateFunctionExpr, + }; + use datafusion_physical_expr::expressions::{BinaryExpr, Literal, col}; use hashbrown::HashTable; use insta::{allow_duplicates, assert_snapshot}; use rstest::*; @@ -2500,6 +2506,261 @@ mod tests { Ok((columns, batches, metrics)) } + #[derive(Clone)] + struct FinalAggregateBuildInput { + raw_schema: SchemaRef, + partial_schema: SchemaRef, + group_by: PhysicalGroupBy, + aggregates: Vec>, + partial_batches: Vec, + num_groups: usize, + } + + fn memory_limited_aggregate_join_task_ctx( + batch_size: usize, + memory_limit: Option, + ) -> Result> { + let mut session_config = SessionConfig::default().with_batch_size(batch_size); + + // Keep the repro focused on normal hash aggregation and hash join paths. + session_config + .options_mut() + .execution + .skip_partial_aggregation_probe_rows_threshold = usize::MAX; + session_config + .options_mut() + .execution + .perfect_hash_join_small_build_threshold = 0; + session_config + .options_mut() + .execution + .perfect_hash_join_min_key_density = f64::INFINITY; + + let mut runtime_builder = RuntimeEnvBuilder::new().with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled), + ); + if let Some(memory_limit) = memory_limit { + runtime_builder = runtime_builder.with_memory_limit(memory_limit, 1.0); + } + + Ok(Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime_builder.build_arc()?), + )) + } + + async fn final_aggregate_build_input( + num_groups: usize, + batch_size: usize, + ) -> Result { + let raw_schema = Arc::new(Schema::new(vec![ + Field::new("group_key", DataType::UInt32, false), + Field::new("value", DataType::UInt64, false), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&raw_schema), + vec![ + Arc::new(UInt32Array::from_iter_values(0..num_groups as u32)), + Arc::new(UInt64Array::from(vec![1; num_groups])), + ], + )?; + let input = + TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&raw_schema), None)?; + + let group_by = PhysicalGroupBy::new_single(vec![( + col("group_key", &raw_schema)?, + "group_key".to_string(), + )]); + let aggregates = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value", &raw_schema)?]) + .schema(Arc::clone(&raw_schema)) + .alias("count_value") + .build()?, + )]; + let partial_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggregates.clone(), + vec![None], + input, + Arc::clone(&raw_schema), + )?); + let partial_schema = partial_aggregate.schema(); + let task_ctx = memory_limited_aggregate_join_task_ctx(batch_size, None)?; + let partial_batches = + common::collect(partial_aggregate.execute(0, task_ctx)?).await?; + + Ok(FinalAggregateBuildInput { + raw_schema, + partial_schema, + group_by, + aggregates, + partial_batches, + num_groups, + }) + } + + fn final_aggregate(input: &FinalAggregateBuildInput) -> Result> { + let partial_batches = input.partial_batches.clone(); + let partial_input = TestMemoryExec::try_new_exec( + &[partial_batches], + Arc::clone(&input.partial_schema), + None, + )?; + + Ok(Arc::new(AggregateExec::try_new( + AggregateMode::Final, + input.group_by.as_final(), + input.aggregates.clone(), + vec![None; input.aggregates.len()], + partial_input, + Arc::clone(&input.raw_schema), + )?)) + } + + fn probe_side(num_groups: usize) -> Result> { + let schema = Arc::new(Schema::new(vec![Field::new( + "probe_key", + DataType::UInt32, + false, + )])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(UInt32Array::from(vec![ + 0, + (num_groups / 2) as u32, + (num_groups - 1) as u32, + ]))], + )?; + + let exec: Arc = + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?; + + Ok(exec) + } + + async fn final_aggregate_peak_mem_used( + input: &FinalAggregateBuildInput, + batch_size: usize, + ) -> Result { + let aggregate = final_aggregate(input)?; + let task_ctx = memory_limited_aggregate_join_task_ctx(batch_size, None)?; + let batches = common::collect(aggregate.execute(0, task_ctx)?).await?; + + assert!( + batches.len() > 1, + "expected final aggregate output to be split into multiple batches" + ); + assert_eq!( + batches.iter().map(RecordBatch::num_rows).sum::(), + input.num_groups + ); + + let metrics = aggregate.metrics().expect("aggregate metrics"); + let peak_mem_used = metrics + .sum_by_name("peak_mem_used") + .expect("peak_mem_used metric") + .as_usize(); + assert!( + peak_mem_used > 0, + "expected non-zero final aggregate peak memory" + ); + + Ok(peak_mem_used) + } + + async fn run_aggregate_build_side_join( + input: &FinalAggregateBuildInput, + batch_size: usize, + memory_limit: usize, + ) -> Result> { + let aggregate: Arc = final_aggregate(input)?; + let right = probe_side(input.num_groups)?; + let on = vec![( + Arc::new(Column::new_with_schema("group_key", &aggregate.schema())?) as _, + Arc::new(Column::new_with_schema("probe_key", &right.schema())?) as _, + )]; + let join = HashJoinExec::try_new( + aggregate, + right, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?; + + let task_ctx = + memory_limited_aggregate_join_task_ctx(batch_size, Some(memory_limit))?; + common::collect(join.execute(0, task_ctx)?).await + } + + async fn first_passing_aggregate_build_side_join_multiplier( + input: &FinalAggregateBuildInput, + batch_size: usize, + aggregate_peak_mem_used: usize, + max_multiplier: usize, + ) -> Result> { + for multiplier in 3..=max_multiplier { + if run_aggregate_build_side_join( + input, + batch_size, + aggregate_peak_mem_used * multiplier, + ) + .await + .is_ok() + { + return Ok(Some(multiplier)); + } + } + + Ok(None) + } + + #[tokio::test] + async fn build_side_final_aggregate_respects_grouped_memory_limit() -> Result<()> { + const BATCH_SIZE: usize = 8192; + const NUM_GROUPS: usize = BATCH_SIZE * 32 + 1; + const EXPECTED_JOIN_ROWS: usize = 3; + + let aggregate_input = final_aggregate_build_input(NUM_GROUPS, BATCH_SIZE).await?; + let aggregate_peak_mem_used = + final_aggregate_peak_mem_used(&aggregate_input, BATCH_SIZE).await?; + let memory_limit = aggregate_peak_mem_used * 2; + + match run_aggregate_build_side_join(&aggregate_input, BATCH_SIZE, memory_limit) + .await + { + Ok(batches) => { + assert_eq!( + batches.iter().map(RecordBatch::num_rows).sum::(), + EXPECTED_JOIN_ROWS + ); + } + Err(err) => { + let passing_multiplier = + first_passing_aggregate_build_side_join_multiplier( + &aggregate_input, + BATCH_SIZE, + aggregate_peak_mem_used, + 64, + ) + .await?; + panic!( + "HashJoinExec build side should pass with a memory limit of 2x \ + final AggregateExec peak grouped memory ({aggregate_peak_mem_used} bytes), \ + but failed with limit {memory_limit} bytes: {err}. Current smallest \ + passing multiplier up to 64x: {passing_multiplier:?}" + ); + } + } + + Ok(()) + } + #[apply(hash_join_exec_configs)] #[tokio::test] async fn join_inner_one( From 20e5ada0b5f678713fe0443b3123a09233234209 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Thu, 4 Jun 2026 12:33:53 +0530 Subject: [PATCH 2/4] test: a bit smaller test --- .../physical-plan/src/joins/hash_join/exec.rs | 55 +++---------------- 1 file changed, 7 insertions(+), 48 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index f47fe1aadc08d..366ab4de4b54e 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -2698,28 +2698,6 @@ mod tests { common::collect(join.execute(0, task_ctx)?).await } - async fn first_passing_aggregate_build_side_join_multiplier( - input: &FinalAggregateBuildInput, - batch_size: usize, - aggregate_peak_mem_used: usize, - max_multiplier: usize, - ) -> Result> { - for multiplier in 3..=max_multiplier { - if run_aggregate_build_side_join( - input, - batch_size, - aggregate_peak_mem_used * multiplier, - ) - .await - .is_ok() - { - return Ok(Some(multiplier)); - } - } - - Ok(None) - } - #[tokio::test] async fn build_side_final_aggregate_respects_grouped_memory_limit() -> Result<()> { const BATCH_SIZE: usize = 8192; @@ -2731,32 +2709,13 @@ mod tests { final_aggregate_peak_mem_used(&aggregate_input, BATCH_SIZE).await?; let memory_limit = aggregate_peak_mem_used * 2; - match run_aggregate_build_side_join(&aggregate_input, BATCH_SIZE, memory_limit) - .await - { - Ok(batches) => { - assert_eq!( - batches.iter().map(RecordBatch::num_rows).sum::(), - EXPECTED_JOIN_ROWS - ); - } - Err(err) => { - let passing_multiplier = - first_passing_aggregate_build_side_join_multiplier( - &aggregate_input, - BATCH_SIZE, - aggregate_peak_mem_used, - 64, - ) - .await?; - panic!( - "HashJoinExec build side should pass with a memory limit of 2x \ - final AggregateExec peak grouped memory ({aggregate_peak_mem_used} bytes), \ - but failed with limit {memory_limit} bytes: {err}. Current smallest \ - passing multiplier up to 64x: {passing_multiplier:?}" - ); - } - } + let batches = + run_aggregate_build_side_join(&aggregate_input, BATCH_SIZE, memory_limit) + .await?; + assert_eq!( + batches.iter().map(RecordBatch::num_rows).sum::(), + EXPECTED_JOIN_ROWS + ); Ok(()) } From d3235a9f19372ce1567a8754d8480ec2c0229d59 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Thu, 4 Jun 2026 12:36:54 +0530 Subject: [PATCH 3/4] test: more simplify --- .../physical-plan/src/joins/hash_join/exec.rs | 68 +++++++------------ 1 file changed, 24 insertions(+), 44 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 366ab4de4b54e..04eecdc85c1c8 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -2150,9 +2150,7 @@ mod tests { use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::Operator; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_physical_expr::aggregate::{ - AggregateExprBuilder, AggregateFunctionExpr, - }; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::{BinaryExpr, Literal, col}; use hashbrown::HashTable; use insta::{allow_duplicates, assert_snapshot}; @@ -2506,16 +2504,6 @@ mod tests { Ok((columns, batches, metrics)) } - #[derive(Clone)] - struct FinalAggregateBuildInput { - raw_schema: SchemaRef, - partial_schema: SchemaRef, - group_by: PhysicalGroupBy, - aggregates: Vec>, - partial_batches: Vec, - num_groups: usize, - } - fn memory_limited_aggregate_join_task_ctx( batch_size: usize, memory_limit: Option, @@ -2550,10 +2538,10 @@ mod tests { )) } - async fn final_aggregate_build_input( + async fn final_aggregate_build_side( num_groups: usize, batch_size: usize, - ) -> Result { + ) -> Result> { let raw_schema = Arc::new(Schema::new(vec![ Field::new("group_key", DataType::UInt32, false), Field::new("value", DataType::UInt64, false), @@ -2590,32 +2578,19 @@ mod tests { let task_ctx = memory_limited_aggregate_join_task_ctx(batch_size, None)?; let partial_batches = common::collect(partial_aggregate.execute(0, task_ctx)?).await?; - - Ok(FinalAggregateBuildInput { - raw_schema, - partial_schema, - group_by, - aggregates, - partial_batches, - num_groups, - }) - } - - fn final_aggregate(input: &FinalAggregateBuildInput) -> Result> { - let partial_batches = input.partial_batches.clone(); let partial_input = TestMemoryExec::try_new_exec( &[partial_batches], - Arc::clone(&input.partial_schema), + Arc::clone(&partial_schema), None, )?; Ok(Arc::new(AggregateExec::try_new( AggregateMode::Final, - input.group_by.as_final(), - input.aggregates.clone(), - vec![None; input.aggregates.len()], + group_by.as_final(), + aggregates.clone(), + vec![None; aggregates.len()], partial_input, - Arc::clone(&input.raw_schema), + Arc::clone(&raw_schema), )?)) } @@ -2641,10 +2616,10 @@ mod tests { } async fn final_aggregate_peak_mem_used( - input: &FinalAggregateBuildInput, + aggregate: &Arc, + num_groups: usize, batch_size: usize, ) -> Result { - let aggregate = final_aggregate(input)?; let task_ctx = memory_limited_aggregate_join_task_ctx(batch_size, None)?; let batches = common::collect(aggregate.execute(0, task_ctx)?).await?; @@ -2654,7 +2629,7 @@ mod tests { ); assert_eq!( batches.iter().map(RecordBatch::num_rows).sum::(), - input.num_groups + num_groups ); let metrics = aggregate.metrics().expect("aggregate metrics"); @@ -2671,12 +2646,13 @@ mod tests { } async fn run_aggregate_build_side_join( - input: &FinalAggregateBuildInput, + aggregate: Arc, + num_groups: usize, batch_size: usize, memory_limit: usize, ) -> Result> { - let aggregate: Arc = final_aggregate(input)?; - let right = probe_side(input.num_groups)?; + let aggregate: Arc = aggregate; + let right = probe_side(num_groups)?; let on = vec![( Arc::new(Column::new_with_schema("group_key", &aggregate.schema())?) as _, Arc::new(Column::new_with_schema("probe_key", &right.schema())?) as _, @@ -2704,14 +2680,18 @@ mod tests { const NUM_GROUPS: usize = BATCH_SIZE * 32 + 1; const EXPECTED_JOIN_ROWS: usize = 3; - let aggregate_input = final_aggregate_build_input(NUM_GROUPS, BATCH_SIZE).await?; + let aggregate = final_aggregate_build_side(NUM_GROUPS, BATCH_SIZE).await?; let aggregate_peak_mem_used = - final_aggregate_peak_mem_used(&aggregate_input, BATCH_SIZE).await?; + final_aggregate_peak_mem_used(&aggregate, NUM_GROUPS, BATCH_SIZE).await?; let memory_limit = aggregate_peak_mem_used * 2; - let batches = - run_aggregate_build_side_join(&aggregate_input, BATCH_SIZE, memory_limit) - .await?; + let batches = run_aggregate_build_side_join( + aggregate, + NUM_GROUPS, + BATCH_SIZE, + memory_limit, + ) + .await?; assert_eq!( batches.iter().map(RecordBatch::num_rows).sum::(), EXPECTED_JOIN_ROWS From 3b8a2777403bfd714dc97fca20b47e143d2131cb Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Thu, 4 Jun 2026 12:45:56 +0530 Subject: [PATCH 4/4] test: simplify even more --- .../physical-plan/src/joins/hash_join/exec.rs | 175 ++++++------------ 1 file changed, 53 insertions(+), 122 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 04eecdc85c1c8..14adc08a7c056 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -2146,7 +2146,6 @@ mod tests { exec_err, internal_err, }; use datafusion_execution::config::SessionConfig; - use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::Operator; use datafusion_functions_aggregate::count::count_udaf; @@ -2504,44 +2503,11 @@ mod tests { Ok((columns, batches, metrics)) } - fn memory_limited_aggregate_join_task_ctx( - batch_size: usize, - memory_limit: Option, - ) -> Result> { - let mut session_config = SessionConfig::default().with_batch_size(batch_size); - - // Keep the repro focused on normal hash aggregation and hash join paths. - session_config - .options_mut() - .execution - .skip_partial_aggregation_probe_rows_threshold = usize::MAX; - session_config - .options_mut() - .execution - .perfect_hash_join_small_build_threshold = 0; - session_config - .options_mut() - .execution - .perfect_hash_join_min_key_density = f64::INFINITY; - - let mut runtime_builder = RuntimeEnvBuilder::new().with_disk_manager_builder( - DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled), - ); - if let Some(memory_limit) = memory_limit { - runtime_builder = runtime_builder.with_memory_limit(memory_limit, 1.0); - } - - Ok(Arc::new( - TaskContext::default() - .with_session_config(session_config) - .with_runtime(runtime_builder.build_arc()?), - )) + fn aggregate_join_group_key(i: usize) -> u32 { + (i as u32) * 1000 } - async fn final_aggregate_build_side( - num_groups: usize, - batch_size: usize, - ) -> Result> { + async fn final_aggregate_build_side(num_groups: usize) -> Result> { let raw_schema = Arc::new(Schema::new(vec![ Field::new("group_key", DataType::UInt32, false), Field::new("value", DataType::UInt64, false), @@ -2549,7 +2515,9 @@ mod tests { let batch = RecordBatch::try_new( Arc::clone(&raw_schema), vec![ - Arc::new(UInt32Array::from_iter_values(0..num_groups as u32)), + Arc::new(UInt32Array::from_iter_values( + (0..num_groups).map(aggregate_join_group_key), + )), Arc::new(UInt64Array::from(vec![1; num_groups])), ], )?; @@ -2575,9 +2543,10 @@ mod tests { Arc::clone(&raw_schema), )?); let partial_schema = partial_aggregate.schema(); - let task_ctx = memory_limited_aggregate_join_task_ctx(batch_size, None)?; - let partial_batches = - common::collect(partial_aggregate.execute(0, task_ctx)?).await?; + let partial_batches = common::collect( + partial_aggregate.execute(0, Arc::new(TaskContext::default()))?, + ) + .await?; let partial_input = TestMemoryExec::try_new_exec( &[partial_batches], Arc::clone(&partial_schema), @@ -2594,73 +2563,54 @@ mod tests { )?)) } - fn probe_side(num_groups: usize) -> Result> { - let schema = Arc::new(Schema::new(vec![Field::new( + #[tokio::test] + async fn build_side_final_aggregate_respects_grouped_memory_limit() -> Result<()> { + const BATCH_SIZE: usize = 8192; + const NUM_GROUPS: usize = BATCH_SIZE * 32 + 1; + const EXPECTED_JOIN_ROWS: usize = 3; + + let aggregate = final_aggregate_build_side(NUM_GROUPS).await?; + let aggregate_batches = + common::collect(aggregate.execute(0, Arc::new(TaskContext::default()))?) + .await?; + assert!(aggregate_batches.len() > 1); + assert_eq!( + aggregate_batches + .iter() + .map(RecordBatch::num_rows) + .sum::(), + NUM_GROUPS + ); + let aggregate_batch = concat_batches(&aggregate.schema(), &aggregate_batches)?; + let memory_limit = get_record_batch_memory_size(&aggregate_batch) * 4; + + let probe_schema = Arc::new(Schema::new(vec![Field::new( "probe_key", DataType::UInt32, false, )])); - let batch = RecordBatch::try_new( - Arc::clone(&schema), + let probe_batch = RecordBatch::try_new( + Arc::clone(&probe_schema), vec![Arc::new(UInt32Array::from(vec![ - 0, - (num_groups / 2) as u32, - (num_groups - 1) as u32, + aggregate_join_group_key(0), + aggregate_join_group_key(NUM_GROUPS / 2), + aggregate_join_group_key(NUM_GROUPS - 1), ]))], )?; + let probe: Arc = TestMemoryExec::try_new_exec( + &[vec![probe_batch]], + Arc::clone(&probe_schema), + None, + )?; - let exec: Arc = - TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?; - - Ok(exec) - } - - async fn final_aggregate_peak_mem_used( - aggregate: &Arc, - num_groups: usize, - batch_size: usize, - ) -> Result { - let task_ctx = memory_limited_aggregate_join_task_ctx(batch_size, None)?; - let batches = common::collect(aggregate.execute(0, task_ctx)?).await?; - - assert!( - batches.len() > 1, - "expected final aggregate output to be split into multiple batches" - ); - assert_eq!( - batches.iter().map(RecordBatch::num_rows).sum::(), - num_groups - ); - - let metrics = aggregate.metrics().expect("aggregate metrics"); - let peak_mem_used = metrics - .sum_by_name("peak_mem_used") - .expect("peak_mem_used metric") - .as_usize(); - assert!( - peak_mem_used > 0, - "expected non-zero final aggregate peak memory" - ); - - Ok(peak_mem_used) - } - - async fn run_aggregate_build_side_join( - aggregate: Arc, - num_groups: usize, - batch_size: usize, - memory_limit: usize, - ) -> Result> { let aggregate: Arc = aggregate; - let right = probe_side(num_groups)?; - let on = vec![( - Arc::new(Column::new_with_schema("group_key", &aggregate.schema())?) as _, - Arc::new(Column::new_with_schema("probe_key", &right.schema())?) as _, - )]; let join = HashJoinExec::try_new( - aggregate, - right, - on, + Arc::clone(&aggregate), + probe, + vec![( + Arc::new(Column::new_with_schema("group_key", &aggregate.schema())?) as _, + Arc::new(Column::new_with_schema("probe_key", &probe_schema)?) as _, + )], None, &JoinType::Inner, None, @@ -2668,30 +2618,11 @@ mod tests { NullEquality::NullEqualsNothing, false, )?; - - let task_ctx = - memory_limited_aggregate_join_task_ctx(batch_size, Some(memory_limit))?; - common::collect(join.execute(0, task_ctx)?).await - } - - #[tokio::test] - async fn build_side_final_aggregate_respects_grouped_memory_limit() -> Result<()> { - const BATCH_SIZE: usize = 8192; - const NUM_GROUPS: usize = BATCH_SIZE * 32 + 1; - const EXPECTED_JOIN_ROWS: usize = 3; - - let aggregate = final_aggregate_build_side(NUM_GROUPS, BATCH_SIZE).await?; - let aggregate_peak_mem_used = - final_aggregate_peak_mem_used(&aggregate, NUM_GROUPS, BATCH_SIZE).await?; - let memory_limit = aggregate_peak_mem_used * 2; - - let batches = run_aggregate_build_side_join( - aggregate, - NUM_GROUPS, - BATCH_SIZE, - memory_limit, - ) - .await?; + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .build_arc()?; + let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime)); + let batches = common::collect(join.execute(0, task_ctx)?).await?; assert_eq!( batches.iter().map(RecordBatch::num_rows).sum::(), EXPECTED_JOIN_ROWS