diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 22d6b2e3c7..9acffb14c8 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -365,6 +365,17 @@ for buf in tbl.scan().to_arrow_batch_reader(): print(f"Buffer contains {len(buf)} rows") ``` +### Streaming writes from a `RecordBatchReader` + +`tbl.append()` and `tbl.overwrite()` also accept a `pyarrow.RecordBatchReader` directly, which lets you write datasets that don't fit in memory without materialising them as a `pa.Table` first. PyIceberg consumes the reader once, writing batches through a rolling Parquet writer that rolls a new file each time the on-disk file size hits `write.target-file-size-bytes` (default 512 MiB). Each input `RecordBatch` becomes a Parquet row group, capped at `write.parquet.row-group-limit` rows (default 1M) — caller batch size sets the lower bound on row group size, the property enforces the upper bound. All files are committed in a single snapshot. + +```python +reader = pa.RecordBatchReader.from_batches(schema, batch_iter) +tbl.append(reader) +``` + +Streaming writes are currently only supported on **unpartitioned** tables. For a partitioned table, materialise the reader as a `pa.Table` first, or follow [#2152](https://github.com/apache/iceberg-python/issues/2152) for the partitioned support tracked as a follow-up. + To avoid any type inconsistencies during writing, you can convert the Iceberg table schema to Arrow: ```python diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 4517ae7327..bb82dbdfed 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2665,7 +2665,159 @@ def write_parquet(task: WriteTask) -> DataFile: return iter(data_files) +def _record_batches_to_data_files( + table_metadata: TableMetadata, + reader: pa.RecordBatchReader, + io: FileIO, + write_uuid: uuid.UUID | None = None, + counter: itertools.count[int] | None = None, +) -> Iterator[DataFile]: + """Stream a ``pa.RecordBatchReader`` into Parquet data files via a rolling ``pq.ParquetWriter``. + + Each input ``RecordBatch`` is written directly via + ``writer.write_batch``. File rollover is driven by ``OutputStream.tell()`` + (#2998): after each batch, if ``tell() >= write.target-file-size-bytes`` + the current writer is closed (footer written) and a new file is opened. + The threshold is measured in compressed on-disk bytes — matching the + spec interpretation of ``write.target-file-size-bytes``. + + Row groups are capped at ``write.parquet.row-group-limit`` rows (default + 1M) via the ``row_group_size`` argument to ``write_batch``: a batch + larger than the cap is split into multiple row groups, each ≤ the cap; + a batch smaller than the cap becomes a single row group of its own + size. Callers control the lower bound of row group size by their + choice of input batch size; pyiceberg enforces the upper bound. This + matches the materialised ``pa.Table`` write path's treatment of the + same property. + + Memory per writer is bounded by one input ``RecordBatch`` plus the + ``ParquetWriter``'s internal page buffer (~1 MiB by default). The + materialised ``pa.Table`` write path (``write_file``) keeps its + existing ``executor.map``-based file-level parallelism; streaming + writes are sequential — one rolling file at a time, with concurrency + provided by the underlying multipart upload pool. + + Streaming writes to partitioned tables are not yet supported — see + https://github.com/apache/iceberg-python/issues/2152. + """ + from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties + + if not table_metadata.spec().is_unpartitioned(): + raise NotImplementedError( + "Writing a pa.RecordBatchReader to a partitioned table is not yet supported. " + "Materialise the reader as a pa.Table first, or follow " + "https://github.com/apache/iceberg-python/issues/2152 for partitioned streaming support." + ) + + counter = counter or itertools.count(0) + write_uuid = write_uuid or uuid.uuid4() + target_file_size: int = property_as_int( # type: ignore # The property is set with non-None value. + properties=table_metadata.properties, + property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES, + default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT, + ) + name_mapping = table_metadata.schema().name_mapping + downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + task_schema = pyarrow_to_schema( + reader.schema, + name_mapping=name_mapping, + downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, + format_version=table_metadata.format_version, + ) + table_schema = table_metadata.schema() + if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema: + file_schema = sanitized_schema + else: + file_schema = table_schema + + parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties) + row_group_size = property_as_int( + properties=table_metadata.properties, + property_name=TableProperties.PARQUET_ROW_GROUP_LIMIT, + default=TableProperties.PARQUET_ROW_GROUP_LIMIT_DEFAULT, + ) + location_provider = load_location_provider(table_location=table_metadata.location, table_properties=table_metadata.properties) + stats_columns = compute_statistics_plan(file_schema, table_metadata.properties) + column_mapping = parquet_path_to_id_mapping(file_schema) + + def _transform(batch: pa.RecordBatch) -> pa.RecordBatch: + return _to_requested_schema( + requested_schema=file_schema, + file_schema=task_schema, + batch=batch, + downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, + include_field_ids=True, + ) + + def _new_data_file_path() -> str: + # Mirrors WriteTask.generate_data_file_filename to keep file names compatible + # with the materialised write path. + filename = f"00000-{next(counter)}-{write_uuid}.parquet" + return location_provider.new_data_location(data_file_name=filename) + + def _build_data_file(file_path: str, output_file: OutputFile, parquet_metadata: Any) -> DataFile: + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=parquet_metadata, + stats_columns=stats_columns, + parquet_column_mapping=column_mapping, + ) + return DataFile.from_args( + content=DataFileContent.DATA, + file_path=file_path, + file_format=FileFormat.PARQUET, + partition=Record(), + file_size_in_bytes=len(output_file), + sort_order_id=None, + spec_id=table_metadata.default_spec_id, + equality_ids=None, + key_metadata=None, + **statistics.to_serialized_dict(), + ) + + batches = iter(reader) + while True: + # Pull the next batch up front. If the reader is exhausted (either at the + # very start or between rolled files), we're done — yield nothing further. + try: + first_batch = next(batches) + except StopIteration: + return + + transformed_first = _transform(first_batch) + file_path = _new_data_file_path() + output_file = io.new_output(file_path) + with output_file.create(overwrite=True) as fos: + with pq.ParquetWriter( + fos, + schema=transformed_first.schema, + store_decimal_as_integer=True, + **parquet_writer_kwargs, + ) as writer: + writer.write_batch(transformed_first, row_group_size=row_group_size) + # Keep writing into this file until the on-disk byte threshold + # is crossed. ``tell()`` advances as ``write_batch`` flushes + # encoded pages to the stream — files end up close to but + # slightly above ``target_file_size`` (lag bounded by one + # Parquet data page, ~1 MiB by default). + while fos.tell() < target_file_size: + try: + batch = next(batches) + except StopIteration: + break + writer.write_batch(_transform(batch), row_group_size=row_group_size) + # writer is closed (footer written) and the OutputStream is flushed. + # writer.writer.metadata is still readable post-close — same access + # pattern used by write_file(). + yield _build_data_file(file_path, output_file, writer.writer.metadata) + + def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[list[pa.RecordBatch]]: + """Bin-pack ``tbl`` into groups of RecordBatches, each ~``target_file_size`` uncompressed Arrow bytes. + + Used by the materialised ``pa.Table`` write path (``_dataframe_to_data_files`` + + ``write_file``) to split a fully in-memory table into multiple Parquet + files written in parallel. + """ from pyiceberg.utils.bin_packing import PackingIterator avg_row_size_bytes = tbl.nbytes / tbl.num_rows @@ -2800,15 +2952,24 @@ def _get_parquet_writer_kwargs(table_properties: Properties) -> dict[str, Any]: def _dataframe_to_data_files( table_metadata: TableMetadata, - df: pa.Table, + df: pa.Table | pa.RecordBatchReader, io: FileIO, write_uuid: uuid.UUID | None = None, counter: itertools.count[int] | None = None, ) -> Iterable[DataFile]: - """Convert a PyArrow table into a DataFile. + """Convert a PyArrow Table or RecordBatchReader into DataFiles. + + For a ``pa.Table`` the data is materialised in memory and bin-packed into + target-sized files (with partition splitting if the table is partitioned). + + For a ``pa.RecordBatchReader`` batches are streamed and microbatched into + target-sized files using bounded memory (see :func:`bin_pack_record_batches`). + Streaming writes are currently only supported on unpartitioned tables; + partitioned support is tracked in + https://github.com/apache/iceberg-python/issues/2152. Returns: - An iterable that supplies datafiles that represent the table. + An iterable that supplies datafiles that represent the input data. """ from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties, WriteTask @@ -2828,6 +2989,20 @@ def _dataframe_to_data_files( format_version=table_metadata.format_version, ) + if isinstance(df, pa.RecordBatchReader): + # Streaming path: rolling ParquetWriter driven by OutputStream.tell() + # for constant-memory writes and on-disk-accurate file sizes. + # Partitioned-table support is the responsibility of + # _record_batches_to_data_files; the NotImplementedError is raised there. + yield from _record_batches_to_data_files( + table_metadata=table_metadata, + reader=df, + io=io, + write_uuid=write_uuid, + counter=counter, + ) + return + if table_metadata.spec().is_unpartitioned(): yield from write_file( io=io, diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index b8d87143c9..616aa9c59c 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -450,12 +450,53 @@ def update_statistics(self) -> UpdateStatistics: """ return UpdateStatistics(transaction=self) - def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH) -> None: + def append( + self, + df: pa.Table | pa.RecordBatchReader, + snapshot_properties: dict[str, str] = EMPTY_DICT, + branch: str | None = MAIN_BRANCH, + ) -> None: """ - Shorthand API for appending a PyArrow table to a table transaction. + Shorthand API for appending PyArrow data to a table transaction. + + Accepts either a fully materialised ``pa.Table`` or a streaming + ``pa.RecordBatchReader``. For a reader, batches are written through a + rolling ``pq.ParquetWriter`` and a new file is rolled each time the + on-disk file size hits ``write.target-file-size-bytes``. The reader is + consumed once and cannot be reused. + + Streaming writes are currently only supported on unpartitioned tables; + passing a ``pa.RecordBatchReader`` for a partitioned table raises + ``NotImplementedError``. See + https://github.com/apache/iceberg-python/issues/2152. + + Note: + When ``df`` is a ``pa.RecordBatchReader`` the reader is consumed + once and cannot be replayed. If the catalog commit fails (e.g. + ``CommitFailedException`` from a concurrent writer) the reader is + already drained and a naive retry will append zero rows. Callers + that need at-least-once semantics should either: + + - reconstruct the reader on each attempt via a factory callable, + or + - use a two-stage pattern — write Parquet files explicitly and + then call :meth:`add_files` (whose input is a replayable list of + paths) within a retry loop. + + Failures during the write stage (mid-stream reader exception, S3 + errors) do not commit a snapshot, but may leave orphan data files + in storage that are not referenced by any snapshot. Clean these + up with expire/orphan-file maintenance jobs. + + For streaming inputs (``pa.RecordBatchReader``) each input + ``RecordBatch`` becomes one Parquet row group. The + ``write.parquet.row-group-limit`` property (rows, default 1M) + caps row group size — batches larger than the cap are split, + smaller batches are not combined. Caller batch size sets the + lower bound; pyiceberg enforces the upper bound. Args: - df: The Arrow dataframe that will be appended to overwrite the table + df: An Arrow Table or a RecordBatchReader of records to append. snapshot_properties: Custom properties to be added to the snapshot summary branch: Branch Reference to run the append operation """ @@ -466,8 +507,8 @@ def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files - if not isinstance(df, pa.Table): - raise ValueError(f"Expected PyArrow table, got: {df}") + if not isinstance(df, (pa.Table, pa.RecordBatchReader)): + raise ValueError(f"Expected pa.Table or pa.RecordBatchReader, got: {df}") downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False _check_pyarrow_schema_compatible( @@ -478,12 +519,14 @@ def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, ) with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: - # skip writing data files if the dataframe is empty - if df.shape[0] > 0: - data_files = list( - _dataframe_to_data_files( - table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io - ) + # For pa.Table we can short-circuit empty inputs cheaply. For a + # RecordBatchReader the stream is consumed lazily by + # _dataframe_to_data_files and an empty reader simply yields zero + # data files (the snapshot is still committed for symmetry with the + # pa.Table case where empty inputs also produce a snapshot). + if isinstance(df, pa.RecordBatchReader) or df.shape[0] > 0: + data_files = _dataframe_to_data_files( + table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io ) for data_file in data_files: append_files.append_data_file(data_file) @@ -555,14 +598,50 @@ def dynamic_partition_overwrite( def overwrite( self, - df: pa.Table, + df: pa.Table | pa.RecordBatchReader, overwrite_filter: BooleanExpression | str = ALWAYS_TRUE, snapshot_properties: dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, branch: str | None = MAIN_BRANCH, ) -> None: """ - Shorthand for adding a table overwrite with a PyArrow table to the transaction. + Shorthand for adding a table overwrite with a PyArrow table or RecordBatchReader to the transaction. + + Accepts either a fully materialised ``pa.Table`` or a streaming + ``pa.RecordBatchReader``. For a reader, batches are written through a + rolling ``pq.ParquetWriter`` and a new file is rolled each time the + on-disk file size hits ``write.target-file-size-bytes``. The reader is + consumed once and cannot be reused. + + Streaming writes are currently only supported on unpartitioned tables; + passing a ``pa.RecordBatchReader`` for a partitioned table raises + ``NotImplementedError``. See + https://github.com/apache/iceberg-python/issues/2152. + + Note: + When ``df`` is a ``pa.RecordBatchReader`` the reader is consumed + once and cannot be replayed. If the catalog commit fails (e.g. + ``CommitFailedException`` from a concurrent writer) the reader is + already drained and a naive retry will write zero rows. Callers + that need at-least-once semantics should either: + + - reconstruct the reader on each attempt via a factory callable, + or + - use a two-stage pattern — write Parquet files explicitly and + then call :meth:`add_files` (whose input is a replayable list + of paths) within a retry loop. + + Failures during the write stage (mid-stream reader exception, S3 + errors) do not commit a snapshot, but may leave orphan data files + in storage that are not referenced by any snapshot. Clean these + up with expire/orphan-file maintenance jobs. + + For streaming inputs (``pa.RecordBatchReader``) each input + ``RecordBatch`` becomes one Parquet row group. The + ``write.parquet.row-group-limit`` property (rows, default 1M) + caps row group size — batches larger than the cap are split, + smaller batches are not combined. Caller batch size sets the + lower bound; pyiceberg enforces the upper bound. An overwrite may produce zero or more snapshots based on the operation: @@ -571,7 +650,7 @@ def overwrite( - APPEND: In case new data is being inserted into the table. Args: - df: The Arrow dataframe that will be used to overwrite the table + df: An Arrow Table or a RecordBatchReader of records to write. overwrite_filter: ALWAYS_TRUE when you overwrite all the data, or a boolean expression in case of a partial overwrite snapshot_properties: Custom properties to be added to the snapshot summary @@ -585,8 +664,8 @@ def overwrite( from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files - if not isinstance(df, pa.Table): - raise ValueError(f"Expected PyArrow table, got: {df}") + if not isinstance(df, (pa.Table, pa.RecordBatchReader)): + raise ValueError(f"Expected pa.Table or pa.RecordBatchReader, got: {df}") downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False _check_pyarrow_schema_compatible( @@ -606,8 +685,8 @@ def overwrite( ) with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: - # skip writing data files if the dataframe is empty - if df.shape[0] > 0: + # See append() for the empty-input handling rationale. + if isinstance(df, pa.RecordBatchReader) or df.shape[0] > 0: data_files = _dataframe_to_data_files( table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io ) @@ -1373,12 +1452,21 @@ def upsert( snapshot_properties=snapshot_properties, ) - def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH) -> None: + def append( + self, + df: pa.Table | pa.RecordBatchReader, + snapshot_properties: dict[str, str] = EMPTY_DICT, + branch: str | None = MAIN_BRANCH, + ) -> None: """ - Shorthand API for appending a PyArrow table to the table. + Shorthand API for appending PyArrow data to the table. + + Accepts either a ``pa.Table`` or a streaming ``pa.RecordBatchReader``. + See :meth:`Transaction.append` for streaming semantics and partition + limitations. Args: - df: The Arrow dataframe that will be appended to overwrite the table + df: An Arrow Table or a RecordBatchReader of records to append. snapshot_properties: Custom properties to be added to the snapshot summary branch: Branch Reference to run the append operation """ @@ -1401,14 +1489,18 @@ def dynamic_partition_overwrite( def overwrite( self, - df: pa.Table, + df: pa.Table | pa.RecordBatchReader, overwrite_filter: BooleanExpression | str = ALWAYS_TRUE, snapshot_properties: dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, branch: str | None = MAIN_BRANCH, ) -> None: """ - Shorthand for overwriting the table with a PyArrow table. + Shorthand for overwriting the table with a PyArrow Table or RecordBatchReader. + + Accepts either a ``pa.Table`` or a streaming ``pa.RecordBatchReader``. + See :meth:`Transaction.overwrite` for streaming semantics and partition + limitations. An overwrite may produce zero or more snapshots based on the operation: @@ -1417,7 +1509,7 @@ def overwrite( - APPEND: In case new data is being inserted into the table. Args: - df: The Arrow dataframe that will be used to overwrite the table + df: An Arrow Table or a RecordBatchReader of records to write. overwrite_filter: ALWAYS_TRUE when you overwrite all the data, or a boolean expression in case of a partial overwrite snapshot_properties: Custom properties to be added to the snapshot summary diff --git a/tests/catalog/test_catalog_behaviors.py b/tests/catalog/test_catalog_behaviors.py index 01e0d2ce31..76473db812 100644 --- a/tests/catalog/test_catalog_behaviors.py +++ b/tests/catalog/test_catalog_behaviors.py @@ -20,6 +20,7 @@ """ import os +from collections.abc import Generator from pathlib import Path from typing import Any @@ -1190,3 +1191,278 @@ def test_drop_namespace_raises_error_when_namespace_not_empty( catalog.create_table(test_table_identifier, table_schema_nested) with pytest.raises(NamespaceNotEmptyError, match=f"Namespace {'.'.join(namespace)} is not empty"): catalog.drop_namespace(namespace) + + +# RecordBatchReader streaming append/overwrite tests +# +# Streaming writes accept a pa.RecordBatchReader and write it through a rolling +# Parquet writer (row groups flushed at write.parquet.row-group-limit, files +# rolled at write.target-file-size-bytes via OutputStream.tell()) instead of +# materialising the full Arrow Table in memory. Tracks +# https://github.com/apache/iceberg-python/issues/2152. + + +def _simple_arrow_table() -> pa.Table: + return pa.Table.from_pydict( + {"foo": ["a", None, "z"]}, + schema=pa.schema([pa.field("foo", pa.large_string(), nullable=True)]), + ) + + +def _simple_record_batch_reader(num_batches: int = 3) -> tuple[pa.RecordBatchReader, int]: + """Build an N-batch reader of the simple schema. Returns (reader, total_rows).""" + pa_table = _simple_arrow_table() + batches = pa_table.to_batches() * num_batches + reader = pa.RecordBatchReader.from_batches(pa_table.schema, iter(batches)) + return reader, sum(b.num_rows for b in batches) + + +def test_append_record_batch_reader(catalog: Catalog) -> None: + catalog.create_namespace("default") + identifier = f"default.append_record_batch_reader_{catalog.name}" + reader, total_rows = _simple_record_batch_reader(num_batches=3) + tbl = catalog.create_table(identifier=identifier, schema=reader.schema) + + tbl.append(reader) + + assert len(tbl.scan().to_arrow()) == total_rows + + +def test_append_record_batch_reader_microbatched(catalog: Catalog) -> None: + """A reader bigger than the per-file target produces multiple Parquet files + in a single snapshot — verifies file rollover via ``OutputStream.tell()``. + + Sets a tiny ``target-file-size-bytes`` so each batch's flush rolls a new + file. Each input ``RecordBatch`` is its own row group, so ``tell()`` + advances after every ``write_batch``. + """ + catalog.create_namespace("default") + identifier = f"default.append_record_batch_reader_microbatch_{catalog.name}" + reader, total_rows = _simple_record_batch_reader(num_batches=8) + tbl = catalog.create_table( + identifier=identifier, + schema=reader.schema, + properties={TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: "1"}, + ) + + tbl.append(reader) + + snapshot = tbl.metadata.current_snapshot() + assert snapshot is not None + assert snapshot.summary is not None + added_files = snapshot.summary["added-data-files"] + assert added_files is not None and int(added_files) > 1, snapshot.summary + assert len(tbl.scan().to_arrow()) == total_rows + + +def test_append_record_batch_reader_row_group_limit_is_cap(catalog: Catalog) -> None: + """``write.parquet.row-group-limit`` caps the maximum rows per Parquet + row group. A single input batch larger than the cap is split into + multiple row groups, each ≤ the cap. The streaming path enforces the + upper bound; callers control the lower bound by their choice of input + batch size. + """ + import pyarrow.parquet as pq + + catalog.create_namespace("default") + identifier = f"default.append_record_batch_reader_row_group_limit_cap_{catalog.name}" + + row_group_cap = 250 + total_rows = 1000 # 4× the cap + schema = pa.schema([("id", pa.int64())]) + # One big batch — pyiceberg should split it into ⌈1000 / 250⌉ = 4 row groups + # of exactly 250 rows each. + big_batch = pa.RecordBatch.from_pylist( + [{"id": i} for i in range(total_rows)], + schema=schema, + ) + reader = pa.RecordBatchReader.from_batches(schema, iter([big_batch])) + + tbl = catalog.create_table( + identifier=identifier, + schema=schema, + properties={ + TableProperties.PARQUET_ROW_GROUP_LIMIT: str(row_group_cap), + # Big enough that everything fits in one file; we're testing row + # group size, not file rollover. + TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: str(64 * 1024 * 1024), + }, + ) + tbl.append(reader) + + assert len(tbl.scan().to_arrow()) == total_rows + + files = tbl.inspect.files().select(["file_path"]).to_pylist() + assert len(files) == 1, files + + # Read the parquet footer and check row group sizes + file_path = files[0]["file_path"] + metadata = pq.read_metadata(tbl.io.new_input(file_path).open()) + row_group_sizes = [metadata.row_group(i).num_rows for i in range(metadata.num_row_groups)] + + # Expect 4 row groups of exactly row_group_cap rows each. Without the cap + # passed to write_batch, the whole 1000-row batch would become one row + # group — the test would fail loudly. + assert metadata.num_row_groups == total_rows // row_group_cap, row_group_sizes + for rg_size in row_group_sizes: + assert rg_size == row_group_cap, row_group_sizes + + +def test_append_record_batch_reader_target_file_size_is_on_disk_bytes(catalog: Catalog) -> None: + """The streaming write path interprets ``write.target-file-size-bytes`` as + actual on-disk compressed Parquet bytes (via ``OutputStream.tell()``), not + uncompressed in-memory Arrow bytes. This test sets a small file target, + streams several batches, and asserts each rolled file is close to the + target size — proving the spec-correct semantics. + """ + catalog.create_namespace("default") + identifier = f"default.append_record_batch_reader_target_size_{catalog.name}" + + target_bytes = 32 * 1024 # 32 KiB target — small so we get multiple files quickly + schema = pa.schema([("id", pa.int64()), ("payload", pa.large_string())]) + # ~80 bytes per row uncompressed; with zstd ~10x compression we expect + # approximately 4000 rows per ~32 KiB file. + rows_per_batch = 1000 + total_batches = 12 + batches = [ + pa.RecordBatch.from_pylist( + [{"id": i * rows_per_batch + j, "payload": f"row_{i * rows_per_batch + j:08d}"} for j in range(rows_per_batch)], + schema=schema, + ) + for i in range(total_batches) + ] + reader = pa.RecordBatchReader.from_batches(schema, iter(batches)) + expected_rows = total_batches * rows_per_batch + + tbl = catalog.create_table( + identifier=identifier, + schema=schema, + properties={TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: str(target_bytes)}, + ) + tbl.append(reader) + + # All rows landed + assert len(tbl.scan().to_arrow()) == expected_rows + + # Multiple files were rolled + snapshot = tbl.metadata.current_snapshot() + assert snapshot is not None and snapshot.summary is not None + added_files = int(snapshot.summary["added-data-files"]) # type: ignore[arg-type] + assert added_files >= 2, snapshot.summary + + # Per-file size: every rolled file (i.e. all but possibly the last) should be + # *close to* target_bytes. The lag between tell() and write_batch is bounded + # by one Parquet data page (~1 MiB by default), so files end up slightly + # above target. We assert each rolled file is between 0.5x and 5x the + # target — a loose bound that catches the old uncompressed-Arrow-bytes + # behaviour (where files would be ~3-10x SMALLER than target). + files = tbl.inspect.files().select(["file_path", "file_size_in_bytes"]).to_pylist() + rolled_files = files[:-1] if len(files) > 1 else files + for f in rolled_files: + size = f["file_size_in_bytes"] + assert target_bytes // 2 <= size <= target_bytes * 5, f"{f['file_path']}: {size} bytes (target {target_bytes})" + + +def test_append_record_batch_reader_empty(catalog: Catalog) -> None: + catalog.create_namespace("default") + identifier = f"default.append_record_batch_reader_empty_{catalog.name}" + schema = _simple_arrow_table().schema + reader = pa.RecordBatchReader.from_batches(schema, iter([])) + tbl = catalog.create_table(identifier=identifier, schema=schema) + + tbl.append(reader) + + assert len(tbl.scan().to_arrow()) == 0 + + +def test_overwrite_record_batch_reader(catalog: Catalog) -> None: + catalog.create_namespace("default") + identifier = f"default.overwrite_record_batch_reader_{catalog.name}" + pa_table = _simple_arrow_table() + tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema) + tbl.append(pa_table) + assert len(tbl.scan().to_arrow()) == pa_table.num_rows + + reader, total_rows = _simple_record_batch_reader(num_batches=2) + tbl.overwrite(reader) + + assert len(tbl.scan().to_arrow()) == total_rows + + +def test_append_record_batch_reader_to_partitioned_table_raises(catalog: Catalog) -> None: + catalog.create_namespace("default") + identifier = f"default.append_record_batch_reader_partitioned_{catalog.name}" + iceberg_schema = Schema( + NestedField(1, "id", IntegerType(), required=False), + NestedField(2, "bucket", StringType(), required=False), + ) + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="bucket"), + ) + tbl = catalog.create_table(identifier=identifier, schema=iceberg_schema, partition_spec=partition_spec) + + arrow_schema = schema_to_pyarrow(iceberg_schema) + reader = pa.RecordBatchReader.from_batches(arrow_schema, iter([])) + with pytest.raises(NotImplementedError, match="partitioned table"): + tbl.append(reader) + + +def test_append_invalid_input_type_raises(catalog: Catalog) -> None: + catalog.create_namespace("default") + identifier = f"default.append_invalid_input_{catalog.name}" + pa_table = _simple_arrow_table() + tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema) + with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader"): + tbl.append("not an arrow object") + + +def test_record_batch_reader_consumed_exactly_once(catalog: Catalog) -> None: + """The streaming path must consume the underlying generator exactly once. + A regression that drained the reader twice (e.g. an extra .schema access + that materialised the iterator, or a retry-loop without a fresh reader) + would silently lose data — the second pass is empty. + """ + catalog.create_namespace("default") + identifier = f"default.record_batch_reader_consumed_once_{catalog.name}" + pa_table = _simple_arrow_table() + consumed_batches = 0 + + def tracking_batches() -> Generator[pa.RecordBatch, None, None]: + nonlocal consumed_batches + for batch in pa_table.to_batches() * 3: + consumed_batches += 1 + yield batch + + reader = pa.RecordBatchReader.from_batches(pa_table.schema, tracking_batches()) + tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema) + + tbl.append(reader) + + # The generator should have been driven to exhaustion exactly once: 3 batches. + assert consumed_batches == 3 + assert len(tbl.scan().to_arrow()) == pa_table.num_rows * 3 + + +def test_record_batch_reader_schema_mismatch_writes_no_files(catalog: Catalog) -> None: + """A schema mismatch must fail before any data files are written. Otherwise + we'd leak orphan parquet files in storage (and a partial commit that picks + them up later via add_files would be a correctness disaster). + """ + catalog.create_namespace("default") + identifier = f"default.record_batch_reader_schema_mismatch_{catalog.name}" + iceberg_schema = Schema(NestedField(1, "foo", StringType(), required=False)) + tbl = catalog.create_table(identifier=identifier, schema=iceberg_schema) + + bad_schema = pa.schema([pa.field("foo", pa.int64(), nullable=True)]) + bad_reader = pa.RecordBatchReader.from_batches( + bad_schema, + iter([pa.RecordBatch.from_pylist([{"foo": 1}], schema=bad_schema)]), + ) + + with pytest.raises(ValueError): + tbl.append(bad_reader) + + # No snapshot should have been produced: the schema check runs before + # _append_snapshot_producer opens. + assert tbl.metadata.current_snapshot() is None + assert len(tbl.scan().to_arrow()) == 0 diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 2bc4985609..1d1488255f 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -768,7 +768,7 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non properties={"format-version": "1"}, ) - with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"): + with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader, got: not a df"): tbl.append("not a df") diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 0a09867656..f2e5f55cd8 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -791,10 +791,10 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_ identifier = "default.arrow_data_files" tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, []) - with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"): + with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader, got: not a df"): tbl.overwrite("not a df") - with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"): + with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader, got: not a df"): tbl.append("not a df") @@ -2571,3 +2571,91 @@ def test_v3_write_and_read_row_lineage(spark: SparkSession, session_catalog: Cat assert tbl.metadata.next_row_id == initial_next_row_id + len(test_data), ( "Expected next_row_id to be incremented by the number of added rows" ) + + +# RecordBatchReader streaming append/overwrite — see https://github.com/apache/iceberg-python/issues/2152 +# +# These integration tests prove Spark can read tables written via the new +# streaming path (rolling pq.ParquetWriter + fast_append commit). Equivalent +# in-process scan coverage lives in tests/catalog/test_catalog_behaviors.py; +# only Spark exercises the resulting manifest stats and Parquet metadata +# against an external reader. + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_append_record_batch_reader( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.streaming_append_record_batch_reader_v{format_version}" + tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}) + + # 4 batches × 3 rows each — exercises the multi-batch streaming path while + # keeping the assertion data tractable for Spark. + batches = arrow_table_with_null.to_batches() * 4 + reader = pa.RecordBatchReader.from_batches(arrow_table_with_null.schema, iter(batches)) + expected_rows = sum(b.num_rows for b in batches) + + tbl.append(reader) + + assert len(tbl.scan().to_arrow()) == expected_rows + df = spark.table(identifier) + assert df.count() == expected_rows + # Spot-check that Spark agrees on the schema as written + assert sorted(df.columns) == sorted(arrow_table_with_null.column_names) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_overwrite_record_batch_reader( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.streaming_overwrite_record_batch_reader_v{format_version}" + tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) + assert len(tbl.scan().to_arrow()) == arrow_table_with_null.num_rows + + batches = arrow_table_with_null.to_batches() * 2 + reader = pa.RecordBatchReader.from_batches(arrow_table_with_null.schema, iter(batches)) + expected_rows = sum(b.num_rows for b in batches) + + tbl.overwrite(reader) + + # Existing rows replaced, only the streamed rows remain + assert len(tbl.scan().to_arrow()) == expected_rows + df = spark.table(identifier) + assert df.count() == expected_rows + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_append_record_batch_reader_multifile( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + """Forcing a tiny target file size should produce >1 data file in a single + snapshot, proving the rolling ParquetWriter's tell()-based rollover fires + end-to-end and the resulting files are valid Iceberg data files (Spark + reads them all).""" + identifier = f"default.streaming_append_multifile_v{format_version}" + tbl = _create_table( + session_catalog, + identifier, + { + "format-version": str(format_version), + TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: "1", + }, + ) + + batches = arrow_table_with_null.to_batches(max_chunksize=1) * 4 + reader = pa.RecordBatchReader.from_batches(arrow_table_with_null.schema, iter(batches)) + expected_rows = sum(b.num_rows for b in batches) + + tbl.append(reader) + + snapshot = tbl.metadata.current_snapshot() + assert snapshot is not None + assert snapshot.summary is not None + added_files = snapshot.summary["added-data-files"] + assert added_files is not None and int(added_files) > 1, snapshot.summary + + df = spark.table(identifier) + assert df.count() == expected_rows