diff --git a/.github/actions/ccache-action/action.yml b/.github/actions/ccache-action/action.yml deleted file mode 100644 index e3af01f4..00000000 --- a/.github/actions/ccache-action/action.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: 'Setup Ccache' -inputs: - key: - description: 'Cache key (defaults to github.job)' - required: false - default: '' -runs: - using: "composite" - steps: - - name: Setup Ccache - uses: hendrikmuhs/ccache-action@main - with: - key: ${{ inputs.key || github.job }} - save: ${{ github.repository != 'duckdb/duckdb-python' || contains('["refs/heads/main", "refs/heads/v1.4-andium", "refs/heads/v1.5-variegata"]', github.ref) }} - # Dump verbose ccache statistics report at end of CI job. - verbose: 1 - # Increase per-directory limit: 5*1024 MB / 16 = 320 MB. - # Note: `layout=subdirs` computes the size limit divided by 16 dirs. - # See also: https://ccache.dev/manual/4.9.html#_cache_size_management - max-size: 1500MB - # Evicts all cache files that were not touched during the job run. - # Removing cache files from previous runs avoids creating huge caches. - evict-old-files: 'job' diff --git a/.github/workflows/packaging_wheels.yml b/.github/workflows/packaging_wheels.yml index fed70203..c7f8c5d7 100644 --- a/.github/workflows/packaging_wheels.yml +++ b/.github/workflows/packaging_wheels.yml @@ -25,112 +25,15 @@ on: type: string jobs: - seed_wheels: - name: 'Seed: cp314-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }}' - strategy: - fail-fast: false - matrix: - python: [ cp314 ] - platform: - - { os: windows-2022, arch: amd64, cibw_system: win } - - { os: windows-11-arm, arch: ARM64, cibw_system: win } - - { os: ubuntu-24.04, arch: x86_64, cibw_system: manylinux } - - { os: ubuntu-24.04-arm, arch: aarch64, cibw_system: manylinux } - - { os: macos-15, arch: arm64, cibw_system: macosx } - - { os: macos-15, arch: universal2, cibw_system: macosx } - - { os: macos-15-intel, arch: x86_64, cibw_system: macosx } - minimal: - - ${{ inputs.minimal }} - exclude: - - { minimal: true, platform: { arch: universal2 } } - runs-on: ${{ matrix.platform.os }} - env: - CCACHE_DIR: ${{ github.workspace }}/.ccache - ### cibuildwheel configuration - # - # This is somewhat brittle, so be careful with changes. Some notes for our future selves (and others): - # - cibw will change its cwd to a temp dir and create a separate venv for testing. It then installs the wheel it - # built into that venv, and run the CIBW_TEST_COMMAND. We have to install all dependencies ourselves, and make - # sure that the pytest config in pyproject.toml is available. - # - CIBW_BEFORE_TEST installs the test dependencies by exporting them into a pylock.toml. At the time of writing, - # `uv sync --no-install-project` had problems correctly resolving dependencies using resolution environments - # across all platforms we build for. This might be solved in newer uv versions. - # - CIBW_TEST_COMMAND specifies pytest conf from pyproject.toml. --confcutdir is needed to prevent pytest from - # traversing the full filesystem, which produces an error on Windows. - # - CIBW_TEST_SKIP we always skip tests for *-macosx_universal2 builds, because we run tests for arm64 and x86_64. - CIBW_TEST_SKIP: ${{ inputs.testsuite == 'none' && '*' || '*-macosx_universal2' }} - CIBW_TEST_SOURCES: tests - CIBW_BEFORE_TEST: > - uv export --only-group test --no-emit-project --quiet --output-file pylock.toml --directory {project} && - uv pip install -r pylock.toml - CIBW_TEST_COMMAND: > - uv run -v pytest --confcutdir=. --rootdir . -c {project}/pyproject.toml ${{ inputs.testsuite == 'fast' && './tests/fast' || './tests' }} - - steps: - - name: Checkout DuckDB Python - uses: actions/checkout@v4 - with: - ref: ${{ inputs.duckdb-python-sha }} - fetch-depth: 0 - submodules: true - - - name: Checkout DuckDB - shell: bash - if: ${{ inputs.duckdb-sha }} - run: | - cd external/duckdb - git fetch origin - git checkout ${{ inputs.duckdb-sha }} - - - name: Set CIBW_ENVIRONMENT - shell: bash - run: | - cibw_env="" - if [[ "${{ matrix.platform.cibw_system }}" == "manylinux" ]]; then - cibw_env="CCACHE_DIR=/host${{ github.workspace }}/.ccache" - fi - if [[ -n "${{ inputs.set-version }}" ]]; then - cibw_env="${cibw_env:+$cibw_env }OVERRIDE_GIT_DESCRIBE=${{ inputs.set-version }}" - fi - if [[ -n "$cibw_env" ]]; then - echo "CIBW_ENVIRONMENT=${cibw_env}" >> $GITHUB_ENV - fi - - - name: Setup Ccache - uses: ./.github/actions/ccache-action - with: - key: ${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} - - # Install Astral UV, which will be used as build-frontend for cibuildwheel - - uses: astral-sh/setup-uv@v7 - with: - version: "0.9.0" - enable-cache: false - cache-suffix: -${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} - - - name: Build${{ inputs.testsuite != 'none' && ' and test ' || ' ' }}wheels - uses: pypa/cibuildwheel@v3.2 - env: - CIBW_ARCHS: ${{ matrix.platform.arch == 'amd64' && 'AMD64' || matrix.platform.arch }} - CIBW_BUILD: ${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} - - - name: Upload wheel - uses: actions/upload-artifact@v4 - with: - name: wheel-${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} - path: wheelhouse/*.whl - compression-level: 0 - build_wheels: name: 'Wheel: ${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }}' - needs: seed_wheels strategy: fail-fast: false matrix: - python: [ cp310, cp311, cp312, cp313 ] + python: [ cp310, cp311, cp312, cp313, cp314 ] platform: - - { os: windows-2025, arch: amd64, cibw_system: win } - - { os: windows-11-arm, arch: ARM64, cibw_system: win } + - { os: windows-2022, arch: amd64, cibw_system: win } + - { os: windows-11-arm, arch: ARM64, cibw_system: win } # cibw requires ARM64 to be uppercase - { os: ubuntu-24.04, arch: x86_64, cibw_system: manylinux } - { os: ubuntu-24.04-arm, arch: aarch64, cibw_system: manylinux } - { os: macos-15, arch: arm64, cibw_system: macosx } @@ -143,10 +46,9 @@ jobs: - { minimal: true, python: cp312 } - { minimal: true, python: cp313 } - { minimal: true, platform: { arch: universal2 } } - - { python: cp310, platform: { os: windows-11-arm, arch: ARM64 } } + - { python: cp310, platform: { os: windows-11-arm, arch: ARM64 } } # too many dependency problems for win arm64 runs-on: ${{ matrix.platform.os }} env: - CCACHE_DIR: ${{ github.workspace }}/.ccache ### cibuildwheel configuration # # This is somewhat brittle, so be careful with changes. Some notes for our future selves (and others): @@ -183,24 +85,11 @@ jobs: git fetch origin git checkout ${{ inputs.duckdb-sha }} - - name: Set CIBW_ENVIRONMENT + # Make sure that OVERRIDE_GIT_DESCRIBE is propagated to cibuildwhel's env, also when it's running linux builds + - name: Set OVERRIDE_GIT_DESCRIBE shell: bash - run: | - cibw_env="" - if [[ "${{ matrix.platform.cibw_system }}" == "manylinux" ]]; then - cibw_env="CCACHE_DIR=/host${{ github.workspace }}/.ccache" - fi - if [[ -n "${{ inputs.set-version }}" ]]; then - cibw_env="${cibw_env:+$cibw_env }OVERRIDE_GIT_DESCRIBE=${{ inputs.set-version }}" - fi - if [[ -n "$cibw_env" ]]; then - echo "CIBW_ENVIRONMENT=${cibw_env}" >> $GITHUB_ENV - fi - - - name: Setup Ccache - uses: ./.github/actions/ccache-action - with: - key: ${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} + if: ${{ inputs.set-version != '' }} + run: echo "CIBW_ENVIRONMENT=OVERRIDE_GIT_DESCRIBE=${{ inputs.set-version }}" >> $GITHUB_ENV # Install Astral UV, which will be used as build-frontend for cibuildwheel - uses: astral-sh/setup-uv@v7 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 727f8027..5c89d6d4 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -39,6 +39,7 @@ defaults: jobs: build_sdist: name: Build an sdist and determine versions + if: ${{ github.ref != 'refs/heads/main' }} uses: ./.github/workflows/packaging_sdist.yml with: testsuite: all diff --git a/CMakeLists.txt b/CMakeLists.txt index 71200269..0c063bdb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,8 +2,8 @@ cmake_minimum_required(VERSION 3.29) project(duckdb_py LANGUAGES CXX) -# Always use C++17 -set(CMAKE_CXX_STANDARD 17) +# Always use C++11 +set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") diff --git a/_duckdb-stubs/__init__.pyi b/_duckdb-stubs/__init__.pyi index 8770483f..1e6bf7e4 100644 --- a/_duckdb-stubs/__init__.pyi +++ b/_duckdb-stubs/__init__.pyi @@ -537,9 +537,7 @@ class DuckDBPyRelation: def distinct(self) -> DuckDBPyRelation: ... def except_(self, other_rel: Self) -> DuckDBPyRelation: ... def execute(self) -> DuckDBPyRelation: ... - def explain( - self, type: ExplainType | ExplainTypeLiteral = ExplainType.STANDARD, format: str | None = None - ) -> str: ... + def explain(self, type: ExplainType | ExplainTypeLiteral = ExplainType.STANDARD) -> str: ... def favg( self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... diff --git a/duckdb/experimental/spark/sql/type_utils.py b/duckdb/experimental/spark/sql/type_utils.py index 43d04e7c..2e15e38b 100644 --- a/duckdb/experimental/spark/sql/type_utils.py +++ b/duckdb/experimental/spark/sql/type_utils.py @@ -19,7 +19,6 @@ IntegerType, LongType, MapType, - NullType, ShortType, StringType, StructField, @@ -28,7 +27,6 @@ TimeNTZType, TimestampMillisecondNTZType, TimestampNanosecondNTZType, - TimestampNanosecondType, TimestampNTZType, TimestampSecondNTZType, TimestampType, @@ -43,7 +41,6 @@ ) _sqltype_to_spark_class = { - "null": NullType, "boolean": BooleanType, "utinyint": UnsignedByteType, "tinyint": ByteType, @@ -65,10 +62,9 @@ "time with time zone": TimeType, "timestamp": TimestampNTZType, "timestamp with time zone": TimestampType, - "timestamp_ms": TimestampMillisecondNTZType, - "timestamp_ns": TimestampNanosecondNTZType, + "timestamp_ms": TimestampNanosecondNTZType, + "timestamp_ns": TimestampMillisecondNTZType, "timestamp_s": TimestampSecondNTZType, - "timestamptz_ns": TimestampNanosecondType, "interval": DayTimeIntervalType, "list": ArrayType, "struct": StructType, diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index e9609627..5bfff09f 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -49,7 +49,6 @@ "TimestampMillisecondNTZType", "TimestampNTZType", "TimestampNanosecondNTZType", - "TimestampNanosecondType", "TimestampSecondNTZType", "TimestampType", "UUIDType", @@ -240,26 +239,6 @@ def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000) -class TimestampNanosecondType(AtomicType, metaclass=DataTypeSingleton): - """Timestamp (datetime.datetime) data type with timezone information with nanosecond precision.""" - - def __init__(self) -> None: # noqa: D107 - super().__init__(DuckDBPyType("TIMESTAMPTZ_NS")) - - def needConversion(self) -> bool: # noqa: D102 - return True - - @classmethod - def typeName(cls) -> str: # noqa: D102 - return "timestamptz_ns" - - def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 - raise ContributionsAcceptedError - - def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 - raise ContributionsAcceptedError - - class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with microsecond precision.""" diff --git a/duckdb_packaging/setuptools_scm_version.py b/duckdb_packaging/setuptools_scm_version.py index 630f2493..51c3f01a 100644 --- a/duckdb_packaging/setuptools_scm_version.py +++ b/duckdb_packaging/setuptools_scm_version.py @@ -12,7 +12,7 @@ from ._versioning import format_version, parse_version # MAIN_BRANCH_VERSIONING should be 'True' on main branch only -MAIN_BRANCH_VERSIONING = True +MAIN_BRANCH_VERSIONING = False SCM_PRETEND_ENV_VAR = "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB" SCM_GLOBAL_PRETEND_ENV_VAR = "SETUPTOOLS_SCM_PRETEND_VERSION" diff --git a/pyproject.toml b/pyproject.toml index 5cc4cc91..e23b259f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,12 +88,6 @@ version_scheme = "duckdb_packaging.setuptools_scm_version:version_scheme" local_scheme = "no-local-version" fallback_version = "0.0.1.dev1" -# main only: count dev distance from the last *minor* tag (v*.*.0), so a patch -# tag (e.g. v1.5.4) merged in from a release branch can't reset .devN. -# Release branches must NOT have this, they correctly count from v*.*.* (the default). -[tool.setuptools_scm.scm.git] -describe_command = "git describe --dirty --tags --long --abbrev=40 --match v*.*.0" - # Override: if COVERAGE is set then: # - we create a RelWithDebInfo build # - we make sure we use a persistent build dir so we get access to the .gcda files @@ -128,7 +122,7 @@ cmake.build-type = "Debug" [[tool.scikit-build.overrides]] if.state = "editable" if.env.COVERAGE = false -if.platform-system = "(?i)darwin" +if.platform-system = "Darwin" inherit.cmake.define = "append" cmake.define.DISABLE_UNITY = "1" diff --git a/scripts/cache_data.json b/scripts/cache_data.json index 1bb132b4..fea6034d 100644 --- a/scripts/cache_data.json +++ b/scripts/cache_data.json @@ -532,9 +532,7 @@ "polars.DataFrame", "polars.LazyFrame", "polars.col", - "polars.lit", - "polars.Series", - "polars.Decimal" + "polars.lit" ], "required": false }, @@ -824,17 +822,5 @@ "full_path": "polars.lit", "name": "lit", "children": [] - }, - "polars.Series": { - "type": "attribute", - "full_path": "polars.Series", - "name": "Series", - "children": [] - }, - "polars.Decimal": { - "type": "attribute", - "full_path": "polars.Decimal", - "name": "Decimal", - "children": [] } } \ No newline at end of file diff --git a/scripts/imports.py b/scripts/imports.py index 240a5f50..26e0394b 100644 --- a/scripts/imports.py +++ b/scripts/imports.py @@ -111,8 +111,6 @@ polars.LazyFrame polars.col polars.lit -polars.Series -polars.Decimal import duckdb import duckdb.filesystem diff --git a/src/duckdb_py/arrow/CMakeLists.txt b/src/duckdb_py/arrow/CMakeLists.txt index da8185a0..2f92f09b 100644 --- a/src/duckdb_py/arrow/CMakeLists.txt +++ b/src/duckdb_py/arrow/CMakeLists.txt @@ -1,7 +1,6 @@ # this is used for clang-tidy checks add_library( - python_arrow OBJECT - arrow_array_stream.cpp arrow_export_utils.cpp filter_pushdown_visitor.cpp - polars_filter_pushdown.cpp pyarrow_filter_pushdown.cpp) + python_arrow OBJECT arrow_array_stream.cpp arrow_export_utils.cpp + polars_filter_pushdown.cpp pyarrow_filter_pushdown.cpp) target_link_libraries(python_arrow PRIVATE _duckdb_dependencies) diff --git a/src/duckdb_py/arrow/arrow_array_stream.cpp b/src/duckdb_py/arrow/arrow_array_stream.cpp index ed9e2275..5b167e5e 100644 --- a/src/duckdb_py/arrow/arrow_array_stream.cpp +++ b/src/duckdb_py/arrow/arrow_array_stream.cpp @@ -43,7 +43,7 @@ py::object PythonTableArrowArrayStreamFactory::ProduceScanner(py::object &arrow_ auto &filter_to_col = parameters.projected_columns.filter_to_col; py::list projection_list = py::cast(column_list); - bool has_filter = filters && filters->HasFilters(); + bool has_filter = filters && !filters->filters.empty(); py::dict kwargs; if (!column_list.empty()) { kwargs["columns"] = projection_list; @@ -73,20 +73,18 @@ unique_ptr PythonTableArrowArrayStreamFactory::Produce( auto filters = parameters.filters; bool filters_pushed = false; - // Translate DuckDB filters to Polars expressions and push into the lazy plan. - // The walker only fails (throws / returns py::none()) for filters that are not - // required for correctness — optional/runtime wrappers it skips, or shapes the - // optimizer keeps above the scan. A throw here would mean the optimizer fully - // pushed something we can't translate (a correctness bug), so we let it surface - // rather than silently returning unfiltered rows — the arrow scan does not - // re-apply pushed filters. Mirrors the pyarrow ProduceScanner path. - if (filters && filters->HasFilters()) { - auto filter_expr = PolarsFilterPushdown::TransformFilter( - *filters, parameters.projected_columns.projection_map, parameters.projected_columns.filter_to_col, - factory->client_properties); - if (!filter_expr.is(py::none())) { - lf = lf.attr("filter")(filter_expr); - filters_pushed = true; + // Translate DuckDB filters to Polars expressions and push into the lazy plan + if (filters && !filters->filters.empty()) { + try { + auto filter_expr = PolarsFilterPushdown::TransformFilter( + *filters, parameters.projected_columns.projection_map, parameters.projected_columns.filter_to_col, + factory->client_properties); + if (!filter_expr.is(py::none())) { + lf = lf.attr("filter")(filter_expr); + filters_pushed = true; + } + } catch (...) { + // Fallback: DuckDB handles filtering post-scan } } diff --git a/src/duckdb_py/arrow/filter_pushdown_visitor.cpp b/src/duckdb_py/arrow/filter_pushdown_visitor.cpp deleted file mode 100644 index 20db8f18..00000000 --- a/src/duckdb_py/arrow/filter_pushdown_visitor.cpp +++ /dev/null @@ -1,216 +0,0 @@ -#include "duckdb_python/arrow/filter_pushdown_visitor.hpp" - -#include "duckdb/function/scalar/struct_utils.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" -#include "duckdb/planner/expression/bound_conjunction_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_operator_expression.hpp" -#include "duckdb/planner/filter/expression_filter.hpp" -#include "duckdb/planner/filter/table_filter_functions.hpp" - -namespace duckdb { - -namespace { - -bool ValueIsNan(const Value &value) { - if (value.type().id() == LogicalTypeId::FLOAT) { - return Value::IsNan(value.GetValue()); - } - if (value.type().id() == LogicalTypeId::DOUBLE) { - return Value::IsNan(value.GetValue()); - } - return false; -} - -// ResolveColumn walks a column-side expression to extract the (full path, leaf -// ArrowType) pair. Accepts a bare BoundReferenceExpression and (nested) -// `struct_extract` chains. Anything else throws NotImplementedException — -// that gives the OPTIONAL_FILTER catch point a chance to swallow it. -struct ResolvedColumn { - vector path; - const ArrowType *leaf_type; -}; - -ResolvedColumn ResolveColumn(const Expression &expr, const vector &root_path, const ArrowType *root_type) { - if (expr.GetExpressionClass() == ExpressionClass::BOUND_REF) { - return {root_path, root_type}; - } - if (expr.GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { - throw NotImplementedException("Cannot push down arrow scan filter on column-side expression: %s", - ExpressionClassToString(expr.GetExpressionClass())); - } - auto &func = expr.Cast(); - idx_t child_idx; - if (!TryGetStructExtractChildIndex(func, child_idx)) { - throw NotImplementedException("Cannot push down arrow scan filter on column-side function: %s", - ExpressionTypeToString(expr.GetExpressionType())); - } - // Recurse innermost-first so names accumulate root → leaf. - auto inner = ResolveColumn(*func.GetChildren()[0], root_path, root_type); - inner.path.push_back(StructType::GetChildName(func.GetChildren()[0]->GetReturnType(), child_idx)); - if (inner.leaf_type) { - inner.leaf_type = &inner.leaf_type->GetTypeInfo().GetChild(child_idx); - } - return inner; -} - -py::object EmitCompare(FilterBackend &backend, ExpressionType op, py::object col, const Value &constant, - const ArrowType *arrow_type, const string &timezone_config) { - if (ValueIsNan(constant)) { - return backend.NaNCompare(op, std::move(col)); - } - auto scalar = backend.MakeScalar(constant, arrow_type, timezone_config); - return backend.Compare(op, std::move(col), std::move(scalar)); -} - -} // anonymous namespace - -py::object TransformExpression(const Expression &expression, const vector &column_path, - FilterBackend &backend, const ArrowType *arrow_type, const string &timezone_config) { - auto expression_class = expression.GetExpressionClass(); - auto expression_type = expression.GetExpressionType(); - - if (expression_class == ExpressionClass::BOUND_FUNCTION) { - auto &bound_function_expression = expression.Cast(); - if (BoundComparisonExpression::IsComparison(expression_type)) { - auto &left = BoundComparisonExpression::Left(bound_function_expression); - auto &right = BoundComparisonExpression::Right(bound_function_expression); - - optional_ptr column_side; - optional_ptr constant_side; - - if (right.GetExpressionType() == ExpressionType::VALUE_CONSTANT) { - column_side = &left; - constant_side = &right.Cast(); - } else if (left.GetExpressionType() == ExpressionType::VALUE_CONSTANT) { - column_side = &right; - constant_side = &left.Cast(); - expression_type = FlipComparisonExpression(expression_type); - } else { - throw NotImplementedException("Can only push down constant comparisons."); - } - - auto resolved = ResolveColumn(*column_side, column_path, arrow_type); - auto col = backend.MakeColumnRef(resolved.path); - return EmitCompare(backend, expression_type, std::move(col), constant_side->GetValue(), resolved.leaf_type, - timezone_config); - } - - // Internal table-filter functions. Since the table-filter -> expression-filter - // migration in core, optional / dynamic / bloom / perfect-hash-join / prefix-range - // filters no longer have dedicated TableFilter subtypes. They arrive as scalar - // function wrappers inside the ExpressionFilter expression tree (see - // table_filter_functions.hpp). - const auto &func_name = bound_function_expression.Function().GetName(); - - // OPTIONAL / SELECTIVITY_OPTIONAL wrap a child predicate that lives in `bind_info` - // (their `children` hold only a placeholder column ref). An optional filter is never - // required for correctness, so if its child can't be translated we push nothing for - // it rather than failing the whole scan. - if (func_name == OptionalFilterScalarFun::NAME || func_name == SelectivityOptionalFilterScalarFun::NAME) { - optional_ptr child; - if (bound_function_expression.BindInfo()) { - if (func_name == OptionalFilterScalarFun::NAME) { - child = bound_function_expression.BindInfo() - ->Cast() - .child_filter_expr.get(); - } else { - child = bound_function_expression.BindInfo() - ->Cast() - .child_filter_expr.get(); - } - } - if (!child) { - return py::none(); - } - try { - return TransformExpression(*child, column_path, backend, arrow_type, timezone_config); - } catch (const NotImplementedException &) { - return py::none(); - } - } - - // DYNAMIC / BLOOM / PERFECT_HASH_JOIN / PREFIX_RANGE are runtime filters with no - // static pyarrow/polars equivalent. They are not required for correctness (the - // engine applies them above the scan), so skip them. - if (TableFilterFunctions::IsTableFilterFunction(func_name)) { - return py::none(); - } - } - - if (expression_class == ExpressionClass::BOUND_OPERATOR) { - auto &op_expr = expression.Cast(); - if (expression_type == ExpressionType::OPERATOR_IS_NULL) { - auto resolved = ResolveColumn(*op_expr.GetChildren()[0], column_path, arrow_type); - auto col = backend.MakeColumnRef(resolved.path); - return backend.IsNull(std::move(col)); - } - if (expression_type == ExpressionType::OPERATOR_IS_NOT_NULL) { - auto resolved = ResolveColumn(*op_expr.GetChildren()[0], column_path, arrow_type); - auto col = backend.MakeColumnRef(resolved.path); - return backend.IsNotNull(std::move(col)); - } - if (expression_type == ExpressionType::COMPARE_IN) { - auto resolved = ResolveColumn(*op_expr.GetChildren()[0], column_path, arrow_type); - auto col = backend.MakeColumnRef(resolved.path); - vector values; - for (idx_t i = 1; i < op_expr.GetChildren().size(); i++) { - auto &const_expr = op_expr.GetChildren()[i]->Cast(); - values.push_back(const_expr.GetValue()); - } - auto col_type = op_expr.GetChildren()[0]->GetReturnType(); - return backend.IsIn(std::move(col), values, col_type, timezone_config); - } - } - - if (expression_class == ExpressionClass::BOUND_CONJUNCTION) { - if (expression_type == ExpressionType::CONJUNCTION_OR || expression_type == ExpressionType::CONJUNCTION_AND) { - const bool is_and = expression_type == ExpressionType::CONJUNCTION_AND; - auto &conj_expr = expression.Cast(); - py::object result = py::none(); - for (idx_t i = 0; i < conj_expr.GetChildren().size(); i++) { - py::object child_expression = - TransformExpression(*conj_expr.GetChildren()[i], column_path, backend, arrow_type, timezone_config); - if (child_expression.is(py::none())) { - if (is_and) { - // A conjunct we can't push can simply be dropped: the remaining AND - // terms still form a correct (if weaker) filter, and the engine - // re-applies the rest above the scan. - continue; - } - // An OR branch that can't be translated (e.g. a dynamic filter) would - // make the pushed-down predicate stricter than the engine intends — - // fall back to no pushdown for the whole disjunction. - return py::none(); - } - if (result.is(py::none())) { - result = std::move(child_expression); - } else if (is_and) { - result = backend.And(std::move(result), std::move(child_expression)); - } else { - result = backend.Or(std::move(result), std::move(child_expression)); - } - } - return result; - } - } - - throw NotImplementedException("Pushdown Filter Type %s is not currently supported in arrow scans", - ExpressionClassToString(expression_class)); -} - -py::object TransformFilter(const TableFilter &filter, const vector &column_path, FilterBackend &backend, - const ArrowType *arrow_type, const string &timezone_config) { - switch (filter.filter_type) { - case TableFilterType::EXPRESSION_FILTER: { - auto &expression_filter = filter.Cast(); - return TransformExpression(*expression_filter.expr, column_path, backend, arrow_type, timezone_config); - } - default: - throw NotImplementedException("Pushdown Filter Type %s is not currently supported in arrow scans", - EnumUtil::ToString(filter.filter_type)); - } -} - -} // namespace duckdb diff --git a/src/duckdb_py/arrow/polars_filter_pushdown.cpp b/src/duckdb_py/arrow/polars_filter_pushdown.cpp index 3bbd4736..493189a3 100644 --- a/src/duckdb_py/arrow/polars_filter_pushdown.cpp +++ b/src/duckdb_py/arrow/polars_filter_pushdown.cpp @@ -1,141 +1,151 @@ #include "duckdb_python/arrow/polars_filter_pushdown.hpp" -#include "duckdb_python/arrow/filter_pushdown_visitor.hpp" -#include "duckdb_python/import_cache/python_import_cache.hpp" +#include "duckdb/planner/filter/in_filter.hpp" +#include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/planner/filter/struct_filter.hpp" +#include "duckdb/planner/table_filter.hpp" + #include "duckdb_python/pyconnection/pyconnection.hpp" #include "duckdb_python/python_objects.hpp" namespace duckdb { -namespace { - -struct PolarsBackend : public FilterBackend { - explicit PolarsBackend(const ClientProperties &client_properties_p) - : client_properties(client_properties_p), import_cache(*DuckDBPyConnection::ImportCache()) { - } +static py::object TransformFilterRecursive(TableFilter &filter, py::object col_expr, + const ClientProperties &client_properties) { + auto &import_cache = *DuckDBPyConnection::ImportCache(); + + switch (filter.filter_type) { + case TableFilterType::CONSTANT_COMPARISON: { + auto &constant_filter = filter.Cast(); + auto &constant = constant_filter.constant; + auto &constant_type = constant.type(); + + // Check for NaN + bool is_nan = false; + if (constant_type.id() == LogicalTypeId::FLOAT) { + is_nan = Value::IsNan(constant.GetValue()); + } else if (constant_type.id() == LogicalTypeId::DOUBLE) { + is_nan = Value::IsNan(constant.GetValue()); + } - py::object MakeColumnRef(const vector &path) override { - // pl.col(path[0]).struct.field(path[1]).struct.field(...) — polars supports arbitrary - // chaining for nested struct access, verified empirically up to 3 levels. - py::object col = import_cache.polars.col()(path[0]); - for (idx_t i = 1; i < path.size(); i++) { - col = col.attr("struct").attr("field")(path[i].GetIdentifierName()); + if (is_nan) { + switch (constant_filter.comparison_type) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return col_expr.attr("is_nan")(); + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_NOTEQUAL: + return col_expr.attr("is_nan")().attr("__invert__")(); + case ExpressionType::COMPARE_GREATERTHAN: + return import_cache.polars.lit()(false); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return import_cache.polars.lit()(true); + default: + return py::none(); + } } - return col; - } - py::object MakeScalar(const Value &v, const ArrowType *arrow_type, const string &timezone_config) override { - // Polars handles type coercion for primitives; no ArrowType lookup is needed. - (void)arrow_type; - (void)timezone_config; - return PythonObject::FromValue(v, v.type(), client_properties); - } + // Convert DuckDB Value to Python object + auto py_value = PythonObject::FromValue(constant, constant_type, client_properties); - py::object Compare(ExpressionType op, py::object col, py::object scalar) override { - switch (op) { + switch (constant_filter.comparison_type) { case ExpressionType::COMPARE_EQUAL: - return col.attr("__eq__")(scalar); - case ExpressionType::COMPARE_NOTEQUAL: - return col.attr("__ne__")(scalar); + return col_expr.attr("__eq__")(py_value); case ExpressionType::COMPARE_LESSTHAN: - return col.attr("__lt__")(scalar); + return col_expr.attr("__lt__")(py_value); case ExpressionType::COMPARE_GREATERTHAN: - return col.attr("__gt__")(scalar); + return col_expr.attr("__gt__")(py_value); case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return col.attr("__le__")(scalar); + return col_expr.attr("__le__")(py_value); case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return col.attr("__ge__")(scalar); + return col_expr.attr("__ge__")(py_value); + case ExpressionType::COMPARE_NOTEQUAL: + return col_expr.attr("__ne__")(py_value); default: - throw NotImplementedException("Comparison Type %s can't be a polars pushdown filter", - ExpressionTypeToString(op)); + return py::none(); } } - - py::object NaNCompare(ExpressionType op, py::object col) override { - switch (op) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return col.attr("is_nan")(); - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_NOTEQUAL: - return col.attr("is_nan")().attr("__invert__")(); - case ExpressionType::COMPARE_GREATERTHAN: - // Nothing is greater than NaN. - return import_cache.polars.lit()(false); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - // Everything is less than or equal to NaN. - return import_cache.polars.lit()(true); - default: - throw NotImplementedException("Unsupported comparison type (%s) for NaN values", - ExpressionTypeToString(op)); + case TableFilterType::IS_NULL: { + return col_expr.attr("is_null")(); + } + case TableFilterType::IS_NOT_NULL: { + return col_expr.attr("is_not_null")(); + } + case TableFilterType::CONJUNCTION_AND: { + auto &and_filter = filter.Cast(); + py::object expression = py::none(); + for (idx_t i = 0; i < and_filter.child_filters.size(); i++) { + auto child_expression = TransformFilterRecursive(*and_filter.child_filters[i], col_expr, client_properties); + if (child_expression.is(py::none())) { + continue; + } + if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__and__")(child_expression); + } } + return expression; } - - py::object IsNull(py::object col) override { - return col.attr("is_null")(); + case TableFilterType::CONJUNCTION_OR: { + auto &or_filter = filter.Cast(); + py::object expression = py::none(); + for (idx_t i = 0; i < or_filter.child_filters.size(); i++) { + auto child_expression = TransformFilterRecursive(*or_filter.child_filters[i], col_expr, client_properties); + if (child_expression.is(py::none())) { + // Can't skip children in OR + return py::none(); + } + if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__or__")(child_expression); + } + } + return expression; } - - py::object IsNotNull(py::object col) override { - return col.attr("is_not_null")(); + case TableFilterType::STRUCT_EXTRACT: { + auto &struct_filter = filter.Cast(); + auto child_col = col_expr.attr("struct").attr("field")(struct_filter.child_name); + return TransformFilterRecursive(*struct_filter.child_filter, child_col, client_properties); } - - py::object IsIn(py::object col, const vector &values, const LogicalType &col_logical_type, - const string &timezone_config) override { - (void)timezone_config; + case TableFilterType::IN_FILTER: { + auto &in_filter = filter.Cast(); py::list py_values; - for (auto &val : values) { - py_values.append(PythonObject::FromValue(val, val.type(), client_properties)); - } - if (col_logical_type.id() == LogicalTypeId::DECIMAL) { - // Polars infers Decimal(38, scale) for a plain list of Python Decimal values, - // which doesn't match the column's declared Decimal(precision, scale) — the call - // then fails with `'is_in' cannot check for List(Decimal(38, _)) values in - // Decimal(p, s) data`. Build a typed Series matching the column to side-step - // that, and wrap it with `.implode()` to silence the - // `is_in`-with-same-dtype-Series deprecation (issue 22149). - uint8_t width; - uint8_t scale; - col_logical_type.GetDecimalProperties(width, scale); - py::object dtype = import_cache.polars.Decimal()(py::arg("precision") = width, py::arg("scale") = scale); - py::object typed_series = - import_cache.polars.Series()(py::arg("values") = py_values, py::arg("dtype") = dtype); - return col.attr("is_in")(typed_series.attr("implode")()); + for (const auto &value : in_filter.values) { + py_values.append(PythonObject::FromValue(value, value.type(), client_properties)); } - return col.attr("is_in")(py_values); + return col_expr.attr("is_in")(py_values); } - - py::object And(py::object a, py::object b) override { - return a.attr("__and__")(b); + case TableFilterType::OPTIONAL_FILTER: { + auto &optional_filter = filter.Cast(); + if (!optional_filter.child_filter) { + return py::none(); + } + return TransformFilterRecursive(*optional_filter.child_filter, col_expr, client_properties); } - - py::object Or(py::object a, py::object b) override { - return a.attr("__or__")(b); + default: + // We skip DYNAMIC_FILTER, EXPRESSION_FILTER, BLOOM_FILTER + return py::none(); } - -private: - const ClientProperties &client_properties; - PythonImportCache &import_cache; -}; - -} // anonymous namespace +} py::object PolarsFilterPushdown::TransformFilter(const TableFilterSet &filter_collection, unordered_map &columns, const unordered_map &filter_to_col, const ClientProperties &client_properties) { - (void)filter_to_col; - PolarsBackend backend(client_properties); + auto &import_cache = *DuckDBPyConnection::ImportCache(); + auto &filters_map = filter_collection.filters; + py::object expression = py::none(); - for (auto &entry : filter_collection) { - auto column_idx = entry.GetIndex(); + for (auto &it : filters_map) { + auto column_idx = it.first; auto &column_name = columns[column_idx]; - D_ASSERT(columns.find(column_idx) != columns.end()); + auto col_expr = import_cache.polars.col()(column_name); - vector column_path = {Identifier(column_name)}; - // Polars does not need ArrowType information — `nullptr` here propagates through the - // shared walker; the PolarsBackend ignores the parameter in MakeScalar. - py::object child_expression = duckdb::TransformFilter(entry.Filter(), std::move(column_path), backend, nullptr, - client_properties.time_zone); + auto child_expression = TransformFilterRecursive(*it.second, col_expr, client_properties); if (child_expression.is(py::none())) { continue; } diff --git a/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp b/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp index 761ccbf6..5f1d1f3d 100644 --- a/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp +++ b/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp @@ -1,15 +1,20 @@ #include "duckdb_python/arrow/pyarrow_filter_pushdown.hpp" -#include "duckdb_python/arrow/filter_pushdown_visitor.hpp" +#include "duckdb/common/types/value_map.hpp" +#include "duckdb/planner/filter/in_filter.hpp" +#include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/planner/filter/struct_filter.hpp" +#include "duckdb/planner/table_filter.hpp" + #include "duckdb_python/pyconnection/pyconnection.hpp" -#include "duckdb_python/python_objects.hpp" #include "duckdb_python/pyrelation.hpp" +#include "duckdb_python/pyresult.hpp" #include "duckdb/function/table/arrow.hpp" namespace duckdb { -namespace { - string ConvertTimestampUnit(ArrowDateTimeType unit) { switch (unit) { case ArrowDateTimeType::MICROSECONDS: @@ -21,16 +26,16 @@ string ConvertTimestampUnit(ArrowDateTimeType unit) { case ArrowDateTimeType::SECONDS: return "s"; default: - throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit: %d", - static_cast(unit)); + throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit: %d", (int)unit); } } int64_t ConvertTimestampTZValue(int64_t base_value, ArrowDateTimeType datetime_type) { auto input = timestamp_t(base_value); - if (!Value::IsFinite(input)) { + if (!Timestamp::IsFinite(input)) { return base_value; } + switch (datetime_type) { case ArrowDateTimeType::MICROSECONDS: return Timestamp::GetEpochMicroSeconds(input); @@ -45,10 +50,7 @@ int64_t ConvertTimestampTZValue(int64_t base_value, ArrowDateTimeType datetime_t } } -// Build a pyarrow.dataset scalar matching the given DuckDB Value and (optionally) ArrowType. -// The ArrowType is needed for timestamp unit / decimal precision / blob-view disambiguation; the -// DuckDB Value alone is not sufficient. -py::object MakePyArrowScalar(const Value &constant, const string &timezone_config, const ArrowType *arrow_type) { +py::object GetScalar(Value &constant, const string &timezone_config, const ArrowType &type) { auto &import_cache = *DuckDBPyConnection::ImportCache(); auto scalar = import_cache.pyarrow.scalar(); py::handle dataset_scalar = import_cache.pyarrow.dataset().attr("scalar"); @@ -72,18 +74,6 @@ py::object MakePyArrowScalar(const Value &constant, const string &timezone_confi py::handle date_type = import_cache.pyarrow.time64(); return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); } - case LogicalTypeId::TIME_NS: { - // Polars TIME columns round-trip through arrow as time64("ns"). - // `Value::GetValue()` has a hand-rolled fast-path switch for TIME but not - // TIME_NS — it falls through to GetValueInternal, which then tries - // Cast::Operation for which no specialization exists, and - // throws "Unimplemented type for cast (INT64 -> INT64)". Use the type-strong - // GetValueUnsafe() which reads `value_.time_ns` from the union - // directly. The `dtime_ns_t.micros` field name is a misnomer — it actually holds - // nanoseconds (see arrow_conversion.cpp:432). - py::handle date_type = import_cache.pyarrow.time64(); - return dataset_scalar(scalar(constant.GetValueUnsafe().micros, date_type("ns"))); - } case LogicalTypeId::TIMESTAMP: { py::handle date_type = import_cache.pyarrow.timestamp(); return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); @@ -101,10 +91,7 @@ py::object MakePyArrowScalar(const Value &constant, const string &timezone_confi return dataset_scalar(scalar(constant.GetValue(), date_type("s"))); } case LogicalTypeId::TIMESTAMP_TZ: { - if (!arrow_type) { - throw NotImplementedException("Cannot push down TIMESTAMP_TZ filter without an arrow type"); - } - auto &datetime_info = arrow_type->GetTypeInfo(); + auto &datetime_info = type.GetTypeInfo(); auto base_value = constant.GetValue(); auto arrow_datetime_type = datetime_info.GetDateTimeType(); auto time_unit_string = ConvertTimestampUnit(arrow_datetime_type); @@ -112,11 +99,6 @@ py::object MakePyArrowScalar(const Value &constant, const string &timezone_confi py::handle date_type = import_cache.pyarrow.timestamp(); return dataset_scalar(scalar(converted_value, date_type(time_unit_string, py::arg("tz") = timezone_config))); } - case LogicalTypeId::TIMESTAMP_TZ_NS: { - py::handle date_type = import_cache.pyarrow.timestamp(); - auto converted_value = Timestamp::GetEpochNanoSeconds(timestamp_t(constant.GetValue())); - return dataset_scalar(scalar(converted_value, date_type("ns", py::arg("tz") = timezone_config))); - } case LogicalTypeId::UTINYINT: { py::handle integer_type = import_cache.pyarrow.uint8(); return dataset_scalar(scalar(constant.GetValue(), integer_type())); @@ -140,19 +122,16 @@ py::object MakePyArrowScalar(const Value &constant, const string &timezone_confi case LogicalTypeId::VARCHAR: return dataset_scalar(constant.ToString()); case LogicalTypeId::BLOB: { - if (arrow_type && arrow_type->GetTypeInfo().GetSizeType() == ArrowVariableSizeType::VIEW) { + if (type.GetTypeInfo().GetSizeType() == ArrowVariableSizeType::VIEW) { py::handle binary_view_type = import_cache.pyarrow.binary_view(); return dataset_scalar(scalar(py::bytes(constant.GetValueUnsafe()), binary_view_type())); } return dataset_scalar(py::bytes(constant.GetValueUnsafe())); } case LogicalTypeId::DECIMAL: { - if (!arrow_type) { - throw NotImplementedException("Cannot push down DECIMAL filter without an arrow type"); - } py::handle decimal_type; - auto &decimal_info = arrow_type->GetTypeInfo(); - auto bit_width = decimal_info.GetBitWidth(); + auto &datetime_info = type.GetTypeInfo(); + auto bit_width = datetime_info.GetBitWidth(); switch (bit_width) { case DecimalBitWidth::DECIMAL_32: decimal_type = import_cache.pyarrow.decimal32(); @@ -170,6 +149,7 @@ py::object MakePyArrowScalar(const Value &constant, const string &timezone_confi uint8_t width; uint8_t scale; constant.type().GetDecimalProperties(width, scale); + // pyarrow only allows 'decimal.Decimal' to be used to construct decimal scalars such as 0.05 auto val = import_cache.decimal.Decimal()(constant.ToString()); return dataset_scalar( scalar(std::move(val), decimal_type(py::arg("precision") = width, py::arg("scale") = scale))); @@ -180,120 +160,173 @@ py::object MakePyArrowScalar(const Value &constant, const string &timezone_confi } } -struct PyArrowBackend : public FilterBackend { - explicit PyArrowBackend(const ClientProperties &client_properties_p) : client_properties(client_properties_p) { - auto &import_cache = *DuckDBPyConnection::ImportCache(); - field_factory = import_cache.pyarrow.dataset().attr("field"); - dataset_scalar = import_cache.pyarrow.dataset().attr("scalar"); +static py::list TransformInList(const InFilter &in) { + py::list res; + ClientProperties default_properties; + for (auto &val : in.values) { + res.append(PythonObject::FromValue(val, val.type(), default_properties)); } + return res; +} - py::object MakeColumnRef(const vector &path) override { - vector str_path; - std::transform(path.begin(), path.end(), std::back_inserter(str_path), - [](const Identifier &segment) { return segment.GetIdentifierName(); }); - return field_factory(py::tuple(py::cast(str_path))); - } +py::object TransformFilterRecursive(TableFilter &filter, vector column_ref, const string &timezone_config, + const ArrowType &type) { + auto &import_cache = *DuckDBPyConnection::ImportCache(); + py::object field = import_cache.pyarrow.dataset().attr("field"); + switch (filter.filter_type) { + case TableFilterType::CONSTANT_COMPARISON: { + auto &constant_filter = filter.Cast(); + auto constant_field = field(py::tuple(py::cast(column_ref))); + auto constant_value = GetScalar(constant_filter.constant, timezone_config, type); - py::object MakeScalar(const Value &v, const ArrowType *arrow_type, const string &timezone_config) override { - return MakePyArrowScalar(v, timezone_config, arrow_type); - } + bool is_nan = false; + auto &constant = constant_filter.constant; + auto &constant_type = constant.type(); + if (constant_type.id() == LogicalTypeId::FLOAT) { + is_nan = Value::IsNan(constant.GetValue()); + } else if (constant_type.id() == LogicalTypeId::DOUBLE) { + is_nan = Value::IsNan(constant.GetValue()); + } + + // Special handling for NaN comparisons (to explicitly violate IEEE-754) + if (is_nan) { + switch (constant_filter.comparison_type) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return constant_field.attr("is_nan")(); + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_NOTEQUAL: + return constant_field.attr("is_nan")().attr("__invert__")(); + case ExpressionType::COMPARE_GREATERTHAN: + // Nothing is greater than NaN + return import_cache.pyarrow.dataset().attr("scalar")(false); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + // Everything is less than or equal to NaN + return import_cache.pyarrow.dataset().attr("scalar")(true); + default: + throw NotImplementedException("Unsupported comparison type (%s) for NaN values", + EnumUtil::ToString(constant_filter.comparison_type)); + } + } - py::object Compare(ExpressionType op, py::object col, py::object scalar) override { - switch (op) { + switch (constant_filter.comparison_type) { case ExpressionType::COMPARE_EQUAL: - return col.attr("__eq__")(scalar); - case ExpressionType::COMPARE_NOTEQUAL: - return col.attr("__ne__")(scalar); + return constant_field.attr("__eq__")(constant_value); case ExpressionType::COMPARE_LESSTHAN: - return col.attr("__lt__")(scalar); + return constant_field.attr("__lt__")(constant_value); case ExpressionType::COMPARE_GREATERTHAN: - return col.attr("__gt__")(scalar); + return constant_field.attr("__gt__")(constant_value); case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return col.attr("__le__")(scalar); + return constant_field.attr("__le__")(constant_value); case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return col.attr("__ge__")(scalar); + return constant_field.attr("__ge__")(constant_value); + case ExpressionType::COMPARE_NOTEQUAL: + return constant_field.attr("__ne__")(constant_value); default: throw NotImplementedException("Comparison Type %s can't be an Arrow Scan Pushdown Filter", - ExpressionTypeToString(op)); + EnumUtil::ToString(constant_filter.comparison_type)); } } - - py::object NaNCompare(ExpressionType op, py::object col) override { - switch (op) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return col.attr("is_nan")(); - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_NOTEQUAL: - return col.attr("is_nan")().attr("__invert__")(); - case ExpressionType::COMPARE_GREATERTHAN: - // Nothing is greater than NaN. - return dataset_scalar(false); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - // Everything is less than or equal to NaN. - return dataset_scalar(true); - default: - throw NotImplementedException("Unsupported comparison type (%s) for NaN values", - ExpressionTypeToString(op)); + //! We do not pushdown is null yet + case TableFilterType::IS_NULL: { + auto constant_field = field(py::tuple(py::cast(column_ref))); + return constant_field.attr("is_null")(); + } + case TableFilterType::IS_NOT_NULL: { + auto constant_field = field(py::tuple(py::cast(column_ref))); + return constant_field.attr("is_valid")(); + } + case TableFilterType::CONJUNCTION_OR: { + auto &or_filter = filter.Cast(); + py::object expression = py::none(); + for (idx_t i = 0; i < or_filter.child_filters.size(); i++) { + auto &child_filter = *or_filter.child_filters[i]; + py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); + if (child_expression.is(py::none())) { + // An OR branch that can't be translated (e.g. DYNAMIC_FILTER) means the pushed-down + // predicate would be stricter than the engine intends — fall back to no pushdown. + return py::none(); + } + if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__or__")(child_expression); + } } + return expression; } - - py::object IsNull(py::object col) override { - return col.attr("is_null")(); + case TableFilterType::CONJUNCTION_AND: { + auto &and_filter = filter.Cast(); + py::object expression = py::none(); + for (idx_t i = 0; i < and_filter.child_filters.size(); i++) { + auto &child_filter = *and_filter.child_filters[i]; + py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); + if (child_expression.is(py::none())) { + continue; + } + if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__and__")(child_expression); + } + } + return expression; } + case TableFilterType::STRUCT_EXTRACT: { + auto &struct_filter = filter.Cast(); + auto &child_name = struct_filter.child_name; + auto &struct_type_info = type.GetTypeInfo(); + auto &struct_child_type = struct_type_info.GetChild(struct_filter.child_idx); - py::object IsNotNull(py::object col) override { - return col.attr("is_valid")(); + column_ref.push_back(child_name); + auto child_expr = TransformFilterRecursive(*struct_filter.child_filter, std::move(column_ref), timezone_config, + struct_child_type); + return child_expr; } - - py::object IsIn(py::object col, const vector &values, const LogicalType &col_logical_type, - const string &timezone_config) override { - // PyArrow accepts a plain Python list of Python-typed scalars; type - // coercion happens inside the scanner. We don't need the column type. - (void)col_logical_type; - (void)timezone_config; - py::list py_values; - for (auto &val : values) { - py_values.append(PythonObject::FromValue(val, val.type(), client_properties)); + case TableFilterType::OPTIONAL_FILTER: { + auto &optional_filter = filter.Cast(); + if (!optional_filter.child_filter) { + return py::none(); } - return col.attr("isin")(std::move(py_values)); + return TransformFilterRecursive(*optional_filter.child_filter, column_ref, timezone_config, type); } - - py::object And(py::object a, py::object b) override { - return a.attr("__and__")(b); + case TableFilterType::IN_FILTER: { + auto &in_filter = filter.Cast(); + auto constant_field = field(py::tuple(py::cast(column_ref))); + auto in_list = TransformInList(in_filter); + return constant_field.attr("isin")(std::move(in_list)); } - - py::object Or(py::object a, py::object b) override { - return a.attr("__or__")(b); + case TableFilterType::DYNAMIC_FILTER: { + //! Ignore dynamic filters for now, not necessary for correctness + return py::none(); } - -private: - const ClientProperties &client_properties; - py::object field_factory; - py::object dataset_scalar; -}; - -} // anonymous namespace + default: + throw NotImplementedException("Pushdown Filter Type %s is not currently supported in PyArrow Scans", + EnumUtil::ToString(filter.filter_type)); + } +} py::object PyArrowFilterPushdown::TransformFilter(TableFilterSet &filter_collection, unordered_map &columns, unordered_map filter_to_col, const ClientProperties &config, const ArrowTableSchema &arrow_table) { - PyArrowBackend backend(config); + auto &filters_map = filter_collection.filters; + py::object expression = py::none(); - for (auto &entry : filter_collection) { - auto column_idx = entry.GetIndex(); + for (auto &it : filters_map) { + auto column_idx = it.first; auto &column_name = columns[column_idx]; + + vector column_ref; + column_ref.push_back(column_name); + D_ASSERT(columns.find(column_idx) != columns.end()); - vector column_path = {Identifier(column_name)}; auto &arrow_type = arrow_table.GetColumns().at(filter_to_col.at(column_idx)); - py::object child_expression = duckdb::TransformFilter(entry.Filter(), std::move(column_path), backend, - arrow_type.get(), config.time_zone); + py::object child_expression = TransformFilterRecursive(*it.second, column_ref, config.time_zone, *arrow_type); if (child_expression.is(py::none())) { continue; - } - if (expression.is(py::none())) { + } else if (expression.is(py::none())) { expression = std::move(child_expression); } else { expression = expression.attr("__and__")(child_expression); diff --git a/src/duckdb_py/duckdb_python.cpp b/src/duckdb_py/duckdb_python.cpp index 5a8506f9..ea2ac66d 100644 --- a/src/duckdb_py/duckdb_python.cpp +++ b/src/duckdb_py/duckdb_python.cpp @@ -9,6 +9,7 @@ #include "duckdb_python/pystatement.hpp" #include "duckdb_python/pyrelation.hpp" #include "duckdb_python/expression/pyexpression.hpp" +#include "duckdb_python/pyresult.hpp" #include "duckdb_python/pybind11/exceptions.hpp" #include "duckdb_python/typing.hpp" #include "duckdb_python/functional.hpp" @@ -21,6 +22,8 @@ #include "duckdb/common/enums/statement_type.hpp" #include "duckdb/common/adbc/adbc-init.hpp" +#include "duckdb.hpp" + #ifndef DUCKDB_PYTHON_LIB_NAME #define DUCKDB_PYTHON_LIB_NAME _duckdb #endif @@ -76,7 +79,7 @@ static void InitializeConnectionMethods(py::module_ &m) { // START_OF_CONNECTION_METHODS m.def( "cursor", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -85,7 +88,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Create a duplicate of the current connection", py::kw_only(), py::arg("connection") = py::none()); m.def( "register_filesystem", - [](AbstractFileSystem filesystem, std::shared_ptr conn = nullptr) { + [](AbstractFileSystem filesystem, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -95,7 +98,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "unregister_filesystem", - [](const py::str &name, std::shared_ptr conn = nullptr) { + [](const py::str &name, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -104,7 +107,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Unregister a filesystem", py::arg("name"), py::kw_only(), py::arg("connection") = py::none()); m.def( "list_filesystems", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -113,7 +116,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "List registered filesystems, including builtin ones", py::kw_only(), py::arg("connection") = py::none()); m.def( "filesystem_is_registered", - [](const string &name, std::shared_ptr conn = nullptr) { + [](const string &name, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -123,7 +126,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "get_profiling_information", - [](const std::string &format, std::shared_ptr conn = nullptr) { + [](const py::str &format, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -133,7 +136,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "enable_profiling", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -142,7 +145,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Enable profiling for the current connection", py::kw_only(), py::arg("connection") = py::none()); m.def( "disable_profiling", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -152,10 +155,10 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "create_function", [](const string &name, const py::function &udf, const py::object &arguments = py::none(), - const std::shared_ptr &return_type = nullptr, PythonUDFType type = PythonUDFType::NATIVE, + const shared_ptr &return_type = nullptr, PythonUDFType type = PythonUDFType::NATIVE, FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING, PythonExceptionHandling exception_handling = PythonExceptionHandling::FORWARD_ERROR, - bool side_effects = false, std::shared_ptr conn = nullptr) { + bool side_effects = false, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -169,7 +172,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "remove_function", - [](const string &name, std::shared_ptr conn = nullptr) { + [](const string &name, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -178,7 +181,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Remove a previously created function", py::arg("name"), py::kw_only(), py::arg("connection") = py::none()); m.def( "sqltype", - [](const string &type_str, std::shared_ptr conn = nullptr) { + [](const string &type_str, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -188,7 +191,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "dtype", - [](const string &type_str, std::shared_ptr conn = nullptr) { + [](const string &type_str, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -198,7 +201,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "type", - [](const string &type_str, std::shared_ptr conn = nullptr) { + [](const string &type_str, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -208,7 +211,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "array_type", - [](const std::shared_ptr &type, idx_t size, std::shared_ptr conn = nullptr) { + [](const shared_ptr &type, idx_t size, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -218,7 +221,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "list_type", - [](const std::shared_ptr &type, std::shared_ptr conn = nullptr) { + [](const shared_ptr &type, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -228,7 +231,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "union_type", - [](const py::object &members, std::shared_ptr conn = nullptr) { + [](const py::object &members, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -238,7 +241,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "string_type", - [](const string &collation = string(), std::shared_ptr conn = nullptr) { + [](const string &collation = string(), shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -248,8 +251,8 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "enum_type", - [](const string &name, const std::shared_ptr &type, const py::list &values_p, - std::shared_ptr conn = nullptr) { + [](const string &name, const shared_ptr &type, const py::list &values_p, + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -259,7 +262,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("type"), py::arg("values"), py::kw_only(), py::arg("connection") = py::none()); m.def( "decimal_type", - [](int width, int scale, std::shared_ptr conn = nullptr) { + [](int width, int scale, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -269,7 +272,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "struct_type", - [](const py::object &fields, std::shared_ptr conn = nullptr) { + [](const py::object &fields, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -279,7 +282,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "row_type", - [](const py::object &fields, std::shared_ptr conn = nullptr) { + [](const py::object &fields, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -289,8 +292,8 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "map_type", - [](const std::shared_ptr &key_type, const std::shared_ptr &value_type, - std::shared_ptr conn = nullptr) { + [](const shared_ptr &key_type, const shared_ptr &value_type, + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -300,7 +303,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("value").none(false), py::kw_only(), py::arg("connection") = py::none()); m.def( "duplicate", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -309,8 +312,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Create a duplicate of the current connection", py::kw_only(), py::arg("connection") = py::none()); m.def( "execute", - [](const py::object &query, py::object params = py::list(), - std::shared_ptr conn = nullptr) { + [](const py::object &query, py::object params = py::list(), shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -320,8 +322,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("parameters") = py::none(), py::kw_only(), py::arg("connection") = py::none()); m.def( "executemany", - [](const py::object &query, py::object params = py::list(), - std::shared_ptr conn = nullptr) { + [](const py::object &query, py::object params = py::list(), shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -331,7 +332,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("query"), py::arg("parameters") = py::none(), py::kw_only(), py::arg("connection") = py::none()); m.def( "close", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -340,7 +341,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Close the connection", py::kw_only(), py::arg("connection") = py::none()); m.def( "interrupt", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -349,7 +350,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Interrupt pending operations", py::kw_only(), py::arg("connection") = py::none()); m.def( "query_progress", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -358,7 +359,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Query progress of pending operation", py::kw_only(), py::arg("connection") = py::none()); m.def( "fetchone", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -367,7 +368,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Fetch a single row from a result following execute", py::kw_only(), py::arg("connection") = py::none()); m.def( "fetchmany", - [](idx_t size, std::shared_ptr conn = nullptr) { + [](idx_t size, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -377,7 +378,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "fetchall", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -386,7 +387,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Fetch all rows from a result following execute", py::kw_only(), py::arg("connection") = py::none()); m.def( "fetchnumpy", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -395,7 +396,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Fetch a result as list of NumPy arrays following execute", py::kw_only(), py::arg("connection") = py::none()); m.def( "fetchdf", - [](bool date_as_object, std::shared_ptr conn = nullptr) { + [](bool date_as_object, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -405,7 +406,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "fetch_df", - [](bool date_as_object, std::shared_ptr conn = nullptr) { + [](bool date_as_object, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -415,7 +416,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "df", - [](bool date_as_object, std::shared_ptr conn = nullptr) { + [](bool date_as_object, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -426,7 +427,7 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "fetch_df_chunk", [](const idx_t vectors_per_chunk = 1, bool date_as_object = false, - std::shared_ptr conn = nullptr) { + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -436,7 +437,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("date_as_object") = false, py::arg("connection") = py::none()); m.def( "pl", - [](idx_t rows_per_batch, bool lazy, std::shared_ptr conn = nullptr) { + [](idx_t rows_per_batch, bool lazy, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -446,7 +447,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("lazy") = false, py::arg("connection") = py::none()); m.def( "to_arrow_table", - [](idx_t batch_size, std::shared_ptr conn = nullptr) { + [](idx_t batch_size, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -456,7 +457,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "to_arrow_reader", - [](idx_t batch_size, std::shared_ptr conn = nullptr) { + [](idx_t batch_size, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -466,7 +467,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "fetch_arrow_table", - [](idx_t rows_per_batch, std::shared_ptr conn = nullptr) { + [](idx_t rows_per_batch, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -478,7 +479,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "fetch_record_batch", - [](const idx_t rows_per_batch, std::shared_ptr conn = nullptr) { + [](const idx_t rows_per_batch, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -490,7 +491,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "torch", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -500,7 +501,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "tf", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -510,7 +511,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "begin", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -519,7 +520,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Start a new transaction", py::kw_only(), py::arg("connection") = py::none()); m.def( "commit", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -528,7 +529,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Commit changes performed within a transaction", py::kw_only(), py::arg("connection") = py::none()); m.def( "rollback", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -537,7 +538,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Roll back changes performed within a transaction", py::kw_only(), py::arg("connection") = py::none()); m.def( "checkpoint", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -548,7 +549,7 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "append", [](const string &name, const PandasDataFrame &value, bool by_name, - std::shared_ptr conn = nullptr) { + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -558,7 +559,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("by_name") = false, py::arg("connection") = py::none()); m.def( "register", - [](const string &name, const py::object &python_object, std::shared_ptr conn = nullptr) { + [](const string &name, const py::object &python_object, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -568,7 +569,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("python_object"), py::kw_only(), py::arg("connection") = py::none()); m.def( "unregister", - [](const string &name, std::shared_ptr conn = nullptr) { + [](const string &name, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -577,7 +578,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Unregister the view name", py::arg("view_name"), py::kw_only(), py::arg("connection") = py::none()); m.def( "table", - [](const string &tname, std::shared_ptr conn = nullptr) { + [](const string &tname, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -587,7 +588,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "view", - [](const string &vname, std::shared_ptr conn = nullptr) { + [](const string &vname, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -597,7 +598,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "values", - [](const py::args ¶ms, std::shared_ptr conn = nullptr) { + [](const py::args ¶ms, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -606,7 +607,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Create a relation object from the passed values", py::kw_only(), py::arg("connection") = py::none()); m.def( "table_function", - [](const string &fname, py::object params = py::list(), std::shared_ptr conn = nullptr) { + [](const string &fname, py::object params = py::list(), shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -632,7 +633,7 @@ static void InitializeConnectionMethods(py::module_ &m) { const Optional &hive_partitioning = py::none(), const Optional &union_by_name = py::none(), const Optional &hive_types = py::none(), const Optional &hive_types_autocast = py::none(), - std::shared_ptr conn = nullptr) { + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -654,7 +655,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("hive_types_autocast") = py::none(), py::arg("connection") = py::none()); m.def( "extract_statements", - [](const string &query, std::shared_ptr conn = nullptr) { + [](const string &query, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -665,7 +666,7 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "sql", [](const py::object &query, string alias = "", py::object params = py::list(), - std::shared_ptr conn = nullptr) { + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -678,7 +679,7 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "query", [](const py::object &query, string alias = "", py::object params = py::list(), - std::shared_ptr conn = nullptr) { + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -691,7 +692,7 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "from_query", [](const py::object &query, string alias = "", py::object params = py::list(), - std::shared_ptr conn = nullptr) { + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -705,7 +706,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "read_csv", [](const py::object &name, py::kwargs &kwargs) { auto connection_arg = kwargs.contains("conn") ? kwargs["conn"] : py::none(); - auto conn = py::cast>(connection_arg); + auto conn = py::cast>(connection_arg); if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); @@ -717,7 +718,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "from_csv_auto", [](const py::object &name, py::kwargs &kwargs) { auto connection_arg = kwargs.contains("conn") ? kwargs["conn"] : py::none(); - auto conn = py::cast>(connection_arg); + auto conn = py::cast>(connection_arg); if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); @@ -727,7 +728,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Create a relation object from the CSV file in 'name'", py::arg("path_or_buffer"), py::kw_only()); m.def( "from_df", - [](const PandasDataFrame &value, std::shared_ptr conn = nullptr) { + [](const PandasDataFrame &value, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -737,7 +738,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "from_arrow", - [](py::object &arrow_object, std::shared_ptr conn = nullptr) { + [](py::object &arrow_object, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -749,7 +750,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "from_parquet", [](const py::object &path_or_buffer, bool binary_as_string, bool file_row_number, bool filename, bool hive_partitioning, bool union_by_name, const py::object &compression = py::none(), - std::shared_ptr conn = nullptr) { + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -764,7 +765,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "read_parquet", [](const py::object &path_or_buffer, bool binary_as_string, bool file_row_number, bool filename, bool hive_partitioning, bool union_by_name, const py::object &compression = py::none(), - std::shared_ptr conn = nullptr) { + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -777,7 +778,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("union_by_name") = false, py::arg("compression") = py::none(), py::arg("connection") = py::none()); m.def( "get_table_names", - [](const string &query, bool qualified, std::shared_ptr conn = nullptr) { + [](const string &query, bool qualified, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -789,7 +790,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "install_extension", [](const string &extension, bool force_install = false, const py::object &repository = py::none(), const py::object &repository_url = py::none(), const py::object &version = py::none(), - std::shared_ptr conn = nullptr) { + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -800,7 +801,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("repository_url") = py::none(), py::arg("version") = py::none(), py::arg("connection") = py::none()); m.def( "load_extension", - [](const string &extension, std::shared_ptr conn = nullptr) { + [](const string &extension, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -810,7 +811,7 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "project", [](const PandasDataFrame &df, const py::args &args, const string &groups = "", - std::shared_ptr conn = nullptr) { + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -820,7 +821,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("groups") = "", py::arg("connection") = py::none()); m.def( "distinct", - [](const PandasDataFrame &df, std::shared_ptr conn = nullptr) { + [](const PandasDataFrame &df, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -838,7 +839,7 @@ static void InitializeConnectionMethods(py::module_ &m) { const py::object &compression = py::none(), const py::object &overwrite = py::none(), const py::object &per_thread_output = py::none(), const py::object &use_tmp_file = py::none(), const py::object &partition_by = py::none(), const py::object &write_partition_columns = py::none(), - std::shared_ptr conn = nullptr) { + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -857,7 +858,7 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "aggregate", [](const PandasDataFrame &df, const py::object &expr, const string &groups = "", - std::shared_ptr conn = nullptr) { + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -867,7 +868,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("aggr_expr"), py::arg("group_expr") = "", py::kw_only(), py::arg("connection") = py::none()); m.def( "alias", - [](const PandasDataFrame &df, const string &expr, std::shared_ptr conn = nullptr) { + [](const PandasDataFrame &df, const string &expr, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -877,7 +878,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "filter", - [](const PandasDataFrame &df, const py::object &expr, std::shared_ptr conn = nullptr) { + [](const PandasDataFrame &df, const py::object &expr, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -887,8 +888,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "limit", - [](const PandasDataFrame &df, int64_t n, int64_t offset = 0, - std::shared_ptr conn = nullptr) { + [](const PandasDataFrame &df, int64_t n, int64_t offset = 0, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -898,7 +898,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("offset") = 0, py::kw_only(), py::arg("connection") = py::none()); m.def( "order", - [](const PandasDataFrame &df, const string &expr, std::shared_ptr conn = nullptr) { + [](const PandasDataFrame &df, const string &expr, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -909,7 +909,7 @@ static void InitializeConnectionMethods(py::module_ &m) { m.def( "query_df", [](const PandasDataFrame &df, const string &view_name, const string &sql_query, - std::shared_ptr conn = nullptr) { + shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -920,7 +920,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "description", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -929,7 +929,7 @@ static void InitializeConnectionMethods(py::module_ &m) { "Get result set attributes, mainly column names", py::kw_only(), py::arg("connection") = py::none()); m.def( "rowcount", - [](std::shared_ptr conn = nullptr) { + [](shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -941,7 +941,7 @@ static void InitializeConnectionMethods(py::module_ &m) { // We define these "wrapper" methods manually because they are overloaded m.def( "arrow", - [](idx_t rows_per_batch, std::shared_ptr conn) -> duckdb::pyarrow::RecordBatchReader { + [](idx_t rows_per_batch, shared_ptr conn) -> duckdb::pyarrow::RecordBatchReader { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -951,7 +951,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("rows_per_batch") = 1000000, py::kw_only(), py::arg("connection") = py::none()); m.def( "arrow", - [](py::object &arrow_object, std::shared_ptr conn) -> std::unique_ptr { + [](py::object &arrow_object, shared_ptr conn) -> unique_ptr { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -961,7 +961,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "df", - [](bool date_as_object, std::shared_ptr conn) -> PandasDataFrame { + [](bool date_as_object, shared_ptr conn) -> PandasDataFrame { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -971,8 +971,7 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("connection") = py::none()); m.def( "df", - [](const PandasDataFrame &value, - std::shared_ptr conn) -> std::unique_ptr { + [](const PandasDataFrame &value, shared_ptr conn) -> unique_ptr { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } @@ -1106,7 +1105,7 @@ PYBIND11_MODULE(DUCKDB_PYTHON_LIB_NAME, m) { // NOLINT "Tokenizes a SQL string, returning a list of (position, type) tuples that can be " "used for e.g., syntax highlighting", py::arg("query")); - py::enum_(m, "token_type") + py::enum_(m, "token_type", py::module_local()) .value("identifier", PySQLTokenType::PY_SQL_TOKEN_IDENTIFIER) .value("numeric_const", PySQLTokenType::PY_SQL_TOKEN_NUMERIC_CONSTANT) .value("string_const", PySQLTokenType::PY_SQL_TOKEN_STRING_CONSTANT) diff --git a/src/duckdb_py/include/duckdb_python/arrow/filter_pushdown_visitor.hpp b/src/duckdb_py/include/duckdb_python/arrow/filter_pushdown_visitor.hpp deleted file mode 100644 index 22111ea8..00000000 --- a/src/duckdb_py/include/duckdb_python/arrow/filter_pushdown_visitor.hpp +++ /dev/null @@ -1,94 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb_python/arrow/filter_pushdown_visitor.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/types/value.hpp" -#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" -#include "duckdb/planner/expression.hpp" -#include "duckdb/planner/table_filter.hpp" -#include "duckdb_python/pybind11/pybind_wrapper.hpp" - -namespace duckdb { - -// A FilterBackend abstracts the Python side of an `ExpressionFilter` → -// expression translation. The shared walker in this file handles the -// structural recursion (CONJUNCTION_AND/OR, struct_extract column paths, the -// optional / selectivity-optional filter wrappers, and the internal runtime -// filter functions) and dispatches leaf operations to the backend. -// -// Two backends exist today: PyArrowBackend (emits pyarrow.dataset.Expression) -// and PolarsBackend (emits polars.Expr). Adding a new backend is purely a -// matter of implementing this interface; the walker itself is reused. -// -// Convention: a backend method that cannot push the given filter must throw -// `NotImplementedException`. The walker swallows it at optional-filter -// boundaries (an optional filter is not required for correctness) and the -// top-level entry points catch it too, returning `py::none()` for the affected -// column. Throwing keeps the "I can't push this" path uniform across backends, -// replacing the old polars walker's ad hoc `return py::none()` style. -struct FilterBackend { - virtual ~FilterBackend() = default; - - // Build a column expression from an accumulated path. `path` always has - // at least one element (the top-level column). For nested struct - // references the path accumulates one entry per `struct_extract`. - virtual py::object MakeColumnRef(const vector &path) = 0; - - // Convert a DuckDB Value to a backend-native Python scalar. `arrow_type` - // may be nullptr for backends that don't need Arrow type information - // (polars relies on DuckDB LogicalType only). `timezone_config` is the - // active session's time zone for `TIMESTAMP_TZ` handling. - virtual py::object MakeScalar(const Value &v, const ArrowType *arrow_type, const string &timezone_config) = 0; - - // Apply a comparison operator. `op` is one of the COMPARE_* ExpressionTypes. - // `scalar` is what MakeScalar returned. NaN special cases go through - // NaNCompare instead. - virtual py::object Compare(ExpressionType op, py::object col, py::object scalar) = 0; - - // NaN-specific comparison. DuckDB treats NaN as the greatest value, so - // each operator decomposes into is_nan / ~is_nan / lit(true|false). - virtual py::object NaNCompare(ExpressionType op, py::object col) = 0; - - virtual py::object IsNull(py::object col) = 0; - virtual py::object IsNotNull(py::object col) = 0; - - // IN list. `col_logical_type` is the column's DuckDB logical type — needed - // by polars to construct a typed Series with matching precision/scale for - // decimal columns. PyArrow ignores this parameter and uses MakeScalar - // per-element. - virtual py::object IsIn(py::object col, const vector &values, const LogicalType &col_logical_type, - const string &timezone_config) = 0; - - virtual py::object And(py::object a, py::object b) = 0; - virtual py::object Or(py::object a, py::object b) = 0; -}; - -// Walk a TableFilter and emit a backend-specific expression. Since the -// table-filter -> expression-filter migration in core, the only runtime filter -// type is `EXPRESSION_FILTER`; this unwraps it and walks the expression tree. -// - `column_path` is the top-level column name; struct paths are accumulated -// inside the expression walk via struct_extract. -// - `arrow_type` is the ArrowType for the current path leaf (nullable for -// backends that don't track Arrow types). -// - Returns `py::none()` if no part of the filter could be pushed. -py::object TransformFilter(const TableFilter &filter, const vector &column_path, FilterBackend &backend, - const ArrowType *arrow_type, const string &timezone_config); - -// Walk a bound Expression tree (the contents of an `ExpressionFilter`) and emit -// a backend-specific expression. Handles BOUND_FUNCTION comparisons, -// BOUND_OPERATOR (IS_NULL / IS_NOT_NULL / COMPARE_IN), BOUND_CONJUNCTION -// (AND/OR), struct_extract column chains, the optional / selectivity-optional -// wrappers (unwrapped from `bind_info`; an untranslatable child is swallowed), -// and the internal runtime filter functions (dynamic / bloom / perfect-hash-join -// / prefix-range, which are skipped). Returns `py::none()` for an optional or -// runtime filter that can't be pushed. -py::object TransformExpression(const Expression &expression, const vector &column_path, - FilterBackend &backend, const ArrowType *arrow_type, const string &timezone_config); - -} // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/arrow/polars_filter_pushdown.hpp b/src/duckdb_py/include/duckdb_python/arrow/polars_filter_pushdown.hpp index a22d367e..adf485c9 100644 --- a/src/duckdb_py/include/duckdb_python/arrow/polars_filter_pushdown.hpp +++ b/src/duckdb_py/include/duckdb_python/arrow/polars_filter_pushdown.hpp @@ -8,7 +8,8 @@ #pragma once -#include "duckdb/planner/table_filter_set.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/planner/table_filter.hpp" #include "duckdb/main/client_properties.hpp" #include "duckdb_python/pybind11/pybind_wrapper.hpp" diff --git a/src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp b/src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp index bf029d76..4cc85a47 100644 --- a/src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp +++ b/src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp @@ -8,8 +8,10 @@ #pragma once +#include "duckdb/common/arrow/arrow_wrapper.hpp" #include "duckdb/function/table/arrow/arrow_duck_schema.hpp" -#include "duckdb/planner/table_filter_set.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/planner/table_filter.hpp" #include "duckdb/main/client_properties.hpp" #include "duckdb_python/pybind11/pybind_wrapper.hpp" diff --git a/src/duckdb_py/include/duckdb_python/expression/pyexpression.hpp b/src/duckdb_py/include/duckdb_python/expression/pyexpression.hpp index 2e741cd8..43c0c5c3 100644 --- a/src/duckdb_py/include/duckdb_python/expression/pyexpression.hpp +++ b/src/duckdb_py/include/duckdb_python/expression/pyexpression.hpp @@ -23,14 +23,14 @@ namespace duckdb { -struct DuckDBPyExpression : public std::enable_shared_from_this { +struct DuckDBPyExpression : public enable_shared_from_this { public: explicit DuckDBPyExpression(unique_ptr expr, OrderType order_type = OrderType::ORDER_DEFAULT, OrderByNullType null_order = OrderByNullType::ORDER_DEFAULT); public: - std::shared_ptr shared_from_this() { - return std::enable_shared_from_this::shared_from_this(); + shared_ptr shared_from_this() { + return enable_shared_from_this::shared_from_this(); } public: @@ -41,93 +41,92 @@ struct DuckDBPyExpression : public std::enable_shared_from_this Add(const DuckDBPyExpression &other) const; - std::shared_ptr Subtract(const DuckDBPyExpression &other) const; - std::shared_ptr Multiply(const DuckDBPyExpression &other) const; - std::shared_ptr Division(const DuckDBPyExpression &other) const; - std::shared_ptr FloorDivision(const DuckDBPyExpression &other) const; - std::shared_ptr Modulo(const DuckDBPyExpression &other) const; - std::shared_ptr Power(const DuckDBPyExpression &other) const; - std::shared_ptr Negate(); + shared_ptr Add(const DuckDBPyExpression &other) const; + shared_ptr Subtract(const DuckDBPyExpression &other) const; + shared_ptr Multiply(const DuckDBPyExpression &other) const; + shared_ptr Division(const DuckDBPyExpression &other) const; + shared_ptr FloorDivision(const DuckDBPyExpression &other) const; + shared_ptr Modulo(const DuckDBPyExpression &other) const; + shared_ptr Power(const DuckDBPyExpression &other) const; + shared_ptr Negate(); // Equality operations - std::shared_ptr Equality(const DuckDBPyExpression &other); - std::shared_ptr Inequality(const DuckDBPyExpression &other); - std::shared_ptr GreaterThan(const DuckDBPyExpression &other); - std::shared_ptr GreaterThanOrEqual(const DuckDBPyExpression &other); - std::shared_ptr LessThan(const DuckDBPyExpression &other); - std::shared_ptr LessThanOrEqual(const DuckDBPyExpression &other); + shared_ptr Equality(const DuckDBPyExpression &other); + shared_ptr Inequality(const DuckDBPyExpression &other); + shared_ptr GreaterThan(const DuckDBPyExpression &other); + shared_ptr GreaterThanOrEqual(const DuckDBPyExpression &other); + shared_ptr LessThan(const DuckDBPyExpression &other); + shared_ptr LessThanOrEqual(const DuckDBPyExpression &other); - std::shared_ptr SetAlias(const string &alias) const; - std::shared_ptr When(const DuckDBPyExpression &condition, const DuckDBPyExpression &value); - std::shared_ptr Else(const DuckDBPyExpression &value); + shared_ptr SetAlias(const string &alias) const; + shared_ptr When(const DuckDBPyExpression &condition, const DuckDBPyExpression &value); + shared_ptr Else(const DuckDBPyExpression &value); - std::shared_ptr Cast(const DuckDBPyType &type) const; - std::shared_ptr Between(const DuckDBPyExpression &lower, const DuckDBPyExpression &upper); - std::shared_ptr Collate(const string &collation); + shared_ptr Cast(const DuckDBPyType &type) const; + shared_ptr Between(const DuckDBPyExpression &lower, const DuckDBPyExpression &upper); + shared_ptr Collate(const string &collation); // AND, OR and NOT - std::shared_ptr Not(); - std::shared_ptr And(const DuckDBPyExpression &other) const; - std::shared_ptr Or(const DuckDBPyExpression &other) const; + shared_ptr Not(); + shared_ptr And(const DuckDBPyExpression &other) const; + shared_ptr Or(const DuckDBPyExpression &other) const; // IS NULL / IS NOT NULL - std::shared_ptr IsNull(); - std::shared_ptr IsNotNull(); + shared_ptr IsNull(); + shared_ptr IsNotNull(); // IN / NOT IN - std::shared_ptr CreateCompareExpression(ExpressionType compare_type, const py::args &args); - std::shared_ptr In(const py::args &args); - std::shared_ptr NotIn(const py::args &args); + shared_ptr CreateCompareExpression(ExpressionType compare_type, const py::args &args); + shared_ptr In(const py::args &args); + shared_ptr NotIn(const py::args &args); // Order modifiers - std::shared_ptr Ascending(); - std::shared_ptr Descending(); + shared_ptr Ascending(); + shared_ptr Descending(); // Null order modifiers - std::shared_ptr NullsFirst(); - std::shared_ptr NullsLast(); + shared_ptr NullsFirst(); + shared_ptr NullsLast(); public: const ParsedExpression &GetExpression() const; - std::shared_ptr Copy() const; + shared_ptr Copy() const; public: - static std::shared_ptr StarExpression(py::object exclude = py::none()); - static std::shared_ptr ColumnExpression(const py::args &column_name); - static std::shared_ptr DefaultExpression(); - static std::shared_ptr ConstantExpression(const py::object &value); - static std::shared_ptr LambdaExpression(const py::object &lhs, const DuckDBPyExpression &rhs); - static std::shared_ptr CaseExpression(const DuckDBPyExpression &condition, - const DuckDBPyExpression &value); - static std::shared_ptr FunctionExpression(const string &function_name, const py::args &args); - static std::shared_ptr Coalesce(const py::args &args); - static std::shared_ptr SQLExpression(string sql); + static shared_ptr StarExpression(py::object exclude = py::none()); + static shared_ptr ColumnExpression(const py::args &column_name); + static shared_ptr DefaultExpression(); + static shared_ptr ConstantExpression(const py::object &value); + static shared_ptr LambdaExpression(const py::object &lhs, const DuckDBPyExpression &rhs); + static shared_ptr CaseExpression(const DuckDBPyExpression &condition, + const DuckDBPyExpression &value); + static shared_ptr FunctionExpression(const string &function_name, const py::args &args); + static shared_ptr Coalesce(const py::args &args); + static shared_ptr SQLExpression(string sql); public: // Internal functions (not exposed to Python) - static std::shared_ptr InternalFunctionExpression(const string &function_name, - vector> children, - bool is_operator = false); - - static std::shared_ptr InternalUnaryOperator(ExpressionType type, - const DuckDBPyExpression &arg); - static std::shared_ptr InternalConjunction(ExpressionType type, const DuckDBPyExpression &arg, - const DuckDBPyExpression &other); - static std::shared_ptr InternalConstantExpression(Value value); - static std::shared_ptr - BinaryOperator(const string &function_name, const DuckDBPyExpression &arg_one, const DuckDBPyExpression &arg_two); - static std::shared_ptr ComparisonExpression(ExpressionType type, const DuckDBPyExpression &left, - const DuckDBPyExpression &right); - static std::shared_ptr InternalWhen(unique_ptr expr, - const DuckDBPyExpression &condition, - const DuckDBPyExpression &value); + static shared_ptr InternalFunctionExpression(const string &function_name, + vector> children, + bool is_operator = false); + + static shared_ptr InternalUnaryOperator(ExpressionType type, const DuckDBPyExpression &arg); + static shared_ptr InternalConjunction(ExpressionType type, const DuckDBPyExpression &arg, + const DuckDBPyExpression &other); + static shared_ptr InternalConstantExpression(Value value); + static shared_ptr BinaryOperator(const string &function_name, const DuckDBPyExpression &arg_one, + const DuckDBPyExpression &arg_two); + static shared_ptr ComparisonExpression(ExpressionType type, const DuckDBPyExpression &left, + const DuckDBPyExpression &right); + static shared_ptr InternalWhen(unique_ptr expr, + const DuckDBPyExpression &condition, + const DuckDBPyExpression &value); void AssertCaseExpression() const; private: diff --git a/src/duckdb_py/include/duckdb_python/import_cache/modules/polars_module.hpp b/src/duckdb_py/include/duckdb_python/import_cache/modules/polars_module.hpp index aec73c0c..17f746fb 100644 --- a/src/duckdb_py/include/duckdb_python/import_cache/modules/polars_module.hpp +++ b/src/duckdb_py/include/duckdb_python/import_cache/modules/polars_module.hpp @@ -28,7 +28,7 @@ struct PolarsCacheItem : public PythonImportCacheItem { public: PolarsCacheItem() : PythonImportCacheItem("polars"), DataFrame("DataFrame", this), LazyFrame("LazyFrame", this), col("col", this), - lit("lit", this), Series("Series", this), Decimal("Decimal", this) { + lit("lit", this) { } ~PolarsCacheItem() override { } @@ -37,8 +37,6 @@ struct PolarsCacheItem : public PythonImportCacheItem { PythonImportCacheItem LazyFrame; PythonImportCacheItem col; PythonImportCacheItem lit; - PythonImportCacheItem Series; - PythonImportCacheItem Decimal; protected: bool IsRequired() const override final { diff --git a/src/duckdb_py/include/duckdb_python/numpy/array_wrapper.hpp b/src/duckdb_py/include/duckdb_python/numpy/array_wrapper.hpp index 4b143aee..a9740e2c 100644 --- a/src/duckdb_py/include/duckdb_python/numpy/array_wrapper.hpp +++ b/src/duckdb_py/include/duckdb_python/numpy/array_wrapper.hpp @@ -41,8 +41,8 @@ struct NumpyAppendData { struct ArrayWrapper { explicit ArrayWrapper(const LogicalType &type, const ClientProperties &client_properties, bool pandas = false); - std::unique_ptr data; - std::unique_ptr mask; + unique_ptr data; + unique_ptr mask; bool requires_mask; const ClientProperties client_properties; bool pandas; diff --git a/src/duckdb_py/include/duckdb_python/numpy/numpy_array.hpp b/src/duckdb_py/include/duckdb_python/numpy/numpy_array.hpp deleted file mode 100644 index b9aae9f4..00000000 --- a/src/duckdb_py/include/duckdb_python/numpy/numpy_array.hpp +++ /dev/null @@ -1,77 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb_python/numpy/numpy_array.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb_python/pybind11/pybind_wrapper.hpp" -#include "duckdb.hpp" - -namespace duckdb { - -//! Thin façade over pybind11's `py::array`. -//! -//! This class is the SINGLE place in the codebase that names `py::array` as the -//! underlying numpy-array representation. A future migration to nanobind's -//! `nb::ndarray` should only require changing the member type and the handful of -//! small methods defined here -- every call site goes through this wrapper -//! instead of touching `py::array` directly. -//! -//! For operations that don't (yet) have a first-class method on the façade -//! (Python attribute access via `.attr(...)`, iteration, resizing, handing the -//! array back to Python, ...) use `GetArray()` to reach the underlying object. -class NumpyArray { -public: - NumpyArray() = default; - //! Wrap an existing numpy array. A `py::object` argument is implicitly - //! converted to a `py::array` (np.asarray semantics), matching the behaviour - //! the call sites relied on before this façade existed. - explicit NumpyArray(py::array arr) : array(std::move(arr)) { - } - - NumpyArray(NumpyArray &&) = default; - NumpyArray &operator=(NumpyArray &&) = default; - NumpyArray(const NumpyArray &) = default; - NumpyArray &operator=(const NumpyArray &) = default; - -public: - //! Allocate a fresh, contiguous 1-D numpy array of `count` elements with the - //! given dtype. - static NumpyArray Allocate(const py::dtype &dtype, idx_t count) { - return NumpyArray(py::array(py::dtype(dtype), count)); - } - - //! Produce a numpy array from an arbitrary Python object (np.asarray semantics). - static NumpyArray FromObject(py::object obj) { - return NumpyArray(py::array(std::move(obj))); - } - - //! Read-only pointer to the underlying data buffer (wraps `py::array::data()`). - const void *Data() const { - return array.data(); - } - - //! Mutable pointer to the underlying data buffer (wraps `py::array::mutable_data()`). - void *MutableData() { - return array.mutable_data(); - } - - //! Access the underlying array, e.g. for `.attr(...)` calls, iteration, or to - //! hand it back to Python. - py::array &GetArray() { - return array; - } - const py::array &GetArray() const { - return array; - } - -private: - //! The single data member -- the one spot that later becomes `nb::ndarray`. - py::array array; -}; - -} // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/numpy/numpy_bind.hpp b/src/duckdb_py/include/duckdb_python/numpy/numpy_bind.hpp index b98d52d4..aa79961e 100644 --- a/src/duckdb_py/include/duckdb_python/numpy/numpy_bind.hpp +++ b/src/duckdb_py/include/duckdb_python/numpy/numpy_bind.hpp @@ -9,7 +9,7 @@ struct PandasColumnBindData; class ClientContext; struct NumpyBind { - static void Bind(ClientContext &config, py::handle df, vector &out, + static void Bind(const ClientContext &config, py::handle df, vector &out, vector &return_types, vector &names); }; diff --git a/src/duckdb_py/include/duckdb_python/numpy/numpy_scan.hpp b/src/duckdb_py/include/duckdb_python/numpy/numpy_scan.hpp index 9be459be..575cebb9 100644 --- a/src/duckdb_py/include/duckdb_python/numpy/numpy_scan.hpp +++ b/src/duckdb_py/include/duckdb_python/numpy/numpy_scan.hpp @@ -8,9 +8,8 @@ namespace duckdb { struct PandasColumnBindData; struct NumpyScan { - static void Scan(ClientContext &context, PandasColumnBindData &bind_data, idx_t count, idx_t offset, Vector &out); - static void ScanObjectColumn(ClientContext &context, PyObject **col, idx_t stride, idx_t count, idx_t offset, - Vector &out); + static void Scan(PandasColumnBindData &bind_data, idx_t count, idx_t offset, Vector &out); + static void ScanObjectColumn(PyObject **col, idx_t stride, idx_t count, idx_t offset, Vector &out); }; } // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/numpy/raw_array_wrapper.hpp b/src/duckdb_py/include/duckdb_python/numpy/raw_array_wrapper.hpp index d24e2612..124f2112 100644 --- a/src/duckdb_py/include/duckdb_python/numpy/raw_array_wrapper.hpp +++ b/src/duckdb_py/include/duckdb_python/numpy/raw_array_wrapper.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb_python/pybind11/pybind_wrapper.hpp" -#include "duckdb_python/numpy/numpy_array.hpp" #include "duckdb.hpp" namespace duckdb { @@ -18,7 +17,7 @@ struct RawArrayWrapper { explicit RawArrayWrapper(const LogicalType &type); - NumpyArray array; + py::array array; data_ptr_t data; LogicalType type; idx_t type_width; diff --git a/src/duckdb_py/include/duckdb_python/pandas/column/pandas_numpy_column.hpp b/src/duckdb_py/include/duckdb_python/pandas/column/pandas_numpy_column.hpp index 20b630d4..9d8587ee 100644 --- a/src/duckdb_py/include/duckdb_python/pandas/column/pandas_numpy_column.hpp +++ b/src/duckdb_py/include/duckdb_python/pandas/column/pandas_numpy_column.hpp @@ -2,20 +2,18 @@ #include "duckdb_python/pandas/pandas_column.hpp" #include "duckdb_python/pybind11/pybind_wrapper.hpp" -#include "duckdb_python/numpy/numpy_array.hpp" namespace duckdb { class PandasNumpyColumn : public PandasColumn { public: - PandasNumpyColumn(NumpyArray array_p) : PandasColumn(PandasColumnBackend::NUMPY), array(std::move(array_p)) { - auto &arr = array.GetArray(); - D_ASSERT(py::hasattr(arr, "strides")); - stride = arr.attr("strides").attr("__getitem__")(0).cast(); + PandasNumpyColumn(py::array array_p) : PandasColumn(PandasColumnBackend::NUMPY), array(std::move(array_p)) { + D_ASSERT(py::hasattr(array, "strides")); + stride = array.attr("strides").attr("__getitem__")(0).cast(); } public: - NumpyArray array; + py::array array; idx_t stride; }; diff --git a/src/duckdb_py/include/duckdb_python/pandas/pandas_analyzer.hpp b/src/duckdb_py/include/duckdb_python/pandas/pandas_analyzer.hpp index 7b6501c8..70098c33 100644 --- a/src/duckdb_py/include/duckdb_python/pandas/pandas_analyzer.hpp +++ b/src/duckdb_py/include/duckdb_python/pandas/pandas_analyzer.hpp @@ -12,13 +12,14 @@ #include "duckdb/main/config.hpp" #include "duckdb_python/pybind11/pybind_wrapper.hpp" #include "duckdb_python/pybind11/gil_wrapper.hpp" +#include "duckdb_python/numpy/numpy_type.hpp" #include "duckdb_python/python_conversion.hpp" namespace duckdb { class PandasAnalyzer { public: - explicit PandasAnalyzer(ClientContext &context) : context(context) { + explicit PandasAnalyzer(const ClientContext &context) { analyzed_type = LogicalType::SQLNULL; Value result; @@ -47,7 +48,6 @@ class PandasAnalyzer { PythonGILWrapper gil; //! The resulting analyzed type LogicalType analyzed_type; - ClientContext &context; }; } // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/pandas/pandas_bind.hpp b/src/duckdb_py/include/duckdb_python/pandas/pandas_bind.hpp index 805f7cf7..5b58de59 100644 --- a/src/duckdb_py/include/duckdb_python/pandas/pandas_bind.hpp +++ b/src/duckdb_py/include/duckdb_python/pandas/pandas_bind.hpp @@ -3,7 +3,6 @@ #include "duckdb_python/pybind11/pybind_wrapper.hpp" #include "duckdb_python/pybind11/python_object_container.hpp" #include "duckdb_python/numpy/numpy_type.hpp" -#include "duckdb_python/numpy/numpy_array.hpp" #include "duckdb/common/helper.hpp" #include "duckdb_python/pandas/pandas_column.hpp" @@ -12,15 +11,15 @@ namespace duckdb { class ClientContext; struct RegisteredArray { - explicit RegisteredArray(NumpyArray numpy_array) : numpy_array(std::move(numpy_array)) { + explicit RegisteredArray(py::array numpy_array) : numpy_array(std::move(numpy_array)) { } - NumpyArray numpy_array; + py::array numpy_array; }; struct PandasColumnBindData { NumpyType numpy_type; - std::unique_ptr pandas_col; - std::unique_ptr mask; + unique_ptr pandas_col; + unique_ptr mask; //! Only for categorical types string internal_categorical_type; //! Hold ownership of objects created during scanning @@ -28,7 +27,7 @@ struct PandasColumnBindData { }; struct Pandas { - static void Bind(ClientContext &config, py::handle df, vector &out, + static void Bind(const ClientContext &config, py::handle df, vector &out, vector &return_types, vector &names); }; diff --git a/src/duckdb_py/include/duckdb_python/pandas/pandas_scan.hpp b/src/duckdb_py/include/duckdb_python/pandas/pandas_scan.hpp index 97c7a841..0ef9a24c 100644 --- a/src/duckdb_py/include/duckdb_python/pandas/pandas_scan.hpp +++ b/src/duckdb_py/include/duckdb_python/pandas/pandas_scan.hpp @@ -51,8 +51,7 @@ struct PandasScanFunction : public TableFunction { // Helper function that transform pandas df names to make them work with our binder static py::object PandasReplaceCopiedNames(const py::object &original_df); - static void PandasBackendScanSwitch(ClientContext &context, PandasColumnBindData &bind_data, idx_t count, - idx_t offset, Vector &out); + static void PandasBackendScanSwitch(PandasColumnBindData &bind_data, idx_t count, idx_t offset, Vector &out); static void PandasSerialize(Serializer &serializer, const optional_ptr bind_data, const TableFunction &function); diff --git a/src/duckdb_py/include/duckdb_python/pybind11/conversions/enum_string_caster.hpp b/src/duckdb_py/include/duckdb_python/pybind11/conversions/enum_string_caster.hpp deleted file mode 100644 index 0bb72026..00000000 --- a/src/duckdb_py/include/duckdb_python/pybind11/conversions/enum_string_caster.hpp +++ /dev/null @@ -1,96 +0,0 @@ -#pragma once - -#include -#include -#include - -//===----------------------------------------------------------------------===// -// Reusable pybind11 type_caster macros for "string / integer or enum" arguments -//===----------------------------------------------------------------------===// -// -// Several DuckDB enums are exposed to Python so that a binding parameter typed as -// the enum also accepts a string (and, for most, an integer) naming one of its -// values, while still accepting an actual registered enum instance. Every one of -// these casters had an identical shape: -// -// - if the source is a Python str -> value = FromString(...) -// - if the source is a Python int -> value = FromInteger(...) (optional) -// - otherwise delegate to a *local* type_caster_base for the registered -// enum instance. -// -// The macros below collapse that boilerplate into a single invocation per enum so -// the eventual nanobind port is a one-place change. Behavior is intentionally -// identical to the hand-written casters they replace. -// -// IMPORTANT (matches the original per-file notes): these casters own their value -// via PYBIND11_TYPE_CASTER and delegate ONLY the registered-instance case to a -// local base caster -- they do NOT inherit type_caster_base. Inheriting the base -// while also writing custom branches is what historically made a caster accept -// str XOR the enum depending on include visibility. Each specialization must be -// visible in every TU that converts the type (they live under the universally -// included pybind_wrapper.hpp umbrella), otherwise it is UB. -// -// Invoke these macros at GLOBAL scope (outside any namespace); each expands to a -// full `namespace pybind11 { namespace detail { ... } }` specialization. Pass -// fully-qualified names (e.g. duckdb::ExplainTypeFromString) for the conversion -// functions and the enum type. - -//! str + int + registered-enum form. -#define DUCKDB_PY_ENUM_STRING_INT_CASTER(EnumType, FromStringFn, FromIntegerFn, NameLiteral) \ - namespace PYBIND11_NAMESPACE { \ - namespace detail { \ - template <> \ - struct type_caster { \ - PYBIND11_TYPE_CASTER(EnumType, const_name(NameLiteral)); \ - \ - bool load(handle src, bool convert) { \ - if (isinstance(src)) { \ - value = FromStringFn(src.cast()); \ - return true; \ - } \ - if (isinstance(src)) { \ - value = FromIntegerFn(src.cast()); \ - return true; \ - } \ - type_caster_base base; \ - if (!base.load(src, convert)) { \ - return false; \ - } \ - value = *static_cast(base); \ - return true; \ - } \ - \ - static handle cast(EnumType src, return_value_policy policy, handle parent) { \ - return type_caster_base::cast(src, policy, parent); \ - } \ - }; \ - } /* namespace detail */ \ - } /* namespace PYBIND11_NAMESPACE */ - -//! str + registered-enum form (no integer accepted). -#define DUCKDB_PY_ENUM_STRING_CASTER(EnumType, FromStringFn, NameLiteral) \ - namespace PYBIND11_NAMESPACE { \ - namespace detail { \ - template <> \ - struct type_caster { \ - PYBIND11_TYPE_CASTER(EnumType, const_name(NameLiteral)); \ - \ - bool load(handle src, bool convert) { \ - if (isinstance(src)) { \ - value = FromStringFn(src.cast()); \ - return true; \ - } \ - type_caster_base base; \ - if (!base.load(src, convert)) { \ - return false; \ - } \ - value = *static_cast(base); \ - return true; \ - } \ - \ - static handle cast(EnumType src, return_value_policy policy, handle parent) { \ - return type_caster_base::cast(src, policy, parent); \ - } \ - }; \ - } /* namespace detail */ \ - } /* namespace PYBIND11_NAMESPACE */ diff --git a/src/duckdb_py/include/duckdb_python/pybind11/conversions/exception_handling_enum.hpp b/src/duckdb_py/include/duckdb_python/pybind11/conversions/exception_handling_enum.hpp index 94adf3d7..acf407fe 100644 --- a/src/duckdb_py/include/duckdb_python/pybind11/conversions/exception_handling_enum.hpp +++ b/src/duckdb_py/include/duckdb_python/pybind11/conversions/exception_handling_enum.hpp @@ -3,35 +3,70 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/string_util.hpp" -#include "duckdb_python/pybind11/conversions/enum_string_caster.hpp" + +using duckdb::InvalidInputException; +using duckdb::string; +using duckdb::StringUtil; namespace duckdb { enum class PythonExceptionHandling : uint8_t { FORWARD_ERROR, RETURN_NULL }; -inline PythonExceptionHandling PythonExceptionHandlingFromString(const string &type) { +} // namespace duckdb + +using duckdb::PythonExceptionHandling; + +namespace py = pybind11; + +static PythonExceptionHandling PythonExceptionHandlingFromString(const string &type) { auto ltype = StringUtil::Lower(type); if (ltype.empty() || ltype == "default") { return PythonExceptionHandling::FORWARD_ERROR; - } - if (ltype == "return_null") { + } else if (ltype == "return_null") { return PythonExceptionHandling::RETURN_NULL; + } else { + throw InvalidInputException("'%s' is not a recognized type for 'exception_handling'", type); } - throw InvalidInputException("'%s' is not a recognized type for 'exception_handling'", type); } -inline PythonExceptionHandling PythonExceptionHandlingFromInteger(int64_t value) { +static PythonExceptionHandling PythonExceptionHandlingFromInteger(int64_t value) { if (value == 0) { return PythonExceptionHandling::FORWARD_ERROR; - } - if (value == 1) { + } else if (value == 1) { return PythonExceptionHandling::RETURN_NULL; + } else { + throw InvalidInputException("'%d' is not a recognized type for 'exception_handling'", value); } - throw InvalidInputException("'%d' is not a recognized type for 'exception_handling'", value); } -} // namespace duckdb +namespace PYBIND11_NAMESPACE { +namespace detail { + +template <> +struct type_caster : public type_caster_base { + using base = type_caster_base; + PythonExceptionHandling tmp; + +public: + bool load(handle src, bool convert) { + if (base::load(src, convert)) { + return true; + } else if (py::isinstance(src)) { + tmp = PythonExceptionHandlingFromString(py::str(src)); + value = &tmp; + return true; + } else if (py::isinstance(src)) { + tmp = PythonExceptionHandlingFromInteger(src.cast()); + value = &tmp; + return true; + } + return false; + } + + static handle cast(PythonExceptionHandling src, return_value_policy policy, handle parent) { + return base::cast(src, policy, parent); + } +}; -//! See enum_string_caster.hpp for the rationale (composition over inheritance, umbrella visibility). -DUCKDB_PY_ENUM_STRING_INT_CASTER(duckdb::PythonExceptionHandling, duckdb::PythonExceptionHandlingFromString, - duckdb::PythonExceptionHandlingFromInteger, "PythonExceptionHandling") +} // namespace detail +} // namespace PYBIND11_NAMESPACE diff --git a/src/duckdb_py/include/duckdb_python/pybind11/conversions/explain_enum.hpp b/src/duckdb_py/include/duckdb_python/pybind11/conversions/explain_enum.hpp index e88f0c02..d92bdb56 100644 --- a/src/duckdb_py/include/duckdb_python/pybind11/conversions/explain_enum.hpp +++ b/src/duckdb_py/include/duckdb_python/pybind11/conversions/explain_enum.hpp @@ -4,33 +4,63 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/string_util.hpp" -#include "duckdb_python/pybind11/conversions/enum_string_caster.hpp" -namespace duckdb { +using duckdb::ExplainType; +using duckdb::InvalidInputException; +using duckdb::string; +using duckdb::StringUtil; -inline ExplainType ExplainTypeFromString(const string &type) { +namespace py = pybind11; + +static ExplainType ExplainTypeFromString(const string &type) { auto ltype = StringUtil::Lower(type); if (ltype.empty() || ltype == "standard") { return ExplainType::EXPLAIN_STANDARD; - } - if (ltype == "analyze") { + } else if (ltype == "analyze") { return ExplainType::EXPLAIN_ANALYZE; + } else { + throw InvalidInputException("Unrecognized type for 'explain'"); } - throw InvalidInputException("Unrecognized type for 'explain'"); } -inline ExplainType ExplainTypeFromInteger(int64_t value) { +static ExplainType ExplainTypeFromInteger(int64_t value) { if (value == 0) { return ExplainType::EXPLAIN_STANDARD; - } - if (value == 1) { + } else if (value == 1) { return ExplainType::EXPLAIN_ANALYZE; + } else { + throw InvalidInputException("Unrecognized type for 'explain'"); } - throw InvalidInputException("Unrecognized type for 'explain'"); } -} // namespace duckdb +namespace PYBIND11_NAMESPACE { +namespace detail { + +template <> +struct type_caster : public type_caster_base { + using base = type_caster_base; + ExplainType tmp; + +public: + bool load(handle src, bool convert) { + if (base::load(src, convert)) { + return true; + } else if (py::isinstance(src)) { + tmp = ExplainTypeFromString(py::str(src)); + value = &tmp; + return true; + } else if (py::isinstance(src)) { + tmp = ExplainTypeFromInteger(src.cast()); + value = &tmp; + return true; + } + return false; + } + + static handle cast(ExplainType src, return_value_policy policy, handle parent) { + return base::cast(src, policy, parent); + } +}; -//! See enum_string_caster.hpp for the rationale (composition over inheritance, umbrella visibility). -DUCKDB_PY_ENUM_STRING_INT_CASTER(duckdb::ExplainType, duckdb::ExplainTypeFromString, duckdb::ExplainTypeFromInteger, - "ExplainType") +} // namespace detail +} // namespace PYBIND11_NAMESPACE diff --git a/src/duckdb_py/include/duckdb_python/pybind11/conversions/identifier.hpp b/src/duckdb_py/include/duckdb_python/pybind11/conversions/identifier.hpp deleted file mode 100644 index 5364190f..00000000 --- a/src/duckdb_py/include/duckdb_python/pybind11/conversions/identifier.hpp +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once -#include "duckdb_python/pybind11/pybind_wrapper.hpp" -#include "duckdb/common/identifier.hpp" - -namespace py = pybind11; - -namespace PYBIND11_NAMESPACE { -namespace detail { -template <> -class type_caster { - PYBIND11_TYPE_CASTER(duckdb::Identifier, const_name("str")); - - // Python str -> Identifier - bool load(handle src, bool) { - if (!PyUnicode_Check(src.ptr())) { - return false; - } - value = duckdb::Identifier(src.cast()); - return true; - } - - // Identifier -> Python str - static handle cast(const duckdb::Identifier &id, return_value_policy, handle) { - auto &str_value = id.GetIdentifierName(); - return PyUnicode_FromStringAndSize(str_value.data(), py::ssize_t(str_value.size())); - } -}; -} // namespace detail -} // namespace PYBIND11_NAMESPACE \ No newline at end of file diff --git a/src/duckdb_py/include/duckdb_python/pybind11/conversions/null_handling_enum.hpp b/src/duckdb_py/include/duckdb_python/pybind11/conversions/null_handling_enum.hpp index e5172706..b9bbcf90 100644 --- a/src/duckdb_py/include/duckdb_python/pybind11/conversions/null_handling_enum.hpp +++ b/src/duckdb_py/include/duckdb_python/pybind11/conversions/null_handling_enum.hpp @@ -4,34 +4,63 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/string_util.hpp" -#include "duckdb_python/pybind11/conversions/enum_string_caster.hpp" -namespace duckdb { +using duckdb::FunctionNullHandling; +using duckdb::InvalidInputException; +using duckdb::string; +using duckdb::StringUtil; -inline FunctionNullHandling FunctionNullHandlingFromString(const string &type) { +namespace py = pybind11; + +static FunctionNullHandling FunctionNullHandlingFromString(const string &type) { auto ltype = StringUtil::Lower(type); if (ltype.empty() || ltype == "default") { return FunctionNullHandling::DEFAULT_NULL_HANDLING; - } - if (ltype == "special") { + } else if (ltype == "special") { return FunctionNullHandling::SPECIAL_HANDLING; + } else { + throw InvalidInputException("'%s' is not a recognized type for 'null_handling'", type); } - throw InvalidInputException("'%s' is not a recognized type for 'null_handling'", type); } -inline FunctionNullHandling FunctionNullHandlingFromInteger(int64_t value) { +static FunctionNullHandling FunctionNullHandlingFromInteger(int64_t value) { if (value == 0) { return FunctionNullHandling::DEFAULT_NULL_HANDLING; - } - if (value == 1) { + } else if (value == 1) { return FunctionNullHandling::SPECIAL_HANDLING; + } else { + throw InvalidInputException("'%d' is not a recognized type for 'null_handling'", value); } - throw InvalidInputException("'%d' is not a recognized type for 'null_handling'", value); } -} // namespace duckdb +namespace PYBIND11_NAMESPACE { +namespace detail { + +template <> +struct type_caster : public type_caster_base { + using base = type_caster_base; + FunctionNullHandling tmp; + +public: + bool load(handle src, bool convert) { + if (base::load(src, convert)) { + return true; + } else if (py::isinstance(src)) { + tmp = FunctionNullHandlingFromString(py::str(src)); + value = &tmp; + return true; + } else if (py::isinstance(src)) { + tmp = FunctionNullHandlingFromInteger(src.cast()); + value = &tmp; + return true; + } + return false; + } + + static handle cast(FunctionNullHandling src, return_value_policy policy, handle parent) { + return base::cast(src, policy, parent); + } +}; -//! See enum_string_caster.hpp for why this owns its value and delegates the enum case to a local base caster -//! instead of inheriting type_caster_base. Must stay visible in every TU (included from pybind_wrapper.hpp). -DUCKDB_PY_ENUM_STRING_INT_CASTER(duckdb::FunctionNullHandling, duckdb::FunctionNullHandlingFromString, - duckdb::FunctionNullHandlingFromInteger, "FunctionNullHandling") +} // namespace detail +} // namespace PYBIND11_NAMESPACE diff --git a/src/duckdb_py/include/duckdb_python/pybind11/conversions/pyconnection_default.hpp b/src/duckdb_py/include/duckdb_python/pybind11/conversions/pyconnection_default.hpp index ed35dc7e..d6ad6979 100644 --- a/src/duckdb_py/include/duckdb_python/pybind11/conversions/pyconnection_default.hpp +++ b/src/duckdb_py/include/duckdb_python/pybind11/conversions/pyconnection_default.hpp @@ -4,28 +4,20 @@ #include "duckdb/common/helper.hpp" using duckdb::DuckDBPyConnection; +using duckdb::shared_ptr; namespace py = pybind11; namespace PYBIND11_NAMESPACE { namespace detail { -// NANOBIND PORTING NOTE (None handling): -// This caster maps a Python None (or an omitted `connection=None` argument) to the module-level default -// connection. It works under pybind11 because pybind11 forwards None into a holder/pointer argument's caster -// `load()` by default (argument_record.none defaults to true). nanobind inverts this: it REJECTS None for -// bound-type (shared_ptr / pointer) arguments BEFORE the caster runs, unless the binding annotates the argument -// with `.none()`. So the eventual nanobind port must (1) keep this None -> DefaultConnection() branch AND -// (2) add `.none()` to every `connection` argument that currently defaults to `py::none()` (see -// NANOBIND_NONE_AUDIT.md -- 81 sites in duckdb_python.cpp). Object-family arguments (py::object / Optional) -// do not need this annotation; their value casters accept None directly. template <> -class type_caster> - : public copyable_holder_caster> { +class type_caster> + : public copyable_holder_caster> { using type = DuckDBPyConnection; - using holder_caster = copyable_holder_caster>; + using holder_caster = copyable_holder_caster>; // This is used to generate documentation on duckdb-web - PYBIND11_TYPE_CASTER(std::shared_ptr, const_name("duckdb.DuckDBPyConnection")); + PYBIND11_TYPE_CASTER(shared_ptr, const_name("duckdb.DuckDBPyConnection")); bool load(handle src, bool convert) { if (py::none().is(src)) { @@ -35,19 +27,17 @@ class type_caster> if (!holder_caster::load(src, convert)) { return false; } - // pybind11's std::shared_ptr holder_caster (smart_holder bakein) has no `holder` member like the - // generic template did for duckdb::shared_ptr; extract the loaded pointer via its conversion operator. - value = static_cast &>(static_cast(*this)); + value = std::move(holder); return true; } - static handle cast(std::shared_ptr base, return_value_policy rvp, handle h) { + static handle cast(shared_ptr base, return_value_policy rvp, handle h) { return holder_caster::cast(base, rvp, h); } }; template <> -struct is_holder_type> : std::true_type {}; +struct is_holder_type> : std::true_type {}; } // namespace detail } // namespace PYBIND11_NAMESPACE diff --git a/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp b/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp index 34325262..70fc2982 100644 --- a/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp +++ b/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp @@ -3,7 +3,10 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/string_util.hpp" -#include "duckdb_python/pybind11/conversions/enum_string_caster.hpp" + +using duckdb::InvalidInputException; +using duckdb::string; +using duckdb::StringUtil; namespace duckdb { @@ -42,7 +45,34 @@ struct PythonCSVLineTerminator { } // namespace duckdb -//! See enum_string_caster.hpp for the rationale (composition over inheritance, umbrella visibility). -//! Only a string or the enum itself are accepted (no integer form). -DUCKDB_PY_ENUM_STRING_CASTER(duckdb::PythonCSVLineTerminator::Type, duckdb::PythonCSVLineTerminator::FromString, - "CSVLineTerminator") +using duckdb::PythonCSVLineTerminator; + +namespace py = pybind11; + +namespace PYBIND11_NAMESPACE { +namespace detail { + +template <> +struct type_caster : public type_caster_base { + using base = type_caster_base; + PythonCSVLineTerminator::Type tmp; + +public: + bool load(handle src, bool convert) { + if (base::load(src, convert)) { + return true; + } else if (py::isinstance(src)) { + tmp = duckdb::PythonCSVLineTerminator::FromString(py::str(src)); + value = &tmp; + return true; + } + return false; + } + + static handle cast(PythonCSVLineTerminator::Type src, return_value_policy policy, handle parent) { + return base::cast(src, policy, parent); + } +}; + +} // namespace detail +} // namespace PYBIND11_NAMESPACE diff --git a/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_udf_type_enum.hpp b/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_udf_type_enum.hpp index 13799ba0..6a224090 100644 --- a/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_udf_type_enum.hpp +++ b/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_udf_type_enum.hpp @@ -3,38 +3,70 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/string_util.hpp" -#include "duckdb_python/pybind11/conversions/enum_string_caster.hpp" + +using duckdb::InvalidInputException; +using duckdb::string; +using duckdb::StringUtil; namespace duckdb { enum class PythonUDFType : uint8_t { NATIVE, ARROW }; -inline PythonUDFType PythonUDFTypeFromString(const string &type) { +} // namespace duckdb + +using duckdb::PythonUDFType; + +namespace py = pybind11; + +static PythonUDFType PythonUDFTypeFromString(const string &type) { auto ltype = StringUtil::Lower(type); if (ltype.empty() || ltype == "default" || ltype == "native") { return PythonUDFType::NATIVE; - } - if (ltype == "arrow") { + } else if (ltype == "arrow") { return PythonUDFType::ARROW; + } else { + throw InvalidInputException("'%s' is not a recognized type for 'udf_type'", type); } - throw InvalidInputException("'%s' is not a recognized type for 'udf_type'", type); } -inline PythonUDFType PythonUDFTypeFromInteger(int64_t value) { +static PythonUDFType PythonUDFTypeFromInteger(int64_t value) { if (value == 0) { return PythonUDFType::NATIVE; - } - if (value == 1) { + } else if (value == 1) { return PythonUDFType::ARROW; + } else { + throw InvalidInputException("'%d' is not a recognized type for 'udf_type'", value); } - throw InvalidInputException("'%d' is not a recognized type for 'udf_type'", value); } -} // namespace duckdb +namespace PYBIND11_NAMESPACE { +namespace detail { + +template <> +struct type_caster : public type_caster_base { + using base = type_caster_base; + PythonUDFType tmp; + +public: + bool load(handle src, bool convert) { + if (base::load(src, convert)) { + return true; + } else if (py::isinstance(src)) { + tmp = PythonUDFTypeFromString(py::str(src)); + value = &tmp; + return true; + } else if (py::isinstance(src)) { + tmp = PythonUDFTypeFromInteger(src.cast()); + value = &tmp; + return true; + } + return false; + } + + static handle cast(PythonUDFType src, return_value_policy policy, handle parent) { + return base::cast(src, policy, parent); + } +}; -//! Accepts the registered PythonUDFType enum, or a string / integer naming one. See enum_string_caster.hpp for -//! the rationale (this owns its value via PYBIND11_TYPE_CASTER and delegates only the registered-enum case to a -//! local base caster instead of inheriting type_caster_base). Keeping the binding parameter typed as the enum -//! preserves the type + default in help()/stubs. -DUCKDB_PY_ENUM_STRING_INT_CASTER(duckdb::PythonUDFType, duckdb::PythonUDFTypeFromString, - duckdb::PythonUDFTypeFromInteger, "PythonUDFType") +} // namespace detail +} // namespace PYBIND11_NAMESPACE diff --git a/src/duckdb_py/include/duckdb_python/pybind11/conversions/render_mode_enum.hpp b/src/duckdb_py/include/duckdb_python/pybind11/conversions/render_mode_enum.hpp index a6e0e6ea..72661f8c 100644 --- a/src/duckdb_py/include/duckdb_python/pybind11/conversions/render_mode_enum.hpp +++ b/src/duckdb_py/include/duckdb_python/pybind11/conversions/render_mode_enum.hpp @@ -5,26 +5,54 @@ #include "duckdb/common/string_util.hpp" #include "duckdb/common/box_renderer.hpp" #include "duckdb/common/enum_util.hpp" -#include "duckdb_python/pybind11/conversions/enum_string_caster.hpp" -namespace duckdb { +using duckdb::InvalidInputException; +using duckdb::RenderMode; +using duckdb::string; +using duckdb::StringUtil; -inline RenderMode RenderModeFromString(const string &value) { - return EnumUtil::FromString(value.empty() ? "ROWS" : value); -} +namespace py = pybind11; -inline RenderMode RenderModeFromInteger(int64_t value) { +static RenderMode RenderModeFromInteger(int64_t value) { if (value == 0) { return RenderMode::ROWS; - } - if (value == 1) { + } else if (value == 1) { return RenderMode::COLUMNS; + } else { + throw InvalidInputException("Unrecognized type for 'render_mode'"); } - throw InvalidInputException("Unrecognized type for 'render_mode'"); } -} // namespace duckdb +namespace PYBIND11_NAMESPACE { +namespace detail { + +template <> +struct type_caster : public type_caster_base { + using base = type_caster_base; + RenderMode tmp; + +public: + bool load(handle src, bool convert) { + if (base::load(src, convert)) { + return true; + } else if (py::isinstance(src)) { + string render_mode_str = py::str(src); + auto render_mode = + duckdb::EnumUtil::FromString(render_mode_str.empty() ? "ROWS" : render_mode_str); + value = &render_mode; + return true; + } else if (py::isinstance(src)) { + tmp = RenderModeFromInteger(src.cast()); + value = &tmp; + return true; + } + return false; + } + + static handle cast(RenderMode src, return_value_policy policy, handle parent) { + return base::cast(src, policy, parent); + } +}; -//! See enum_string_caster.hpp for the rationale (composition over inheritance, umbrella visibility). -DUCKDB_PY_ENUM_STRING_INT_CASTER(duckdb::RenderMode, duckdb::RenderModeFromString, duckdb::RenderModeFromInteger, - "RenderMode") +} // namespace detail +} // namespace PYBIND11_NAMESPACE diff --git a/src/duckdb_py/include/duckdb_python/pybind11/pybind_wrapper.hpp b/src/duckdb_py/include/duckdb_python/pybind11/pybind_wrapper.hpp index 618ab73a..d51ddea2 100644 --- a/src/duckdb_py/include/duckdb_python/pybind11/pybind_wrapper.hpp +++ b/src/duckdb_py/include/duckdb_python/pybind11/pybind_wrapper.hpp @@ -11,15 +11,6 @@ #include #include #include -// Custom type_caster specializations must be visible in every TU that converts the type (otherwise it is -// UB); keep ALL of them here, in this universally-included umbrella, never in scattered per-feature headers. -#include "duckdb_python/pybind11/conversions/identifier.hpp" -#include "duckdb_python/pybind11/conversions/python_udf_type_enum.hpp" -#include "duckdb_python/pybind11/conversions/null_handling_enum.hpp" -#include "duckdb_python/pybind11/conversions/exception_handling_enum.hpp" -#include "duckdb_python/pybind11/conversions/explain_enum.hpp" -#include "duckdb_python/pybind11/conversions/render_mode_enum.hpp" -#include "duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/common/assert.hpp" #include "duckdb/common/helper.hpp" diff --git a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp index 4fac0b52..74cdf6ce 100644 --- a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp @@ -10,6 +10,7 @@ #include "duckdb_python/arrow/arrow_array_stream.hpp" #include "duckdb.hpp" #include "duckdb_python/pybind11/pybind_wrapper.hpp" +#include "duckdb/common/unordered_map.hpp" #include "duckdb_python/import_cache/python_import_cache.hpp" #include "duckdb_python/numpy/numpy_type.hpp" #include "duckdb_python/pyrelation.hpp" @@ -22,6 +23,7 @@ #include "duckdb/function/scalar_function.hpp" #include "duckdb_python/pybind11/conversions/exception_handling_enum.hpp" #include "duckdb_python/pybind11/conversions/python_udf_type_enum.hpp" +#include "duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp" #include "duckdb/common/shared_ptr.hpp" namespace duckdb { @@ -53,11 +55,11 @@ struct DefaultConnectionHolder { DefaultConnectionHolder &operator=(DefaultConnectionHolder &&other) = delete; public: - std::shared_ptr Get(); - void Set(std::shared_ptr conn); + shared_ptr Get(); + void Set(shared_ptr conn); private: - std::shared_ptr connection; + shared_ptr connection; mutex l; }; @@ -129,7 +131,7 @@ struct ConnectionGuard { void SetConnection(unique_ptr con) { connection = std::move(con); } - void SetResult(std::unique_ptr res) { + void SetResult(unique_ptr res) { result = std::move(res); } @@ -141,10 +143,10 @@ struct ConnectionGuard { private: shared_ptr database; unique_ptr connection; - std::unique_ptr result; + unique_ptr result; }; -struct DuckDBPyConnection : public std::enable_shared_from_this { +struct DuckDBPyConnection : public enable_shared_from_this { private: class Cursors { public: @@ -152,12 +154,12 @@ struct DuckDBPyConnection : public std::enable_shared_from_this conn); + void AddCursor(shared_ptr conn); void ClearCursors(); private: mutex lock; - vector> cursors; + vector> cursors; }; public: @@ -191,7 +193,7 @@ struct DuckDBPyConnection : public std::enable_shared_from_this internal_object_filesystem; + shared_ptr internal_object_filesystem; case_insensitive_map_t> registered_functions; case_insensitive_set_t registered_objects; @@ -204,7 +206,7 @@ struct DuckDBPyConnection : public std::enable_shared_from_this Enter(); + shared_ptr Enter(); static void Exit(DuckDBPyConnection &self, const py::object &exc_type, const py::object &exc, const py::object &traceback); @@ -212,16 +214,16 @@ struct DuckDBPyConnection : public std::enable_shared_from_this DefaultConnection(); - static void SetDefaultConnection(std::shared_ptr conn); + static shared_ptr DefaultConnection(); + static void SetDefaultConnection(shared_ptr conn); static PythonImportCache *ImportCache(); static bool IsInteractive(); - std::unique_ptr ReadCSV(const py::object &name, py::kwargs &kwargs); + unique_ptr ReadCSV(const py::object &name, py::kwargs &kwargs); py::list ExtractStatements(const string &query); - std::unique_ptr ReadJSON( + unique_ptr ReadJSON( const py::object &name, const Optional &columns = py::none(), const Optional &sample_size = py::none(), const Optional &maximum_depth = py::none(), const Optional &records = py::none(), const Optional &format = py::none(), @@ -237,27 +239,28 @@ struct DuckDBPyConnection : public std::enable_shared_from_this &union_by_name = py::none(), const Optional &hive_types = py::none(), const Optional &hive_types_autocast = py::none()); - std::shared_ptr MapType(const std::shared_ptr &key_type, - const std::shared_ptr &value_type); - std::shared_ptr StructType(const py::object &fields); - std::shared_ptr ListType(const std::shared_ptr &type); - std::shared_ptr ArrayType(const std::shared_ptr &type, idx_t size); - std::shared_ptr UnionType(const py::object &members); - std::shared_ptr EnumType(const string &name, const std::shared_ptr &type, - const py::list &values_p); - std::shared_ptr DecimalType(int width, int scale); - std::shared_ptr StringType(const string &collation = string()); - std::shared_ptr Type(const string &type_str); - - std::shared_ptr RegisterScalarUDF( - const string &name, const py::function &udf, const py::object &arguments = py::none(), - const std::shared_ptr &return_type = nullptr, PythonUDFType type = PythonUDFType::NATIVE, - FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING, - PythonExceptionHandling exception_handling = PythonExceptionHandling::FORWARD_ERROR, bool side_effects = false); - - std::shared_ptr UnregisterUDF(const string &name); - - std::shared_ptr ExecuteMany(const py::object &query, py::object params = py::list()); + shared_ptr MapType(const shared_ptr &key_type, + const shared_ptr &value_type); + shared_ptr StructType(const py::object &fields); + shared_ptr ListType(const shared_ptr &type); + shared_ptr ArrayType(const shared_ptr &type, idx_t size); + shared_ptr UnionType(const py::object &members); + shared_ptr EnumType(const string &name, const shared_ptr &type, + const py::list &values_p); + shared_ptr DecimalType(int width, int scale); + shared_ptr StringType(const string &collation = string()); + shared_ptr Type(const string &type_str); + + shared_ptr + RegisterScalarUDF(const string &name, const py::function &udf, const py::object &arguments = py::none(), + const shared_ptr &return_type = nullptr, PythonUDFType type = PythonUDFType::NATIVE, + FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING, + PythonExceptionHandling exception_handling = PythonExceptionHandling::FORWARD_ERROR, + bool side_effects = false); + + shared_ptr UnregisterUDF(const string &name); + + shared_ptr ExecuteMany(const py::object &query, py::object params = py::list()); void ExecuteImmediately(vector> statements); unique_ptr PrepareQuery(unique_ptr statement); @@ -265,12 +268,12 @@ struct DuckDBPyConnection : public std::enable_shared_from_this PrepareAndExecuteInternal(unique_ptr statement, py::object params = py::list()); - std::shared_ptr Execute(const py::object &query, py::object params = py::list()); - std::shared_ptr ExecuteFromString(const string &query); + shared_ptr Execute(const py::object &query, py::object params = py::list()); + shared_ptr ExecuteFromString(const string &query); - std::shared_ptr Append(const string &name, const PandasDataFrame &value, bool by_name); + shared_ptr Append(const string &name, const PandasDataFrame &value, bool by_name); - std::shared_ptr RegisterPythonObject(const string &name, const py::object &python_object); + shared_ptr RegisterPythonObject(const string &name, const py::object &python_object); void InstallExtension(const string &extension, bool force_install = false, const py::object &repository = py::none(), const py::object &repository_url = py::none(), @@ -278,36 +281,35 @@ struct DuckDBPyConnection : public std::enable_shared_from_this RunQuery(const py::object &query, string alias = "", - py::object params = py::list()); + unique_ptr RunQuery(const py::object &query, string alias = "", py::object params = py::list()); - std::unique_ptr Table(const string &tname); + unique_ptr Table(const string &tname); - std::unique_ptr Values(const py::args ¶ms); + unique_ptr Values(const py::args ¶ms); - std::unique_ptr View(const string &vname); + unique_ptr View(const string &vname); - std::unique_ptr TableFunction(const string &fname, py::object params = py::list()); + unique_ptr TableFunction(const string &fname, py::object params = py::list()); - std::unique_ptr FromDF(const PandasDataFrame &value); + unique_ptr FromDF(const PandasDataFrame &value); - std::unique_ptr FromParquet(const py::object &path_or_buffer, bool binary_as_string, - bool file_row_number, bool filename, bool hive_partitioning, - bool union_by_name, const py::object &compression = py::none()); + unique_ptr FromParquet(const py::object &path_or_buffer, bool binary_as_string, + bool file_row_number, bool filename, bool hive_partitioning, + bool union_by_name, const py::object &compression = py::none()); - std::unique_ptr FromArrow(py::object &arrow_object); + unique_ptr FromArrow(py::object &arrow_object); unordered_set GetTableNames(const string &query, bool qualified); - std::shared_ptr UnregisterPythonObject(const string &name); + shared_ptr UnregisterPythonObject(const string &name); - std::shared_ptr Begin(); + shared_ptr Begin(); - std::shared_ptr Commit(); + shared_ptr Commit(); - std::shared_ptr Rollback(); + shared_ptr Rollback(); - std::shared_ptr Checkpoint(); + shared_ptr Checkpoint(); void Close(); @@ -318,7 +320,7 @@ struct DuckDBPyConnection : public std::enable_shared_from_this Cursor(); + shared_ptr Cursor(); Optional GetDescription(); @@ -344,12 +346,10 @@ struct DuckDBPyConnection : public std::enable_shared_from_this Connect(const py::object &database, bool read_only, - const py::dict &config); + static shared_ptr Connect(const py::object &database, bool read_only, const py::dict &config); - static vector TransformPythonParamList(ClientContext &context, const py::handle ¶ms); - static identifier_map_t TransformPythonParamDict(ClientContext &context, - const py::dict ¶ms); + static vector TransformPythonParamList(const py::handle ¶ms); + static case_insensitive_map_t TransformPythonParamDict(const py::dict ¶ms); void RegisterFilesystem(AbstractFileSystem filesystem); void UnregisterFilesystem(const py::str &name); @@ -357,10 +357,15 @@ struct DuckDBPyConnection : public std::enable_shared_from_this import_cache; + static bool IsPandasDataframe(const py::object &object); static PyArrowObjectType GetArrowType(const py::handle &obj); static bool IsAcceptedArrowObject(const py::object &object); @@ -369,15 +374,18 @@ struct DuckDBPyConnection : public std::enable_shared_from_this CompletePendingQuery(PendingQueryResult &pending_query); private: - std::unique_ptr CreateRelation(shared_ptr rel); - std::unique_ptr CreateRelation(std::shared_ptr result); + unique_ptr CreateRelation(shared_ptr rel); + unique_ptr CreateRelation(shared_ptr result); PathLike GetPathLike(const py::object &object); ScalarFunction CreateScalarUDF(const string &name, const py::function &udf, const py::object ¶meters, - const std::shared_ptr &return_type, bool vectorized, + const shared_ptr &return_type, bool vectorized, FunctionNullHandling null_handling, PythonExceptionHandling exception_handling, bool side_effects); + void RegisterArrowObject(const py::object &arrow_object, const string &name); vector> GetStatements(const py::object &query); + static PythonEnvironmentType environment; + static std::string formatted_python_version; static void DetectEnvironment(); }; diff --git a/src/duckdb_py/include/duckdb_python/pyrelation.hpp b/src/duckdb_py/include/duckdb_python/pyrelation.hpp index f77c937f..50f39b5f 100644 --- a/src/duckdb_py/include/duckdb_python/pyrelation.hpp +++ b/src/duckdb_py/include/duckdb_python/pyrelation.hpp @@ -11,18 +11,24 @@ #include "duckdb_python/pybind11/pybind_wrapper.hpp" #include "duckdb.hpp" #include "duckdb_python/arrow/arrow_array_stream.hpp" +#include "duckdb/main/external_dependencies.hpp" #include "duckdb_python/numpy/numpy_type.hpp" +#include "duckdb_python/pybind11/registered_py_object.hpp" #include "duckdb_python/pyresult.hpp" +#include "duckdb/parser/statement/explain_statement.hpp" +#include "duckdb_python/pybind11/conversions/explain_enum.hpp" #include "duckdb_python/pybind11/conversions/render_mode_enum.hpp" +#include "duckdb_python/pybind11/conversions/null_handling_enum.hpp" #include "duckdb_python/pybind11/dataframe.hpp" #include "duckdb_python/python_objects.hpp" +#include "duckdb/common/box_renderer.hpp" namespace duckdb { struct DuckDBPyRelation { public: explicit DuckDBPyRelation(shared_ptr rel); - explicit DuckDBPyRelation(std::shared_ptr result); + explicit DuckDBPyRelation(shared_ptr result); ~DuckDBPyRelation(); public: @@ -32,106 +38,99 @@ struct DuckDBPyRelation { void Close(); - std::unique_ptr GetAttribute(const string &name); + unique_ptr GetAttribute(const string &name); py::str GetAlias(); - static std::unique_ptr EmptyResult(const shared_ptr &context, - const vector &types, vector names); + static unique_ptr EmptyResult(const shared_ptr &context, + const vector &types, vector names); - std::unique_ptr SetAlias(const string &expr); + unique_ptr SetAlias(const string &expr); - std::unique_ptr ProjectFromExpression(const string &expr); - std::unique_ptr ProjectFromTypes(const py::object &types); - std::unique_ptr Project(const py::args &args, const string &groups = ""); - std::unique_ptr Filter(const py::object &expr); - std::unique_ptr FilterFromExpression(const string &expr); - std::unique_ptr Limit(int64_t n, int64_t offset = 0); - std::unique_ptr Order(const string &expr); - std::unique_ptr Sort(const py::args &args); + unique_ptr ProjectFromExpression(const string &expr); + unique_ptr ProjectFromTypes(const py::object &types); + unique_ptr Project(const py::args &args, const string &groups = ""); + unique_ptr Filter(const py::object &expr); + unique_ptr FilterFromExpression(const string &expr); + unique_ptr Limit(int64_t n, int64_t offset = 0); + unique_ptr Order(const string &expr); + unique_ptr Sort(const py::args &args); - std::unique_ptr Aggregate(const py::object &expr, const string &groups = ""); + unique_ptr Aggregate(const py::object &expr, const string &groups = ""); - std::unique_ptr GenericAggregator(const string &function_name, const string &aggregated_columns, - const string &groups = "", - const string &function_parameter = "", - const string &projected_columns = ""); + unique_ptr GenericAggregator(const string &function_name, const string &aggregated_columns, + const string &groups = "", const string &function_parameter = "", + const string &projected_columns = ""); /* General aggregate functions */ - std::unique_ptr AnyValue(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr ArgMax(const string &arg_column, const string &value_column, - const string &groups = "", const string &window_spec = "", - const string &projected_columns = ""); - std::unique_ptr ArgMin(const string &arg_column, const string &value_column, - const string &groups = "", const string &window_spec = "", - const string &projected_columns = ""); - std::unique_ptr Avg(const string &column, const string &groups = "", + unique_ptr AnyValue(const string &column, const string &groups = "", const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr BitAnd(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr BitOr(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr BitXor(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr BitStringAgg(const string &column, const Optional &min, - const Optional &max, const string &groups = "", - const string &window_spec = "", - const string &projected_columns = ""); - std::unique_ptr BoolAnd(const string &column, const string &groups = "", + unique_ptr ArgMax(const string &arg_column, const string &value_column, const string &groups = "", + const string &window_spec = "", const string &projected_columns = ""); + unique_ptr ArgMin(const string &arg_column, const string &value_column, const string &groups = "", + const string &window_spec = "", const string &projected_columns = ""); + unique_ptr Avg(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr BitAnd(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr BitOr(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr BitXor(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr BitStringAgg(const string &column, const Optional &min, + const Optional &max, const string &groups = "", const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr BoolOr(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr ValueCounts(const string &column, const string &groups = ""); - std::unique_ptr Count(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr FAvg(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr First(const string &column, const string &groups = "", - const string &projected_columns = ""); - std::unique_ptr FSum(const string &column, const string &groups = "", + unique_ptr BoolAnd(const string &column, const string &groups = "", + const string &window_spec = "", const string &projected_columns = ""); + unique_ptr BoolOr(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr ValueCounts(const string &column, const string &groups = ""); + unique_ptr Count(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr FAvg(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr First(const string &column, const string &groups = "", + const string &projected_columns = ""); + unique_ptr FSum(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr GeoMean(const string &column, const string &groups = "", + const string &projected_columns = ""); + unique_ptr Histogram(const string &column, const string &groups = "", const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr GeoMean(const string &column, const string &groups = "", - const string &projected_columns = ""); - std::unique_ptr Histogram(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr Last(const string &column, const string &groups = "", - const string &projected_columns = ""); - std::unique_ptr List(const string &column, const string &groups = "", + unique_ptr Last(const string &column, const string &groups = "", + const string &projected_columns = ""); + unique_ptr List(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr Max(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr Min(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr Product(const string &column, const string &groups = "", + const string &window_spec = "", const string &projected_columns = ""); + unique_ptr StringAgg(const string &column, const string &sep = ",", const string &groups = "", const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr Max(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr Min(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr Product(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr StringAgg(const string &column, const string &sep = ",", - const string &groups = "", const string &window_spec = "", - const string &projected_columns = ""); - std::unique_ptr Sum(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); + unique_ptr Sum(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); /* TODO: Approximate aggregate functions */ /* TODO: Statistical aggregate functions */ - std::unique_ptr Median(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr Mode(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr QuantileCont(const string &column, const py::object &q, const string &groups = "", - const string &window_spec = "", - const string &projected_columns = ""); - std::unique_ptr QuantileDisc(const string &column, const py::object &q, const string &groups = "", - const string &window_spec = "", - const string &projected_columns = ""); - std::unique_ptr StdPop(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr StdSamp(const string &column, const string &groups = "", + unique_ptr Median(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr Mode(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr QuantileCont(const string &column, const py::object &q, const string &groups = "", const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr VarPop(const string &column, const string &groups = "", - const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr VarSamp(const string &column, const string &groups = "", + unique_ptr QuantileDisc(const string &column, const py::object &q, const string &groups = "", const string &window_spec = "", const string &projected_columns = ""); + unique_ptr StdPop(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr StdSamp(const string &column, const string &groups = "", + const string &window_spec = "", const string &projected_columns = ""); + unique_ptr VarPop(const string &column, const string &groups = "", const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr VarSamp(const string &column, const string &groups = "", + const string &window_spec = "", const string &projected_columns = ""); - std::unique_ptr Describe(); + unique_ptr Describe(); string ToSQL(); @@ -141,36 +140,35 @@ struct DuckDBPyRelation { py::tuple Shape(); - std::unique_ptr Unique(const string &aggr_columns); + unique_ptr Unique(const string &aggr_columns); - std::unique_ptr GenericWindowFunction(const string &function_name, - const string &function_parameters, - const string &aggr_columns, const string &window_spec, - const bool &ignore_nulls, const string &projected_columns); + unique_ptr GenericWindowFunction(const string &function_name, const string &function_parameters, + const string &aggr_columns, const string &window_spec, + const bool &ignore_nulls, const string &projected_columns); /* General purpose window functions */ - std::unique_ptr RowNumber(const string &window_spec, const string &projected_columns); - std::unique_ptr Rank(const string &window_spec, const string &projected_columns); - std::unique_ptr DenseRank(const string &window_spec, const string &projected_columns); - std::unique_ptr PercentRank(const string &window_spec, const string &projected_columns); - std::unique_ptr CumeDist(const string &window_spec, const string &projected_columns); - std::unique_ptr FirstValue(const string &column, const string &window_spec = "", - const string &projected_columns = ""); - std::unique_ptr NTile(const string &window_spec, const int &num_buckets, - const string &projected_columns); - std::unique_ptr Lag(const string &column, const string &window_spec, const int &offset, - const string &default_value, const bool &ignore_nulls, - const string &projected_columns); - std::unique_ptr LastValue(const string &column, const string &window_spec = "", - const string &projected_columns = ""); - std::unique_ptr Lead(const string &column, const string &window_spec, const int &offset, - const string &default_value, const bool &ignore_nulls, - const string &projected_columns); - - std::unique_ptr NthValue(const string &column, const string &window_spec, const int &offset, - const bool &ignore_nulls, const string &projected_columns); - - std::unique_ptr Distinct(); + unique_ptr RowNumber(const string &window_spec, const string &projected_columns); + unique_ptr Rank(const string &window_spec, const string &projected_columns); + unique_ptr DenseRank(const string &window_spec, const string &projected_columns); + unique_ptr PercentRank(const string &window_spec, const string &projected_columns); + unique_ptr CumeDist(const string &window_spec, const string &projected_columns); + unique_ptr FirstValue(const string &column, const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr NTile(const string &window_spec, const int &num_buckets, + const string &projected_columns); + unique_ptr Lag(const string &column, const string &window_spec, const int &offset, + const string &default_value, const bool &ignore_nulls, + const string &projected_columns); + unique_ptr LastValue(const string &column, const string &window_spec = "", + const string &projected_columns = ""); + unique_ptr Lead(const string &column, const string &window_spec, const int &offset, + const string &default_value, const bool &ignore_nulls, + const string &projected_columns); + + unique_ptr NthValue(const string &column, const string &window_spec, const int &offset, + const bool &ignore_nulls, const string &projected_columns); + + unique_ptr Distinct(); PandasDataFrame FetchDF(bool date_as_object); @@ -200,16 +198,16 @@ struct DuckDBPyRelation { duckdb::pyarrow::RecordBatchReader ToRecordBatch(idx_t batch_size); - std::unique_ptr Union(DuckDBPyRelation *other); + unique_ptr Union(DuckDBPyRelation *other); - std::unique_ptr Except(DuckDBPyRelation *other); + unique_ptr Except(DuckDBPyRelation *other); - std::unique_ptr Intersect(DuckDBPyRelation *other); + unique_ptr Intersect(DuckDBPyRelation *other); - std::unique_ptr Map(py::function fun, Optional schema); + unique_ptr Map(py::function fun, Optional schema); - std::unique_ptr Join(DuckDBPyRelation *other, const py::object &condition, const string &type); - std::unique_ptr Cross(DuckDBPyRelation *other); + unique_ptr Join(DuckDBPyRelation *other, const py::object &condition, const string &type); + unique_ptr Cross(DuckDBPyRelation *other); void ToParquet(const string &filename, const py::object &compression = py::none(), const py::object &field_ids = py::none(), const py::object &row_group_size_bytes = py::none(), @@ -229,9 +227,9 @@ struct DuckDBPyRelation { const py::object &write_partition_columns = py::none()); // should this return a rel with the new view? - std::unique_ptr CreateView(const string &view_name, bool replace = true); + unique_ptr CreateView(const string &view_name, bool replace = true); - std::unique_ptr Query(const string &view_name, const string &sql_query); + unique_ptr Query(const string &view_name, const string &sql_query); // Update the internal result of the relation DuckDBPyRelation &Execute(); @@ -252,7 +250,7 @@ struct DuckDBPyRelation { const Optional &max_col_width, const Optional &null_value, const py::object &render_mode); - string Explain(ExplainType type, const string &format = ""); + string Explain(ExplainType type); static bool IsRelation(const py::object &object); @@ -265,8 +263,8 @@ struct DuckDBPyRelation { bool ContainsColumnByName(const string &name) const; void SetConnectionOwner(py::object owner); - std::unique_ptr DeriveRelation(shared_ptr new_rel); - std::unique_ptr DeriveRelation(std::shared_ptr result); + unique_ptr DeriveRelation(shared_ptr new_rel); + unique_ptr DeriveRelation(shared_ptr result); private: string ToStringInternal(const BoxRendererConfig &config, bool invalidate_cache = false); @@ -278,10 +276,10 @@ struct DuckDBPyRelation { const string &groups = "", const string &function_parameter = "", bool ignore_nulls = false, const string &projected_columns = "", const string &window_spec = ""); - std::unique_ptr ApplyAggOrWin(const string &function_name, const string &agg_columns, - const string &function_parameters = "", const string &groups = "", - const string &window_spec = "", - const string &projected_columns = "", bool ignore_nulls = false); + unique_ptr ApplyAggOrWin(const string &function_name, const string &agg_columns, + const string &function_parameters = "", const string &groups = "", + const string &window_spec = "", const string &projected_columns = "", + bool ignore_nulls = false); void AssertResult() const; void AssertResultOpen() const; @@ -298,7 +296,7 @@ struct DuckDBPyRelation { shared_ptr rel; vector types; vector names; - std::shared_ptr result; + shared_ptr result; std::string rendered_result; }; diff --git a/src/duckdb_py/include/duckdb_python/pyresult.hpp b/src/duckdb_py/include/duckdb_python/pyresult.hpp index 1a014824..d7da83cc 100644 --- a/src/duckdb_py/include/duckdb_python/pyresult.hpp +++ b/src/duckdb_py/include/duckdb_python/pyresult.hpp @@ -32,7 +32,7 @@ struct DuckDBPyResult { py::dict FetchNumpy(); py::dict FetchNumpyInternal(bool stream = false, idx_t vectors_per_chunk = 1, - std::unique_ptr conversion = nullptr); + unique_ptr conversion = nullptr); PandasDataFrame FetchDF(bool date_as_object); @@ -67,7 +67,7 @@ struct DuckDBPyResult { void ConvertDateTimeTypes(PandasDataFrame &df, bool date_as_object) const; unique_ptr FetchNext(QueryResult &result); unique_ptr FetchNextRaw(QueryResult &result); - std::unique_ptr InitializeNumpyConversion(bool pandas = false); + unique_ptr InitializeNumpyConversion(bool pandas = false); //! Re-feed an already-MATERIALIZED result (a ColumnDataCollection, e.g. from //! rel.execute()) back through the engine on the user's own context. The eager diff --git a/src/duckdb_py/include/duckdb_python/python_conversion.hpp b/src/duckdb_py/include/duckdb_python/python_conversion.hpp index 05715cbe..bad518ef 100644 --- a/src/duckdb_py/include/duckdb_python/python_conversion.hpp +++ b/src/duckdb_py/include/duckdb_python/python_conversion.hpp @@ -47,9 +47,8 @@ PythonObjectType GetPythonObjectType(py::handle &ele); LogicalType SniffPythonIntegerType(py::handle ele); bool DictionaryHasMapFormat(const PyDictionary &dict); -void TransformPythonObject(optional_ptr context, py::handle ele, Vector &vector, idx_t result_offset, +void TransformPythonObject(py::handle ele, Vector &vector, idx_t result_offset, bool nan_as_null = true); +Value TransformPythonValue(py::handle ele, const LogicalType &target_type = LogicalType::UNKNOWN, bool nan_as_null = true); -Value TransformPythonValue(optional_ptr context, py::handle ele, - const LogicalType &target_type = LogicalType::UNKNOWN, bool nan_as_null = true); } // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/pytype.hpp b/src/duckdb_py/include/duckdb_python/pytype.hpp index 87f56836..a6e13dfd 100644 --- a/src/duckdb_py/include/duckdb_python/pytype.hpp +++ b/src/duckdb_py/include/duckdb_python/pytype.hpp @@ -21,7 +21,7 @@ class PyUnionType : public py::object { static bool check_(const py::handle &object); }; -class DuckDBPyType : public std::enable_shared_from_this { +class DuckDBPyType : public enable_shared_from_this { public: explicit DuckDBPyType(LogicalType type); @@ -29,9 +29,9 @@ class DuckDBPyType : public std::enable_shared_from_this { static void Initialize(py::handle &m); public: - bool Equals(const std::shared_ptr &other) const; + bool Equals(const shared_ptr &other) const; bool EqualsString(const string &type_str) const; - std::shared_ptr GetAttribute(const string &name) const; + shared_ptr GetAttribute(const string &name) const; py::list Children() const; string ToString() const; const LogicalType &Type() const; diff --git a/src/duckdb_py/map.cpp b/src/duckdb_py/map.cpp index 10ea9774..9864f2de 100644 --- a/src/duckdb_py/map.cpp +++ b/src/duckdb_py/map.cpp @@ -23,10 +23,10 @@ struct MapFunctionData : public TableFunctionData { } PyObject *function; vector in_types, out_types; - vector in_names, out_names; + vector in_names, out_names; }; -static py::object FunctionCall(NumpyResultConversion &conversion, const vector &names, PyObject *function) { +static py::object FunctionCall(NumpyResultConversion &conversion, const vector &names, PyObject *function) { py::dict in_numpy_dict; for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { in_numpy_dict[names[col_idx].c_str()] = conversion.ToArray(col_idx); @@ -71,8 +71,8 @@ static bool ContainsNullType(const vector &types) { return false; } -static void OverrideNullType(vector &return_types, const vector &return_names, - const vector &original_types, const vector &original_names) { +static void OverrideNullType(vector &return_types, const vector &return_names, + const vector &original_types, const vector &original_names) { if (!ContainsNullType(return_types)) { // Nothing to override, none of the returned types are NULL return; @@ -115,15 +115,13 @@ unique_ptr BindExplicitSchema(unique_ptr function for (auto &item : schema) { auto name = item.first; auto type_p = item.second; - names.push_back(string(py::str(name))); + names.push_back(std::string(py::str(name))); // TODO: replace with py::try_cast so we can catch the error and throw a better exception - auto type = py::cast>(type_p); + auto type = py::cast>(type_p); types.push_back(type->Type()); } - for (auto &name : names) { - function_data->out_names.push_back(Identifier(name)); - } + function_data->out_names = names; function_data->out_types = types; return std::move(function_data); @@ -151,15 +149,10 @@ unique_ptr MapFunction::MapFunctionBind(ClientContext &context, Ta vector pandas_bind_data; // unused Pandas::Bind(context, df, pandas_bind_data, return_types, names); - // Build the Identifier names only after Pandas::Bind has populated 'names'. - vector name_identifiers(names.size()); - std::transform(names.begin(), names.end(), name_identifiers.begin(), - [](const string &name) { return Identifier(name); }); - // output types are potentially NULL, this happens for types that map to 'object' dtype - OverrideNullType(return_types, name_identifiers, data.in_types, data.in_names); + OverrideNullType(return_types, names, data.in_types, data.in_names); - data.out_names = name_identifiers; + data.out_names = names; data.out_types = return_types; return std::move(data_uptr); } @@ -198,10 +191,7 @@ OperatorResultType MapFunction::MapFunctionExec(ExecutionContext &context, Table throw InvalidInputException("UDF column type mismatch, expected [%s], got [%s]", TypeVectorToString(data.out_types), TypeVectorToString(pandas_return_types)); } - vector pandas_name_identifiers(pandas_names.size()); - std::transform(pandas_names.begin(), pandas_names.end(), pandas_name_identifiers.begin(), - [](const string &name) { return Identifier(name); }); - if (pandas_name_identifiers != data.out_names) { + if (pandas_names != data.out_names) { throw InvalidInputException("UDF column name mismatch, expected [%s], got [%s]", StringUtil::Join(data.out_names, ", "), StringUtil::Join(pandas_names, ", ")); } @@ -216,9 +206,9 @@ OperatorResultType MapFunction::MapFunctionExec(ExecutionContext &context, Table for (idx_t col_idx = 0; col_idx < output.ColumnCount(); col_idx++) { auto &bind_data = pandas_bind_data[col_idx]; - PandasScanFunction::PandasBackendScanSwitch(context.client, bind_data, row_count, 0, output.data[col_idx]); + PandasScanFunction::PandasBackendScanSwitch(bind_data, row_count, 0, output.data[col_idx]); } - output.SetChildCardinality(row_count); + output.SetCardinality(row_count); return OperatorResultType::NEED_MORE_INPUT; } diff --git a/src/duckdb_py/native/python_conversion.cpp b/src/duckdb_py/native/python_conversion.cpp index 722d85c2..a56ea73f 100644 --- a/src/duckdb_py/native/python_conversion.cpp +++ b/src/duckdb_py/native/python_conversion.cpp @@ -3,24 +3,16 @@ #include "duckdb_python/pyrelation.hpp" #include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb_python/pyresult.hpp" #include "duckdb/common/types.hpp" -#include "duckdb/common/vector/flat_vector.hpp" -#include "duckdb/common/vector/list_vector.hpp" -#include "duckdb/common/vector/array_vector.hpp" -#include "duckdb/common/vector/struct_vector.hpp" #include "duckdb/common/exception/conversion_exception.hpp" +#include "datetime.h" //From Python + #include "duckdb/common/limits.hpp" namespace duckdb { -// Proxy to LogicalType::ForceMaxLogicalType that switches based on the presence of a client context. -static LogicalType ProxiedForceMaxLogicalType(optional_ptr context, const LogicalType &left, - const LogicalType &right) { - return context ? LogicalType::ForceMaxLogicalType(*context, left, right) - : LogicalType::DefaultForceMaxLogicalType(left, right); -} - // Like DefaultCastAs, but handles UNION targets by finding the first compatible member. DefaultCastAs raises a // Conversion Error when multiple UNION members have the same type (e.g. UNION(u1 DOUBLE, u2 DOUBLE)), so for UNION // targets we resolve the member ourselves. @@ -56,11 +48,11 @@ static Value EmptyMapValue() { return Value::MAP(ListType::GetChildType(map_type), vector()); } -vector TransformStructKeys(py::handle keys, idx_t size, const LogicalType &type = LogicalType::UNKNOWN) { - vector res; +vector TransformStructKeys(py::handle keys, idx_t size, const LogicalType &type = LogicalType::UNKNOWN) { + vector res; res.reserve(size); for (idx_t i = 0; i < size; i++) { - res.emplace_back(Identifier(py::str(keys.attr("__getitem__")(i)))); + res.emplace_back(py::str(keys.attr("__getitem__")(i))); } return res; } @@ -113,8 +105,7 @@ bool DictionaryHasMapFormat(const PyDictionary &dict) { return true; } -Value TransformDictionaryToStruct(optional_ptr context, const PyDictionary &dict, - const LogicalType &target_type = LogicalType::UNKNOWN) { +Value TransformDictionaryToStruct(const PyDictionary &dict, const LogicalType &target_type = LogicalType::UNKNOWN) { auto struct_keys = TransformStructKeys(dict.keys, dict.len, target_type); bool struct_target = target_type.id() == LogicalTypeId::STRUCT; @@ -123,7 +114,7 @@ Value TransformDictionaryToStruct(optional_ptr context, const PyD dict.ToString(), target_type.ToString()); } - identifier_map_t key_mapping; + case_insensitive_map_t key_mapping; for (idx_t i = 0; i < struct_keys.size(); i++) { key_mapping[struct_keys[i]] = i; } @@ -133,14 +124,13 @@ Value TransformDictionaryToStruct(optional_ptr context, const PyD auto &key = struct_target ? StructType::GetChildName(target_type, i) : struct_keys[i]; auto value_index = struct_target ? key_mapping[key] : i; auto &child_type = struct_target ? StructType::GetChildType(target_type, i) : LogicalType::UNKNOWN; - auto val = TransformPythonValue(context, dict.values.attr("__getitem__")(value_index), child_type); + auto val = TransformPythonValue(dict.values.attr("__getitem__")(value_index), child_type); struct_values.emplace_back(make_pair(std::move(key), std::move(val))); } return Value::STRUCT(std::move(struct_values)); } -Value TransformStructFormatDictionaryToMap(optional_ptr context, const PyDictionary &dict, - const LogicalType &target_type) { +Value TransformStructFormatDictionaryToMap(const PyDictionary &dict, const LogicalType &target_type) { if (dict.len == 0) { return EmptyMapValue(); } @@ -165,11 +155,11 @@ Value TransformStructFormatDictionaryToMap(optional_ptr context, vector elements; for (idx_t i = 0; i < size; i++) { - Value new_key = TransformPythonValue(context, dict.keys.attr("__getitem__")(i), key_target); - Value new_value = TransformPythonValue(context, dict.values.attr("__getitem__")(i), value_target); + Value new_key = TransformPythonValue(dict.keys.attr("__getitem__")(i), key_target); + Value new_value = TransformPythonValue(dict.values.attr("__getitem__")(i), value_target); - key_type = ProxiedForceMaxLogicalType(context, key_type, new_key.type()); - value_type = ProxiedForceMaxLogicalType(context, value_type, new_value.type()); + key_type = LogicalType::ForceMaxLogicalType(key_type, new_key.type()); + value_type = LogicalType::ForceMaxLogicalType(value_type, new_value.type()); child_list_t struct_values; struct_values.emplace_back(make_pair("key", std::move(new_key))); @@ -189,11 +179,10 @@ Value TransformStructFormatDictionaryToMap(optional_ptr context, return Value::MAP(ListType::GetChildType(map_type), std::move(elements)); } -Value TransformDictionaryToMap(optional_ptr context, const PyDictionary &dict, - const LogicalType &target_type = LogicalType::UNKNOWN) { +Value TransformDictionaryToMap(const PyDictionary &dict, const LogicalType &target_type = LogicalType::UNKNOWN) { if (target_type.id() != LogicalTypeId::UNKNOWN && !DictionaryHasMapFormat(dict)) { // dict == { 'k1': v1, 'k2': v2, ..., 'kn': vn } - return TransformStructFormatDictionaryToMap(context, dict, target_type); + return TransformStructFormatDictionaryToMap(dict, target_type); } auto keys = dict.values.attr("__getitem__")(0); @@ -220,8 +209,8 @@ Value TransformDictionaryToMap(optional_ptr context, const PyDict value_target = LogicalType::LIST(MapType::ValueType(target_type)); } - auto key_list = TransformPythonValue(context, keys, key_target); - auto value_list = TransformPythonValue(context, values, value_target); + auto key_list = TransformPythonValue(keys, key_target); + auto value_list = TransformPythonValue(values, value_target); LogicalType key_type = LogicalType::SQLNULL; LogicalType value_type = LogicalType::SQLNULL; @@ -232,8 +221,8 @@ Value TransformDictionaryToMap(optional_ptr context, const PyDict Value new_key = ListValue::GetChildren(key_list)[i]; Value new_value = ListValue::GetChildren(value_list)[i]; - key_type = ProxiedForceMaxLogicalType(context, key_type, new_key.type()); - value_type = ProxiedForceMaxLogicalType(context, value_type, new_value.type()); + key_type = LogicalType::ForceMaxLogicalType(key_type, new_key.type()); + value_type = LogicalType::ForceMaxLogicalType(value_type, new_value.type()); child_list_t struct_values; struct_values.emplace_back(make_pair("key", std::move(new_key))); @@ -247,8 +236,7 @@ Value TransformDictionaryToMap(optional_ptr context, const PyDict return Value::MAP(ListType::GetChildType(map_type), std::move(elements)); } -Value TransformTupleToStruct(optional_ptr context, py::handle ele, - const LogicalType &target_type = LogicalType::UNKNOWN) { +Value TransformTupleToStruct(py::handle ele, const LogicalType &target_type = LogicalType::UNKNOWN) { auto tuple = py::cast(ele); auto size = py::len(tuple); @@ -265,7 +253,7 @@ Value TransformTupleToStruct(optional_ptr context, py::handle ele auto &type = child_types[i].second; auto &name = StructType::GetChildName(target_type, i); auto element = py::handle(tuple[i]); - auto converted_value = TransformPythonValue(context, element, type); + auto converted_value = TransformPythonValue(element, type); children.emplace_back(make_pair(name, std::move(converted_value))); } auto result = Value::STRUCT(std::move(children)); @@ -376,7 +364,7 @@ LogicalType SniffPythonIntegerType(py::handle ele) { return res.type(); } -Value TransformDictionary(optional_ptr context, const PyDictionary &dict) { +Value TransformDictionary(const PyDictionary &dict) { //! DICT -> MAP FORMAT // keys() = [key, value] // values() = [ [n keys] ], [ [n values] ] @@ -390,9 +378,9 @@ Value TransformDictionary(optional_ptr context, const PyDictionar } if (DictionaryHasMapFormat(dict)) { - return TransformDictionaryToMap(context, dict); + return TransformDictionaryToMap(dict); } - return TransformDictionaryToStruct(context, dict); + return TransformDictionaryToStruct(dict); } PythonObjectType GetPythonObjectType(py::handle &ele) { @@ -530,8 +518,7 @@ struct PythonValueConversion { } } - static void HandleList(optional_ptr context, Value &result, const LogicalType &target_type, - py::handle ele, idx_t list_size) { + static void HandleList(Value &result, const LogicalType &target_type, py::handle ele, idx_t list_size) { vector values; values.reserve(list_size); @@ -545,8 +532,8 @@ struct PythonValueConversion { } LogicalType element_type = LogicalType::SQLNULL; for (idx_t i = 0; i < list_size; i++) { - Value new_value = TransformPythonValue(context, ele.attr("__getitem__")(i), child_type); - element_type = ProxiedForceMaxLogicalType(context, element_type, new_value.type()); + Value new_value = TransformPythonValue(ele.attr("__getitem__")(i), child_type); + element_type = LogicalType::ForceMaxLogicalType(element_type, new_value.type()); values.push_back(std::move(new_value)); } if (is_array) { @@ -556,17 +543,16 @@ struct PythonValueConversion { } } - static void HandleTuple(optional_ptr context, Value &result, const LogicalType &target_type, - py::handle ele, idx_t list_size) { + static void HandleTuple(Value &result, const LogicalType &target_type, py::handle ele, idx_t list_size) { if (target_type.id() == LogicalTypeId::STRUCT) { - result = TransformTupleToStruct(context, ele, target_type); + result = TransformTupleToStruct(ele, target_type); return; } - HandleList(context, result, target_type, ele, list_size); + HandleList(result, target_type, ele, list_size); } - static Value HandleObjectInternal(optional_ptr context, py::handle ele, PythonObjectType object_type, - const LogicalType &target_type, bool nan_as_null) { + static Value HandleObjectInternal(py::handle ele, PythonObjectType object_type, const LogicalType &target_type, + bool nan_as_null) { switch (object_type) { case PythonObjectType::Decimal: { PyDecimal decimal(ele); @@ -584,32 +570,32 @@ struct PythonValueConversion { PyDictionary dict = PyDictionary(py::reinterpret_borrow(ele)); switch (target_type.id()) { case LogicalTypeId::STRUCT: - return TransformDictionaryToStruct(context, dict, target_type); + return TransformDictionaryToStruct(dict, target_type); case LogicalTypeId::MAP: - return TransformDictionaryToMap(context, dict, target_type); + return TransformDictionaryToMap(dict, target_type); default: - return TransformDictionary(context, dict); + return TransformDictionary(dict); } } case PythonObjectType::Value: { // Extract the internal object and the type from the Value instance auto object = ele.attr("object"); auto type = ele.attr("type"); - std::shared_ptr internal_type; - if (!py::try_cast>(type, internal_type)) { + shared_ptr internal_type; + if (!py::try_cast>(type, internal_type)) { string actual_type = py::str(py::type::of(type)); throw InvalidInputException("The 'type' of a Value should be of type DuckDBPyType, not '%s'", actual_type); } - return TransformPythonValue(context, object, internal_type->Type()); + return TransformPythonValue(object, internal_type->Type()); } default: throw InternalException("Unsupported fallback"); } } - static void HandleObject(optional_ptr context, py::handle ele, PythonObjectType object_type, - Value &result, const LogicalType &target_type, bool nan_as_null) { - result = HandleObjectInternal(context, ele, object_type, target_type, nan_as_null); + static void HandleObject(py::handle ele, PythonObjectType object_type, Value &result, + const LogicalType &target_type, bool nan_as_null) { + result = HandleObjectInternal(ele, object_type, target_type, nan_as_null); } }; @@ -626,16 +612,16 @@ struct PythonVectorConversion { LogicalType::BOOLEAN, result.GetType(), "Python Conversion Failure: Expected a value of type %s, but got a value of type boolean"); } - FlatVector::GetDataMutable(result)[result_offset] = val; + FlatVector::GetData(result)[result_offset] = val; } static void HandleDouble(Vector &result, const idx_t &result_offset, double val) { switch (result.GetType().id()) { case LogicalTypeId::DOUBLE: { - FlatVector::GetDataMutable(result)[result_offset] = val; + FlatVector::GetData(result)[result_offset] = val; break; } case LogicalTypeId::FLOAT: { - FlatVector::GetDataMutable(result)[result_offset] = static_cast(val); + FlatVector::GetData(result)[result_offset] = static_cast(val); break; } default: @@ -651,13 +637,13 @@ struct PythonVectorConversion { // this code path is only called for values in the range of [INT64_MAX...UINT64_MAX] switch (result.GetType().id()) { case LogicalTypeId::HUGEINT: - FlatVector::GetDataMutable(result)[result_offset] = Hugeint::Convert(value); + FlatVector::GetData(result)[result_offset] = Hugeint::Convert(value); break; case LogicalTypeId::UHUGEINT: - FlatVector::GetDataMutable(result)[result_offset] = Uhugeint::Convert(value); + FlatVector::GetData(result)[result_offset] = Uhugeint::Convert(value); break; case LogicalTypeId::UBIGINT: - FlatVector::GetDataMutable(result)[result_offset] = value; + FlatVector::GetData(result)[result_offset] = value; break; default: FallbackValueConversion(result, result_offset, CastToTarget(Value::UBIGINT(value), result.GetType())); @@ -667,67 +653,67 @@ struct PythonVectorConversion { static void HandleBigint(Vector &result, const idx_t &result_offset, int64_t value) { switch (result.GetType().id()) { case LogicalTypeId::HUGEINT: { - FlatVector::GetDataMutable(result)[result_offset] = Hugeint::Convert(value); + FlatVector::GetData(result)[result_offset] = Hugeint::Convert(value); break; } case LogicalTypeId::UHUGEINT: { if (value < 0) { throw InvalidInputException("Python Conversion Failure: Value out of range for type UHUGEINT"); } - FlatVector::GetDataMutable(result)[result_offset] = Uhugeint::Convert(value); + FlatVector::GetData(result)[result_offset] = Uhugeint::Convert(value); break; } case LogicalTypeId::BIGINT: { - FlatVector::GetDataMutable(result)[result_offset] = value; + FlatVector::GetData(result)[result_offset] = value; break; } case LogicalTypeId::INTEGER: { if (value < NumericLimits::Minimum() || value > NumericLimits::Maximum()) { throw InvalidInputException("Python Conversion Failure: Value out of range for type INT"); } - FlatVector::GetDataMutable(result)[result_offset] = static_cast(value); + FlatVector::GetData(result)[result_offset] = static_cast(value); break; } case LogicalTypeId::SMALLINT: { if (value < NumericLimits::Minimum() || value > NumericLimits::Maximum()) { throw InvalidInputException("Python Conversion Failure: Value out of range for type SMALLINT"); } - FlatVector::GetDataMutable(result)[result_offset] = static_cast(value); + FlatVector::GetData(result)[result_offset] = static_cast(value); break; } case LogicalTypeId::TINYINT: { if (value < NumericLimits::Minimum() || value > NumericLimits::Maximum()) { throw InvalidInputException("Python Conversion Failure: Value out of range for type TINYINT"); } - FlatVector::GetDataMutable(result)[result_offset] = static_cast(value); + FlatVector::GetData(result)[result_offset] = static_cast(value); break; } case LogicalTypeId::UBIGINT: { if (value < 0) { throw InvalidInputException("Python Conversion Failure: Value out of range for type UBIGINT"); } - FlatVector::GetDataMutable(result)[result_offset] = static_cast(value); + FlatVector::GetData(result)[result_offset] = static_cast(value); break; } case LogicalTypeId::UINTEGER: { if (value < 0 || value > (int64_t)NumericLimits::Maximum()) { throw InvalidInputException("Python Conversion Failure: Value out of range for type UINTEGER"); } - FlatVector::GetDataMutable(result)[result_offset] = static_cast(value); + FlatVector::GetData(result)[result_offset] = static_cast(value); break; } case LogicalTypeId::USMALLINT: { if (value < 0 || value > (int64_t)NumericLimits::Maximum()) { throw InvalidInputException("Python Conversion Failure: Value out of range for type USMALLINT"); } - FlatVector::GetDataMutable(result)[result_offset] = static_cast(value); + FlatVector::GetData(result)[result_offset] = static_cast(value); break; } case LogicalTypeId::UTINYINT: { if (value < 0 || value > (int64_t)NumericLimits::Maximum()) { throw InvalidInputException("Python Conversion Failure: Value out of range for type UTINYINT"); } - FlatVector::GetDataMutable(result)[result_offset] = static_cast(value); + FlatVector::GetData(result)[result_offset] = static_cast(value); break; } default: @@ -739,7 +725,7 @@ struct PythonVectorConversion { static void HandleString(Vector &result, const idx_t &result_offset, const string &value) { auto &result_type = result.GetType(); if (result_type.id() == LogicalTypeId::VARCHAR) { - FlatVector::GetDataMutable(result)[result_offset] = StringVector::AddString(result, value); + FlatVector::GetData(result)[result_offset] = StringVector::AddString(result, value); return; } Value result_val; @@ -751,7 +737,7 @@ struct PythonVectorConversion { auto &result_type = result.GetType(); switch (result_type.id()) { case LogicalTypeId::DATE: - FlatVector::GetDataMutable(result)[result_offset] = date.ToDate(); + FlatVector::GetData(result)[result_offset] = date.ToDate(); break; default: { auto value = date.ToDuckValue(); @@ -765,7 +751,7 @@ struct PythonVectorConversion { auto &result_type = result.GetType(); switch (result_type.id()) { case LogicalTypeId::TIME: - FlatVector::GetDataMutable(result)[result_offset] = time.ToDuckTime(); + FlatVector::GetData(result)[result_offset] = time.ToDuckTime(); break; default: { auto value = time.ToDuckValue(); @@ -779,7 +765,7 @@ struct PythonVectorConversion { auto &result_type = result.GetType(); switch (result_type.id()) { case LogicalTypeId::BLOB: - FlatVector::GetDataMutable(result)[result_offset] = + FlatVector::GetData(result)[result_offset] = StringVector::AddStringOrBlob(result, const_char_ptr_cast(blob), blob_size); break; default: { @@ -794,13 +780,13 @@ struct PythonVectorConversion { auto &result_type = result.GetType(); switch (result_type.id()) { case LogicalTypeId::TIMESTAMP: - FlatVector::GetDataMutable(result)[result_offset] = datetime.ToTimestamp(); + FlatVector::GetData(result)[result_offset] = datetime.ToTimestamp(); break; case LogicalTypeId::TIME: - FlatVector::GetDataMutable(result)[result_offset] = datetime.ToDuckTime(); + FlatVector::GetData(result)[result_offset] = datetime.ToDuckTime(); break; case LogicalTypeId::DATE: - FlatVector::GetDataMutable(result)[result_offset] = datetime.ToDate(); + FlatVector::GetData(result)[result_offset] = datetime.ToDate(); break; default: { auto value = datetime.ToDuckValue(result_type); @@ -811,8 +797,7 @@ struct PythonVectorConversion { } template - static void HandleListFast(optional_ptr context, Vector &result, const idx_t &result_offset, - py::handle ele, idx_t list_size) { + static void HandleListFast(Vector &result, const idx_t &result_offset, py::handle ele, idx_t list_size) { auto &result_type = result.GetType(); if (result_type.id() == LogicalTypeId::ARRAY) { idx_t array_size = ArrayType::GetSize(result_type); @@ -821,11 +806,11 @@ struct PythonVectorConversion { "size %d, but got a list of size %d", array_size, list_size); } - auto &child_array = ArrayVector::GetChildMutable(result); + auto &child_array = ArrayVector::GetEntry(result); idx_t start_offset = result_offset * array_size; for (idx_t i = 0; i < list_size; i++) { auto child_ele = IS_LIST ? PyList_GetItem(ele.ptr(), i) : PyTuple_GetItem(ele.ptr(), i); - TransformPythonObject(context, child_ele, child_array, start_offset + i); + TransformPythonObject(child_ele, child_array, start_offset + i); } return; } @@ -835,15 +820,15 @@ struct PythonVectorConversion { ListVector::Reserve(result, start_offset + list_size); // set up the list entry - auto &list_entry = FlatVector::GetDataMutable(result)[result_offset]; + auto &list_entry = FlatVector::GetData(result)[result_offset]; list_entry.offset = start_offset; list_entry.length = list_size; // convert the child elements - auto &child_vector = ListVector::GetChildMutable(result); + auto &child_vector = ListVector::GetEntry(result); for (idx_t i = 0; i < list_size; i++) { auto child_ele = IS_LIST ? PyList_GetItem(ele.ptr(), i) : PyTuple_GetItem(ele.ptr(), i); - TransformPythonObject(context, child_ele, child_vector, start_offset + i); + TransformPythonObject(child_ele, child_vector, start_offset + i); } ListVector::SetListSize(result, start_offset + list_size); return; @@ -851,21 +836,19 @@ struct PythonVectorConversion { throw InternalException("Unsupported type for HandleListFast"); } - static void HandleList(optional_ptr context, Vector &result, const idx_t &result_offset, - py::handle ele, idx_t list_size) { + static void HandleList(Vector &result, const idx_t &result_offset, py::handle ele, idx_t list_size) { auto &result_type = result.GetType(); if (result_type.id() == LogicalTypeId::ARRAY || result_type.id() == LogicalTypeId::LIST) { - HandleListFast(context, result, result_offset, ele, list_size); + HandleListFast(result, result_offset, ele, list_size); return; } // fallback to value conversion Value result_val; - PythonValueConversion::HandleList(context, result_val, result_type, ele, list_size); + PythonValueConversion::HandleList(result_val, result_type, ele, list_size); FallbackValueConversion(result, result_offset, std::move(result_val)); } - static void ConvertTupleToStruct(optional_ptr context, Vector &result, const idx_t &result_offset, - py::handle ele, idx_t size) { + static void ConvertTupleToStruct(Vector &result, const idx_t &result_offset, py::handle ele, idx_t size) { auto &child_types = StructType::GetChildTypes(result.GetType()); auto child_count = child_types.size(); if (size != child_count) { @@ -877,20 +860,19 @@ struct PythonVectorConversion { auto &struct_children = StructVector::GetEntries(result); for (idx_t i = 0; i < child_count; i++) { auto child_ele = PyTuple_GetItem(ele.ptr(), i); - TransformPythonObject(context, child_ele, struct_children[i], result_offset); + TransformPythonObject(child_ele, *struct_children[i], result_offset); } } - static void HandleTuple(optional_ptr context, Vector &result, const idx_t &result_offset, - py::handle ele, idx_t tuple_size) { + static void HandleTuple(Vector &result, const idx_t &result_offset, py::handle ele, idx_t tuple_size) { auto &result_type = result.GetType(); switch (result_type.id()) { case LogicalTypeId::STRUCT: - ConvertTupleToStruct(context, result, result_offset, ele, tuple_size); + ConvertTupleToStruct(result, result_offset, ele, tuple_size); break; case LogicalTypeId::ARRAY: case LogicalTypeId::LIST: - HandleListFast(context, result, result_offset, ele, tuple_size); + HandleListFast(result, result_offset, ele, tuple_size); break; default: throw InternalException("Unsupported type for HandleTuple"); @@ -900,17 +882,16 @@ struct PythonVectorConversion { static void FallbackValueConversion(Vector &result, const idx_t &result_offset, Value val) { result.SetValue(result_offset, val); } - static void HandleObject(optional_ptr context, py::handle ele, PythonObjectType object_type, - Vector &result, const idx_t &result_offset, bool nan_as_null) { + static void HandleObject(py::handle ele, PythonObjectType object_type, Vector &result, const idx_t &result_offset, + bool nan_as_null) { Value result_val; - PythonValueConversion::HandleObject(context, ele, object_type, result_val, result.GetType(), nan_as_null); + PythonValueConversion::HandleObject(ele, object_type, result_val, result.GetType(), nan_as_null); result.SetValue(result_offset, result_val); } }; template -void TransformPythonObjectInternal(optional_ptr context, py::handle ele, A &result, const B ¶m, - bool nan_as_null) { +void TransformPythonObjectInternal(py::handle ele, A &result, const B ¶m, bool nan_as_null) { auto object_type = GetPythonObjectType(ele); switch (object_type) { @@ -973,7 +954,7 @@ void TransformPythonObjectInternal(optional_ptr context, py::hand } case PythonObjectType::List: { auto list_size = py::len(ele); - OP::HandleList(context, result, param, ele, list_size); + OP::HandleList(result, param, ele, list_size); break; } case PythonObjectType::Tuple: { @@ -984,7 +965,7 @@ void TransformPythonObjectInternal(optional_ptr context, py::hand case LogicalTypeId::UNKNOWN: case LogicalTypeId::LIST: case LogicalTypeId::ARRAY: - OP::HandleTuple(context, result, param, ele, list_size); + OP::HandleTuple(result, param, ele, list_size); break; default: throw InvalidInputException("Can't convert tuple to a Value of type %s", conversion_target); @@ -1041,14 +1022,14 @@ void TransformPythonObjectInternal(optional_ptr context, py::hand } case PythonObjectType::NdArray: case PythonObjectType::NdDatetime: - TransformPythonObjectInternal(context, ele.attr("tolist")(), result, param, nan_as_null); + TransformPythonObjectInternal(ele.attr("tolist")(), result, param, nan_as_null); break; case PythonObjectType::Uuid: case PythonObjectType::Timedelta: case PythonObjectType::Dict: case PythonObjectType::Value: case PythonObjectType::Decimal: { - OP::HandleObject(context, ele, object_type, result, param, nan_as_null); + OP::HandleObject(ele, object_type, result, param, nan_as_null); break; } case PythonObjectType::Other: @@ -1059,15 +1040,13 @@ void TransformPythonObjectInternal(optional_ptr context, py::hand } } -void TransformPythonObject(optional_ptr context, py::handle ele, Vector &vector, idx_t result_offset, - bool nan_as_null) { - TransformPythonObjectInternal(context, ele, vector, result_offset, nan_as_null); +void TransformPythonObject(py::handle ele, Vector &vector, idx_t result_offset, bool nan_as_null) { + TransformPythonObjectInternal(ele, vector, result_offset, nan_as_null); } -Value TransformPythonValue(optional_ptr context, py::handle ele, const LogicalType &target_type, - bool nan_as_null) { +Value TransformPythonValue(py::handle ele, const LogicalType &target_type, bool nan_as_null) { Value result; - TransformPythonObjectInternal(context, ele, result, target_type, nan_as_null); + TransformPythonObjectInternal(ele, result, target_type, nan_as_null); return result; } diff --git a/src/duckdb_py/native/python_objects.cpp b/src/duckdb_py/native/python_objects.cpp index d34cf28f..c59bd76d 100644 --- a/src/duckdb_py/native/python_objects.cpp +++ b/src/duckdb_py/native/python_objects.cpp @@ -4,6 +4,7 @@ #include "duckdb/common/types/value.hpp" #include "duckdb/common/types/decimal.hpp" #include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/cast_helpers.hpp" #include "duckdb/common/operator/cast_operators.hpp" #include "duckdb_python/pyconnection/pyconnection.hpp" #include "duckdb/common/operator/add.hpp" @@ -12,7 +13,7 @@ #include "datetime.h" // Python datetime initialize #1 -#include +#include #include namespace duckdb { @@ -107,6 +108,7 @@ bool PyDecimal::TryGetType(LogicalType &type) { throw NotImplementedException("case not implemented for type PyDecimalExponentType"); } // LCOV_EXCL_STOP } + return true; } // LCOV_EXCL_START static void ExponentNotRecognized() { @@ -437,7 +439,6 @@ static bool KeyIsHashable(const LogicalType &type) { case LogicalTypeId::TIMESTAMP_NS: case LogicalTypeId::TIMESTAMP_SEC: case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_TZ_NS: case LogicalTypeId::TIME_TZ: case LogicalTypeId::TIME: case LogicalTypeId::DATE: @@ -462,9 +463,6 @@ static bool KeyIsHashable(const LogicalType &type) { } case LogicalTypeId::STRUCT: return false; - case LogicalTypeId::SQLNULL: - // A SQLNULL key is always NULL, and Python's None is hashable. - return true; default: throw NotImplementedException("Unsupported type: \"%s\"", type.ToString()); } @@ -522,8 +520,7 @@ py::object PythonObject::FromValue(const Value &val, const LogicalType &type, case LogicalTypeId::TIMESTAMP_MS: case LogicalTypeId::TIMESTAMP_NS: case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_TZ_NS: { + case LogicalTypeId::TIMESTAMP_TZ: { D_ASSERT(type.InternalType() == PhysicalType::INT64); auto timestamp = val.GetValueUnsafe(); @@ -537,7 +534,7 @@ py::object PythonObject::FromValue(const Value &val, const LogicalType &type, if (type.id() == LogicalTypeId::TIMESTAMP_MS) { timestamp = Timestamp::FromEpochMs(timestamp.value); - } else if (type.id() == LogicalTypeId::TIMESTAMP_NS || type.id() == LogicalTypeId::TIMESTAMP_TZ_NS) { + } else if (type.id() == LogicalTypeId::TIMESTAMP_NS) { timestamp = Timestamp::FromEpochNanoSeconds(timestamp.value); } else if (type.id() == LogicalTypeId::TIMESTAMP_SEC) { timestamp = Timestamp::FromEpochSeconds(timestamp.value); @@ -560,7 +557,7 @@ py::object PythonObject::FromValue(const Value &val, const LogicalType &type, // Failed to convert, fall back to str return py::str(val.ToString()); } - if (type.id() == LogicalTypeId::TIMESTAMP_TZ || type.id() == LogicalTypeId::TIMESTAMP_TZ_NS) { + if (type.id() == LogicalTypeId::TIMESTAMP_TZ) { // We have to add the timezone info auto tz_utc = import_cache.pytz.timezone()("UTC"); auto timestamp_utc = tz_utc.attr("localize")(py_timestamp); @@ -620,7 +617,7 @@ py::object PythonObject::FromValue(const Value &val, const LogicalType &type, D_ASSERT(type.InternalType() == PhysicalType::INT32); auto date = val.GetValueUnsafe(); int32_t year, month, day; - if (!Value::IsFinite(date)) { + if (!duckdb::Date::IsFinite(date)) { if (date == date_t::infinity()) { return py::reinterpret_borrow(import_cache.datetime.date.max()); } @@ -710,9 +707,9 @@ py::object PythonObject::FromValue(const Value &val, const LogicalType &type, py::arg("microseconds") = interval_value.micros); } case LogicalTypeId::VARIANT: { - Vector tmp(val, count_t(1)); + Vector tmp(val); RecursiveUnifiedVectorFormat format; - Vector::RecursiveToUnifiedFormat(tmp, format); + Vector::RecursiveToUnifiedFormat(tmp, 1, format); UnifiedVariantVectorData vector_data(format); auto variant_val = VariantUtils::ConvertVariantToValue(vector_data, 0, 0); return FromValue(variant_val, variant_val.type(), client_properties); diff --git a/src/duckdb_py/numpy/array_wrapper.cpp b/src/duckdb_py/numpy/array_wrapper.cpp index 7cf38f6d..5b3a372b 100644 --- a/src/duckdb_py/numpy/array_wrapper.cpp +++ b/src/duckdb_py/numpy/array_wrapper.cpp @@ -8,9 +8,8 @@ #include "duckdb_python/pyrelation.hpp" #include "duckdb_python/python_objects.hpp" #include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb_python/pyresult.hpp" #include "duckdb/common/types/uuid.hpp" -#include "duckdb/common/vector/list_vector.hpp" -#include "duckdb/common/vector/array_vector.hpp" #include @@ -36,7 +35,7 @@ struct TimestampConvert { template static int64_t ConvertValue(timestamp_t val, NumpyAppendData &append_data) { (void)append_data; - if (!Value::IsFinite(val)) { + if (!Timestamp::IsFinite(val)) { return val.value; } return Timestamp::GetEpochNanoSeconds(val); @@ -53,7 +52,7 @@ struct TimestampConvertSec { template static int64_t ConvertValue(timestamp_t val, NumpyAppendData &append_data) { (void)append_data; - if (!Value::IsFinite(val)) { + if (!Timestamp::IsFinite(val)) { return val.value; } return Timestamp::GetEpochNanoSeconds(Timestamp::FromEpochSeconds(val.value)); @@ -70,7 +69,7 @@ struct TimestampConvertMilli { template static int64_t ConvertValue(timestamp_t val, NumpyAppendData &append_data) { (void)append_data; - if (!Value::IsFinite(val)) { + if (!Timestamp::IsFinite(val)) { return val.value; } return Timestamp::GetEpochNanoSeconds(Timestamp::FromEpochMs(val.value)); @@ -180,25 +179,6 @@ struct StringConvert { } }; -struct NullConvert { - template - static PyObject *ConvertValue(DUCKDB_T val, NumpyAppendData &append_data) { - // A SQLNULL column contains only NULLs, so ConvertValue is never reached; every row takes NullValue. - (void)val; - (void)append_data; - Py_RETURN_NONE; - } - template - static NUMPY_T NullValue(bool &set_mask) { - if (PANDAS) { - set_mask = false; - Py_RETURN_NONE; - } - set_mask = true; - return nullptr; - } -}; - struct BlobConvert { template static PyObject *ConvertValue(string_t val, NumpyAppendData &append_data) { @@ -257,6 +237,7 @@ static py::object InternalCreateList(Vector &input, idx_t total_size, idx_t offs struct ListConvert { static py::object ConvertValue(Vector &input, idx_t chunk_offset, NumpyAppendData &append_data) { + auto &client_properties = append_data.client_properties; auto &list_data = append_data.idata; // Get the list entry information from the parent @@ -268,7 +249,7 @@ struct ListConvert { auto list_size = list_entry.length; auto list_offset = list_entry.offset; auto child_size = ListVector::GetListSize(input); - auto &child_vector = ListVector::GetChildMutable(input); + auto &child_vector = ListVector::GetEntry(input); return InternalCreateList(child_vector, child_size, list_offset, list_size, append_data); } @@ -288,7 +269,7 @@ struct ArrayConvert { auto array_size = ArrayType::GetSize(array_type); auto array_offset = array_index * array_size; auto child_size = ArrayVector::GetTotalSize(input); - auto &child_vector = ArrayVector::GetChildMutable(input); + auto &child_vector = ArrayVector::GetEntry(input); return InternalCreateList(child_vector, child_size, array_offset, array_size, append_data); } @@ -327,9 +308,9 @@ struct VariantConvert { static py::object ConvertValue(Vector &input, idx_t chunk_offset, NumpyAppendData &append_data) { auto &client_properties = append_data.client_properties; auto val = input.GetValue(chunk_offset); - Vector tmp(val, count_t(1)); + Vector tmp(val); RecursiveUnifiedVectorFormat format; - Vector::RecursiveToUnifiedFormat(tmp, format); + Vector::RecursiveToUnifiedFormat(tmp, 1, format); UnifiedVariantVectorData vector_data(format); auto variant_val = VariantUtils::ConvertVariantToValue(vector_data, 0, 0); return PythonObject::FromValue(variant_val, variant_val.type(), client_properties); @@ -410,16 +391,22 @@ static bool ConvertColumn(NumpyAppendData &append_data) { auto src_ptr = UnifiedVectorFormat::GetData(idata); auto out_ptr = reinterpret_cast(target_data); - if (!idata.validity.CannotHaveNull()) { + if (!idata.validity.AllValid()) { if (append_data.pandas) { return ConvertColumnTemplated(append_data); + } else { + return ConvertColumnTemplated( + append_data); + } + } else { + if (append_data.pandas) { + return ConvertColumnTemplated( + append_data); + } else { + return ConvertColumnTemplated( + append_data); } - return ConvertColumnTemplated(append_data); - } - if (append_data.pandas) { - return ConvertColumnTemplated(append_data); } - return ConvertColumnTemplated(append_data); } template @@ -432,23 +419,23 @@ static bool ConvertColumnCategoricalTemplate(NumpyAppendData &append_data) { auto src_ptr = UnifiedVectorFormat::GetData(idata); auto out_ptr = reinterpret_cast(target_data); - if (!idata.validity.CannotHaveNull()) { + if (!idata.validity.AllValid()) { for (idx_t i = 0; i < count; i++) { idx_t src_idx = idata.sel->get_index(i + source_offset); idx_t offset = target_offset + i; if (!idata.validity.RowIsValidUnsafe(src_idx)) { out_ptr[offset] = static_cast(-1); } else { - out_ptr[offset] = - duckdb_py_convert::RegularConvert::ConvertValue(src_ptr[src_idx], append_data); + out_ptr[offset] = duckdb_py_convert::RegularConvert::template ConvertValue( + src_ptr[src_idx], append_data); } } } else { for (idx_t i = 0; i < count; i++) { idx_t src_idx = idata.sel->get_index(i + source_offset); idx_t offset = target_offset + i; - out_ptr[offset] = - duckdb_py_convert::RegularConvert::ConvertValue(src_ptr[src_idx], append_data); + out_ptr[offset] = duckdb_py_convert::RegularConvert::template ConvertValue( + src_ptr[src_idx], append_data); } } // Null values are encoded in the data itself @@ -462,11 +449,12 @@ static bool ConvertNested(NumpyAppendData &append_data) { auto target_mask = append_data.target_mask; auto &input = append_data.input; auto &idata = append_data.idata; + auto &client_properties = append_data.client_properties; auto count = append_data.count; auto source_offset = append_data.source_offset; auto out_ptr = reinterpret_cast(target_data); - if (!idata.validity.CannotHaveNull()) { + if (!idata.validity.AllValid()) { bool requires_mask = false; for (idx_t i = 0; i < count; i++) { idx_t index = i + source_offset; @@ -526,7 +514,7 @@ static bool ConvertDecimalInternal(NumpyAppendData &append_data, double division auto src_ptr = UnifiedVectorFormat::GetData(idata); auto out_ptr = reinterpret_cast(target_data); - if (!idata.validity.CannotHaveNull()) { + if (!idata.validity.AllValid()) { bool requires_mask = false; for (idx_t i = 0; i < count; i++) { idx_t src_idx = idata.sel->get_index(i + source_offset); @@ -575,8 +563,8 @@ static bool ConvertDecimal(NumpyAppendData &append_data) { ArrayWrapper::ArrayWrapper(const LogicalType &type, const ClientProperties &client_properties_p, bool pandas) : requires_mask(false), client_properties(client_properties_p), pandas(pandas) { - data = std::make_unique(type); - mask = std::make_unique(LogicalType::BOOLEAN); + data = make_uniq(type); + mask = make_uniq(LogicalType::BOOLEAN); } void ArrayWrapper::Initialize(idx_t capacity) { @@ -598,7 +586,7 @@ void ArrayWrapper::Append(idx_t current_offset, Vector &input, idx_t source_size bool may_have_null; UnifiedVectorFormat idata; - input.ToUnifiedFormat(idata); + input.ToUnifiedFormat(source_size, idata); if (count == DConstants::INVALID_INDEX) { D_ASSERT(source_size != DConstants::INVALID_INDEX); @@ -673,7 +661,6 @@ void ArrayWrapper::Append(idx_t current_offset, Vector &input, idx_t source_size break; case LogicalTypeId::TIMESTAMP: case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_TZ_NS: case LogicalTypeId::TIMESTAMP_SEC: case LogicalTypeId::TIMESTAMP_MS: case LogicalTypeId::TIMESTAMP_NS: @@ -722,11 +709,6 @@ void ArrayWrapper::Append(idx_t current_offset, Vector &input, idx_t source_size case LogicalTypeId::UUID: may_have_null = ConvertColumn(append_data); break; - case LogicalTypeId::SQLNULL: - // An all-NULL column (e.g. an untyped NULL literal): emit an object column of None. SQLNULL's physical - // type is INT32, but its data is never read since every row is NULL. - may_have_null = ConvertColumn(append_data); - break; default: throw NotImplementedException("Unsupported type \"%s\"", input.GetType().ToString()); @@ -739,15 +721,15 @@ void ArrayWrapper::Append(idx_t current_offset, Vector &input, idx_t source_size } py::object ArrayWrapper::ToArray() const { - D_ASSERT(data->array.GetArray() && mask->array.GetArray()); + D_ASSERT(data->array && mask->array); data->Resize(data->count); if (!requires_mask) { - return std::move(data->array.GetArray()); + return std::move(data->array); } mask->Resize(mask->count); // construct numpy arrays from the data and the mask - auto values = std::move(data->array.GetArray()); - auto nullmask = std::move(mask->array.GetArray()); + auto values = std::move(data->array); + auto nullmask = std::move(mask->array); // create masked array and return it auto masked_array = py::module::import("numpy.ma").attr("masked_array")(values, nullmask); diff --git a/src/duckdb_py/numpy/numpy_bind.cpp b/src/duckdb_py/numpy/numpy_bind.cpp index c197e4ba..e2a4a83f 100644 --- a/src/duckdb_py/numpy/numpy_bind.cpp +++ b/src/duckdb_py/numpy/numpy_bind.cpp @@ -1,6 +1,5 @@ #include "duckdb_python/numpy/numpy_bind.hpp" #include "duckdb_python/numpy/array_wrapper.hpp" -#include "duckdb_python/numpy/numpy_array.hpp" #include "duckdb_python/pandas/pandas_analyzer.hpp" #include "duckdb_python/pandas/column/pandas_numpy_column.hpp" #include "duckdb_python/pandas/pandas_bind.hpp" @@ -9,7 +8,7 @@ namespace duckdb { -void NumpyBind::Bind(ClientContext &context, py::handle df, vector &bind_columns, +void NumpyBind::Bind(const ClientContext &context, py::handle df, vector &bind_columns, vector &return_types, vector &names) { auto df_columns = py::list(df.attr("keys")()); @@ -35,7 +34,7 @@ void NumpyBind::Bind(ClientContext &context, py::handle df, vector(NumpyArray(column.attr("astype")("float32"))); + bind_data.pandas_col = make_uniq(py::array(column.attr("astype")("float32"))); bind_data.numpy_type.type = NumpyNullableType::FLOAT_32; duckdb_col_type = NumpyToLogicalType(bind_data.numpy_type); } else if (bind_data.numpy_type.type == NumpyNullableType::STRING) { @@ -47,16 +46,16 @@ void NumpyBind::Bind(ClientContext &context, py::handle df, vector enum_entries = py::cast>(uniq.attr("__getitem__")(0)); idx_t size = enum_entries.size(); Vector enum_entries_vec(LogicalType::VARCHAR, size); - auto enum_entries_ptr = FlatVector::GetDataMutable(enum_entries_vec); + auto enum_entries_ptr = FlatVector::GetData(enum_entries_vec); for (idx_t i = 0; i < size; i++) { enum_entries_ptr[i] = StringVector::AddStringOrBlob(enum_entries_vec, enum_entries[i]); } duckdb_col_type = LogicalType::ENUM(enum_entries_vec, size); auto pandas_col = uniq.attr("__getitem__")(1); bind_data.internal_categorical_type = string(py::str(pandas_col.attr("dtype"))); - bind_data.pandas_col = std::make_unique(NumpyArray(pandas_col)); + bind_data.pandas_col = make_uniq(pandas_col); } else { - bind_data.pandas_col = std::make_unique(NumpyArray(column)); + bind_data.pandas_col = make_uniq(column); duckdb_col_type = NumpyToLogicalType(bind_data.numpy_type); } diff --git a/src/duckdb_py/numpy/numpy_scan.cpp b/src/duckdb_py/numpy/numpy_scan.cpp index 9c965968..b1cd6e60 100644 --- a/src/duckdb_py/numpy/numpy_scan.cpp +++ b/src/duckdb_py/numpy/numpy_scan.cpp @@ -4,8 +4,6 @@ #include "duckdb_python/python_conversion.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/vector/struct_vector.hpp" -#include "duckdb/common/vector/map_vector.hpp" #include "utf8proc_wrapper.hpp" #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb_python/pandas/pandas_bind.hpp" @@ -14,16 +12,15 @@ #include "duckdb_python/numpy/numpy_type.hpp" #include "duckdb/function/scalar/nested_functions.hpp" #include "duckdb_python/numpy/numpy_scan.hpp" -#include "duckdb_python/numpy/numpy_array.hpp" #include "duckdb_python/pandas/column/pandas_numpy_column.hpp" namespace duckdb { template -void ScanNumpyColumn(NumpyArray &numpy_col, idx_t stride, idx_t offset, Vector &out, idx_t count) { - auto src_ptr = (T *)numpy_col.Data(); +void ScanNumpyColumn(py::array &numpy_col, idx_t stride, idx_t offset, Vector &out, idx_t count) { + auto src_ptr = (T *)numpy_col.data(); if (stride == sizeof(T)) { - FlatVector::SetData(out, data_ptr_cast(src_ptr + offset), count_t(count)); + FlatVector::SetData(out, data_ptr_cast(src_ptr + offset)); } else { auto tgt_ptr = (T *)FlatVector::GetData(out); for (idx_t i = 0; i < count; i++) { @@ -33,10 +30,10 @@ void ScanNumpyColumn(NumpyArray &numpy_col, idx_t stride, idx_t offset, Vector & } template -void ScanNumpyCategoryTemplated(NumpyArray &column, idx_t offset, Vector &out, idx_t count) { - auto src_ptr = (T *)column.Data(); +void ScanNumpyCategoryTemplated(py::array &column, idx_t offset, Vector &out, idx_t count) { + auto src_ptr = (T *)column.data(); auto tgt_ptr = (V *)FlatVector::GetData(out); - auto &tgt_mask = FlatVector::ValidityMutable(out); + auto &tgt_mask = FlatVector::Validity(out); for (idx_t i = 0; i < count; i++) { if (src_ptr[i + offset] == -1) { // Null value @@ -48,7 +45,7 @@ void ScanNumpyCategoryTemplated(NumpyArray &column, idx_t offset, Vector &out, i } template -void ScanNumpyCategory(NumpyArray &column, idx_t count, idx_t offset, Vector &out, string &src_type) { +void ScanNumpyCategory(py::array &column, idx_t count, idx_t offset, Vector &out, string &src_type) { if (src_type == "int8") { ScanNumpyCategoryTemplated(column, offset, out, count); } else if (src_type == "int16") { @@ -64,7 +61,7 @@ void ScanNumpyCategory(NumpyArray &column, idx_t count, idx_t offset, Vector &ou static void ApplyMask(PandasColumnBindData &bind_data, ValidityMask &validity, idx_t count, idx_t offset) { D_ASSERT(bind_data.mask); - auto mask = reinterpret_cast(bind_data.mask->numpy_array.Data()); + auto mask = reinterpret_cast(bind_data.mask->numpy_array.data()); for (idx_t i = 0; i < count; i++) { auto is_null = mask[offset + i]; if (is_null) { @@ -79,7 +76,7 @@ void ScanNumpyMasked(PandasColumnBindData &bind_data, idx_t count, idx_t offset, auto &numpy_col = reinterpret_cast(*bind_data.pandas_col); ScanNumpyColumn(numpy_col.array, numpy_col.stride, offset, out, count); if (bind_data.mask) { - auto &result_mask = FlatVector::ValidityMutable(out); + auto &result_mask = FlatVector::Validity(out); ApplyMask(bind_data, result_mask, count, offset); } } @@ -87,26 +84,27 @@ void ScanNumpyMasked(PandasColumnBindData &bind_data, idx_t count, idx_t offset, template void ScanNumpyFpColumn(PandasColumnBindData &bind_data, const T *src_ptr, idx_t stride, idx_t count, idx_t offset, Vector &out) { + auto &mask = FlatVector::Validity(out); if (stride == sizeof(T)) { - FlatVector::SetData(out, (data_ptr_t)(src_ptr + offset), count_t(count)); // NOLINT + FlatVector::SetData(out, (data_ptr_t)(src_ptr + offset)); // NOLINT // Turn NaN values into NULL auto tgt_ptr = FlatVector::GetData(out); for (idx_t i = 0; i < count; i++) { if (Value::IsNan(tgt_ptr[i])) { - FlatVector::ValidityMutable(out).SetInvalid(i); + mask.SetInvalid(i); } } } else { - auto tgt_ptr = FlatVector::GetDataMutable(out); + auto tgt_ptr = FlatVector::GetData(out); for (idx_t i = 0; i < count; i++) { tgt_ptr[i] = src_ptr[stride / sizeof(T) * (i + offset)]; if (Value::IsNan(tgt_ptr[i])) { - FlatVector::ValidityMutable(out).SetInvalid(i); + mask.SetInvalid(i); } } } if (bind_data.mask) { - auto &result_mask = FlatVector::ValidityMutable(out); + auto &result_mask = FlatVector::Validity(out); ApplyMask(bind_data, result_mask, count, offset); } } @@ -133,26 +131,26 @@ static string_t DecodePythonUnicode(T *codepoints, idx_t codepoint_count, Vector } static void SetInvalidRecursive(Vector &out, idx_t index) { - auto &validity = FlatVector::ValidityMutable(out); + auto &validity = FlatVector::Validity(out); validity.SetInvalid(index); if (out.GetType().InternalType() == PhysicalType::STRUCT) { auto &children = StructVector::GetEntries(out); for (idx_t i = 0; i < children.size(); i++) { - SetInvalidRecursive(children[i], index); + SetInvalidRecursive(*children[i], index); } } } //! 'count' is the amount of rows in the 'out' vector //! 'offset' is the current row number within this vector -void ScanNumpyObject(optional_ptr context, PyObject *object, idx_t offset, Vector &out) { +void ScanNumpyObject(PyObject *object, idx_t offset, Vector &out) { // handle None if (object == Py_None) { SetInvalidRecursive(out, offset); return; } - TransformPythonObject(context, object, out, offset); + TransformPythonObject(object, out, offset); } static void VerifyMapConstraints(Vector &vec, idx_t count) { @@ -180,8 +178,7 @@ void VerifyTypeConstraints(Vector &vec, idx_t count) { } } -void NumpyScan::ScanObjectColumn(ClientContext &context, PyObject **col, idx_t stride, idx_t count, idx_t offset, - Vector &out) { +void NumpyScan::ScanObjectColumn(PyObject **col, idx_t stride, idx_t count, idx_t offset, Vector &out) { // numpy_col is a sequential list of objects, that make up one "column" (Vector) out.SetVectorType(VectorType::FLAT_VECTOR); PythonGILWrapper gil; // We're creating python objects here, so we need the GIL @@ -189,12 +186,12 @@ void NumpyScan::ScanObjectColumn(ClientContext &context, PyObject **col, idx_t s if (stride == sizeof(PyObject *)) { auto src_ptr = col + offset; for (idx_t i = 0; i < count; i++) { - ScanNumpyObject(context, src_ptr[i], i, out); + ScanNumpyObject(src_ptr[i], i, out); } } else { for (idx_t i = 0; i < count; i++) { auto src_ptr = col[stride / sizeof(PyObject *) * (i + offset)]; - ScanNumpyObject(context, src_ptr, i, out); + ScanNumpyObject(src_ptr, i, out); } } VerifyTypeConstraints(out, count); @@ -202,7 +199,7 @@ void NumpyScan::ScanObjectColumn(ClientContext &context, PyObject **col, idx_t s //! 'offset' is the offset within the column //! 'count' is the amount of values we will convert in this batch -void NumpyScan::Scan(ClientContext &context, PandasColumnBindData &bind_data, idx_t count, idx_t offset, Vector &out) { +void NumpyScan::Scan(PandasColumnBindData &bind_data, idx_t count, idx_t offset, Vector &out) { D_ASSERT(bind_data.pandas_col->Backend() == PandasColumnBackend::NUMPY); auto &numpy_col = reinterpret_cast(*bind_data.pandas_col); auto &array = numpy_col.array; @@ -237,19 +234,20 @@ void NumpyScan::Scan(ClientContext &context, PandasColumnBindData &bind_data, id ScanNumpyMasked(bind_data, count, offset, out); break; case NumpyNullableType::FLOAT_32: - ScanNumpyFpColumn(bind_data, reinterpret_cast(array.Data()), numpy_col.stride, count, + ScanNumpyFpColumn(bind_data, reinterpret_cast(array.data()), numpy_col.stride, count, offset, out); break; case NumpyNullableType::FLOAT_64: - ScanNumpyFpColumn(bind_data, reinterpret_cast(array.Data()), numpy_col.stride, count, + ScanNumpyFpColumn(bind_data, reinterpret_cast(array.data()), numpy_col.stride, count, offset, out); break; case NumpyNullableType::DATETIME_NS: case NumpyNullableType::DATETIME_MS: case NumpyNullableType::DATETIME_US: case NumpyNullableType::DATETIME_S: { - auto src_ptr = reinterpret_cast(array.Data()); - auto tgt_ptr = FlatVector::GetDataMutable(out); + auto src_ptr = reinterpret_cast(array.data()); + auto tgt_ptr = FlatVector::GetData(out); + auto &mask = FlatVector::Validity(out); using timestamp_convert_func = std::function; timestamp_convert_func convert_func; @@ -290,13 +288,13 @@ void NumpyScan::Scan(ClientContext &context, PandasColumnBindData &bind_data, id auto source_idx = stride / sizeof(int64_t) * (row + offset); if (src_ptr[source_idx] <= NumericLimits::Minimum()) { // pandas Not a Time (NaT) - FlatVector::ValidityMutable(out).SetInvalid(row); + mask.SetInvalid(row); continue; } // Direct conversion, we've already matched the numpy type with the equivalent duckdb type auto input = timestamp_t(src_ptr[source_idx]); - if (Value::IsFinite(input)) { + if (Timestamp::IsFinite(input)) { tgt_ptr[row] = convert_func(src_ptr[source_idx]); } else { tgt_ptr[row] = input; @@ -308,9 +306,9 @@ void NumpyScan::Scan(ClientContext &context, PandasColumnBindData &bind_data, id case NumpyNullableType::TIMEDELTA_US: case NumpyNullableType::TIMEDELTA_MS: case NumpyNullableType::TIMEDELTA_S: { - auto src_ptr = reinterpret_cast(array.Data()); - auto tgt_ptr = FlatVector::GetDataMutable(out); - auto &mask = FlatVector::ValidityMutable(out); + auto src_ptr = reinterpret_cast(array.data()); + auto tgt_ptr = FlatVector::GetData(out); + auto &mask = FlatVector::Validity(out); for (idx_t row = 0; row < count; row++) { auto source_idx = stride / sizeof(int64_t) * (row + offset); @@ -353,17 +351,17 @@ void NumpyScan::Scan(ClientContext &context, PandasColumnBindData &bind_data, id case NumpyNullableType::STRING: case NumpyNullableType::OBJECT: { // Get the source pointer of the numpy array - auto src_ptr = (PyObject **)array.Data(); // NOLINT + auto src_ptr = (PyObject **)array.data(); // NOLINT const bool is_object_col = bind_data.numpy_type.type == NumpyNullableType::OBJECT; if (is_object_col && out.GetType().id() != LogicalTypeId::VARCHAR) { //! We have determined the underlying logical type of this object column - return NumpyScan::ScanObjectColumn(context, src_ptr, numpy_col.stride, count, offset, out); + return NumpyScan::ScanObjectColumn(src_ptr, numpy_col.stride, count, offset, out); } // Get the data pointer and the validity mask of the result vector - auto tgt_ptr = FlatVector::GetDataMutable(out); - auto &out_mask = FlatVector::ValidityMutable(out); - std::unique_ptr gil; + auto tgt_ptr = FlatVector::GetData(out); + auto &out_mask = FlatVector::Validity(out); + unique_ptr gil; auto &import_cache = *DuckDBPyConnection::ImportCache(); // Loop over every row of the arrays contents @@ -400,7 +398,7 @@ void NumpyScan::Scan(ClientContext &context, PandasColumnBindData &bind_data, id } if (!py::isinstance(val)) { if (!gil) { - gil = std::make_unique(); + gil = make_uniq(); } bind_data.object_str_val.Push(std::move(py::str(val))); val = reinterpret_cast(bind_data.object_str_val.LastAddedObject().ptr()); diff --git a/src/duckdb_py/numpy/raw_array_wrapper.cpp b/src/duckdb_py/numpy/raw_array_wrapper.cpp index df89a0f6..c6c1f8d2 100644 --- a/src/duckdb_py/numpy/raw_array_wrapper.cpp +++ b/src/duckdb_py/numpy/raw_array_wrapper.cpp @@ -46,7 +46,6 @@ static idx_t GetNumpyTypeWidth(const LogicalType &type) { case LogicalTypeId::DATE: case LogicalTypeId::INTERVAL: case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_TZ_NS: return sizeof(int64_t); case LogicalTypeId::TIME: case LogicalTypeId::TIME_NS: @@ -63,7 +62,6 @@ static idx_t GetNumpyTypeWidth(const LogicalType &type) { case LogicalTypeId::ARRAY: case LogicalTypeId::VARIANT: case LogicalTypeId::GEOMETRY: - case LogicalTypeId::SQLNULL: return sizeof(PyObject *); default: throw NotImplementedException("Unsupported type \"%s\" for DuckDB -> NumPy conversion", type.ToString()); @@ -104,7 +102,6 @@ string RawArrayWrapper::DuckDBToNumpyDtype(const LogicalType &type) { return "datetime64[us]"; case LogicalTypeId::TIMESTAMP_TZ: return "datetime64[us]"; - case LogicalTypeId::TIMESTAMP_TZ_NS: case LogicalTypeId::TIMESTAMP_NS: return "datetime64[ns]"; case LogicalTypeId::TIMESTAMP_MS: @@ -129,7 +126,6 @@ string RawArrayWrapper::DuckDBToNumpyDtype(const LogicalType &type) { case LogicalTypeId::ARRAY: case LogicalTypeId::VARIANT: case LogicalTypeId::GEOMETRY: - case LogicalTypeId::SQLNULL: return "object"; case LogicalTypeId::ENUM: { auto size = EnumType::GetSize(type); @@ -151,14 +147,14 @@ string RawArrayWrapper::DuckDBToNumpyDtype(const LogicalType &type) { void RawArrayWrapper::Initialize(idx_t capacity) { string dtype = DuckDBToNumpyDtype(type); - array = NumpyArray::Allocate(py::dtype(dtype), capacity); - data = data_ptr_cast(array.MutableData()); + array = py::array(py::dtype(dtype), capacity); + data = data_ptr_cast(array.mutable_data()); } void RawArrayWrapper::Resize(idx_t new_capacity) { vector new_shape {py::ssize_t(new_capacity)}; - array.GetArray().resize(new_shape, false); - data = data_ptr_cast(array.MutableData()); + array.resize(new_shape, false); + data = data_ptr_cast(array.mutable_data()); } } // namespace duckdb diff --git a/src/duckdb_py/pandas/analyzer.cpp b/src/duckdb_py/pandas/analyzer.cpp index a0fbeaf3..a91bff51 100644 --- a/src/duckdb_py/pandas/analyzer.cpp +++ b/src/duckdb_py/pandas/analyzer.cpp @@ -1,7 +1,10 @@ #include "duckdb_python/pyrelation.hpp" #include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb_python/pyresult.hpp" #include "duckdb_python/pandas/pandas_analyzer.hpp" #include "duckdb_python/python_conversion.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/helper.hpp" namespace duckdb { @@ -41,7 +44,7 @@ static bool SameTypeRealm(const LogicalType &a, const LogicalType &b) { return true; } -static bool UpgradeType(ClientContext &context, LogicalType &left, const LogicalType &right); +static bool UpgradeType(LogicalType &left, const LogicalType &right); static bool CheckTypeCompatibility(const LogicalType &left, const LogicalType &right) { if (!SameTypeRealm(left, right)) { @@ -69,12 +72,13 @@ static bool IsStructColumnValid(const LogicalType &left, const LogicalType &righ return false; } //! Compare keys of struct case-insensitively + auto compare = CaseInsensitiveStringEquality(); for (idx_t i = 0; i < left_children.size(); i++) { auto &left_child = left_children[i]; auto &right_child = right_children[i]; // keys in left and right don't match - if (left_child.first != right_child.first) { + if (!compare(left_child.first, right_child.first)) { return false; } // Types are not compatible with each other @@ -85,25 +89,24 @@ static bool IsStructColumnValid(const LogicalType &left, const LogicalType &righ return true; } -static bool CombineStructTypes(ClientContext &context, LogicalType &result, const LogicalType &input) { +static bool CombineStructTypes(LogicalType &result, const LogicalType &input) { D_ASSERT(input.id() == LogicalTypeId::STRUCT); auto &children = StructType::GetChildTypes(input); for (auto &type : children) { - if (!UpgradeType(context, result, type.second)) { + if (!UpgradeType(result, type.second)) { return false; } } return true; } -static bool SatisfiesMapConstraints(ClientContext &context, const LogicalType &left, const LogicalType &right, - LogicalType &map_value_type) { +static bool SatisfiesMapConstraints(const LogicalType &left, const LogicalType &right, LogicalType &map_value_type) { D_ASSERT(left.id() == LogicalTypeId::STRUCT && left.id() == right.id()); - if (!CombineStructTypes(context, map_value_type, left)) { + if (!CombineStructTypes(map_value_type, left)) { return false; } - if (!CombineStructTypes(context, map_value_type, right)) { + if (!CombineStructTypes(map_value_type, right)) { return false; } return true; @@ -116,7 +119,7 @@ static LogicalType ConvertStructToMap(LogicalType &map_value_type) { // This is similar to ForceMaxLogicalType but we have custom rules around combining STRUCT types // And because of that we have to avoid ForceMaxLogicalType for every nested type -static bool UpgradeType(ClientContext &context, LogicalType &left, const LogicalType &right) { +static bool UpgradeType(LogicalType &left, const LogicalType &right) { if (left.id() == LogicalTypeId::SQLNULL) { // Early out for upgrading null left = right; @@ -135,10 +138,10 @@ static bool UpgradeType(ClientContext &context, LogicalType &left, const Logical return false; } LogicalType child_type = LogicalType::SQLNULL; - if (!UpgradeType(context, child_type, ListType::GetChildType(left))) { + if (!UpgradeType(child_type, ListType::GetChildType(left))) { return false; } - if (!UpgradeType(context, child_type, ListType::GetChildType(right))) { + if (!UpgradeType(child_type, ListType::GetChildType(right))) { return false; } left = LogicalType::LIST(child_type); @@ -160,7 +163,7 @@ static bool UpgradeType(ClientContext &context, LogicalType &left, const Logical auto new_child = StructType::GetChildType(left, i); auto child_name = StructType::GetChildName(left, i); - if (!UpgradeType(context, new_child, right_child)) { + if (!UpgradeType(new_child, right_child)) { return false; } children.push_back(std::make_pair(child_name, new_child)); @@ -168,7 +171,7 @@ static bool UpgradeType(ClientContext &context, LogicalType &left, const Logical left = LogicalType::STRUCT(std::move(children)); } else { LogicalType value_type = LogicalType::SQLNULL; - if (SatisfiesMapConstraints(context, left, right, value_type)) { + if (SatisfiesMapConstraints(left, right, value_type)) { // Combine all the child types together, becoming the value_type for the resulting MAP left = ConvertStructToMap(value_type); } else { @@ -179,7 +182,7 @@ static bool UpgradeType(ClientContext &context, LogicalType &left, const Logical // Left: STRUCT, Right: MAP // Combine all the child types of the STRUCT into the value type of the MAP auto value_type = MapType::ValueType(right); - if (!CombineStructTypes(context, value_type, left)) { + if (!CombineStructTypes(value_type, left)) { return false; } left = LogicalType::MAP(LogicalType::VARCHAR, value_type); @@ -195,25 +198,25 @@ static bool UpgradeType(ClientContext &context, LogicalType &left, const Logical if (right.id() == LogicalTypeId::MAP) { // Key Type LogicalType key_type = LogicalType::SQLNULL; - if (!UpgradeType(context, key_type, MapType::KeyType(left))) { + if (!UpgradeType(key_type, MapType::KeyType(left))) { return false; } - if (!UpgradeType(context, key_type, MapType::KeyType(right))) { + if (!UpgradeType(key_type, MapType::KeyType(right))) { return false; } // Value Type LogicalType value_type = LogicalType::SQLNULL; - if (!UpgradeType(context, value_type, MapType::ValueType(left))) { + if (!UpgradeType(value_type, MapType::ValueType(left))) { return false; } - if (!UpgradeType(context, value_type, MapType::ValueType(right))) { + if (!UpgradeType(value_type, MapType::ValueType(right))) { return false; } left = LogicalType::MAP(key_type, value_type); } else if (right.id() == LogicalTypeId::STRUCT) { auto value_type = MapType::ValueType(left); - if (!CombineStructTypes(context, value_type, right)) { + if (!CombineStructTypes(value_type, right)) { return false; } left = LogicalType::MAP(LogicalType::VARCHAR, value_type); @@ -226,7 +229,7 @@ static bool UpgradeType(ClientContext &context, LogicalType &left, const Logical if (!CheckTypeCompatibility(left, right)) { return false; } - left = LogicalType::ForceMaxLogicalType(context, left, right); + left = LogicalType::ForceMaxLogicalType(left, right); return true; } } @@ -247,7 +250,7 @@ LogicalType PandasAnalyzer::GetListType(py::object &ele, bool &can_convert) { if (!i) { list_type = item_type; } else { - if (!UpgradeType(context, list_type, item_type)) { + if (!UpgradeType(list_type, item_type)) { can_convert = false; } } @@ -270,7 +273,7 @@ static bool StructKeysAreEqual(idx_t row, const child_list_t &refer for (idx_t i = 0; i < reference.size(); i++) { auto &ref = reference[i].first; auto &comp = compare[i].first; - if (ref != comp) { + if (!duckdb::CaseInsensitiveStringEquality()(ref, comp)) { return false; } } @@ -338,7 +341,7 @@ LogicalType PandasAnalyzer::DictToStruct(const PyDictionary &dict, bool &can_con auto dict_key = dict.keys.attr("__getitem__")(i); //! Have to already transform here because the child_list needs a string as key - auto key = Identifier(py::str(dict_key)); + auto key = string(py::str(dict_key)); auto dict_val = dict.values.attr("__getitem__")(i); auto val = GetItemType(dict_val, can_convert); @@ -480,7 +483,7 @@ LogicalType PandasAnalyzer::InnerAnalyze(py::object column, bool &can_convert, i auto next_item_type = GetItemType(obj, can_convert); types.push_back(next_item_type); - if (!can_convert || !UpgradeType(context, item_type, next_item_type)) { + if (!can_convert || !UpgradeType(item_type, next_item_type)) { can_convert = false; return next_item_type; } diff --git a/src/duckdb_py/pandas/bind.cpp b/src/duckdb_py/pandas/bind.cpp index edc85132..4e40c20e 100644 --- a/src/duckdb_py/pandas/bind.cpp +++ b/src/duckdb_py/pandas/bind.cpp @@ -1,7 +1,6 @@ #include "duckdb_python/pandas/pandas_bind.hpp" #include "duckdb_python/pandas/pandas_analyzer.hpp" #include "duckdb_python/pandas/column/pandas_numpy_column.hpp" -#include "duckdb_python/numpy/numpy_array.hpp" #include "duckdb_python/pyconnection/pyconnection.hpp" namespace duckdb { @@ -45,7 +44,8 @@ struct PandasDataFrameBind { }; // namespace -static LogicalType BindColumn(ClientContext &context, PandasBindColumn &column_p, PandasColumnBindData &bind_data) { +static LogicalType BindColumn(PandasBindColumn &column_p, PandasColumnBindData &bind_data, + const ClientContext &context) { LogicalType column_type; auto &column = column_p.handle; @@ -54,54 +54,54 @@ static LogicalType BindColumn(ClientContext &context, PandasBindColumn &column_p if (column_has_mask) { // masked object, fetch the internal data and mask array - bind_data.mask = std::make_unique(NumpyArray(column.attr("array").attr("_mask"))); + bind_data.mask = make_uniq(column.attr("array").attr("_mask")); } if (bind_data.numpy_type.type == NumpyNullableType::CATEGORY) { // for category types, we create an ENUM type for string or use the converted numpy type for the rest D_ASSERT(py::hasattr(column, "cat")); D_ASSERT(py::hasattr(column.attr("cat"), "categories")); - NumpyArray categories(column.attr("cat").attr("categories")); - auto categories_pd_type = ConvertNumpyType(categories.GetArray().attr("dtype")); + auto categories = py::array(column.attr("cat").attr("categories")); + auto categories_pd_type = ConvertNumpyType(categories.attr("dtype")); if (categories_pd_type.type == NumpyNullableType::OBJECT) { // Let's hope the object type is a string. bind_data.numpy_type.type = NumpyNullableType::CATEGORY; - vector enum_entries = py::cast>(categories.GetArray()); + vector enum_entries = py::cast>(categories); idx_t size = enum_entries.size(); Vector enum_entries_vec(LogicalType::VARCHAR, size); - auto enum_entries_ptr = FlatVector::GetDataMutable(enum_entries_vec); + auto enum_entries_ptr = FlatVector::GetData(enum_entries_vec); for (idx_t i = 0; i < size; i++) { enum_entries_ptr[i] = StringVector::AddStringOrBlob(enum_entries_vec, enum_entries[i]); } D_ASSERT(py::hasattr(column.attr("cat"), "codes")); column_type = LogicalType::ENUM(enum_entries_vec, size); - NumpyArray pandas_col(column.attr("cat").attr("codes")); - bind_data.internal_categorical_type = string(py::str(pandas_col.GetArray().attr("dtype"))); - bind_data.pandas_col = std::make_unique(std::move(pandas_col)); + auto pandas_col = py::array(column.attr("cat").attr("codes")); + bind_data.internal_categorical_type = string(py::str(pandas_col.attr("dtype"))); + bind_data.pandas_col = make_uniq(pandas_col); } else { - NumpyArray pandas_col(column.attr("to_numpy")()); - auto numpy_type = pandas_col.GetArray().attr("dtype"); - bind_data.pandas_col = std::make_unique(std::move(pandas_col)); + auto pandas_col = py::array(column.attr("to_numpy")()); + auto numpy_type = pandas_col.attr("dtype"); + bind_data.pandas_col = make_uniq(pandas_col); // for category types (non-strings), we use the converted numpy type bind_data.numpy_type = ConvertNumpyType(numpy_type); column_type = NumpyToLogicalType(bind_data.numpy_type); } } else if (bind_data.numpy_type.type == NumpyNullableType::FLOAT_16) { auto pandas_array = column.attr("array"); - bind_data.pandas_col = std::make_unique(NumpyArray(column.attr("to_numpy")("float32"))); + bind_data.pandas_col = make_uniq(py::array(column.attr("to_numpy")("float32"))); bind_data.numpy_type.type = NumpyNullableType::FLOAT_32; column_type = NumpyToLogicalType(bind_data.numpy_type); } else { auto pandas_array = column.attr("array"); if (py::hasattr(pandas_array, "_data")) { // This means we can access the numpy array directly - bind_data.pandas_col = std::make_unique(NumpyArray(column.attr("array").attr("_data"))); + bind_data.pandas_col = make_uniq(column.attr("array").attr("_data")); } else if (py::hasattr(pandas_array, "asi8")) { // This is a datetime object, has the option to get the array as int64_t's - bind_data.pandas_col = std::make_unique(NumpyArray(pandas_array.attr("asi8"))); + bind_data.pandas_col = make_uniq(py::array(pandas_array.attr("asi8"))); } else { // Otherwise we have to get it through 'to_numpy()' - bind_data.pandas_col = std::make_unique(NumpyArray(column.attr("to_numpy")())); + bind_data.pandas_col = make_uniq(py::array(column.attr("to_numpy")())); } column_type = NumpyToLogicalType(bind_data.numpy_type); } @@ -115,7 +115,7 @@ static LogicalType BindColumn(ClientContext &context, PandasBindColumn &column_p return column_type; } -void Pandas::Bind(ClientContext &context, py::handle df_p, vector &bind_columns, +void Pandas::Bind(const ClientContext &context, py::handle df_p, vector &bind_columns, vector &return_types, vector &names) { PandasDataFrameBind df(df_p); @@ -140,7 +140,7 @@ void Pandas::Bind(ClientContext &context, py::handle df_p, vectorBackend(); switch (backend) { case PandasColumnBackend::NUMPY: { - NumpyScan::Scan(context, bind_data, count, offset, out); + NumpyScan::Scan(bind_data, count, offset, out); break; } default: { @@ -194,13 +194,13 @@ void PandasScanFunction::PandasScanFunc(ClientContext &context, TableFunctionInp } } idx_t this_count = std::min((idx_t)STANDARD_VECTOR_SIZE, state.end - state.start); - output.SetChildCardinality(this_count); + output.SetCardinality(this_count); for (idx_t idx = 0; idx < state.column_ids.size(); idx++) { auto col_idx = state.column_ids[idx]; if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { output.data[idx].Sequence(state.start, 1, this_count); } else { - PandasBackendScanSwitch(context, data.pandas_bind_data[col_idx], this_count, state.start, output.data[idx]); + PandasBackendScanSwitch(data.pandas_bind_data[col_idx], this_count, state.start, output.data[idx]); } } state.start += this_count; diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index 6ad90bce..bfab3f37 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -1,6 +1,10 @@ #include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb/catalog/default/default_types.hpp" #include "duckdb/common/arrow/arrow.hpp" +#include "duckdb/common/enums/file_compression_type.hpp" +#include "duckdb/common/enums/profiler_format.hpp" +#include "duckdb/common/printer.hpp" #include "duckdb/common/types.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/function/table/read_csv.hpp" @@ -14,9 +18,12 @@ #include "duckdb/main/relation/read_json_relation.hpp" #include "duckdb/main/relation/value_relation.hpp" #include "duckdb/main/relation/view_relation.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" +#include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" #include "duckdb/parser/parser.hpp" #include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb_python/arrow/arrow_array_stream.hpp" #include "duckdb_python/map.hpp" @@ -26,42 +33,45 @@ #include "duckdb_python/pyresult.hpp" #include "duckdb_python/python_conversion.hpp" #include "duckdb_python/numpy/numpy_type.hpp" -#include "duckdb_python/numpy/numpy_array.hpp" +#include "duckdb/main/prepared_statement.hpp" #include "duckdb_python/jupyter_progress_bar_display.hpp" #include "duckdb_python/pyfilesystem.hpp" +#include "duckdb/main/client_config.hpp" +#include "duckdb/function/table/read_csv.hpp" +#include "duckdb/common/enums/file_compression_type.hpp" +#include "duckdb/catalog/default/default_types.hpp" +#include "duckdb/main/relation/value_relation.hpp" +#include "duckdb_python/filesystem_object.hpp" #include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" #include "duckdb/function/scalar_function.hpp" +#include "duckdb_python/pandas/pandas_scan.hpp" #include "duckdb_python/python_objects.hpp" #include "duckdb/function/function.hpp" #include "duckdb_python/pybind11/conversions/exception_handling_enum.hpp" #include "duckdb/parser/parsed_data/drop_info.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" #include "duckdb/main/pending_query_result.hpp" +#include "duckdb/parser/keyword_helper.hpp" #include "duckdb_python/python_replacement_scan.hpp" #include "duckdb/common/shared_ptr.hpp" #include "duckdb/main/materialized_query_result.hpp" #include "duckdb/main/stream_query_result.hpp" #include "duckdb/main/relation/materialized_relation.hpp" +#include "duckdb/main/relation/query_relation.hpp" #include "duckdb/parser/statement/load_statement.hpp" #include "duckdb_python/expression/pyexpression.hpp" -#include "duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp" -namespace duckdb { +#include -// All process-global module state lives in one struct, reached only through GetModuleState(). -// This is the single seam to retarget for PEP 489 multi-phase init (per-module state via -// PyModule_GetState); call sites never touch the storage directly. -struct DuckDBPyModuleState { - DefaultConnectionHolder default_connection; - DBInstanceCache instance_cache; - std::shared_ptr import_cache; - PythonEnvironmentType environment = PythonEnvironmentType::NORMAL; - std::string formatted_python_version; -}; +#include "duckdb/common/printer.hpp" -static DuckDBPyModuleState &GetModuleState() { - static DuckDBPyModuleState state; // NOLINT: allow global - sole module-state seam (future: PyModule_GetState) - return state; -} +namespace duckdb { + +DefaultConnectionHolder DuckDBPyConnection::default_connection; // NOLINT: allow global +DBInstanceCache instance_cache; // NOLINT: allow global +shared_ptr DuckDBPyConnection::import_cache = nullptr; // NOLINT: allow global +PythonEnvironmentType DuckDBPyConnection::environment = PythonEnvironmentType::NORMAL; // NOLINT: allow global +std::string DuckDBPyConnection::formatted_python_version = ""; DuckDBPyConnection::~DuckDBPyConnection() { try { @@ -81,15 +91,15 @@ DuckDBPyConnection::~DuckDBPyConnection() { } } -std::unique_ptr DuckDBPyConnection::CreateRelation(shared_ptr rel) { - auto py_rel = std::make_unique(std::move(rel)); +unique_ptr DuckDBPyConnection::CreateRelation(shared_ptr rel) { + auto py_rel = make_uniq(std::move(rel)); py::gil_scoped_acquire gil; py_rel->SetConnectionOwner(py::cast(shared_from_this())); return py_rel; } -std::unique_ptr DuckDBPyConnection::CreateRelation(std::shared_ptr result) { - auto py_rel = std::make_unique(std::move(result)); +unique_ptr DuckDBPyConnection::CreateRelation(shared_ptr result) { + auto py_rel = make_uniq(std::move(result)); py::gil_scoped_acquire gil; py_rel->SetConnectionOwner(py::cast(shared_from_this())); return py_rel; @@ -101,14 +111,14 @@ void DuckDBPyConnection::DetectEnvironment() { py::object version_info = sys.attr("version_info"); int major = py::cast(version_info.attr("major")); int minor = py::cast(version_info.attr("minor")); - GetModuleState().formatted_python_version = std::to_string(major) + "." + std::to_string(minor); + DuckDBPyConnection::formatted_python_version = std::to_string(major) + "." + std::to_string(minor); // If __main__ does not have a __file__ attribute, we are in interactive mode auto main_module = py::module_::import("__main__"); if (py::hasattr(main_module, "__file__")) { return; } - GetModuleState().environment = PythonEnvironmentType::INTERACTIVE; + DuckDBPyConnection::environment = PythonEnvironmentType::INTERACTIVE; if (!ModuleIsLoaded()) { return; } @@ -126,7 +136,7 @@ void DuckDBPyConnection::DetectEnvironment() { } py::dict ipython_config = ipython.attr("config"); if (ipython_config.contains("IPKernelApp")) { - GetModuleState().environment = PythonEnvironmentType::JUPYTER; + DuckDBPyConnection::environment = PythonEnvironmentType::JUPYTER; } return; } @@ -137,17 +147,17 @@ bool DuckDBPyConnection::DetectAndGetEnvironment() { } bool DuckDBPyConnection::IsJupyter() { - return GetModuleState().environment == PythonEnvironmentType::JUPYTER; + return DuckDBPyConnection::environment == PythonEnvironmentType::JUPYTER; } std::string DuckDBPyConnection::FormattedPythonVersion() { - return GetModuleState().formatted_python_version; + return DuckDBPyConnection::formatted_python_version; } // NOTE: this function is generated by tools/pythonpkg/scripts/generate_connection_methods.py. // Do not edit this function manually, your changes will be overwritten! -static void InitializeConnectionMethods(py::class_> &m) { +static void InitializeConnectionMethods(py::class_> &m) { m.def("cursor", &DuckDBPyConnection::Cursor, "Create a duplicate of the current connection"); m.def("register_filesystem", &DuckDBPyConnection::RegisterFilesystem, "Register a fsspec compliant filesystem", py::arg("filesystem")); @@ -355,23 +365,21 @@ py::list DuckDBPyConnection::ListFilesystems() { return names; } -py::str DuckDBPyConnection::GetProfilingInformation(const string &format) { +py::str DuckDBPyConnection::GetProfilingInformation(const py::str &format) { // We want to expose ProfilerPrintFormat as a string to Python users ProfilerPrintFormat format_enum; - if (format == "html") { - format_enum = ProfilerPrintFormat::HTML(); + if (format == "query_tree") { + format_enum = ProfilerPrintFormat::QUERY_TREE; } else if (format == "json") { - format_enum = ProfilerPrintFormat::JSON(); + format_enum = ProfilerPrintFormat::JSON; + } else if (format == "query_tree_optimizer") { + format_enum = ProfilerPrintFormat::QUERY_TREE_OPTIMIZER; + } else if (format == "no_output") { + format_enum = ProfilerPrintFormat::NO_OUTPUT; + } else if (format == "html") { + format_enum = ProfilerPrintFormat::HTML; } else if (format == "graphviz") { - format_enum = ProfilerPrintFormat::Graphviz(); - } else if (format == "default") { - format_enum = ProfilerPrintFormat::Default(); - } else if (format == "mermaid") { - format_enum = ProfilerPrintFormat::Mermaid(); - } else if (format == "text") { - format_enum = ProfilerPrintFormat::Text(); - } else if (format == "yaml") { - format_enum = ProfilerPrintFormat::YAML(); + format_enum = ProfilerPrintFormat::GRAPHVIZ; } else { throw InvalidInputException( "Invalid ProfilerPrintFormat string: " + std::string(format) + @@ -397,7 +405,7 @@ py::list DuckDBPyConnection::ExtractStatements(const string &query) { auto &connection = con.GetConnection(); auto statements = connection.ExtractStatements(query); for (auto &statement : statements) { - result.append(std::make_unique(std::move(statement))); + result.append(make_uniq(std::move(statement))); } return result; } @@ -408,7 +416,7 @@ bool DuckDBPyConnection::FileSystemIsRegistered(const string &name) { return std::find(subsystems.begin(), subsystems.end(), name) != subsystems.end(); } -std::shared_ptr DuckDBPyConnection::UnregisterUDF(const string &name) { +shared_ptr DuckDBPyConnection::UnregisterUDF(const string &name) { auto entry = registered_functions.find(name); if (entry == registered_functions.end()) { // Not registered or already unregistered @@ -424,7 +432,7 @@ std::shared_ptr DuckDBPyConnection::UnregisterUDF(const stri auto &catalog = Catalog::GetCatalog(context, SYSTEM_CATALOG); DropInfo info; info.type = CatalogType::SCALAR_FUNCTION_ENTRY; - info.NameMutable() = Identifier(name); + info.name = name; info.allow_drop_internal = true; info.cascade = false; info.if_not_found = OnEntryNotFound::THROW_EXCEPTION; @@ -435,9 +443,9 @@ std::shared_ptr DuckDBPyConnection::UnregisterUDF(const stri return shared_from_this(); } -std::shared_ptr +shared_ptr DuckDBPyConnection::RegisterScalarUDF(const string &name, const py::function &udf, const py::object ¶meters_p, - const std::shared_ptr &return_type_p, PythonUDFType type, + const shared_ptr &return_type_p, PythonUDFType type, FunctionNullHandling null_handling, PythonExceptionHandling exception_handling, bool side_effects) { auto &connection = con.GetConnection(); @@ -466,7 +474,7 @@ DuckDBPyConnection::RegisterScalarUDF(const string &name, const py::function &ud void DuckDBPyConnection::Initialize(py::handle &m) { auto connection_module = - py::class_>(m, "DuckDBPyConnection"); + py::class_>(m, "DuckDBPyConnection", py::module_local()); connection_module.def("__enter__", &DuckDBPyConnection::Enter) .def("__exit__", &DuckDBPyConnection::Exit, py::arg("exc_type"), py::arg("exc"), py::arg("traceback")); @@ -480,7 +488,7 @@ void DuckDBPyConnection::Initialize(py::handle &m) { DuckDBPyConnection::ImportCache(); } -std::shared_ptr DuckDBPyConnection::ExecuteMany(const py::object &query, py::object params_p) { +shared_ptr DuckDBPyConnection::ExecuteMany(const py::object &query, py::object params_p) { py::gil_scoped_acquire gil; ConnectionLockGuard conn_lock(*this); con.SetResult(nullptr); @@ -520,7 +528,7 @@ std::shared_ptr DuckDBPyConnection::ExecuteMany(const py::ob if (query_result) { // Don't use CreateRelation here — the result is stored inside the connection, // so setting connection_owner would create a ref cycle (connection → result → connection). - con.SetResult(std::make_unique(std::make_shared(std::move(query_result)))); + con.SetResult(make_uniq(make_shared_ptr(std::move(query_result)))); } return shared_from_this(); @@ -581,9 +589,9 @@ py::list TransformNamedParameters(const case_insensitive_map_t &named_par return new_params; } -identifier_map_t TransformPreparedParameters(ClientContext &context, const py::object ¶ms, - optional_ptr prep = {}) { - identifier_map_t named_values; +case_insensitive_map_t TransformPreparedParameters(const py::object ¶ms, + optional_ptr prep = {}) { + case_insensitive_map_t named_values; if (py::is_list_like(params)) { if (prep && prep->named_param_map.size() != py::len(params)) { if (py::len(params) == 0) { @@ -593,15 +601,15 @@ identifier_map_t TransformPreparedParameters(ClientContext & throw InvalidInputException("Prepared statement needs %d parameters, %d given", prep->named_param_map.size(), py::len(params)); } - auto unnamed_values = DuckDBPyConnection::TransformPythonParamList(context, params); + auto unnamed_values = DuckDBPyConnection::TransformPythonParamList(params); for (idx_t i = 0; i < unnamed_values.size(); i++) { auto &value = unnamed_values[i]; - auto identifier = Identifier(std::to_string(i + 1)); + auto identifier = std::to_string(i + 1); named_values[identifier] = BoundParameterData(std::move(value)); } } else if (py::is_dict_like(params)) { auto dict = py::cast(params); - named_values = DuckDBPyConnection::TransformPythonParamDict(context, dict); + named_values = DuckDBPyConnection::TransformPythonParamDict(dict); } else { throw InvalidInputException("Prepared parameters can only be passed as a list or a dictionary"); } @@ -628,10 +636,9 @@ unique_ptr DuckDBPyConnection::ExecuteInternal(PreparedStatement &p if (params.is_none()) { params = py::list(); } - auto &context = *con.GetConnection().context; // Execute the prepared statement with the prepared parameters - auto named_values = TransformPreparedParameters(context, params, prep); + auto named_values = TransformPreparedParameters(params, prep); unique_ptr res; { D_ASSERT(py::gil_check()); @@ -656,9 +663,8 @@ unique_ptr DuckDBPyConnection::PrepareAndExecuteInternal(unique_ptr if (params.is_none()) { params = py::list(); } - auto &context = *con.GetConnection().context; - auto named_values = TransformPreparedParameters(context, params); + auto named_values = TransformPreparedParameters(params); unique_ptr res; { @@ -682,10 +688,10 @@ unique_ptr DuckDBPyConnection::PrepareAndExecuteInternal(unique_ptr } vector> DuckDBPyConnection::GetStatements(const py::object &query) { - if (py::isinstance(query)) { - auto &statement_obj = py::cast(query); + shared_ptr statement_obj; + if (py::try_cast(query, statement_obj)) { vector> result; - result.push_back(statement_obj.GetStatement()); + result.push_back(statement_obj->GetStatement()); return result; } if (py::isinstance(query)) { @@ -697,11 +703,11 @@ vector> DuckDBPyConnection::GetStatements(const py::obj throw InvalidInputException("Please provide either a DuckDBPyStatement or a string representing the query"); } -std::shared_ptr DuckDBPyConnection::ExecuteFromString(const string &query) { +shared_ptr DuckDBPyConnection::ExecuteFromString(const string &query) { return Execute(py::str(query)); } -std::shared_ptr DuckDBPyConnection::Execute(const py::object &query, py::object params) { +shared_ptr DuckDBPyConnection::Execute(const py::object &query, py::object params) { py::gil_scoped_acquire gil; ConnectionLockGuard conn_lock(*this); con.SetResult(nullptr); @@ -724,13 +730,13 @@ std::shared_ptr DuckDBPyConnection::Execute(const py::object if (res) { // Don't use CreateRelation here — the result is stored inside the connection, // so setting connection_owner would create a ref cycle (connection → result → connection). - con.SetResult(std::make_unique(std::make_shared(std::move(res)))); + con.SetResult(make_uniq(make_shared_ptr(std::move(res)))); } return shared_from_this(); } -std::shared_ptr DuckDBPyConnection::Append(const string &name, const PandasDataFrame &value, - bool by_name) { +shared_ptr DuckDBPyConnection::Append(const string &name, const PandasDataFrame &value, + bool by_name) { RegisterPythonObject("__append_df", value); string columns = ""; if (by_name) { @@ -754,29 +760,29 @@ std::shared_ptr DuckDBPyConnection::Append(const string &nam return Execute(py::str(sql_query)); } -std::shared_ptr DuckDBPyConnection::RegisterPythonObject(const string &name, - const py::object &python_object) { +shared_ptr DuckDBPyConnection::RegisterPythonObject(const string &name, + const py::object &python_object) { auto &connection = con.GetConnection(); auto &client = *connection.context; auto object = PythonReplacementScan::ReplacementObject(python_object, name, client); auto view_rel = make_shared_ptr(connection.context, std::move(object), name); bool replace = registered_objects.count(name); - view_rel->CreateView(Identifier(name), replace, true); + view_rel->CreateView(name, replace, true); registered_objects.insert(name); return shared_from_this(); } -static void ParseMultiFileOptions(ClientContext &context, named_parameter_map_t &options, - const Optional &filename, const Optional &hive_partitioning, +static void ParseMultiFileOptions(named_parameter_map_t &options, const Optional &filename, + const Optional &hive_partitioning, const Optional &union_by_name, const Optional &hive_types, const Optional &hive_types_autocast) { if (!py::none().is(filename)) { - auto val = TransformPythonValue(context, filename); + auto val = TransformPythonValue(filename); options["filename"] = val; } if (!py::none().is(hive_types)) { - auto val = TransformPythonValue(context, hive_types); + auto val = TransformPythonValue(hive_types); options["hive_types"] = val; } @@ -785,7 +791,7 @@ static void ParseMultiFileOptions(ClientContext &context, named_parameter_map_t string actual_type = py::str(py::type::of(hive_partitioning)); throw BinderException("read_json only accepts 'hive_partitioning' as a boolean, not '%s'", actual_type); } - auto val = TransformPythonValue(context, hive_partitioning, LogicalTypeId::BOOLEAN); + auto val = TransformPythonValue(hive_partitioning, LogicalTypeId::BOOLEAN); options["hive_partitioning"] = val; } @@ -794,7 +800,7 @@ static void ParseMultiFileOptions(ClientContext &context, named_parameter_map_t string actual_type = py::str(py::type::of(union_by_name)); throw BinderException("read_json only accepts 'union_by_name' as a boolean, not '%s'", actual_type); } - auto val = TransformPythonValue(context, union_by_name, LogicalTypeId::BOOLEAN); + auto val = TransformPythonValue(union_by_name, LogicalTypeId::BOOLEAN); options["union_by_name"] = val; } @@ -803,12 +809,12 @@ static void ParseMultiFileOptions(ClientContext &context, named_parameter_map_t string actual_type = py::str(py::type::of(hive_types_autocast)); throw BinderException("read_json only accepts 'hive_types_autocast' as a boolean, not '%s'", actual_type); } - auto val = TransformPythonValue(context, hive_types_autocast, LogicalTypeId::BOOLEAN); + auto val = TransformPythonValue(hive_types_autocast, LogicalTypeId::BOOLEAN); options["hive_types_autocast"] = val; } } -std::unique_ptr DuckDBPyConnection::ReadJSON( +unique_ptr DuckDBPyConnection::ReadJSON( const py::object &name_p, const Optional &columns, const Optional &sample_size, const Optional &maximum_depth, const Optional &records, const Optional &format, const Optional &date_format, const Optional ×tamp_format, @@ -822,13 +828,11 @@ std::unique_ptr DuckDBPyConnection::ReadJSON( named_parameter_map_t options; auto &connection = con.GetConnection(); - auto &context = *connection.context; auto path_like = GetPathLike(name_p); auto &name = path_like.files; auto file_like_object_wrapper = std::move(path_like.dependency); - ParseMultiFileOptions(context, options, filename, hive_partitioning, union_by_name, hive_types, - hive_types_autocast); + ParseMultiFileOptions(options, filename, hive_partitioning, union_by_name, hive_types, hive_types_autocast); if (!py::none().is(columns)) { if (!py::is_dict_like(columns)) { @@ -926,7 +930,7 @@ std::unique_ptr DuckDBPyConnection::ReadJSON( throw BinderException("read_json only accepts 'maximum_object_size' as an unsigned integer, not '%s'", actual_type); } - auto val = TransformPythonValue(context, maximum_object_size, LogicalTypeId::UINTEGER); + auto val = TransformPythonValue(maximum_object_size, LogicalTypeId::UINTEGER); options["maximum_object_size"] = val; } @@ -935,7 +939,7 @@ std::unique_ptr DuckDBPyConnection::ReadJSON( string actual_type = py::str(py::type::of(ignore_errors)); throw BinderException("read_json only accepts 'ignore_errors' as a boolean, not '%s'", actual_type); } - auto val = TransformPythonValue(context, ignore_errors, LogicalTypeId::BOOLEAN); + auto val = TransformPythonValue(ignore_errors, LogicalTypeId::BOOLEAN); options["ignore_errors"] = val; } @@ -945,7 +949,7 @@ std::unique_ptr DuckDBPyConnection::ReadJSON( throw BinderException("read_json only accepts 'convert_strings_to_integers' as a boolean, not '%s'", actual_type); } - auto val = TransformPythonValue(context, convert_strings_to_integers, LogicalTypeId::BOOLEAN); + auto val = TransformPythonValue(convert_strings_to_integers, LogicalTypeId::BOOLEAN); options["convert_strings_to_integers"] = val; } @@ -955,7 +959,7 @@ std::unique_ptr DuckDBPyConnection::ReadJSON( throw BinderException("read_json only accepts 'field_appearance_threshold' as a float, not '%s'", actual_type); } - auto val = TransformPythonValue(context, field_appearance_threshold, LogicalTypeId::DOUBLE); + auto val = TransformPythonValue(field_appearance_threshold, LogicalTypeId::DOUBLE); options["field_appearance_threshold"] = val; } @@ -965,7 +969,7 @@ std::unique_ptr DuckDBPyConnection::ReadJSON( throw BinderException("read_json only accepts 'map_inference_threshold' as an integer, not '%s'", actual_type); } - auto val = TransformPythonValue(context, map_inference_threshold, LogicalTypeId::BIGINT); + auto val = TransformPythonValue(map_inference_threshold, LogicalTypeId::BIGINT); options["map_inference_threshold"] = val; } @@ -974,7 +978,7 @@ std::unique_ptr DuckDBPyConnection::ReadJSON( string actual_type = py::str(py::type::of(maximum_sample_files)); throw BinderException("read_json only accepts 'maximum_sample_files' as an integer, not '%s'", actual_type); } - auto val = TransformPythonValue(context, maximum_sample_files, LogicalTypeId::BIGINT); + auto val = TransformPythonValue(maximum_sample_files, LogicalTypeId::BIGINT); options["maximum_sample_files"] = val; } @@ -1071,11 +1075,11 @@ void ConvertBooleanValue(const py::object &value, string param_name, named_param } else { throw InvalidInputException("read_csv only accepts '%s' as an integer, or a boolean", param_name); } - bind_parameters[Identifier(param_name)] = Value::BOOLEAN(converted_value); + bind_parameters[param_name] = Value::BOOLEAN(converted_value); } } -std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_p, py::kwargs &kwargs) { +unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_p, py::kwargs &kwargs) { py::object header = py::none(); py::object strict_mode = py::none(); py::object auto_detect = py::none(); @@ -1208,15 +1212,13 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & } auto &connection = con.GetConnection(); - auto &context = *connection.context; CSVReaderOptions options; auto path_like = GetPathLike(name_p); auto &name = path_like.files; auto file_like_object_wrapper = std::move(path_like.dependency); named_parameter_map_t bind_parameters; - ParseMultiFileOptions(context, bind_parameters, filename, hive_partitioning, union_by_name, hive_types, - hive_types_autocast); + ParseMultiFileOptions(bind_parameters, filename, hive_partitioning, union_by_name, hive_types, hive_types_autocast); // First check if the header is explicitly set // when false this affects the returned types, so it needs to be known at initialization of the relation @@ -1235,7 +1237,7 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & child_list_t struct_fields; py::dict dtype_dict = dtype; for (auto &kv : dtype_dict) { - std::shared_ptr sql_type; + shared_ptr sql_type; if (!py::try_cast(kv.second, sql_type)) { struct_fields.emplace_back(py::str(kv.first), py::str(kv.second)); } else { @@ -1248,7 +1250,7 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & vector list_values; py::list dtype_list = dtype; for (auto &child : dtype_list) { - std::shared_ptr sql_type; + shared_ptr sql_type; if (!py::try_cast(child, sql_type)) { list_values.push_back(Value(py::str(child))); } else { @@ -1439,7 +1441,7 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & throw BinderException("read_csv only accepts 'max_line_size' as a string or an integer, not '%s'", actual_type); } - auto val = TransformPythonValue(context, max_line_size, LogicalTypeId::VARCHAR); + auto val = TransformPythonValue(max_line_size, LogicalTypeId::VARCHAR); bind_parameters["max_line_size"] = val; } @@ -1448,7 +1450,7 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & string actual_type = py::str(py::type::of(auto_type_candidates)); throw BinderException("read_csv only accepts 'auto_type_candidates' as a list[str], not '%s'", actual_type); } - auto val = TransformPythonValue(context, auto_type_candidates, LogicalType::LIST(LogicalTypeId::VARCHAR)); + auto val = TransformPythonValue(auto_type_candidates, LogicalType::LIST(LogicalTypeId::VARCHAR)); bind_parameters["auto_type_candidates"] = val; } @@ -1457,7 +1459,7 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & string actual_type = py::str(py::type::of(ignore_errors)); throw BinderException("read_csv only accepts 'ignore_errors' as a bool, not '%s'", actual_type); } - auto val = TransformPythonValue(context, ignore_errors, LogicalTypeId::BOOLEAN); + auto val = TransformPythonValue(ignore_errors, LogicalTypeId::BOOLEAN); bind_parameters["ignore_errors"] = val; } @@ -1466,7 +1468,7 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & string actual_type = py::str(py::type::of(store_rejects)); throw BinderException("read_csv only accepts 'store_rejects' as a bool, not '%s'", actual_type); } - auto val = TransformPythonValue(context, store_rejects, LogicalTypeId::BOOLEAN); + auto val = TransformPythonValue(store_rejects, LogicalTypeId::BOOLEAN); bind_parameters["store_rejects"] = val; } @@ -1475,7 +1477,7 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & string actual_type = py::str(py::type::of(rejects_table)); throw BinderException("read_csv only accepts 'rejects_table' as a string, not '%s'", actual_type); } - auto val = TransformPythonValue(context, rejects_table, LogicalTypeId::VARCHAR); + auto val = TransformPythonValue(rejects_table, LogicalTypeId::VARCHAR); bind_parameters["rejects_table"] = val; } @@ -1484,7 +1486,7 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & string actual_type = py::str(py::type::of(rejects_scan)); throw BinderException("read_csv only accepts 'rejects_scan' as a string, not '%s'", actual_type); } - auto val = TransformPythonValue(context, rejects_scan, LogicalTypeId::VARCHAR); + auto val = TransformPythonValue(rejects_scan, LogicalTypeId::VARCHAR); bind_parameters["rejects_scan"] = val; } @@ -1493,7 +1495,7 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & string actual_type = py::str(py::type::of(rejects_limit)); throw BinderException("read_csv only accepts 'rejects_limit' as an int, not '%s'", actual_type); } - auto val = TransformPythonValue(context, rejects_limit, LogicalTypeId::BIGINT); + auto val = TransformPythonValue(rejects_limit, LogicalTypeId::BIGINT); bind_parameters["rejects_limit"] = val; } @@ -1502,7 +1504,7 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & string actual_type = py::str(py::type::of(force_not_null)); throw BinderException("read_csv only accepts 'force_not_null' as a list[str], not '%s'", actual_type); } - auto val = TransformPythonValue(context, force_not_null, LogicalType::LIST(LogicalTypeId::VARCHAR)); + auto val = TransformPythonValue(force_not_null, LogicalType::LIST(LogicalTypeId::VARCHAR)); bind_parameters["force_not_null"] = val; } @@ -1511,7 +1513,7 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & string actual_type = py::str(py::type::of(buffer_size)); throw BinderException("read_csv only accepts 'buffer_size' as a list[str], not '%s'", actual_type); } - auto val = TransformPythonValue(context, buffer_size, LogicalTypeId::UBIGINT); + auto val = TransformPythonValue(buffer_size, LogicalTypeId::UBIGINT); bind_parameters["buffer_size"] = val; } @@ -1520,7 +1522,7 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & string actual_type = py::str(py::type::of(decimal)); throw BinderException("read_csv only accepts 'decimal' as a string, not '%s'", actual_type); } - auto val = TransformPythonValue(context, decimal, LogicalTypeId::VARCHAR); + auto val = TransformPythonValue(decimal, LogicalTypeId::VARCHAR); bind_parameters["decimal_separator"] = val; } @@ -1529,7 +1531,7 @@ std::unique_ptr DuckDBPyConnection::ReadCSV(const py::object & string actual_type = py::str(py::type::of(allow_quoted_nulls)); throw BinderException("read_csv only accepts 'allow_quoted_nulls' as a bool, not '%s'", actual_type); } - auto val = TransformPythonValue(context, allow_quoted_nulls, LogicalTypeId::BOOLEAN); + auto val = TransformPythonValue(allow_quoted_nulls, LogicalTypeId::BOOLEAN); bind_parameters["allow_quoted_nulls"] = val; } @@ -1595,8 +1597,7 @@ void DuckDBPyConnection::ExecuteImmediately(vector> sta } } -std::unique_ptr DuckDBPyConnection::RunQuery(const py::object &query, string alias, - py::object params) { +unique_ptr DuckDBPyConnection::RunQuery(const py::object &query, string alias, py::object params) { auto &connection = con.GetConnection(); if (alias.empty()) { alias = "unnamed_relation_" + StringUtil::GenerateRandomName(16); @@ -1651,28 +1652,24 @@ std::unique_ptr DuckDBPyConnection::RunQuery(const py::object res = stream_result.Materialize(); } auto &materialized_result = res->Cast(); - vector col_names(res->names.size()); - std::transform(res->names.begin(), res->names.end(), col_names.begin(), - [](string &name) { return Identifier(name); }); relation = make_shared_ptr(connection.context, materialized_result.TakeCollection(), - col_names, Identifier(alias)); + res->names, alias); } return CreateRelation(std::move(relation)); } -std::unique_ptr DuckDBPyConnection::Table(const string &tname) { +unique_ptr DuckDBPyConnection::Table(const string &tname) { auto &connection = con.GetConnection(); auto qualified_name = QualifiedName::Parse(tname); - if (qualified_name.Schema().empty()) { - qualified_name.SchemaMutable() = DEFAULT_SCHEMA; + if (qualified_name.schema.empty()) { + qualified_name.schema = DEFAULT_SCHEMA; } try { - return CreateRelation( - connection.Table(qualified_name.Catalog(), qualified_name.Schema(), qualified_name.Name())); + return CreateRelation(connection.Table(qualified_name.catalog, qualified_name.schema, qualified_name.name)); } catch (const CatalogException &) { // CatalogException will be of the type '... is not a table' // Not a table in the database, make a query relation that can perform replacement scans - auto sql_query = StringUtil::Format("from %s", SQLIdentifier::ToString(tname)); + auto sql_query = StringUtil::Format("from %s", KeywordHelper::WriteOptionallyQuoted(tname)); return RunQuery(py::str(sql_query), tname); } } @@ -1686,8 +1683,8 @@ static vector> ValueListFromExpressions(const py::a for (idx_t i = 0; i < arg_count; i++) { py::handle arg = expressions[i]; - std::shared_ptr py_expr; - if (!py::try_cast>(arg, py_expr)) { + shared_ptr py_expr; + if (!py::try_cast>(arg, py_expr)) { throw InvalidInputException("Please provide arguments of type Expression!"); } auto expr = py_expr->GetExpression().Copy(); @@ -1722,9 +1719,8 @@ static vector>> ValueListsFromTuples(const p return result; } -std::unique_ptr DuckDBPyConnection::Values(const py::args &args) { +unique_ptr DuckDBPyConnection::Values(const py::args &args) { auto &connection = con.GetConnection(); - auto &context = *connection.context; auto arg_count = args.size(); if (arg_count == 0) { @@ -1734,7 +1730,7 @@ std::unique_ptr DuckDBPyConnection::Values(const py::args &arg D_ASSERT(py::gil_check()); py::handle first_arg = args[0]; if (arg_count == 1 && py::isinstance(first_arg)) { - vector> values {DuckDBPyConnection::TransformPythonParamList(context, first_arg)}; + vector> values {DuckDBPyConnection::TransformPythonParamList(first_arg)}; return CreateRelation(connection.Values(values)); } else { vector>> expressions; @@ -1748,14 +1744,13 @@ std::unique_ptr DuckDBPyConnection::Values(const py::args &arg } } -std::unique_ptr DuckDBPyConnection::View(const string &vname) { +unique_ptr DuckDBPyConnection::View(const string &vname) { auto &connection = con.GetConnection(); - return CreateRelation(connection.View(Identifier(vname))); + return CreateRelation(connection.View(vname)); } -std::unique_ptr DuckDBPyConnection::TableFunction(const string &fname, py::object params) { +unique_ptr DuckDBPyConnection::TableFunction(const string &fname, py::object params) { auto &connection = con.GetConnection(); - auto &context = *connection.context; if (params.is_none()) { params = py::list(); } @@ -1763,11 +1758,10 @@ std::unique_ptr DuckDBPyConnection::TableFunction(const string throw InvalidInputException("'params' has to be a list of parameters"); } - return CreateRelation( - connection.TableFunction(fname, DuckDBPyConnection::TransformPythonParamList(context, params))); + return CreateRelation(connection.TableFunction(fname, DuckDBPyConnection::TransformPythonParamList(params))); } -std::unique_ptr DuckDBPyConnection::FromDF(const PandasDataFrame &value) { +unique_ptr DuckDBPyConnection::FromDF(const PandasDataFrame &value) { auto &connection = con.GetConnection(); string name = "df_" + StringUtil::GenerateRandomName(); if (PandasDataFrame::IsPyArrowBacked(value)) { @@ -1780,10 +1774,10 @@ std::unique_ptr DuckDBPyConnection::FromDF(const PandasDataFra return CreateRelation(std::move(rel)); } -std::unique_ptr DuckDBPyConnection::FromParquet(const py::object &path_or_buffer, - bool binary_as_string, bool file_row_number, - bool filename, bool hive_partitioning, - bool union_by_name, const py::object &compression) { +unique_ptr DuckDBPyConnection::FromParquet(const py::object &path_or_buffer, bool binary_as_string, + bool file_row_number, bool filename, + bool hive_partitioning, bool union_by_name, + const py::object &compression) { auto &connection = con.GetConnection(); auto path_like = GetPathLike(path_or_buffer); auto file_like_object_wrapper = std::move(path_like.dependency); @@ -1816,7 +1810,7 @@ std::unique_ptr DuckDBPyConnection::FromParquet(const py::obje return CreateRelation(parquet_relation->Alias(name)); } -std::unique_ptr DuckDBPyConnection::FromArrow(py::object &arrow_object) { +unique_ptr DuckDBPyConnection::FromArrow(py::object &arrow_object) { auto &connection = con.GetConnection(); string name = "arrow_object_" + StringUtil::GenerateRandomName(); if (!IsAcceptedArrowObject(arrow_object)) { @@ -1834,7 +1828,7 @@ unordered_set DuckDBPyConnection::GetTableNames(const string &query, boo return connection.GetTableNames(query, qualified); } -std::shared_ptr DuckDBPyConnection::UnregisterPythonObject(const string &name) { +shared_ptr DuckDBPyConnection::UnregisterPythonObject(const string &name) { auto &connection = con.GetConnection(); if (!registered_objects.count(name)) { return shared_from_this(); @@ -1842,18 +1836,18 @@ std::shared_ptr DuckDBPyConnection::UnregisterPythonObject(c D_ASSERT(py::gil_check()); py::gil_scoped_release release; // FIXME: DROP TEMPORARY VIEW? doesn't exist? - const auto quoted_name = SQLQuotedIdentifier::ToString(name); + const auto quoted_name = KeywordHelper::WriteOptionallyQuoted(name, '\"'); connection.Query("DROP VIEW " + quoted_name + ""); registered_objects.erase(name); return shared_from_this(); } -std::shared_ptr DuckDBPyConnection::Begin() { +shared_ptr DuckDBPyConnection::Begin() { ExecuteFromString("BEGIN TRANSACTION"); return shared_from_this(); } -std::shared_ptr DuckDBPyConnection::Commit() { +shared_ptr DuckDBPyConnection::Commit() { auto &connection = con.GetConnection(); if (connection.context->transaction.IsAutoCommit()) { return shared_from_this(); @@ -1862,12 +1856,12 @@ std::shared_ptr DuckDBPyConnection::Commit() { return shared_from_this(); } -std::shared_ptr DuckDBPyConnection::Rollback() { +shared_ptr DuckDBPyConnection::Rollback() { ExecuteFromString("ROLLBACK"); return shared_from_this(); } -std::shared_ptr DuckDBPyConnection::Checkpoint() { +shared_ptr DuckDBPyConnection::Checkpoint() { ExecuteFromString("CHECKPOINT"); return shared_from_this(); } @@ -1964,11 +1958,10 @@ void DuckDBPyConnection::InstallExtension(const string &extension, bool force_in void DuckDBPyConnection::LoadExtension(const string &extension) { auto &connection = con.GetConnection(); - const ExtensionLoadOptions extension_opts = {extension}; - ExtensionHelper::LoadExternalExtension(*connection.context, extension_opts); + ExtensionHelper::LoadExternalExtension(*connection.context, extension); } -std::shared_ptr DefaultConnectionHolder::Get() { +shared_ptr DefaultConnectionHolder::Get() { lock_guard guard(l); if (!connection || connection->con.ConnectionIsClosed()) { py::dict config_dict; @@ -1977,16 +1970,16 @@ std::shared_ptr DefaultConnectionHolder::Get() { return connection; } -void DefaultConnectionHolder::Set(std::shared_ptr conn) { +void DefaultConnectionHolder::Set(shared_ptr conn) { lock_guard guard(l); connection = conn; } -void DuckDBPyConnection::Cursors::AddCursor(std::shared_ptr conn) { +void DuckDBPyConnection::Cursors::AddCursor(shared_ptr conn) { lock_guard l(lock); // Clean up previously created cursors - vector> compacted_cursors; + vector> compacted_cursors; bool needs_compaction = false; for (auto &cur_p : cursors) { auto cur = cur_p.lock(); @@ -2023,8 +2016,8 @@ void DuckDBPyConnection::Cursors::ClearCursors() { cursors.clear(); } -std::shared_ptr DuckDBPyConnection::Cursor() { - auto res = std::make_shared(); +shared_ptr DuckDBPyConnection::Cursor() { + auto res = make_shared_ptr(); res->con.SetDatabase(con); res->con.SetConnection(make_uniq(res->con.GetDatabase())); cursors.AddCursor(res); @@ -2185,12 +2178,12 @@ void InstantiateNewInstance(DuckDB &db) { MapFunction map_fun; TableFunctionSet map_set(map_fun.name); - map_set.AddFunction(static_cast(std::move(map_fun))); + map_set.AddFunction(std::move(map_fun)); CreateTableFunctionInfo map_info(std::move(map_set)); map_info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; TableFunctionSet scan_set(scan_fun.name); - scan_set.AddFunction(static_cast(std::move(scan_fun))); + scan_set.AddFunction(std::move(scan_fun)); CreateTableFunctionInfo scan_info(std::move(scan_set)); scan_info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; @@ -2201,16 +2194,16 @@ void InstantiateNewInstance(DuckDB &db) { system_catalog.CreateFunction(transaction, scan_info); } -static std::shared_ptr FetchOrCreateInstance(const string &database_path, DBConfig &config) { - auto res = std::make_shared(); +static shared_ptr FetchOrCreateInstance(const string &database_path, DBConfig &config) { + auto res = make_shared_ptr(); bool cache_instance = database_path != ":memory:" && !database_path.empty(); config.replacement_scans.emplace_back(PythonReplacementScan::Replace); { D_ASSERT(py::gil_check()); py::gil_scoped_release release; unique_lock lock(res->py_connection_lock); - auto database = GetModuleState().instance_cache.GetOrCreateInstance(database_path, config, cache_instance, - InstantiateNewInstance); + auto database = + instance_cache.GetOrCreateInstance(database_path, config, cache_instance, InstantiateNewInstance); res->con.SetDatabase(std::move(database)); res->con.SetConnection(make_uniq(res->con.GetDatabase())); } @@ -2239,8 +2232,8 @@ static string GetPathString(const py::object &path) { throw InvalidInputException("Please provide either a str or a pathlib.Path, not %s", actual_type); } -std::shared_ptr DuckDBPyConnection::Connect(const py::object &database_p, bool read_only, - const py::dict &config_options) { +shared_ptr DuckDBPyConnection::Connect(const py::object &database_p, bool read_only, + const py::dict &config_options) { auto config_dict = TransformPyConfigDict(config_options); auto database = GetPathString(database_p); if (IsDefaultConnectionString(database, read_only, config_dict)) { @@ -2271,41 +2264,38 @@ std::shared_ptr DuckDBPyConnection::Connect(const py::object return res; } -vector DuckDBPyConnection::TransformPythonParamList(ClientContext &context, const py::handle ¶ms) { +vector DuckDBPyConnection::TransformPythonParamList(const py::handle ¶ms) { vector args; args.reserve(py::len(params)); for (auto param : params) { - args.emplace_back(TransformPythonValue(context, param, LogicalType::UNKNOWN, false)); + args.emplace_back(TransformPythonValue(param, LogicalType::UNKNOWN, false)); } return args; } -identifier_map_t DuckDBPyConnection::TransformPythonParamDict(ClientContext &context, - const py::dict ¶ms) { - identifier_map_t args; +case_insensitive_map_t DuckDBPyConnection::TransformPythonParamDict(const py::dict ¶ms) { + case_insensitive_map_t args; for (auto pair : params) { auto &key = pair.first; auto &value = pair.second; - args[Identifier(py::str(key))] = - BoundParameterData(TransformPythonValue(context, value, LogicalType::UNKNOWN, false)); + args[std::string(py::str(key))] = BoundParameterData(TransformPythonValue(value, LogicalType::UNKNOWN, false)); } return args; } -std::shared_ptr DuckDBPyConnection::DefaultConnection() { - return GetModuleState().default_connection.Get(); +shared_ptr DuckDBPyConnection::DefaultConnection() { + return default_connection.Get(); } -void DuckDBPyConnection::SetDefaultConnection(std::shared_ptr connection) { - return GetModuleState().default_connection.Set(std::move(connection)); +void DuckDBPyConnection::SetDefaultConnection(shared_ptr connection) { + return default_connection.Set(std::move(connection)); } PythonImportCache *DuckDBPyConnection::ImportCache() { - auto &import_cache = GetModuleState().import_cache; if (!import_cache) { - import_cache = std::make_shared(); + import_cache = make_shared_ptr(); } return import_cache.get(); } @@ -2319,7 +2309,7 @@ ModifiedMemoryFileSystem &DuckDBPyConnection::GetObjectFileSystem() { throw InvalidInputException( "This operation could not be completed because required module 'fsspec' is not installed"); } - internal_object_filesystem = std::make_shared(modified_memory_fs()); + internal_object_filesystem = make_shared_ptr(modified_memory_fs()); auto &abstract_fs = reinterpret_cast(*internal_object_filesystem); RegisterFilesystem(abstract_fs); } @@ -2327,10 +2317,10 @@ ModifiedMemoryFileSystem &DuckDBPyConnection::GetObjectFileSystem() { } bool DuckDBPyConnection::IsInteractive() { - return GetModuleState().environment != PythonEnvironmentType::NORMAL; + return DuckDBPyConnection::environment != PythonEnvironmentType::NORMAL; } -std::shared_ptr DuckDBPyConnection::Enter() { +shared_ptr DuckDBPyConnection::Enter() { return shared_from_this(); } @@ -2345,8 +2335,8 @@ void DuckDBPyConnection::Exit(DuckDBPyConnection &self, const py::object &exc_ty } void DuckDBPyConnection::Cleanup() { - GetModuleState().default_connection.Set(nullptr); - GetModuleState().import_cache.reset(); + default_connection.Set(nullptr); + import_cache.reset(); } bool DuckDBPyConnection::IsPandasDataframe(const py::object &object) { @@ -2364,7 +2354,7 @@ bool IsValidNumpyDimensions(const py::handle &object, int &dim) { if (!py::isinstance(object, import_cache.numpy.ndarray())) { return false; } - auto shape = NumpyArray(py::reinterpret_borrow(object)).GetArray().attr("shape"); + auto shape = (py::cast(object)).attr("shape"); if (py::len(shape) != 1) { return false; } @@ -2376,9 +2366,9 @@ NumpyObjectType DuckDBPyConnection::IsAcceptedNumpyObject(const py::object &obje if (!ModuleIsLoaded()) { return NumpyObjectType::INVALID; } - auto import_cache_ = ImportCache(); - if (py::isinstance(object, import_cache_->numpy.ndarray())) { - auto len = py::len(NumpyArray(object).GetArray().attr("shape")); + auto &import_cache = *DuckDBPyConnection::ImportCache(); + if (py::isinstance(object, import_cache.numpy.ndarray())) { + auto len = py::len((py::cast(object)).attr("shape")); switch (len) { case 1: return NumpyObjectType::NDARRAY1D; @@ -2423,17 +2413,17 @@ PyArrowObjectType DuckDBPyConnection::GetArrowType(const py::handle &obj) { } if (ModuleIsLoaded()) { - auto import_cache_ = ImportCache(); + auto &import_cache = *DuckDBPyConnection::ImportCache(); // MessageReader requires nanoarrow, separate scan function - if (py::isinstance(obj, import_cache_->pyarrow.ipc.MessageReader())) { + if (py::isinstance(obj, import_cache.pyarrow.ipc.MessageReader())) { return PyArrowObjectType::MessageReader; } if (ModuleIsLoaded()) { // Scanner/Dataset don't have __arrow_c_stream__, need dedicated handling - if (py::isinstance(obj, import_cache_->pyarrow.dataset.Scanner())) { + if (py::isinstance(obj, import_cache.pyarrow.dataset.Scanner())) { return PyArrowObjectType::Scanner; - } else if (py::isinstance(obj, import_cache_->pyarrow.dataset.Dataset())) { + } else if (py::isinstance(obj, import_cache.pyarrow.dataset.Dataset())) { return PyArrowObjectType::Dataset; } } diff --git a/src/duckdb_py/pyconnection/type_creation.cpp b/src/duckdb_py/pyconnection/type_creation.cpp index 0a98adbb..71e6c610 100644 --- a/src/duckdb_py/pyconnection/type_creation.cpp +++ b/src/duckdb_py/pyconnection/type_creation.cpp @@ -2,20 +2,20 @@ namespace duckdb { -std::shared_ptr DuckDBPyConnection::MapType(const std::shared_ptr &key_type, - const std::shared_ptr &value_type) { +shared_ptr DuckDBPyConnection::MapType(const shared_ptr &key_type, + const shared_ptr &value_type) { auto map_type = LogicalType::MAP(key_type->Type(), value_type->Type()); - return std::make_shared(map_type); + return make_shared_ptr(map_type); } -std::shared_ptr DuckDBPyConnection::ListType(const std::shared_ptr &type) { +shared_ptr DuckDBPyConnection::ListType(const shared_ptr &type) { auto array_type = LogicalType::LIST(type->Type()); - return std::make_shared(array_type); + return make_shared_ptr(array_type); } -std::shared_ptr DuckDBPyConnection::ArrayType(const std::shared_ptr &type, idx_t size) { +shared_ptr DuckDBPyConnection::ArrayType(const shared_ptr &type, idx_t size) { auto array_type = LogicalType::ARRAY(type->Type(), size); - return std::make_shared(array_type); + return make_shared_ptr(array_type); } static child_list_t GetChildList(const py::object &container) { @@ -24,12 +24,12 @@ static child_list_t GetChildList(const py::object &container) { const py::list &fields = container; idx_t i = 1; for (auto &item : fields) { - std::shared_ptr pytype; - if (!py::try_cast>(item, pytype)) { + shared_ptr pytype; + if (!py::try_cast>(item, pytype)) { string actual_type = py::str(py::type::of(item)); throw InvalidInputException("object has to be a list of DuckDBPyType's, not '%s'", actual_type); } - types.push_back(std::make_pair(Identifier(StringUtil::Format("v%d", i++)), pytype->Type())); + types.push_back(std::make_pair(StringUtil::Format("v%d", i++), pytype->Type())); } return types; } else if (py::isinstance(container)) { @@ -37,9 +37,9 @@ static child_list_t GetChildList(const py::object &container) { for (auto &item : fields) { auto &name_p = item.first; auto &type_p = item.second; - auto name = Identifier(py::str(name_p)); - std::shared_ptr pytype; - if (!py::try_cast>(type_p, pytype)) { + string name = py::str(name_p); + shared_ptr pytype; + if (!py::try_cast>(type_p, pytype)) { string actual_type = py::str(py::type::of(type_p)); throw InvalidInputException("object has to be a list of DuckDBPyType's, not '%s'", actual_type); } @@ -53,51 +53,51 @@ static child_list_t GetChildList(const py::object &container) { } } -std::shared_ptr DuckDBPyConnection::StructType(const py::object &fields) { +shared_ptr DuckDBPyConnection::StructType(const py::object &fields) { child_list_t types = GetChildList(fields); if (types.empty()) { throw InvalidInputException("Can not create an empty struct type!"); } auto struct_type = LogicalType::STRUCT(std::move(types)); - return std::make_shared(struct_type); + return make_shared_ptr(struct_type); } -std::shared_ptr DuckDBPyConnection::UnionType(const py::object &members) { +shared_ptr DuckDBPyConnection::UnionType(const py::object &members) { child_list_t types = GetChildList(members); if (types.empty()) { throw InvalidInputException("Can not create an empty union type!"); } auto union_type = LogicalType::UNION(std::move(types)); - return std::make_shared(union_type); + return make_shared_ptr(union_type); } -std::shared_ptr -DuckDBPyConnection::EnumType(const string &name, const std::shared_ptr &type, const py::list &values_p) { +shared_ptr DuckDBPyConnection::EnumType(const string &name, const shared_ptr &type, + const py::list &values_p) { throw NotImplementedException("enum_type creation method is not implemented yet"); } -std::shared_ptr DuckDBPyConnection::DecimalType(int width, int scale) { +shared_ptr DuckDBPyConnection::DecimalType(int width, int scale) { auto decimal_type = LogicalType::DECIMAL(width, scale); - return std::make_shared(decimal_type); + return make_shared_ptr(decimal_type); } -std::shared_ptr DuckDBPyConnection::StringType(const string &collation) { +shared_ptr DuckDBPyConnection::StringType(const string &collation) { LogicalType type; if (collation.empty()) { type = LogicalType::VARCHAR; } else { type = LogicalType::VARCHAR_COLLATION(collation); } - return std::make_shared(type); + return make_shared_ptr(type); } -std::shared_ptr DuckDBPyConnection::Type(const string &type_str) { +shared_ptr DuckDBPyConnection::Type(const string &type_str) { auto &connection = con.GetConnection(); auto &context = *connection.context; - std::shared_ptr result; + shared_ptr result; context.RunFunctionInTransaction([&result, &type_str, &context]() { - result = std::make_shared(TransformStringToLogicalType(type_str, context)); + result = make_shared_ptr(TransformStringToLogicalType(type_str, context)); }); return result; } diff --git a/src/duckdb_py/pyexpression.cpp b/src/duckdb_py/pyexpression.cpp index 4d984b36..0703389b 100644 --- a/src/duckdb_py/pyexpression.cpp +++ b/src/duckdb_py/pyexpression.cpp @@ -24,7 +24,7 @@ DuckDBPyExpression::DuckDBPyExpression(unique_ptr expr_p, Orde } string DuckDBPyExpression::Type() const { - return ExpressionTypeToString(expression->GetExpressionType()); + return ExpressionTypeToString(expression->type); } string DuckDBPyExpression::ToString() const { @@ -32,7 +32,7 @@ string DuckDBPyExpression::ToString() const { } string DuckDBPyExpression::GetName() const { - return expression->GetName().GetIdentifierName(); + return expression->GetName(); } void DuckDBPyExpression::Print() const { @@ -43,35 +43,35 @@ const ParsedExpression &DuckDBPyExpression::GetExpression() const { return *expression; } -std::shared_ptr DuckDBPyExpression::Copy() const { +shared_ptr DuckDBPyExpression::Copy() const { auto expr = GetExpression().Copy(); - return std::make_shared(std::move(expr), order_type, null_order); + return make_shared_ptr(std::move(expr), order_type, null_order); } -std::shared_ptr DuckDBPyExpression::SetAlias(const string &name) const { +shared_ptr DuckDBPyExpression::SetAlias(const string &name) const { auto copied_expression = GetExpression().Copy(); - copied_expression->SetAlias(Identifier(name)); - return std::make_shared(std::move(copied_expression)); + copied_expression->alias = name; + return make_shared_ptr(std::move(copied_expression)); } -std::shared_ptr DuckDBPyExpression::Cast(const DuckDBPyType &type) const { +shared_ptr DuckDBPyExpression::Cast(const DuckDBPyType &type) const { auto copied_expression = GetExpression().Copy(); auto case_expr = make_uniq(type.Type(), std::move(copied_expression)); - return std::make_shared(std::move(case_expr)); + return make_shared_ptr(std::move(case_expr)); } -std::shared_ptr DuckDBPyExpression::Between(const DuckDBPyExpression &lower, - const DuckDBPyExpression &upper) { +shared_ptr DuckDBPyExpression::Between(const DuckDBPyExpression &lower, + const DuckDBPyExpression &upper) { auto copied_expression = GetExpression().Copy(); auto between_expr = make_uniq(std::move(copied_expression), lower.GetExpression().Copy(), upper.GetExpression().Copy()); - return std::make_shared(std::move(between_expr)); + return make_shared_ptr(std::move(between_expr)); } -std::shared_ptr DuckDBPyExpression::Collate(const string &collation) { +shared_ptr DuckDBPyExpression::Collate(const string &collation) { auto copied_expression = GetExpression().Copy(); auto collation_expression = make_uniq(collation, std::move(copied_expression)); - return std::make_shared(std::move(collation_expression)); + return make_shared_ptr(std::move(collation_expression)); } // Case Expression modifiers @@ -82,18 +82,18 @@ void DuckDBPyExpression::AssertCaseExpression() const { } } -std::shared_ptr DuckDBPyExpression::InternalWhen(unique_ptr expr, - const DuckDBPyExpression &condition, - const DuckDBPyExpression &value) { +shared_ptr DuckDBPyExpression::InternalWhen(unique_ptr expr, + const DuckDBPyExpression &condition, + const DuckDBPyExpression &value) { CaseCheck check; check.when_expr = condition.GetExpression().Copy(); check.then_expr = value.GetExpression().Copy(); - expr->CaseChecksMutable().push_back(std::move(check)); - return std::make_shared(std::move(expr)); + expr->case_checks.push_back(std::move(check)); + return make_shared_ptr(std::move(expr)); } -std::shared_ptr DuckDBPyExpression::When(const DuckDBPyExpression &condition, - const DuckDBPyExpression &value) { +shared_ptr DuckDBPyExpression::When(const DuckDBPyExpression &condition, + const DuckDBPyExpression &value) { AssertCaseExpression(); auto expr_p = expression->Copy(); auto expr = unique_ptr_cast(std::move(expr_p)); @@ -101,99 +101,99 @@ std::shared_ptr DuckDBPyExpression::When(const DuckDBPyExpre return InternalWhen(std::move(expr), condition, value); } -std::shared_ptr DuckDBPyExpression::Else(const DuckDBPyExpression &value) { +shared_ptr DuckDBPyExpression::Else(const DuckDBPyExpression &value) { AssertCaseExpression(); auto expr_p = expression->Copy(); auto expr = unique_ptr_cast(std::move(expr_p)); - expr->ElseMutable() = value.GetExpression().Copy(); - return std::make_shared(std::move(expr)); + expr->else_expr = value.GetExpression().Copy(); + return make_shared_ptr(std::move(expr)); } // Binary operators -std::shared_ptr DuckDBPyExpression::Add(const DuckDBPyExpression &other) const { +shared_ptr DuckDBPyExpression::Add(const DuckDBPyExpression &other) const { return DuckDBPyExpression::BinaryOperator("+", *this, other); } -std::shared_ptr DuckDBPyExpression::Subtract(const DuckDBPyExpression &other) const { +shared_ptr DuckDBPyExpression::Subtract(const DuckDBPyExpression &other) const { return DuckDBPyExpression::BinaryOperator("-", *this, other); } -std::shared_ptr DuckDBPyExpression::Multiply(const DuckDBPyExpression &other) const { +shared_ptr DuckDBPyExpression::Multiply(const DuckDBPyExpression &other) const { return DuckDBPyExpression::BinaryOperator("*", *this, other); } -std::shared_ptr DuckDBPyExpression::Division(const DuckDBPyExpression &other) const { +shared_ptr DuckDBPyExpression::Division(const DuckDBPyExpression &other) const { return DuckDBPyExpression::BinaryOperator("/", *this, other); } -std::shared_ptr DuckDBPyExpression::FloorDivision(const DuckDBPyExpression &other) const { +shared_ptr DuckDBPyExpression::FloorDivision(const DuckDBPyExpression &other) const { return DuckDBPyExpression::BinaryOperator("//", *this, other); } -std::shared_ptr DuckDBPyExpression::Modulo(const DuckDBPyExpression &other) const { +shared_ptr DuckDBPyExpression::Modulo(const DuckDBPyExpression &other) const { return DuckDBPyExpression::BinaryOperator("%", *this, other); } -std::shared_ptr DuckDBPyExpression::Power(const DuckDBPyExpression &other) const { +shared_ptr DuckDBPyExpression::Power(const DuckDBPyExpression &other) const { return DuckDBPyExpression::BinaryOperator("**", *this, other); } // Comparison expressions -std::shared_ptr DuckDBPyExpression::Equality(const DuckDBPyExpression &other) { +shared_ptr DuckDBPyExpression::Equality(const DuckDBPyExpression &other) { return ComparisonExpression(ExpressionType::COMPARE_EQUAL, *this, other); } -std::shared_ptr DuckDBPyExpression::Inequality(const DuckDBPyExpression &other) { +shared_ptr DuckDBPyExpression::Inequality(const DuckDBPyExpression &other) { return ComparisonExpression(ExpressionType::COMPARE_NOTEQUAL, *this, other); } -std::shared_ptr DuckDBPyExpression::GreaterThan(const DuckDBPyExpression &other) { +shared_ptr DuckDBPyExpression::GreaterThan(const DuckDBPyExpression &other) { return ComparisonExpression(ExpressionType::COMPARE_GREATERTHAN, *this, other); } -std::shared_ptr DuckDBPyExpression::GreaterThanOrEqual(const DuckDBPyExpression &other) { +shared_ptr DuckDBPyExpression::GreaterThanOrEqual(const DuckDBPyExpression &other) { return ComparisonExpression(ExpressionType::COMPARE_GREATERTHANOREQUALTO, *this, other); } -std::shared_ptr DuckDBPyExpression::LessThan(const DuckDBPyExpression &other) { +shared_ptr DuckDBPyExpression::LessThan(const DuckDBPyExpression &other) { return ComparisonExpression(ExpressionType::COMPARE_LESSTHAN, *this, other); } -std::shared_ptr DuckDBPyExpression::LessThanOrEqual(const DuckDBPyExpression &other) { +shared_ptr DuckDBPyExpression::LessThanOrEqual(const DuckDBPyExpression &other) { return ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, *this, other); } // AND, OR and NOT -std::shared_ptr DuckDBPyExpression::Not() { +shared_ptr DuckDBPyExpression::Not() { return DuckDBPyExpression::InternalUnaryOperator(ExpressionType::OPERATOR_NOT, *this); } -std::shared_ptr DuckDBPyExpression::And(const DuckDBPyExpression &other) const { +shared_ptr DuckDBPyExpression::And(const DuckDBPyExpression &other) const { return DuckDBPyExpression::InternalConjunction(ExpressionType::CONJUNCTION_AND, *this, other); } -std::shared_ptr DuckDBPyExpression::Or(const DuckDBPyExpression &other) const { +shared_ptr DuckDBPyExpression::Or(const DuckDBPyExpression &other) const { return DuckDBPyExpression::InternalConjunction(ExpressionType::CONJUNCTION_OR, *this, other); } // NULL -std::shared_ptr DuckDBPyExpression::IsNull() { +shared_ptr DuckDBPyExpression::IsNull() { return DuckDBPyExpression::InternalUnaryOperator(ExpressionType::OPERATOR_IS_NULL, *this); } -std::shared_ptr DuckDBPyExpression::IsNotNull() { +shared_ptr DuckDBPyExpression::IsNotNull() { return DuckDBPyExpression::InternalUnaryOperator(ExpressionType::OPERATOR_IS_NOT_NULL, *this); } // IN / NOT IN -std::shared_ptr DuckDBPyExpression::CreateCompareExpression(ExpressionType compare_type, - const py::args &args) { +shared_ptr DuckDBPyExpression::CreateCompareExpression(ExpressionType compare_type, + const py::args &args) { D_ASSERT(args.size() >= 1); vector> expressions; @@ -201,25 +201,25 @@ std::shared_ptr DuckDBPyExpression::CreateCompareExpression( expressions.push_back(GetExpression().Copy()); for (auto arg : args) { - std::shared_ptr py_expr; - if (!py::try_cast>(arg, py_expr)) { + shared_ptr py_expr; + if (!py::try_cast>(arg, py_expr)) { throw InvalidInputException("Please provide arguments of type Expression!"); } auto expr = py_expr->GetExpression().Copy(); expressions.push_back(std::move(expr)); } auto operator_expr = make_uniq(compare_type, std::move(expressions)); - return std::make_shared(std::move(operator_expr)); + return make_shared_ptr(std::move(operator_expr)); } -std::shared_ptr DuckDBPyExpression::In(const py::args &args) { +shared_ptr DuckDBPyExpression::In(const py::args &args) { if (args.size() == 0) { throw InvalidInputException("Incorrect amount of parameters to 'isin', needs at least 1 parameter"); } return CreateCompareExpression(ExpressionType::COMPARE_IN, args); } -std::shared_ptr DuckDBPyExpression::NotIn(const py::args &args) { +shared_ptr DuckDBPyExpression::NotIn(const py::args &args) { if (args.size() == 0) { throw InvalidInputException("Incorrect amount of parameters to 'isnotin', needs at least 1 parameter"); } @@ -228,13 +228,13 @@ std::shared_ptr DuckDBPyExpression::NotIn(const py::args &ar // COALESCE -std::shared_ptr DuckDBPyExpression::Coalesce(const py::args &args) { +shared_ptr DuckDBPyExpression::Coalesce(const py::args &args) { vector> expressions; expressions.reserve(args.size()); for (auto arg : args) { - std::shared_ptr py_expr; - if (!py::try_cast>(arg, py_expr)) { + shared_ptr py_expr; + if (!py::try_cast>(arg, py_expr)) { throw InvalidInputException("Please provide arguments of type Expression!"); } auto expr = py_expr->GetExpression().Copy(); @@ -244,18 +244,18 @@ std::shared_ptr DuckDBPyExpression::Coalesce(const py::args throw InvalidInputException("Please provide at least one argument"); } auto operator_expr = make_uniq(ExpressionType::OPERATOR_COALESCE, std::move(expressions)); - return std::make_shared(std::move(operator_expr)); + return make_shared_ptr(std::move(operator_expr)); } // Order modifiers -std::shared_ptr DuckDBPyExpression::Ascending() { +shared_ptr DuckDBPyExpression::Ascending() { auto py_expr = Copy(); py_expr->order_type = OrderType::ASCENDING; return py_expr; } -std::shared_ptr DuckDBPyExpression::Descending() { +shared_ptr DuckDBPyExpression::Descending() { auto py_expr = Copy(); py_expr->order_type = OrderType::DESCENDING; return py_expr; @@ -263,13 +263,13 @@ std::shared_ptr DuckDBPyExpression::Descending() { // Null order modifiers -std::shared_ptr DuckDBPyExpression::NullsFirst() { +shared_ptr DuckDBPyExpression::NullsFirst() { auto py_expr = Copy(); py_expr->null_order = OrderByNullType::NULLS_FIRST; return py_expr; } -std::shared_ptr DuckDBPyExpression::NullsLast() { +shared_ptr DuckDBPyExpression::NullsLast() { auto py_expr = Copy(); py_expr->null_order = OrderByNullType::NULLS_LAST; return py_expr; @@ -277,7 +277,7 @@ std::shared_ptr DuckDBPyExpression::NullsLast() { // Unary operators -std::shared_ptr DuckDBPyExpression::Negate() { +shared_ptr DuckDBPyExpression::Negate() { vector> children; children.push_back(GetExpression().Copy()); return DuckDBPyExpression::InternalFunctionExpression("-", std::move(children), true); @@ -297,7 +297,7 @@ static void PopulateExcludeList(qualified_column_set_t &exclude, py::object list exclude.insert(qname); continue; } - std::shared_ptr expr; + shared_ptr expr; if (!py::try_cast(item, expr)) { throw py::value_error("Items in the exclude list should either be 'str' or Expression"); } @@ -309,15 +309,15 @@ static void PopulateExcludeList(qualified_column_set_t &exclude, py::object list } } -std::shared_ptr DuckDBPyExpression::StarExpression(py::object exclude_list) { +shared_ptr DuckDBPyExpression::StarExpression(py::object exclude_list) { case_insensitive_set_t exclude; auto star = make_uniq(); - PopulateExcludeList(star->ExcludeListMutable(), std::move(exclude_list)); - return std::make_shared(std::move(star)); + PopulateExcludeList(star->exclude_list, std::move(exclude_list)); + return make_shared_ptr(std::move(star)); } -std::shared_ptr DuckDBPyExpression::ColumnExpression(const py::args &names) { - vector column_names; +shared_ptr DuckDBPyExpression::ColumnExpression(const py::args &names) { + vector column_names; if (names.size() == 1) { string column_name = std::string(py::str(names[0])); if (column_name == "*") { @@ -325,28 +325,28 @@ std::shared_ptr DuckDBPyExpression::ColumnExpression(const p } auto qualified_name = QualifiedName::Parse(column_name); - if (!qualified_name.Catalog().empty()) { - column_names.push_back(qualified_name.Catalog()); + if (!qualified_name.catalog.empty()) { + column_names.push_back(qualified_name.catalog); } - if (!qualified_name.Schema().empty()) { - column_names.push_back(qualified_name.Schema()); + if (!qualified_name.schema.empty()) { + column_names.push_back(qualified_name.schema); } - column_names.push_back(qualified_name.Name()); + column_names.push_back(qualified_name.name); } else { for (auto &part : names) { - column_names.push_back(Identifier(py::str(part))); + column_names.push_back(std::string(py::str(part))); } } auto column_ref = make_uniq(std::move(column_names)); - return std::make_shared(std::move(column_ref)); + return make_shared_ptr(std::move(column_ref)); } -std::shared_ptr DuckDBPyExpression::DefaultExpression() { - return std::make_shared(make_uniq()); +shared_ptr DuckDBPyExpression::DefaultExpression() { + return make_shared_ptr(make_uniq()); } -std::shared_ptr DuckDBPyExpression::ConstantExpression(const py::object &value) { - auto val = TransformPythonValue(nullptr, value); +shared_ptr DuckDBPyExpression::ConstantExpression(const py::object &value) { + auto val = TransformPythonValue(value); return InternalConstantExpression(std::move(val)); } @@ -358,8 +358,8 @@ static py::args CreateArgsFromItem(py::handle item) { } } -std::shared_ptr DuckDBPyExpression::LambdaExpression(const py::object &lhs_p, - const DuckDBPyExpression &rhs) { +shared_ptr DuckDBPyExpression::LambdaExpression(const py::object &lhs_p, + const DuckDBPyExpression &rhs) { unique_ptr lhs; if (py::isinstance(lhs_p)) { // LambdaExpression(lhs=(, , )) @@ -369,7 +369,7 @@ std::shared_ptr DuckDBPyExpression::LambdaExpression(const p unique_ptr column; if (py::isinstance(item)) { // 'item' is already an Expression, check its type and use it - auto column_expr = py::cast>(item); + auto column_expr = py::cast>(item); if (column_expr->GetExpression().GetExpressionType() != ExpressionType::COLUMN_REF) { throw py::value_error("'lhs' was provided as a tuple of columns, but one of the columns is not of " "type ColumnExpression"); @@ -400,7 +400,7 @@ std::shared_ptr DuckDBPyExpression::LambdaExpression(const p } else if (py::isinstance(lhs_p)) { // LambdaExpression(lhs=Expression) // 'lhs_p' is already an Expression, check its type and use it - auto column_expr = py::cast>(lhs_p); + auto column_expr = py::cast>(lhs_p); if (column_expr->GetExpression().GetExpressionType() != ExpressionType::COLUMN_REF) { throw py::value_error("'lhs' was an Expression, but is not of type ColumnExpression"); } @@ -409,14 +409,10 @@ std::shared_ptr DuckDBPyExpression::LambdaExpression(const p throw py::value_error("Please provide 'lhs' as either a tuple containing strings, or a single string"); } auto lambda_expression = make_uniq(std::move(lhs), rhs.GetExpression().Copy()); - // Use the modern `lambda x, y: ...` syntax. The lhs we built (a column ref, or a `row` function for multiple - // parameters) is identical to what the named-parameter constructor produces; only the syntax type differs, and - // the single-arrow form is now deprecated and errors by default. - lambda_expression->GetLambdaSyntaxTypeMutable() = LambdaSyntaxType::LAMBDA_KEYWORD; - return std::make_shared(std::move(lambda_expression)); + return make_shared_ptr(std::move(lambda_expression)); } -std::shared_ptr DuckDBPyExpression::SQLExpression(string sql) { +shared_ptr DuckDBPyExpression::SQLExpression(string sql) { auto conn = DuckDBPyConnection::DefaultConnection(); auto &context = *conn->con.GetConnection().context; vector> expressions; @@ -432,14 +428,14 @@ std::shared_ptr DuckDBPyExpression::SQLExpression(string sql expressions.size()); } - return std::make_shared(std::move(expressions[0])); + return make_shared_ptr(std::move(expressions[0])); } // Private methods -std::shared_ptr DuckDBPyExpression::BinaryOperator(const string &function_name, - const DuckDBPyExpression &arg_one, - const DuckDBPyExpression &arg_two) { +shared_ptr DuckDBPyExpression::BinaryOperator(const string &function_name, + const DuckDBPyExpression &arg_one, + const DuckDBPyExpression &arg_two) { vector> children; children.push_back(arg_one.GetExpression().Copy()); @@ -447,63 +443,63 @@ std::shared_ptr DuckDBPyExpression::BinaryOperator(const str return InternalFunctionExpression(function_name, std::move(children), true); } -std::shared_ptr +shared_ptr DuckDBPyExpression::InternalFunctionExpression(const string &function_name, vector> children, bool is_operator) { - auto function_expression = make_uniq(Identifier(function_name), std::move(children), - nullptr, nullptr, false, is_operator); - return std::make_shared(std::move(function_expression)); + auto function_expression = + make_uniq(function_name, std::move(children), nullptr, nullptr, false, is_operator); + return make_shared_ptr(std::move(function_expression)); } -std::shared_ptr DuckDBPyExpression::InternalUnaryOperator(ExpressionType type, - const DuckDBPyExpression &arg) { +shared_ptr DuckDBPyExpression::InternalUnaryOperator(ExpressionType type, + const DuckDBPyExpression &arg) { auto expr = arg.GetExpression().Copy(); auto operator_expression = make_uniq(type, std::move(expr)); - return std::make_shared(std::move(operator_expression)); + return make_shared_ptr(std::move(operator_expression)); } -std::shared_ptr DuckDBPyExpression::InternalConjunction(ExpressionType type, - const DuckDBPyExpression &arg, - const DuckDBPyExpression &other) { +shared_ptr DuckDBPyExpression::InternalConjunction(ExpressionType type, + const DuckDBPyExpression &arg, + const DuckDBPyExpression &other) { vector> children; children.reserve(2); children.push_back(arg.GetExpression().Copy()); children.push_back(other.GetExpression().Copy()); auto operator_expression = make_uniq(type, std::move(children)); - return std::make_shared(std::move(operator_expression)); + return make_shared_ptr(std::move(operator_expression)); } -std::shared_ptr DuckDBPyExpression::InternalConstantExpression(Value val) { - return std::make_shared(make_uniq(std::move(val))); +shared_ptr DuckDBPyExpression::InternalConstantExpression(Value val) { + return make_shared_ptr(make_uniq(std::move(val))); } -std::shared_ptr DuckDBPyExpression::ComparisonExpression(ExpressionType type, - const DuckDBPyExpression &left_p, - const DuckDBPyExpression &right_p) { +shared_ptr DuckDBPyExpression::ComparisonExpression(ExpressionType type, + const DuckDBPyExpression &left_p, + const DuckDBPyExpression &right_p) { auto left = left_p.GetExpression().Copy(); auto right = right_p.GetExpression().Copy(); - return std::make_shared( + return make_shared_ptr( make_uniq(type, std::move(left), std::move(right))); } -std::shared_ptr DuckDBPyExpression::CaseExpression(const DuckDBPyExpression &condition, - const DuckDBPyExpression &value) { +shared_ptr DuckDBPyExpression::CaseExpression(const DuckDBPyExpression &condition, + const DuckDBPyExpression &value) { auto expr = make_uniq(); auto case_expr = InternalWhen(std::move(expr), condition, value); // Add NULL as default Else expression auto &internal_expression = reinterpret_cast(*case_expr->expression); - internal_expression.ElseMutable() = make_uniq(Value(LogicalTypeId::SQLNULL)); + internal_expression.else_expr = make_uniq(Value(LogicalTypeId::SQLNULL)); return case_expr; } -std::shared_ptr DuckDBPyExpression::FunctionExpression(const string &function_name, - const py::args &args) { +shared_ptr DuckDBPyExpression::FunctionExpression(const string &function_name, + const py::args &args) { vector> expressions; for (auto arg : args) { - std::shared_ptr py_expr; - if (!py::try_cast>(arg, py_expr)) { + shared_ptr py_expr; + if (!py::try_cast>(arg, py_expr)) { string actual_type = py::str(py::type::of(arg)); throw InvalidInputException("Expected argument of type Expression, received '%s' instead", actual_type); } diff --git a/src/duckdb_py/pyexpression/initialize.cpp b/src/duckdb_py/pyexpression/initialize.cpp index 1ea38136..11cf5dc3 100644 --- a/src/duckdb_py/pyexpression/initialize.cpp +++ b/src/duckdb_py/pyexpression/initialize.cpp @@ -47,7 +47,7 @@ void InitializeStaticMethods(py::module_ &m) { m.def("SQLExpression", &DuckDBPyExpression::SQLExpression, docs, py::arg("expression")); } -static void InitializeDunderMethods(py::class_> &m) { +static void InitializeDunderMethods(py::class_> &m) { const char *docs; docs = R"( @@ -287,13 +287,13 @@ static void InitializeDunderMethods(py::class_> &m) { +static void InitializeImplicitConversion(py::class_> &m) { m.def(py::init<>([](const string &name) { auto names = py::make_tuple(py::str(name)); return DuckDBPyExpression::ColumnExpression(names); })); m.def(py::init<>([](const py::object &obj) { - auto val = TransformPythonValue(nullptr, obj); + auto val = TransformPythonValue(obj); return DuckDBPyExpression::InternalConstantExpression(std::move(val)); })); py::implicitly_convertible(); @@ -301,7 +301,8 @@ static void InitializeImplicitConversion(py::class_>(m, "Expression"); + auto expression = + py::class_>(m, "Expression", py::module_local()); InitializeStaticMethods(m); InitializeDunderMethods(expression); diff --git a/src/duckdb_py/pyrelation.cpp b/src/duckdb_py/pyrelation.cpp index a991c71e..23040a53 100644 --- a/src/duckdb_py/pyrelation.cpp +++ b/src/duckdb_py/pyrelation.cpp @@ -5,9 +5,11 @@ #include "duckdb_python/pyresult.hpp" #include "duckdb/parser/qualified_name.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb_python/numpy/numpy_type.hpp" #include "duckdb/main/relation/query_relation.hpp" #include "duckdb/main/relation/join_relation.hpp" #include "duckdb/parser/parser.hpp" +#include "duckdb/main/relation/view_relation.hpp" #include "duckdb/function/pragma/pragma_functions.hpp" #include "duckdb/parser/statement/pragma_statement.hpp" #include "duckdb/common/box_renderer.hpp" @@ -16,6 +18,7 @@ #include "duckdb/parser/statement/explain_statement.hpp" #include "duckdb/catalog/default/default_types.hpp" #include "duckdb/main/relation/value_relation.hpp" +#include "duckdb/main/relation/filter_relation.hpp" #include "duckdb_python/expression/pyexpression.hpp" #include "duckdb/common/arrow/physical_arrow_collector.hpp" #include "duckdb_python/arrow/arrow_export_utils.hpp" @@ -29,7 +32,7 @@ DuckDBPyRelation::DuckDBPyRelation(shared_ptr rel_p) : rel(std::move(r this->executed = false; auto &columns = rel->Columns(); for (auto &col : columns) { - names.push_back(col.Name().GetIdentifierName()); + names.push_back(col.GetName()); types.push_back(col.GetType()); } } @@ -63,8 +66,7 @@ DuckDBPyRelation::~DuckDBPyRelation() { rel.reset(); } -DuckDBPyRelation::DuckDBPyRelation(std::shared_ptr result_p) - : rel(nullptr), result(std::move(result_p)) { +DuckDBPyRelation::DuckDBPyRelation(shared_ptr result_p) : rel(nullptr), result(std::move(result_p)) { if (!result) { throw InternalException("DuckDBPyRelation created without a result"); } @@ -73,7 +75,7 @@ DuckDBPyRelation::DuckDBPyRelation(std::shared_ptr result_p) this->names = result->GetNames(); } -std::unique_ptr DuckDBPyRelation::ProjectFromExpression(const string &expression) { +unique_ptr DuckDBPyRelation::ProjectFromExpression(const string &expression) { auto projected_relation = DeriveRelation(rel->Project(expression)); for (auto &dep : this->rel->external_dependencies) { projected_relation->rel->AddExternalDependency(dep); @@ -81,7 +83,7 @@ std::unique_ptr DuckDBPyRelation::ProjectFromExpression(const return projected_relation; } -std::unique_ptr DuckDBPyRelation::Project(const py::args &args, const string &groups) { +unique_ptr DuckDBPyRelation::Project(const py::args &args, const string &groups) { if (!rel) { return nullptr; } @@ -96,8 +98,8 @@ std::unique_ptr DuckDBPyRelation::Project(const py::args &args } else { vector> expressions; for (auto arg : args) { - std::shared_ptr py_expr; - if (!py::try_cast>(arg, py_expr)) { + shared_ptr py_expr; + if (!py::try_cast>(arg, py_expr)) { throw InvalidInputException("Please provide arguments of type Expression!"); } auto expr = py_expr->GetExpression().Copy(); @@ -112,7 +114,7 @@ std::unique_ptr DuckDBPyRelation::Project(const py::args &args } } -std::unique_ptr DuckDBPyRelation::ProjectFromTypes(const py::object &obj) { +unique_ptr DuckDBPyRelation::ProjectFromTypes(const py::object &obj) { if (!rel) { return nullptr; } @@ -152,7 +154,7 @@ std::unique_ptr DuckDBPyRelation::ProjectFromTypes(const py::o if (!projection.empty()) { projection += ", "; } - projection += SQLIdentifier(names[i]); + projection += KeywordHelper::WriteOptionallyQuoted(names[i]); } } if (projection.empty()) { @@ -161,9 +163,8 @@ std::unique_ptr DuckDBPyRelation::ProjectFromTypes(const py::o return ProjectFromExpression(projection); } -std::unique_ptr DuckDBPyRelation::EmptyResult(const shared_ptr &context, - const vector &types, - vector names) { +unique_ptr DuckDBPyRelation::EmptyResult(const shared_ptr &context, + const vector &types, vector names) { vector dummy_values; D_ASSERT(types.size() == names.size()); dummy_values.reserve(types.size()); @@ -173,12 +174,12 @@ std::unique_ptr DuckDBPyRelation::EmptyResult(const shared_ptr } vector> single_row(1, dummy_values); auto values_relation = - std::make_unique(make_shared_ptr(context, single_row, std::move(names))); + make_uniq(make_shared_ptr(context, single_row, std::move(names))); // Add a filter on an impossible condition return values_relation->FilterFromExpression("true = false"); } -std::unique_ptr DuckDBPyRelation::SetAlias(const string &expr) { +unique_ptr DuckDBPyRelation::SetAlias(const string &expr) { return DeriveRelation(rel->Alias(expr)); } @@ -186,12 +187,12 @@ py::str DuckDBPyRelation::GetAlias() { return py::str(string(rel->GetAlias())); } -std::unique_ptr DuckDBPyRelation::Filter(const py::object &expr) { +unique_ptr DuckDBPyRelation::Filter(const py::object &expr) { if (py::isinstance(expr)) { string expression = py::cast(expr); return FilterFromExpression(expression); } - std::shared_ptr expression; + shared_ptr expression; if (!py::try_cast(expr, expression)) { throw InvalidInputException("Please provide either a string or a DuckDBPyExpression object to 'filter'"); } @@ -199,25 +200,25 @@ std::unique_ptr DuckDBPyRelation::Filter(const py::object &exp return DeriveRelation(rel->Filter(std::move(expr_p))); } -std::unique_ptr DuckDBPyRelation::FilterFromExpression(const string &expr) { +unique_ptr DuckDBPyRelation::FilterFromExpression(const string &expr) { return DeriveRelation(rel->Filter(expr)); } -std::unique_ptr DuckDBPyRelation::Limit(int64_t n, int64_t offset) { +unique_ptr DuckDBPyRelation::Limit(int64_t n, int64_t offset) { return DeriveRelation(rel->Limit(n, offset)); } -std::unique_ptr DuckDBPyRelation::Order(const string &expr) { +unique_ptr DuckDBPyRelation::Order(const string &expr) { return DeriveRelation(rel->Order(expr)); } -std::unique_ptr DuckDBPyRelation::Sort(const py::args &args) { +unique_ptr DuckDBPyRelation::Sort(const py::args &args) { vector order_nodes; order_nodes.reserve(args.size()); for (auto arg : args) { - std::shared_ptr py_expr; - if (!py::try_cast>(arg, py_expr)) { + shared_ptr py_expr; + if (!py::try_cast>(arg, py_expr)) { string actual_type = py::str(py::type::of(arg)); throw InvalidInputException("Expected argument of type Expression, received '%s' instead", actual_type); } @@ -235,12 +236,12 @@ vector> GetExpressions(ClientContext &context, cons vector> expressions; auto aggregate_list = py::list(expr); for (auto &item : aggregate_list) { - std::shared_ptr py_expr; - if (!py::try_cast>(item, py_expr)) { + shared_ptr py_expr; + if (!py::try_cast>(item, py_expr)) { throw InvalidInputException("Please provide arguments of type Expression!"); } - auto expr_ = py_expr->GetExpression().Copy(); - expressions.push_back(std::move(expr_)); + auto expr = py_expr->GetExpression().Copy(); + expressions.push_back(std::move(expr)); } return expressions; } else if (py::isinstance(expr)) { @@ -254,7 +255,7 @@ vector> GetExpressions(ClientContext &context, cons } } -std::unique_ptr DuckDBPyRelation::Aggregate(const py::object &expr, const string &groups) { +unique_ptr DuckDBPyRelation::Aggregate(const py::object &expr, const string &groups) { AssertRelation(); auto expressions = GetExpressions(*rel->context->GetContext(), expr); if (!groups.empty()) { @@ -331,7 +332,7 @@ vector CreateExpressionList(const vector &columns, } expr += aggregates[i].name; expr += "("; - expr += SQLIdentifier(col.GetName()); + expr += KeywordHelper::WriteOptionallyQuoted(col.GetName()); expr += ")"; if (col.GetType().IsNumeric()) { expr += "::DOUBLE"; @@ -340,13 +341,13 @@ vector CreateExpressionList(const vector &columns, } } expr += "])"; - expr += " AS " + SQLIdentifier(col.GetName()); + expr += " AS " + KeywordHelper::WriteOptionallyQuoted(col.GetName()); expressions.push_back(expr); } return expressions; } -std::unique_ptr DuckDBPyRelation::Describe() { +unique_ptr DuckDBPyRelation::Describe() { auto &columns = rel->Columns(); vector aggregates; aggregates = {DescribeAggregateInfo("count"), DescribeAggregateInfo("mean", true), @@ -399,9 +400,6 @@ string DuckDBPyRelation::GenerateExpressionList(const string &function_name, vec // We parse the input as an expression to validate it. auto trimmed_input = input[i]; StringUtil::Trim(trimmed_input); - if (trimmed_input.empty()) { - throw ParserException("Invalid column expression: '%s'", input[i]); - } unique_ptr expression; try { @@ -411,7 +409,7 @@ string DuckDBPyRelation::GenerateExpressionList(const string &function_name, vec } } catch (const ParserException &) { // First attempt at parsing failed, the input might be a column name that needs quoting. - auto quoted_input = SQLQuotedIdentifier::ToString(trimmed_input); + auto quoted_input = KeywordHelper::WriteQuoted(trimmed_input, '"'); auto expressions = Parser::ParseExpressionList(quoted_input); if (expressions.size() == 1 && expressions[0]->GetExpressionClass() == ExpressionClass::COLUMN_REF) { expression = std::move(expressions[0]); @@ -441,9 +439,10 @@ string DuckDBPyRelation::GenerateExpressionList(const string &function_name, vec /* General aggregate functions */ -std::unique_ptr -DuckDBPyRelation::GenericAggregator(const string &function_name, const string &aggregated_columns, const string &groups, - const string &function_parameter, const string &projected_columns) { +unique_ptr DuckDBPyRelation::GenericAggregator(const string &function_name, + const string &aggregated_columns, const string &groups, + const string &function_parameter, + const string &projected_columns) { //! Construct Aggregation Expression auto expr = GenerateExpressionList(function_name, aggregated_columns, groups, function_parameter, false, @@ -451,7 +450,7 @@ DuckDBPyRelation::GenericAggregator(const string &function_name, const string &a return Aggregate(py::str(expr), groups); } -std::unique_ptr +unique_ptr DuckDBPyRelation::GenericWindowFunction(const string &function_name, const string &function_parameters, const string &aggr_columns, const string &window_spec, const bool &ignore_nulls, const string &projected_columns) { @@ -460,11 +459,10 @@ DuckDBPyRelation::GenericWindowFunction(const string &function_name, const strin return DeriveRelation(rel->Project(expr)); } -std::unique_ptr DuckDBPyRelation::ApplyAggOrWin(const string &function_name, - const string &agg_columns, - const string &function_parameters, - const string &groups, const string &window_spec, - const string &projected_columns, bool ignore_nulls) { +unique_ptr DuckDBPyRelation::ApplyAggOrWin(const string &function_name, const string &agg_columns, + const string &function_parameters, const string &groups, + const string &window_spec, const string &projected_columns, + bool ignore_nulls) { if (!groups.empty() && !window_spec.empty()) { throw InvalidInputException("Either groups or window must be set (can't be both at the same time)"); } @@ -476,54 +474,52 @@ std::unique_ptr DuckDBPyRelation::ApplyAggOrWin(const string & } } -std::unique_ptr DuckDBPyRelation::AnyValue(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::AnyValue(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("any_value", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::ArgMax(const std::string &arg_column, - const std::string &value_column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::ArgMax(const std::string &arg_column, const std::string &value_column, + const std::string &groups, const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("arg_max", arg_column, value_column, groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::ArgMin(const std::string &arg_column, - const std::string &value_column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::ArgMin(const std::string &arg_column, const std::string &value_column, + const std::string &groups, const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("arg_min", arg_column, value_column, groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::Avg(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::Avg(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("avg", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::BitAnd(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::BitAnd(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("bit_and", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::BitOr(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::BitOr(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("bit_or", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::BitXor(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::BitXor(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("bit_xor", column, "", groups, window_spec, projected_columns); } -std::unique_ptr -DuckDBPyRelation::BitStringAgg(const std::string &column, const Optional &min, - const Optional &max, const std::string &groups, - const std::string &window_spec, const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::BitStringAgg(const std::string &column, const Optional &min, + const Optional &max, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { if ((min.is_none() && !max.is_none()) || (!min.is_none() && max.is_none())) { throw InvalidInputException("Both min and max values must be set"); } @@ -537,117 +533,116 @@ DuckDBPyRelation::BitStringAgg(const std::string &column, const Optional DuckDBPyRelation::BoolAnd(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::BoolAnd(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("bool_and", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::BoolOr(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::BoolOr(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("bool_or", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::ValueCounts(const std::string &column, const std::string &groups) { +unique_ptr DuckDBPyRelation::ValueCounts(const std::string &column, const std::string &groups) { return Count(column, groups, "", column); } -std::unique_ptr DuckDBPyRelation::Count(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::Count(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("count", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::FAvg(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::FAvg(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("favg", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::First(const string &column, const std::string &groups, - const string &projected_columns) { +unique_ptr DuckDBPyRelation::First(const string &column, const std::string &groups, + const string &projected_columns) { return GenericAggregator("first", column, groups, "", projected_columns); } -std::unique_ptr DuckDBPyRelation::FSum(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::FSum(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("fsum", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::GeoMean(const std::string &column, const std::string &groups, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::GeoMean(const std::string &column, const std::string &groups, + const std::string &projected_columns) { return GenericAggregator("geomean", column, groups, "", projected_columns); } -std::unique_ptr DuckDBPyRelation::Histogram(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::Histogram(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("histogram", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::List(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::List(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("list", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::Last(const std::string &column, const std::string &groups, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::Last(const std::string &column, const std::string &groups, + const std::string &projected_columns) { return GenericAggregator("last", column, groups, "", projected_columns); } -std::unique_ptr DuckDBPyRelation::Max(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::Max(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("max", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::Min(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::Min(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("min", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::Product(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::Product(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("product", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::StringAgg(const std::string &column, const std::string &sep, - const std::string &groups, const std::string &window_spec, - const std::string &projected_columns) { - auto string_agg_params = SQLString::ToString(sep); +unique_ptr DuckDBPyRelation::StringAgg(const std::string &column, const std::string &sep, + const std::string &groups, const std::string &window_spec, + const std::string &projected_columns) { + auto string_agg_params = KeywordHelper::WriteOptionallyQuoted(sep, '\''); return ApplyAggOrWin("string_agg", column, string_agg_params, groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::Sum(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::Sum(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("sum", column, "", groups, window_spec, projected_columns); } /* TODO: Approximate aggregate functions */ /* TODO: Statistical aggregate functions */ -std::unique_ptr DuckDBPyRelation::Median(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::Median(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("median", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::Mode(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::Mode(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("mode", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::QuantileCont(const std::string &column, const py::object &q, - const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::QuantileCont(const std::string &column, const py::object &q, + const std::string &groups, const std::string &window_spec, + const std::string &projected_columns) { string quantile_params = ""; if (py::isinstance(q)) { quantile_params = std::to_string(q.cast()); @@ -667,10 +662,9 @@ std::unique_ptr DuckDBPyRelation::QuantileCont(const std::stri return ApplyAggOrWin("quantile_cont", column, quantile_params, groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::QuantileDisc(const std::string &column, const py::object &q, - const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::QuantileDisc(const std::string &column, const py::object &q, + const std::string &groups, const std::string &window_spec, + const std::string &projected_columns) { string quantile_params = ""; if (py::isinstance(q)) { quantile_params = std::to_string(q.cast()); @@ -690,27 +684,27 @@ std::unique_ptr DuckDBPyRelation::QuantileDisc(const std::stri return ApplyAggOrWin("quantile_disc", column, quantile_params, groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::StdPop(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::StdPop(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("stddev_pop", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::StdSamp(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::StdSamp(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("stddev_samp", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::VarPop(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::VarPop(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("var_pop", column, "", groups, window_spec, projected_columns); } -std::unique_ptr DuckDBPyRelation::VarSamp(const std::string &column, const std::string &groups, - const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::VarSamp(const std::string &column, const std::string &groups, + const std::string &window_spec, + const std::string &projected_columns) { return ApplyAggOrWin("var_samp", column, "", groups, window_spec, projected_columns); } @@ -727,49 +721,45 @@ py::tuple DuckDBPyRelation::Shape() { return py::make_tuple(length, rel->Columns().size()); } -std::unique_ptr DuckDBPyRelation::Unique(const string &std_columns) { +unique_ptr DuckDBPyRelation::Unique(const string &std_columns) { return DeriveRelation(rel->Project(std_columns)->Distinct()); } /* General-purpose window functions */ -std::unique_ptr DuckDBPyRelation::RowNumber(const string &window_spec, - const string &projected_columns) { +unique_ptr DuckDBPyRelation::RowNumber(const string &window_spec, const string &projected_columns) { return GenericWindowFunction("row_number", "", "*", window_spec, false, projected_columns); } -std::unique_ptr DuckDBPyRelation::Rank(const string &window_spec, const string &projected_columns) { +unique_ptr DuckDBPyRelation::Rank(const string &window_spec, const string &projected_columns) { return GenericWindowFunction("rank", "", "*", window_spec, false, projected_columns); } -std::unique_ptr DuckDBPyRelation::DenseRank(const string &window_spec, - const string &projected_columns) { +unique_ptr DuckDBPyRelation::DenseRank(const string &window_spec, const string &projected_columns) { return GenericWindowFunction("dense_rank", "", "*", window_spec, false, projected_columns); } -std::unique_ptr DuckDBPyRelation::PercentRank(const string &window_spec, - const string &projected_columns) { +unique_ptr DuckDBPyRelation::PercentRank(const string &window_spec, const string &projected_columns) { return GenericWindowFunction("percent_rank", "", "*", window_spec, false, projected_columns); } -std::unique_ptr DuckDBPyRelation::CumeDist(const string &window_spec, - const string &projected_columns) { +unique_ptr DuckDBPyRelation::CumeDist(const string &window_spec, const string &projected_columns) { return GenericWindowFunction("cume_dist", "", "*", window_spec, false, projected_columns); } -std::unique_ptr DuckDBPyRelation::FirstValue(const string &column, const string &window_spec, - const string &projected_columns) { +unique_ptr DuckDBPyRelation::FirstValue(const string &column, const string &window_spec, + const string &projected_columns) { return GenericWindowFunction("first_value", "", column, window_spec, false, projected_columns); } -std::unique_ptr DuckDBPyRelation::NTile(const string &window_spec, const int &num_buckets, - const string &projected_columns) { +unique_ptr DuckDBPyRelation::NTile(const string &window_spec, const int &num_buckets, + const string &projected_columns) { return GenericWindowFunction("ntile", std::to_string(num_buckets), "", window_spec, false, projected_columns); } -std::unique_ptr DuckDBPyRelation::Lag(const string &column, const string &window_spec, - const int &offset, const string &default_value, - const bool &ignore_nulls, const string &projected_columns) { +unique_ptr DuckDBPyRelation::Lag(const string &column, const string &window_spec, const int &offset, + const string &default_value, const bool &ignore_nulls, + const string &projected_columns) { string lag_params = ""; if (offset != 0) { lag_params += std::to_string(offset); @@ -780,14 +770,14 @@ std::unique_ptr DuckDBPyRelation::Lag(const string &column, co return GenericWindowFunction("lag", lag_params, column, window_spec, ignore_nulls, projected_columns); } -std::unique_ptr DuckDBPyRelation::LastValue(const std::string &column, const std::string &window_spec, - const std::string &projected_columns) { +unique_ptr DuckDBPyRelation::LastValue(const std::string &column, const std::string &window_spec, + const std::string &projected_columns) { return GenericWindowFunction("last_value", "", column, window_spec, false, projected_columns); } -std::unique_ptr DuckDBPyRelation::Lead(const string &column, const string &window_spec, - const int &offset, const string &default_value, - const bool &ignore_nulls, const string &projected_columns) { +unique_ptr DuckDBPyRelation::Lead(const string &column, const string &window_spec, const int &offset, + const string &default_value, const bool &ignore_nulls, + const string &projected_columns) { string lead_params = ""; if (offset != 0) { lead_params += std::to_string(offset); @@ -798,14 +788,14 @@ std::unique_ptr DuckDBPyRelation::Lead(const string &column, c return GenericWindowFunction("lead", lead_params, column, window_spec, ignore_nulls, projected_columns); } -std::unique_ptr DuckDBPyRelation::NthValue(const string &column, const string &window_spec, - const int &offset, const bool &ignore_nulls, - const string &projected_columns) { +unique_ptr DuckDBPyRelation::NthValue(const string &column, const string &window_spec, + const int &offset, const bool &ignore_nulls, + const string &projected_columns) { return GenericWindowFunction("nth_value", std::to_string(offset), column, window_spec, ignore_nulls, projected_columns); } -std::unique_ptr DuckDBPyRelation::Distinct() { +unique_ptr DuckDBPyRelation::Distinct() { return DeriveRelation(rel->Distinct()); } @@ -840,7 +830,7 @@ void DuckDBPyRelation::ExecuteOrThrow(bool stream_result) { if (query_result->HasError()) { query_result->ThrowError(); } - result = std::make_unique(std::move(query_result)); + result = make_uniq(std::move(query_result)); } PandasDataFrame DuckDBPyRelation::FetchDF(bool date_as_object) { @@ -1035,7 +1025,7 @@ PolarsDataFrame DuckDBPyRelation::ToPolars(idx_t batch_size, bool lazy) { ArrowConverter::ToArrowSchema(&arrow_schema, types, result_names, client_properties); py::list batches; // Now we create an empty arrow table - auto empty_table = pyarrow::ToArrowTable(types, result_names, batches, client_properties); + auto empty_table = pyarrow::ToArrowTable(types, result_names, std::move(batches), client_properties); // And we extract the polars schema from the arrow table auto polars_df = py::cast(pybind11::module_::import("polars").attr("DataFrame")(empty_table)); @@ -1081,47 +1071,47 @@ void DuckDBPyRelation::SetConnectionOwner(py::object owner) { connection_owner = std::move(owner); } -std::unique_ptr DuckDBPyRelation::DeriveRelation(shared_ptr new_rel) { - auto result_ = std::make_unique(std::move(new_rel)); - result_->connection_owner = connection_owner; - return result_; +unique_ptr DuckDBPyRelation::DeriveRelation(shared_ptr new_rel) { + auto result = make_uniq(std::move(new_rel)); + result->connection_owner = connection_owner; + return result; } -std::unique_ptr DuckDBPyRelation::DeriveRelation(std::shared_ptr result_p) { - auto result_ = std::make_unique(std::move(result_p)); - result_->connection_owner = connection_owner; - return result_; +unique_ptr DuckDBPyRelation::DeriveRelation(shared_ptr result_p) { + auto result = make_uniq(std::move(result_p)); + result->connection_owner = connection_owner; + return result; } static bool ContainsStructFieldByName(LogicalType &type, const string &name) { if (type.id() != LogicalTypeId::STRUCT) { return false; } - const auto name_identifier = Identifier(name); - const auto count = StructType::GetChildCount(type); + auto count = StructType::GetChildCount(type); for (idx_t i = 0; i < count; i++) { - if (StructType::GetChildName(type, i) == name) { + auto &field_name = StructType::GetChildName(type, i); + if (StringUtil::CIEquals(name, field_name)) { return true; } } return false; } -std::unique_ptr DuckDBPyRelation::GetAttribute(const string &name) { +unique_ptr DuckDBPyRelation::GetAttribute(const string &name) { // TODO: support fetching a result containing only column 'name' from a value_relation if (!rel) { throw py::attribute_error( StringUtil::Format("This relation does not contain a column by the name of '%s'", name)); } - vector column_names; + vector column_names; if (names.size() == 1 && ContainsStructFieldByName(types[0], name)) { // e.g 'rel['my_struct']['my_field']: // first 'my_struct' is selected by the bottom condition // then 'my_field' is accessed on the result of this - column_names.push_back(Identifier(names[0])); - column_names.push_back(Identifier(name)); + column_names.push_back(names[0]); + column_names.push_back(name); } else if (ContainsColumnByName(name)) { - column_names.push_back(Identifier(name)); + column_names.push_back(name); } if (column_names.empty()) { @@ -1136,15 +1126,15 @@ std::unique_ptr DuckDBPyRelation::GetAttribute(const string &n return DeriveRelation(rel->Project(std::move(expressions), aliases)); } -std::unique_ptr DuckDBPyRelation::Union(DuckDBPyRelation *other) { +unique_ptr DuckDBPyRelation::Union(DuckDBPyRelation *other) { return DeriveRelation(rel->Union(other->rel)); } -std::unique_ptr DuckDBPyRelation::Except(DuckDBPyRelation *other) { +unique_ptr DuckDBPyRelation::Except(DuckDBPyRelation *other) { return DeriveRelation(rel->Except(other->rel)); } -std::unique_ptr DuckDBPyRelation::Intersect(DuckDBPyRelation *other) { +unique_ptr DuckDBPyRelation::Intersect(DuckDBPyRelation *other) { return DeriveRelation(rel->Intersect(other->rel)); } @@ -1187,8 +1177,8 @@ static JoinType ParseJoinType(const string &type) { throw InvalidInputException("Unsupported join type %s, try one of: %s", provided, options); } -std::unique_ptr DuckDBPyRelation::Join(DuckDBPyRelation *other, const py::object &condition, - const string &type) { +unique_ptr DuckDBPyRelation::Join(DuckDBPyRelation *other, const py::object &condition, + const string &type) { if (!other) { throw InvalidInputException("No relation provided for join"); } @@ -1211,14 +1201,15 @@ std::unique_ptr DuckDBPyRelation::Join(DuckDBPyRelation *other auto condition_string = std::string(py::cast(condition)); return DeriveRelation(rel->Join(other->rel, condition_string, join_type)); } - vector using_list; + vector using_list; if (py::is_list_like(condition)) { - for (auto &item : py::list(condition)) { + auto using_list_p = py::list(condition); + for (auto &item : using_list_p) { if (!py::isinstance(item)) { string actual_type = py::str(py::type::of(item)); throw InvalidInputException("Using clause should be a list of strings, not %s", actual_type); } - using_list.push_back(Identifier(std::string(py::str(item)))); + using_list.push_back(std::string(py::str(item))); } if (using_list.empty()) { throw InvalidInputException("Please provide at least one string in the condition to create a USING clause"); @@ -1226,7 +1217,7 @@ std::unique_ptr DuckDBPyRelation::Join(DuckDBPyRelation *other auto join_relation = make_shared_ptr(rel, other->rel, std::move(using_list), join_type); return DeriveRelation(std::move(join_relation)); } - std::shared_ptr condition_expr; + shared_ptr condition_expr; if (!py::try_cast(condition, condition_expr)) { throw InvalidInputException( "Please provide condition as an expression either in string form or as an Expression object"); @@ -1236,7 +1227,7 @@ std::unique_ptr DuckDBPyRelation::Join(DuckDBPyRelation *other return DeriveRelation(rel->Join(other->rel, std::move(conditions), join_type)); } -std::unique_ptr DuckDBPyRelation::Cross(DuckDBPyRelation *other) { +unique_ptr DuckDBPyRelation::Cross(DuckDBPyRelation *other) { return DeriveRelation(rel->CrossProduct(other->rel)); } @@ -1255,13 +1246,11 @@ static Value NestedDictToStruct(const py::object &dictionary) { throw InvalidInputException("NestedDictToStruct only accepts a dictionary with string keys"); } - auto item_key_str = string(py::str(item_key)); - if (py::isinstance(item_value)) { int32_t item_value_int = py::int_(item_value); - children.push_back(std::make_pair(Identifier(item_key_str), Value(item_value_int))); + children.push_back(std::make_pair(py::str(item_key), Value(item_value_int))); } else if (py::isinstance(item_value)) { - children.push_back(std::make_pair(Identifier(item_key_str), NestedDictToStruct(item_value))); + children.push_back(std::make_pair(py::str(item_key), NestedDictToStruct(item_value))); } else { throw InvalidInputException( "NestedDictToStruct only accepts a dictionary with integer values or nested dictionaries"); @@ -1326,7 +1315,7 @@ void DuckDBPyRelation::ToParquet(const string &filename, const py::object &compr if (!py::isinstance(field)) { throw InvalidInputException("to_parquet only accepts 'partition_by' as a list of strings"); } - partition_by_values.emplace_back(py::str(field)); + partition_by_values.emplace_back(Value(py::str(field))); } options["partition_by"] = {partition_by_values}; } @@ -1516,7 +1505,7 @@ void DuckDBPyRelation::ToCSV(const string &filename, const py::object &sep, cons if (!py::isinstance(field)) { throw InvalidInputException("to_csv only accepts 'partition_by' as a list of strings"); } - partition_by_values.emplace_back(py::str(field)); + partition_by_values.emplace_back(Value(py::str(field))); } options["partition_by"] = {partition_by_values}; } @@ -1533,8 +1522,8 @@ void DuckDBPyRelation::ToCSV(const string &filename, const py::object &sep, cons } // should this return a rel with the new view? -std::unique_ptr DuckDBPyRelation::CreateView(const string &view_name, bool replace) { - rel->CreateView(Identifier(view_name), replace); +unique_ptr DuckDBPyRelation::CreateView(const string &view_name, bool replace) { + rel->CreateView(view_name, replace); return DeriveRelation(rel); } @@ -1549,8 +1538,8 @@ static bool IsDescribeStatement(SQLStatement &statement) { return true; } -std::unique_ptr DuckDBPyRelation::Query(const string &view_name, const string &sql_query) { - rel->CreateView(Identifier(view_name), /*replace=*/true, /*temporary=*/true); +unique_ptr DuckDBPyRelation::Query(const string &view_name, const string &sql_query) { + rel->CreateView(view_name, /*replace=*/true, /*temporary=*/true); auto all_dependencies = rel->GetAllDependencies(); Parser parser(rel->context->GetContext()->GetParserOptions()); @@ -1562,7 +1551,7 @@ std::unique_ptr DuckDBPyRelation::Query(const string &view_nam if (statement.type == StatementType::SELECT_STATEMENT) { auto select_statement = unique_ptr_cast(std::move(parser.statements[0])); auto query_relation = make_shared_ptr(rel->context->GetContext(), std::move(select_statement), - "query_relation_" + StringUtil::GenerateRandomName(16), sql_query); + sql_query, "query_relation"); return DeriveRelation(std::move(query_relation)); } else if (IsDescribeStatement(statement)) { auto query = PragmaShow(view_name); @@ -1591,7 +1580,7 @@ DuckDBPyRelation &DuckDBPyRelation::Execute() { void DuckDBPyRelation::InsertInto(const string &table) { AssertRelation(); auto parsed_info = QualifiedName::Parse(table); - auto insert = rel->InsertRel(parsed_info.Catalog(), parsed_info.Schema(), parsed_info.Name()); + auto insert = rel->InsertRel(parsed_info.catalog, parsed_info.schema, parsed_info.name); PyExecuteRelation(insert); } @@ -1599,8 +1588,8 @@ void DuckDBPyRelation::Update(const py::object &set_p, const py::object &where) AssertRelation(); unique_ptr condition; if (!py::none().is(where)) { - std::shared_ptr py_expr; - if (!py::try_cast>(where, py_expr)) { + shared_ptr py_expr; + if (!py::try_cast>(where, py_expr)) { throw InvalidInputException("Please provide an Expression to 'condition'"); } condition = py_expr->GetExpression().Copy(); @@ -1610,7 +1599,7 @@ void DuckDBPyRelation::Update(const py::object &set_p, const py::object &where) throw InvalidInputException("Please provide 'set' as a dictionary of column name to Expression"); } - vector names_; + vector names; vector> expressions; py::dict set = py::dict(set_p); @@ -1626,17 +1615,17 @@ void DuckDBPyRelation::Update(const py::object &set_p, const py::object &where) if (!py::isinstance(item_key)) { throw InvalidInputException("Please provide the column name as the key of the dictionary"); } - std::shared_ptr py_expr; - if (!py::try_cast>(item_value, py_expr)) { + shared_ptr py_expr; + if (!py::try_cast>(item_value, py_expr)) { string actual_type = py::str(py::type::of(item_value)); throw InvalidInputException("Please provide an object of type Expression as the value, not %s", actual_type); } - names_.push_back(std::string(py::str(item_key))); + names.push_back(std::string(py::str(item_key))); expressions.push_back(py_expr->GetExpression().Copy()); } - return rel->Update(std::move(names_), std::move(expressions), std::move(condition)); + return rel->Update(std::move(names), std::move(expressions), std::move(condition)); } void DuckDBPyRelation::Insert(const py::object ¶ms) const { @@ -1644,8 +1633,7 @@ void DuckDBPyRelation::Insert(const py::object ¶ms) const { if (this->rel->type != RelationType::TABLE_RELATION) { throw InvalidInputException("'DuckDBPyRelation.insert' can only be used on a table relation"); } - vector> values { - DuckDBPyConnection::TransformPythonParamList(*this->rel->context->GetContext(), params)}; + vector> values {DuckDBPyConnection::TransformPythonParamList(params)}; D_ASSERT(py::gil_check()); py::gil_scoped_release release; @@ -1655,11 +1643,11 @@ void DuckDBPyRelation::Insert(const py::object ¶ms) const { void DuckDBPyRelation::Create(const string &table) { AssertRelation(); auto parsed_info = QualifiedName::Parse(table); - auto create = rel->CreateRel(parsed_info.Schema(), parsed_info.Name(), false); + auto create = rel->CreateRel(parsed_info.schema, parsed_info.name, false); PyExecuteRelation(create); } -std::unique_ptr DuckDBPyRelation::Map(py::function fun, Optional schema) { +unique_ptr DuckDBPyRelation::Map(py::function fun, Optional schema) { AssertRelation(); vector params; params.emplace_back(Value::POINTER(CastPointerToValue(fun.ptr()))); @@ -1678,8 +1666,9 @@ string DuckDBPyRelation::ToStringInternal(const BoxRendererConfig &config, bool BoxRenderer renderer; auto limit = Limit(config.limit, 0); auto res = limit->ExecuteInternal(); - auto context = ClientBoxRendererContext(*rel->context->GetContext()); - rendered_result = res->ToBox(context, config); + + auto context = rel->context->GetContext(); + rendered_result = res->ToBox(*context, config); } return rendered_result; } @@ -1734,11 +1723,12 @@ void DuckDBPyRelation::Print(const Optional &max_width, const Optional py::print(py::str(ToStringInternal(config, invalidate_cache))); } -static ProfilerPrintFormat GetExplainFormat(ExplainType type) { +static ExplainFormat GetExplainFormat(ExplainType type) { if (DuckDBPyConnection::IsJupyter() && type != ExplainType::EXPLAIN_ANALYZE) { - return ProfilerPrintFormat::HTML(); + return ExplainFormat::HTML; + } else { + return ExplainFormat::DEFAULT; } - return ProfilerPrintFormat::Default(); } static void DisplayHTML(const string &html) { @@ -1750,35 +1740,30 @@ static void DisplayHTML(const string &html) { display_attr(html_object); } -string DuckDBPyRelation::Explain(ExplainType type, const string &format) { +string DuckDBPyRelation::Explain(ExplainType type) { AssertRelation(); D_ASSERT(py::gil_check()); py::gil_scoped_release release; - // An empty format means "auto": the default format, or HTML when running under Jupyter. - const bool auto_format = format.empty(); - auto explain_format = auto_format ? GetExplainFormat(type) : ProfilerPrintFormat(format); + auto explain_format = GetExplainFormat(type); auto res = rel->Explain(type, explain_format); D_ASSERT(res->type == duckdb::QueryResultType::MATERIALIZED_RESULT); auto &materialized = res->Cast(); auto &coll = materialized.Collection(); - // Only the implicit Jupyter path renders HTML inline; an explicitly requested format always returns a string. - const bool jupyter_html = - auto_format && explain_format == ProfilerPrintFormat::HTML() && DuckDBPyConnection::IsJupyter(); - if (!jupyter_html) { - string result_; + if (explain_format != ExplainFormat::HTML || !DuckDBPyConnection::IsJupyter()) { + string result; for (auto &row : coll.Rows()) { // Skip the first column because it just contains 'physical plan' for (idx_t col_idx = 1; col_idx < coll.ColumnCount(); col_idx++) { if (col_idx > 1) { - result_ += "\t"; + result += "\t"; } auto val = row.GetValue(col_idx); - result_ += val.IsNull() ? "NULL" : StringUtil::Replace(val.ToString(), string("\0", 1), "\\0"); + result += val.IsNull() ? "NULL" : StringUtil::Replace(val.ToString(), string("\0", 1), "\\0"); } - result_ += "\n"; + result += "\n"; } - return result_; + return result; } auto chunk = materialized.Fetch(); diff --git a/src/duckdb_py/pyrelation/initialize.cpp b/src/duckdb_py/pyrelation/initialize.cpp index 154a1b80..4393889a 100644 --- a/src/duckdb_py/pyrelation/initialize.cpp +++ b/src/duckdb_py/pyrelation/initialize.cpp @@ -1,7 +1,6 @@ #include "duckdb_python/pyrelation.hpp" #include "duckdb_python/pyconnection/pyconnection.hpp" #include "duckdb_python/pyresult.hpp" -#include "duckdb_python/pybind11/conversions/explain_enum.hpp" #include "duckdb/parser/qualified_name.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb_python/numpy/numpy_type.hpp" @@ -263,18 +262,11 @@ static void InitializeSetOperators(py::class_ &m) { static void InitializeMetaQueries(py::class_ &m) { m.def("describe", &DuckDBPyRelation::Describe, "Gives basic statistics (e.g., min, max) and if NULL exists for each column of the relation.") - .def( - "explain", - [](DuckDBPyRelation &self, ExplainType type, const py::object &format) { - // An omitted format (None) maps to "" = auto-select (default, or HTML under Jupyter). - string format_str = format.is_none() ? string() : string(py::str(format)); - return self.Explain(type, format_str); - }, - py::arg("type") = ExplainType::EXPLAIN_STANDARD, py::arg("format") = py::none()); + .def("explain", &DuckDBPyRelation::Explain, py::arg("type") = "standard"); } void DuckDBPyRelation::Initialize(py::handle &m) { - auto relation_module = py::class_(m, "DuckDBPyRelation"); + auto relation_module = py::class_(m, "DuckDBPyRelation", py::module_local()); InitializeReadOnlyProperties(relation_module); InitializeAggregates(relation_module); InitializeWindowOperators(relation_module); diff --git a/src/duckdb_py/pyresult.cpp b/src/duckdb_py/pyresult.cpp index ed7d0481..270c1625 100644 --- a/src/duckdb_py/pyresult.cpp +++ b/src/duckdb_py/pyresult.cpp @@ -10,7 +10,13 @@ #include "duckdb/common/arrow/arrow_converter.hpp" #include "duckdb/common/arrow/arrow_wrapper.hpp" #include "duckdb/common/arrow/result_arrow_wrapper.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/types/uhugeint.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/timestamp.hpp" #include "duckdb/common/types/uuid.hpp" +#include "duckdb_python/numpy/array_wrapper.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/enums/stream_execution_result.hpp" #include "duckdb_python/arrow/arrow_export_utils.hpp" @@ -220,7 +226,7 @@ void InsertCategory(QueryResult &result, unordered_map &categor } } -std::unique_ptr DuckDBPyResult::InitializeNumpyConversion(bool pandas) { +unique_ptr DuckDBPyResult::InitializeNumpyConversion(bool pandas) { if (!result) { throw InvalidInputException("result closed"); } @@ -233,12 +239,12 @@ std::unique_ptr DuckDBPyResult::InitializeNumpyConversion } auto conversion = - std::make_unique(result->types, initial_capacity, result->client_properties, pandas); + make_uniq(result->types, initial_capacity, result->client_properties, pandas); return conversion; } py::dict DuckDBPyResult::FetchNumpyInternal(bool stream, idx_t vectors_per_chunk, - std::unique_ptr conversion_p) { + unique_ptr conversion_p) { if (!result) { throw InvalidInputException("result closed"); } @@ -430,13 +436,7 @@ static unique_ptr MakeColumnDataScanStatement(unique_ptr expected_names; - expected_names.reserve(deduplicated_names.size()); - for (auto &name : deduplicated_names) { - expected_names.emplace_back(std::move(name)); - } - auto table_ref = make_uniq(std::move(collection), std::move(expected_names)); + auto table_ref = make_uniq(std::move(collection), std::move(deduplicated_names)); table_ref->alias = "materialized"; // binding asserts on an unset alias auto select_node = make_uniq(); select_node->select_list.push_back(make_uniq()); diff --git a/src/duckdb_py/pystatement.cpp b/src/duckdb_py/pystatement.cpp index c58df10d..7e84df7e 100644 --- a/src/duckdb_py/pystatement.cpp +++ b/src/duckdb_py/pystatement.cpp @@ -4,7 +4,7 @@ namespace duckdb { enum class ExpectedResultType : uint8_t { QUERY_RESULT, NOTHING, CHANGED_ROWS, UNKNOWN }; -static void InitializeReadOnlyProperties(py::class_> &m) { +static void InitializeReadOnlyProperties(py::class_> &m) { m.def_property_readonly("type", &DuckDBPyStatement::Type, "Get the type of the statement.") .def_property_readonly("query", &DuckDBPyStatement::Query, "Get the query equivalent to this statement.") .def_property_readonly("named_parameters", &DuckDBPyStatement::NamedParameters, @@ -15,7 +15,8 @@ static void InitializeReadOnlyProperties(py::class_>(m, "Statement"); + auto relation_module = + py::class_>(m, "Statement", py::module_local()); InitializeReadOnlyProperties(relation_module); } @@ -36,7 +37,7 @@ py::set DuckDBPyStatement::NamedParameters() const { py::set result; auto &named_parameters = statement->named_param_map; for (auto ¶m : named_parameters) { - result.add(param.first.GetIdentifierName()); + result.add(param.first); } return result; } diff --git a/src/duckdb_py/python_replacement_scan.cpp b/src/duckdb_py/python_replacement_scan.cpp index cef37cd1..8bff9e8f 100644 --- a/src/duckdb_py/python_replacement_scan.cpp +++ b/src/duckdb_py/python_replacement_scan.cpp @@ -3,7 +3,6 @@ #include "duckdb_python/pybind11/pybind_wrapper.hpp" #include "duckdb/main/client_properties.hpp" #include "duckdb_python/numpy/numpy_type.hpp" -#include "duckdb_python/numpy/numpy_array.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb_python/pyconnection/pyconnection.hpp" #include "duckdb_python/pybind11/dataframe.hpp" @@ -167,15 +166,13 @@ unique_ptr PythonReplacementScan::TryReplacementObject(const py::objec case NumpyObjectType::NDARRAY1D: data["column0"] = entry; break; - case NumpyObjectType::NDARRAY2D: { + case NumpyObjectType::NDARRAY2D: idx = 0; - NumpyArray ndarray(entry); - for (auto item : ndarray.GetArray()) { + for (auto item : py::cast(entry)) { data[("column" + std::to_string(idx)).c_str()] = item; idx++; } break; - } case NumpyObjectType::LIST: idx = 0; for (auto item : py::cast(entry)) { diff --git a/src/duckdb_py/python_udf.cpp b/src/duckdb_py/python_udf.cpp index c8199c05..9af66b37 100644 --- a/src/duckdb_py/python_udf.cpp +++ b/src/duckdb_py/python_udf.cpp @@ -94,7 +94,7 @@ static void ConvertArrowTableToVector(const py::object &table, Vector &out, Clie children.push_back(Value::POINTER(CastPointerToValue(stream_factory_get_schema))); named_parameter_map_t named_params; vector input_types; - vector input_names; + vector input_names; TableFunctionRef empty; TableFunction dummy_table_function; @@ -130,8 +130,8 @@ static void ConvertArrowTableToVector(const py::object &table, Vector &out, Clie } VectorOperations::Cast(context, result.data[0], out, count); - out.Flatten(); - out.Verify(); + out.Flatten(count); + out.Verify(count); } static string NullHandlingError() { @@ -150,7 +150,7 @@ static ValidityMask &GetResultValidity(Vector &result) { if (vector_type == VectorType::CONSTANT_VECTOR) { return ConstantVector::Validity(result); } else if (vector_type == VectorType::FLAT_VECTOR) { - return FlatVector::ValidityMutable(result); + return FlatVector::Validity(result); } else { throw InternalException("VectorType %s was not expected here (GetResultValidity)", EnumUtil::ToString(vector_type)); @@ -158,8 +158,9 @@ static ValidityMask &GetResultValidity(Vector &result) { } static void VerifyVectorizedNullHandling(Vector &result, idx_t count) { + auto &validity = GetResultValidity(result); - if (const auto &validity = GetResultValidity(result); validity.CannotHaveNull()) { + if (validity.AllValid()) { return; } @@ -185,12 +186,13 @@ static scalar_function_t CreateVectorizedFunction(PyObject *function, PythonExce auto options = context.GetClientProperties(); // } + auto result_validity = FlatVector::Validity(result); SelectionVector selvec(input.size()); idx_t input_size = input.size(); if (default_null_handling) { vector vec_data(input.ColumnCount()); for (idx_t i = 0; i < input.ColumnCount(); i++) { - input.data[i].ToUnifiedFormat(vec_data[i]); + input.data[i].ToUnifiedFormat(input.size(), vec_data[i]); } idx_t index = 0; @@ -204,6 +206,7 @@ static scalar_function_t CreateVectorizedFunction(PyObject *function, PythonExce } } if (any_null) { + result_validity.SetInvalid(i); continue; } selvec.set_index(index++, i); @@ -261,14 +264,13 @@ static scalar_function_t CreateVectorizedFunction(PyObject *function, PythonExce } if (count) { SelectionVector inverted(input_size); - // Map each target row back to a source row in temp. Non-null target rows map to - // their UDF output; null target rows point at the next non-null source row (their - // data is later masked out by SetNull). - // example: input_size: 6, null_indices: 1,3 - // selvec (non-null indices): [0, 2, 4, 5] - // inverted selvec: [0, 1, 1, 2, 2, 3] + // Create a SelVec that inverts the filtering + // example: count: 6, null_indices: 1,3 + // input selvec: [0, 2, 4, 5] + // inverted selvec: [0, 0, 1, 1, 2, 3] idx_t src_index = 0; for (idx_t i = 0; i < input_size; i++) { + // Fill the gaps with the previous index inverted.set_index(i, src_index); if (src_index + 1 < count && selvec.get_index(src_index) == i) { src_index++; @@ -276,18 +278,10 @@ static scalar_function_t CreateVectorizedFunction(PyObject *function, PythonExce } VectorOperations::Copy(temp, result, inverted, count, 0, 0, input_size); } - // Apply the null mask: any position not present in selvec was a null input row. - // VectorOperations::Copy unconditionally overwrites the result's validity from - // the source's, so we must do this after the Copy. - idx_t sel_idx = 0; for (idx_t i = 0; i < input_size; i++) { - if (sel_idx < count && selvec.get_index(sel_idx) == i) { - sel_idx++; - } else { - FlatVector::SetNull(result, i, true); - } + FlatVector::SetNull(result, i, !result_validity.RowIsValid(i)); } - result.Verify(); + result.Verify(input_size); } else { ConvertArrowTableToVector(python_object, result, state.GetContext(), count); if (default_null_handling && !exception_occurred) { @@ -357,7 +351,7 @@ static scalar_function_t CreateNativeFunction(PyObject *function, PythonExceptio throw InvalidInputException(NullHandlingError()); } } - TransformPythonObject(state.GetContext(), ret, result, row); + TransformPythonObject(ret, result, row); } if (input.size() == 1) { @@ -423,7 +417,7 @@ struct PythonUDFData { } } - void OverrideReturnType(const std::shared_ptr &type) { + void OverrideReturnType(const shared_ptr &type) { if (!type) { return; } @@ -451,7 +445,7 @@ struct PythonUDFData { } idx_t i = 0; for (auto ¶m : params) { - auto type = py::cast>(param); + auto type = py::cast>(param); parameters[i++] = type->Type(); } } @@ -474,8 +468,8 @@ struct PythonUDFData { auto return_annotation = signature.attr("return_annotation"); auto empty = py::module_::import("inspect").attr("Signature").attr("empty"); if (!py::none().is(return_annotation) && !empty.is(return_annotation)) { - std::shared_ptr pytype; - if (py::try_cast>(return_annotation, pytype)) { + shared_ptr pytype; + if (py::try_cast>(return_annotation, pytype)) { return_type = pytype->Type(); } } @@ -484,8 +478,8 @@ struct PythonUDFData { auto params = py::dict(sig_params); for (auto &item : params) { auto &value = item.second; - std::shared_ptr pytype; - if (py::try_cast>(value.attr("annotation"), pytype)) { + shared_ptr pytype; + if (py::try_cast>(value.attr("annotation"), pytype)) { parameters.push_back(pytype->Type()); } else { std::string kind = py::str(value.attr("kind")); @@ -525,7 +519,7 @@ struct PythonUDFData { } FunctionStability function_side_effects = side_effects ? FunctionStability::VOLATILE : FunctionStability::CONSISTENT; - ScalarFunction scalar_function(Identifier(name), std::move(parameters), return_type, func, nullptr, nullptr, + ScalarFunction scalar_function(name, std::move(parameters), return_type, func, nullptr, nullptr, nullptr, nullptr, varargs, function_side_effects, null_handling); return scalar_function; } @@ -535,7 +529,7 @@ struct PythonUDFData { ScalarFunction DuckDBPyConnection::CreateScalarUDF(const string &name, const py::function &udf, const py::object ¶meters, - const std::shared_ptr &return_type, bool vectorized, + const shared_ptr &return_type, bool vectorized, FunctionNullHandling null_handling, PythonExceptionHandling exception_handling, bool side_effects) { PythonUDFData data(name, vectorized, null_handling); diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index fef05918..5087de50 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -38,7 +38,7 @@ bool PyUnionType::check_(const py::handle &object) { DuckDBPyType::DuckDBPyType(LogicalType type) : type(std::move(type)) { } -bool DuckDBPyType::Equals(const std::shared_ptr &other) const { +bool DuckDBPyType::Equals(const shared_ptr &other) const { if (!other) { return false; } @@ -49,27 +49,26 @@ bool DuckDBPyType::EqualsString(const string &type_str) const { return StringUtil::CIEquals(type.ToString(), type_str); } -std::shared_ptr DuckDBPyType::GetAttribute(const string &name) const { - auto name_identifier = Identifier(name); +shared_ptr DuckDBPyType::GetAttribute(const string &name) const { if (type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION) { auto &children = StructType::GetChildTypes(type); for (idx_t i = 0; i < children.size(); i++) { auto &child = children[i]; - if (child.first == name) { - return std::make_shared(StructType::GetChildType(type, i)); + if (StringUtil::CIEquals(child.first, name)) { + return make_shared_ptr(StructType::GetChildType(type, i)); } } } if (type.id() == LogicalTypeId::LIST && StringUtil::CIEquals(name, "child")) { - return std::make_shared(ListType::GetChildType(type)); + return make_shared_ptr(ListType::GetChildType(type)); } if (type.id() == LogicalTypeId::MAP) { auto is_key = StringUtil::CIEquals(name, "key"); auto is_value = StringUtil::CIEquals(name, "value"); if (is_key) { - return std::make_shared(MapType::KeyType(type)); + return make_shared_ptr(MapType::KeyType(type)); } else if (is_value) { - return std::make_shared(MapType::ValueType(type)); + return make_shared_ptr(MapType::ValueType(type)); } else { throw py::attribute_error(StringUtil::Format("Tried to get a child from a map by the name of '%s', but " "this type only has 'key' and 'value' children", @@ -118,7 +117,7 @@ static PythonTypeObject GetTypeObjectType(const py::handle &type_object) { return PythonTypeObject::INVALID; } -static LogicalType FromString(const string &type_str, std::shared_ptr pycon) { +static LogicalType FromString(const string &type_str, shared_ptr pycon) { if (!pycon) { pycon = DuckDBPyConnection::DefaultConnection(); } @@ -229,7 +228,7 @@ static LogicalType FromUnionTypeInternal(const py::tuple &args) { child_list_t members; for (const auto &arg : args) { - auto name = Identifier(StringUtil::Format("u%d", index++)); + auto name = StringUtil::Format("u%d", index++); py::object object = py::reinterpret_borrow(arg); members.push_back(make_pair(name, FromObject(object))); } @@ -285,7 +284,7 @@ static LogicalType FromDictionary(const py::object &obj) { for (auto &item : dict) { auto &name_p = item.first; auto type_p = py::reinterpret_borrow(item.second); - auto name = Identifier(py::str(name_p)); + string name = py::str(name_p); auto type = FromObject(type_p); children.push_back(std::make_pair(name, std::move(type))); } @@ -312,8 +311,8 @@ static LogicalType FromObject(const py::object &object) { return FromString(string_value, nullptr); } case PythonTypeObject::TYPE: { - std::shared_ptr type_object; - if (!py::try_cast>(object, type_object)) { + shared_ptr type_object; + if (!py::try_cast>(object, type_object)) { string actual_type = py::str(py::type::of(object)); throw InvalidInputException("Expected argument of type DuckDBPyType, received '%s' instead", actual_type); } @@ -327,7 +326,7 @@ static LogicalType FromObject(const py::object &object) { } void DuckDBPyType::Initialize(py::handle &m) { - auto type_module = py::class_>(m, "DuckDBPyType"); + auto type_module = py::class_>(m, "DuckDBPyType", py::module_local()); type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object"); type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"), @@ -337,21 +336,21 @@ void DuckDBPyType::Initialize(py::handle &m) { type_module.def("__hash__", [](const DuckDBPyType &type) { return py::hash(py::str(type.ToString())); }); type_module.def_property_readonly("id", &DuckDBPyType::GetId); type_module.def_property_readonly("children", &DuckDBPyType::Children); - type_module.def(py::init<>([](const string &type_str, std::shared_ptr connection = nullptr) { + type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { auto ltype = FromString(type_str, std::move(connection)); - return std::make_shared(ltype); + return make_shared_ptr(ltype); })); type_module.def(py::init<>([](const PyGenericAlias &obj) { auto ltype = FromGenericAlias(obj); - return std::make_shared(ltype); + return make_shared_ptr(ltype); })); type_module.def(py::init<>([](const PyUnionType &obj) { auto ltype = FromUnionType(obj); - return std::make_shared(ltype); + return make_shared_ptr(ltype); })); type_module.def(py::init<>([](const py::object &obj) { auto ltype = FromObject(obj); - return std::make_shared(ltype); + return make_shared_ptr(ltype); })); type_module.def("__getattr__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name")); type_module.def("__getitem__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name"), @@ -385,11 +384,11 @@ py::list DuckDBPyType::Children() const { py::list children; auto id = type.id(); if (id == LogicalTypeId::LIST) { - children.append(py::make_tuple("child", std::make_shared(ListType::GetChildType(type)))); + children.append(py::make_tuple("child", make_shared_ptr(ListType::GetChildType(type)))); return children; } if (id == LogicalTypeId::ARRAY) { - children.append(py::make_tuple("child", std::make_shared(ArrayType::GetChildType(type)))); + children.append(py::make_tuple("child", make_shared_ptr(ArrayType::GetChildType(type)))); children.append(py::make_tuple("size", ArrayType::GetSize(type))); return children; } @@ -408,13 +407,13 @@ py::list DuckDBPyType::Children() const { for (idx_t i = 0; i < struct_children.size(); i++) { auto &child = struct_children[i]; children.append( - py::make_tuple(child.first, std::make_shared(StructType::GetChildType(type, i)))); + py::make_tuple(child.first, make_shared_ptr(StructType::GetChildType(type, i)))); } return children; } if (id == LogicalTypeId::MAP) { - children.append(py::make_tuple("key", std::make_shared(MapType::KeyType(type)))); - children.append(py::make_tuple("value", std::make_shared(MapType::ValueType(type)))); + children.append(py::make_tuple("key", make_shared_ptr(MapType::KeyType(type)))); + children.append(py::make_tuple("value", make_shared_ptr(MapType::ValueType(type)))); return children; } if (id == LogicalTypeId::DECIMAL) { diff --git a/src/duckdb_py/typing/typing.cpp b/src/duckdb_py/typing/typing.cpp index 492dea23..c86f3712 100644 --- a/src/duckdb_py/typing/typing.cpp +++ b/src/duckdb_py/typing/typing.cpp @@ -4,40 +4,40 @@ namespace duckdb { static void DefineBaseTypes(py::handle &m) { - m.attr("SQLNULL") = std::make_shared(LogicalType::SQLNULL); - m.attr("BOOLEAN") = std::make_shared(LogicalType::BOOLEAN); - m.attr("TINYINT") = std::make_shared(LogicalType::TINYINT); - m.attr("UTINYINT") = std::make_shared(LogicalType::UTINYINT); - m.attr("SMALLINT") = std::make_shared(LogicalType::SMALLINT); - m.attr("USMALLINT") = std::make_shared(LogicalType::USMALLINT); - m.attr("INTEGER") = std::make_shared(LogicalType::INTEGER); - m.attr("UINTEGER") = std::make_shared(LogicalType::UINTEGER); - m.attr("BIGINT") = std::make_shared(LogicalType::BIGINT); - m.attr("UBIGINT") = std::make_shared(LogicalType::UBIGINT); - m.attr("HUGEINT") = std::make_shared(LogicalType::HUGEINT); - m.attr("UHUGEINT") = std::make_shared(LogicalType::UHUGEINT); - m.attr("UUID") = std::make_shared(LogicalType::UUID); - m.attr("FLOAT") = std::make_shared(LogicalType::FLOAT); - m.attr("DOUBLE") = std::make_shared(LogicalType::DOUBLE); - m.attr("DATE") = std::make_shared(LogicalType::DATE); - - m.attr("TIMESTAMP") = std::make_shared(LogicalType::TIMESTAMP); - m.attr("TIMESTAMP_MS") = std::make_shared(LogicalType::TIMESTAMP_MS); - m.attr("TIMESTAMP_NS") = std::make_shared(LogicalType::TIMESTAMP_NS); - m.attr("TIMESTAMP_S") = std::make_shared(LogicalType::TIMESTAMP_S); - - m.attr("TIME") = std::make_shared(LogicalType::TIME); - m.attr("TIME_NS") = std::make_shared(LogicalType::TIME_NS); - - m.attr("TIME_TZ") = std::make_shared(LogicalType::TIME_TZ); - m.attr("TIMESTAMP_TZ") = std::make_shared(LogicalType::TIMESTAMP_TZ); - - m.attr("VARCHAR") = std::make_shared(LogicalType::VARCHAR); - - m.attr("BLOB") = std::make_shared(LogicalType::BLOB); - m.attr("BIT") = std::make_shared(LogicalType::BIT); - m.attr("INTERVAL") = std::make_shared(LogicalType::INTERVAL); - m.attr("VARIANT") = std::make_shared(LogicalType::VARIANT()); + m.attr("SQLNULL") = make_shared_ptr(LogicalType::SQLNULL); + m.attr("BOOLEAN") = make_shared_ptr(LogicalType::BOOLEAN); + m.attr("TINYINT") = make_shared_ptr(LogicalType::TINYINT); + m.attr("UTINYINT") = make_shared_ptr(LogicalType::UTINYINT); + m.attr("SMALLINT") = make_shared_ptr(LogicalType::SMALLINT); + m.attr("USMALLINT") = make_shared_ptr(LogicalType::USMALLINT); + m.attr("INTEGER") = make_shared_ptr(LogicalType::INTEGER); + m.attr("UINTEGER") = make_shared_ptr(LogicalType::UINTEGER); + m.attr("BIGINT") = make_shared_ptr(LogicalType::BIGINT); + m.attr("UBIGINT") = make_shared_ptr(LogicalType::UBIGINT); + m.attr("HUGEINT") = make_shared_ptr(LogicalType::HUGEINT); + m.attr("UHUGEINT") = make_shared_ptr(LogicalType::UHUGEINT); + m.attr("UUID") = make_shared_ptr(LogicalType::UUID); + m.attr("FLOAT") = make_shared_ptr(LogicalType::FLOAT); + m.attr("DOUBLE") = make_shared_ptr(LogicalType::DOUBLE); + m.attr("DATE") = make_shared_ptr(LogicalType::DATE); + + m.attr("TIMESTAMP") = make_shared_ptr(LogicalType::TIMESTAMP); + m.attr("TIMESTAMP_MS") = make_shared_ptr(LogicalType::TIMESTAMP_MS); + m.attr("TIMESTAMP_NS") = make_shared_ptr(LogicalType::TIMESTAMP_NS); + m.attr("TIMESTAMP_S") = make_shared_ptr(LogicalType::TIMESTAMP_S); + + m.attr("TIME") = make_shared_ptr(LogicalType::TIME); + m.attr("TIME_NS") = make_shared_ptr(LogicalType::TIME_NS); + + m.attr("TIME_TZ") = make_shared_ptr(LogicalType::TIME_TZ); + m.attr("TIMESTAMP_TZ") = make_shared_ptr(LogicalType::TIMESTAMP_TZ); + + m.attr("VARCHAR") = make_shared_ptr(LogicalType::VARCHAR); + + m.attr("BLOB") = make_shared_ptr(LogicalType::BLOB); + m.attr("BIT") = make_shared_ptr(LogicalType::BIT); + m.attr("INTERVAL") = make_shared_ptr(LogicalType::INTERVAL); + m.attr("VARIANT") = make_shared_ptr(LogicalType::VARIANT()); } void DuckDBPyTyping::Initialize(py::module_ &parent) { diff --git a/tests/extensions/json/test_read_json.py b/tests/extensions/json/test_read_json.py index af4abb50..f431906b 100644 --- a/tests/extensions/json/test_read_json.py +++ b/tests/extensions/json/test_read_json.py @@ -37,14 +37,11 @@ def test_read_json_sample_size(self): assert res == (1, "O Brother, Where Art Thou?") def test_read_json_format(self): - # Use a dedicated connection: a binder error now invalidates the active transaction, and the shared - # default connection can carry an open transaction from a previous test's unexhausted fetchone(). - con = duckdb.connect() # Wrong option with pytest.raises(duckdb.BinderException, match=r"format must be one of .* not 'test'"): - rel = con.read_json(TestFile("example.json"), format="test") + rel = duckdb.read_json(TestFile("example.json"), format="test") - rel = con.read_json(TestFile("example.json"), format="unstructured") + rel = duckdb.read_json(TestFile("example.json"), format="unstructured") res = rel.fetchone() print(res) assert res == ( @@ -75,13 +72,11 @@ def test_read_filelike(self, duckdb_cursor): assert res[0][2] != res[1][2] def test_read_json_records(self): - # Dedicated connection (see test_read_json_format): avoid inheriting a poisoned transaction. - con = duckdb.connect() # Wrong option with pytest.raises(duckdb.BinderException, match="""read_json requires "records" to be one of"""): - rel = con.read_json(TestFile("example.json"), records="none") + rel = duckdb.read_json(TestFile("example.json"), records="none") - rel = con.read_json(TestFile("example.json"), records="true") + rel = duckdb.read_json(TestFile("example.json"), records="true") res = rel.fetchone() print(res) assert res == (1, "O Brother, Where Art Thou?") diff --git a/tests/fast/adbc/test_adbc.py b/tests/fast/adbc/test_adbc.py index 6568e937..f82d0982 100644 --- a/tests/fast/adbc/test_adbc.py +++ b/tests/fast/adbc/test_adbc.py @@ -158,17 +158,12 @@ def test_connection_get_table_schema(duck_conn): ] ) - # The schema/tables created above must survive the rollback below, so commit them first. - duck_conn.commit() - # Test invalid catalog name with pytest.raises( adbc_driver_manager.InternalError, match=r'Catalog "bla" does not exist', ): duck_conn.adbc_get_table_schema("tableschema", catalog_filter="bla", db_schema_filter="test") - # The failed lookup aborts the (autocommit-off) transaction; roll it back before continuing. - duck_conn.rollback() # Catalog and DB Schema name assert duck_conn.adbc_get_table_schema( @@ -219,9 +214,6 @@ def test_insertion(duck_conn): cursor.execute("SELECT * FROM ingest") assert cursor.fetch_arrow_table() == table - # The created tables must survive the rollback below, so commit them first. - duck_conn.commit() - # Test Append with duck_conn.cursor() as cursor: with pytest.raises( @@ -229,8 +221,6 @@ def test_insertion(duck_conn): match=r"ALREADY_EXISTS", ): cursor.adbc_ingest("ingest_table", table, "create") - # The failed create aborts the (autocommit-off) transaction; roll it back before continuing. - duck_conn.rollback() cursor.adbc_ingest("ingest_table", table, "append") cursor.execute("SELECT count(*) FROM ingest_table") assert cursor.fetch_arrow_table().to_pydict() == {"count_star()": [8]} diff --git a/tests/fast/adbc/test_statement_bind.py b/tests/fast/adbc/test_statement_bind.py index a5cc7791..b6cff16c 100644 --- a/tests/fast/adbc/test_statement_bind.py +++ b/tests/fast/adbc/test_statement_bind.py @@ -191,6 +191,6 @@ def test_not_enough_parameters(self): statement.bind(array, schema) with pytest.raises( adbc_driver_manager.ProgrammingError, - match="Values were not provided for the following parameters: 2", + match="Values were not provided for the following prepared statement parameters: 2", ): statement.execute_query() diff --git a/tests/fast/api/test_duckdb_connection.py b/tests/fast/api/test_duckdb_connection.py index 2ffab929..9bca8288 100644 --- a/tests/fast/api/test_duckdb_connection.py +++ b/tests/fast/api/test_duckdb_connection.py @@ -129,17 +129,17 @@ def test_executemany(self): duckdb.execute("drop table tbl") def test_pystatement(self): - with pytest.raises(duckdb.ParserException, match="syntax error"): + with pytest.raises(duckdb.ParserException, match="seledct"): statements = duckdb.extract_statements("seledct 42; select 21") statements = duckdb.extract_statements("select $1; select 21") assert len(statements) == 2 - assert statements[0].query.startswith("select $1") + assert statements[0].query == "select $1" assert statements[0].type == duckdb.StatementType.SELECT assert statements[0].named_parameters == set("1") assert statements[0].expected_result_type == [duckdb.ExpectedResultType.QUERY_RESULT] - assert statements[1].query.startswith("select 21") + assert statements[1].query == " select 21" assert statements[1].type == duckdb.StatementType.SELECT assert statements[1].named_parameters == set() @@ -157,7 +157,7 @@ def test_pystatement(self): with pytest.raises( duckdb.InvalidInputException, - match="Values were not provided for the following parameters: 1", + match="Values were not provided for the following prepared statement parameters: 1", ): duckdb.execute(statements[0]) assert duckdb.execute(statements[0], {"1": 42}).fetchall() == [(42,)] diff --git a/tests/fast/api/test_duckdb_query.py b/tests/fast/api/test_duckdb_query.py index 78aea7a7..8be3287c 100644 --- a/tests/fast/api/test_duckdb_query.py +++ b/tests/fast/api/test_duckdb_query.py @@ -91,12 +91,9 @@ def test_named_param(self): def test_named_param_not_dict(self): con = duckdb.connect() - # Passing a list binds positionally as $1, $2, $3, which are excess against the named parameters in the - # query. The excess-parameter names come from an unordered set, so their order is not stable; match all - # three regardless of order. with pytest.raises( duckdb.InvalidInputException, - match=r"excess parameters: (?=.*\b1\b)(?=.*\b2\b)(?=.*\b3\b)", + match="Values were not provided for the following prepared statement parameters: name1, name2, name3", ): con.execute("select $name1, $name2, $name3", ["name1", "name2", "name3"]) @@ -113,7 +110,7 @@ def test_named_param_not_exhaustive(self): with pytest.raises( duckdb.InvalidInputException, - match="Invalid Input Error: Values were not provided for the following parameters: name3", + match="Invalid Input Error: Values were not provided for the following prepared statement parameters: name3", # noqa: E501 ): con.execute("select $name1, $name2, $name3", {"name1": 5, "name2": 3}) @@ -122,18 +119,16 @@ def test_named_param_excessive(self): with pytest.raises( duckdb.InvalidInputException, - match="Parameter argument/count mismatch, identifiers of the excess parameters: not_a_named_param", + match="Values were not provided for the following prepared statement parameters: name3", ): con.execute("select $name1, $name2, $name3", {"name1": 5, "name2": 3, "not_a_named_param": 5}) def test_named_param_not_named(self): con = duckdb.connect() - # Passing a dict for positional parameters makes name1/name2 excess. Unordered set: match both excess - # parameters regardless of order. with pytest.raises( duckdb.InvalidInputException, - match=r"excess parameters: (?=.*\bname1\b)(?=.*\bname2\b)", + match="Values were not provided for the following prepared statement parameters: 1, 2", ): con.execute("select $1, $1, $2", {"name1": 5, "name2": 3}) diff --git a/tests/fast/api/test_with_propagating_exceptions.py b/tests/fast/api/test_with_propagating_exceptions.py index edf335b6..6f4719fb 100644 --- a/tests/fast/api/test_with_propagating_exceptions.py +++ b/tests/fast/api/test_with_propagating_exceptions.py @@ -6,10 +6,7 @@ class TestWithPropagatingExceptions: def test_with(self): # Should propagate exception raised in the 'with duckdb.connect() ..' - with ( - pytest.raises(duckdb.CatalogException, match="Table with name invalid does not exist"), - duckdb.connect() as con, - ): + with pytest.raises(duckdb.ParserException, match=r"syntax error at or near *"), duckdb.connect() as con: con.execute("invalid") # Does not raise an exception diff --git a/tests/fast/arrow/_pushdown_helpers.py b/tests/fast/arrow/_pushdown_helpers.py deleted file mode 100644 index 48cfe78e..00000000 --- a/tests/fast/arrow/_pushdown_helpers.py +++ /dev/null @@ -1,176 +0,0 @@ -"""Shared helpers for filter pushdown tests. - -Used by both ``test_filter_pushdown.py`` (pyarrow) and -``test_polars_filter_pushdown.py``. The leading underscore prevents pytest -from collecting this module as a test file. - -The factories and parametrization data make the same set of comparison -correctness assertions reusable across every Arrow-shaped input the -replacement scan recognises (pyarrow Table, pyarrow Dataset, pyarrow-backed -pandas, polars LazyFrame, polars DataFrame). -""" - -from __future__ import annotations - -from typing import NamedTuple - -import pytest -from conftest import PANDAS_GE_3 - -pa_ds = pytest.importorskip("pyarrow.dataset") - - -# =========================================================================== -# Conversion factories — pyarrow side -# =========================================================================== - - -def to_arrow_table(rel): - return rel.to_arrow_table() - - -def to_arrow_via_pandas(rel): - if PANDAS_GE_3: - return rel.df() - return rel.df().convert_dtypes(dtype_backend="pyarrow") - - -def to_arrow_dataset(rel): - return pa_ds.dataset(rel.to_arrow_table()) - - -# Standard parametrization: every test that doesn't care about the conversion -# path runs against both the table and pandas factories. -ARROW_FACTORIES = [ - pytest.param(to_arrow_table, id="table"), - pytest.param(to_arrow_via_pandas, id="pandas"), -] - -ARROW_FACTORIES_WITH_DATASET = [ - *ARROW_FACTORIES, - pytest.param(to_arrow_dataset, id="dataset"), -] - - -# =========================================================================== -# Typed data fixtures -# -# For every type we test the same fixed 4-row layout: -# row 0: (low, low, low) -# row 1: (mid, mid, mid) -# row 2: (high, mid, high) -# row 3: (NULL, NULL, NULL) -# -# That layout makes the expected row counts for every comparison the same -# across all types, which is what lets us parametrize the comparison tests. -# =========================================================================== - - -class TypedCase(NamedTuple): - id: str # pytest id - sql_type: str - low: str # SQL literal for the smallest value - mid: str # SQL literal for the middle value (duplicated in col b row 2) - high: str # SQL literal for the largest value - - -COMPARABLE_TYPES: list[TypedCase] = [ - # numeric - TypedCase("tinyint", "TINYINT", "1", "10", "100"), - TypedCase("smallint", "SMALLINT", "1", "10", "100"), - TypedCase("integer", "INTEGER", "1", "10", "100"), - TypedCase("bigint", "BIGINT", "1", "10", "100"), - TypedCase("utinyint", "UTINYINT", "1", "10", "100"), - TypedCase("usmallint", "USMALLINT", "1", "10", "100"), - TypedCase("uinteger", "UINTEGER", "1", "10", "100"), - TypedCase("ubigint", "UBIGINT", "1", "10", "100"), - TypedCase("hugeint", "HUGEINT", "1", "10", "100"), - TypedCase("float", "FLOAT", "1.0", "10.0", "100.0"), - TypedCase("double", "DOUBLE", "1.0", "10.0", "100.0"), - TypedCase("decimal_4_1", "DECIMAL(4,1)", "1.0", "10.0", "100.0"), - TypedCase("decimal_9_1", "DECIMAL(9,1)", "1.0", "10.0", "100.0"), - TypedCase("decimal_18_4", "DECIMAL(18,4)", "1.0", "10.0", "100.0"), - TypedCase("decimal_30_12", "DECIMAL(30,12)", "1.0", "10.0", "100.0"), - # string / blob - TypedCase("varchar", "VARCHAR", "'1'", "'10'", "'100'"), - TypedCase("blob", "BLOB", r"'\x01'", r"'\x02'", r"'\x03'"), - # temporal - TypedCase("date", "DATE", "'2000-01-01'", "'2000-10-01'", "'2010-01-01'"), - TypedCase("time", "TIME", "'00:01:00'", "'00:10:00'", "'01:00:00'"), - TypedCase("timestamp", "TIMESTAMP", "'2008-01-01 00:00:01'", "'2010-01-01 10:00:01'", "'2020-03-01 10:00:01'"), - TypedCase("timestamptz", "TIMESTAMPTZ", "'2008-01-01 00:00:01'", "'2010-01-01 10:00:01'", "'2020-03-01 10:00:01'"), -] - - -def make_typed_table(con, factory, case: TypedCase) -> object: - """Create the standard table for `case`, convert via `factory`, and register it as ``arrow_table``.""" - name = f"_t_{case.id}" - con.execute(f"DROP TABLE IF EXISTS {name}") - con.execute(f"CREATE TABLE {name} (a {case.sql_type}, b {case.sql_type}, c {case.sql_type})") - con.execute( - f"""INSERT INTO {name} VALUES - ({case.low}, {case.low}, {case.low}), - ({case.mid}, {case.mid}, {case.mid}), - ({case.high}, {case.mid}, {case.high}), - (NULL, NULL, NULL)""" - ) - arrow_table = factory(con.table(name)) - con.register("arrow_table", arrow_table) - return arrow_table - - -def count(con, predicate: str) -> int: - return con.execute(f"SELECT count(*) FROM arrow_table WHERE {predicate}").fetchone()[0] - - -# =========================================================================== -# Predicate templates parametrized in `test_comparisons` -# =========================================================================== - - -# (predicate template, expected row count). Templates reference {low}, {mid}, {high}. -COMPARISON_CASES = [ - pytest.param("a = {low}", 1, id="eq"), - pytest.param("a != {low}", 2, id="ne"), - pytest.param("a > {low}", 2, id="gt"), - pytest.param("a >= {mid}", 2, id="ge"), - pytest.param("a < {mid}", 1, id="lt"), - pytest.param("a <= {mid}", 2, id="le"), - pytest.param("a IS NULL", 1, id="is_null"), - pytest.param("a IS NOT NULL", 3, id="is_not_null"), - pytest.param("a = {mid} AND b = {low}", 0, id="and_empty"), - pytest.param("a = {high} AND b = {mid} AND c = {high}", 1, id="and_match"), - pytest.param("a = {high} OR b = {low}", 2, id="or"), -] - - -# =========================================================================== -# Plan-inspection helpers -# =========================================================================== - - -def arrow_scan_block(plan: str) -> str | None: - """Return the ARROW_SCAN box (top border to bottom border) from an EXPLAIN plan. - - Works uniformly for pyarrow Tables, pyarrow Datasets, pl.LazyFrame, and - pl.DataFrame inputs — they all bind through ``arrow_scan`` in DuckDB and - render as ``ARROW_SCAN`` in the plan. - """ - lines = plan.splitlines() - scan_idx = next((i for i, line in enumerate(lines) if "ARROW_SCAN" in line), None) - if scan_idx is None: - return None - top = scan_idx - while top > 0 and "┌" not in lines[top]: - top -= 1 - bot = scan_idx - while bot < len(lines) and "└" not in lines[bot]: - bot += 1 - return "\n".join(lines[top : bot + 1]) - - -def was_pushed(con, query: str) -> bool: - """True if EXPLAIN of `query` shows a ``Filters:`` line in the ARROW_SCAN block.""" - plan = con.execute(f"EXPLAIN {query}").fetchone()[1] - block = arrow_scan_block(plan) - return block is not None and "Filters:" in block diff --git a/tests/fast/arrow/test_arrow_types.py b/tests/fast/arrow/test_arrow_types.py index 3a0b88ed..5f884f6a 100644 --- a/tests/fast/arrow/test_arrow_types.py +++ b/tests/fast/arrow/test_arrow_types.py @@ -13,21 +13,24 @@ def test_null_type(self, duckdb_cursor): arrow_table = pa.Table.from_arrays(inputs, schema=schema) duckdb_cursor.register("testarrow", arrow_table) rel = duckdb.from_arrow(arrow_table).to_arrow_table() - # NULL type now round-trips faithfully (previously it was coerced to int32) + # We turn it to an array of int32 nulls + schema = pa.schema([("data", pa.int32())]) + inputs = [pa.array([None, None, None], type=pa.null())] + arrow_table = pa.Table.from_arrays(inputs, schema=schema) + assert rel["data"] == arrow_table["data"] - def test_empty_struct(self, duckdb_cursor): - # Empty structs are now supported by DuckDB core. This previously raised - # "Attempted to convert a STRUCT with no fields to DuckDB which is not supported"; - # the core check was removed and empty structs now round-trip faithfully. + def test_invalid_struct(self, duckdb_cursor): empty_struct_type = pa.struct([]) - arrow_table = pa.Table.from_arrays( # noqa: F841 - [pa.array([None, None], type=empty_struct_type)], - schema=pa.schema([("data", empty_struct_type)]), - ) - result = duckdb_cursor.sql("select * from arrow_table").to_arrow_table() - assert result["data"].type == empty_struct_type - assert result["data"].to_pylist() == [None, None] + + # Create an empty array with the defined struct type + empty_array = pa.array([], type=empty_struct_type) + arrow_table = pa.Table.from_arrays([empty_array], schema=pa.schema([("data", empty_struct_type)])) # noqa: F841 + with pytest.raises( + duckdb.InvalidInputException, + match="Attempted to convert a STRUCT with no fields to DuckDB which is not supported", + ): + duckdb_cursor.sql("select * from arrow_table").fetchall() def test_invalid_union(self, duckdb_cursor): # Create a sparse union array from dense arrays diff --git a/tests/fast/arrow/test_filter_pushdown.py b/tests/fast/arrow/test_filter_pushdown.py index 42fda869..1dabdece 100644 --- a/tests/fast/arrow/test_filter_pushdown.py +++ b/tests/fast/arrow/test_filter_pushdown.py @@ -1,57 +1,8 @@ # ruff: noqa: F841 -"""Filter pushdown tests for the arrow scan integration. - -What's tested here: - -* **Comparison correctness** across every supported type: - ``=``, ``!=``, ``<``, ``<=``, ``>``, ``>=``, ``IS NULL``, ``IS NOT NULL``, - ``AND``, ``OR``. -* **Optimizer pushdown decisions** — which predicate shapes get pushed into - the ``ARROW_SCAN`` operator and which the optimizer keeps above. Verified - by inspecting the EXPLAIN plan, not just row counts. -* **Special filter shapes** — ``IN``, ``LIKE``, ``CAST(ts AS DATE) = …``, - ``IS DISTINCT FROM NULL`` inside ``OR``, NaN ordering, struct extraction - (one- and multi-level), optional filter, dynamic top-N filter, join - filter pushdown. -* **Unsupported-type fallback** — ``UHUGEINT``, ``string_view``, - ``binary_view`` columns must not crash; the filter is applied above the - scan instead. -* **Regressions** — issues that were fixed previously and need to stay - fixed. -* **Canaries** — markers for behaviour we expect to change upstream - (pyarrow gaining view-filter support, DuckDB starting to push IS_NULL or - struct IN, etc.). - -Two conversion paths are exercised everywhere it makes sense — `.to_arrow_table()` -and `pandas`-via-pyarrow `df()`. Some tests also run through -`pyarrow.dataset` to cover the dataset-scanner code path. -""" - -from __future__ import annotations - -import datetime as dt -import re +import sys import pytest -from _pushdown_helpers import ( - ARROW_FACTORIES, - ARROW_FACTORIES_WITH_DATASET, - COMPARABLE_TYPES, - COMPARISON_CASES, - to_arrow_table, -) -from _pushdown_helpers import ( - arrow_scan_block as _arrow_scan_block, -) -from _pushdown_helpers import ( - count as _count, -) -from _pushdown_helpers import ( - make_typed_table as _make_typed_table, -) -from _pushdown_helpers import ( - was_pushed as _was_pushed, -) +from conftest import PANDAS_GE_3 from packaging.version import Version import duckdb @@ -62,408 +13,942 @@ pa_parquet = pytest.importorskip("pyarrow.parquet") pd = pytest.importorskip("pandas") np = pytest.importorskip("numpy") +re = pytest.importorskip("re") -# =========================================================================== -# 1. Comparison correctness across types -# =========================================================================== - - -@pytest.mark.parametrize("factory", ARROW_FACTORIES) -@pytest.mark.parametrize("case", COMPARABLE_TYPES, ids=lambda c: c.id) -@pytest.mark.parametrize(("predicate_tpl", "expected"), COMPARISON_CASES) -def test_comparisons(duckdb_cursor, factory, case, predicate_tpl, expected): - """Each (type, factory, predicate) tuple produces the expected row count.""" - _make_typed_table(duckdb_cursor, factory, case) - predicate = predicate_tpl.format(low=case.low, mid=case.mid, high=case.high) - assert _count(duckdb_cursor, predicate) == expected - - -# BOOL has no ordering, so it gets its own tiny suite. -@pytest.mark.parametrize("factory", ARROW_FACTORIES) -def test_bool_comparisons(duckdb_cursor, factory): - """Equality / IS NULL / AND / OR on BOOL columns.""" - duckdb_cursor.execute("CREATE TABLE _b (a BOOL, b BOOL)") - duckdb_cursor.execute("INSERT INTO _b VALUES (TRUE, TRUE), (TRUE, FALSE), (FALSE, TRUE), (NULL, NULL)") - arrow_table = factory(duckdb_cursor.table("_b")) - duckdb_cursor.register("arrow_table", arrow_table) - - assert _count(duckdb_cursor, "a = TRUE") == 2 - assert _count(duckdb_cursor, "a IS NULL") == 1 - assert _count(duckdb_cursor, "a IS NOT NULL") == 3 - assert _count(duckdb_cursor, "a = TRUE AND b = TRUE") == 1 - assert _count(duckdb_cursor, "a = TRUE OR b = TRUE") == 3 - - -# Integer boundary values are worth a separate test because the GetScalar path -# has to coerce each (DuckDB Value) -> (pyarrow scalar) at the limit. -@pytest.mark.parametrize("factory", ARROW_FACTORIES) -@pytest.mark.parametrize( - ("data_type", "max_value"), - [ - ("TINYINT", 127), - ("SMALLINT", 32767), - ("INTEGER", 2147483647), - ("BIGINT", 9223372036854775807), - ("UTINYINT", 255), - ("USMALLINT", 65535), - ("UINTEGER", 4294967295), - ("UBIGINT", 18446744073709551615), - ], -) -def test_integer_max_value(duckdb_cursor, factory, data_type, max_value): - """Pushdown round-trips through every integer's maximum representable value.""" - duckdb_cursor.execute(f"CREATE TABLE _t AS SELECT {max_value}::{data_type} AS i") - arrow_table = factory(duckdb_cursor.table("_t")) - duckdb_cursor.register("arrow_table", arrow_table) - expected = [(max_value,)] - assert duckdb_cursor.sql("SELECT * FROM arrow_table WHERE i > 0").fetchall() == expected - assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE i > ?", (0,)).fetchall() == expected - assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE i = ?", (max_value,)).fetchall() == expected - - -# =========================================================================== -# 2. OR pushdown decisions -# -# The optimizer pushes only a subset of OR shapes. These tests verify which -# shapes survive by inspecting the EXPLAIN plan. -# =========================================================================== - - -class TestOrPushdownDecisions: - """Same-column ORs push; multi-column ORs and OR-with-LIKE/NULL don't.""" - - @pytest.fixture(autouse=True) - def _arrow_table(self, duckdb_cursor): - duckdb_cursor.execute("CREATE TABLE _o (a INTEGER, b INTEGER, c INTEGER)") - duckdb_cursor.execute("INSERT INTO _o VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") - duckdb_cursor.register("arrow_table", to_arrow_table(duckdb_cursor.table("_o"))) - - def test_single_column_or_pushes(self, duckdb_cursor): - assert _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a = 1 OR a = 10") - - def test_single_column_or_with_and_does_not_push(self, duckdb_cursor): - # The optimizer does not currently push ``a = 1 OR (a > 3 AND a < 5)`` - # — the AND inside the OR keeps it as a filter node above the scan. - # The original test had a vacuous regex (``...|$``) that always matched; - # this is the real behavior on current DuckDB main. - assert not _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a = 1 OR (a > 3 AND a < 5)") - - def test_multiple_or_terms_push(self, duckdb_cursor): - assert _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a = 1 OR a > 3 OR a < 5") - - def test_or_with_not_equal_pushes(self, duckdb_cursor): - assert _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a != 1 OR a > 3 OR a < 2") - - def test_multi_column_or_does_not_push(self, duckdb_cursor): - # Optimizer refuses to push a root OR that references multiple columns. - assert not _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a = 1 OR b = 2 AND (a > 3 OR b < 5)") - - -class TestStringOrSpecifics: - """VARCHAR has stricter OR-pushdown rules than numeric types.""" - - @pytest.fixture(autouse=True) - def _arrow_table(self, duckdb_cursor): - duckdb_cursor.execute("CREATE TABLE _v (a VARCHAR, b VARCHAR, c VARCHAR)") - duckdb_cursor.execute( - "INSERT INTO _v VALUES ('1','1','1'),('10','10','10'),('100','10','100'),(NULL,NULL,NULL)" - ) - duckdb_cursor.register("arrow_table", to_arrow_table(duckdb_cursor.table("_v"))) +def create_pyarrow_pandas(rel): + if PANDAS_GE_3: + return rel.df() + else: + return rel.df().convert_dtypes(dtype_backend="pyarrow") - def test_string_range_or_pushes(self, duckdb_cursor): - assert _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a >= '1' OR a <= '10'") - def test_or_with_is_null_does_not_push(self, duckdb_cursor): - assert not _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a IS NULL OR a = '1'") +def create_pyarrow_table(rel): + return rel.to_arrow_table() - def test_or_with_is_not_null_does_not_push(self, duckdb_cursor): - assert not _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a IS NOT NULL OR a = '1'") - def test_or_with_like_does_not_push(self, duckdb_cursor): - assert not _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a = '1' OR a LIKE '10%'") +def create_pyarrow_dataset(rel): + table = create_pyarrow_table(rel) + return pa_ds.dataset(table) -# =========================================================================== -# 3. IN-list pushdown -# =========================================================================== +def test_decimal_filter_pushdown(duckdb_cursor): + pl = pytest.importorskip("polars") + np = pytest.importorskip("numpy") + np.random.seed(10) + df = pl.DataFrame({"x": pl.Series(np.random.uniform(-10, 10, 1000)).cast(pl.Decimal(precision=18, scale=4))}) + + query = """ + SELECT + x, + x > 0.05 AS is_x_good, + x::FLOAT > 0.05 AS is_float_x_good + FROM {} + WHERE + is_x_good + ORDER BY x ASC + """ + + assert len(duckdb_cursor.sql(query.format("df")).fetchall()) == 495 -class TestInPushdown: - """IN (...) pushdown test. - IN (...) reaches the walker as a ``COMPARE_IN`` ``BOUND_OPERATOR``, wrapped in an - ``__internal_tablefilter_optional`` function that the walker must unwrap before it - sees the operator. +def numeric_operators(connection, data_type, tbl_name, create_table): + connection.execute( + f""" + CREATE TABLE {tbl_name} ( + a {data_type}, + b {data_type}, + c {data_type} + ) """ + ) + connection.execute( + f""" + INSERT INTO {tbl_name} VALUES + (1,1,1), + (10,10,10), + (100,10,100), + (NULL,NULL,NULL) + """ + ) + duck_tbl = connection.table(tbl_name) + arrow_table = create_table(duck_tbl) + + # Try == + assert connection.execute("SELECT count(*) from arrow_table where a = 1").fetchone()[0] == 1 + # Try > + assert connection.execute("SELECT count(*) from arrow_table where a > 1").fetchone()[0] == 2 + # Try >= + assert connection.execute("SELECT count(*) from arrow_table where a >= 10").fetchone()[0] == 2 + # Try < + assert connection.execute("SELECT count(*) from arrow_table where a < 10").fetchone()[0] == 1 + # Try <= + assert connection.execute("SELECT count(*) from arrow_table where a <= 10").fetchone()[0] == 2 + + # Try Is Null + assert connection.execute("SELECT count(*) from arrow_table where a IS NULL").fetchone()[0] == 1 + # Try Is Not Null + assert connection.execute("SELECT count(*) from arrow_table where a IS NOT NULL").fetchone()[0] == 3 + + # Try And + assert connection.execute("SELECT count(*) from arrow_table where a = 10 and b = 1").fetchone()[0] == 0 + assert ( + connection.execute("SELECT count(*) from arrow_table where a = 100 and b = 10 and c = 100").fetchone()[0] == 1 + ) - def test_basic(self, duckdb_cursor): - duckdb_cursor.execute("CREATE TABLE _t AS SELECT range a FROM range(1000)") - duckdb_cursor.register("arrow_table", to_arrow_table(duckdb_cursor.table("_t"))) - assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE a = ANY([1, 999])").fetchall() == [(1,), (999,)] + # Try Or + assert connection.execute("SELECT count(*) from arrow_table where a = 100 or b = 1").fetchone()[0] == 2 - @pytest.mark.timeout(10) - def test_large_in_list_does_not_hang(self): - """Regression: https://github.com/duckdb/duckdb-python/issues/52.""" - duckdb.register("arrow_table", pa.table({"a": pa.array(range(5000))})) - in_list = ", ".join(str(i) for i in range(0, 5000, 2)) - result = duckdb.sql(f"SELECT count(*) FROM arrow_table WHERE a IN ({in_list})").fetchone() - assert result == (2500,) + connection.execute("EXPLAIN SELECT count(*) from arrow_table where a = 100 or b = 1") + print(connection.fetchall()) - def test_in_with_no_nulls_in_list(self): - duckdb.register("arrow_table", pa.table({"a": pa.array([1, 2, None, 4, None, 6])})) - result = duckdb.sql("SELECT a FROM arrow_table WHERE a IN (1, 4) ORDER BY a").fetchall() - assert result == [(1,), (4,)] - def test_in_with_null_in_list(self): - """SQL semantics: NULL in the IN list still doesn't match NULL rows.""" - duckdb.register("arrow_table", pa.table({"a": pa.array([1, 2, None, 4, None, 6])})) - result = duckdb.sql("SELECT a FROM arrow_table WHERE a IN (1, 4, NULL) ORDER BY a").fetchall() - assert result == [(1,), (4,)] +def numeric_check_or_pushdown(connection, tbl_name, create_table): + duck_tbl = connection.table(tbl_name) + arrow_table = create_table(duck_tbl) - def test_in_varchar(self): - duckdb.register("arrow_table", pa.table({"s": pa.array(["alice", "bob", "charlie", "dave", None])})) - result = duckdb.sql("SELECT s FROM arrow_table WHERE s IN ('bob', 'dave') ORDER BY s").fetchall() - assert result == [("bob",), ("dave",)] + # Multiple column in the root OR node, don't push down + query_res = connection.execute( + """ + EXPLAIN SELECT * FROM arrow_table WHERE + a = 1 OR b = 2 AND (a > 3 OR b < 5) + """ + ).fetchall() + match = re.search(".*ARROW_SCAN.*Filters:.*", query_res[0][1]) + assert not match - def test_in_float(self): - duckdb.register("arrow_table", pa.table({"f": pa.array([1.0, 2.5, 3.75, 4.0, None], type=pa.float64())})) - result = duckdb.sql("SELECT f FROM arrow_table WHERE f IN (2.5, 4.0) ORDER BY f").fetchall() - assert result == [(2.5,), (4.0,)] + # Single column in the root OR node + query_res = connection.execute( + """ + EXPLAIN SELECT * FROM arrow_table WHERE + a = 1 OR a = 10 + """ + ).fetchall() + match = re.search(".*ARROW_SCAN.*Filters: a=1 OR a=10.*|$", query_res[0][1]) + assert match + + # Single column + root OR node with AND + query_res = connection.execute( + """ + EXPLAIN SELECT * FROM arrow_table + WHERE a = 1 OR (a > 3 AND a < 5) + """ + ).fetchall() + match = re.search(".*ARROW_SCAN.*Filters: a=1 OR a>3 AND a<5.*|$", query_res[0][1]) + assert match + + # Single column multiple ORs + query_res = connection.execute("EXPLAIN SELECT * FROM arrow_table WHERE a=1 OR a>3 OR a<5").fetchall() + match = re.search(".*ARROW_SCAN.*Filters: a=1 OR a>3 OR a<5.*|$", query_res[0][1]) + assert match + + # Testing not equal + query_res = connection.execute("EXPLAIN SELECT * FROM arrow_table WHERE a!=1 OR a>3 OR a<2").fetchall() + match = re.search(".*ARROW_SCAN.*Filters: a!=1 OR a>3 OR a<2.*|$", query_res[0][1]) + assert match + + # Multiple OR filters connected with ANDs + query_res = connection.execute( + "EXPLAIN SELECT * FROM arrow_table WHERE (a<2 OR a>3) AND (a=1 OR a=4) AND (b=1 OR b<5)" + ).fetchall() + match = re.search(".*ARROW_SCAN.*Filters: a<2 OR a>3 AND a=1|\n.*OR a=4.*\n.*b=2 OR b<5.*|$", query_res[0][1]) + assert match + + +def string_check_or_pushdown(connection, tbl_name, create_table): + duck_tbl = connection.table(tbl_name) + arrow_table = create_table(duck_tbl) + + # Check string zonemap + query_res = connection.execute("EXPLAIN SELECT * FROM arrow_table WHERE a >= '1' OR a <= '10'").fetchall() + match = re.search(".*ARROW_SCAN.*Filters: a>=1 OR a<=10.*|$", query_res[0][1]) + assert match + + # No support for OR with is null + query_res = connection.execute("EXPLAIN SELECT * FROM arrow_table WHERE a IS NULL or a = '1'").fetchall() + match = re.search(".*ARROW_SCAN.*Filters:.*", query_res[0][1]) + assert not match + + # No support for OR with is not null + query_res = connection.execute("EXPLAIN SELECT * FROM arrow_table WHERE a IS NOT NULL OR a = '1'").fetchall() + match = re.search(".*ARROW_SCAN.*Filters:.*", query_res[0][1]) + assert not match + # OR with the like operator + query_res = connection.execute("EXPLAIN SELECT * FROM arrow_table WHERE a = 1 OR a LIKE '10%'").fetchall() + match = re.search(".*ARROW_SCAN.*Filters:.*", query_res[0][1]) + assert not match -# =========================================================================== -# 4. NaN pushdown -# -# DuckDB intentionally violates IEEE-754: NaN is the greatest value. -# The pyarrow_filter_pushdown special-cases this so the pyarrow side gets -# is_nan() / its inverse / constant(true|false) depending on the operator. -# =========================================================================== +class TestArrowFilterPushdown: + @pytest.mark.parametrize( + "data_type", + [ + "TINYINT", + "SMALLINT", + "INTEGER", + "BIGINT", + "UTINYINT", + "USMALLINT", + "UINTEGER", + "UBIGINT", + "FLOAT", + "DOUBLE", + "HUGEINT", + "DECIMAL(4,1)", + "DECIMAL(9,1)", + "DECIMAL(18,4)", + "DECIMAL(30,12)", + ], + ) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) + def test_filter_pushdown_numeric(self, data_type, duckdb_cursor, create_table): + tbl_name = "tbl" + numeric_operators(duckdb_cursor, data_type, tbl_name, create_table) + numeric_check_or_pushdown(duckdb_cursor, tbl_name, create_table) + + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) + def test_filter_pushdown_varchar(self, duckdb_cursor, create_table): + duckdb_cursor.execute( + """ + CREATE TABLE test_varchar ( + a VARCHAR, + b VARCHAR, + c VARCHAR + ) + """ + ) + duckdb_cursor.execute( + """ + INSERT INTO test_varchar VALUES + ('1','1','1'), + ('10','10','10'), + ('100','10','100'), + (NULL, NULL, NULL) + """ + ) + duck_tbl = duckdb_cursor.table("test_varchar") + arrow_table = create_table(duck_tbl) + + # Try == + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a = '1'").fetchone()[0] == 1 + # Try > + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a > '1'").fetchone()[0] == 2 + # Try >= + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a >= '10'").fetchone()[0] == 2 + # Try < + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a < '10'").fetchone()[0] == 1 + # Try <= + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a <= '10'").fetchone()[0] == 2 + + # Try Is Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NULL").fetchone()[0] == 1 + # Try Is Not Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NOT NULL").fetchone()[0] == 3 + + # Try And + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a = '10' and b = '1'").fetchone()[0] == 0 + assert ( + duckdb_cursor.execute( + "SELECT count(*) from arrow_table where a = '100' and b = '10' and c = '100'" + ).fetchone()[0] + == 1 + ) + # Try Or + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a = '100' or b ='1'").fetchone()[0] == 2 -class TestNaNPushdown: - """Six comparison operators against a NaN constant on a DOUBLE column.""" + # More complex tests for OR pushed down on string + string_check_or_pushdown(duckdb_cursor, "test_varchar", create_table) + + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) + def test_filter_pushdown_bool(self, duckdb_cursor, create_table): + duckdb_cursor.execute( + """ + CREATE TABLE test_bool ( + a BOOL, + b BOOL + ) + """ + ) + duckdb_cursor.execute( + """ + INSERT INTO test_bool VALUES + (TRUE,TRUE), + (TRUE,FALSE), + (FALSE,TRUE), + (NULL,NULL) + """ + ) + duck_tbl = duckdb_cursor.table("test_bool") + arrow_table = create_table(duck_tbl) + + # Try == + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a = True").fetchone()[0] == 2 + + # Try Is Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NULL").fetchone()[0] == 1 + # Try Is Not Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NOT NULL").fetchone()[0] == 3 + + # Try And + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a= True and b = True").fetchone()[0] == 1 + # Try Or + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a = True or b = True").fetchone()[0] == 3 + + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) + def test_filter_pushdown_time(self, duckdb_cursor, create_table): + duckdb_cursor.execute( + """ + CREATE TABLE test_time ( + a TIME, + b TIME, + c TIME + ) + """ + ) + duckdb_cursor.execute( + """ + INSERT INTO test_time VALUES + ('00:01:00','00:01:00','00:01:00'), + ('00:10:00','00:10:00','00:10:00'), + ('01:00:00','00:10:00','01:00:00'), + (NULL,NULL,NULL) + """ + ) + duck_tbl = duckdb_cursor.table("test_time") + arrow_table = create_table(duck_tbl) + + # Try == + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a ='00:01:00'").fetchone()[0] == 1 + # Try > + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a >'00:01:00'").fetchone()[0] == 2 + # Try >= + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a >='00:10:00'").fetchone()[0] == 2 + # Try < + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a <'00:10:00'").fetchone()[0] == 1 + # Try <= + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a <='00:10:00'").fetchone()[0] == 2 + + # Try Is Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NULL").fetchone()[0] == 1 + # Try Is Not Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NOT NULL").fetchone()[0] == 3 + + # Try And + assert ( + duckdb_cursor.execute("SELECT count(*) from arrow_table where a='00:10:00' and b ='00:01:00'").fetchone()[0] + == 0 + ) + assert ( + duckdb_cursor.execute( + "SELECT count(*) from arrow_table where a ='01:00:00' and b = '00:10:00' and c = '01:00:00'" + ).fetchone()[0] + == 1 + ) + # Try Or + assert ( + duckdb_cursor.execute("SELECT count(*) from arrow_table where a = '01:00:00' or b ='00:01:00'").fetchone()[ + 0 + ] + == 2 + ) - @pytest.fixture(autouse=True) - def _nan_arrow_table(self, duckdb_cursor): + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) + def test_filter_pushdown_timestamp(self, duckdb_cursor, create_table): duckdb_cursor.execute( - "CREATE TABLE _n AS SELECT a::DOUBLE a FROM VALUES " - "('inf'), ('nan'), ('0.34234'), ('34234234.00005'), ('-nan') t(a)" + """ + CREATE TABLE test_timestamp ( + a TIMESTAMP, + b TIMESTAMP, + c TIMESTAMP + ) + """ + ) + duckdb_cursor.execute( + """ + INSERT INTO test_timestamp VALUES + ('2008-01-01 00:00:01','2008-01-01 00:00:01','2008-01-01 00:00:01'), + ('2010-01-01 10:00:01','2010-01-01 10:00:01','2010-01-01 10:00:01'), + ('2020-03-01 10:00:01','2010-01-01 10:00:01','2020-03-01 10:00:01'), + (NULL,NULL,NULL) + """ + ) + duck_tbl = duckdb_cursor.table("test_timestamp") + arrow_table = create_table(duck_tbl) + + # Try == + assert ( + duckdb_cursor.execute("SELECT count(*) from arrow_table where a ='2008-01-01 00:00:01'").fetchone()[0] == 1 + ) + # Try > + assert ( + duckdb_cursor.execute("SELECT count(*) from arrow_table where a >'2008-01-01 00:00:01'").fetchone()[0] == 2 + ) + # Try >= + assert ( + duckdb_cursor.execute("SELECT count(*) from arrow_table where a >='2010-01-01 10:00:01'").fetchone()[0] == 2 + ) + # Try < + assert ( + duckdb_cursor.execute("SELECT count(*) from arrow_table where a <'2010-01-01 10:00:01'").fetchone()[0] == 1 + ) + # Try <= + assert ( + duckdb_cursor.execute("SELECT count(*) from arrow_table where a <='2010-01-01 10:00:01'").fetchone()[0] == 2 ) - arrow_table = to_arrow_table(duckdb_cursor.table("_n")) - duckdb_cursor.register("arrow_table", arrow_table) + # Try Is Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NULL").fetchone()[0] == 1 + # Try Is Not Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NOT NULL").fetchone()[0] == 3 + + # Try And + assert ( + duckdb_cursor.execute( + "SELECT count(*) from arrow_table where a='2010-01-01 10:00:01' and b ='2008-01-01 00:00:01'" + ).fetchone()[0] + == 0 + ) + assert ( + duckdb_cursor.execute( + "SELECT count(*) from arrow_table where a ='2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'" # noqa: E501 + ).fetchone()[0] + == 1 + ) + # Try Or + assert ( + duckdb_cursor.execute( + "SELECT count(*) from arrow_table where a = '2020-03-01 10:00:01' or b ='2008-01-01 00:00:01'" + ).fetchone()[0] + == 2 + ) + + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) + def test_filter_pushdown_timestamp_TZ(self, duckdb_cursor, create_table): + duckdb_cursor.execute( + """ + CREATE TABLE test_timestamptz ( + a TIMESTAMPTZ, + b TIMESTAMPTZ, + c TIMESTAMPTZ + ) + """ + ) + duckdb_cursor.execute( + """ + INSERT INTO test_timestamptz VALUES + ('2008-01-01 00:00:01','2008-01-01 00:00:01','2008-01-01 00:00:01'), + ('2010-01-01 10:00:01','2010-01-01 10:00:01','2010-01-01 10:00:01'), + ('2020-03-01 10:00:01','2010-01-01 10:00:01','2020-03-01 10:00:01'), + (NULL,NULL,NULL) + """ + ) + duck_tbl = duckdb_cursor.table("test_timestamptz") + arrow_table = create_table(duck_tbl) + + # Try == + assert ( + duckdb_cursor.execute("SELECT count(*) from arrow_table where a = '2008-01-01 00:00:01'").fetchone()[0] == 1 + ) + # Try > + assert ( + duckdb_cursor.execute("SELECT count(*) from arrow_table where a > '2008-01-01 00:00:01'").fetchone()[0] == 2 + ) + # Try >= + assert ( + duckdb_cursor.execute("SELECT count(*) from arrow_table where a >= '2010-01-01 10:00:01'").fetchone()[0] + == 2 + ) + # Try < + assert ( + duckdb_cursor.execute("SELECT count(*) from arrow_table where a < '2010-01-01 10:00:01'").fetchone()[0] == 1 + ) + # Try <= + assert ( + duckdb_cursor.execute("SELECT count(*) from arrow_table where a <= '2010-01-01 10:00:01'").fetchone()[0] + == 2 + ) + + # Try Is Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NULL").fetchone()[0] == 1 + # Try Is Not Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NOT NULL").fetchone()[0] == 3 + + # Try And + assert ( + duckdb_cursor.execute( + "SELECT count(*) from arrow_table where a = '2010-01-01 10:00:01' and b = '2008-01-01 00:00:01'" + ).fetchone()[0] + == 0 + ) + assert ( + duckdb_cursor.execute( + "SELECT count(*) from arrow_table where a = '2020-03-01 10:00:01' and b = '2010-01-01 10:00:01' and c = '2020-03-01 10:00:01'" # noqa: E501 + ).fetchone()[0] + == 1 + ) + # Try Or + assert ( + duckdb_cursor.execute( + "SELECT count(*) from arrow_table where a = '2020-03-01 10:00:01' or b ='2008-01-01 00:00:01'" + ).fetchone()[0] + == 2 + ) + + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) @pytest.mark.parametrize( - "op", - ["=", "!=", "<", "<=", ">", ">="], + ("data_type", "value"), + [ + ("TINYINT", 127), + ("SMALLINT", 32767), + ("INTEGER", 2147483647), + ("BIGINT", 9223372036854775807), + ("UTINYINT", 255), + ("USMALLINT", 65535), + ("UINTEGER", 4294967295), + ("UBIGINT", 18446744073709551615), + ], + ) + def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_table): + duckdb_cursor.execute( + f""" + CREATE TABLE tbl as select {value}::{data_type} as i + """ + ) + expected = duckdb_cursor.table("tbl").fetchall() + filter = "i > 0" + rel = duckdb_cursor.table("tbl") + arrow_table = create_table(rel) + actual = duckdb_cursor.sql(f"select * from arrow_table where {filter}").fetchall() + assert expected == actual + + # Test with equivalent prepared statement + actual = duckdb_cursor.execute("select * from arrow_table where i > ?", (0,)).fetchall() + assert expected == actual + # Test equality + actual = duckdb_cursor.execute("select * from arrow_table where i = ?", (value,)).fetchall() + assert expected == actual + + @pytest.mark.skipif( + Version(pa.__version__) < Version("15.0.0"), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" ) - def test_nan_comparison_matches_duckdb(self, duckdb_cursor, op): - """Each NaN comparison through the arrow scan agrees with DuckDB's own answer.""" - q_arrow = f"SELECT count(*) FROM arrow_table WHERE a {op} 'NaN'::FLOAT" - q_duck = f"SELECT count(*) FROM _n WHERE a {op} 'NaN'::FLOAT" - assert duckdb_cursor.execute(q_arrow).fetchone() == duckdb_cursor.execute(q_duck).fetchone() + def test_9371(self, duckdb_cursor, tmp_path): + import datetime + # connect to an in-memory database + duckdb_cursor.execute("SET TimeZone='UTC';") + base_path = tmp_path / "parquet_folder" + base_path.mkdir(exist_ok=True) + file_path = base_path / "test.parquet" -# =========================================================================== -# 5. Struct extract pushdown -# =========================================================================== + duckdb_cursor.execute("SET TimeZone='UTC';") + # Example data + dt = datetime.datetime(2023, 8, 29, 1, tzinfo=datetime.timezone.utc) -class TestOneLevelStruct: - """``struct_extract`` chains build the path inside ``ResolveColumn``. + my_arrow_table = pa.Table.from_pydict({"ts": [dt, dt, dt], "value": [1, 2, 3]}) + df = my_arrow_table.to_pandas() + df = df.set_index("ts") # SET INDEX! (It all works correctly when the index is not set) + df.to_parquet(str(file_path)) - The EXPLAIN plan renders the predicate using the function form - ``(struct_extract(s, 'a') < 2)`` rather than the dot form ``s.a < 2``. - """ + my_arrow_dataset = pa_ds.dataset(str(file_path)) + res = duckdb_cursor.execute("SELECT * FROM my_arrow_dataset WHERE ts = ?", parameters=[dt]).to_arrow_table() + output = duckdb_cursor.sql("select * from res").fetchall() + expected = [(1, dt), (2, dt), (3, dt)] + assert output == expected - @pytest.fixture(autouse=True) - def _one_level_struct(self, duckdb_cursor): - duckdb_cursor.execute("CREATE TABLE _s (s STRUCT(a INTEGER, b BOOL))") + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) + def test_filter_pushdown_date(self, duckdb_cursor, create_table): duckdb_cursor.execute( - "INSERT INTO _s VALUES " - "({'a': 1, 'b': true}), ({'a': 2, 'b': false}), (NULL), " - "({'a': 3, 'b': true}), ({'a': NULL, 'b': NULL})" + """ + CREATE TABLE test_date ( + a DATE, + b DATE, + c DATE + ) + """ + ) + duckdb_cursor.execute( + """ + INSERT INTO test_date VALUES + ('2000-01-01','2000-01-01','2000-01-01'), + ('2000-10-01','2000-10-01','2000-10-01'), + ('2010-01-01','2000-10-01','2010-01-01'), + (NULL,NULL,NULL) + """ + ) + duck_tbl = duckdb_cursor.table("test_date") + arrow_table = create_table(duck_tbl) + + # Try == + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a = '2000-01-01'").fetchone()[0] == 1 + # Try > + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a > '2000-01-01'").fetchone()[0] == 2 + # Try >= + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a >= '2000-10-01'").fetchone()[0] == 2 + # Try < + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a < '2000-10-01'").fetchone()[0] == 1 + # Try <= + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a <= '2000-10-01'").fetchone()[0] == 2 + + # Try Is Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NULL").fetchone()[0] == 1 + # Try Is Not Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NOT NULL").fetchone()[0] == 3 + + # Try And + assert ( + duckdb_cursor.execute( + "SELECT count(*) from arrow_table where a = '2000-10-01' and b = '2000-01-01'" + ).fetchone()[0] + == 0 + ) + assert ( + duckdb_cursor.execute( + "SELECT count(*) from arrow_table where a = '2010-01-01' and b = '2000-10-01' and c = '2010-01-01'" + ).fetchone()[0] + == 1 + ) + # Try Or + assert ( + duckdb_cursor.execute( + "SELECT count(*) from arrow_table where a = '2010-01-01' or b = '2000-01-01'" + ).fetchone()[0] + == 2 ) - arrow_table = to_arrow_table(duckdb_cursor.table("_s")) - duckdb_cursor.register("arrow_table", arrow_table) - def test_one_level_comparison_is_pushed(self, duckdb_cursor): - plan = duckdb_cursor.execute("EXPLAIN SELECT * FROM arrow_table WHERE s.a < 2").fetchone()[1] - assert re.search(r"struct_extract\(s,\s*'a'\)\s*<", plan) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) + def test_filter_pushdown_blob(self, duckdb_cursor, create_table): + import pandas - def test_one_level_comparison_correct(self, duckdb_cursor): - assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a < 2").fetchone()[0] == {"a": 1, "b": True} + df = pandas.DataFrame( + { + "a": [bytes([1]), bytes([2]), bytes([3]), None], + "b": [bytes([1]), bytes([2]), bytes([3]), None], + "c": [bytes([1]), bytes([2]), bytes([3]), None], + } + ) + rel = duckdb.from_df(df) + arrow_table = create_table(rel) + + # Try == + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a = '\x01'").fetchone()[0] == 1 + # # Try > + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a > '\x01'").fetchone()[0] == 2 + # Try >= + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a >= '\x02'").fetchone()[0] == 2 + # Try < + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a < '\x02'").fetchone()[0] == 1 + # Try <= + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a <= '\x02'").fetchone()[0] == 2 + + # Try Is Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NULL").fetchone()[0] == 1 + # Try Is Not Null + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a IS NOT NULL").fetchone()[0] == 3 + + # Try And + assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a='\x02' and b ='\x01'").fetchone()[0] == 0 + assert ( + duckdb_cursor.execute( + "SELECT count(*) from arrow_table where a = '\x02' and b = '\x02' and c = '\x02'" + ).fetchone()[0] + == 1 + ) + # Try Or + assert ( + duckdb_cursor.execute("SELECT count(*) from arrow_table where a = '\x01' or b = '\x02'").fetchone()[0] == 2 + ) - def test_one_level_and_across_fields_is_pushed(self, duckdb_cursor): - # Strip box-drawing/padding so we can pattern-match across line wraps. - plan = duckdb_cursor.execute("EXPLAIN SELECT * FROM arrow_table WHERE s.a < 3 AND s.b = true").fetchone()[1] - block = _arrow_scan_block(plan) - assert block is not None - flat = re.sub(r"[│|\s]+", " ", block) - assert re.search(r"struct_extract\(s, 'a'\).*struct_extract\(s, 'b'\)", flat) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table, create_pyarrow_dataset]) + def test_filter_pushdown_no_projection(self, duckdb_cursor, create_table): + duckdb_cursor.execute( + """ + CREATE TABLE test_int ( + a INTEGER, + b INTEGER, + c INTEGER + ) + """ + ) + duckdb_cursor.execute( + """ + INSERT INTO test_int VALUES + (1,1,1), + (10,10,10), + (100,10,100), + (NULL,NULL,NULL) + """ + ) + duck_tbl = duckdb_cursor.table("test_int") + arrow_table = create_table(duck_tbl) - def test_one_level_and_correct(self, duckdb_cursor): - assert duckdb_cursor.execute("SELECT count(*) FROM arrow_table WHERE s.a < 3 AND s.b = true").fetchone() == (1,) - assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a < 3 AND s.b = true").fetchone()[0] == { - "a": 1, - "b": True, - } + assert duckdb_cursor.execute("SELECT * FROM arrow_table VALUES where a = 1").fetchall() == [(1, 1, 1)] + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) + def test_filter_pushdown_2145(self, duckdb_cursor, tmp_path, create_table): + import pandas -class TestNestedStruct: - """Two-level ``struct_extract`` chains.""" + date1 = pandas.date_range("2018-01-01", "2018-12-31", freq="B") + df1 = pandas.DataFrame(np.random.randn(date1.shape[0], 5), columns=list("ABCDE")) + df1["date"] = date1 - @pytest.fixture(autouse=True) - def _nested_struct(self, duckdb_cursor): - duckdb_cursor.execute("CREATE TABLE _n (s STRUCT(a STRUCT(b INTEGER, c BOOL), d STRUCT(e INTEGER, f VARCHAR)))") + date2 = pandas.date_range("2019-01-01", "2019-12-31", freq="B") + df2 = pandas.DataFrame(np.random.randn(date2.shape[0], 5), columns=list("ABCDE")) + df2["date"] = date2 + + data1 = tmp_path / "data1.parquet" + data2 = tmp_path / "data2.parquet" + duckdb_cursor.execute(f"copy (select * from df1) to '{data1.as_posix()}'") + duckdb_cursor.execute(f"copy (select * from df2) to '{data2.as_posix()}'") + + glob_pattern = tmp_path / "data*.parquet" + table = duckdb_cursor.read_parquet(glob_pattern.as_posix()).to_arrow_table() + + output_df = duckdb.arrow(table).filter("date > '2019-01-01'").df() + expected_df = duckdb.from_parquet(glob_pattern.as_posix()).filter("date > '2019-01-01'").df() + pandas.testing.assert_frame_equal(expected_df, output_df) + + @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) + def test_struct_filter_pushdown(self, duckdb_cursor, create_table): duckdb_cursor.execute( - "INSERT INTO _n VALUES " - "({'a': {'b': 1, 'c': false}, 'd': {'e': 2, 'f': 'foo'}}), " - "(NULL), " - "({'a': {'b': 3, 'c': true}, 'd': {'e': 4, 'f': 'bar'}}), " - "({'a': {'b': NULL, 'c': true}, 'd': {'e': 5, 'f': 'qux'}}), " - "({'a': NULL, 'd': NULL})" - ) - arrow_table = to_arrow_table(duckdb_cursor.table("_n")) - duckdb_cursor.register("arrow_table", arrow_table) - - def test_nested_two_level_is_pushed(self, duckdb_cursor): - plan = duckdb_cursor.execute("EXPLAIN SELECT * FROM arrow_table WHERE s.a.b < 2").fetchone()[1] - # Outer struct_extract(_, 'b') around inner struct_extract(s, 'a'). - assert re.search( - r"struct_extract.*\(struct_extract\(s,\s*'a'\),.*'b'\)\s*<\s*2", - plan, - flags=re.DOTALL, + """ + CREATE TABLE test_structs (s STRUCT(a integer, b bool)) + """ + ) + duckdb_cursor.execute( + """ + INSERT INTO test_structs VALUES + ({'a': 1, 'b': true}), + ({'a': 2, 'b': false}), + (NULL), + ({'a': 3, 'b': true}), + ({'a': NULL, 'b': NULL}); + """ ) - def test_nested_two_level_correct(self, duckdb_cursor): - assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a.b < 2").fetchone()[0] == { - "a": {"b": 1, "c": False}, - "d": {"e": 2, "f": "foo"}, - } + duck_tbl = duckdb_cursor.table("test_structs") + arrow_table = create_table(duck_tbl) + + # Ensure that the filter is pushed down + query_res = duckdb_cursor.execute( + """ + EXPLAIN SELECT * FROM arrow_table WHERE + s.a < 2 + """ + ).fetchall() + + input = query_res[0][1] + if "PANDAS_SCAN" in input: + pytest.skip(reason="This version of pandas does not produce an Arrow object") + match = re.search(r".*ARROW_SCAN.*Filters:.*s\.a<2.*", input, flags=re.DOTALL) + assert match + + # Check that the filter is applied correctly + assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a < 2").fetchone()[0] == {"a": 1, "b": True} - def test_nested_and_across_branches(self, duckdb_cursor): - assert duckdb_cursor.execute( - "SELECT count(*) FROM arrow_table WHERE s.a.c = true AND s.d.e = 5" - ).fetchone() == (1,) + query_res = duckdb_cursor.execute( + """ + EXPLAIN SELECT * FROM arrow_table WHERE s.a < 3 AND s.b = true + """ + ).fetchall() - def test_nested_varchar_comparison(self, duckdb_cursor): - plan = duckdb_cursor.execute("EXPLAIN SELECT * FROM arrow_table WHERE s.d.f = 'bar'").fetchone()[1] - assert re.search( - r"struct_extract.*\(struct_extract\(s,\s*'d'\),.*'f'\)\s*=\s*'bar'", - plan, + # the explain-output is pretty cramped, so just make sure we see both struct references. + match = re.search( + r".*ARROW_SCAN.*Filters:.*s\.a<3.*AND s\.b=true.*", + query_res[0][1], flags=re.DOTALL, ) - assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.d.f = 'bar'").fetchone()[0] == { - "a": {"b": 3, "c": True}, - "d": {"e": 4, "f": "bar"}, + assert match + + # Check that the filter is applied correctly + assert duckdb_cursor.execute("SELECT COUNT(*) FROM arrow_table WHERE s.a < 3 AND s.b = true").fetchone()[0] == 1 + assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a < 3 AND s.b = true").fetchone()[0] == { + "a": 1, + "b": True, } + # This should not produce a pushdown + query_res = duckdb_cursor.execute( + """ + EXPLAIN SELECT * FROM arrow_table WHERE + s.a IS NULL + """ + ).fetchall() -# =========================================================================== -# 6. LIKE pushdown -# =========================================================================== + match = re.search(".*ARROW_SCAN.*Filters: s\\.a IS NULL.*", query_res[0][1], flags=re.DOTALL) + assert not match + @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) + def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): + duckdb_cursor.execute( + """ + CREATE TABLE test_nested_structs(s STRUCT(a STRUCT(b integer, c bool), d STRUCT(e integer, f varchar))); + """ + ) + duckdb_cursor.execute( + """ + INSERT INTO test_nested_structs VALUES + ({'a': {'b': 1, 'c': false}, 'd': {'e': 2, 'f': 'foo'}}), + (NULL), + ({'a': {'b': 3, 'c': true}, 'd': {'e': 4, 'f': 'bar'}}), + ({'a': {'b': NULL, 'c': true}, 'd': {'e': 5, 'f': 'qux'}}), + ({'a': NULL, 'd': NULL}); + """ + ) -class TestLikePushdown: - """Test LIKE filter pushdown. + duck_tbl = duckdb_cursor.table("test_nested_structs") + arrow_table = create_table(duck_tbl) - LIKE with a fixed prefix decomposes into ``>= prefix AND < prefix+1``; a - constant LIKE (no wildcards) decomposes into ``=``. Both produce regular - comparison ExpressionFilters that the walker handles. - """ + # Ensure that the filter is pushed down + query_res = duckdb_cursor.execute( + """ + EXPLAIN SELECT * FROM arrow_table WHERE s.a.b < 2; + """ + ).fetchall() - @pytest.fixture(autouse=True) - def _s_arrow_table(self, duckdb_cursor): - duckdb_cursor.execute("CREATE TABLE _l AS SELECT 'str_' || lpad(i::VARCHAR, 4, '0') AS s FROM range(100) t(i)") - arrow_table = to_arrow_table(duckdb_cursor.table("_l")) - duckdb_cursor.register("arrow_table", arrow_table) + input = query_res[0][1] + if "PANDAS_SCAN" in input: + pytest.skip(reason="This version of pandas does not produce an Arrow object") + match = re.search(r".*ARROW_SCAN.*Filters:.*s\.a\.b<2.*", input, flags=re.DOTALL) + assert match - def test_like_with_prefix_is_pushed(self, duckdb_cursor): - assert _was_pushed(duckdb_cursor, "SELECT s FROM arrow_table WHERE s LIKE 'str_001%'") + # Check that the filter is applied correctly + assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a.b < 2").fetchone()[0] == { + "a": {"b": 1, "c": False}, + "d": {"e": 2, "f": "foo"}, + } - def test_like_constant_is_pushed(self, duckdb_cursor): - assert _was_pushed(duckdb_cursor, "SELECT s FROM arrow_table WHERE s LIKE 'str_0042'") + query_res = duckdb_cursor.execute( + """ + EXPLAIN SELECT * FROM arrow_table WHERE s.a.c=true AND s.d.e=5 + """ + ).fetchall() - def test_like_with_prefix_correct(self, duckdb_cursor): - rows = duckdb_cursor.execute("SELECT s FROM arrow_table WHERE s LIKE 'str_001%' ORDER BY s").fetchall() - assert rows == [(f"str_001{d}",) for d in "0123456789"] + # the explain-output is pretty cramped, so just make sure we see both struct references. + match = re.search( + r".*ARROW_SCAN.*Filters:.*s\.a\.c=true.*AND s\.d\.e=5.*", + query_res[0][1], + flags=re.DOTALL, + ) + assert match + # Check that the filter is applied correctly + assert duckdb_cursor.execute("SELECT COUNT(*) FROM arrow_table WHERE s.a.c=true AND s.d.e=5").fetchone()[0] == 1 + assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a.c=true AND s.d.e=5").fetchone()[0] == { + "a": {"b": None, "c": True}, + "d": {"e": 5, "f": "qux"}, + } -# =========================================================================== -# 7. CAST temporal pushdown -# =========================================================================== + query_res = duckdb_cursor.execute( + """ + EXPLAIN SELECT * FROM arrow_table WHERE s.d.f = 'bar'; + """ + ) + res = query_res.fetchone()[1] + match = re.search( + r".*ARROW_SCAN.*Filters:.*s\.d\.f='bar'.*", + res, + flags=re.DOTALL, + ) -class TestTemporalCastPushdown: - """``CAST(timestamp_col AS DATE) = …`` pushes an optional relaxed range filter. + assert match - See `TryPushdownTemporalCastFilter`. - """ + # Check that the filter is applied correctly + assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.d.f = 'bar'").fetchone()[0] == { + "a": {"b": 3, "c": True}, + "d": {"e": 4, "f": "bar"}, + } - def test_cast_timestamp_to_date_is_pushed(self, duckdb_cursor): - duckdb_cursor.execute( - "CREATE TABLE _ct AS " - "SELECT TIMESTAMP '2024-01-01 00:00:00' + INTERVAL (i) SECOND AS ts FROM range(86400) t(i)" - ) - arrow_table = to_arrow_table(duckdb_cursor.table("_ct")) - duckdb_cursor.register("arrow_table", arrow_table) - assert _was_pushed( - duckdb_cursor, - "SELECT * FROM arrow_table WHERE CAST(ts AS DATE) = DATE '2024-01-01'", + def test_filter_pushdown_not_supported(self): + con = duckdb.connect() + con.execute( + "CREATE TABLE T as SELECT i::integer a, i::varchar b, i::uhugeint c, i::integer d FROM range(5) tbl(i)" ) + arrow_tbl = con.execute("FROM T").to_arrow_table() + # No projection just unsupported filter + assert con.execute("from arrow_tbl where c == 3").fetchall() == [(3, "3", 3, 3)] -# =========================================================================== -# 8. IS DISTINCT FROM NULL inside OR -# -# This is the one realistic SQL path that produces an -# ExpressionFilter(OPERATOR_IS_NULL / IS_NOT_NULL) for the walker — see -# filter_combiner.cpp:615-625. -# =========================================================================== + # No projection unsupported + supported filter + assert con.execute("from arrow_tbl where c < 4 and a > 2").fetchall() == [(3, "3", 3, 3)] + # No projection supported + unsupported + supported filter + assert con.execute("from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, "3", 3, 3)] + assert con.execute("from arrow_tbl where a > 2 and c < 4 and b == '0' ").fetchall() == [] -class TestDistinctFromNullOrPushdown: - """``IS DISTINCT FROM NULL OR ...`` produces an IS_NOT_NULL ExpressionFilter.""" + # Projection with unsupported filter column + unsupported + supported filter + assert con.execute("select c, b from arrow_tbl where c < 4 and b == '3' and a > 2 ").fetchall() == [(3, "3")] + assert con.execute("select c, b from arrow_tbl where a > 2 and c < 4 and b == '3'").fetchall() == [(3, "3")] - @pytest.fixture(autouse=True) - def _with_nulls(self, duckdb_cursor): - duckdb_cursor.execute("CREATE TABLE _d AS SELECT * FROM (VALUES (1), (NULL), (5), (10)) t(a)") - arrow_table = to_arrow_table(duckdb_cursor.table("_d")) - duckdb_cursor.register("arrow_table", arrow_table) + # Projection without unsupported filter column + unsupported + supported filter + assert con.execute("select a, b from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, "3")] - def test_distinct_from_null_or_eq_is_pushed(self, duckdb_cursor): - assert _was_pushed( - duckdb_cursor, - "SELECT a FROM arrow_table WHERE a IS DISTINCT FROM NULL OR a = 5", + # Lets also experiment with multiple unpush-able filters + con.execute( + "CREATE TABLE T_2 as SELECT i::integer a, i::varchar b, i::uhugeint c, i::integer d , i::uhugeint e, i::smallint f, i::uhugeint g FROM range(50) tbl(i)" # noqa: E501 ) - def test_distinct_from_null_or_eq_correct(self, duckdb_cursor): - rows = duckdb_cursor.execute( - "SELECT a FROM arrow_table WHERE a IS DISTINCT FROM NULL OR a = 5 ORDER BY a" - ).fetchall() - assert rows == [(1,), (5,), (10,)] + arrow_tbl = con.execute("FROM T_2").to_arrow_table() - def test_not_distinct_from_null_or_eq_correct(self, duckdb_cursor): - rows = duckdb_cursor.execute( - "SELECT a FROM arrow_table WHERE a IS NOT DISTINCT FROM NULL OR a = 5 ORDER BY a NULLS FIRST" - ).fetchall() - # NULL row + a=5 row - assert rows == [(None,), (5,)] + assert con.execute( + "select a, b from arrow_tbl where a > 2 and c < 40 and b == '28' and g > 15 and e < 30" + ).fetchall() == [(28, "28")] + def test_join_filter_pushdown(self, duckdb_cursor): + duckdb_conn = duckdb.connect() + duckdb_conn.execute("CREATE TABLE probe as select range a from range(10000);") + duckdb_conn.execute("CREATE TABLE build as select (random()*9999)::INT b from range(20);") + duck_probe = duckdb_conn.table("probe") + duck_build = duckdb_conn.table("build") + duck_probe_arrow = duck_probe.to_arrow_table() + duck_build_arrow = duck_build.to_arrow_table() + duckdb_conn.register("duck_probe_arrow", duck_probe_arrow) + duckdb_conn.register("duck_build_arrow", duck_build_arrow) + assert duckdb_conn.execute("SELECT count(*) from duck_probe_arrow, duck_build_arrow where a=b").fetchall() == [ + (20,) + ] + + def test_in_filter_pushdown(self, duckdb_cursor): + duckdb_conn = duckdb.connect() + duckdb_conn.execute("CREATE TABLE probe as select range a from range(1000);") + duck_probe = duckdb_conn.table("probe") + duck_probe_arrow = duck_probe.to_arrow_table() + duckdb_conn.register("duck_probe_arrow", duck_probe_arrow) + assert duckdb_conn.execute("SELECT * from duck_probe_arrow where a = any([1,999])").fetchall() == [(1,), (999,)] -# =========================================================================== -# 9. Special-shape filters: optional, dynamic top-N, join -# =========================================================================== + @pytest.mark.timeout(10) + def test_in_filter_pushdown_large_list(self, duckdb_cursor): + """Large IN lists must not hang. Regression test for https://github.com/duckdb/duckdb-python/issues/52.""" + arrow_table = pa.table({"a": pa.array(range(5000))}) + in_list = ", ".join(str(i) for i in range(0, 5000, 2)) + result = duckdb.sql(f"SELECT count(*) FROM arrow_table WHERE a IN ({in_list})").fetchone() + assert result == (2500,) + def test_in_filter_pushdown_with_nulls(self, duckdb_cursor): + arrow_table = pa.table({"a": pa.array([1, 2, None, 4, None, 6])}) + # IN list without NULL: null rows should not match + result = duckdb.sql("SELECT a FROM arrow_table WHERE a IN (1, 4) ORDER BY a").fetchall() + assert result == [(1,), (4,)] + # IN list with NULL: null rows still should not match (SQL semantics) + result = duckdb.sql("SELECT a FROM arrow_table WHERE a IN (1, 4, NULL) ORDER BY a").fetchall() + assert result == [(1,), (4,)] -class TestOptionalFilter: - """An OptionalFilter is allowed to silently fail. + def test_in_filter_pushdown_varchar(self, duckdb_cursor): + arrow_table = pa.table({"s": pa.array(["alice", "bob", "charlie", "dave", None])}) + result = duckdb.sql("SELECT s FROM arrow_table WHERE s IN ('bob', 'dave') ORDER BY s").fetchall() + assert result == [("bob",), ("dave",)] - The engine reapplies it above the scan. The result must remain correct. - """ + def test_in_filter_pushdown_float(self, duckdb_cursor): + arrow_table = pa.table({"f": pa.array([1.0, 2.5, 3.75, 4.0, None], type=pa.float64())}) + result = duckdb.sql("SELECT f FROM arrow_table WHERE f IN (2.5, 4.0) ORDER BY f").fetchall() + assert result == [(2.5,), (4.0,)] - def test_no_crash_correct_result(self): + def test_pushdown_of_optional_filter(self, duckdb_cursor): cardinality_table = pa.Table.from_pydict( { "column_name": [ @@ -480,10 +965,17 @@ def test_no_crash_correct_result(self): "cardinality": [100, 100, 100, 45, 5, 3, 6, 39, 5], } ) + result = duckdb.query( - "SELECT * FROM cardinality_table WHERE cardinality > 1 ORDER BY cardinality ASC" - ).fetchall() - assert result == [ + """ + SELECT * + FROM cardinality_table + WHERE cardinality > 1 + ORDER BY cardinality ASC + """ + ) + res = result.fetchall() + assert res == [ ("is_available", 3), ("category", 5), ("color", 5), @@ -495,55 +987,39 @@ def test_no_crash_correct_result(self): ("price", 100), ] + # DuckDB intentionally violates IEEE-754 when it comes to NaNs, ensuring a total ordering where NaN is the + # greatest value + def test_nan_filter_pushdown(self, duckdb_cursor): + duckdb_cursor.execute( + """ + create table test as select a::DOUBLE a from VALUES + ('inf'), + ('nan'), + ('0.34234'), + ('34234234.00005'), + ('-nan') + t(a); + """ + ) -class TestDynamicFilter: - """The top-N optimization installs a dynamic filter. + def assert_equal_results(con, arrow_table, query) -> None: + duckdb_res = con.sql(query.format(table="test")).fetchall() + arrow_res = con.sql(query.format(table="arrow_table")).fetchall() + assert len(duckdb_res) == len(arrow_res) - The walker returns ``py::none()`` for those (DuckDB applies them above the scan). - """ + arrow_table = duckdb_cursor.table("test").to_arrow_table() + assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a > 'NaN'::FLOAT") + assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a >= 'NaN'::FLOAT") + assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a < 'NaN'::FLOAT") + assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a <= 'NaN'::FLOAT") + assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a = 'NaN'::FLOAT") + assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a != 'NaN'::FLOAT") - def test_topn_dynamic_filter(self, duckdb_cursor): + def test_dynamic_filter(self, duckdb_cursor): t = pa.Table.from_pydict({"a": [3, 24, 234, 234, 234, 234, 234, 234, 234, 45, 2, 5, 2, 45]}) duckdb_cursor.register("t", t) - rows = duckdb_cursor.sql("SELECT a FROM t ORDER BY a LIMIT 11").fetchall() - assert len(rows) == 11 - - -class TestJoinFilterPushdown: - """Join pushdown between two arrow tables must produce the right count. - - The join's runtime filters take a separate code path that doesn't - reach the static-filter walker — see TestCanaries below. - """ - - def test_two_arrow_tables(self): - con = duckdb.connect() - con.execute("CREATE TABLE probe AS SELECT range a FROM range(10000)") - con.execute("CREATE TABLE build AS SELECT (random()*9999)::INT b FROM range(20)") - con.register("probe_arrow", to_arrow_table(con.table("probe"))) - con.register("build_arrow", to_arrow_table(con.table("build"))) - assert con.execute("SELECT count(*) FROM probe_arrow, build_arrow WHERE a = b").fetchall() == [(20,)] - - -# =========================================================================== -# 10. Unsupported-type fallback (filter applied above the scan) -# =========================================================================== - - -class TestUnsupportedTypes: - """``UHUGEINT``, ``string_view``, ``binary_view`` filters must not crash. - - The filter is applied above the scan instead of being pushed down. - """ - - def test_uhugeint_single_filter(self): - con = duckdb.connect() - con.execute( - "CREATE TABLE t AS SELECT i::INTEGER a, i::VARCHAR b, i::UHUGEINT c, i::INTEGER d FROM range(5) tbl(i)" - ) - arrow_tbl = to_arrow_table(con.table("t")) - con.register("arrow_tbl", arrow_tbl) - assert con.execute("FROM arrow_tbl WHERE c = 3").fetchall() == [(3, "3", 3, 3)] + res = duckdb_cursor.sql("SELECT a FROM t ORDER BY a LIMIT 11").fetchall() + assert len(res) == 11 def test_dynamic_filter_nulls_first_pyarrow(self, duckdb_cursor): # Regression for #460(a): TOP_N with ASC NULLS FIRST pushes an @@ -565,304 +1041,39 @@ def test_dynamic_filter_nulls_first_polars_dataframe(self, duckdb_cursor): res = duckdb_cursor.sql("SELECT * FROM src ORDER BY x ASC NULLS FIRST LIMIT 1").fetchall() assert res == [(1,)] - def test_uhugeint_mixed_with_supported(self): - con = duckdb.connect() - con.execute( - "CREATE TABLE t AS SELECT i::INTEGER a, i::VARCHAR b, i::UHUGEINT c, i::INTEGER d FROM range(5) tbl(i)" - ) - arrow_tbl = to_arrow_table(con.table("t")) - con.register("arrow_tbl", arrow_tbl) - assert con.execute("FROM arrow_tbl WHERE c < 4 AND a > 2").fetchall() == [(3, "3", 3, 3)] - assert con.execute("FROM arrow_tbl WHERE a > 2 AND c < 4 AND b = '3'").fetchall() == [(3, "3", 3, 3)] - assert con.execute("FROM arrow_tbl WHERE a > 2 AND c < 4 AND b = '0'").fetchall() == [] - - def test_uhugeint_with_projection(self): - con = duckdb.connect() - con.execute( - "CREATE TABLE t AS SELECT i::INTEGER a, i::VARCHAR b, i::UHUGEINT c, i::INTEGER d FROM range(5) tbl(i)" - ) - arrow_tbl = to_arrow_table(con.table("t")) - con.register("arrow_tbl", arrow_tbl) - assert con.execute("SELECT c, b FROM arrow_tbl WHERE c < 4 AND b = '3' AND a > 2").fetchall() == [(3, "3")] - # Projection list doesn't include the unpushable column - assert con.execute("SELECT a, b FROM arrow_tbl WHERE a > 2 AND c < 4 AND b = '3'").fetchall() == [(3, "3")] - - def test_multiple_unpushable_filters(self): - con = duckdb.connect() - con.execute( - "CREATE TABLE t AS SELECT i::INTEGER a, i::VARCHAR b, i::UHUGEINT c, " - "i::INTEGER d, i::UHUGEINT e, i::SMALLINT f, i::UHUGEINT g " - "FROM range(50) tbl(i)" - ) - arrow_tbl = to_arrow_table(con.table("t")) - con.register("arrow_tbl", arrow_tbl) - assert con.execute( - "SELECT a, b FROM arrow_tbl WHERE a > 2 AND c < 40 AND b = '28' AND g > 15 AND e < 30" - ).fetchall() == [(28, "28")] - - def test_binary_view_filter_does_not_crash(self): - """Binary view filters cannot be pushed (pyarrow limitation). - - Results must still be correct. - """ + def test_binary_view_filter(self, duckdb_cursor): + """Filters on a view column work (without pushdown because pyarrow does not support view filters yet).""" table = pa.table({"col": pa.array([b"abc", b"efg"], type=pa.binary_view())}) dset = pa_ds.dataset(table) - res = duckdb.sql("SELECT * FROM dset WHERE col = 'abc'::BINARY").fetchall() - assert res == [(b"abc",)] - - def test_string_view_filter_does_not_crash(self): - """String view filters cannot be pushed (pyarrow limitation). + res = duckdb_cursor.sql("select * from dset where col = 'abc'::binary") + assert len(res) == 1 - Results must still be correct. - """ + def test_string_view_filter(self, duckdb_cursor): + """Filters on a view column work (without pushdown because pyarrow does not support view filters yet).""" table = pa.table({"col": pa.array(["abc", "efg"], type=pa.string_view())}) dset = pa_ds.dataset(table) - res = duckdb.sql("SELECT * FROM dset WHERE col = 'abc'").fetchall() - assert res == [("abc",)] - - -# =========================================================================== -# 11. Projection / scanner-path interactions -# =========================================================================== - - -@pytest.mark.parametrize("factory", ARROW_FACTORIES_WITH_DATASET) -def test_filter_without_projection(duckdb_cursor, factory): - """Filter applied when no projection is specified. - - Covers all three conversion paths including the dataset scanner. - """ - duckdb_cursor.execute("CREATE TABLE _np (a INTEGER, b INTEGER, c INTEGER)") - duckdb_cursor.execute("INSERT INTO _np VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") - arrow_table = factory(duckdb_cursor.table("_np")) - duckdb_cursor.register("arrow_table", arrow_table) - assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE a = 1").fetchall() == [(1, 1, 1)] - - -# =========================================================================== -# 12. Decimal pushdown via polars (the only path that exercises decimal -# scalar coercion in the walker) -# =========================================================================== - - -def test_decimal_filter_pushdown_via_polars(duckdb_cursor): - """Polars decimal frames stress GetScalar's decimal branch.""" - pl = pytest.importorskip("polars") - np.random.seed(10) - df = pl.DataFrame({"x": pl.Series(np.random.uniform(-10, 10, 1000)).cast(pl.Decimal(precision=18, scale=4))}) - rows = duckdb_cursor.sql( - """ - SELECT x, x > 0.05 AS is_x_good, x::FLOAT > 0.05 AS is_float_x_good - FROM df - WHERE is_x_good - ORDER BY x ASC - """ - ).fetchall() - assert len(rows) == 495 + res = duckdb_cursor.sql("select * from dset where col = 'abc'") + assert len(res) == 1 - -# =========================================================================== -# 13. Regressions -# =========================================================================== - - -class TestRegressions: - """Issues that were fixed and need to stay fixed.""" - - @pytest.mark.skipif( - Version(pa.__version__) < Version("15.0.0"), - reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning", - ) - def test_9371_arrow_dataset_with_tz_parameter(self, duckdb_cursor, tmp_path): - """Parameterized timestamp filter against a pandas-indexed parquet dataset. - - https://github.com/duckdb/duckdb/issues/9371 - """ - duckdb_cursor.execute("SET TimeZone='UTC'") - file_path = tmp_path / "test.parquet" - timestamp = dt.datetime(2023, 8, 29, 1, tzinfo=dt.timezone.utc) - my_arrow_table = pa.Table.from_pydict({"ts": [timestamp] * 3, "value": [1, 2, 3]}) - df = my_arrow_table.to_pandas().set_index("ts") - df.to_parquet(str(file_path)) - - my_arrow_dataset = pa_ds.dataset(str(file_path)) - res = duckdb_cursor.execute( - "SELECT * FROM my_arrow_dataset WHERE ts = ?", parameters=[timestamp] - ).to_arrow_table() - assert duckdb_cursor.sql("SELECT * FROM res").fetchall() == [(1, timestamp), (2, timestamp), (3, timestamp)] - - def test_2145_parquet_glob_through_arrow(self, duckdb_cursor, tmp_path): - """Filter pushdown into a parquet glob via the arrow scan. - - https://github.com/duckdb/duckdb/issues/2145 - """ - date1 = pd.date_range("2018-01-01", "2018-12-31", freq="B") - df1 = pd.DataFrame(np.random.randn(date1.shape[0], 5), columns=list("ABCDE")) - df1["date"] = date1 - date2 = pd.date_range("2019-01-01", "2019-12-31", freq="B") - df2 = pd.DataFrame(np.random.randn(date2.shape[0], 5), columns=list("ABCDE")) - df2["date"] = date2 - - data1 = tmp_path / "data1.parquet" - data2 = tmp_path / "data2.parquet" - duckdb_cursor.execute(f"COPY (SELECT * FROM df1) TO '{data1.as_posix()}'") - duckdb_cursor.execute(f"COPY (SELECT * FROM df2) TO '{data2.as_posix()}'") - - glob_pattern = (tmp_path / "data*.parquet").as_posix() - table = duckdb_cursor.read_parquet(glob_pattern).to_arrow_table() - output_df = duckdb.arrow(table).filter("date > '2019-01-01'").df() - expected_df = duckdb.from_parquet(glob_pattern).filter("date > '2019-01-01'").df() - pd.testing.assert_frame_equal(expected_df, output_df) - - -# =========================================================================== -# 14. Canaries -# -# Each canary documents a current limitation or an expected future change. -# When upstream behaviour shifts, a canary either xpasses (failing the suite -# and forcing us to update) or fails outright. -# =========================================================================== - - -class TestCanaries: - """Markers for behaviours we expect to change upstream eventually.""" - - # ----- pyarrow capabilities ---------------------------------------- - - @pytest.mark.xfail( - raises=pa_lib.ArrowNotImplementedError, - reason="pyarrow does not yet implement string_view filter compare kernels", - strict=True, - ) - def test_pyarrow_gains_string_view_filter_support(self): - """When pyarrow adds string_view comparison kernels this will xpass. - - At that point we should remove the post-scan fallback in TestUnsupportedTypes. - """ + @pytest.mark.xfail(raises=pa_lib.ArrowNotImplementedError) + def test_canary_for_pyarrow_string_view_filter_support(self, duckdb_cursor): + """This canary will xpass when pyarrow implements string view filter support.""" + # predicate: field == "string value" filter_expr = pa_ds.field("col") == pa_ds.scalar("val1") + # dataset with a string view column table = pa.table({"col": pa.array(["val1", "val2"], type=pa.string_view())}) - pa_ds.dataset(table).scanner(columns=["col"], filter=filter_expr) - - @pytest.mark.xfail( - raises=pa_lib.ArrowNotImplementedError, - reason="pyarrow does not yet implement binary_view filter compare kernels", - strict=True, - ) - def test_pyarrow_gains_binary_view_filter_support(self): - """When pyarrow adds binary_view comparison kernels this will xpass.""" - filter_expr = pa_ds.field("col") == pa_ds.scalar(pa.scalar(b"bin1", pa.binary_view())) + dset = pa_ds.dataset(table) + # creating the scanner fails + dset.scanner(columns=["col"], filter=filter_expr) + + @pytest.mark.xfail(raises=pa_lib.ArrowNotImplementedError) + def test_canary_for_pyarrow_binary_view_filter_support(self, duckdb_cursor): + """This canary will xpass when pyarrow implements binary view filter support.""" + # predicate: field == const + const = pa_ds.scalar(pa.scalar(b"bin1", pa.binary_view())) + filter_expr = pa_ds.field("col") == const + # dataset with a string view column table = pa.table({"col": pa.array([b"bin1", b"bin2"], type=pa.binary_view())}) - pa_ds.dataset(table).scanner(columns=["col"], filter=filter_expr) - - # ----- DuckDB optimizer decisions we expect to change -------------- - - @pytest.mark.xfail( - reason="DuckDB does not currently push IS NULL into the arrow scan", - strict=True, - ) - def test_is_null_pushes_into_arrow_scan(self, duckdb_cursor): - """If the optimizer starts pushing standalone IS NULL into arrow scans, this canary xpasses. - - The walker already has the OPERATOR_IS_NULL arm. - """ - duckdb_cursor.execute("CREATE TABLE _t AS SELECT * FROM (VALUES (1), (NULL), (3)) v(a)") - arrow_table = to_arrow_table(duckdb_cursor.table("_t")) - duckdb_cursor.register("arrow_table", arrow_table) - assert _was_pushed(duckdb_cursor, "SELECT a FROM arrow_table WHERE a IS NULL") - - @pytest.mark.xfail( - reason="DuckDB does not currently push IS NULL on struct fields into the arrow scan", - strict=True, - ) - def test_struct_is_null_pushes(self, duckdb_cursor): - """If the optimizer starts pushing struct-field IS NULL, this canary xpasses.""" - duckdb_cursor.execute("CREATE TABLE _s (s STRUCT(a INTEGER))") - duckdb_cursor.execute("INSERT INTO _s VALUES ({'a': 1}), ({'a': NULL}), (NULL)") - arrow_table = to_arrow_table(duckdb_cursor.table("_s")) - duckdb_cursor.register("arrow_table", arrow_table) - assert _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE s.a IS NULL") - - @pytest.mark.xfail( - reason="DuckDB does not currently push struct.a IN (...) into the arrow scan; " - "TryPushdownInFilter requires a bare BoundColumnRef (filter_combiner.cpp:505-508)", - strict=True, - ) - def test_struct_in_pushes(self, duckdb_cursor): - """When DuckDB extends TryPushdownInFilter to allow struct_extract column sides, this canary xpasses. - - ResolveColumn already handles the path. - """ - duckdb_cursor.execute("CREATE TABLE _s (s STRUCT(a INTEGER))") - duckdb_cursor.execute("INSERT INTO _s VALUES ({'a': 1}), ({'a': 2}), ({'a': 42}), (NULL)") - arrow_table = to_arrow_table(duckdb_cursor.table("_s")) - duckdb_cursor.register("arrow_table", arrow_table) - assert _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE s.a IN (1, 42, 99)") - - # ----- Join filter pushdown ---------------------------------------- - # - # BLOOM_FILTER, PERFECT_HASH_JOIN_FILTER, and PREFIX_RANGE_FILTER are - # generated by hash joins. Today they take a separate runtime path - # (`info.dynamic_filters->PushFilter(...)` in physical_hash_join.cpp) that - # does NOT reach PyArrowFilterPushdown::TransformFilter. These canaries - # document the current behaviour: the join runs to the correct answer but - # the walker is not invoked, so the new filter types never surface to the - # bindings. The day arrow_array_stream.cpp starts receiving them via the - # static filter set, we'll fail here. - - @pytest.mark.xfail( - reason="BLOOM_FILTER from joins reaches arrow scans via runtime dynamic_filters, " - "not PyArrowFilterPushdown::TransformFilter", - strict=True, - ) - def test_bloom_filter_reaches_walker(self, duckdb_cursor): - """When join bloom filters start flowing through the static filter set, this canary xpasses. - - Used by arrow_array_stream.cpp. We'll need to add a BLOOM_FILTER case in - TransformFilterRecursive. - """ - duckdb_cursor.execute("CREATE TABLE _probe AS SELECT range AS k FROM range(100_000)") - duckdb_cursor.execute("CREATE TABLE _build AS SELECT (i*2)::BIGINT AS k FROM range(50_000) t(i)") - probe_arrow = to_arrow_table(duckdb_cursor.table("_probe")) - duckdb_cursor.register("probe_arrow", probe_arrow) - assert _was_pushed(duckdb_cursor, "SELECT count(*) FROM probe_arrow JOIN _build USING(k)") - - @pytest.mark.xfail( - reason="PERFECT_HASH_JOIN_FILTER from joins reaches arrow scans via runtime dynamic_filters, " - "not PyArrowFilterPushdown::TransformFilter", - strict=True, - ) - def test_perfect_hash_join_filter_reaches_walker(self, duckdb_cursor): - duckdb_cursor.execute("CREATE TABLE _probe AS SELECT range AS k FROM range(10_000)") - duckdb_cursor.execute("CREATE TABLE _build AS SELECT i AS k FROM range(100) t(i)") - probe_arrow = to_arrow_table(duckdb_cursor.table("_probe")) - duckdb_cursor.register("probe_arrow", probe_arrow) - assert _was_pushed(duckdb_cursor, "SELECT count(*) FROM probe_arrow JOIN _build USING(k)") - - @pytest.mark.xfail( - reason="PREFIX_RANGE_FILTER from joins reaches arrow scans via runtime dynamic_filters, " - "not PyArrowFilterPushdown::TransformFilter", - strict=True, - ) - def test_prefix_range_filter_reaches_walker(self, duckdb_cursor): - duckdb_cursor.execute( - "CREATE TABLE _probe AS SELECT 'str_' || lpad(i::VARCHAR, 4, '0') AS k FROM range(10_000) t(i)" - ) - duckdb_cursor.execute( - "CREATE TABLE _build AS SELECT 'str_' || lpad((i*2)::VARCHAR, 4, '0') AS k FROM range(500) t(i)" - ) - probe_arrow = to_arrow_table(duckdb_cursor.table("_probe")) - duckdb_cursor.register("probe_arrow", probe_arrow) - assert _was_pushed(duckdb_cursor, "SELECT count(*) FROM probe_arrow JOIN _build USING(k)") - - # ----- Optimizer canonicalization expected to stay ----------------- - - def test_not_in_does_not_push(self, duckdb_cursor): - """``NOT IN`` is rewritten to AND-of-!= by the optimizer rather than being pushed as a single operator. - - If this ever changes the assertion flips to ``_was_pushed`` and the walker's - BOUND_OPERATOR arm needs an ``OPERATOR_NOT`` case. - """ - duckdb_cursor.execute("CREATE TABLE _t AS SELECT range AS a FROM range(1000)") - arrow_table = to_arrow_table(duckdb_cursor.table("_t")) - duckdb_cursor.register("arrow_table", arrow_table) - assert not _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a NOT IN (1, 5, 100)") + dset = pa_ds.dataset(table) + # creating the scanner fails + dset.scanner(columns=["col"], filter=filter_expr) diff --git a/tests/fast/arrow/test_polars_filter_pushdown.py b/tests/fast/arrow/test_polars_filter_pushdown.py index 8b3f4acf..20e63819 100644 --- a/tests/fast/arrow/test_polars_filter_pushdown.py +++ b/tests/fast/arrow/test_polars_filter_pushdown.py @@ -1,519 +1,155 @@ # ruff: noqa: F841 -"""Filter pushdown tests for polars-backed scans. - -What's tested here: - -* **Comparison correctness** across every supported type, run against two - polars factories (`pl.LazyFrame` and `pl.DataFrame`) — see below. -* **Optimizer pushdown decisions** — same EXPLAIN-based checks as the pyarrow - test file; verified that polars LazyFrame and DataFrame both bind through - ``arrow_scan`` and render as ``ARROW_SCAN`` in the plan. -* **Special filter shapes** — same coverage as the pyarrow file: IN, LIKE, - CAST temporal, ``IS DISTINCT FROM NULL`` inside OR, NaN ordering, struct - extraction (single and multi-level), OptionalFilter, top-N dynamic, join. -* **Produce-path** tests specific to the LazyFrame branch in - ``arrow_array_stream.cpp`` (cached materialised table, repeated filtered - scans, empty result, etc.). -* **Regressions** and **canaries** mirroring the pyarrow file. - -Two factories cover the two distinct C++ scan paths that polars inputs -exercise: - -* ``to_polars_lazyframe(rel) → pl.LazyFrame`` — routes through - ``PyArrowObjectType::PolarsLazyFrame`` and is handled by - ``polars_filter_pushdown.cpp``. This is the path with the current gaps. -* ``to_polars_dataframe(rel) → pl.DataFrame`` — converted to a pyarrow Table - via ``.to_arrow()`` at registration and then handled by - ``pyarrow_filter_pushdown.cpp``. Acts as a "reference implementation": - whatever passes here defines the expected polars-side behaviour. - -The pre-Phase-3 expectation is that ``dataframe`` factory cases mostly pass -while ``lazyframe`` factory cases fail wherever the polars walker has a gap -(EXPRESSION_FILTER, deep struct, decimal IN, OR-bail). Those failures drive -the C++ implementation phases. -""" - -from __future__ import annotations - import math -import re import pytest -from _pushdown_helpers import ( - COMPARABLE_TYPES, - COMPARISON_CASES, -) -from _pushdown_helpers import ( - arrow_scan_block as _arrow_scan_block, -) -from _pushdown_helpers import ( - count as _count, -) -from _pushdown_helpers import ( - make_typed_table as _make_typed_table, -) -from _pushdown_helpers import ( - was_pushed as _was_pushed, -) import duckdb pl = pytest.importorskip("polars") -pa = pytest.importorskip("pyarrow") - - -# =========================================================================== -# Conversion factories — polars side -# =========================================================================== - - -def to_polars_lazyframe(rel): - return rel.pl().lazy() - - -def to_polars_dataframe(rel): - return rel.pl() - - -POLARS_FACTORIES = [ - pytest.param(to_polars_lazyframe, id="lazyframe"), - pytest.param(to_polars_dataframe, id="dataframe"), -] - - -# =========================================================================== -# 1. Comparison correctness across types -# =========================================================================== - +pytest.importorskip("pyarrow") -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -@pytest.mark.parametrize("case", COMPARABLE_TYPES, ids=lambda c: c.id) -@pytest.mark.parametrize(("predicate_tpl", "expected"), COMPARISON_CASES) -def test_comparisons(duckdb_cursor, factory, case, predicate_tpl, expected): - """Each (type, factory, predicate) tuple produces the expected row count.""" - _make_typed_table(duckdb_cursor, factory, case) - predicate = predicate_tpl.format(low=case.low, mid=case.mid, high=case.high) - assert _count(duckdb_cursor, predicate) == expected +class TestPolarsLazyFrameFilterPushdown: + """Tests for filter pushdown on LazyFrames. -# BOOL has no ordering, so it gets its own tiny suite. -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -def test_bool_comparisons(duckdb_cursor, factory): - """Equality / IS NULL / AND / OR on BOOL columns.""" - duckdb_cursor.execute("CREATE TABLE _b (a BOOL, b BOOL)") - duckdb_cursor.execute("INSERT INTO _b VALUES (TRUE, TRUE), (TRUE, FALSE), (FALSE, TRUE), (NULL, NULL)") - arrow_table = factory(duckdb_cursor.table("_b")) - duckdb_cursor.register("arrow_table", arrow_table) - - assert _count(duckdb_cursor, "a = TRUE") == 2 - assert _count(duckdb_cursor, "a IS NULL") == 1 - assert _count(duckdb_cursor, "a IS NOT NULL") == 3 - assert _count(duckdb_cursor, "a = TRUE AND b = TRUE") == 1 - assert _count(duckdb_cursor, "a = TRUE OR b = TRUE") == 3 + All tests use pl.LazyFrame (the target of this change). DuckDB pushes filters and projections into the Polars lazy + plan before collection, so only surviving rows are ever materialized. + """ + ##### CONSTANT_COMPARISON: all six comparison operators -# Integer boundary values are worth a separate test because GetScalar / FromValue -# has to coerce each (DuckDB Value) -> (target backend scalar) at the limit. -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -@pytest.mark.parametrize( - ("data_type", "max_value"), - [ - ("TINYINT", 127), - ("SMALLINT", 32767), - ("INTEGER", 2147483647), - ("BIGINT", 9223372036854775807), - ("UTINYINT", 255), - ("USMALLINT", 65535), - ("UINTEGER", 4294967295), - ("UBIGINT", 18446744073709551615), - ], -) -def test_integer_max_value(duckdb_cursor, factory, data_type, max_value): - """Pushdown round-trips through every integer's maximum representable value.""" - duckdb_cursor.execute(f"CREATE TABLE _t AS SELECT {max_value}::{data_type} AS i") - arrow_table = factory(duckdb_cursor.table("_t")) - duckdb_cursor.register("arrow_table", arrow_table) - expected = [(max_value,)] - assert duckdb_cursor.sql("SELECT * FROM arrow_table WHERE i > 0").fetchall() == expected - assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE i > ?", (0,)).fetchall() == expected - assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE i = ?", (max_value,)).fetchall() == expected - - -# =========================================================================== -# 2. OR pushdown decisions -# =========================================================================== - - -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestOrPushdownDecisions: - """Same-column ORs push; multi-column ORs and OR-with-AND-child don't.""" - - @pytest.fixture(autouse=True) - def _arrow_table(self, duckdb_cursor, factory): - duckdb_cursor.execute("CREATE TABLE _o (a INTEGER, b INTEGER, c INTEGER)") - duckdb_cursor.execute("INSERT INTO _o VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_o"))) - - def test_single_column_or_pushes(self, duckdb_cursor, factory): - assert _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a = 1 OR a = 10") - - def test_single_column_or_with_and_does_not_push(self, duckdb_cursor, factory): - assert not _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a = 1 OR (a > 3 AND a < 5)") - - def test_multiple_or_terms_push(self, duckdb_cursor, factory): - assert _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a = 1 OR a > 3 OR a < 5") - - def test_or_with_not_equal_pushes(self, duckdb_cursor, factory): - assert _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a != 1 OR a > 3 OR a < 2") - - def test_multi_column_or_does_not_push(self, duckdb_cursor, factory): - assert not _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a = 1 OR b = 2 AND (a > 3 OR b < 5)") - - -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestStringOrSpecifics: - """VARCHAR has stricter OR-pushdown rules than numeric types.""" - - @pytest.fixture(autouse=True) - def _arrow_table(self, duckdb_cursor, factory): - duckdb_cursor.execute("CREATE TABLE _v (a VARCHAR, b VARCHAR, c VARCHAR)") - duckdb_cursor.execute( - "INSERT INTO _v VALUES ('1','1','1'),('10','10','10'),('100','10','100'),(NULL,NULL,NULL)" - ) - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_v"))) - - def test_string_range_or_pushes(self, duckdb_cursor, factory): - assert _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a >= '1' OR a <= '10'") - - def test_or_with_is_null_does_not_push(self, duckdb_cursor, factory): - assert not _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a IS NULL OR a = '1'") - - def test_or_with_is_not_null_does_not_push(self, duckdb_cursor, factory): - assert not _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a IS NOT NULL OR a = '1'") - - def test_or_with_like_does_not_push(self, duckdb_cursor, factory): - assert not _was_pushed(duckdb_cursor, "SELECT * FROM arrow_table WHERE a LIKE '1%' OR a = '10'") - - -# =========================================================================== -# 3. IN pushdown -# =========================================================================== - - -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestInPushdown: - """Small IN lowers to OR of equalities (CONJUNCTION_OR); large IN forces a real IN_FILTER.""" - - def test_basic(self, duckdb_cursor, factory): - duckdb_cursor.execute("CREATE TABLE _i AS SELECT i AS a FROM range(10) t(i)") - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_i"))) - assert sorted(duckdb_cursor.execute("SELECT * FROM arrow_table WHERE a IN (2, 5, 7)").fetchall()) == [ - (2,), - (5,), - (7,), - ] - - def test_large_in_list_does_not_hang(self, duckdb_cursor, factory): - """A 200-element IN list forces the IN_FILTER path (not the OR-of-eq rewrite).""" - duckdb_cursor.execute("CREATE TABLE _i AS SELECT i AS a FROM range(500) t(i)") - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_i"))) - in_list = ", ".join(str(i) for i in range(200)) - rows = duckdb_cursor.execute(f"SELECT count(*) FROM arrow_table WHERE a IN ({in_list})").fetchone() - assert rows == (200,) - - def test_in_varchar(self, duckdb_cursor, factory): - duckdb_cursor.execute("CREATE TABLE _i AS SELECT 'str_' || i::VARCHAR AS s FROM range(10) t(i)") - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_i"))) - rows = sorted(duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s IN ('str_2', 'str_5')").fetchall()) - assert rows == [("str_2",), ("str_5",)] - - def test_in_float(self, duckdb_cursor, factory): - duckdb_cursor.execute("CREATE TABLE _i AS SELECT i::DOUBLE AS a FROM range(10) t(i)") - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_i"))) - rows = sorted(duckdb_cursor.execute("SELECT * FROM arrow_table WHERE a IN (2.0, 5.0)").fetchall()) - assert rows == [(2.0,), (5.0,)] - - def test_in_with_no_nulls_in_list(self, duckdb_cursor, factory): - duckdb_cursor.execute("CREATE TABLE _i AS SELECT i AS a FROM range(10) t(i)") - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_i"))) - assert sorted(duckdb_cursor.execute("SELECT * FROM arrow_table WHERE a IN (1, 2, 3)").fetchall()) == [ - (1,), - (2,), - (3,), - ] - - def test_in_with_null_in_list(self, duckdb_cursor, factory): - """``a IN (NULL, …)`` returns no rows for the NULL entry; non-null matches survive.""" - duckdb_cursor.execute("CREATE TABLE _i AS SELECT * FROM (VALUES (1), (NULL), (2), (3)) t(a)") - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_i"))) - assert sorted(duckdb_cursor.execute("SELECT * FROM arrow_table WHERE a IN (1, NULL, 3)").fetchall()) == [ - (1,), - (3,), - ] - - def test_in_decimal_large_list(self, duckdb_cursor, factory): - """Large IN list on a decimal column must force the IN_FILTER path and pass. - - For LazyFrame, the current C++ walker constructs a plain Python list of - ``Decimal(...)`` values for ``pl.col(d).is_in(...)``. Polars infers a - higher-precision dtype for the literal list and refuses to compare against - the column's actual ``Decimal(precision, scale)``. The post-Phase-3 - PolarsBackend builds a typed Series matching the column to close this. - """ - duckdb_cursor.execute("CREATE TABLE _i AS SELECT i::DECIMAL(18,4) AS d FROM range(500) t(i)") - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_i"))) - in_list = ", ".join(f"{i}.0000" for i in range(200)) - rows = duckdb_cursor.execute(f"SELECT count(*) FROM arrow_table WHERE d IN ({in_list})").fetchone() - assert rows == (200,) - - -# =========================================================================== -# 4. NaN pushdown -# =========================================================================== -# -# DuckDB intentionally violates IEEE-754: NaN is the greatest value. -# Each backend has to translate that to its target operators (is_nan / lit). -# =========================================================================== - - -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestNaNPushdown: - """Six comparison operators against a NaN constant on a DOUBLE column.""" - - @pytest.fixture(autouse=True) - def _nan_arrow_table(self, duckdb_cursor, factory): - duckdb_cursor.execute( - "CREATE TABLE _n AS SELECT a::DOUBLE a FROM VALUES " - "('inf'), ('nan'), ('0.34234'), ('34234234.00005'), ('-nan') t(a)" - ) - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_n"))) - - @pytest.mark.parametrize( - "op", - ["=", "!=", "<", "<=", ">", ">="], - ) - def test_nan_comparison_matches_duckdb(self, duckdb_cursor, factory, op): - """Each NaN comparison through the arrow scan agrees with DuckDB's own answer.""" - q_arrow = f"SELECT count(*) FROM arrow_table WHERE a {op} 'NaN'::FLOAT" - q_duck = f"SELECT count(*) FROM _n WHERE a {op} 'NaN'::FLOAT" - assert duckdb_cursor.execute(q_arrow).fetchone() == duckdb_cursor.execute(q_duck).fetchone() - - -# =========================================================================== -# 5. Struct extract pushdown -# =========================================================================== - - -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestOneLevelStruct: - """``struct_extract`` chains build the path inside ``ResolveColumn``.""" - - @pytest.fixture(autouse=True) - def _one_level_struct(self, duckdb_cursor, factory): - duckdb_cursor.execute("CREATE TABLE _s (s STRUCT(a INTEGER, b BOOL))") - duckdb_cursor.execute( - "INSERT INTO _s VALUES " - "({'a': 1, 'b': true}), ({'a': 2, 'b': false}), (NULL), " - "({'a': 3, 'b': true}), ({'a': NULL, 'b': NULL})" - ) - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_s"))) - - def test_one_level_comparison_is_pushed(self, duckdb_cursor, factory): - plan = duckdb_cursor.execute("EXPLAIN SELECT * FROM arrow_table WHERE s.a < 2").fetchone()[1] - assert re.search(r"struct_extract\(s,\s*'a'\)\s*<", plan) - - def test_one_level_comparison_correct(self, duckdb_cursor, factory): - assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a < 2").fetchone()[0] == {"a": 1, "b": True} - - def test_one_level_and_correct(self, duckdb_cursor, factory): - assert duckdb_cursor.execute("SELECT count(*) FROM arrow_table WHERE s.a < 3 AND s.b = true").fetchone() == (1,) - - -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestNestedStruct: - """Multi-level ``struct_extract`` chains. - - Polars supports arbitrary depth via chained ``struct.field``; pyarrow uses - tuple-path field references. Both are driven by the same ``ResolveColumn`` - recursion in the shared walker. - """ + def test_comparison_equal(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + assert duckdb.sql("SELECT * FROM lf WHERE a = 3").fetchall() == [(3,)] - @pytest.fixture(autouse=True) - def _nested_struct(self, duckdb_cursor, factory): - duckdb_cursor.execute("CREATE TABLE _n (s STRUCT(a STRUCT(b INTEGER, c BOOL), d STRUCT(e INTEGER, f VARCHAR)))") - duckdb_cursor.execute( - "INSERT INTO _n VALUES " - "({'a': {'b': 1, 'c': false}, 'd': {'e': 2, 'f': 'foo'}}), " - "(NULL), " - "({'a': {'b': 3, 'c': true}, 'd': {'e': 4, 'f': 'bar'}}), " - "({'a': {'b': NULL, 'c': true}, 'd': {'e': 5, 'f': 'qux'}}), " - "({'a': NULL, 'd': NULL})" - ) - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_n"))) - - def test_nested_two_level_correct(self, duckdb_cursor, factory): - assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a.b < 2").fetchone()[0] == { - "a": {"b": 1, "c": False}, - "d": {"e": 2, "f": "foo"}, - } - - def test_nested_and_across_branches(self, duckdb_cursor, factory): - assert duckdb_cursor.execute( - "SELECT count(*) FROM arrow_table WHERE s.a.c = true AND s.d.e = 5" - ).fetchone() == (1,) - - def test_nested_varchar_comparison(self, duckdb_cursor, factory): - assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.d.f = 'bar'").fetchone()[0] == { - "a": {"b": 3, "c": True}, - "d": {"e": 4, "f": "bar"}, - } - - -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestThreeLevelStruct: - """Three-level ``struct_extract`` — verifies depth-N support, not just one or two.""" - - @pytest.fixture(autouse=True) - def _three_level(self, duckdb_cursor, factory): - duckdb_cursor.execute("CREATE TABLE _3 (s STRUCT(a STRUCT(b STRUCT(c INTEGER, d VARCHAR), e INTEGER), f BOOL))") - duckdb_cursor.execute( - "INSERT INTO _3 VALUES " - "({'a': {'b': {'c': 1, 'd': 'one'}, 'e': 10}, 'f': true}), " - "({'a': {'b': {'c': 2, 'd': 'two'}, 'e': 20}, 'f': false}), " - "({'a': {'b': {'c': 3, 'd': 'three'}, 'e': 30}, 'f': true})" - ) - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_3"))) - - def test_three_level_comparison_correct(self, duckdb_cursor, factory): - rows = duckdb_cursor.execute("SELECT count(*) FROM arrow_table WHERE s.a.b.c > 1").fetchone() - assert rows == (2,) - - def test_three_level_varchar_correct(self, duckdb_cursor, factory): - rows = duckdb_cursor.execute("SELECT count(*) FROM arrow_table WHERE s.a.b.d = 'two'").fetchone() - assert rows == (1,) - - -# =========================================================================== -# 6. LIKE pushdown -# =========================================================================== - - -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestLikePushdown: - """LIKE pushdown decomposes into comparison filters. - - A LIKE with a fixed prefix decomposes into ``>= prefix AND < prefix+1``; a - constant LIKE (no wildcards) decomposes into ``=``. Both produce regular - comparison ExpressionFilters that the walker handles. - """ + def test_comparison_not_equal(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + assert duckdb.sql("SELECT * FROM lf WHERE a != 3").fetchall() == [(1,), (2,), (4,), (5,)] - @pytest.fixture(autouse=True) - def _s_arrow_table(self, duckdb_cursor, factory): - duckdb_cursor.execute("CREATE TABLE _l AS SELECT 'str_' || lpad(i::VARCHAR, 4, '0') AS s FROM range(100) t(i)") - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_l"))) + def test_comparison_less_than(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + assert duckdb.sql("SELECT * FROM lf WHERE a < 3").fetchall() == [(1,), (2,)] - def test_like_with_prefix_is_pushed(self, duckdb_cursor, factory): - assert _was_pushed(duckdb_cursor, "SELECT s FROM arrow_table WHERE s LIKE 'str_001%'") + def test_comparison_less_than_or_equal(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + assert duckdb.sql("SELECT * FROM lf WHERE a <= 3").fetchall() == [(1,), (2,), (3,)] - def test_like_constant_is_pushed(self, duckdb_cursor, factory): - assert _was_pushed(duckdb_cursor, "SELECT s FROM arrow_table WHERE s LIKE 'str_0042'") + def test_comparison_greater_than(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + assert duckdb.sql("SELECT * FROM lf WHERE a > 3").fetchall() == [(4,), (5,)] - def test_like_with_prefix_correct(self, duckdb_cursor, factory): - rows = duckdb_cursor.execute("SELECT s FROM arrow_table WHERE s LIKE 'str_001%' ORDER BY s").fetchall() - assert rows == [(f"str_001{d}",) for d in "0123456789"] + def test_comparison_greater_than_or_equal(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + assert duckdb.sql("SELECT * FROM lf WHERE a >= 3").fetchall() == [(3,), (4,), (5,)] + def test_string_comparison(self): + lf = pl.LazyFrame({"name": ["alice", "bob", "charlie"], "val": [1, 2, 3]}) + assert duckdb.sql("SELECT * FROM lf WHERE name = 'bob'").fetchall() == [("bob", 2)] -# =========================================================================== -# 7. CAST temporal pushdown -# =========================================================================== + ##### NaN comparisons (CONSTANT_COMPARISON with is_nan path) + def test_nan_equal(self): + """NaN = NaN is true in DuckDB; pushes is_nan().""" + lf = pl.LazyFrame({"a": [1.0, float("nan"), 3.0]}) + result = duckdb.sql("SELECT * FROM lf WHERE a = 'NaN'::DOUBLE").fetchall() + assert len(result) == 1 + assert math.isnan(result[0][0]) -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestTemporalCastPushdown: - """``CAST(timestamp_col AS DATE) = …`` pushes an optional relaxed range filter.""" + def test_nan_greater_than_or_equal(self): + """NaN >= NaN is true; pushes is_nan().""" + lf = pl.LazyFrame({"a": [1.0, float("nan"), 3.0]}) + result = duckdb.sql("SELECT * FROM lf WHERE a >= 'NaN'::DOUBLE").fetchall() + assert len(result) == 1 + assert math.isnan(result[0][0]) - def test_cast_timestamp_to_date_is_pushed(self, duckdb_cursor, factory): - duckdb_cursor.execute( - "CREATE TABLE _ct AS " - "SELECT TIMESTAMP '2024-01-01 00:00:00' + INTERVAL (i) SECOND AS ts FROM range(86400) t(i)" - ) - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_ct"))) - assert _was_pushed( - duckdb_cursor, - "SELECT * FROM arrow_table WHERE CAST(ts AS DATE) = DATE '2024-01-01'", - ) + def test_nan_less_than(self): + """X < NaN is true for non-NaN values; pushes is_nan().__invert__().""" + lf = pl.LazyFrame({"a": [1.0, float("nan"), 3.0]}) + result = duckdb.sql("SELECT * FROM lf WHERE a < 'NaN'::DOUBLE").fetchall() + assert sorted(result) == [(1.0,), (3.0,)] + def test_nan_not_equal(self): + """X != NaN is true for non-NaN values; pushes is_nan().__invert__().""" + lf = pl.LazyFrame({"a": [1.0, float("nan"), 3.0]}) + result = duckdb.sql("SELECT * FROM lf WHERE a != 'NaN'::DOUBLE").fetchall() + assert sorted(result) == [(1.0,), (3.0,)] -# =========================================================================== -# 8. IS DISTINCT FROM NULL inside OR -# =========================================================================== + def test_nan_greater_than(self): + """X > NaN is always false; pushes lit(false).""" + lf = pl.LazyFrame({"a": [1.0, float("nan"), 3.0]}) + result = duckdb.sql("SELECT * FROM lf WHERE a > 'NaN'::DOUBLE").fetchall() + assert result == [] + def test_nan_less_than_or_equal(self): + """X <= NaN is always true; pushes lit(true).""" + lf = pl.LazyFrame({"a": [1.0, float("nan"), 3.0]}) + result = duckdb.sql("SELECT * FROM lf WHERE a <= 'NaN'::DOUBLE").fetchall() + assert len(result) == 3 + + ##### IS_NULL / IS_NOT_NULL (triggered via DISTINCT FROM NULL inside OR) + + def test_is_null_filter(self): + """IS NOT DISTINCT FROM NULL inside an OR pushes IS_NULL as a child of CONJUNCTION_OR.""" + lf = pl.LazyFrame({"a": [1, None, 3, None, 5]}) + result = duckdb.sql("SELECT * FROM lf WHERE a = 1 OR a IS NOT DISTINCT FROM NULL").fetchall() + values = [row[0] for row in result] + assert values.count(None) == 2 + assert 1 in values + assert len(values) == 3 + + def test_is_not_null_filter(self): + """IS DISTINCT FROM NULL inside an OR pushes IS_NOT_NULL as a child of CONJUNCTION_OR.""" + lf = pl.LazyFrame({"a": [1, None, 3, None, 5]}) + result = duckdb.sql("SELECT * FROM lf WHERE a = 1 OR a IS DISTINCT FROM NULL").fetchall() + assert sorted(result) == [(1,), (3,), (5,)] + + # ── CONJUNCTION_AND ── + + def test_conjunction_and_range(self): + """BETWEEN on a single column pushes a CONJUNCTION_AND with GTE + LTE children.""" + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + result = duckdb.sql("SELECT * FROM lf WHERE a BETWEEN 2 AND 4").fetchall() + assert result == [(2,), (3,), (4,)] -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestDistinctFromNullOrPushdown: - """``IS DISTINCT FROM NULL OR ...`` produces an IS_NOT_NULL ExpressionFilter.""" + def test_conjunction_and_multi_column(self): + """Filters on two different columns combine via AND in TransformFilter.""" + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5], "b": ["x", "y", "x", "y", "x"]}) + result = duckdb.sql("SELECT * FROM lf WHERE a > 2 AND b = 'x'").fetchall() + assert result == [(3, "x"), (5, "x")] - @pytest.fixture(autouse=True) - def _with_nulls(self, duckdb_cursor, factory): - duckdb_cursor.execute("CREATE TABLE _d AS SELECT * FROM (VALUES (1), (NULL), (5), (10)) t(a)") - duckdb_cursor.register("arrow_table", factory(duckdb_cursor.table("_d"))) + ##### CONJUNCTION_OR - def test_distinct_from_null_or_eq_is_pushed(self, duckdb_cursor, factory): - assert _was_pushed( - duckdb_cursor, - "SELECT a FROM arrow_table WHERE a IS DISTINCT FROM NULL OR a = 5", - ) + def test_conjunction_or(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + result = duckdb.sql("SELECT * FROM lf WHERE a = 1 OR a = 5").fetchall() + assert sorted(result) == [(1,), (5,)] - def test_distinct_from_null_or_eq_correct(self, duckdb_cursor, factory): - rows = duckdb_cursor.execute( - "SELECT a FROM arrow_table WHERE a IS DISTINCT FROM NULL OR a = 5 ORDER BY a" - ).fetchall() - assert rows == [(1,), (5,), (10,)] + ##### IN_FILTER - def test_not_distinct_from_null_or_eq_correct(self, duckdb_cursor, factory): - rows = duckdb_cursor.execute( - "SELECT a FROM arrow_table WHERE a IS NOT DISTINCT FROM NULL OR a = 5 ORDER BY a NULLS FIRST" - ).fetchall() - assert rows == [(None,), (5,)] + def test_in_filter(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + result = duckdb.sql("SELECT * FROM lf WHERE a IN (2, 4)").fetchall() + assert sorted(result) == [(2,), (4,)] + ##### STRUCT_EXTRACT -# =========================================================================== -# 9. Special-shape filters: optional, dynamic top-N, join -# =========================================================================== + def test_struct_extract(self): + lf = pl.LazyFrame({"s": [{"x": 1, "y": "a"}, {"x": 2, "y": "b"}, {"x": 3, "y": "c"}]}) + result = duckdb.sql("SELECT * FROM lf WHERE s.x > 1").fetchall() + assert len(result) == 2 + assert all(row[0]["x"] > 1 for row in result) + ##### OPTIONAL_FILTER -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestOptionalFilter: - """An OptionalFilter is allowed to silently fail. + def test_optional_filter(self): + """OR filters are wrapped in OPTIONAL_FILTER by DuckDB's optimizer.""" + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + result = duckdb.sql("SELECT * FROM lf WHERE a = 1 OR a = 3").fetchall() + assert sorted(result) == [(1,), (3,)] - The engine reapplies it above the scan. The result must remain correct. - """ + ##### DYNAMIC_FILTER via TOP_N (issue #460(a)) - def test_no_crash_correct_result(self, factory): - con = duckdb.connect() - con.execute( - "CREATE TABLE _t AS SELECT * FROM (VALUES " - "('id', 100), ('product_code', 100), ('price', 100), ('quantity', 45), ('category', 5), " - "('is_available', 3), ('rating', 6), ('discount', 39), ('color', 5)) t(column_name, cardinality)" - ) - cardinality_table = factory(con.table("_t")) - con.register("cardinality_table", cardinality_table) - result = con.execute( - "SELECT * FROM cardinality_table WHERE cardinality > 1 ORDER BY cardinality ASC" - ).fetchall() - assert result == [ - ("is_available", 3), - ("category", 5), - ("color", 5), - ("rating", 6), - ("discount", 39), - ("quantity", 45), - ("id", 100), - ("product_code", 100), - ("price", 100), - ] - - def test_top_n_nulls_first_includes_min(self, factory): + def test_top_n_nulls_first_includes_min(self): """ORDER BY x ASC NULLS FIRST LIMIT 1 pushes OPTIONAL(IS_NULL OR DYNAMIC_FILTER) into the scan. The OR branch must not be partially translated: dropping the @@ -525,100 +161,24 @@ def test_top_n_nulls_first_includes_min(self, factory): result = duckdb.sql("SELECT * FROM lf ORDER BY x ASC NULLS FIRST LIMIT 1").fetchall() assert result == [(1,)] - -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestDynamicFilter: - """The top-N optimization installs a dynamic filter. - - The walker returns ``py::none()`` for those (DuckDB applies them above - the scan). - """ - - def test_topn_dynamic_filter(self, duckdb_cursor, factory): - duckdb_cursor.execute( - "CREATE TABLE _t AS SELECT * FROM (VALUES " - "(3), (24), (234), (234), (234), (234), (234), (234), (234), (45), (2), (5), (2), (45)) t(a)" - ) - duckdb_cursor.register("t", factory(duckdb_cursor.table("_t"))) - rows = duckdb_cursor.sql("SELECT a FROM t ORDER BY a LIMIT 11").fetchall() - assert len(rows) == 11 - - -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestJoinFilterPushdown: - """Join pushdown between two polars-backed scans must produce the right count. - - (The join's runtime filters take a separate code path that doesn't reach - the static-filter walker — see TestCanaries below.) - """ - - def test_two_polars_tables(self, factory): - con = duckdb.connect() - con.execute("CREATE TABLE probe AS SELECT range a FROM range(10000)") - con.execute("CREATE TABLE build AS SELECT (random()*9999)::INT b FROM range(20)") - con.register("probe_arrow", factory(con.table("probe"))) - con.register("build_arrow", factory(con.table("build"))) - assert con.execute("SELECT count(*) FROM probe_arrow, build_arrow WHERE a = b").fetchall() == [(20,)] - - -# =========================================================================== -# 10. Unsupported-type fallback (filter applied above the scan) -# =========================================================================== - - -@pytest.mark.parametrize("factory", POLARS_FACTORIES) -class TestUnsupportedTypes: - """``UHUGEINT`` filters must not crash. - - The filter is applied above the scan instead of being pushed down. - """ - - def test_uhugeint_single_filter(self, factory): - con = duckdb.connect() - con.execute( - "CREATE TABLE t AS SELECT i::INTEGER a, i::VARCHAR b, i::UHUGEINT c, i::INTEGER d FROM range(5) tbl(i)" - ) - con.register("arrow_tbl", factory(con.table("t"))) - assert con.execute("FROM arrow_tbl WHERE c = 3").fetchall() == [(3, "3", 3, 3)] - - def test_uhugeint_mixed_with_supported(self, factory): - con = duckdb.connect() - con.execute( - "CREATE TABLE t AS SELECT i::INTEGER a, i::VARCHAR b, i::UHUGEINT c, i::INTEGER d FROM range(5) tbl(i)" - ) - con.register("arrow_tbl", factory(con.table("t"))) - assert con.execute("FROM arrow_tbl WHERE c < 4 AND a > 2").fetchall() == [(3, "3", 3, 3)] - assert con.execute("FROM arrow_tbl WHERE a > 2 AND c < 4 AND b = '3'").fetchall() == [(3, "3", 3, 3)] - assert con.execute("FROM arrow_tbl WHERE a > 2 AND c < 4 AND b = '0'").fetchall() == [] - - -# =========================================================================== -# 11. Produce-path interactions (LazyFrame branch in arrow_array_stream.cpp) -# -# These exercise the cache reuse and repeated-scan behaviour of the -# LazyFrame produce path. Only meaningful for the LazyFrame factory. -# =========================================================================== - - -class TestLazyFrameProducePath: - """Behaviour specific to the LazyFrame branch in ``arrow_array_stream.cpp``.""" + ##### Produce path, no filters def test_unfiltered_scan(self): - con = duckdb.connect() lf = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) - con.register("lf", lf) - result = con.sql("SELECT * FROM lf").fetchall() + result = duckdb.sql("SELECT * FROM lf").fetchall() assert result == [(1, 4), (2, 5), (3, 6)] + ##### Produce path, column projection + def test_column_projection(self): - con = duckdb.connect() lf = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) - con.register("lf", lf) - result = con.sql("SELECT a, c FROM lf").fetchall() + result = duckdb.sql("SELECT a, c FROM lf").fetchall() assert result == [(1, 7), (2, 8), (3, 9)] + ##### Produce path, cached DataFrame reuse + def test_cached_dataframe_reuse(self): - """Repeated unfiltered scans on a registered LazyFrame reuse the cached materialised table.""" + """Repeated unfiltered scans on a registered LazyFrame reuse the cached DataFrame.""" con = duckdb.connect() lf = pl.LazyFrame({"a": [1, 2, 3]}) con.register("my_lf", lf) @@ -626,6 +186,8 @@ def test_cached_dataframe_reuse(self): r2 = con.sql("SELECT * FROM my_lf").fetchall() assert r1 == r2 == [(1,), (2,), (3,)] + ##### Produce path, filter + collect (no cache) + def test_filtered_scan_not_cached(self): """Filtered scans collect a new DataFrame each time (not cached).""" con = duckdb.connect() @@ -636,68 +198,9 @@ def test_filtered_scan_not_cached(self): assert sorted(r1) == [(4,), (5,)] assert sorted(r2) == [(1,), (2,)] + ##### Empty result + def test_empty_result(self): - con = duckdb.connect() lf = pl.LazyFrame({"a": [1, 2, 3]}) - con.register("lf", lf) - result = con.sql("SELECT * FROM lf WHERE a > 100").fetchall() + result = duckdb.sql("SELECT * FROM lf WHERE a > 100").fetchall() assert result == [] - - -# =========================================================================== -# 12. Regressions -# =========================================================================== - - -class TestRegressions: - """Bug-fix regressions specific to polars.""" - - def test_nan_comparison_uses_is_nan(self): - """NaN equality must produce ``is_nan()`` on the polars side, not literal-NaN equality.""" - lf = pl.LazyFrame({"a": [1.0, float("nan"), 3.0]}) - result = duckdb.sql("SELECT * FROM lf WHERE a = 'NaN'::DOUBLE").fetchall() - assert len(result) == 1 - assert math.isnan(result[0][0]) - - -# =========================================================================== -# 13. Canaries — behaviour we expect to change upstream -# =========================================================================== - - -class TestCanaries: - """If any of these starts passing, the upstream behaviour changed.""" - - @pytest.mark.xfail(reason="DuckDB does not push IS_NULL as a root TableFilter into arrow scans") - def test_is_null_pushes_into_arrow_scan(self): - con = duckdb.connect() - lf = pl.LazyFrame({"a": [1, None, 3]}) - con.register("arrow_table", lf) - assert _was_pushed(con, "SELECT a FROM arrow_table WHERE a IS NULL") - - @pytest.mark.xfail(reason="DuckDB does not push struct IS_NULL into arrow scans") - def test_struct_is_null_pushes(self): - con = duckdb.connect() - lf = pl.LazyFrame({"s": [{"x": 1}, None, {"x": 3}]}) - con.register("arrow_table", lf) - assert _was_pushed(con, "SELECT s FROM arrow_table WHERE s.x IS NULL") - - @pytest.mark.xfail(reason="filter_combiner does not currently rewrite struct IN to IN_FILTER") - def test_struct_in_pushes(self): - con = duckdb.connect() - lf = pl.LazyFrame({"s": [{"x": 1}, {"x": 2}, {"x": 3}]}) - con.register("arrow_table", lf) - assert _was_pushed(con, "SELECT s FROM arrow_table WHERE s.x IN (1, 2)") - - @pytest.mark.xfail(reason="Bloom filters never reach PolarsFilterPushdown::TransformFilter") - def test_bloom_filter_reaches_walker(self): - # If this ever flips, BLOOM_FILTER reached the walker and should be handled. - con = duckdb.connect() - con.execute("CREATE TABLE build AS SELECT (i*2)::BIGINT AS k FROM range(50000) t(i)") - con.execute("CREATE TABLE probe_src AS SELECT i::BIGINT AS k FROM range(50000) t(i)") - con.register("probe", to_polars_lazyframe(con.table("probe_src"))) - plan = con.execute("EXPLAIN SELECT count(*) FROM probe JOIN build USING (k)").fetchone()[1] - block = _arrow_scan_block(plan) - assert block is not None - # If a bloom filter reaches the walker, it would show up as a Filters: line. - assert "Filters:" in block diff --git a/tests/fast/arrow/test_timestamp_timezone.py b/tests/fast/arrow/test_timestamp_timezone.py index 0e4661bc..cec7c20d 100644 --- a/tests/fast/arrow/test_timestamp_timezone.py +++ b/tests/fast/arrow/test_timestamp_timezone.py @@ -37,17 +37,17 @@ def test_timestamp_timezone_overflow(self, duckdb_cursor): with pytest.raises(duckdb.ConversionException, match="Could not convert"): duckdb.from_arrow(arrow_table).execute().fetchall() - @pytest.mark.parametrize("precision", ["us", "s", "ns", "ms"]) - def test_timestamp_tz_to_arrow(self, duckdb_cursor, precision): + def test_timestamp_tz_to_arrow(self, duckdb_cursor): + precisions = ["us", "s", "ns", "ms"] current_time = datetime.datetime(2017, 11, 28, 23, 55, 59) con = duckdb.connect() - expected_precision = "ns" if precision == "ns" else "us" - for timezone in timezones: - con.execute("SET TimeZone = '" + timezone + "'") - arrow_table = generate_table(current_time, precision, timezone) - res = con.from_arrow(arrow_table).to_arrow_table() - assert res[0].type == pa.timestamp(expected_precision, tz=timezone) - assert res == generate_table(current_time, expected_precision, timezone) + for precision in precisions: + for timezone in timezones: + con.execute("SET TimeZone = '" + timezone + "'") + arrow_table = generate_table(current_time, precision, timezone) + res = con.from_arrow(arrow_table).to_arrow_table() + assert res[0].type == pa.timestamp("us", tz=timezone) + assert res == generate_table(current_time, "us", timezone) def test_timestamp_tz_with_null(self, duckdb_cursor): con = duckdb.connect() diff --git a/tests/fast/pandas/test_fetch_nested.py b/tests/fast/pandas/test_fetch_nested.py index 66d508c5..3bf46c10 100644 --- a/tests/fast/pandas/test_fetch_nested.py +++ b/tests/fast/pandas/test_fetch_nested.py @@ -33,11 +33,12 @@ def list_test_cases(): ) ] }), - # An untyped NULL list now has child type SQLNULL (previously it defaulted to INTEGER), so it - # converts to an object array of None rather than a masked integer array. ("SELECT list_value(NULL,NULL,NULL) as a", { 'a': [ - np.array([None, None, None], dtype=object) + np.ma.array( + [0, 0, 0], + mask=[1, 1, 1], + ) ] }), ("SELECT list_value() as a", { diff --git a/tests/fast/relational_api/test_rapi_query.py b/tests/fast/relational_api/test_rapi_query.py index 2fa34436..25f8c323 100644 --- a/tests/fast/relational_api/test_rapi_query.py +++ b/tests/fast/relational_api/test_rapi_query.py @@ -38,19 +38,6 @@ def test_query_chain(self, steps): result = rel.execute() assert len(result.fetchall()) == amount - def test_query_chain_using_alias(self): - # Test query chaining using DuckDBPyRelation.alias - con = duckdb.connect() - con.execute("CREATE TABLE raw (yr INT, facility VARCHAR, val DOUBLE)") - con.execute("INSERT INTO raw VALUES (2020, 'F001', 1.0), (2021, 'F001', 2.0)") - data = con.sql("SELECT * FROM raw") - step1 = data.query(data.alias, f"SELECT *, val * 2 AS val_doubled FROM {data.alias}") - step2 = step1.query(step1.alias, f"SELECT *, val_doubled + 1 AS val_incremented FROM {step1.alias}") - assert step2.fetchall() == [ - (2020, "F001", 1.0, 2.0, 3.0), - (2021, "F001", 2.0, 4.0, 5.0), - ] - @pytest.mark.parametrize("input", [[5, 4, 3], [], [1000]]) def test_query_table(self, tbl_table, input): con = duckdb.default_connection() diff --git a/tests/fast/spark/test_spark_types.py b/tests/fast/spark/test_spark_types.py index af26ec1e..c9bd12ee 100644 --- a/tests/fast/spark/test_spark_types.py +++ b/tests/fast/spark/test_spark_types.py @@ -32,7 +32,6 @@ TimeNTZType, TimestampMillisecondNTZType, TimestampNanosecondNTZType, - TimestampNanosecondType, TimestampNTZType, TimestampSecondNTZType, TimestampType, @@ -56,7 +55,6 @@ def test_all_types_schema(self, spark): medium_enum, large_enum, 'union', - empty_struct, fixed_int_array, fixed_varchar_array, fixed_nested_int_array, @@ -88,11 +86,10 @@ def test_all_types_schema(self, spark): StructField("time", TimeNTZType(), True), StructField("timestamp", TimestampNTZType(), True), StructField("timestamp_s", TimestampSecondNTZType(), True), - StructField("timestamp_ms", TimestampMillisecondNTZType(), True), - StructField("timestamp_ns", TimestampNanosecondNTZType(), True), + StructField("timestamp_ms", TimestampNanosecondNTZType(), True), + StructField("timestamp_ns", TimestampMillisecondNTZType(), True), StructField("time_tz", TimeType(), True), StructField("timestamp_tz", TimestampType(), True), - StructField("timestamp_tz_ns", TimestampNanosecondType(), True), StructField("float", FloatType(), True), StructField("double", DoubleType(), True), StructField("dec_4_1", DecimalType(4, 1), True), diff --git a/tests/fast/test_case_alias.py b/tests/fast/test_case_alias.py index 84a94fc7..f99b994e 100644 --- a/tests/fast/test_case_alias.py +++ b/tests/fast/test_case_alias.py @@ -15,22 +15,20 @@ def test_case_alias(self, duckdb_cursor): assert r1["CoL2"][0] == 1.05 assert r1["CoL2"][1] == 17 - # An explicit column reference takes its output name from the casing as written in the query (COL2), - # unlike `select *` above which preserves the source column's casing (CoL2). r2 = con.from_df(df).query("df", "select COL1, COL2 from df").df() assert r2["COL1"][0] == "val1" assert r2["COL1"][1] == "val3" - assert r2["COL2"][0] == 1.05 - assert r2["COL2"][1] == 17 + assert r2["CoL2"][0] == 1.05 + assert r2["CoL2"][1] == 17 r3 = con.from_df(df).query("df", "select COL1, COL2 from df ORDER BY COL1").df() assert r3["COL1"][0] == "val1" assert r3["COL1"][1] == "val3" - assert r3["COL2"][0] == 1.05 - assert r3["COL2"][1] == 17 + assert r3["CoL2"][0] == 1.05 + assert r3["CoL2"][1] == 17 r4 = con.from_df(df).query("df", "select COL1, COL2 from df GROUP BY COL1, COL2 ORDER BY COL1").df() assert r4["COL1"][0] == "val1" assert r4["COL1"][1] == "val3" - assert r4["COL2"][0] == 1.05 - assert r4["COL2"][1] == 17 + assert r4["CoL2"][0] == 1.05 + assert r4["CoL2"][1] == 17 diff --git a/tests/fast/test_expression.py b/tests/fast/test_expression.py index fef81f22..c7eee6c1 100644 --- a/tests/fast/test_expression.py +++ b/tests/fast/test_expression.py @@ -202,8 +202,8 @@ def test_column_expression_explain(self): res = rel.explain() assert "c0" in res assert "c1" in res - # the physical plan now renders projection column names (c0, c1, c2) rather than literal constant values - assert "c2" in res + # 'c2' is not in the explain result because it shows NULL instead + assert "NULL" in res res = rel.fetchall() assert res == [("a", 42, None)] diff --git a/tests/fast/test_filesystem.py b/tests/fast/test_filesystem.py index a134afad..758a243e 100644 --- a/tests/fast/test_filesystem.py +++ b/tests/fast/test_filesystem.py @@ -172,7 +172,7 @@ def test_database_attach(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): write_errors = intercept(monkeypatch, fsspec.implementations.local.LocalFileOpener, "write") conn.register_filesystem(fs) db_path_posix = str(PurePosixPath(tmp_path.as_posix()) / "hello.db") - conn.execute(f"ATTACH 'file:///{db_path_posix}'") + conn.execute(f"ATTACH 'file://{db_path_posix}'") conn.execute("INSERT INTO hello.t VALUES (1)") diff --git a/tests/fast/test_profiler.py b/tests/fast/test_profiler.py index b7538fda..9e023b0e 100644 --- a/tests/fast/test_profiler.py +++ b/tests/fast/test_profiler.py @@ -20,29 +20,26 @@ def test_profiler_matches_expected_format(self, profiling_connection, tmp_path_f profiling_info_json = profiling_info.to_json() assert isinstance(profiling_info_json, str) - # Test expected metrics are there and profiling is json loadable. The profiling output is now grouped - # into top-level sections (the flat per-metric keys moved underneath these, e.g. latency -> query.total_time, - # total_bytes_read -> io.total_bytes_read, system_peak_buffer_memory -> system.peak_buffer_memory). + # Test expected metrics are there and profiling is json loadable profiling_dict = profiling_info.to_pydict() expected_keys = { - "query", - "system", - "io", - "operator", - "optimizer", - "physical_planner", - "planner", - "parser", + "query_name", + "total_bytes_written", + "total_bytes_read", + "system_peak_temp_dir_size", + "system_peak_buffer_memory", + "rows_returned", + "result_set_size", + "latency", + "cumulative_rows_scanned", + "cumulative_cardinality", + "cpu_time", + "extra_info", + "blocked_thread_time", + "children", } assert expected_keys.issubset(profiling_dict.keys()) - @pytest.mark.xfail( - reason="query_graph HTML renderer (duckdb/query_graph/__main__.py) is not yet updated for the " - "restructured profiling output: it still walks the old flat children/operator_type/cpu_time tree and " - "reads flat metric keys (latency, total_bytes_read, ...) that are now grouped under " - "query/system/io/operator. Needs a renderer rewrite. See memory: project_query_graph_renderer_outdated.", - strict=False, - ) def test_profiler_html_output(self, profiling_connection, tmp_path_factory): tmp_dir = tmp_path_factory.mktemp("profiler", numbered=True) profiling_info = ProfilingInfo(profiling_connection) diff --git a/tests/fast/test_runtime_error.py b/tests/fast/test_runtime_error.py index 8107ae5f..62bf7589 100644 --- a/tests/fast/test_runtime_error.py +++ b/tests/fast/test_runtime_error.py @@ -120,7 +120,7 @@ def test_conn_prepared_statement_error(self): conn.execute("create table integers (a integer, b integer)") with pytest.raises( duckdb.InvalidInputException, - match="Values were not provided for the following parameters: 2", + match="Values were not provided for the following prepared statement parameters: 2", ): conn.execute("select * from integers where a =? and b=?", [1])