Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 96 additions & 64 deletions crates/iceberg/src/arrow/reader/predicate_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ use std::sync::Arc;
use arrow_arith::boolean::{and, and_kleene, is_not_null, is_null, not, or, or_kleene};
use arrow_array::cast::AsArray;
use arrow_array::types::{Float32Type, Float64Type};
use arrow_array::{Array, ArrayRef, BooleanArray, Datum as ArrowDatum, RecordBatch, Scalar};
use arrow_array::{
Array, ArrayRef, BooleanArray, Datum as ArrowDatum, RecordBatch, Scalar, StructArray,
};
use arrow_buffer::BooleanBuffer;
use arrow_cast::cast::cast;
use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
Expand Down Expand Up @@ -198,7 +200,8 @@ impl BoundPredicateVisitor for CollectFieldIdVisitor {

/// A visitor to convert Iceberg bound predicates to Arrow predicates.
pub(super) struct PredicateConverter<'a> {
/// The Parquet schema descriptor.
/// The Parquet schema descriptor. Used to resolve the parquet column path
/// from a leaf column index when the predicate targets a nested field.
pub(super) parquet_schema: &'a SchemaDescriptor,
/// The map between field id and leaf column index in Parquet schema.
pub(super) column_map: &'a HashMap<i32, usize>,
Expand All @@ -207,26 +210,20 @@ pub(super) struct PredicateConverter<'a> {
}

impl PredicateConverter<'_> {
/// When visiting a bound reference, we return index of the leaf column in the
/// required column indices which is used to project the column in the record batch.
/// Return None if the field id is not found in the column map, which is possible
/// due to schema evolution.
fn bound_reference(&mut self, reference: &BoundReference) -> Result<Option<usize>> {
// The leaf column's index in Parquet schema.
/// When visiting a bound reference, we return the parquet column path
/// (root -> leaf) so the predicate closure can extract the primitive leaf
/// from the projected `RecordBatch`. `ProjectionMask::leaves` preserves
/// the parquet schema's nesting structure, so a predicate that targets
/// `nested.value` sees a `RecordBatch` whose `nested` column is a
/// `StructArray` holding the `value` leaf. We descend by name to reach
/// the primitive.
///
/// Returns `Ok(None)` when the field id is not present in the parquet
/// schema (possible under schema evolution).
fn bound_reference(&mut self, reference: &BoundReference) -> Result<Option<Arc<[String]>>> {
if let Some(column_idx) = self.column_map.get(&reference.field().id) {
if self.parquet_schema.get_column_root(*column_idx).is_group() {
return Err(Error::new(
ErrorKind::DataInvalid,
format!(
"Leaf column `{}` in predicates isn't a root column in Parquet schema.",
reference.field().name
),
));
}

// The leaf column's index in the required column indices.
let index = self
.column_indices
// Sanity-check that the leaf is among the projected columns.
self.column_indices
.iter()
.position(|&idx| idx == *column_idx)
.ok_or(Error::new(
Expand All @@ -237,7 +234,16 @@ impl PredicateConverter<'_> {
),
))?;

Ok(Some(index))
let path: Arc<[String]> = self
.parquet_schema
.column(*column_idx)
.path()
.parts()
.iter()
.cloned()
.collect();

Ok(Some(path))
} else {
Ok(None)
}
Expand All @@ -258,20 +264,46 @@ impl PredicateConverter<'_> {
}
}

