@@ -53,6 +53,10 @@ void PyConnection::initialize(py::handle& m) {
5353 py::arg (" arrow_table" ))
5454 .def (" create_arrow_rel_table" , &PyConnection::createArrowRelTable, py::arg (" table_name" ),
5555 py::arg (" arrow_table" ), py::arg (" src_table_name" ), py::arg (" dst_table_name" ))
56+ .def (" create_arrow_csr_rel_table" , &PyConnection::createArrowCsrRelTable,
57+ py::arg (" table_name" ), py::arg (" src_table_name" ), py::arg (" dst_table_name" ),
58+ py::arg (" fwd_indices" ), py::arg (" fwd_indptr" ), py::arg (" bwd_indices" ) = py::none (),
59+ py::arg (" bwd_indptr" ) = py::none ())
5660 .def (" drop_arrow_table" , &PyConnection::dropArrowTable, py::arg (" table_name" ));
5761 PyDateTime_IMPORT;
5862}
@@ -1093,6 +1097,89 @@ std::unique_ptr<PyQueryResult> PyConnection::createArrowRelTable(const std::stri
10931097 return checkAndWrapQueryResult (result.queryResult , state);
10941098}
10951099
1100+ static std::pair<ArrowSchemaWrapper, std::vector<ArrowArrayWrapper>> exportPyArrowTable (
1101+ py::object& tbl) {
1102+ ArrowSchemaWrapper schema;
1103+ tbl.attr (" schema" ).attr (" _export_to_c" )(reinterpret_cast <uint64_t >(&schema));
1104+ std::vector<ArrowArrayWrapper> arrays;
1105+ py::list batches = tbl.attr (" to_batches" )();
1106+ for (auto & batch : batches) {
1107+ arrays.emplace_back ();
1108+ batch.attr (" _export_to_c" )(reinterpret_cast <uint64_t >(&arrays.back ()));
1109+ }
1110+ return {std::move (schema), std::move (arrays)};
1111+ }
1112+
1113+ static py::object toPyArrow (const py::object& obj,
1114+ const std::shared_ptr<PythonCachedImport>& importCache) {
1115+ if (PyConnection::isPandasDataframe (obj)) {
1116+ return importCache->pyarrow .lib .Table .from_pandas ()(obj);
1117+ }
1118+
1119+ if (PyConnection::isPolarsDataframe (obj)) {
1120+ return obj.attr (" to_arrow" )();
1121+ }
1122+
1123+ if (PyConnection::isPyArrowTable (obj)) {
1124+ return obj;
1125+ }
1126+
1127+ throw RuntimeException (" Expected a pyarrow Table, polars DataFrame, or pandas DataFrame" );
1128+ }
1129+
1130+ std::unique_ptr<PyQueryResult> PyConnection::createArrowCsrRelTable (const std::string& tableName,
1131+ const std::string& srcTableName, const std::string& dstTableName, py::object fwdIndices,
1132+ py::object fwdIndptr, py::object bwdIndices, py::object bwdIndptr) {
1133+ auto & stateRef = refState ();
1134+ py::gil_scoped_acquire acquire;
1135+
1136+ bool hasBwd = !bwdIndices.is_none ();
1137+ if (hasBwd != !bwdIndptr.is_none ()) {
1138+ throw RuntimeException (" bwd_indices and bwd_indptr must both be provided or both be None" );
1139+ }
1140+
1141+ fwdIndices = toPyArrow (fwdIndices, importCache);
1142+ fwdIndptr = toPyArrow (fwdIndptr, importCache);
1143+
1144+ py::list keepAlive;
1145+ keepAlive.append (fwdIndices);
1146+ keepAlive.append (fwdIndices.attr (" to_batches" )());
1147+ keepAlive.append (fwdIndptr);
1148+ keepAlive.append (fwdIndptr.attr (" to_batches" )());
1149+
1150+ auto [fwdIdxSchema, fwdIdxArrays] = exportPyArrowTable (fwdIndices);
1151+ auto [fwdIpSchema, fwdIpArrays] = exportPyArrowTable (fwdIndptr);
1152+
1153+ std::optional<ArrowSchemaWrapper> bwdIdxSchema;
1154+ std::optional<std::vector<ArrowArrayWrapper>> bwdIdxArrays;
1155+ std::optional<ArrowSchemaWrapper> bwdIpSchema;
1156+ std::optional<std::vector<ArrowArrayWrapper>> bwdIpArrays;
1157+ if (hasBwd) {
1158+ bwdIndices = toPyArrow (bwdIndices, importCache);
1159+ bwdIndptr = toPyArrow (bwdIndptr, importCache);
1160+ keepAlive.append (bwdIndices);
1161+ keepAlive.append (bwdIndices.attr (" to_batches" )());
1162+ keepAlive.append (bwdIndptr);
1163+ keepAlive.append (bwdIndptr.attr (" to_batches" )());
1164+ auto [bis, bia] = exportPyArrowTable (bwdIndices);
1165+ auto [bps, bpa] = exportPyArrowTable (bwdIndptr);
1166+ bwdIdxSchema = std::move (bis);
1167+ bwdIdxArrays = std::move (bia);
1168+ bwdIpSchema = std::move (bps);
1169+ bwdIpArrays = std::move (bpa);
1170+ }
1171+
1172+ auto result = ArrowTableSupport::createArrowCsrRelTable (stateRef.ref (), tableName, srcTableName,
1173+ dstTableName, std::move (fwdIdxSchema), std::move (fwdIdxArrays), std::move (fwdIpSchema),
1174+ std::move (fwdIpArrays), std::move (bwdIdxSchema), std::move (bwdIdxArrays),
1175+ std::move (bwdIpSchema), std::move (bwdIpArrays));
1176+ if (result.queryResult && result.queryResult ->isSuccess ()) {
1177+ stateRef.arrowTableRefs [tableName] = std::move (keepAlive);
1178+ }
1179+
1180+ return checkAndWrapQueryResult (result.queryResult , state);
1181+ }
1182+
10961183std::unique_ptr<PyQueryResult> PyConnection::dropArrowTable (const std::string& tableName) {
10971184 auto & stateRef = refState ();
10981185 auto result = ArrowTableSupport::unregisterArrowTable (stateRef.ref (), tableName);
0 commit comments