Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class PyConnection {
py::object arrowTable);
std::unique_ptr<PyQueryResult> createArrowRelTable(const std::string& tableName,
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName);
std::unique_ptr<PyQueryResult> createArrowCsrRelTable(const std::string& tableName,
const std::string& srcTableName, const std::string& dstTableName, py::object fwdIndices,
py::object fwdIndptr, py::object bwdIndices = py::none(),
py::object bwdIndptr = py::none());
std::unique_ptr<PyQueryResult> dropArrowTable(const std::string& tableName);

static Value transformPythonValue(const py::handle& val);
Expand Down
87 changes: 87 additions & 0 deletions src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ void PyConnection::initialize(py::handle& m) {
py::arg("arrow_table"))
.def("create_arrow_rel_table", &PyConnection::createArrowRelTable, py::arg("table_name"),
py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name"))
.def("create_arrow_csr_rel_table", &PyConnection::createArrowCsrRelTable,
py::arg("table_name"), py::arg("src_table_name"), py::arg("dst_table_name"),
py::arg("fwd_indices"), py::arg("fwd_indptr"), py::arg("bwd_indices") = py::none(),
py::arg("bwd_indptr") = py::none())
.def("drop_arrow_table", &PyConnection::dropArrowTable, py::arg("table_name"));
PyDateTime_IMPORT;
}
Expand Down Expand Up @@ -1093,6 +1097,89 @@ std::unique_ptr<PyQueryResult> PyConnection::createArrowRelTable(const std::stri
return checkAndWrapQueryResult(result.queryResult, state);
}

static std::pair<ArrowSchemaWrapper, std::vector<ArrowArrayWrapper>> exportPyArrowTable(
py::object& tbl) {
ArrowSchemaWrapper schema;
tbl.attr("schema").attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema));
std::vector<ArrowArrayWrapper> arrays;
py::list batches = tbl.attr("to_batches")();
for (auto& batch : batches) {
arrays.emplace_back();
batch.attr("_export_to_c")(reinterpret_cast<uint64_t>(&arrays.back()));
}
return {std::move(schema), std::move(arrays)};
}

static py::object toPyArrow(const py::object& obj,
const std::shared_ptr<PythonCachedImport>& importCache) {
if (PyConnection::isPandasDataframe(obj)) {
return importCache->pyarrow.lib.Table.from_pandas()(obj);
}

if (PyConnection::isPolarsDataframe(obj)) {
return obj.attr("to_arrow")();
}

if (PyConnection::isPyArrowTable(obj)) {
return obj;
}

throw RuntimeException("Expected a pyarrow Table, polars DataFrame, or pandas DataFrame");
}

