Skip to content

Commit e00a55c

Browse files
committed
Handle Empty RecordBatch within _task_to_record_batches (#1026)
1 parent 01e8ce2 commit e00a55c

3 files changed

Lines changed: 100 additions & 2 deletions

File tree

pyiceberg/io/pyarrow.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,9 +1205,11 @@ def _task_to_record_batches(
12051205
columns=[col.name for col in file_project_schema.columns],
12061206
)
12071207

1208-
current_index = 0
1208+
next_index = 0
12091209
batches = fragment_scanner.to_batches()
12101210
for batch in batches:
1211+
next_index = next_index + len(batch)
1212+
current_index = next_index - len(batch)
12111213
if positional_deletes:
12121214
# Create the mask of indices that we're interested in
12131215
indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch))
@@ -1219,9 +1221,10 @@ def _task_to_record_batches(
12191221
# https://github.com/apache/arrow/issues/39220
12201222
arrow_table = pa.Table.from_batches([batch])
12211223
arrow_table = arrow_table.filter(pyarrow_filter)
1224+
if len(arrow_table) == 0:
1225+
continue
12221226
batch = arrow_table.to_batches()[0]
12231227
yield _to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True)
1224-
current_index += len(batch)
12251228

12261229

12271230
def _task_to_table(

tests/integration/test_deletes.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,73 @@ def test_delete_partitioned_table_positional_deletes(spark: SparkSession, sessio
222222
assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [10], "number": [20]}
223223

224224

225+
@pytest.mark.integration
226+
def test_delete_partitioned_table_positional_deletes_empty_batch(spark: SparkSession, session_catalog: RestCatalog) -> None:
227+
identifier = "default.test_delete_partitioned_table_positional_deletes_empty_batch"
228+
229+
run_spark_commands(
230+
spark,
231+
[
232+
f"DROP TABLE IF EXISTS {identifier}",
233+
f"""
234+
CREATE TABLE {identifier} (
235+
number_partitioned int,
236+
number int
237+
)
238+
USING iceberg
239+
PARTITIONED BY (number_partitioned)
240+
TBLPROPERTIES(
241+
'format-version' = 2,
242+
'write.delete.mode'='merge-on-read',
243+
'write.update.mode'='merge-on-read',
244+
'write.merge.mode'='merge-on-read',
245+
'write.parquet.row-group-limit'=1
246+
)
247+
""",
248+
],
249+
)
250+
251+
tbl = session_catalog.load_table(identifier)
252+
253+
arrow_table = pa.Table.from_arrays(
254+
[
255+
pa.array([10, 10, 10]),
256+
pa.array([1, 2, 3]),
257+
],
258+
schema=pa.schema([pa.field("number_partitioned", pa.int32()), pa.field("number", pa.int32())]),
259+
)
260+
261+
tbl.append(arrow_table)
262+
263+
assert len(tbl.scan().to_arrow()) == 3
264+
265+
run_spark_commands(
266+
spark,
267+
[
268+
# Generate a positional delete
269+
f"""
270+
DELETE FROM {identifier} WHERE number = 1
271+
""",
272+
],
273+
)
274+
# Assert that there is just a single Parquet file, that has one merge on read file
275+
tbl = tbl.refresh()
276+
277+
files = list(tbl.scan().plan_files())
278+
assert len(files) == 1
279+
assert len(files[0].delete_files) == 1
280+
281+
assert len(tbl.scan().to_arrow()) == 2
282+
283+
assert len(tbl.scan(row_filter="number_partitioned == 10").to_arrow()) == 2
284+
285+
assert len(tbl.scan(row_filter="number_partitioned == 1").to_arrow()) == 0
286+
287+
reader = tbl.scan(row_filter="number_partitioned == 1").to_arrow_batch_reader()
288+
assert isinstance(reader, pa.RecordBatchReader)
289+
assert len(reader.read_all()) == 0
290+
291+
225292
@pytest.mark.integration
226293
def test_overwrite_partitioned_table(spark: SparkSession, session_catalog: RestCatalog) -> None:
227294
identifier = "default.table_partitioned_delete"

tests/integration/test_reads.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,3 +707,31 @@ def test_empty_scan_ordered_str(catalog: Catalog) -> None:
707707
table_empty_scan_ordered_str = catalog.load_table("default.test_empty_scan_ordered_str")
708708
arrow_table = table_empty_scan_ordered_str.scan(EqualTo("id", "b")).to_arrow()
709709
assert len(arrow_table) == 0
710+
711+
712+
@pytest.mark.integration
713+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
714+
def test_table_scan_empty_table(catalog: Catalog) -> None:
715+
identifier = "default.test_table_scan_empty_table"
716+
arrow_table = pa.Table.from_arrays(
717+
[
718+
pa.array([]),
719+
],
720+
schema=pa.schema([pa.field("colA", pa.string())]),
721+
)
722+
723+
try:
724+
catalog.drop_table(identifier)
725+
except NoSuchTableError:
726+
pass
727+
728+
tbl = catalog.create_table(
729+
identifier,
730+
schema=arrow_table.schema,
731+
)
732+
733+
tbl.append(arrow_table)
734+
735+
result_table = tbl.scan().to_arrow()
736+
737+
assert len(result_table) == 0

0 commit comments

Comments
 (0)