Skip to content

Commit 119fd2c

Browse files
committed
implement icebug-memory
1 parent d5fdcc4 commit 119fd2c

5 files changed

Lines changed: 413 additions & 0 deletions

File tree

src_cpp/include/py_connection.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ class PyConnection {
5757
py::object arrowTable);
5858
std::unique_ptr<PyQueryResult> createArrowRelTable(const std::string& tableName,
5959
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName);
60+
std::unique_ptr<PyQueryResult> createArrowCsrRelTable(const std::string& tableName,
61+
const std::string& srcTableName, const std::string& dstTableName, py::object fwdIndices,
62+
py::object fwdIndptr, py::object bwdIndices = py::none(),
63+
py::object bwdIndptr = py::none());
6064
std::unique_ptr<PyQueryResult> dropArrowTable(const std::string& tableName);
6165

6266
static Value transformPythonValue(const py::handle& val);

src_cpp/py_connection.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ void PyConnection::initialize(py::handle& m) {
5353
py::arg("arrow_table"))
5454
.def("create_arrow_rel_table", &PyConnection::createArrowRelTable, py::arg("table_name"),
5555
py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name"))
56+
.def("create_arrow_csr_rel_table", &PyConnection::createArrowCsrRelTable,
57+
py::arg("table_name"), py::arg("src_table_name"), py::arg("dst_table_name"),
58+
py::arg("fwd_indices"), py::arg("fwd_indptr"), py::arg("bwd_indices") = py::none(),
59+
py::arg("bwd_indptr") = py::none())
5660
.def("drop_arrow_table", &PyConnection::dropArrowTable, py::arg("table_name"));
5761
PyDateTime_IMPORT;
5862
}
@@ -1093,6 +1097,89 @@ std::unique_ptr<PyQueryResult> PyConnection::createArrowRelTable(const std::stri
10931097
return checkAndWrapQueryResult(result.queryResult, state);
10941098
}
10951099

1100+
static std::pair<ArrowSchemaWrapper, std::vector<ArrowArrayWrapper>> exportPyArrowTable(
1101+
py::object& tbl) {
1102+
ArrowSchemaWrapper schema;
1103+
tbl.attr("schema").attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema));
1104+
std::vector<ArrowArrayWrapper> arrays;
1105+
py::list batches = tbl.attr("to_batches")();
1106+
for (auto& batch : batches) {
1107+
arrays.emplace_back();
1108+
batch.attr("_export_to_c")(reinterpret_cast<uint64_t>(&arrays.back()));
1109+
}
1110+
return {std::move(schema), std::move(arrays)};
1111+
}
1112+
1113+
static py::object toPyArrow(const py::object& obj,
1114+
const std::shared_ptr<PythonCachedImport>& importCache) {
1115+
if (PyConnection::isPandasDataframe(obj)) {
1116+
return importCache->pyarrow.lib.Table.from_pandas()(obj);
1117+
}
1118+
1119+
if (PyConnection::isPolarsDataframe(obj)) {
1120+
return obj.attr("to_arrow")();
1121+
}
1122+
1123+
if (PyConnection::isPyArrowTable(obj)) {
1124+
return obj;
1125+
}
1126+
1127+
throw RuntimeException("Expected a pyarrow Table, polars DataFrame, or pandas DataFrame");
1128+
}
1129+
1130+
std::unique_ptr<PyQueryResult> PyConnection::createArrowCsrRelTable(const std::string& tableName,
1131+
const std::string& srcTableName, const std::string& dstTableName, py::object fwdIndices,
1132+
py::object fwdIndptr, py::object bwdIndices, py::object bwdIndptr) {
1133+
auto& stateRef = refState();
1134+
py::gil_scoped_acquire acquire;
1135+
1136+
bool hasBwd = !bwdIndices.is_none();
1137+
if (hasBwd != !bwdIndptr.is_none()) {
1138+
throw RuntimeException("bwd_indices and bwd_indptr must both be provided or both be None");
1139+
}
1140+
1141+
fwdIndices = toPyArrow(fwdIndices, importCache);
1142+
fwdIndptr = toPyArrow(fwdIndptr, importCache);
1143+
1144+
py::list keepAlive;
1145+
keepAlive.append(fwdIndices);
1146+
keepAlive.append(fwdIndices.attr("to_batches")());
1147+
keepAlive.append(fwdIndptr);
1148+
keepAlive.append(fwdIndptr.attr("to_batches")());
1149+
1150+
auto [fwdIdxSchema, fwdIdxArrays] = exportPyArrowTable(fwdIndices);
1151+
auto [fwdIpSchema, fwdIpArrays] = exportPyArrowTable(fwdIndptr);
1152+
1153+
std::optional<ArrowSchemaWrapper> bwdIdxSchema;
1154+
std::optional<std::vector<ArrowArrayWrapper>> bwdIdxArrays;
1155+
std::optional<ArrowSchemaWrapper> bwdIpSchema;
1156+
std::optional<std::vector<ArrowArrayWrapper>> bwdIpArrays;
1157+
if (hasBwd) {
1158+
bwdIndices = toPyArrow(bwdIndices, importCache);
1159+
bwdIndptr = toPyArrow(bwdIndptr, importCache);
1160+
keepAlive.append(bwdIndices);
1161+
keepAlive.append(bwdIndices.attr("to_batches")());
1162+
keepAlive.append(bwdIndptr);
1163+
keepAlive.append(bwdIndptr.attr("to_batches")());
1164+
auto [bis, bia] = exportPyArrowTable(bwdIndices);
1165+
auto [bps, bpa] = exportPyArrowTable(bwdIndptr);
1166+
bwdIdxSchema = std::move(bis);
1167+
bwdIdxArrays = std::move(bia);
1168+
bwdIpSchema = std::move(bps);
1169+
bwdIpArrays = std::move(bpa);
1170+
}
1171+
1172+
auto result = ArrowTableSupport::createArrowCsrRelTable(stateRef.ref(), tableName, srcTableName,
1173+
dstTableName, std::move(fwdIdxSchema), std::move(fwdIdxArrays), std::move(fwdIpSchema),
1174+
std::move(fwdIpArrays), std::move(bwdIdxSchema), std::move(bwdIdxArrays),
1175+
std::move(bwdIpSchema), std::move(bwdIpArrays));
1176+
if (result.queryResult && result.queryResult->isSuccess()) {
1177+
stateRef.arrowTableRefs[tableName] = std::move(keepAlive);
1178+
}
1179+
1180+
return checkAndWrapQueryResult(result.queryResult, state);
1181+
}
1182+
10961183
std::unique_ptr<PyQueryResult> PyConnection::dropArrowTable(const std::string& tableName) {
10971184
auto& stateRef = refState();
10981185
auto result = ArrowTableSupport::unregisterArrowTable(stateRef.ref(), tableName);

src_py/_lbug_capi.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,27 @@ def _setup_signatures() -> None:
339339
]
340340
_LIB.lbug_connection_drop_arrow_table.restype = ctypes.c_int
341341