std::unique_ptr<PyQueryResult> PyConnection::createArrowCsrRelTable(const std::string& tableName,
const std::string& srcTableName, const std::string& dstTableName, py::object fwdIndices,
py::object fwdIndptr, py::object bwdIndices, py::object bwdIndptr) {
auto& stateRef = refState();
py::gil_scoped_acquire acquire;

bool hasBwd = !bwdIndices.is_none();
if (hasBwd != !bwdIndptr.is_none()) {
throw RuntimeException("bwd_indices and bwd_indptr must both be provided or both be None");
}

fwdIndices = toPyArrow(fwdIndices, importCache);
fwdIndptr = toPyArrow(fwdIndptr, importCache);

py::list keepAlive;
keepAlive.append(fwdIndices);
keepAlive.append(fwdIndices.attr("to_batches")());
keepAlive.append(fwdIndptr);
keepAlive.append(fwdIndptr.attr("to_batches")());

auto [fwdIdxSchema, fwdIdxArrays] = exportPyArrowTable(fwdIndices);
auto [fwdIpSchema, fwdIpArrays] = exportPyArrowTable(fwdIndptr);

std::optional<ArrowSchemaWrapper> bwdIdxSchema;
std::optional<std::vector<ArrowArrayWrapper>> bwdIdxArrays;
std::optional<ArrowSchemaWrapper> bwdIpSchema;
std::optional<std::vector<ArrowArrayWrapper>> bwdIpArrays;
if (hasBwd) {
bwdIndices = toPyArrow(bwdIndices, importCache);
bwdIndptr = toPyArrow(bwdIndptr, importCache);
keepAlive.append(bwdIndices);
keepAlive.append(bwdIndices.attr("to_batches")());
keepAlive.append(bwdIndptr);
keepAlive.append(bwdIndptr.attr("to_batches")());
auto [bis, bia] = exportPyArrowTable(bwdIndices);
auto [bps, bpa] = exportPyArrowTable(bwdIndptr);
bwdIdxSchema = std::move(bis);
bwdIdxArrays = std::move(bia);
bwdIpSchema = std::move(bps);
bwdIpArrays = std::move(bpa);
}

auto result = ArrowTableSupport::createArrowCsrRelTable(stateRef.ref(), tableName, srcTableName,
dstTableName, std::move(fwdIdxSchema), std::move(fwdIdxArrays), std::move(fwdIpSchema),
std::move(fwdIpArrays), std::move(bwdIdxSchema), std::move(bwdIdxArrays),
std::move(bwdIpSchema), std::move(bwdIpArrays));
if (result.queryResult && result.queryResult->isSuccess()) {
stateRef.arrowTableRefs[tableName] = std::move(keepAlive);
}

return checkAndWrapQueryResult(result.queryResult, state);
}

std::unique_ptr<PyQueryResult> PyConnection::dropArrowTable(const std::string& tableName) {
auto& stateRef = refState();
auto result = ArrowTableSupport::unregisterArrowTable(stateRef.ref(), tableName);
Expand Down
80 changes: 80 additions & 0 deletions src_py/_lbug_capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,27 @@ def _setup_signatures() -> None:
]
_LIB.lbug_connection_drop_arrow_table.restype = ctypes.c_int

_LIB.lbug_connection_create_arrow_csr_rel_table.argtypes = [
ctypes.POINTER(_LbugConnection), # connection
ctypes.c_char_p, # table_name
ctypes.c_char_p, # src_table_name
ctypes.c_char_p, # dst_table_name
ctypes.POINTER(_ArrowSchema), # fwd_indices_schema
ctypes.POINTER(_ArrowArray), # fwd_indices_arrays
ctypes.c_uint64, # fwd_indices_num_arrays
ctypes.POINTER(_ArrowSchema), # fwd_indptr_schema
ctypes.POINTER(_ArrowArray), # fwd_indptr_arrays
ctypes.c_uint64, # fwd_indptr_num_arrays
ctypes.POINTER(_ArrowSchema), # bwd_indices_schema (nullable)
ctypes.POINTER(_ArrowArray), # bwd_indices_arrays (nullable)
ctypes.c_uint64, # bwd_indices_num_arrays
ctypes.POINTER(_ArrowSchema), # bwd_indptr_schema (nullable)
ctypes.POINTER(_ArrowArray), # bwd_indptr_arrays (nullable)
ctypes.c_uint64, # bwd_indptr_num_arrays
ctypes.POINTER(_LbugQueryResult), # out_query_result
]
_LIB.lbug_connection_create_arrow_csr_rel_table.restype = ctypes.c_int

