Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 250 additions & 0 deletions scripts/pgembed_ladybug_sample_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
#!/usr/bin/env python3

"""

uv run --python 3.12 --with pgembed --with 'psycopg[binary]' \
python scripts/pgembed_ladybug_sample_query.py
"""

from __future__ import annotations

import importlib.util
import os
import platform
import subprocess
import sys
import tempfile
from pathlib import Path
from urllib.parse import parse_qs, unquote, urlparse


def is_musl() -> bool:
libc_name, _ = platform.libc_ver()
if libc_name == "musl":
return True
try:
return "musl" in os.confstr("CS_GNU_LIBC_VERSION").lower()
except (AttributeError, OSError, ValueError):
return False


def import_pgembed_or_bootstrap():
try:
import pgembed

return pgembed
except ModuleNotFoundError:
if is_musl():
raise
if os.environ.get("LBUG_PGEMBED_BOOTSTRAPPED") == "1":
raise
env = os.environ.copy()
env["LBUG_PGEMBED_BOOTSTRAPPED"] = "1"
os.execvpe(
"uv",
[
"uv",
"run",
"--python",
env.get("PGEMBED_PYTHON", "3.12"),
"--with",
"pgembed",
"--with",
"psycopg[binary]",
"python",
__file__,
],
env,
)
raise RuntimeError("unreachable")


def first_query_value(query: dict[str, list[str]], key: str) -> str | None:
values = query.get(key)
return values[0] if values else None


def quote_libpq_value(value: str) -> str:
if value and not any(char.isspace() or char in "\\'" for char in value):
return value
return "'" + value.replace("\\", "\\\\").replace("'", "\\'") + "'"


def uri_to_libpq_connection_string(uri: str, database_name: str, user: str) -> str:
parsed = urlparse(uri)
query = parse_qs(parsed.query)
values = {
"dbname": database_name,
"user": user,
"host": parsed.hostname or first_query_value(query, "host") or "localhost",
"password": "ci",
}
port = parsed.port or first_query_value(query, "port")
if port is not None:
values["port"] = str(port)
if parsed.password:
values["password"] = unquote(parsed.password)
return " ".join(f"{key}={quote_libpq_value(value)}" for key, value in values.items())


def repo_root() -> Path:
return Path(__file__).resolve().parent.parent


def import_ladybug_from_submodule(root: Path):
package_dir = root / "tools" / "python_api" / "src_py"
os.environ.setdefault(
"LBUG_C_API_LIB_PATH",
str(root / "build" / "relwithdebinfo" / "src" / "liblbug.dylib"),
)
spec = importlib.util.spec_from_file_location(
"ladybug",
package_dir / "__init__.py",
submodule_search_locations=[str(package_dir)],
)
if spec is None or spec.loader is None:
raise RuntimeError(f"failed to load ladybug package from {package_dir}")
module = importlib.util.module_from_spec(spec)
sys.modules["ladybug"] = module
spec.loader.exec_module(module)
return module


def run_ladybug_query(conn, query: str):
result = conn.execute(query)
return result.get_all()


def main() -> int:
pgembed = import_pgembed_or_bootstrap()
import psycopg

root = repo_root()
lb = import_ladybug_from_submodule(root)

extension_path = root / "extension" / "postgres" / "build" / "libpostgres.lbug_extension"
if not extension_path.exists():
raise RuntimeError(f"missing postgres extension: {extension_path}")

with tempfile.TemporaryDirectory(prefix="lbug_pgembed_repro_") as tmpdir:
with pgembed.get_server(tmpdir) as pg:
admin_uri = pg.get_uri("postgres")
repro_uri = pg.get_uri("pgscan")

