From cc43c7b050ab187ebd156e94076f188874851979 Mon Sep 17 00:00:00 2001 From: "xiyu.zk" Date: Thu, 28 May 2026 21:06:11 +0800 Subject: [PATCH] [python][daft] Make Daft Paimon write sink serializable --- paimon-python/pypaimon/daft/daft_datasink.py | 128 +++++++++++++++++- .../pypaimon/tests/daft/daft_sink_test.py | 22 +++ 2 files changed, 148 insertions(+), 2 deletions(-) diff --git a/paimon-python/pypaimon/daft/daft_datasink.py b/paimon-python/pypaimon/daft/daft_datasink.py index c019b16a31f3..7e6b871f06dd 100644 --- a/paimon-python/pypaimon/daft/daft_datasink.py +++ b/paimon-python/pypaimon/daft/daft_datasink.py @@ -32,6 +32,93 @@ from pypaimon.table.file_store_table import FileStoreTable +_PaimonIdentifier = tuple[str, str, str | None] + + +def _options_to_dict(options: Any) -> dict[str, Any]: + if options is None: + return {} + if isinstance(options, dict): + return dict(options) + + to_map = getattr(options, "to_map", None) + if callable(to_map): + return dict(to_map()) + + data = getattr(options, "data", None) + if isinstance(data, dict): + return dict(data) + + return {} + + +def _extract_catalog_options(table: FileStoreTable) -> dict[str, Any]: + file_io = getattr(table, "file_io", None) + properties = getattr(file_io, "properties", None) + if properties is None: + properties = getattr(file_io, "catalog_options", None) + return _options_to_dict(properties) + + +def _extract_identifier(table: FileStoreTable) -> _PaimonIdentifier | None: + identifier = getattr(table, "identifier", None) + if identifier is None: + return None + + get_database_name = getattr(identifier, "get_database_name", None) + get_table_name = getattr(identifier, "get_table_name", None) + get_branch_name = getattr(identifier, "get_branch_name", None) + + database_name = ( + get_database_name() + if callable(get_database_name) + else getattr(identifier, "database", None) + ) + table_name = ( + get_table_name() + if callable(get_table_name) + else getattr(identifier, "object", None) + ) + branch_name = ( + get_branch_name() + if callable(get_branch_name) + else getattr(identifier, "branch", None) + ) + if database_name is None or table_name is None: + return None + return database_name, table_name, branch_name + + +def _to_paimon_identifier(identifier: _PaimonIdentifier) -> Any: + database_name, table_name, branch_name = identifier + if branch_name: + from pypaimon.common.identifier import Identifier + + return Identifier(database_name, table_name, branch_name) + return f"{database_name}.{table_name}" + + +def _load_table( + catalog_options: dict[str, Any], + table_identifier: _PaimonIdentifier | None, + table_path: str | None, +) -> FileStoreTable: + if catalog_options and table_identifier is not None: + from pypaimon.catalog.catalog_factory import CatalogFactory + + catalog = CatalogFactory.create(catalog_options) + return catalog.get_table(_to_paimon_identifier(table_identifier)) + + if table_path: + from pypaimon.table.file_store_table import FileStoreTable + + return FileStoreTable.from_path(table_path) + + raise RuntimeError( + "Unable to reconstruct Paimon table while deserializing PaimonDataSink." + ) + + class PaimonDataSink(DataSink[list[Any]]): """DataSink for writing data to an Apache Paimon table. @@ -45,14 +132,51 @@ class PaimonDataSink(DataSink[list[Any]]): def __init__(self, table: FileStoreTable, mode: str = "append") -> None: if mode not in ("append", "overwrite"): raise ValueError(f"Only 'append' or 'overwrite' mode is supported, got: {mode!r}") - self._table = table self._mode = mode + self._catalog_options = _extract_catalog_options(table) + self._table_identifier = _extract_identifier(table) + table_path = getattr(table, "table_path", None) + self._table_path = str(table_path) if table_path is not None else None + self._commit_user: str | None = None + self._init_table(table) + + def __getstate__(self) -> dict[str, Any]: + return { + "_mode": self._mode, + "_catalog_options": self._catalog_options, + "_table_identifier": self._table_identifier, + "_table_path": self._table_path, + "_commit_user": self._commit_user, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + self._mode = state["_mode"] + self._catalog_options = state["_catalog_options"] + self._table_identifier = state["_table_identifier"] + self._table_path = state["_table_path"] + self._commit_user = state["_commit_user"] + table = _load_table( + self._catalog_options, + self._table_identifier, + self._table_path, + ) + self._init_table(table) + + def _init_table(self, table: FileStoreTable) -> None: + self._table = table from pypaimon.schema.data_types import PyarrowFieldParser self._target_schema: pa.Schema = PyarrowFieldParser.from_paimon_schema(table.fields) self._write_builder = table.new_batch_write_builder() - if mode == "overwrite": + if ( + self._commit_user is not None + and hasattr(self._write_builder, "commit_user") + ): + self._write_builder.commit_user = self._commit_user + else: + self._commit_user = getattr(self._write_builder, "commit_user", None) + if self._mode == "overwrite": self._write_builder.overwrite({}) def name(self) -> str: diff --git a/paimon-python/pypaimon/tests/daft/daft_sink_test.py b/paimon-python/pypaimon/tests/daft/daft_sink_test.py index 3cca5487d97f..ad963de4447e 100644 --- a/paimon-python/pypaimon/tests/daft/daft_sink_test.py +++ b/paimon-python/pypaimon/tests/daft/daft_sink_test.py @@ -303,6 +303,28 @@ def test_write_paimon_invalid_mode(append_only_table): _write_table(df, table, mode="upsert") +def test_write_paimon_sink_serializes_without_file_io(append_only_table): + """PaimonDataSink should not pickle table FileIO objects.""" + from daft.pickle import dumps, loads + + class Unpicklable: + def __reduce__(self): + raise TypeError("file io marker should not be serialized") + + table, _ = append_only_table + table.file_io._unpicklable_marker = Unpicklable() + sink = PaimonDataSink(table, mode="overwrite") + commit_user = sink._write_builder.commit_user + + restored = loads(dumps(sink)) + + assert restored.name() == sink.name() + assert restored._mode == "overwrite" + assert restored._write_builder.commit_user == commit_user + assert restored._write_builder.static_partition == {} + assert restored._table.identifier.get_full_name() == table.identifier.get_full_name() + + def test_write_paimon_rejects_extra_columns(local_paimon_catalog): """Extra input columns should fail instead of being silently dropped.""" catalog, _ = local_paimon_catalog