Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 132 additions & 1 deletion datafusion/physical-plan/src/joins/hash_join/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2125,6 +2125,7 @@ mod tests {
Ok((left_schema, right_schema, on))
}

use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use crate::coalesce_partitions::CoalescePartitionsExec;
use crate::joins::hash_join::stream::lookup_join_hashmap;
use crate::test::{TestMemoryExec, assert_join_metrics};
Expand All @@ -2147,7 +2148,9 @@ mod tests {
use datafusion_execution::config::SessionConfig;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::{BinaryExpr, Literal, col};
use hashbrown::HashTable;
use insta::{allow_duplicates, assert_snapshot};
use rstest::*;
Expand Down Expand Up @@ -2500,6 +2503,134 @@ mod tests {
Ok((columns, batches, metrics))
}

fn aggregate_join_group_key(i: usize) -> u32 {
(i as u32) * 1000
}

async fn final_aggregate_build_side(num_groups: usize) -> Result<Arc<AggregateExec>> {
let raw_schema = Arc::new(Schema::new(vec![
Field::new("group_key", DataType::UInt32, false),
Field::new("value", DataType::UInt64, false),
]));
let batch = RecordBatch::try_new(
Arc::clone(&raw_schema),
vec![
Arc::new(UInt32Array::from_iter_values(
(0..num_groups).map(aggregate_join_group_key),
)),
Arc::new(UInt64Array::from(vec![1; num_groups])),
],
)?;
let input =
TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&raw_schema), None)?;

let group_by = PhysicalGroupBy::new_single(vec![(
col("group_key", &raw_schema)?,
"group_key".to_string(),
)]);
let aggregates = vec![Arc::new(
AggregateExprBuilder::new(count_udaf(), vec![col("value", &raw_schema)?])
.schema(Arc::clone(&raw_schema))
.alias("count_value")
.build()?,
)];
let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
group_by.clone(),
aggregates.clone(),
vec![None],
input,
Arc::clone(&raw_schema),
)?);
let partial_schema = partial_aggregate.schema();
let partial_batches = common::collect(
partial_aggregate.execute(0, Arc::new(TaskContext::default()))?,
)
.await?;
let partial_input = TestMemoryExec::try_new_exec(
&[partial_batches],
Arc::clone(&partial_schema),
None,
)?;

Ok(Arc::new(AggregateExec::try_new(
AggregateMode::Final,
group_by.as_final(),
aggregates.clone(),
vec![None; aggregates.len()],
partial_input,
Arc::clone(&raw_schema),
)?))
}

#[tokio::test]
async fn build_side_final_aggregate_respects_grouped_memory_limit() -> Result<()> {
const BATCH_SIZE: usize = 8192;
const NUM_GROUPS: usize = BATCH_SIZE * 32 + 1;
const EXPECTED_JOIN_ROWS: usize = 3;

let aggregate = final_aggregate_build_side(NUM_GROUPS).await?;
let aggregate_batches =
common::collect(aggregate.execute(0, Arc::new(TaskContext::default()))?)
.await?;
assert!(aggregate_batches.len() > 1);
assert_eq!(
aggregate_batches
.iter()
.map(RecordBatch::num_rows)
.sum::<usize>(),
NUM_GROUPS
);
let aggregate_batch = concat_batches(&aggregate.schema(), &aggregate_batches)?;
let memory_limit = get_record_batch_memory_size(&aggregate_batch) * 4;

let probe_schema = Arc::new(Schema::new(vec![Field::new(
"probe_key",
DataType::UInt32,
false,
)]));
let probe_batch = RecordBatch::try_new(
Arc::clone(&probe_schema),
vec![Arc::new(UInt32Array::from(vec![
aggregate_join_group_key(0),
aggregate_join_group_key(NUM_GROUPS / 2),
aggregate_join_group_key(NUM_GROUPS - 1),
]))],
)?;
let probe: Arc<dyn ExecutionPlan> = TestMemoryExec::try_new_exec(
&[vec![probe_batch]],
Arc::clone(&probe_schema),
None,
)?;

let aggregate: Arc<dyn ExecutionPlan> = aggregate;
let join = HashJoinExec::try_new(
Arc::clone(&aggregate),
probe,
vec![(
Arc::new(Column::new_with_schema("group_key", &aggregate.schema())?) as _,
Arc::new(Column::new_with_schema("probe_key", &probe_schema)?) as _,
)],
None,
&JoinType::Inner,
None,
PartitionMode::CollectLeft,
NullEquality::NullEqualsNothing,
false,
)?;
let runtime = RuntimeEnvBuilder::new()
.with_memory_limit(memory_limit, 1.0)
.build_arc()?;
let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime));
let batches = common::collect(join.execute(0, task_ctx)?).await?;
assert_eq!(
batches.iter().map(RecordBatch::num_rows).sum::<usize>(),
EXPECTED_JOIN_ROWS
);

Ok(())
}

#[apply(hash_join_exec_configs)]
#[tokio::test]
async fn join_inner_one(
Expand Down
Loading