_LIB.lbug_prepared_statement_destroy.argtypes = [
ctypes.POINTER(_LbugPreparedStatement)
]
Expand Down Expand Up @@ -2340,3 +2361,62 @@ def create_arrow_rel_table(
if state != _LBUG_SUCCESS and not result._query_result:
_check_state(state, "Failed to create Arrow relationship table")
return QueryResult(result)

def create_arrow_csr_rel_table(
self,
table_name: str,
src_table_name: str,
dst_table_name: str,
fwd_indices: Any,
fwd_indptr: Any,
bwd_indices: Any = None,
bwd_indptr: Any = None,
) -> QueryResult:
has_bwd = bwd_indices is not None
if has_bwd != (bwd_indptr is not None):
msg = "bwd_indices and bwd_indptr must both be provided or both be None"
raise ValueError(msg)

_fi_tbl, fi_schema, fi_arrays, _fi_b = self._export_arrow_table(fwd_indices)
_fp_tbl, fp_schema, fp_arrays, _fp_b = self._export_arrow_table(fwd_indptr)

if has_bwd:
_bi_tbl, bi_schema, bi_arrays, _bi_b = self._export_arrow_table(bwd_indices)
_bp_tbl, bp_schema, bp_arrays, _bp_b = self._export_arrow_table(bwd_indptr)
bi_schema_ref = ctypes.byref(bi_schema)
bi_arrays_ref = bi_arrays
bi_num = len(bi_arrays)
bp_schema_ref = ctypes.byref(bp_schema)
bp_arrays_ref = bp_arrays
bp_num = len(bp_arrays)
else:
bi_schema_ref = None
bi_arrays_ref = None
bi_num = 0
bp_schema_ref = None
bp_arrays_ref = None
bp_num = 0

result = _LbugQueryResult()
state = _LIB.lbug_connection_create_arrow_csr_rel_table(
ctypes.byref(self._connection),
table_name.encode("utf-8"),
src_table_name.encode("utf-8"),
dst_table_name.encode("utf-8"),
ctypes.byref(fi_schema),
fi_arrays,
len(fi_arrays),
ctypes.byref(fp_schema),
fp_arrays,
len(fp_arrays),
bi_schema_ref,
bi_arrays_ref,
bi_num,
bp_schema_ref,
bp_arrays_ref,
bp_num,
ctypes.byref(result),
)
if state != _LBUG_SUCCESS and not result._query_result:
_check_state(state, "Failed to create Arrow CSR relationship table")
return QueryResult(result)
72 changes: 72 additions & 0 deletions src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,3 +857,75 @@ def create_arrow_rel_table(
if not query_result_internal.isSuccess():
raise RuntimeError(query_result_internal.getErrorMessage())
return QueryResult(self, query_result_internal)

def create_arrow_csr_rel_table(
self,
table_name: str,
src_table_name: str,
dst_table_name: str,
fwd_indices: Any,
fwd_indptr: Any,
bwd_indices: Any = None,
bwd_indptr: Any = None,
) -> QueryResult:
"""
Create an Arrow CSR memory-backed relationship table.

Parameters
----------
table_name : str
Name of the relationship table to create.
src_table_name : str
Source node table name.
dst_table_name : str
Destination node table name.
fwd_indices : Any
Forward adjacency indices table (struct array: child[0] = UINT64 dst offsets,
optional further children are edge properties). Accepts pandas, polars, or pyarrow.
fwd_indptr : Any
Forward adjacency indptr table (struct array: child[0] = UINT64 row pointers).
bwd_indices : Any, optional
Backward adjacency indices table. Must be provided together with bwd_indptr.
bwd_indptr : Any, optional
Backward adjacency indptr table. Must be provided together with bwd_indices.

Returns
-------
QueryResult
Result of the table creation query.
"""
has_bwd = bwd_indices is not None
if has_bwd != (bwd_indptr is not None):
msg = "bwd_indices and bwd_indptr must both be provided or both be None"
raise ValueError(msg)

self.init_connection()
try:
query_result_internal = self._connection.create_arrow_csr_rel_table(
table_name,
src_table_name,
dst_table_name,
fwd_indices,
fwd_indptr,
bwd_indices,
bwd_indptr,
)
except NotImplementedError:
py_connection = self._get_pybind_connection()
if py_connection is None:
raise
self._prefer_pybind = True
query_result_internal = py_connection.create_arrow_csr_rel_table(
table_name,
src_table_name,
dst_table_name,
fwd_indices,
fwd_indptr,
bwd_indices,
bwd_indptr,
)

if not query_result_internal.isSuccess():
raise RuntimeError(query_result_internal.getErrorMessage())

return QueryResult(self, query_result_internal)
Loading
Loading