with psycopg.connect(admin_uri, autocommit=True) as conn:
conn.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'ci') THEN
CREATE ROLE ci WITH LOGIN SUPERUSER PASSWORD 'ci';
END IF;
END
$$;
"""
)
if (
conn.execute(
"SELECT 1 FROM pg_database WHERE datname = 'pgscan'"
).fetchone()
is None
):
conn.execute("CREATE DATABASE pgscan OWNER ci")

with psycopg.connect(repro_uri) as conn:
conn.execute(
"""
CREATE TABLE organisation (
id BIGINT PRIMARY KEY,
name TEXT NOT NULL,
score BIGINT NOT NULL,
mark DOUBLE PRECISION NOT NULL,
orgcode BIGINT NOT NULL
)
"""
)
conn.execute(
"""
INSERT INTO organisation VALUES
(1, 'ABFsUni', 4, 3.7, 325),
(4, 'CsWork', 7, 4.1, 934),
(6, 'DEsWork', 2, 4.1, 824)
"""
)
conn.commit()

conn_string = uri_to_libpq_connection_string(repro_uri, "pgscan", "ci")
db = lb.Database(":memory:", backend="capi")
conn = lb.Connection(db)
try:
run_ladybug_query(conn, f"LOAD EXTENSION '{extension_path}'")
run_ladybug_query(
conn,
f"ATTACH '{conn_string}' AS pg (dbtype POSTGRES)",
)

checks = [
(
"select-star-projection",
"CALL SQL_QUERY('pg', 'select * from organisation') "
"RETURN name, orgcode",
[["ABFsUni", 325], ["CsWork", 934], ["DEsWork", 824]],
),
(
"select-star-reordered-projection",
"CALL SQL_QUERY('pg', 'select * from organisation') "
"RETURN orgcode, name",
[[325, "ABFsUni"], [934, "CsWork"], [824, "DEsWork"]],
),
(
"select-star-filter-skipped-column",
"CALL SQL_QUERY('pg', 'select * from organisation') "
"WHERE score > 4 RETURN name, orgcode",
[["CsWork", 934]],
),
(
"explicit-query-filter-skipped-column",
"CALL SQL_QUERY('pg', "
"'select name, score, mark, orgcode from organisation') "
"WHERE score > 4 YIELD name, score, mark, orgcode AS code "
"RETURN name, code",
[["CsWork", 934]],
),
]

for name, query, expected in checks:
actual = run_ladybug_query(conn, query)
print(f"{name}: {actual}")
if actual != expected:
raise AssertionError(f"{name}: expected {expected}, got {actual}")
finally:
conn.close()
db.close()

print("postgres select-star repro passed")
return 0


if __name__ == "__main__":
if (
"uv" not in Path(sys.executable).name
and os.environ.get("LBUG_PGEMBED_BOOTSTRAPPED") != "1"
):
try:
import pgembed # noqa: F401
import psycopg # noqa: F401
except ModuleNotFoundError:
cmd = [
"uv",
"run",
"--python",
os.environ.get("PGEMBED_PYTHON", "3.12"),
"--with",
"pgembed",
"--with",
"psycopg[binary]",
"python",
__file__,
]
raise SystemExit(subprocess.run(cmd).returncode)
raise SystemExit(main())
28 changes: 26 additions & 2 deletions src/optimizer/filter_push_down_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <algorithm>
#include <array>
#include <functional>
#include <unordered_set>

#include "binder/expression/literal_expression.h"
#include "binder/expression/property_expression.h"
Expand Down Expand Up @@ -235,9 +236,32 @@ std::shared_ptr<LogicalOperator> FilterPushDownOptimizer::visitScanNodeTableRepl
std::shared_ptr<LogicalOperator> FilterPushDownOptimizer::visitTableFunctionCallReplace(
const std::shared_ptr<LogicalOperator>& op) {
auto& tableFunctionCall = op->cast<LogicalTableFunctionCall>();
auto columnPredicates = getColumnPredicateSets(tableFunctionCall.getBindData()->columns,
predicateSet.getAllPredicates());
if (!tableFunctionCall.getTableFunc().supportsPushDownFunc()) {
return finishPushDown(op);
}
std::vector<ColumnPredicateSet> columnPredicates;
std::unordered_set<const Expression*> pushedPredicates;
auto predicates = predicateSet.getAllPredicates();
for (auto& column : tableFunctionCall.getBindData()->columns) {
auto columnPredicateSet = ColumnPredicateSet();
for (auto& predicate : predicates) {
auto columnPredicate = ColumnPredicateUtil::tryConvert(*column, *predicate);
if (columnPredicate == nullptr) {
continue;
}
columnPredicateSet.addPredicate(std::move(columnPredicate));
pushedPredicates.insert(predicate.get());
}
columnPredicates.push_back(std::move(columnPredicateSet));
}
tableFunctionCall.setColumnPredicates(std::move(columnPredicates));
auto remainingPredicates = PredicateSet();
for (auto& predicate : predicates) {
if (!pushedPredicates.contains(predicate.get())) {
remainingPredicates.addPredicate(predicate);
}
}
predicateSet = std::move(remainingPredicates);
return finishPushDown(op);
}

Expand Down
41 changes: 39 additions & 2 deletions src/optimizer/foreign_join_push_down_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "common/exception/runtime.h"
#include "main/database_manager.h"
#include "planner/operator/extend/logical_extend.h"
#include "planner/operator/logical_filter.h"
#include "planner/operator/logical_flatten.h"
#include "planner/operator/logical_hash_join.h"
#include "planner/operator/logical_table_function_call.h"
Expand Down Expand Up @@ -132,6 +133,7 @@ struct ForeignJoinPatternInfo {
// Intermediate operators
const LogicalHashJoin* outerHashJoin = nullptr;
const LogicalHashJoin* innerHashJoin = nullptr;
const LogicalFilter* relFilter = nullptr;
// Original output schema
const Schema* outputSchema = nullptr;
// Table names extracted from bind data
Expand Down Expand Up @@ -221,6 +223,13 @@ static std::optional<ForeignJoinPatternInfo> matchPattern(const LogicalOperator*

// Inner hash join probe side should be EXTEND
auto extendOp = probeOp->getChild(0).get();
if (extendOp != nullptr && extendOp->getOperatorType() == LogicalOperatorType::FILTER) {
info.relFilter = extendOp->constPtrCast<LogicalFilter>();
if (extendOp->getNumChildren() < 1) {
return std::nullopt;
}
extendOp = extendOp->getChild(0).get();
}
if (extendOp == nullptr || extendOp->getOperatorType() != LogicalOperatorType::EXTEND) {
return std::nullopt;
}
Expand Down Expand Up @@ -516,8 +525,16 @@ std::shared_ptr<LogicalOperator> ForeignJoinPushDownOptimizer::visitHashJoinRepl
// "a.id" when a canonical pattern-variable expression "_N_a.id" exists.
auto allColumns = info.outputSchema->getExpressionsInScope();
expression_vector outputColumns;
std::unordered_set<std::string> outputColumnNames;
std::unordered_set<std::string> canonicalVarProps;

auto appendOutputColumn = [&](const std::shared_ptr<Expression>& column) {
if (!outputColumnNames.insert(column->getUniqueName()).second) {
return;
}
outputColumns.push_back(column);
};

auto extractCanonicalVarProp = [](const std::string& uniqueName) -> std::string {
// "_N_var.prop" -> "var.prop"
if (uniqueName.empty() || uniqueName[0] != '_') {
Expand Down Expand Up @@ -582,14 +599,30 @@ std::shared_ptr<LogicalOperator> ForeignJoinPushDownOptimizer::visitHashJoinRepl
if (hasLowercaseID(uniqueName)) {
continue;
}
outputColumns.push_back(col);
appendOutputColumn(col);
}

// The foreign join rewrite runs before projection pushdown, so the matched
// hash join's schema can be narrower than parent FILTER/ORDER BY/PROJECTION
// requirements. Keep the available graph properties in the pushed-down scan;
// projection pushdown can prune unused columns later.
auto appendPatternProperties = [&](const std::shared_ptr<NodeOrRelExpression>& pattern) {
for (auto& property : pattern->getPropertyExpressions()) {
if (property->getPropertyName().starts_with("_")) {
continue;
}
appendOutputColumn(property);
}
};
appendPatternProperties(info.extend->getBoundNode());
appendPatternProperties(info.extend->getRel());
appendPatternProperties(info.extend->getNbrNode());

// Fallback: if no property/variable columns were identified, preserve
// original scope to avoid breaking operator replacement.
if (outputColumns.empty()) {
for (auto& col : allColumns) {
outputColumns.push_back(col);
appendOutputColumn(col);
}
}

Expand All @@ -603,6 +636,10 @@ std::shared_ptr<LogicalOperator> ForeignJoinPushDownOptimizer::visitHashJoinRepl
return op;
}

if (info.relFilter != nullptr) {
result = std::make_shared<LogicalFilter>(info.relFilter->getPredicate(), std::move(result));
result->computeFlatSchema();
}
return result;
}

Expand Down
Loading
Loading