Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub fn set_nulls<T: ArrowNumericType + Send>(
///
/// The output is `true` for rows where the filter is `Some(true)`, and `false`
/// for rows where the filter is `Some(false)` or `None`.
pub(crate) fn filter_to_validity(filter: &BooleanArray) -> BooleanBuffer {
pub fn filter_to_validity(filter: &BooleanArray) -> BooleanBuffer {
let Some(filter_nulls) = filter.nulls() else {
return filter.values().clone();
};
Expand Down
97 changes: 94 additions & 3 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt,
GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility,
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filter_to_validity;
use datafusion_functions_aggregate_common::utils::get_sort_options;
use datafusion_macros::user_doc;
use datafusion_physical_expr_common::sort_expr::LexOrdering;
Expand Down Expand Up @@ -552,10 +553,12 @@ impl<S: ValueState> FirstLastGroupsAccumulator<S> {
LexicographicalComparator::try_new(&sort_columns)?
};

for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
let group_idx = *group_idx;
let filter_validity = opt_filter.map(filter_to_validity);

let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
for (idx_in_val, &group_idx) in group_indices.iter().enumerate() {
let passed_filter = filter_validity
.as_ref()
.is_none_or(|validity| validity.value(idx_in_val));
let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));

if !passed_filter || !is_set {
Expand Down Expand Up @@ -1416,13 +1419,53 @@ mod tests {

use arrow::{
array::{BooleanArray, Int64Array, ListArray, PrimitiveArray, StringArray},
buffer::NullBuffer,
compute::SortOptions,
datatypes::Schema,
};
use datafusion_physical_expr::{PhysicalSortExpr, expressions::col};

use super::*;

fn new_int64_first_last_group_acc(
pick_first_in_group: bool,
) -> Result<FirstLastGroupsAccumulator<PrimitiveValueState<Int64Type>>> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, true),
Field::new("c", DataType::Int64, true),
]));

let sort_keys = [PhysicalSortExpr {
expr: col("c", &schema).unwrap(),
options: SortOptions::default(),
}];

FirstLastGroupsAccumulator::try_new(
PrimitiveValueState::<Int64Type>::new(DataType::Int64),
sort_keys.into(),
true,
&[DataType::Int64],
pick_first_in_group,
)
}

fn nullable_bool_filter(values: Vec<bool>, validity: Vec<bool>) -> BooleanArray {
BooleanArray::new(
BooleanBuffer::from(values),
Some(NullBuffer::from(validity)),
)
}

fn assert_group_acc_int64_result(
group_acc: &mut FirstLastGroupsAccumulator<PrimitiveValueState<Int64Type>>,
expected: Int64Array,
) -> Result<()> {
let result = group_acc.evaluate(EmitTo::All)?;
let result = result.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(result, &expected);
Ok(())
}

#[test]
fn test_first_last_value_value() -> Result<()> {
let mut first_accumulator =
Expand Down Expand Up @@ -1621,6 +1664,54 @@ mod tests {
Ok(())
}

#[test]
fn test_first_group_acc_rejects_null_filter_with_true_value_bit() -> Result<()> {
let mut group_acc = new_int64_first_last_group_acc(true)?;

let values_and_orderings: Vec<ArrayRef> = vec![
Arc::new(Int64Array::from(vec![10, 20])),
Arc::new(Int64Array::from(vec![1, 2])),
];
let filter = nullable_bool_filter(vec![true, false], vec![false, true]);

group_acc.update_batch(&values_and_orderings, &[0, 0], Some(&filter), 1)?;

assert_group_acc_int64_result(&mut group_acc, Int64Array::from(vec![None]))
}

#[test]
fn test_last_group_acc_rejects_null_filter_with_true_value_bit() -> Result<()> {
let mut group_acc = new_int64_first_last_group_acc(false)?;

let values_and_orderings: Vec<ArrayRef> = vec![
Arc::new(Int64Array::from(vec![10, 20, 30])),
Arc::new(Int64Array::from(vec![1, 2, 3])),
];
let filter =
nullable_bool_filter(vec![true, true, false], vec![false, true, true]);

group_acc.update_batch(&values_and_orderings, &[0, 0, 0], Some(&filter), 1)?;

assert_group_acc_int64_result(&mut group_acc, Int64Array::from(vec![Some(20)]))
}

#[test]
fn test_first_group_acc_merge_rejects_null_filter_with_true_value_bit() -> Result<()>
{
let mut group_acc = new_int64_first_last_group_acc(true)?;

let states: Vec<ArrayRef> = vec![
Arc::new(Int64Array::from(vec![10, 20])),
Arc::new(Int64Array::from(vec![1, 2])),
Arc::new(BooleanArray::from(vec![true, true])),
];
let filter = nullable_bool_filter(vec![true, true], vec![false, true]);

group_acc.merge_batch(&states, &[0, 0], Some(&filter), 1)?;

assert_group_acc_int64_result(&mut group_acc, Int64Array::from(vec![Some(20)]))
}

#[test]
fn test_group_acc_size_of_ordering() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Expand Down
44 changes: 44 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -6197,6 +6197,50 @@ GROUP BY g
----
0 0

# Grouped first_value/last_value must apply aggregate FILTER with Some(true)
# semantics: a row passes only when the predicate is TRUE. Rows where the
# predicate evaluates to NULL or FALSE must be excluded.
#
# Rows per group (predicate is b < 1):
# g=1: (a=10, b=NULL -> NULL), (a=20, b=2 -> FALSE) => no rows pass
# g=2: (a=30, b=0 -> TRUE), (a=40, b=NULL -> NULL),
# (a=50, b=-5 -> TRUE) => a=30 and a=50 pass
# g=3: (a=60, b=NULL -> NULL) => no rows pass
statement ok
CREATE TABLE first_last_filter_null_tests(g INT, a INT, b INT) AS VALUES
(1, 10, CAST(NULL AS INT)),
(1, 20, 2),
(2, 30, 0),
(2, 40, CAST(NULL AS INT)),
(2, 50, -5),
(3, 60, CAST(NULL AS INT));

# Groups 1 and 3 have no rows passing the filter -> NULL.
# Group 2 has a=30 and a=50 passing -> first_value ORDER BY a = 30.
query II
SELECT g, first_value(a ORDER BY a) FILTER (WHERE b < 1) AS fv
FROM first_last_filter_null_tests
GROUP BY g
ORDER BY g;
----
1 NULL
2 30
3 NULL

# Same groups via last_value: group 2 picks the largest passing a = 50.
query II
SELECT g, last_value(a ORDER BY a) FILTER (WHERE b < 1) AS lv
FROM first_last_filter_null_tests
GROUP BY g
ORDER BY g;
----
1 NULL
2 50
3 NULL

statement ok
DROP TABLE first_last_filter_null_tests;

# query_with_and_without_filter
query III rowsort
SELECT
Expand Down
Loading