/// Gets the leaf column from the record batch for the required column index. Only
/// supports top-level columns for now.
fn project_column(
batch: &RecordBatch,
column_idx: usize,
) -> std::result::Result<ArrayRef, ArrowError> {
let column = batch.column(column_idx);

match column.data_type() {
DataType::Struct(_) => Err(ArrowError::SchemaError(
"Does not support struct column yet.".to_string(),
)),
_ => Ok(column.clone()),
}
/// Walks the parquet column path (root -> leaf) through the projected
/// `RecordBatch` to reach a primitive leaf. Top-level paths return the
/// matching column directly; nested paths descend through `StructArray`
/// children by name.
fn project_column(batch: &RecordBatch, path: &[String]) -> std::result::Result<ArrayRef, ArrowError> {
let (root_name, rest) = path.split_first().ok_or_else(|| {
ArrowError::SchemaError("Predicate column path is empty.".to_string())
})?;
let mut current = batch
.column_by_name(root_name)
.ok_or_else(|| {
ArrowError::SchemaError(format!(
"Predicate column root `{root_name}` not found in projected RecordBatch."
))
})?
.clone();
for part in rest {
let struct_array = current.as_any().downcast_ref::<StructArray>().ok_or_else(|| {
ArrowError::SchemaError(format!(
"Predicate column path expected a struct at `{part}` but got {:?}.",
current.data_type()
))
})?;
current = struct_array
.column_by_name(part)
.ok_or_else(|| {
ArrowError::SchemaError(format!(
"Predicate column nested field `{part}` not found in struct."
))
})?
.clone();
}

if matches!(current.data_type(), DataType::Struct(_)) {
return Err(ArrowError::SchemaError(
"Predicate column path resolved to a struct, expected a primitive leaf.".to_string(),
));
}

Ok(current)
}

fn compute_is_nan(array: &ArrayRef) -> std::result::Result<BooleanArray, ArrowError> {
Expand Down Expand Up @@ -353,9 +385,9 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
reference: &BoundReference,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
Ok(Box::new(move |batch| {
let column = project_column(&batch, idx)?;
let column = project_column(&batch, &path)?;
is_null(&column)
}))
} else {
Expand All @@ -369,9 +401,9 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
reference: &BoundReference,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
Ok(Box::new(move |batch| {
let column = project_column(&batch, idx)?;
let column = project_column(&batch, &path)?;
is_not_null(&column)
}))
} else {
Expand All @@ -385,9 +417,9 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
reference: &BoundReference,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
Ok(Box::new(move |batch| {
let column = project_column(&batch, idx)?;
let column = project_column(&batch, &path)?;
compute_is_nan(&column)
}))
} else {
Expand All @@ -401,9 +433,9 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
reference: &BoundReference,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
Ok(Box::new(move |batch| {
let column = project_column(&batch, idx)?;
let column = project_column(&batch, &path)?;
let is_nan = compute_is_nan(&column)?;
not(&is_nan)
}))
Expand All @@ -419,11 +451,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let left = project_column(&batch, &path)?;
let literal = try_cast_literal(&literal, left.data_type())?;
lt(&left, literal.as_ref())
}))
Expand All @@ -439,11 +471,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let left = project_column(&batch, &path)?;
let literal = try_cast_literal(&literal, left.data_type())?;
lt_eq(&left, literal.as_ref())
}))
Expand All @@ -459,11 +491,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let left = project_column(&batch, &path)?;
let literal = try_cast_literal(&literal, left.data_type())?;
gt(&left, literal.as_ref())
}))
Expand All @@ -479,11 +511,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let left = project_column(&batch, &path)?;
let literal = try_cast_literal(&literal, left.data_type())?;
gt_eq(&left, literal.as_ref())
}))
Expand All @@ -499,11 +531,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let left = project_column(&batch, &path)?;
let literal = try_cast_literal(&literal, left.data_type())?;
eq(&left, literal.as_ref())
}))
Expand All @@ -519,11 +551,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let left = project_column(&batch, &path)?;
let literal = try_cast_literal(&literal, left.data_type())?;
neq(&left, literal.as_ref())
}))
Expand All @@ -539,11 +571,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let left = project_column(&batch, &path)?;
let literal = try_cast_literal(&literal, left.data_type())?;
starts_with(&left, literal.as_ref())
}))
Expand All @@ -559,11 +591,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
let left = project_column(&batch, &path)?;
let literal = try_cast_literal(&literal, left.data_type())?;
// update here if arrow ever adds a native not_starts_with
not(&starts_with(&left, literal.as_ref())?)
Expand All @@ -580,15 +612,15 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
literals: &FnvHashSet<Datum>,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
let literals: Vec<_> = literals
.iter()
.map(|lit| get_arrow_datum(lit).unwrap())
.collect();

Ok(Box::new(move |batch| {
// update this if arrow ever adds a native is_in kernel
let left = project_column(&batch, idx)?;
let left = project_column(&batch, &path)?;

let mut acc = BooleanArray::from(vec![false; batch.num_rows()]);
for literal in &literals {
Expand All @@ -610,15 +642,15 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
literals: &FnvHashSet<Datum>,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if let Some(idx) = self.bound_reference(reference)? {
if let Some(path) = self.bound_reference(reference)? {
let literals: Vec<_> = literals
.iter()
.map(|lit| get_arrow_datum(lit).unwrap())
.collect();

Ok(Box::new(move |batch| {
// update this if arrow ever adds a native not_in kernel
let left = project_column(&batch, idx)?;
let left = project_column(&batch, &path)?;
let mut acc = BooleanArray::from(vec![true; batch.num_rows()]);
for literal in &literals {
let literal = try_cast_literal(literal, left.data_type())?;
Expand Down
Loading
Loading