Skip to content

Commit eff5c67

Browse files
author
Henri Froese
committed
Expose register_listing_table
This lets users nicely use `object_store` with python datafusion for partitioned dataset e.g. in S3. Closes #617
1 parent 79cd69a commit eff5c67

2 files changed

Lines changed: 109 additions & 2 deletions

File tree

datafusion/tests/test_sql.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
import pyarrow as pa
2222
import pyarrow.dataset as ds
2323
import pytest
24+
from datafusion.object_store import LocalFileSystem
2425

25-
from datafusion import udf
26+
from datafusion import udf, col
2627

2728
from . import generic as helpers
2829

@@ -374,3 +375,58 @@ def test_simple_select(ctx, tmp_path, arr):
374375
result = batches[0].column(0)
375376

376377
np.testing.assert_equal(result, arr)
378+
379+
380+
@pytest.mark.parametrize("file_sort_order", (None, [[col("int").sort(True, True)]]))
381+
@pytest.mark.parametrize("pass_schema", (True, False))
382+
def test_register_listing_table(ctx, tmp_path, pass_schema, file_sort_order):
383+
dir_root = tmp_path / "dataset_parquet_partitioned"
384+
dir_root.mkdir(exist_ok=False)
385+
(dir_root / "grp=a/date_id=20201005").mkdir(exist_ok=False, parents=True)
386+
(dir_root / "grp=a/date_id=20211005").mkdir(exist_ok=False, parents=True)
387+
(dir_root / "grp=b/date_id=20201005").mkdir(exist_ok=False, parents=True)
388+
389+
table = pa.Table.from_arrays(
390+
[
391+
[1, 2, 3, 4, 5, 6, 7],
392+
["a", "b", "c", "d", "e", "f", "g"],
393+
[1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7],
394+
],
395+
names=["int", "str", "float"],
396+
)
397+
pa.parquet.write_table(
398+
table.slice(0, 3), dir_root / "grp=a/date_id=20201005/file.parquet"
399+
)
400+
pa.parquet.write_table(
401+
table.slice(3, 2), dir_root / "grp=a/date_id=20211005/file.parquet"
402+
)
403+
pa.parquet.write_table(
404+
table.slice(5, 10), dir_root / "grp=b/date_id=20201005/file.parquet"
405+
)
406+
407+
ctx.register_object_store("file://local", LocalFileSystem(), None)
408+
ctx.register_listing_table(
409+
"my_table",
410+
f"file://{dir_root}/",
411+
table_partition_cols=[("grp", "string"), ("date_id", "int")],
412+
file_extension=".parquet",
413+
schema=table.schema if pass_schema else None,
414+
file_sort_order=file_sort_order,
415+
)
416+
assert ctx.tables() == {"my_table"}
417+
418+
result = ctx.sql(
419+
"SELECT grp, COUNT(*) AS count FROM my_table GROUP BY grp"
420+
).collect()
421+
result = pa.Table.from_batches(result)
422+
423+
rd = result.to_pydict()
424+
assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2}
425+
426+
result = ctx.sql(
427+
"SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 GROUP BY grp"
428+
).collect()
429+
result = pa.Table.from_batches(result)
430+
431+
rd = result.to_pydict()
432+
assert dict(zip(rd["grp"], rd["count"])) == {"a": 3, "b": 2}

src/context.rs

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,14 @@ use crate::store::StorageContexts;
3939
use crate::udaf::PyAggregateUDF;
4040
use crate::udf::PyScalarUDF;
4141
use crate::utils::{get_tokio_runtime, wait_for_future};
42-
use datafusion::arrow::datatypes::{DataType, Schema};
42+
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
4343
use datafusion::arrow::pyarrow::PyArrowType;
4444
use datafusion::arrow::record_batch::RecordBatch;
4545
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
46+
use datafusion::datasource::file_format::parquet::ParquetFormat;
47+
use datafusion::datasource::listing::{
48+
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
49+
};
4650
use datafusion::datasource::MemTable;
4751
use datafusion::datasource::TableProvider;
4852
use datafusion::execution::context::{SessionConfig, SessionContext, SessionState, TaskContext};
@@ -278,6 +282,53 @@ impl PySessionContext {
278282
Ok(())
279283
}
280284

285+
#[allow(clippy::too_many_arguments)]
286+
#[pyo3(signature = (name, path, table_partition_cols=vec![],
287+
file_extension=".parquet",
288+
schema=None,
289+
file_sort_order=None))]
290+
pub fn register_listing_table(
291+
&mut self,
292+
name: &str,
293+
path: &str,
294+
table_partition_cols: Vec<(String, String)>,
295+
file_extension: &str,
296+
schema: Option<PyArrowType<Schema>>,
297+
file_sort_order: Option<Vec<Vec<PyExpr>>>,
298+
py: Python,
299+
) -> PyResult<()> {
300+
let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
301+
.with_file_extension(file_extension)
302+
.with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
303+
.with_file_sort_order(
304+
file_sort_order
305+
.unwrap_or_default()
306+
.into_iter()
307+
.map(|e| e.into_iter().map(|f| f.into()).collect())
308+
.collect(),
309+
);
310+
let table_path = ListingTableUrl::parse(path)?;
311+
let resolved_schema: SchemaRef = match schema {
312+
Some(s) => Arc::new(s.0),
313+
None => {
314+
let state = self.ctx.state();
315+
let schema = options.infer_schema(&state, &table_path);
316+
wait_for_future(py, schema).map_err(DataFusionError::from)?
317+
}
318+
};
319+
let config = ListingTableConfig::new(table_path)
320+
.with_listing_options(options)
321+
.with_schema(resolved_schema);
322+
let table = ListingTable::try_new(config)?;
323+
self.register_table(
324+
name,
325+
&PyTable {
326+
table: Arc::new(table),
327+
},
328+
)?;
329+
Ok(())
330+
}
331+
281332
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
282333
pub fn sql(&mut self, query: &str, py: Python) -> PyResult<PyDataFrame> {
283334
let result = self.ctx.sql(query);

0 commit comments

Comments
 (0)