Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions src/duckdb_py/arrow/arrow_array_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "duckdb/common/assert.hpp"
#include "duckdb/common/common.hpp"
#include "duckdb/common/limits.hpp"
#include "duckdb/main/client_config.hpp"

namespace duckdb {

Expand All @@ -30,13 +29,12 @@ void VerifyArrowDatasetLoaded() {

py::object PythonTableArrowArrayStreamFactory::ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle,
ArrowStreamParameters &parameters,
const ClientProperties &client_properties) {
const shared_ptr<ClientContext> &client_context) {
D_ASSERT(!py::isinstance<py::capsule>(arrow_obj_handle));
ArrowSchemaWrapper schema;
PythonTableArrowArrayStreamFactory::GetSchemaInternal(arrow_obj_handle, schema);
ArrowTableSchema arrow_table;
ArrowTableFunction::PopulateArrowTableSchema(*client_properties.client_context.get_mutable(), arrow_table,
schema.arrow_schema);
ArrowTableFunction::PopulateArrowTableSchema(*client_context, arrow_table, schema.arrow_schema);

auto filters = parameters.filters;
auto &column_list = parameters.projected_columns.columns;
Expand All @@ -50,8 +48,9 @@ py::object PythonTableArrowArrayStreamFactory::ProduceScanner(py::object &arrow_
}

if (has_filter) {
auto filter = PyArrowFilterPushdown::TransformFilter(*filters, parameters.projected_columns.projection_map,
filter_to_col, client_properties, arrow_table);
auto filter =
PyArrowFilterPushdown::TransformFilter(*filters, parameters.projected_columns.projection_map, filter_to_col,
client_context->GetClientProperties(), arrow_table);
if (!filter.is(py::none())) {
kwargs["filter"] = filter;
}
Expand All @@ -78,7 +77,7 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
try {
auto filter_expr = PolarsFilterPushdown::TransformFilter(
*filters, parameters.projected_columns.projection_map, parameters.projected_columns.filter_to_col,
factory->client_properties);
factory->client_context->GetClientProperties());
if (!filter_expr.is(py::none())) {
lf = lf.attr("filter")(filter_expr);
filters_pushed = true;
Expand Down Expand Up @@ -139,7 +138,7 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
auto &import_cache = *DuckDBPyConnection::ImportCache();
py::object arrow_batch_scanner = import_cache.pyarrow.dataset.Scanner().attr("from_batches");
py::handle reader_handle = reader;
auto scanner = ProduceScanner(arrow_batch_scanner, reader_handle, parameters, factory->client_properties);
auto scanner = ProduceScanner(arrow_batch_scanner, reader_handle, parameters, factory->client_context);
auto record_batches = scanner.attr("to_reader")();
auto res = make_uniq<ArrowArrayStreamWrapper>();
auto export_to_c = record_batches.attr("_export_to_c");
Expand Down Expand Up @@ -177,12 +176,12 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
// If it's a scanner we have to turn it to a record batch reader, and then a scanner again since we can't stack
// scanners on arrow Otherwise pushed-down projections and filters will disappear like tears in the rain
auto record_batches = arrow_obj_handle.attr("to_reader")();
scanner = ProduceScanner(arrow_batch_scanner, record_batches, parameters, factory->client_properties);
scanner = ProduceScanner(arrow_batch_scanner, record_batches, parameters, factory->client_context);
break;
}
case PyArrowObjectType::Dataset: {
py::object arrow_scanner = arrow_obj_handle.attr("__class__").attr("scanner");
scanner = ProduceScanner(arrow_scanner, arrow_obj_handle, parameters, factory->client_properties);
scanner = ProduceScanner(arrow_scanner, arrow_obj_handle, parameters, factory->client_context);
break;
}
default: {
Expand Down
4 changes: 2 additions & 2 deletions src/duckdb_py/arrow/arrow_export_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ namespace duckdb {
namespace pyarrow {

py::object ToArrowTable(const vector<LogicalType> &types, const vector<string> &names, const py::list &batches,
ClientProperties &options) {
ClientContext &client_context) {
py::gil_scoped_acquire acquire;

auto pyarrow_lib_module = py::module::import("pyarrow").attr("lib");
auto from_batches_func = pyarrow_lib_module.attr("Table").attr("from_batches");
auto schema_import_func = pyarrow_lib_module.attr("Schema").attr("_import_from_c");
ArrowSchema schema;
ArrowConverter::ToArrowSchema(&schema, types, names, options);
ArrowConverter::ToArrowSchema(&schema, types, names, client_context);
auto schema_obj = schema_import_func(reinterpret_cast<uint64_t>(&schema));

return py::cast<duckdb::pyarrow::Table>(from_batches_func(batches, schema_obj));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ PyArrowObjectType GetArrowType(const py::handle &obj);

class PythonTableArrowArrayStreamFactory {
public:
explicit PythonTableArrowArrayStreamFactory(PyObject *arrow_table, const ClientProperties &client_properties_p,
explicit PythonTableArrowArrayStreamFactory(PyObject *arrow_table, const shared_ptr<ClientContext> &client_context,
PyArrowObjectType arrow_type_p)
: arrow_object(arrow_table), client_properties(client_properties_p), cached_arrow_type(arrow_type_p) {
: arrow_object(arrow_table), client_context(client_context), cached_arrow_type(arrow_type_p) {
cached_schema.release = nullptr;
}

Expand All @@ -94,7 +94,7 @@ class PythonTableArrowArrayStreamFactory {
//! Arrow Object (i.e., Scanner, Record Batch Reader, Table, Dataset)
PyObject *arrow_object;

const ClientProperties client_properties;
const shared_ptr<ClientContext> client_context;
const PyArrowObjectType cached_arrow_type;

//! Cached Arrow table from an unfiltered .collect().to_arrow() on a LazyFrame.
Expand All @@ -106,7 +106,8 @@ class PythonTableArrowArrayStreamFactory {
bool schema_cached = false;

static py::object ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle,
ArrowStreamParameters &parameters, const ClientProperties &client_properties);
ArrowStreamParameters &parameters,
const shared_ptr<ClientContext> &client_context);
};
} // namespace duckdb

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace duckdb {
namespace pyarrow {

py::object ToArrowTable(const vector<LogicalType> &types, const vector<string> &names, const py::list &batches,
ClientProperties &options);
ClientContext &client_context);

} // namespace pyarrow

Expand Down
2 changes: 1 addition & 1 deletion src/duckdb_py/include/duckdb_python/pyresult.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct DuckDBPyResult {
const vector<string> &GetNames();
const vector<LogicalType> &GetTypes();

ClientProperties GetClientProperties();
shared_ptr<ClientContext> GetClientContext() const;

private:
void FillNumpy(py::dict &res, idx_t col_idx, NumpyResultConversion &conversion, const char *name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ namespace duckdb {

struct PythonReplacementScan {
public:
static unique_ptr<TableRef> Replace(ClientContext &context, ReplacementScanInput &input,
static unique_ptr<TableRef> Replace(ClientContext &client_context, ReplacementScanInput &input,
optional_ptr<ReplacementScanData> data);
//! Try to perform a replacement, returns NULL on error
static unique_ptr<TableRef> TryReplacementObject(const py::object &entry, const string &name,
ClientContext &context, bool relation = false);
ClientContext &client_context, bool relation = false);
//! Perform a replacement or throw if it failed
static unique_ptr<TableRef> ReplacementObject(const py::object &entry, const string &name, ClientContext &context,
bool relation = false);
static unique_ptr<TableRef> ReplacementObject(const py::object &entry, const string &name,
ClientContext &client_context, bool relation = false);
};

} // namespace duckdb
3 changes: 1 addition & 2 deletions src/duckdb_py/pyconnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,8 +763,7 @@ shared_ptr<DuckDBPyConnection> DuckDBPyConnection::Append(const string &name, co
shared_ptr<DuckDBPyConnection> DuckDBPyConnection::RegisterPythonObject(const string &name,
const py::object &python_object) {
auto &connection = con.GetConnection();
auto &client = *connection.context;
auto object = PythonReplacementScan::ReplacementObject(python_object, name, client);
auto object = PythonReplacementScan::ReplacementObject(python_object, name, *connection.context);
auto view_rel = make_shared_ptr<ViewRelation>(connection.context, std::move(object), name);
bool replace = registered_objects.count(name);
view_rel->CreateView(name, replace, true);
Expand Down
10 changes: 5 additions & 5 deletions src/duckdb_py/pyrelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,18 +1024,18 @@ PolarsDataFrame DuckDBPyRelation::ToPolars(idx_t batch_size, bool lazy) {
ArrowSchema arrow_schema;
auto result_names = names;
QueryResult::DeduplicateColumns(result_names);
ClientProperties client_properties;
shared_ptr<ClientContext> client_context;
if (rel) {
client_properties = rel->context->GetContext()->GetClientProperties();
client_context = rel->context->GetContext();
} else if (result) {
client_properties = result->GetClientProperties();
client_context = result->GetClientContext();
} else {
throw InternalException("DuckDBPyRelation To Polars must have a valid relation or result");
}
ArrowConverter::ToArrowSchema(&arrow_schema, types, result_names, client_properties);
ArrowConverter::ToArrowSchema(&arrow_schema, types, result_names, *client_context);
py::list batches;
// Now we create an empty arrow table
auto empty_table = pyarrow::ToArrowTable(types, result_names, std::move(batches), client_properties);
auto empty_table = pyarrow::ToArrowTable(types, result_names, std::move(batches), *client_context);

// And we extract the polars schema from the arrow table
auto polars_df = py::cast<PolarsDataFrame>(pybind11::module_::import("polars").attr("DataFrame")(empty_table));
Expand Down
30 changes: 15 additions & 15 deletions src/duckdb_py/pyresult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ DuckDBPyResult::~DuckDBPyResult() {
}
}

ClientProperties DuckDBPyResult::GetClientProperties() {
return result->client_properties;
shared_ptr<ClientContext> DuckDBPyResult::GetClientContext() const {
return result->client_context;
}

const vector<string> &DuckDBPyResult::GetNames() {
Expand Down Expand Up @@ -138,7 +138,8 @@ Optional<py::tuple> DuckDBPyResult::Fetchone() {
continue;
}
auto val = current_chunk->data[col_idx].GetValue(chunk_offset);
res[col_idx] = PythonObject::FromValue(val, result->types[col_idx], result->client_properties);
res[col_idx] =
PythonObject::FromValue(val, result->types[col_idx], result->client_context->GetClientProperties());
}
chunk_offset++;
return res;
Expand Down Expand Up @@ -225,8 +226,8 @@ unique_ptr<NumpyResultConversion> DuckDBPyResult::InitializeNumpyConversion(bool
initial_capacity = materialized.RowCount();
}

auto conversion =
make_uniq<NumpyResultConversion>(result->types, initial_capacity, result->client_properties, pandas);
auto conversion = make_uniq<NumpyResultConversion>(result->types, initial_capacity,
result->client_context->GetClientProperties(), pandas);
return conversion;
}

Expand Down Expand Up @@ -297,7 +298,8 @@ void DuckDBPyResult::ConvertDateTimeTypes(PandasDataFrame &df, bool date_as_obje
if (result->types[i] == LogicalType::TIMESTAMP_TZ) {
// first localize to UTC then convert to timezone_config
auto utc_local = df[names[i].c_str()].attr("dt").attr("tz_localize")("UTC");
auto new_value = utc_local.attr("dt").attr("tz_convert")(result->client_properties.time_zone);
auto new_value =
utc_local.attr("dt").attr("tz_convert")(result->client_context->GetClientProperties().time_zone);
// We need to create the column anew because the exact dt changed to a new timezone
ReplaceDFColumn(df, names[i].c_str(), i, new_value);
} else if (date_as_object && result->types[i] == LogicalType::DATE) {
Expand Down Expand Up @@ -440,8 +442,7 @@ duckdb::pyarrow::Table DuckDBPyResult::FetchArrowTable(idx_t rows_per_batch, boo
}
ArrowArray data = array->arrow_array;
array->arrow_array.release = nullptr;
ArrowConverter::ToArrowSchema(&arrow_schema, arrow_result.types, result_names,
arrow_result.client_properties);
ArrowConverter::ToArrowSchema(&arrow_schema, arrow_result.types, result_names, *GetClientContext());
TransformDuckToArrowChunk(arrow_schema, data, batches);
}
} else {
Expand All @@ -453,9 +454,9 @@ duckdb::pyarrow::Table DuckDBPyResult::FetchArrowTable(idx_t rows_per_batch, boo
{
D_ASSERT(py::gil_check());
py::gil_scoped_release release;
count = ArrowUtil::FetchChunk(scan_state, query_result.client_properties, rows_per_batch, &data,
ArrowTypeExtensionData::GetExtensionTypes(
*query_result.client_properties.client_context, query_result.types));
auto arrow_type_exts =
ArrowTypeExtensionData::GetExtensionTypes(*GetClientContext(), query_result.types);
count = ArrowUtil::FetchChunk(scan_state, *GetClientContext(), rows_per_batch, &data, arrow_type_exts);
}
if (count == 0) {
break;
Expand All @@ -465,13 +466,12 @@ duckdb::pyarrow::Table DuckDBPyResult::FetchArrowTable(idx_t rows_per_batch, boo
if (to_polars) {
QueryResult::DeduplicateColumns(result_names);
}
ArrowConverter::ToArrowSchema(&arrow_schema, query_result.types, result_names,
query_result.client_properties);
ArrowConverter::ToArrowSchema(&arrow_schema, query_result.types, result_names, *GetClientContext());
TransformDuckToArrowChunk(arrow_schema, data, batches);
}
}

return pyarrow::ToArrowTable(result->types, names, std::move(batches), result->client_properties);
return pyarrow::ToArrowTable(result->types, names, std::move(batches), *GetClientContext());
}

ArrowArrayStream DuckDBPyResult::FetchArrowArrayStream(idx_t rows_per_batch) {
Expand Down Expand Up @@ -623,7 +623,7 @@ struct ArrowQueryResultStreamWrapper {
arrays = arrow_result.ConsumeArrays();

cached_schema.release = nullptr;
ArrowConverter::ToArrowSchema(&cached_schema, result->types, result->names, result->client_properties);
ArrowConverter::ToArrowSchema(&cached_schema, result->types, result->names, *result->client_context);

stream.private_data = this;
stream.get_schema = GetSchema;
Expand Down
Loading
Loading