diff --git a/crates/iceberg/src/arrow/reader/predicate_visitor.rs b/crates/iceberg/src/arrow/reader/predicate_visitor.rs index 272de49390..7dde9b1939 100644 --- a/crates/iceberg/src/arrow/reader/predicate_visitor.rs +++ b/crates/iceberg/src/arrow/reader/predicate_visitor.rs @@ -25,7 +25,9 @@ use std::sync::Arc; use arrow_arith::boolean::{and, and_kleene, is_not_null, is_null, not, or, or_kleene}; use arrow_array::cast::AsArray; use arrow_array::types::{Float32Type, Float64Type}; -use arrow_array::{Array, ArrayRef, BooleanArray, Datum as ArrowDatum, RecordBatch, Scalar}; +use arrow_array::{ + Array, ArrayRef, BooleanArray, Datum as ArrowDatum, RecordBatch, Scalar, StructArray, +}; use arrow_buffer::BooleanBuffer; use arrow_cast::cast::cast; use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; @@ -198,7 +200,8 @@ impl BoundPredicateVisitor for CollectFieldIdVisitor { /// A visitor to convert Iceberg bound predicates to Arrow predicates. pub(super) struct PredicateConverter<'a> { - /// The Parquet schema descriptor. + /// The Parquet schema descriptor. Used to resolve the parquet column path + /// from a leaf column index when the predicate targets a nested field. pub(super) parquet_schema: &'a SchemaDescriptor, /// The map between field id and leaf column index in Parquet schema. pub(super) column_map: &'a HashMap, @@ -207,26 +210,20 @@ pub(super) struct PredicateConverter<'a> { } impl PredicateConverter<'_> { - /// When visiting a bound reference, we return index of the leaf column in the - /// required column indices which is used to project the column in the record batch. - /// Return None if the field id is not found in the column map, which is possible - /// due to schema evolution. - fn bound_reference(&mut self, reference: &BoundReference) -> Result> { - // The leaf column's index in Parquet schema. + /// When visiting a bound reference, we return the parquet column path + /// (root -> leaf) so the predicate closure can extract the primitive leaf + /// from the projected `RecordBatch`. `ProjectionMask::leaves` preserves + /// the parquet schema's nesting structure, so a predicate that targets + /// `nested.value` sees a `RecordBatch` whose `nested` column is a + /// `StructArray` holding the `value` leaf. We descend by name to reach + /// the primitive. + /// + /// Returns `Ok(None)` when the field id is not present in the parquet + /// schema (possible under schema evolution). + fn bound_reference(&mut self, reference: &BoundReference) -> Result>> { if let Some(column_idx) = self.column_map.get(&reference.field().id) { - if self.parquet_schema.get_column_root(*column_idx).is_group() { - return Err(Error::new( - ErrorKind::DataInvalid, - format!( - "Leaf column `{}` in predicates isn't a root column in Parquet schema.", - reference.field().name - ), - )); - } - - // The leaf column's index in the required column indices. - let index = self - .column_indices + // Sanity-check that the leaf is among the projected columns. + self.column_indices .iter() .position(|&idx| idx == *column_idx) .ok_or(Error::new( @@ -237,7 +234,16 @@ impl PredicateConverter<'_> { ), ))?; - Ok(Some(index)) + let path: Arc<[String]> = self + .parquet_schema + .column(*column_idx) + .path() + .parts() + .iter() + .cloned() + .collect(); + + Ok(Some(path)) } else { Ok(None) } @@ -258,20 +264,46 @@ impl PredicateConverter<'_> { } } -/// Gets the leaf column from the record batch for the required column index. Only -/// supports top-level columns for now. -fn project_column( - batch: &RecordBatch, - column_idx: usize, -) -> std::result::Result { - let column = batch.column(column_idx); - - match column.data_type() { - DataType::Struct(_) => Err(ArrowError::SchemaError( - "Does not support struct column yet.".to_string(), - )), - _ => Ok(column.clone()), - } +/// Walks the parquet column path (root -> leaf) through the projected +/// `RecordBatch` to reach a primitive leaf. Top-level paths return the +/// matching column directly; nested paths descend through `StructArray` +/// children by name. +fn project_column(batch: &RecordBatch, path: &[String]) -> std::result::Result { + let (root_name, rest) = path.split_first().ok_or_else(|| { + ArrowError::SchemaError("Predicate column path is empty.".to_string()) + })?; + let mut current = batch + .column_by_name(root_name) + .ok_or_else(|| { + ArrowError::SchemaError(format!( + "Predicate column root `{root_name}` not found in projected RecordBatch." + )) + })? + .clone(); + for part in rest { + let struct_array = current.as_any().downcast_ref::().ok_or_else(|| { + ArrowError::SchemaError(format!( + "Predicate column path expected a struct at `{part}` but got {:?}.", + current.data_type() + )) + })?; + current = struct_array + .column_by_name(part) + .ok_or_else(|| { + ArrowError::SchemaError(format!( + "Predicate column nested field `{part}` not found in struct." + )) + })? + .clone(); + } + + if matches!(current.data_type(), DataType::Struct(_)) { + return Err(ArrowError::SchemaError( + "Predicate column path resolved to a struct, expected a primitive leaf.".to_string(), + )); + } + + Ok(current) } fn compute_is_nan(array: &ArrayRef) -> std::result::Result { @@ -353,9 +385,9 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { reference: &BoundReference, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { Ok(Box::new(move |batch| { - let column = project_column(&batch, idx)?; + let column = project_column(&batch, &path)?; is_null(&column) })) } else { @@ -369,9 +401,9 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { reference: &BoundReference, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { Ok(Box::new(move |batch| { - let column = project_column(&batch, idx)?; + let column = project_column(&batch, &path)?; is_not_null(&column) })) } else { @@ -385,9 +417,9 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { reference: &BoundReference, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { Ok(Box::new(move |batch| { - let column = project_column(&batch, idx)?; + let column = project_column(&batch, &path)?; compute_is_nan(&column) })) } else { @@ -401,9 +433,9 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { reference: &BoundReference, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { Ok(Box::new(move |batch| { - let column = project_column(&batch, idx)?; + let column = project_column(&batch, &path)?; let is_nan = compute_is_nan(&column)?; not(&is_nan) })) @@ -419,11 +451,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { literal: &Datum, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { let literal = get_arrow_datum(literal)?; Ok(Box::new(move |batch| { - let left = project_column(&batch, idx)?; + let left = project_column(&batch, &path)?; let literal = try_cast_literal(&literal, left.data_type())?; lt(&left, literal.as_ref()) })) @@ -439,11 +471,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { literal: &Datum, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { let literal = get_arrow_datum(literal)?; Ok(Box::new(move |batch| { - let left = project_column(&batch, idx)?; + let left = project_column(&batch, &path)?; let literal = try_cast_literal(&literal, left.data_type())?; lt_eq(&left, literal.as_ref()) })) @@ -459,11 +491,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { literal: &Datum, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { let literal = get_arrow_datum(literal)?; Ok(Box::new(move |batch| { - let left = project_column(&batch, idx)?; + let left = project_column(&batch, &path)?; let literal = try_cast_literal(&literal, left.data_type())?; gt(&left, literal.as_ref()) })) @@ -479,11 +511,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { literal: &Datum, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { let literal = get_arrow_datum(literal)?; Ok(Box::new(move |batch| { - let left = project_column(&batch, idx)?; + let left = project_column(&batch, &path)?; let literal = try_cast_literal(&literal, left.data_type())?; gt_eq(&left, literal.as_ref()) })) @@ -499,11 +531,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { literal: &Datum, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { let literal = get_arrow_datum(literal)?; Ok(Box::new(move |batch| { - let left = project_column(&batch, idx)?; + let left = project_column(&batch, &path)?; let literal = try_cast_literal(&literal, left.data_type())?; eq(&left, literal.as_ref()) })) @@ -519,11 +551,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { literal: &Datum, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { let literal = get_arrow_datum(literal)?; Ok(Box::new(move |batch| { - let left = project_column(&batch, idx)?; + let left = project_column(&batch, &path)?; let literal = try_cast_literal(&literal, left.data_type())?; neq(&left, literal.as_ref()) })) @@ -539,11 +571,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { literal: &Datum, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { let literal = get_arrow_datum(literal)?; Ok(Box::new(move |batch| { - let left = project_column(&batch, idx)?; + let left = project_column(&batch, &path)?; let literal = try_cast_literal(&literal, left.data_type())?; starts_with(&left, literal.as_ref()) })) @@ -559,11 +591,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { literal: &Datum, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { let literal = get_arrow_datum(literal)?; Ok(Box::new(move |batch| { - let left = project_column(&batch, idx)?; + let left = project_column(&batch, &path)?; let literal = try_cast_literal(&literal, left.data_type())?; // update here if arrow ever adds a native not_starts_with not(&starts_with(&left, literal.as_ref())?) @@ -580,7 +612,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { literals: &FnvHashSet, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { let literals: Vec<_> = literals .iter() .map(|lit| get_arrow_datum(lit).unwrap()) @@ -588,7 +620,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { Ok(Box::new(move |batch| { // update this if arrow ever adds a native is_in kernel - let left = project_column(&batch, idx)?; + let left = project_column(&batch, &path)?; let mut acc = BooleanArray::from(vec![false; batch.num_rows()]); for literal in &literals { @@ -610,7 +642,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { literals: &FnvHashSet, _predicate: &BoundPredicate, ) -> Result> { - if let Some(idx) = self.bound_reference(reference)? { + if let Some(path) = self.bound_reference(reference)? { let literals: Vec<_> = literals .iter() .map(|lit| get_arrow_datum(lit).unwrap()) @@ -618,7 +650,7 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { Ok(Box::new(move |batch| { // update this if arrow ever adds a native not_in kernel - let left = project_column(&batch, idx)?; + let left = project_column(&batch, &path)?; let mut acc = BooleanArray::from(vec![true; batch.num_rows()]); for literal in &literals { let literal = try_cast_literal(literal, left.data_type())?; diff --git a/crates/iceberg/src/arrow/reader/row_filter.rs b/crates/iceberg/src/arrow/reader/row_filter.rs index 80432a0437..ff77df274f 100644 --- a/crates/iceberg/src/arrow/reader/row_filter.rs +++ b/crates/iceberg/src/arrow/reader/row_filter.rs @@ -196,12 +196,14 @@ mod tests { use std::sync::Arc; use arrow_array::cast::AsArray; - use arrow_array::{ArrayRef, LargeStringArray, RecordBatch, StringArray}; + use arrow_array::{ArrayRef, LargeStringArray, RecordBatch, StringArray, StructArray}; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use futures::TryStreamExt; use parquet::arrow::{ArrowWriter, PARQUET_FIELD_ID_META_KEY}; use parquet::basic::Compression; use parquet::file::properties::WriterProperties; + use parquet::schema::parser::parse_message_type; + use parquet::schema::types::SchemaDescriptor; use tempfile::TempDir; use crate::arrow::{ArrowReader, ArrowReaderBuilder}; @@ -422,6 +424,168 @@ mod tests { } } + /// Predicate filters that reference a nested-leaf column used to be + /// rejected wholesale by `PredicateConverter::bound_reference` because the + /// parquet column root for a nested leaf is the surrounding group. The + /// downstream filter pipeline (`ProjectionMask::leaves` -> `ArrowPredicateFn` + /// -> `RowFilter`) handles nested struct leaves correctly, so the converter + /// should accept them. + #[test] + fn test_get_row_filter_accepts_predicate_on_nested_leaf() { + let schema = Arc::new( + Schema::builder() + .with_schema_id(1) + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required( + 2, + "nested", + Type::Struct(crate::spec::StructType::new(vec![ + NestedField::required(3, "value", Type::Primitive(PrimitiveType::Long)) + .into(), + ])), + ) + .into(), + ]) + .build() + .unwrap(), + ); + + let message_type = " +message schema { + required int32 id = 1; + required group nested = 2 { + required int64 value = 3; + } +} + "; + let parquet_type = parse_message_type(message_type).expect("parse parquet message type"); + let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_type)); + + let predicate = Reference::new("nested.value").equal_to(Datum::long(42)); + let bound_predicate = predicate.bind(schema.clone(), true).unwrap(); + let (iceberg_field_ids, field_id_map) = + ArrowReader::build_field_id_set_and_map(&parquet_schema, &bound_predicate) + .expect("build field id map"); + + ArrowReader::get_row_filter( + &bound_predicate, + &parquet_schema, + &iceberg_field_ids, + &field_id_map, + ) + .expect("get_row_filter should accept a predicate on a nested struct leaf"); + } + + /// End-to-end check that a predicate on a nested-leaf column actually + /// filters rows through the full Arrow reader pipeline. + #[tokio::test] + async fn test_perform_read_with_nested_leaf_predicate() { + use arrow_array::{Int32Array, Int64Array}; + use arrow_schema::Fields; + + let schema = Arc::new( + Schema::builder() + .with_schema_id(1) + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required( + 2, + "nested", + Type::Struct(crate::spec::StructType::new(vec![ + NestedField::required(3, "value", Type::Primitive(PrimitiveType::Long)) + .into(), + ])), + ) + .into(), + ]) + .build() + .unwrap(), + ); + + let value_arrow_field = Field::new("value", DataType::Int64, false).with_metadata( + HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "3".to_string())]), + ); + let nested_struct_fields = Fields::from(vec![value_arrow_field.clone()]); + let arrow_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "1".to_string(), + )])), + Field::new("nested", DataType::Struct(nested_struct_fields.clone()), false) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "2".to_string(), + )])), + ])); + + let id_data = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let value_data = Arc::new(Int64Array::from(vec![10_i64, 20, 30])) as ArrayRef; + let nested_data = Arc::new(StructArray::from(vec![( + Arc::new(value_arrow_field), + value_data, + )])) as ArrayRef; + let to_write = + RecordBatch::try_new(arrow_schema.clone(), vec![id_data, nested_data]).unwrap(); + + let tmp_dir = TempDir::new().unwrap(); + let table_location = tmp_dir.path().to_str().unwrap().to_string(); + let file_io = FileIO::new_with_fs(); + + let props = WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(); + let file = File::create(format!("{table_location}/1.parquet")).unwrap(); + let mut writer = ArrowWriter::try_new(file, to_write.schema(), Some(props)).unwrap(); + writer.write(&to_write).expect("Writing batch"); + writer.close().unwrap(); + + let predicate = Reference::new("nested.value").equal_to(Datum::long(20)); + let tasks = Box::pin(futures::stream::iter( + vec![Ok(FileScanTask { + file_size_in_bytes: std::fs::metadata(format!("{table_location}/1.parquet")) + .unwrap() + .len(), + start: 0, + length: 0, + record_count: None, + data_file_path: format!("{table_location}/1.parquet"), + data_file_format: DataFileFormat::Parquet, + schema: schema.clone(), + project_field_ids: vec![1, 2], + predicate: Some(predicate.bind(schema.clone(), true).unwrap()), + deletes: vec![], + partition: None, + partition_spec: None, + name_mapping: None, + case_sensitive: true, + })] + .into_iter(), + )) as FileScanTaskStream; + + let reader = ArrowReaderBuilder::new(file_io).build(); + let result = reader + .read(tasks) + .unwrap() + .stream() + .try_collect::>() + .await + .unwrap(); + + let total_rows: usize = result.iter().map(RecordBatch::num_rows).sum(); + assert_eq!( + total_rows, 1, + "predicate on nested leaf should leave exactly one row" + ); + + let id_column = result[0].column_by_name("id").expect("id column present"); + let id_values = id_column + .as_any() + .downcast_ref::() + .expect("id is int32"); + assert_eq!(id_values.values(), &[2_i32], "expected the row with value=20"); + } + /// Verifies that file splits respect byte ranges and only read specific row groups. #[tokio::test] async fn test_file_splits_respect_byte_ranges() {