Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions datafusion/core/tests/physical_optimizer/aggregate_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,76 @@ async fn utf8_grouping_min_max_limit_fallbacks() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn null_min_max_topk_preserves_group_rows() -> Result<()> {
let mut config = SessionConfig::new();
config.options_mut().optimizer.enable_topk_aggregation = true;
let ctx = SessionContext::new_with_config(config);

let schema = Arc::new(Schema::new(vec![
Field::new("g", DataType::Utf8, true),
Field::new("x", DataType::Int64, true),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(StringArray::from(vec![Some("a"), Some("b")])),
Arc::new(Int64Array::from(vec![None, Some(5)])),
],
)?;
let table = MemTable::try_new(schema, vec![vec![batch]])?;
ctx.register_table("t", Arc::new(table))?;

let min_topk_df = ctx
.sql(
"SELECT min_x FROM (SELECT g, MIN(x) AS min_x FROM t GROUP BY g) q \
ORDER BY min_x ASC NULLS LAST LIMIT 20",
)
.await?;
let min_topk_plan = min_topk_df.clone().create_physical_plan().await?;
let min_topk_batches = min_topk_df.collect().await?;

let min_topk_plan_display =
displayable(min_topk_plan.as_ref()).indent(true).to_string();
assert!(
min_topk_plan_display.contains("lim=[20]"),
"Expected TopK aggregation optimization to remain enabled: {min_topk_plan_display}"
);
assert_batches_eq!(
&[
"+-------+",
"| min_x |",
"+-------+",
"| 5 |",
"| |",
"+-------+"
],
&min_topk_batches
);

let max_topk_batches = ctx
.sql(
"SELECT max_x FROM (SELECT g, MAX(x) AS max_x FROM t GROUP BY g) q \
ORDER BY max_x DESC NULLS FIRST LIMIT 20",
)
.await?
.collect()
.await?;
assert_batches_eq!(
&[
"+-------+",
"| max_x |",
"+-------+",
"| |",
"| 5 |",
"+-------+"
],
&max_topk_batches
);

Ok(())
}

#[tokio::test]
async fn test_count_distinct_optimization() -> Result<()> {
struct TestCase {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-optimizer/src/topk_aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ impl TopKAggregation {
let order = sort.properties().output_ordering()?;
let order = order.iter().exactly_one().ok()?;
let order_desc = order.options.descending;
if order.options.nulls_first != !order_desc {
return None;
}
let order = order.expr.downcast_ref::<Column>()?;
let mut cur_col_name = order.name().to_string();
let limit = sort.fetch()?;
Expand Down
129 changes: 92 additions & 37 deletions datafusion/physical-plan/src/aggregates/topk/heap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@
//! Supported value types include Arrow primitives (integers, floats, decimals, intervals)
//! and UTF-8 strings (`Utf8`, `LargeUtf8`, `Utf8View`) using lexicographic ordering.

use arrow::array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray, downcast_primitive};
use arrow::array::{
Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray, downcast_primitive,
};
use arrow::array::{LargeStringBuilder, StringBuilder, StringViewBuilder};
use arrow::array::{
StringArray,
cast::AsArray,
types::{IntervalDayTime, IntervalMonthDayNano},
};
use arrow::buffer::ScalarBuffer;
use arrow::buffer::{NullBuffer, ScalarBuffer};
use arrow::datatypes::{DataType, i256};
use datafusion_common::Result;
use datafusion_common::exec_datafusion_err;
Expand Down Expand Up @@ -93,17 +95,19 @@ pub trait ArrowHeap {
/// An implementation of `ArrowHeap` that deals with primitive values
pub struct PrimitiveHeap<VAL: ArrowPrimitiveType>
where
<VAL as ArrowPrimitiveType>::Native: Comparable,
<VAL as ArrowPrimitiveType>::Native: Comparable + Default,
Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
{
batch: ArrayRef,
heap: TopKHeap<VAL::Native>,
heap: TopKHeap<Option<VAL::Native>>,
desc: bool,
data_type: DataType,
}

impl<VAL: ArrowPrimitiveType> PrimitiveHeap<VAL>
where
<VAL as ArrowPrimitiveType>::Native: Comparable,
<VAL as ArrowPrimitiveType>::Native: Comparable + Default,
Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
{
pub fn new(limit: usize, desc: bool, data_type: DataType) -> Self {
let owned: ArrayRef = Arc::new(PrimitiveArray::<VAL>::builder(0).finish());
Expand All @@ -114,11 +118,21 @@ where
data_type,
}
}

fn value(&self, row_idx: usize) -> Option<VAL::Native> {
let vals = self.batch.as_primitive::<VAL>();
if vals.is_null(row_idx) {
None
} else {
Some(vals.value(row_idx))
}
}
}

impl<VAL: ArrowPrimitiveType> ArrowHeap for PrimitiveHeap<VAL>
where
<VAL as ArrowPrimitiveType>::Native: Comparable,
<VAL as ArrowPrimitiveType>::Native: Comparable + Default,
Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
{
fn set_batch(&mut self, vals: ArrayRef) {
self.batch = vals;
Expand All @@ -128,20 +142,19 @@ where
if !self.heap.is_full() {
return false;
}
let vals = self.batch.as_primitive::<VAL>();
let new_val = vals.value(row_idx);
let new_val = self.value(row_idx);
let worst_val = self.heap.worst_val().expect("Missing root");
(!self.desc && new_val > *worst_val) || (self.desc && new_val < *worst_val)
(!self.desc && new_val.comp(worst_val) == Ordering::Greater)
|| (self.desc && new_val.comp(worst_val) == Ordering::Less)
}

fn worst_map_idx(&self) -> usize {
self.heap.worst_map_idx()
}

fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>) {
let vals = self.batch.as_primitive::<VAL>();
let new_val = vals.value(row_idx);
self.heap.append_or_replace(new_val, map_idx, map);
self.heap
.append_or_replace(self.value(row_idx), map_idx, map);
}

fn replace_if_better(
Expand All @@ -150,15 +163,37 @@ where
row_idx: usize,
map: &mut Vec<(usize, usize)>,
) {
let vals = self.batch.as_primitive::<VAL>();
let new_val = vals.value(row_idx);
self.heap.replace_if_better(heap_idx, new_val, map);
let new_val = self.value(row_idx);
let existing = self.heap.heap[heap_idx]
.as_ref()
.expect("Missing heap item");

match (&existing.val, &new_val) {
// MIN/MAX ignore null inputs, so a non-null row must replace an
// existing null aggregate state for the same group.
(None, Some(_)) if !self.desc => {
self.heap.heap[heap_idx]
.as_mut()
.expect("Missing heap item")
.val = new_val;
self.heap.heapify_up(heap_idx, map);
}
(Some(_), None) => {}
_ => self.heap.replace_if_better(heap_idx, new_val, map),
}
}

fn drain(&mut self) -> (ArrayRef, Vec<usize>) {
let nulls = None;
let (vals, map_idxs) = self.heap.drain();
let arr = PrimitiveArray::<VAL>::new(ScalarBuffer::from(vals), nulls)
let nulls = vals.iter().any(|val| val.is_none()).then(|| {
NullBuffer::from(vals.iter().map(|val| val.is_some()).collect::<Vec<_>>())
});
let values = ScalarBuffer::from(
vals.iter()
.map(|val| val.unwrap_or_default())
.collect::<Vec<_>>(),
);
let arr = PrimitiveArray::<VAL>::new(values, nulls)
.with_data_type(self.data_type.clone());
(Arc::new(arr), map_idxs)
}
Expand Down Expand Up @@ -194,10 +229,24 @@ impl StringHeap {
/// Panics if the row index is out of bounds or if the data type is not one of
/// the supported UTF-8 string types.
///
/// Note: Null values should not appear in the input; the aggregation layer
/// ensures nulls are filtered before reaching this code.
fn value(&self, row_idx: usize) -> &str {
extract_string_value(&self.batch, &self.data_type, row_idx)
fn value(&self, row_idx: usize) -> Option<&str> {
if self.batch.is_null(row_idx) {
None
} else {
Some(extract_string_value(&self.batch, &self.data_type, row_idx))
}
}
}

fn compare_optional_str(
lhs: Option<&str>,
rhs: &Option<String>,
) -> Ordering {
match (lhs, rhs.as_deref()) {
(Some(lhs), Some(rhs)) => lhs.cmp(rhs),
(None, Some(_)) => Ordering::Less,
(Some(_), None) => Ordering::Greater,
(None, None) => Ordering::Equal,
}
}

Expand Down Expand Up @@ -234,13 +283,8 @@ impl ArrowHeap for StringHeap {
// existing heap entry.
let new_val = self.value(row_idx);
let worst_val = self.heap.worst_val().expect("Missing root");
match worst_val {
None => false,
Some(worst_str) => {
(!self.desc && new_val > worst_str.as_str())
|| (self.desc && new_val < worst_str.as_str())
}
}
(!self.desc && compare_optional_str(new_val, worst_val) == Ordering::Greater)
|| (self.desc && compare_optional_str(new_val, worst_val) == Ordering::Less)
}

fn worst_map_idx(&self) -> usize {
Expand All @@ -252,8 +296,7 @@ impl ArrowHeap for StringHeap {
// because it will be stored in the heap. For replacements we avoid
// allocation until `replace_if_better` confirms a replacement is
// necessary.
let new_str = self.value(row_idx).to_string();
let new_val = Some(new_str);
let new_val = self.value(row_idx).map(ToString::to_string);
self.heap.append_or_replace(new_val, map_idx, map);
}

Expand All @@ -273,16 +316,28 @@ impl ArrowHeap for StringHeap {
// only allocate (`to_string()`) when a replacement is required.
match &existing.val {
None => {
// Existing is null; new value always wins
let new_val = Some(new_str.to_string());
self.heap.replace_if_better(heap_idx, new_val, map);
// MIN/MAX ignore null inputs, so a non-null row must replace an
// existing null aggregate state for the same group.
if let Some(new_str) = new_str {
let new_val = Some(new_str.to_string());
if self.desc {
self.heap.replace_if_better(heap_idx, new_val, map);
} else {
self.heap.heap[heap_idx]
.as_mut()
.expect("Missing heap item")
.val = new_val;
self.heap.heapify_up(heap_idx, map);
}
}
}
Some(existing_str) => {
Some(_) => {
// Compare borrowed strings first
if (!self.desc && new_str < existing_str.as_str())
|| (self.desc && new_str > existing_str.as_str())
if (!self.desc && compare_optional_str(new_str, &existing.val) == Ordering::Less)
|| (self.desc
&& compare_optional_str(new_str, &existing.val) == Ordering::Greater)
{
let new_val = Some(new_str.to_string());
let new_val = new_str.map(ToString::to_string);
self.heap.replace_if_better(heap_idx, new_val, map);
}
// Else: no improvement, no allocation
Expand Down
9 changes: 0 additions & 9 deletions datafusion/physical-plan/src/aggregates/topk_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ impl GroupedTopKAggregateStream {
);
}

// Note: Null values in aggregate columns are filtered by the aggregation layer
// before reaching the heap, so the heap implementations don't need explicit null handling.
let priority_map = PriorityMap::new(kt, vt, limit, desc)?;

Ok(GroupedTopKAggregateStream {
Expand Down Expand Up @@ -145,14 +143,7 @@ impl GroupedTopKAggregateStream {
self.priority_map
.set_batch(Arc::clone(ids), Arc::clone(vals));

let has_nulls = vals.null_count() > 0;
if has_nulls && self.is_group_by_only() {
self.null_group_seen = true;
}
for row_idx in 0..len {
if has_nulls && vals.is_null(row_idx) {
continue;
}
self.priority_map.insert(row_idx)?;
}
Ok(())
Expand Down
Loading