diff --git a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs index 0fa60ae20d2be..57b8434bc3c45 100644 --- a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs @@ -410,6 +410,76 @@ async fn utf8_grouping_min_max_limit_fallbacks() -> Result<()> { Ok(()) } +#[tokio::test] +async fn null_min_max_topk_preserves_group_rows() -> Result<()> { + let mut config = SessionConfig::new(); + config.options_mut().optimizer.enable_topk_aggregation = true; + let ctx = SessionContext::new_with_config(config); + + let schema = Arc::new(Schema::new(vec![ + Field::new("g", DataType::Utf8, true), + Field::new("x", DataType::Int64, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec![Some("a"), Some("b")])), + Arc::new(Int64Array::from(vec![None, Some(5)])), + ], + )?; + let table = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(table))?; + + let min_topk_df = ctx + .sql( + "SELECT min_x FROM (SELECT g, MIN(x) AS min_x FROM t GROUP BY g) q \ + ORDER BY min_x ASC NULLS LAST LIMIT 20", + ) + .await?; + let min_topk_plan = min_topk_df.clone().create_physical_plan().await?; + let min_topk_batches = min_topk_df.collect().await?; + + let min_topk_plan_display = + displayable(min_topk_plan.as_ref()).indent(true).to_string(); + assert!( + min_topk_plan_display.contains("lim=[20]"), + "Expected TopK aggregation optimization to remain enabled: {min_topk_plan_display}" + ); + assert_batches_eq!( + &[ + "+-------+", + "| min_x |", + "+-------+", + "| 5 |", + "| |", + "+-------+" + ], + &min_topk_batches + ); + + let max_topk_batches = ctx + .sql( + "SELECT max_x FROM (SELECT g, MAX(x) AS max_x FROM t GROUP BY g) q \ + ORDER BY max_x DESC NULLS FIRST LIMIT 20", + ) + .await? + .collect() + .await?; + assert_batches_eq!( + &[ + "+-------+", + "| max_x |", + "+-------+", + "| |", + "| 5 |", + "+-------+" + ], + &max_topk_batches + ); + + Ok(()) +} + #[tokio::test] async fn test_count_distinct_optimization() -> Result<()> { struct TestCase { diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index e1779c04a6a92..757d1345d22a3 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -100,6 +100,9 @@ impl TopKAggregation { let order = sort.properties().output_ordering()?; let order = order.iter().exactly_one().ok()?; let order_desc = order.options.descending; + if order.options.nulls_first != !order_desc { + return None; + } let order = order.expr.downcast_ref::()?; let mut cur_col_name = order.name().to_string(); let limit = sort.fetch()?; diff --git a/datafusion/physical-plan/src/aggregates/topk/heap.rs b/datafusion/physical-plan/src/aggregates/topk/heap.rs index 889fe04bf830a..328e6ab215e13 100644 --- a/datafusion/physical-plan/src/aggregates/topk/heap.rs +++ b/datafusion/physical-plan/src/aggregates/topk/heap.rs @@ -23,14 +23,16 @@ //! Supported value types include Arrow primitives (integers, floats, decimals, intervals) //! and UTF-8 strings (`Utf8`, `LargeUtf8`, `Utf8View`) using lexicographic ordering. -use arrow::array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray, downcast_primitive}; +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray, downcast_primitive, +}; use arrow::array::{LargeStringBuilder, StringBuilder, StringViewBuilder}; use arrow::array::{ StringArray, cast::AsArray, types::{IntervalDayTime, IntervalMonthDayNano}, }; -use arrow::buffer::ScalarBuffer; +use arrow::buffer::{NullBuffer, ScalarBuffer}; use arrow::datatypes::{DataType, i256}; use datafusion_common::Result; use datafusion_common::exec_datafusion_err; @@ -93,17 +95,19 @@ pub trait ArrowHeap { /// An implementation of `ArrowHeap` that deals with primitive values pub struct PrimitiveHeap where - ::Native: Comparable, + ::Native: Comparable + Default, + Option<::Native>: Comparable, { batch: ArrayRef, - heap: TopKHeap, + heap: TopKHeap>, desc: bool, data_type: DataType, } impl PrimitiveHeap where - ::Native: Comparable, + ::Native: Comparable + Default, + Option<::Native>: Comparable, { pub fn new(limit: usize, desc: bool, data_type: DataType) -> Self { let owned: ArrayRef = Arc::new(PrimitiveArray::::builder(0).finish()); @@ -114,11 +118,21 @@ where data_type, } } + + fn value(&self, row_idx: usize) -> Option { + let vals = self.batch.as_primitive::(); + if vals.is_null(row_idx) { + None + } else { + Some(vals.value(row_idx)) + } + } } impl ArrowHeap for PrimitiveHeap where - ::Native: Comparable, + ::Native: Comparable + Default, + Option<::Native>: Comparable, { fn set_batch(&mut self, vals: ArrayRef) { self.batch = vals; @@ -128,10 +142,10 @@ where if !self.heap.is_full() { return false; } - let vals = self.batch.as_primitive::(); - let new_val = vals.value(row_idx); + let new_val = self.value(row_idx); let worst_val = self.heap.worst_val().expect("Missing root"); - (!self.desc && new_val > *worst_val) || (self.desc && new_val < *worst_val) + (!self.desc && new_val.comp(worst_val) == Ordering::Greater) + || (self.desc && new_val.comp(worst_val) == Ordering::Less) } fn worst_map_idx(&self) -> usize { @@ -139,9 +153,8 @@ where } fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>) { - let vals = self.batch.as_primitive::(); - let new_val = vals.value(row_idx); - self.heap.append_or_replace(new_val, map_idx, map); + self.heap + .append_or_replace(self.value(row_idx), map_idx, map); } fn replace_if_better( @@ -150,15 +163,37 @@ where row_idx: usize, map: &mut Vec<(usize, usize)>, ) { - let vals = self.batch.as_primitive::(); - let new_val = vals.value(row_idx); - self.heap.replace_if_better(heap_idx, new_val, map); + let new_val = self.value(row_idx); + let existing = self.heap.heap[heap_idx] + .as_ref() + .expect("Missing heap item"); + + match (&existing.val, &new_val) { + // MIN/MAX ignore null inputs, so a non-null row must replace an + // existing null aggregate state for the same group. + (None, Some(_)) if !self.desc => { + self.heap.heap[heap_idx] + .as_mut() + .expect("Missing heap item") + .val = new_val; + self.heap.heapify_up(heap_idx, map); + } + (Some(_), None) => {} + _ => self.heap.replace_if_better(heap_idx, new_val, map), + } } fn drain(&mut self) -> (ArrayRef, Vec) { - let nulls = None; let (vals, map_idxs) = self.heap.drain(); - let arr = PrimitiveArray::::new(ScalarBuffer::from(vals), nulls) + let nulls = vals.iter().any(|val| val.is_none()).then(|| { + NullBuffer::from(vals.iter().map(|val| val.is_some()).collect::>()) + }); + let values = ScalarBuffer::from( + vals.iter() + .map(|val| val.unwrap_or_default()) + .collect::>(), + ); + let arr = PrimitiveArray::::new(values, nulls) .with_data_type(self.data_type.clone()); (Arc::new(arr), map_idxs) } @@ -194,10 +229,24 @@ impl StringHeap { /// Panics if the row index is out of bounds or if the data type is not one of /// the supported UTF-8 string types. /// - /// Note: Null values should not appear in the input; the aggregation layer - /// ensures nulls are filtered before reaching this code. - fn value(&self, row_idx: usize) -> &str { - extract_string_value(&self.batch, &self.data_type, row_idx) + fn value(&self, row_idx: usize) -> Option<&str> { + if self.batch.is_null(row_idx) { + None + } else { + Some(extract_string_value(&self.batch, &self.data_type, row_idx)) + } + } +} + +fn compare_optional_str( + lhs: Option<&str>, + rhs: &Option, +) -> Ordering { + match (lhs, rhs.as_deref()) { + (Some(lhs), Some(rhs)) => lhs.cmp(rhs), + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + (None, None) => Ordering::Equal, } } @@ -234,13 +283,8 @@ impl ArrowHeap for StringHeap { // existing heap entry. let new_val = self.value(row_idx); let worst_val = self.heap.worst_val().expect("Missing root"); - match worst_val { - None => false, - Some(worst_str) => { - (!self.desc && new_val > worst_str.as_str()) - || (self.desc && new_val < worst_str.as_str()) - } - } + (!self.desc && compare_optional_str(new_val, worst_val) == Ordering::Greater) + || (self.desc && compare_optional_str(new_val, worst_val) == Ordering::Less) } fn worst_map_idx(&self) -> usize { @@ -252,8 +296,7 @@ impl ArrowHeap for StringHeap { // because it will be stored in the heap. For replacements we avoid // allocation until `replace_if_better` confirms a replacement is // necessary. - let new_str = self.value(row_idx).to_string(); - let new_val = Some(new_str); + let new_val = self.value(row_idx).map(ToString::to_string); self.heap.append_or_replace(new_val, map_idx, map); } @@ -273,16 +316,28 @@ impl ArrowHeap for StringHeap { // only allocate (`to_string()`) when a replacement is required. match &existing.val { None => { - // Existing is null; new value always wins - let new_val = Some(new_str.to_string()); - self.heap.replace_if_better(heap_idx, new_val, map); + // MIN/MAX ignore null inputs, so a non-null row must replace an + // existing null aggregate state for the same group. + if let Some(new_str) = new_str { + let new_val = Some(new_str.to_string()); + if self.desc { + self.heap.replace_if_better(heap_idx, new_val, map); + } else { + self.heap.heap[heap_idx] + .as_mut() + .expect("Missing heap item") + .val = new_val; + self.heap.heapify_up(heap_idx, map); + } + } } - Some(existing_str) => { + Some(_) => { // Compare borrowed strings first - if (!self.desc && new_str < existing_str.as_str()) - || (self.desc && new_str > existing_str.as_str()) + if (!self.desc && compare_optional_str(new_str, &existing.val) == Ordering::Less) + || (self.desc + && compare_optional_str(new_str, &existing.val) == Ordering::Greater) { - let new_val = Some(new_str.to_string()); + let new_val = new_str.map(ToString::to_string); self.heap.replace_if_better(heap_idx, new_val, map); } // Else: no improvement, no allocation diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index 97f4662c11342..7c2bbf51da0e1 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -106,8 +106,6 @@ impl GroupedTopKAggregateStream { ); } - // Note: Null values in aggregate columns are filtered by the aggregation layer - // before reaching the heap, so the heap implementations don't need explicit null handling. let priority_map = PriorityMap::new(kt, vt, limit, desc)?; Ok(GroupedTopKAggregateStream { @@ -145,14 +143,7 @@ impl GroupedTopKAggregateStream { self.priority_map .set_batch(Arc::clone(ids), Arc::clone(vals)); - let has_nulls = vals.null_count() > 0; - if has_nulls && self.is_group_by_only() { - self.null_group_seen = true; - } for row_idx in 0..len { - if has_nulls && vals.is_null(row_idx) { - continue; - } self.priority_map.insert(row_idx)?; } Ok(())