diff --git a/src/duckdb_py/arrow/arrow_array_stream.cpp b/src/duckdb_py/arrow/arrow_array_stream.cpp index 4f438dec..b9771f51 100644 --- a/src/duckdb_py/arrow/arrow_array_stream.cpp +++ b/src/duckdb_py/arrow/arrow_array_stream.cpp @@ -10,7 +10,6 @@ #include "duckdb/common/assert.hpp" #include "duckdb/common/common.hpp" #include "duckdb/common/limits.hpp" -#include "duckdb/main/client_config.hpp" namespace duckdb { @@ -30,13 +29,12 @@ void VerifyArrowDatasetLoaded() { py::object PythonTableArrowArrayStreamFactory::ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle, ArrowStreamParameters ¶meters, - const ClientProperties &client_properties) { + const shared_ptr &client_context) { D_ASSERT(!py::isinstance(arrow_obj_handle)); ArrowSchemaWrapper schema; PythonTableArrowArrayStreamFactory::GetSchemaInternal(arrow_obj_handle, schema); ArrowTableSchema arrow_table; - ArrowTableFunction::PopulateArrowTableSchema(*client_properties.client_context.get_mutable(), arrow_table, - schema.arrow_schema); + ArrowTableFunction::PopulateArrowTableSchema(*client_context, arrow_table, schema.arrow_schema); auto filters = parameters.filters; auto &column_list = parameters.projected_columns.columns; @@ -50,8 +48,9 @@ py::object PythonTableArrowArrayStreamFactory::ProduceScanner(py::object &arrow_ } if (has_filter) { - auto filter = PyArrowFilterPushdown::TransformFilter(*filters, parameters.projected_columns.projection_map, - filter_to_col, client_properties, arrow_table); + auto filter = + PyArrowFilterPushdown::TransformFilter(*filters, parameters.projected_columns.projection_map, filter_to_col, + client_context->GetClientProperties(), arrow_table); if (!filter.is(py::none())) { kwargs["filter"] = filter; } @@ -78,7 +77,7 @@ unique_ptr PythonTableArrowArrayStreamFactory::Produce( try { auto filter_expr = PolarsFilterPushdown::TransformFilter( *filters, parameters.projected_columns.projection_map, parameters.projected_columns.filter_to_col, - factory->client_properties); + factory->client_context->GetClientProperties()); if (!filter_expr.is(py::none())) { lf = lf.attr("filter")(filter_expr); filters_pushed = true; @@ -139,7 +138,7 @@ unique_ptr PythonTableArrowArrayStreamFactory::Produce( auto &import_cache = *DuckDBPyConnection::ImportCache(); py::object arrow_batch_scanner = import_cache.pyarrow.dataset.Scanner().attr("from_batches"); py::handle reader_handle = reader; - auto scanner = ProduceScanner(arrow_batch_scanner, reader_handle, parameters, factory->client_properties); + auto scanner = ProduceScanner(arrow_batch_scanner, reader_handle, parameters, factory->client_context); auto record_batches = scanner.attr("to_reader")(); auto res = make_uniq(); auto export_to_c = record_batches.attr("_export_to_c"); @@ -177,12 +176,12 @@ unique_ptr PythonTableArrowArrayStreamFactory::Produce( // If it's a scanner we have to turn it to a record batch reader, and then a scanner again since we can't stack // scanners on arrow Otherwise pushed-down projections and filters will disappear like tears in the rain auto record_batches = arrow_obj_handle.attr("to_reader")(); - scanner = ProduceScanner(arrow_batch_scanner, record_batches, parameters, factory->client_properties); + scanner = ProduceScanner(arrow_batch_scanner, record_batches, parameters, factory->client_context); break; } case PyArrowObjectType::Dataset: { py::object arrow_scanner = arrow_obj_handle.attr("__class__").attr("scanner"); - scanner = ProduceScanner(arrow_scanner, arrow_obj_handle, parameters, factory->client_properties); + scanner = ProduceScanner(arrow_scanner, arrow_obj_handle, parameters, factory->client_context); break; } default: { diff --git a/src/duckdb_py/arrow/arrow_export_utils.cpp b/src/duckdb_py/arrow/arrow_export_utils.cpp index ea30b94b..d772ef13 100644 --- a/src/duckdb_py/arrow/arrow_export_utils.cpp +++ b/src/duckdb_py/arrow/arrow_export_utils.cpp @@ -18,14 +18,14 @@ namespace duckdb { namespace pyarrow { py::object ToArrowTable(const vector &types, const vector &names, const py::list &batches, - ClientProperties &options) { + ClientContext &client_context) { py::gil_scoped_acquire acquire; auto pyarrow_lib_module = py::module::import("pyarrow").attr("lib"); auto from_batches_func = pyarrow_lib_module.attr("Table").attr("from_batches"); auto schema_import_func = pyarrow_lib_module.attr("Schema").attr("_import_from_c"); ArrowSchema schema; - ArrowConverter::ToArrowSchema(&schema, types, names, options); + ArrowConverter::ToArrowSchema(&schema, types, names, client_context); auto schema_obj = schema_import_func(reinterpret_cast(&schema)); return py::cast(from_batches_func(batches, schema_obj)); diff --git a/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp b/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp index 90974ff7..129b7aba 100644 --- a/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp +++ b/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp @@ -68,9 +68,9 @@ PyArrowObjectType GetArrowType(const py::handle &obj); class PythonTableArrowArrayStreamFactory { public: - explicit PythonTableArrowArrayStreamFactory(PyObject *arrow_table, const ClientProperties &client_properties_p, + explicit PythonTableArrowArrayStreamFactory(PyObject *arrow_table, const shared_ptr &client_context, PyArrowObjectType arrow_type_p) - : arrow_object(arrow_table), client_properties(client_properties_p), cached_arrow_type(arrow_type_p) { + : arrow_object(arrow_table), client_context(client_context), cached_arrow_type(arrow_type_p) { cached_schema.release = nullptr; } @@ -94,7 +94,7 @@ class PythonTableArrowArrayStreamFactory { //! Arrow Object (i.e., Scanner, Record Batch Reader, Table, Dataset) PyObject *arrow_object; - const ClientProperties client_properties; + const shared_ptr client_context; const PyArrowObjectType cached_arrow_type; //! Cached Arrow table from an unfiltered .collect().to_arrow() on a LazyFrame. @@ -106,7 +106,8 @@ class PythonTableArrowArrayStreamFactory { bool schema_cached = false; static py::object ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle, - ArrowStreamParameters ¶meters, const ClientProperties &client_properties); + ArrowStreamParameters ¶meters, + const shared_ptr &client_context); }; } // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/arrow/arrow_export_utils.hpp b/src/duckdb_py/include/duckdb_python/arrow/arrow_export_utils.hpp index 39d1b3dc..25370b36 100644 --- a/src/duckdb_py/include/duckdb_python/arrow/arrow_export_utils.hpp +++ b/src/duckdb_py/include/duckdb_python/arrow/arrow_export_utils.hpp @@ -7,7 +7,7 @@ namespace duckdb { namespace pyarrow { py::object ToArrowTable(const vector &types, const vector &names, const py::list &batches, - ClientProperties &options); + ClientContext &client_context); } // namespace pyarrow diff --git a/src/duckdb_py/include/duckdb_python/pyresult.hpp b/src/duckdb_py/include/duckdb_python/pyresult.hpp index 941a203b..132a00e1 100644 --- a/src/duckdb_py/include/duckdb_python/pyresult.hpp +++ b/src/duckdb_py/include/duckdb_python/pyresult.hpp @@ -59,7 +59,7 @@ struct DuckDBPyResult { const vector &GetNames(); const vector &GetTypes(); - ClientProperties GetClientProperties(); + shared_ptr GetClientContext() const; private: void FillNumpy(py::dict &res, idx_t col_idx, NumpyResultConversion &conversion, const char *name); diff --git a/src/duckdb_py/include/duckdb_python/python_replacement_scan.hpp b/src/duckdb_py/include/duckdb_python/python_replacement_scan.hpp index 8e329ea7..58014660 100644 --- a/src/duckdb_py/include/duckdb_python/python_replacement_scan.hpp +++ b/src/duckdb_py/include/duckdb_python/python_replacement_scan.hpp @@ -10,14 +10,14 @@ namespace duckdb { struct PythonReplacementScan { public: - static unique_ptr Replace(ClientContext &context, ReplacementScanInput &input, + static unique_ptr Replace(ClientContext &client_context, ReplacementScanInput &input, optional_ptr data); //! Try to perform a replacement, returns NULL on error static unique_ptr TryReplacementObject(const py::object &entry, const string &name, - ClientContext &context, bool relation = false); + ClientContext &client_context, bool relation = false); //! Perform a replacement or throw if it failed - static unique_ptr ReplacementObject(const py::object &entry, const string &name, ClientContext &context, - bool relation = false); + static unique_ptr ReplacementObject(const py::object &entry, const string &name, + ClientContext &client_context, bool relation = false); }; } // namespace duckdb diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index 6883ba45..271c7e04 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -763,8 +763,7 @@ shared_ptr DuckDBPyConnection::Append(const string &name, co shared_ptr DuckDBPyConnection::RegisterPythonObject(const string &name, const py::object &python_object) { auto &connection = con.GetConnection(); - auto &client = *connection.context; - auto object = PythonReplacementScan::ReplacementObject(python_object, name, client); + auto object = PythonReplacementScan::ReplacementObject(python_object, name, *connection.context); auto view_rel = make_shared_ptr(connection.context, std::move(object), name); bool replace = registered_objects.count(name); view_rel->CreateView(name, replace, true); diff --git a/src/duckdb_py/pyrelation.cpp b/src/duckdb_py/pyrelation.cpp index 35e33786..a033742c 100644 --- a/src/duckdb_py/pyrelation.cpp +++ b/src/duckdb_py/pyrelation.cpp @@ -1024,18 +1024,18 @@ PolarsDataFrame DuckDBPyRelation::ToPolars(idx_t batch_size, bool lazy) { ArrowSchema arrow_schema; auto result_names = names; QueryResult::DeduplicateColumns(result_names); - ClientProperties client_properties; + shared_ptr client_context; if (rel) { - client_properties = rel->context->GetContext()->GetClientProperties(); + client_context = rel->context->GetContext(); } else if (result) { - client_properties = result->GetClientProperties(); + client_context = result->GetClientContext(); } else { throw InternalException("DuckDBPyRelation To Polars must have a valid relation or result"); } - ArrowConverter::ToArrowSchema(&arrow_schema, types, result_names, client_properties); + ArrowConverter::ToArrowSchema(&arrow_schema, types, result_names, *client_context); py::list batches; // Now we create an empty arrow table - auto empty_table = pyarrow::ToArrowTable(types, result_names, std::move(batches), client_properties); + auto empty_table = pyarrow::ToArrowTable(types, result_names, std::move(batches), *client_context); // And we extract the polars schema from the arrow table auto polars_df = py::cast(pybind11::module_::import("polars").attr("DataFrame")(empty_table)); diff --git a/src/duckdb_py/pyresult.cpp b/src/duckdb_py/pyresult.cpp index 91118add..e7e31049 100644 --- a/src/duckdb_py/pyresult.cpp +++ b/src/duckdb_py/pyresult.cpp @@ -43,8 +43,8 @@ DuckDBPyResult::~DuckDBPyResult() { } } -ClientProperties DuckDBPyResult::GetClientProperties() { - return result->client_properties; +shared_ptr DuckDBPyResult::GetClientContext() const { + return result->client_context; } const vector &DuckDBPyResult::GetNames() { @@ -138,7 +138,8 @@ Optional DuckDBPyResult::Fetchone() { continue; } auto val = current_chunk->data[col_idx].GetValue(chunk_offset); - res[col_idx] = PythonObject::FromValue(val, result->types[col_idx], result->client_properties); + res[col_idx] = + PythonObject::FromValue(val, result->types[col_idx], result->client_context->GetClientProperties()); } chunk_offset++; return res; @@ -225,8 +226,8 @@ unique_ptr DuckDBPyResult::InitializeNumpyConversion(bool initial_capacity = materialized.RowCount(); } - auto conversion = - make_uniq(result->types, initial_capacity, result->client_properties, pandas); + auto conversion = make_uniq(result->types, initial_capacity, + result->client_context->GetClientProperties(), pandas); return conversion; } @@ -297,7 +298,8 @@ void DuckDBPyResult::ConvertDateTimeTypes(PandasDataFrame &df, bool date_as_obje if (result->types[i] == LogicalType::TIMESTAMP_TZ) { // first localize to UTC then convert to timezone_config auto utc_local = df[names[i].c_str()].attr("dt").attr("tz_localize")("UTC"); - auto new_value = utc_local.attr("dt").attr("tz_convert")(result->client_properties.time_zone); + auto new_value = + utc_local.attr("dt").attr("tz_convert")(result->client_context->GetClientProperties().time_zone); // We need to create the column anew because the exact dt changed to a new timezone ReplaceDFColumn(df, names[i].c_str(), i, new_value); } else if (date_as_object && result->types[i] == LogicalType::DATE) { @@ -440,8 +442,7 @@ duckdb::pyarrow::Table DuckDBPyResult::FetchArrowTable(idx_t rows_per_batch, boo } ArrowArray data = array->arrow_array; array->arrow_array.release = nullptr; - ArrowConverter::ToArrowSchema(&arrow_schema, arrow_result.types, result_names, - arrow_result.client_properties); + ArrowConverter::ToArrowSchema(&arrow_schema, arrow_result.types, result_names, *GetClientContext()); TransformDuckToArrowChunk(arrow_schema, data, batches); } } else { @@ -453,9 +454,9 @@ duckdb::pyarrow::Table DuckDBPyResult::FetchArrowTable(idx_t rows_per_batch, boo { D_ASSERT(py::gil_check()); py::gil_scoped_release release; - count = ArrowUtil::FetchChunk(scan_state, query_result.client_properties, rows_per_batch, &data, - ArrowTypeExtensionData::GetExtensionTypes( - *query_result.client_properties.client_context, query_result.types)); + auto arrow_type_exts = + ArrowTypeExtensionData::GetExtensionTypes(*GetClientContext(), query_result.types); + count = ArrowUtil::FetchChunk(scan_state, *GetClientContext(), rows_per_batch, &data, arrow_type_exts); } if (count == 0) { break; @@ -465,13 +466,12 @@ duckdb::pyarrow::Table DuckDBPyResult::FetchArrowTable(idx_t rows_per_batch, boo if (to_polars) { QueryResult::DeduplicateColumns(result_names); } - ArrowConverter::ToArrowSchema(&arrow_schema, query_result.types, result_names, - query_result.client_properties); + ArrowConverter::ToArrowSchema(&arrow_schema, query_result.types, result_names, *GetClientContext()); TransformDuckToArrowChunk(arrow_schema, data, batches); } } - return pyarrow::ToArrowTable(result->types, names, std::move(batches), result->client_properties); + return pyarrow::ToArrowTable(result->types, names, std::move(batches), *GetClientContext()); } ArrowArrayStream DuckDBPyResult::FetchArrowArrayStream(idx_t rows_per_batch) { @@ -623,7 +623,7 @@ struct ArrowQueryResultStreamWrapper { arrays = arrow_result.ConsumeArrays(); cached_schema.release = nullptr; - ArrowConverter::ToArrowSchema(&cached_schema, result->types, result->names, result->client_properties); + ArrowConverter::ToArrowSchema(&cached_schema, result->types, result->names, *result->client_context); stream.private_data = this; stream.get_schema = GetSchema; diff --git a/src/duckdb_py/python_replacement_scan.cpp b/src/duckdb_py/python_replacement_scan.cpp index 8bff9e8f..1575dc74 100644 --- a/src/duckdb_py/python_replacement_scan.cpp +++ b/src/duckdb_py/python_replacement_scan.cpp @@ -17,11 +17,12 @@ namespace duckdb { static void CreateArrowScan(const string &name, py::object entry, TableFunctionRef &table_function, - vector> &children, ClientProperties &client_properties, - PyArrowObjectType type, DatabaseInstance &db) { + vector> &children, + const shared_ptr &client_context, PyArrowObjectType type) { + auto const db = client_context->db; shared_ptr external_dependency = make_shared_ptr(); if (type == PyArrowObjectType::MessageReader) { - if (!db.ExtensionIsLoaded("nanoarrow")) { + if (!db->ExtensionIsLoaded("nanoarrow")) { throw MissingExtensionException( "The nanoarrow community extension is needed to read the Arrow IPC protocol. \n You can install it " "with \"INSTALL nanoarrow FROM community;\". \n Then you can load it with \"LOAD nanoarrow;\""); @@ -51,7 +52,7 @@ static void CreateArrowScan(const string &name, py::object entry, TableFunctionR auto dependency_item = PythonDependencyItem::Create(stream_messages); external_dependency->AddDependency("replacement_cache", std::move(dependency_item)); } else { - auto stream_factory = make_uniq(entry.ptr(), client_properties, type); + auto stream_factory = make_uniq(entry.ptr(), client_context, type); auto stream_factory_produce = PythonTableArrowArrayStreamFactory::Produce; auto stream_factory_get_schema = PythonTableArrowArrayStreamFactory::GetSchema; @@ -98,8 +99,8 @@ static void ThrowScanFailureError(const py::object &entry, const string &name, c } unique_ptr PythonReplacementScan::ReplacementObject(const py::object &entry, const string &name, - ClientContext &context, bool relation) { - auto replacement = TryReplacementObject(entry, name, context, relation); + ClientContext &client_context, bool relation) { + auto replacement = TryReplacementObject(entry, name, client_context, relation); if (!replacement) { ThrowScanFailureError(entry, name); } @@ -107,8 +108,8 @@ unique_ptr PythonReplacementScan::ReplacementObject(const py::object & } unique_ptr PythonReplacementScan::TryReplacementObject(const py::object &entry, const string &name, - ClientContext &context, bool relation) { - auto client_properties = context.GetClientProperties(); + ClientContext &client_context, bool relation) { + auto client_properties = client_context.GetClientProperties(); auto table_function = make_uniq(); vector> children; NumpyObjectType numpytype; @@ -116,9 +117,10 @@ unique_ptr PythonReplacementScan::TryReplacementObject(const py::objec if (DuckDBPyConnection::IsPandasDataframe(entry)) { if (PandasDataFrame::IsPyArrowBacked(entry)) { auto table = PandasDataFrame::ToArrowTable(entry); - CreateArrowScan(name, table, *table_function, children, client_properties, PyArrowObjectType::Table, - *context.db); + CreateArrowScan(name, table, *table_function, children, client_context.shared_from_this(), + PyArrowObjectType::Table); } else { + // TODO: this smells like a bug string name = "df_" + StringUtil::GenerateRandomName(); auto new_df = PandasScanFunction::PandasReplaceCopiedNames(entry); children.push_back(make_uniq(Value::POINTER(CastPointerToValue(new_df.ptr())))); @@ -130,7 +132,7 @@ unique_ptr PythonReplacementScan::TryReplacementObject(const py::objec } } else if (DuckDBPyRelation::IsRelation(entry)) { auto pyrel = py::cast(entry); - if (!pyrel->CanBeRegisteredBy(context)) { + if (!pyrel->CanBeRegisteredBy(client_context)) { throw InvalidInputException( "Python Object \"%s\" of type \"DuckDBPyRelation\" not suitable for replacement scan.\nThe object was " "created by another Connection and can therefore not be used by this Connection.", @@ -149,14 +151,14 @@ unique_ptr PythonReplacementScan::TryReplacementObject(const py::objec // Polars's __arrow_c_stream__() serializes from its internal layout on every call, // which is expensive for repeated scans. The .to_arrow() path converts once. auto arrow_dataset = entry.attr("to_arrow")(); - CreateArrowScan(name, arrow_dataset, *table_function, children, client_properties, PyArrowObjectType::Table, - *context.db); + CreateArrowScan(name, arrow_dataset, *table_function, children, client_context.shared_from_this(), + PyArrowObjectType::Table); } else if (PolarsDataFrame::IsLazyFrame(entry)) { - CreateArrowScan(name, entry, *table_function, children, client_properties, PyArrowObjectType::PolarsLazyFrame, - *context.db); + CreateArrowScan(name, entry, *table_function, children, client_context.shared_from_this(), + PyArrowObjectType::PolarsLazyFrame); } else if ((arrow_type = DuckDBPyConnection::GetArrowType(entry)) != PyArrowObjectType::Invalid && !(arrow_type == PyArrowObjectType::MessageReader && !relation)) { - CreateArrowScan(name, entry, *table_function, children, client_properties, arrow_type, *context.db); + CreateArrowScan(name, entry, *table_function, children, client_context.shared_from_this(), arrow_type); } else if (DuckDBPyConnection::IsAcceptedNumpyObject(entry) != NumpyObjectType::INVALID) { numpytype = DuckDBPyConnection::IsAcceptedNumpyObject(entry); string np_name = "np_" + StringUtil::GenerateRandomName(); @@ -205,7 +207,7 @@ static bool IsBuiltinFunction(const py::object &object) { return py::isinstance(object, import_cache_py.types.BuiltinFunctionType()); } -static unique_ptr TryReplacement(py::dict &dict, const string &name, ClientContext &context, +static unique_ptr TryReplacement(py::dict &dict, const string &name, ClientContext &client_context, py::object ¤t_frame) { auto table_name = py::str(name); if (!dict.contains(table_name)) { @@ -218,7 +220,7 @@ static unique_ptr TryReplacement(py::dict &dict, const string &name, C return nullptr; } - auto result = PythonReplacementScan::TryReplacementObject(entry, name, context); + auto result = PythonReplacementScan::TryReplacementObject(entry, name, client_context); if (!result) { std::string location = py::cast(current_frame.attr("f_code").attr("co_filename")); location += ":"; @@ -228,9 +230,9 @@ static unique_ptr TryReplacement(py::dict &dict, const string &name, C return result; } -static unique_ptr ReplaceInternal(ClientContext &context, const string &table_name) { +static unique_ptr ReplaceInternal(ClientContext &client_context, const string &table_name) { Value result; - auto lookup_result = context.TryGetCurrentSetting("python_enable_replacements", result); + auto lookup_result = client_context.TryGetCurrentSetting("python_enable_replacements", result); D_ASSERT((bool)lookup_result); auto enabled = result.GetValue(); @@ -238,7 +240,7 @@ static unique_ptr ReplaceInternal(ClientContext &context, const string return nullptr; } - lookup_result = context.TryGetCurrentSetting("python_scan_all_frames", result); + lookup_result = client_context.TryGetCurrentSetting("python_scan_all_frames", result); D_ASSERT((bool)lookup_result); auto scan_all_frames = result.GetValue(); @@ -268,7 +270,7 @@ static unique_ptr ReplaceInternal(ClientContext &context, const string if (has_locals) { // search local dictionary auto local_dict = py::cast(local_dict_p); - auto result = TryReplacement(local_dict, table_name, context, current_frame); + auto result = TryReplacement(local_dict, table_name, client_context, current_frame); if (result) { return result; } @@ -283,7 +285,7 @@ static unique_ptr ReplaceInternal(ClientContext &context, const string if (has_globals) { auto global_dict = py::cast(global_dict_p); // search global dictionary - auto result = TryReplacement(global_dict, table_name, context, current_frame); + auto result = TryReplacement(global_dict, table_name, client_context, current_frame); if (result) { return result; } @@ -297,16 +299,16 @@ static unique_ptr ReplaceInternal(ClientContext &context, const string return nullptr; } -unique_ptr PythonReplacementScan::Replace(ClientContext &context, ReplacementScanInput &input, +unique_ptr PythonReplacementScan::Replace(ClientContext &client_context, ReplacementScanInput &input, optional_ptr data) { auto &table_name = input.table_name; - auto &config = DBConfig::GetConfig(context); + auto &config = DBConfig::GetConfig(client_context); if (!Settings::Get(config)) { return nullptr; } unique_ptr result; - result = ReplaceInternal(context, table_name); + result = ReplaceInternal(client_context, table_name); return result; } diff --git a/src/duckdb_py/python_udf.cpp b/src/duckdb_py/python_udf.cpp index a62004d4..01f06caf 100644 --- a/src/duckdb_py/python_udf.cpp +++ b/src/duckdb_py/python_udf.cpp @@ -8,13 +8,10 @@ #include "duckdb/common/arrow/arrow_converter.hpp" #include "duckdb/common/arrow/arrow_wrapper.hpp" #include "duckdb/common/arrow/arrow_appender.hpp" -#include "duckdb/common/arrow/result_arrow_wrapper.hpp" #include "duckdb_python/arrow/arrow_array_stream.hpp" #include "duckdb/function/table/arrow.hpp" #include "duckdb/function/function.hpp" -#include "duckdb_python/numpy/numpy_scan.hpp" #include "duckdb_python/arrow/arrow_export_utils.hpp" -#include "duckdb/common/types/arrow_aux_data.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/function/table/arrow/arrow_duck_schema.hpp" #include "duckdb_python/python_conversion.hpp" @@ -22,20 +19,20 @@ namespace duckdb { static py::list ConvertToSingleBatch(vector &types, vector &names, DataChunk &input, - ClientProperties &options, ClientContext &context) { + ClientContext &client_context) { ArrowSchema schema; - ArrowConverter::ToArrowSchema(&schema, types, names, options); + ArrowConverter::ToArrowSchema(&schema, types, names, client_context); py::list single_batch; - ArrowAppender appender(types, STANDARD_VECTOR_SIZE, options, - ArrowTypeExtensionData::GetExtensionTypes(context, types)); + ArrowAppender appender(types, STANDARD_VECTOR_SIZE, client_context.shared_from_this(), + ArrowTypeExtensionData::GetExtensionTypes(client_context, types)); appender.Append(input, 0, input.size(), input.size()); auto array = appender.Finalize(); TransformDuckToArrowChunk(schema, array, single_batch); return single_batch; } -static py::object ConvertDataChunkToPyArrowTable(DataChunk &input, ClientProperties &options, ClientContext &context) { +static py::object ConvertDataChunkToPyArrowTable(DataChunk &input, ClientContext &context) { auto types = input.GetTypes(); vector names; names.reserve(types.size()); @@ -43,7 +40,7 @@ static py::object ConvertDataChunkToPyArrowTable(DataChunk &input, ClientPropert names.push_back(StringUtil::Format("c%d", i)); } - return pyarrow::ToArrowTable(types, names, ConvertToSingleBatch(types, names, input, options, context), options); + return pyarrow::ToArrowTable(types, names, ConvertToSingleBatch(types, names, input, context), context); } // If these types are arrow canonical extensions, we must check if they are registered. @@ -75,7 +72,7 @@ static void ConvertArrowTableToVector(const py::object &table, Vector &out, Clie py::gil_scoped_release gil; auto stream_factory = - make_uniq(ptr, context.GetClientProperties(), PyArrowObjectType::Table); + make_uniq(ptr, context.shared_from_this(), PyArrowObjectType::Table); auto stream_factory_produce = PythonTableArrowArrayStreamFactory::Produce; auto stream_factory_get_schema = PythonTableArrowArrayStreamFactory::GetSchema; @@ -177,13 +174,6 @@ static scalar_function_t CreateVectorizedFunction(PyObject *function, PythonExce // owning references py::object python_object; - // Convert the input datachunk to pyarrow - // ClientProperties options; - - // if (state.HasContext()) { - auto &context = state.GetContext(); - auto options = context.GetClientProperties(); - // } auto result_validity = FlatVector::Validity(result); SelectionVector selvec(input.size()); @@ -215,7 +205,7 @@ static scalar_function_t CreateVectorizedFunction(PyObject *function, PythonExce } } - auto pyarrow_table = ConvertDataChunkToPyArrowTable(input, options, state.GetContext()); + auto pyarrow_table = ConvertDataChunkToPyArrowTable(input, state.GetContext()); py::tuple column_list = pyarrow_table.attr("columns"); auto count = input.size(); @@ -454,6 +444,7 @@ struct PythonUDFData { auto python_version = PY_VERSION_HEX; auto signature_func = py::module_::import("inspect").attr("signature"); + // TODO: we dropped support for Python < 3.10 as of DuckDB 1.5 if (python_version >= PYTHON_3_10_HEX) { return signature_func(udf, py::arg("eval_str") = true); } else {