diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 22d6b2e3c7..9c4cccf08c 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -365,6 +365,17 @@ for buf in tbl.scan().to_arrow_batch_reader(): print(f"Buffer contains {len(buf)} rows") ``` +### Streaming writes from a `RecordBatchReader` + +`tbl.append()` and `tbl.overwrite()` also accept a `pyarrow.RecordBatchReader` directly, which lets you write datasets that don't fit in memory without materialising them as a `pa.Table` first. PyIceberg consumes the reader once and microbatches it into Parquet files of approximately `write.target-file-size-bytes` (default 512 MiB), keeping memory usage bounded by the target size. All files are committed in a single snapshot. + +```python +reader = pa.RecordBatchReader.from_batches(schema, batch_iter) +tbl.append(reader) +``` + +Streaming writes are currently only supported on **unpartitioned** tables. For a partitioned table, materialise the reader as a `pa.Table` first, or follow [#2152](https://github.com/apache/iceberg-python/issues/2152) for the partitioned support tracked as a follow-up. + To avoid any type inconsistencies during writing, you can convert the Iceberg table schema to Arrow: ```python diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 4517ae7327..d24eceb9a5 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2666,6 +2666,18 @@ def write_parquet(task: WriteTask) -> DataFile: def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[list[pa.RecordBatch]]: + """Bin-pack ``tbl`` into groups of RecordBatches, each ~``target_file_size``. + + Note: + ``target_file_size`` is measured in **uncompressed in-memory** Arrow bytes + (``Table.nbytes`` / ``RecordBatch.nbytes``), not compressed on-disk Parquet + bytes. The resulting Parquet file after compression (zstd by default, + plus dictionary/RLE encoding) is typically 3-10× smaller than + ``target_file_size``. This is a coarse proxy for the spec-defined + ``write.target-file-size-bytes`` and will be tightened to true on-disk + bytes once the writer is switched to a rolling-``ParquetWriter`` with + ``OutputStream.tell()`` (#2998). + """ from pyiceberg.utils.bin_packing import PackingIterator avg_row_size_bytes = tbl.nbytes / tbl.num_rows @@ -2681,6 +2693,41 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[list[ return bin_packed_record_batches +def bin_pack_record_batches(batches: Iterable[pa.RecordBatch], target_file_size: int) -> Iterator[list[pa.RecordBatch]]: + """Microbatch a single-pass stream of RecordBatches into target-sized groups. + + Unlike :func:`bin_pack_arrow_table`, this consumes ``batches`` lazily and + holds at most one in-flight buffer in memory, bounded by ``target_file_size``. + Suitable for streaming inputs (``pa.RecordBatchReader``, + ``Iterator[pa.RecordBatch]``) where the total size is unknown up front and + the caller cannot afford to materialise the full dataset. + + Each yielded list of batches is intended to be written as a single Parquet + data file. Because this is single-pass FIFO accumulation (no lookback), the + last bin may be smaller than ``target_file_size``. + + Note: + ``target_file_size`` is measured in **uncompressed in-memory** Arrow + bytes (``RecordBatch.nbytes``), not compressed on-disk Parquet bytes. + The resulting Parquet file after compression is typically 3-10× + smaller than ``target_file_size``. Matches the existing + :func:`bin_pack_arrow_table` semantics; both will be tightened to true + on-disk bytes once the writer is switched to a rolling- + ``ParquetWriter`` with ``OutputStream.tell()`` (#2998). + """ + buffer: list[pa.RecordBatch] = [] + buffer_bytes = 0 + for batch in batches: + buffer.append(batch) + buffer_bytes += batch.nbytes + if buffer_bytes >= target_file_size: + yield buffer + buffer = [] + buffer_bytes = 0 + if buffer: + yield buffer + + def _check_pyarrow_schema_compatible( requested_schema: Schema, provided_schema: pa.Schema, @@ -2800,15 +2847,24 @@ def _get_parquet_writer_kwargs(table_properties: Properties) -> dict[str, Any]: def _dataframe_to_data_files( table_metadata: TableMetadata, - df: pa.Table, + df: pa.Table | pa.RecordBatchReader, io: FileIO, write_uuid: uuid.UUID | None = None, counter: itertools.count[int] | None = None, ) -> Iterable[DataFile]: - """Convert a PyArrow table into a DataFile. + """Convert a PyArrow Table or RecordBatchReader into DataFiles. + + For a ``pa.Table`` the data is materialised in memory and bin-packed into + target-sized files (with partition splitting if the table is partitioned). + + For a ``pa.RecordBatchReader`` batches are streamed and microbatched into + target-sized files using bounded memory (see :func:`bin_pack_record_batches`). + Streaming writes are currently only supported on unpartitioned tables; + partitioned support is tracked in + https://github.com/apache/iceberg-python/issues/2152. Returns: - An iterable that supplies datafiles that represent the table. + An iterable that supplies datafiles that represent the input data. """ from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties, WriteTask @@ -2828,6 +2884,23 @@ def _dataframe_to_data_files( format_version=table_metadata.format_version, ) + if isinstance(df, pa.RecordBatchReader): + if not table_metadata.spec().is_unpartitioned(): + raise NotImplementedError( + "Writing a pa.RecordBatchReader to a partitioned table is not yet supported. " + "Materialise the reader as a pa.Table first, or follow " + "https://github.com/apache/iceberg-python/issues/2152 for partitioned streaming support." + ) + yield from write_file( + io=io, + table_metadata=table_metadata, + tasks=( + WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema) + for batches in bin_pack_record_batches(df, target_file_size) + ), + ) + return + if table_metadata.spec().is_unpartitioned(): yield from write_file( io=io, diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index b8d87143c9..c5367c8679 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -450,12 +450,53 @@ def update_statistics(self) -> UpdateStatistics: """ return UpdateStatistics(transaction=self) - def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH) -> None: + def append( + self, + df: pa.Table | pa.RecordBatchReader, + snapshot_properties: dict[str, str] = EMPTY_DICT, + branch: str | None = MAIN_BRANCH, + ) -> None: """ - Shorthand API for appending a PyArrow table to a table transaction. + Shorthand API for appending PyArrow data to a table transaction. + + Accepts either a fully materialised ``pa.Table`` or a streaming + ``pa.RecordBatchReader``. Streaming is microbatched by + ``write.target-file-size-bytes`` so memory stays bounded; the reader is + consumed once and cannot be reused. + + Streaming writes are currently only supported on unpartitioned tables; + passing a ``pa.RecordBatchReader`` for a partitioned table raises + ``NotImplementedError``. See + https://github.com/apache/iceberg-python/issues/2152. + + Note: + When ``df`` is a ``pa.RecordBatchReader`` the reader is consumed + once and cannot be replayed. If the catalog commit fails (e.g. + ``CommitFailedException`` from a concurrent writer) the reader is + already drained and a naive retry will append zero rows. Callers + that need at-least-once semantics should either: + + - reconstruct the reader on each attempt via a factory callable, + or + - use a two-stage pattern — write Parquet files explicitly and + then call :meth:`add_files` (whose input is a replayable list of + paths) within a retry loop. + + Failures during the write stage (mid-stream reader exception, S3 + errors) do not commit a snapshot, but may leave orphan data files + in storage that are not referenced by any snapshot. Clean these + up with expire/orphan-file maintenance jobs. + + ``write.target-file-size-bytes`` is currently interpreted as + uncompressed in-memory Arrow bytes (the bin-packing weight) rather + than compressed on-disk Parquet bytes. The resulting files are + typically 3-10× smaller than the property suggests after + compression. This matches the existing ``pa.Table`` write path and + will be tightened once the writer is switched to a + rolling-``ParquetWriter`` with ``OutputStream.tell()`` (#2998). Args: - df: The Arrow dataframe that will be appended to overwrite the table + df: An Arrow Table or a RecordBatchReader of records to append. snapshot_properties: Custom properties to be added to the snapshot summary branch: Branch Reference to run the append operation """ @@ -466,8 +507,8 @@ def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files - if not isinstance(df, pa.Table): - raise ValueError(f"Expected PyArrow table, got: {df}") + if not isinstance(df, (pa.Table, pa.RecordBatchReader)): + raise ValueError(f"Expected pa.Table or pa.RecordBatchReader, got: {df}") downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False _check_pyarrow_schema_compatible( @@ -478,12 +519,14 @@ def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, ) with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: - # skip writing data files if the dataframe is empty - if df.shape[0] > 0: - data_files = list( - _dataframe_to_data_files( - table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io - ) + # For pa.Table we can short-circuit empty inputs cheaply. For a + # RecordBatchReader the stream is consumed lazily by + # _dataframe_to_data_files and an empty reader simply yields zero + # data files (the snapshot is still committed for symmetry with the + # pa.Table case where empty inputs also produce a snapshot). + if isinstance(df, pa.RecordBatchReader) or df.shape[0] > 0: + data_files = _dataframe_to_data_files( + table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io ) for data_file in data_files: append_files.append_data_file(data_file) @@ -555,14 +598,50 @@ def dynamic_partition_overwrite( def overwrite( self, - df: pa.Table, + df: pa.Table | pa.RecordBatchReader, overwrite_filter: BooleanExpression | str = ALWAYS_TRUE, snapshot_properties: dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, branch: str | None = MAIN_BRANCH, ) -> None: """ - Shorthand for adding a table overwrite with a PyArrow table to the transaction. + Shorthand for adding a table overwrite with a PyArrow table or RecordBatchReader to the transaction. + + Accepts either a fully materialised ``pa.Table`` or a streaming + ``pa.RecordBatchReader``. Streaming is microbatched by + ``write.target-file-size-bytes`` so memory stays bounded; the reader is + consumed once and cannot be reused. + + Streaming writes are currently only supported on unpartitioned tables; + passing a ``pa.RecordBatchReader`` for a partitioned table raises + ``NotImplementedError``. See + https://github.com/apache/iceberg-python/issues/2152. + + Note: + When ``df`` is a ``pa.RecordBatchReader`` the reader is consumed + once and cannot be replayed. If the catalog commit fails (e.g. + ``CommitFailedException`` from a concurrent writer) the reader is + already drained and a naive retry will write zero rows. Callers + that need at-least-once semantics should either: + + - reconstruct the reader on each attempt via a factory callable, + or + - use a two-stage pattern — write Parquet files explicitly and + then call :meth:`add_files` (whose input is a replayable list + of paths) within a retry loop. + + Failures during the write stage (mid-stream reader exception, S3 + errors) do not commit a snapshot, but may leave orphan data files + in storage that are not referenced by any snapshot. Clean these + up with expire/orphan-file maintenance jobs. + + ``write.target-file-size-bytes`` is currently interpreted as + uncompressed in-memory Arrow bytes (the bin-packing weight) rather + than compressed on-disk Parquet bytes. The resulting files are + typically 3-10× smaller than the property suggests after + compression. This matches the existing ``pa.Table`` write path and + will be tightened once the writer is switched to a + rolling-``ParquetWriter`` with ``OutputStream.tell()`` (#2998). An overwrite may produce zero or more snapshots based on the operation: @@ -571,7 +650,7 @@ def overwrite( - APPEND: In case new data is being inserted into the table. Args: - df: The Arrow dataframe that will be used to overwrite the table + df: An Arrow Table or a RecordBatchReader of records to write. overwrite_filter: ALWAYS_TRUE when you overwrite all the data, or a boolean expression in case of a partial overwrite snapshot_properties: Custom properties to be added to the snapshot summary @@ -585,8 +664,8 @@ def overwrite( from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files - if not isinstance(df, pa.Table): - raise ValueError(f"Expected PyArrow table, got: {df}") + if not isinstance(df, (pa.Table, pa.RecordBatchReader)): + raise ValueError(f"Expected pa.Table or pa.RecordBatchReader, got: {df}") downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False _check_pyarrow_schema_compatible( @@ -606,8 +685,8 @@ def overwrite( ) with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: - # skip writing data files if the dataframe is empty - if df.shape[0] > 0: + # See append() for the empty-input handling rationale. + if isinstance(df, pa.RecordBatchReader) or df.shape[0] > 0: data_files = _dataframe_to_data_files( table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io ) @@ -1373,12 +1452,21 @@ def upsert( snapshot_properties=snapshot_properties, ) - def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH) -> None: + def append( + self, + df: pa.Table | pa.RecordBatchReader, + snapshot_properties: dict[str, str] = EMPTY_DICT, + branch: str | None = MAIN_BRANCH, + ) -> None: """ - Shorthand API for appending a PyArrow table to the table. + Shorthand API for appending PyArrow data to the table. + + Accepts either a ``pa.Table`` or a streaming ``pa.RecordBatchReader``. + See :meth:`Transaction.append` for streaming semantics and partition + limitations. Args: - df: The Arrow dataframe that will be appended to overwrite the table + df: An Arrow Table or a RecordBatchReader of records to append. snapshot_properties: Custom properties to be added to the snapshot summary branch: Branch Reference to run the append operation """ @@ -1401,14 +1489,18 @@ def dynamic_partition_overwrite( def overwrite( self, - df: pa.Table, + df: pa.Table | pa.RecordBatchReader, overwrite_filter: BooleanExpression | str = ALWAYS_TRUE, snapshot_properties: dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, branch: str | None = MAIN_BRANCH, ) -> None: """ - Shorthand for overwriting the table with a PyArrow table. + Shorthand for overwriting the table with a PyArrow Table or RecordBatchReader. + + Accepts either a ``pa.Table`` or a streaming ``pa.RecordBatchReader``. + See :meth:`Transaction.overwrite` for streaming semantics and partition + limitations. An overwrite may produce zero or more snapshots based on the operation: @@ -1417,7 +1509,7 @@ def overwrite( - APPEND: In case new data is being inserted into the table. Args: - df: The Arrow dataframe that will be used to overwrite the table + df: An Arrow Table or a RecordBatchReader of records to write. overwrite_filter: ALWAYS_TRUE when you overwrite all the data, or a boolean expression in case of a partial overwrite snapshot_properties: Custom properties to be added to the snapshot summary diff --git a/tests/catalog/test_catalog_behaviors.py b/tests/catalog/test_catalog_behaviors.py index 01e0d2ce31..0a10c556e3 100644 --- a/tests/catalog/test_catalog_behaviors.py +++ b/tests/catalog/test_catalog_behaviors.py @@ -20,6 +20,7 @@ """ import os +from collections.abc import Generator from pathlib import Path from typing import Any @@ -1190,3 +1191,164 @@ def test_drop_namespace_raises_error_when_namespace_not_empty( catalog.create_table(test_table_identifier, table_schema_nested) with pytest.raises(NamespaceNotEmptyError, match=f"Namespace {'.'.join(namespace)} is not empty"): catalog.drop_namespace(namespace) + + +# RecordBatchReader streaming append/overwrite tests +# +# Streaming writes accept a pa.RecordBatchReader and microbatch it into target-sized +# Parquet files instead of materialising the full Arrow Table in memory. Tracks +# https://github.com/apache/iceberg-python/issues/2152. + + +def _simple_arrow_table() -> pa.Table: + return pa.Table.from_pydict( + {"foo": ["a", None, "z"]}, + schema=pa.schema([pa.field("foo", pa.large_string(), nullable=True)]), + ) + + +def _simple_record_batch_reader(num_batches: int = 3) -> tuple[pa.RecordBatchReader, int]: + """Build an N-batch reader of the simple schema. Returns (reader, total_rows).""" + pa_table = _simple_arrow_table() + batches = pa_table.to_batches() * num_batches + reader = pa.RecordBatchReader.from_batches(pa_table.schema, iter(batches)) + return reader, sum(b.num_rows for b in batches) + + +def test_append_record_batch_reader(catalog: Catalog) -> None: + catalog.create_namespace("default") + identifier = f"default.append_record_batch_reader_{catalog.name}" + reader, total_rows = _simple_record_batch_reader(num_batches=3) + tbl = catalog.create_table(identifier=identifier, schema=reader.schema) + + tbl.append(reader) + + assert len(tbl.scan().to_arrow()) == total_rows + + +def test_append_record_batch_reader_microbatched(catalog: Catalog) -> None: + """A reader bigger than the per-file target produces multiple Parquet files + in a single snapshot — verifying the byte-budget microbatching path.""" + catalog.create_namespace("default") + identifier = f"default.append_record_batch_reader_microbatch_{catalog.name}" + reader, total_rows = _simple_record_batch_reader(num_batches=8) + # Force every batch to roll a new file by setting an absurdly small target size. + tbl = catalog.create_table( + identifier=identifier, + schema=reader.schema, + properties={TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: "1"}, + ) + + tbl.append(reader) + + snapshot = tbl.metadata.current_snapshot() + assert snapshot is not None + assert snapshot.summary is not None + added_files = snapshot.summary["added-data-files"] + assert added_files is not None and int(added_files) > 1, snapshot.summary + assert len(tbl.scan().to_arrow()) == total_rows + + +def test_append_record_batch_reader_empty(catalog: Catalog) -> None: + catalog.create_namespace("default") + identifier = f"default.append_record_batch_reader_empty_{catalog.name}" + schema = _simple_arrow_table().schema + reader = pa.RecordBatchReader.from_batches(schema, iter([])) + tbl = catalog.create_table(identifier=identifier, schema=schema) + + tbl.append(reader) + + assert len(tbl.scan().to_arrow()) == 0 + + +def test_overwrite_record_batch_reader(catalog: Catalog) -> None: + catalog.create_namespace("default") + identifier = f"default.overwrite_record_batch_reader_{catalog.name}" + pa_table = _simple_arrow_table() + tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema) + tbl.append(pa_table) + assert len(tbl.scan().to_arrow()) == pa_table.num_rows + + reader, total_rows = _simple_record_batch_reader(num_batches=2) + tbl.overwrite(reader) + + assert len(tbl.scan().to_arrow()) == total_rows + + +def test_append_record_batch_reader_to_partitioned_table_raises(catalog: Catalog) -> None: + catalog.create_namespace("default") + identifier = f"default.append_record_batch_reader_partitioned_{catalog.name}" + iceberg_schema = Schema( + NestedField(1, "id", IntegerType(), required=False), + NestedField(2, "bucket", StringType(), required=False), + ) + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="bucket"), + ) + tbl = catalog.create_table(identifier=identifier, schema=iceberg_schema, partition_spec=partition_spec) + + arrow_schema = schema_to_pyarrow(iceberg_schema) + reader = pa.RecordBatchReader.from_batches(arrow_schema, iter([])) + with pytest.raises(NotImplementedError, match="partitioned table"): + tbl.append(reader) + + +def test_append_invalid_input_type_raises(catalog: Catalog) -> None: + catalog.create_namespace("default") + identifier = f"default.append_invalid_input_{catalog.name}" + pa_table = _simple_arrow_table() + tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema) + with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader"): + tbl.append("not an arrow object") + + +def test_record_batch_reader_consumed_exactly_once(catalog: Catalog) -> None: + """The streaming path must consume the underlying generator exactly once. + A regression that drained the reader twice (e.g. an extra .schema access + that materialised the iterator, or a retry-loop without a fresh reader) + would silently lose data — the second pass is empty. + """ + catalog.create_namespace("default") + identifier = f"default.record_batch_reader_consumed_once_{catalog.name}" + pa_table = _simple_arrow_table() + consumed_batches = 0 + + def tracking_batches() -> Generator[pa.RecordBatch, None, None]: + nonlocal consumed_batches + for batch in pa_table.to_batches() * 3: + consumed_batches += 1 + yield batch + + reader = pa.RecordBatchReader.from_batches(pa_table.schema, tracking_batches()) + tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema) + + tbl.append(reader) + + # The generator should have been driven to exhaustion exactly once: 3 batches. + assert consumed_batches == 3 + assert len(tbl.scan().to_arrow()) == pa_table.num_rows * 3 + + +def test_record_batch_reader_schema_mismatch_writes_no_files(catalog: Catalog) -> None: + """A schema mismatch must fail before any data files are written. Otherwise + we'd leak orphan parquet files in storage (and a partial commit that picks + them up later via add_files would be a correctness disaster). + """ + catalog.create_namespace("default") + identifier = f"default.record_batch_reader_schema_mismatch_{catalog.name}" + iceberg_schema = Schema(NestedField(1, "foo", StringType(), required=False)) + tbl = catalog.create_table(identifier=identifier, schema=iceberg_schema) + + bad_schema = pa.schema([pa.field("foo", pa.int64(), nullable=True)]) + bad_reader = pa.RecordBatchReader.from_batches( + bad_schema, + iter([pa.RecordBatch.from_pylist([{"foo": 1}], schema=bad_schema)]), + ) + + with pytest.raises(ValueError): + tbl.append(bad_reader) + + # No snapshot should have been produced: the schema check runs before + # _append_snapshot_producer opens. + assert tbl.metadata.current_snapshot() is None + assert len(tbl.scan().to_arrow()) == 0 diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 2bc4985609..1d1488255f 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -768,7 +768,7 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non properties={"format-version": "1"}, ) - with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"): + with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader, got: not a df"): tbl.append("not a df") diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 0a09867656..80d7cce6bc 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -791,10 +791,10 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_ identifier = "default.arrow_data_files" tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, []) - with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"): + with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader, got: not a df"): tbl.overwrite("not a df") - with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"): + with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader, got: not a df"): tbl.append("not a df") @@ -2571,3 +2571,87 @@ def test_v3_write_and_read_row_lineage(spark: SparkSession, session_catalog: Cat assert tbl.metadata.next_row_id == initial_next_row_id + len(test_data), ( "Expected next_row_id to be incremented by the number of added rows" ) + + +# RecordBatchReader streaming append/overwrite — see https://github.com/apache/iceberg-python/issues/2152 +# +# These integration tests prove Spark can read tables written via the new streaming +# path. Equivalent in-process scan coverage lives in tests/catalog/test_catalog_behaviors.py +# but only Spark exercises the manifest stats + Parquet metadata produced by the +# write_file → fast_append pipeline against an external reader. + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_append_record_batch_reader( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.streaming_append_record_batch_reader_v{format_version}" + tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}) + + # 4 batches × 3 rows each — exercises the multi-batch streaming path while + # keeping the assertion data tractable for Spark. + batches = arrow_table_with_null.to_batches() * 4 + reader = pa.RecordBatchReader.from_batches(arrow_table_with_null.schema, iter(batches)) + expected_rows = sum(b.num_rows for b in batches) + + tbl.append(reader) + + assert len(tbl.scan().to_arrow()) == expected_rows + df = spark.table(identifier) + assert df.count() == expected_rows + # Spot-check that Spark agrees on the schema as written + assert sorted(df.columns) == sorted(arrow_table_with_null.column_names) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_overwrite_record_batch_reader( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + identifier = f"default.streaming_overwrite_record_batch_reader_v{format_version}" + tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) + assert len(tbl.scan().to_arrow()) == arrow_table_with_null.num_rows + + batches = arrow_table_with_null.to_batches() * 2 + reader = pa.RecordBatchReader.from_batches(arrow_table_with_null.schema, iter(batches)) + expected_rows = sum(b.num_rows for b in batches) + + tbl.overwrite(reader) + + # Existing rows replaced, only the streamed rows remain + assert len(tbl.scan().to_arrow()) == expected_rows + df = spark.table(identifier) + assert df.count() == expected_rows + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_append_record_batch_reader_multifile( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int +) -> None: + """Forcing a tiny target file size should produce >1 data file in a single + snapshot, proving the byte-budget rollover in bin_pack_record_batches fires + end-to-end and the resulting files are valid Iceberg data files (Spark reads + them all).""" + identifier = f"default.streaming_append_multifile_v{format_version}" + tbl = _create_table( + session_catalog, + identifier, + {"format-version": str(format_version), TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: "1"}, + ) + + batches = arrow_table_with_null.to_batches(max_chunksize=1) * 4 + reader = pa.RecordBatchReader.from_batches(arrow_table_with_null.schema, iter(batches)) + expected_rows = sum(b.num_rows for b in batches) + + tbl.append(reader) + + snapshot = tbl.metadata.current_snapshot() + assert snapshot is not None + assert snapshot.summary is not None + added_files = snapshot.summary["added-data-files"] + assert added_files is not None and int(added_files) > 1, snapshot.summary + + df = spark.table(identifier) + assert df.count() == expected_rows diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 2170741bdd..a05b295fc1 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -20,6 +20,7 @@ import tempfile import uuid import warnings +from collections.abc import Iterator from datetime import date, datetime, timezone from pathlib import Path from typing import Any @@ -76,6 +77,7 @@ _task_to_record_batches, _to_requested_schema, bin_pack_arrow_table, + bin_pack_record_batches, compute_statistics_plan, data_file_statistics_from_parquet_metadata, expression_to_pyarrow, @@ -2364,6 +2366,50 @@ def test_bin_pack_arrow_table_target_size_smaller_than_row(arrow_table_with_null assert sum(batch.num_rows for bin_ in bin_packed for batch in bin_) == arrow_table_with_null.num_rows +def test_bin_pack_record_batches_single_bin(arrow_table_with_null: pa.Table) -> None: + batches = arrow_table_with_null.to_batches() + bins = list(bin_pack_record_batches(iter(batches), target_file_size=arrow_table_with_null.nbytes * 10)) + # everything fits in one bin + assert len(bins) == 1 + assert sum(b.num_rows for b in bins[0]) == arrow_table_with_null.num_rows + + +def test_bin_pack_record_batches_microbatched(arrow_table_with_null: pa.Table) -> None: + # repeat the per-row batches so we have many small inputs to pack + batches = list(arrow_table_with_null.to_batches(max_chunksize=1)) * 5 + bin_size = arrow_table_with_null.nbytes // 2 # forces multiple bins + bins = list(bin_pack_record_batches(iter(batches), target_file_size=bin_size)) + assert len(bins) > 1 + assert sum(b.num_rows for bin_ in bins for b in bin_) == arrow_table_with_null.num_rows * 5 + # All but the last bin should have crossed the size threshold. + for bin_ in bins[:-1]: + assert sum(b.nbytes for b in bin_) >= bin_size + + +def test_bin_pack_record_batches_empty() -> None: + assert list(bin_pack_record_batches(iter([]), target_file_size=1024)) == [] + + +def test_bin_pack_record_batches_is_lazy(arrow_table_with_null: pa.Table) -> None: + # Streams are single-pass: confirm the helper consumes its input batch-by-batch + # rather than materialising the whole iterator before yielding the first bin. + consumed: list[int] = [] + + def tracking_iter() -> Iterator[pa.RecordBatch]: + for i, batch in enumerate(arrow_table_with_null.to_batches(max_chunksize=1)): + consumed.append(i) + yield batch + + target = max(1, arrow_table_with_null.nbytes // 4) + bins_iter = bin_pack_record_batches(tracking_iter(), target_file_size=target) + first_bin = next(bins_iter) + assert len(first_bin) >= 1 + # Generator should not have walked the entire input upon yielding the first bin + assert len(consumed) < arrow_table_with_null.num_rows + list(bins_iter) + assert len(consumed) == arrow_table_with_null.num_rows + + def test_schema_mismatch_type(table_schema_simple: Schema) -> None: other_schema = pa.schema( (