342+
_LIB.lbug_connection_create_arrow_csr_rel_table.argtypes = [
343+
ctypes.POINTER(_LbugConnection), # connection
344+
ctypes.c_char_p, # table_name
345+
ctypes.c_char_p, # src_table_name
346+
ctypes.c_char_p, # dst_table_name
347+
ctypes.POINTER(_ArrowSchema), # fwd_indices_schema
348+
ctypes.POINTER(_ArrowArray), # fwd_indices_arrays
349+
ctypes.c_uint64, # fwd_indices_num_arrays
350+
ctypes.POINTER(_ArrowSchema), # fwd_indptr_schema
351+
ctypes.POINTER(_ArrowArray), # fwd_indptr_arrays
352+
ctypes.c_uint64, # fwd_indptr_num_arrays
353+
ctypes.POINTER(_ArrowSchema), # bwd_indices_schema (nullable)
354+
ctypes.POINTER(_ArrowArray), # bwd_indices_arrays (nullable)
355+
ctypes.c_uint64, # bwd_indices_num_arrays
356+
ctypes.POINTER(_ArrowSchema), # bwd_indptr_schema (nullable)
357+
ctypes.POINTER(_ArrowArray), # bwd_indptr_arrays (nullable)
358+
ctypes.c_uint64, # bwd_indptr_num_arrays
359+
ctypes.POINTER(_LbugQueryResult), # out_query_result
360+
]
361+
_LIB.lbug_connection_create_arrow_csr_rel_table.restype = ctypes.c_int
362+
342363
_LIB.lbug_prepared_statement_destroy.argtypes = [
343364
ctypes.POINTER(_LbugPreparedStatement)
344365
]
@@ -2340,3 +2361,63 @@ def create_arrow_rel_table(
23402361
if state != _LBUG_SUCCESS and not result._query_result:
23412362
_check_state(state, "Failed to create Arrow relationship table")
23422363
return QueryResult(result)
2364+
2365+
def create_arrow_csr_rel_table(
2366+
self,
2367+
table_name: str,
2368+
src_table_name: str,
2369+
dst_table_name: str,
2370+
fwd_indices: Any,
2371+
fwd_indptr: Any,
2372+
bwd_indices: Any = None,
2373+
bwd_indptr: Any = None,
2374+
) -> QueryResult:
2375+
has_bwd = bwd_indices is not None
2376+
if has_bwd != (bwd_indptr is not None):
2377+
raise ValueError(
2378+
"bwd_indices and bwd_indptr must both be provided or both be None"
2379+
)
2380+
2381+
_fi_tbl, fi_schema, fi_arrays, _fi_b = self._export_arrow_table(fwd_indices)
2382+
_fp_tbl, fp_schema, fp_arrays, _fp_b = self._export_arrow_table(fwd_indptr)
2383+
2384+
if has_bwd:
2385+
_bi_tbl, bi_schema, bi_arrays, _bi_b = self._export_arrow_table(bwd_indices)
2386+
_bp_tbl, bp_schema, bp_arrays, _bp_b = self._export_arrow_table(bwd_indptr)
2387+
bi_schema_ref = ctypes.byref(bi_schema)
2388+
bi_arrays_ref = bi_arrays
2389+
bi_num = len(bi_arrays)
2390+
bp_schema_ref = ctypes.byref(bp_schema)
2391+
bp_arrays_ref = bp_arrays
2392+
bp_num = len(bp_arrays)
2393+
else:
2394+
bi_schema_ref = None
2395+
bi_arrays_ref = None
2396+
bi_num = 0
2397+
bp_schema_ref = None
2398+
bp_arrays_ref = None
2399+
bp_num = 0
2400+
2401+
result = _LbugQueryResult()
2402+
state = _LIB.lbug_connection_create_arrow_csr_rel_table(
2403+
ctypes.byref(self._connection),
2404+
table_name.encode("utf-8"),
2405+
src_table_name.encode("utf-8"),
2406+
dst_table_name.encode("utf-8"),
2407+
ctypes.byref(fi_schema),
2408+
fi_arrays,
2409+
len(fi_arrays),
2410+
ctypes.byref(fp_schema),
2411+
fp_arrays,
2412+
len(fp_arrays),
2413+
bi_schema_ref,
2414+
bi_arrays_ref,
2415+
bi_num,
2416+
bp_schema_ref,
2417+
bp_arrays_ref,
2418+
bp_num,
2419+
ctypes.byref(result),
2420+
)
2421+
if state != _LBUG_SUCCESS and not result._query_result:
2422+
_check_state(state, "Failed to create Arrow CSR relationship table")
2423+
return QueryResult(result)

src_py/connection.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,3 +857,76 @@ def create_arrow_rel_table(
857857
if not query_result_internal.isSuccess():
858858
raise RuntimeError(query_result_internal.getErrorMessage())
859859
return QueryResult(self, query_result_internal)
860+
861+
def create_arrow_csr_rel_table(
862+
self,
863+
table_name: str,
864+
src_table_name: str,
865+
dst_table_name: str,
866+
fwd_indices: Any,
867+
fwd_indptr: Any,
868+
bwd_indices: Any = None,
869+
bwd_indptr: Any = None,
870+
) -> QueryResult:
871+
"""
872+
Create an Arrow CSR memory-backed relationship table.
873+
874+
Parameters
875+
----------
876+
table_name : str
877+
Name of the relationship table to create.
878+
src_table_name : str
879+
Source node table name.
880+
dst_table_name : str
881+
Destination node table name.
882+
fwd_indices : Any
883+
Forward adjacency indices table (struct array: child[0] = UINT64 dst offsets,
884+
optional further children are edge properties). Accepts pandas, polars, or pyarrow.
885+
fwd_indptr : Any
886+
Forward adjacency indptr table (struct array: child[0] = UINT64 row pointers).
887+
bwd_indices : Any, optional
888+
Backward adjacency indices table. Must be provided together with bwd_indptr.
889+
bwd_indptr : Any, optional
890+
Backward adjacency indptr table. Must be provided together with bwd_indices.
891+
892+
Returns
893+
-------
894+
QueryResult
895+
Result of the table creation query.
896+
"""
897+
has_bwd = bwd_indices is not None
898+
if has_bwd != (bwd_indptr is not None):
899+
raise ValueError(
900+
"bwd_indices and bwd_indptr must both be provided or both be None"
901+
)
902+
903+
self.init_connection()
904+
try:
905+
query_result_internal = self._connection.create_arrow_csr_rel_table(
906+
table_name,
907+
src_table_name,
908+
dst_table_name,
909+
fwd_indices,
910+
fwd_indptr,
911+
bwd_indices,
912+
bwd_indptr,
913+
)
914+
except NotImplementedError:
915+
py_connection = self._get_pybind_connection()
916+
if py_connection is None:
917+
raise
918+
self._prefer_pybind = True
919+
query_result_internal = py_connection.create_arrow_csr_rel_table(
920+
table_name,
921+
src_table_name,
922+
dst_table_name,
923+
fwd_indices,
924+
fwd_indptr,
925+
bwd_indices,
926+
bwd_indptr,
927+
)
928+
929+
if not query_result_internal.isSuccess():
930+
raise RuntimeError(query_result_internal.getErrorMessage())
931+
932+
return QueryResult(self, query_result_internal)

0 commit comments

Comments
 (0)