Skip to content
Draft
2 changes: 2 additions & 0 deletions datafusion/physical-expr/src/expressions/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ use datafusion_expr::{ColumnarValue, expr_vec_fmt};

mod array_static_filter;
mod primitive_filter;
mod result;
mod static_filter;
mod strategy;
mod transform;

use static_filter::StaticFilter;
use strategy::instantiate_static_filter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,104 @@ use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::compute::{SortOptions, take};
use arrow::datatypes::DataType;
use arrow::util::bit_iterator::BitIndexIterator;
use datafusion_common::HashMap;
use datafusion_common::Result;
use datafusion_common::hash_utils::{RandomState, with_hashes};
use hashbrown::hash_map::RawEntryMut;
use hashbrown::HashTable;

use super::result::build_in_list_result;
use super::static_filter::StaticFilter;

/// Static filter for InList that stores the array and hash set for O(1) lookups
#[derive(Debug, Clone)]
pub(super) struct ArrayStaticFilter {
in_array: ArrayRef,
state: RandomState,
/// Used to provide a lookup from value to in list index
/// Stores indices into `in_array` for O(1) lookups.
table: HashTable<usize>,
}

impl ArrayStaticFilter {
/// Computes a [`StaticFilter`] for the provided [`Array`] if there
/// are nulls present or there are more than the configured number of
/// elements.
///
/// Note: usize::hash is not used, instead the raw entry
/// API is used to store entries w.r.t their value
map: HashMap<usize, (), ()>,
/// Note: This is split into a separate function as higher-rank trait bounds currently
/// cause type inference to misbehave
pub(super) fn try_new(in_array: ArrayRef) -> Result<ArrayStaticFilter> {
// Null type has no natural order - return empty hash set
if in_array.data_type() == &DataType::Null {
return Ok(ArrayStaticFilter {
in_array,
state: RandomState::default(),
table: HashTable::new(),
});
}

let state = RandomState::default();
let table = Self::build_haystack_table(&in_array, &state)?;

Ok(Self {
in_array,
state,
table,
})
}

fn build_haystack_table(
haystack: &ArrayRef,
state: &RandomState,
) -> Result<HashTable<usize>> {
let mut table = HashTable::new();

with_hashes([haystack.as_ref()], state, |hashes| -> Result<()> {
let cmp = make_comparator(haystack, haystack, SortOptions::default())?;

let insert_value = |idx| {
let hash = hashes[idx];
// Only insert if not already present (deduplication)
if table.find(hash, |&x| cmp(x, idx).is_eq()).is_none() {
table.insert_unique(hash, idx, |&x| hashes[x]);
}
};

match haystack.nulls() {
Some(nulls) => {
BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len())
.for_each(insert_value)
}
None => (0..haystack.len()).for_each(insert_value),
}

Ok(())
})?;

Ok(table)
}

fn find_needles_in_haystack(
&self,
needles: &dyn Array,
negated: bool,
) -> Result<BooleanArray> {
let needle_nulls = needles.logical_nulls();
let haystack_has_nulls = self.in_array.null_count() != 0;

with_hashes([needles], &self.state, |needle_hashes| {
let cmp = make_comparator(needles, &self.in_array, SortOptions::default())?;

Ok(build_in_list_result(
needles.len(),
needle_nulls.as_ref(),
haystack_has_nulls,
negated,
#[inline(always)]
|i| {
let hash = needle_hashes[i];
self.table.find(hash, |&idx| cmp(i, idx).is_eq()).is_some()
},
))
})
}
}

impl StaticFilter for ArrayStaticFilter {
Expand Down Expand Up @@ -76,85 +157,6 @@ impl StaticFilter for ArrayStaticFilter {
_ => {}
}

let needle_nulls = v.logical_nulls();
let needle_nulls = needle_nulls.as_ref();
let haystack_has_nulls = self.in_array.null_count() != 0;

with_hashes([v], &self.state, |hashes| {
let cmp = make_comparator(v, &self.in_array, SortOptions::default())?;
Ok((0..v.len())
.map(|i| {
// SQL three-valued logic: null IN (...) is always null
if needle_nulls.is_some_and(|nulls| nulls.is_null(i)) {
return None;
}

let hash = hashes[i];
let contains = self
.map
.raw_entry()
.from_hash(hash, |idx| cmp(i, *idx).is_eq())
.is_some();

match contains {
true => Some(!negated),
false if haystack_has_nulls => None,
false => Some(negated),
}
})
.collect())
})
}
}

impl ArrayStaticFilter {
/// Computes a [`StaticFilter`] for the provided [`Array`] if there
/// are nulls present or there are more than the configured number of
/// elements.
///
/// Note: This is split into a separate function as higher-rank trait bounds currently
/// cause type inference to misbehave
pub(super) fn try_new(in_array: ArrayRef) -> Result<ArrayStaticFilter> {
// Null type has no natural order - return empty hash set
if in_array.data_type() == &DataType::Null {
return Ok(ArrayStaticFilter {
in_array,
state: RandomState::default(),
map: HashMap::with_hasher(()),
});
}

let state = RandomState::default();
let mut map: HashMap<usize, (), ()> = HashMap::with_hasher(());

with_hashes([&in_array], &state, |hashes| -> Result<()> {
let cmp = make_comparator(&in_array, &in_array, SortOptions::default())?;

let insert_value = |idx| {
let hash = hashes[idx];
if let RawEntryMut::Vacant(v) = map
.raw_entry_mut()
.from_hash(hash, |x| cmp(*x, idx).is_eq())
{
v.insert_with_hasher(hash, idx, (), |x| hashes[*x]);
}
};

match in_array.nulls() {
Some(nulls) => {
BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len())
.for_each(insert_value)
}
None => (0..in_array.len()).for_each(insert_value),
}

Ok(())
})?;

Ok(Self {
in_array,
state,
map,
})
self.find_needles_in_haystack(v, negated)
}
}
Loading
Loading