Skip to content

Commit 56f6437

Browse files
authored
feat: Support determining extensions from names like foo.parquet.snappy as well as foo.parquet (#7972)
* feat: read files based on the file extention * fix: some the file extension might be started with . and some not * fix: rename extention to extension * chore: use exec_err * chore: rename extention to extension * chore: rename extention to extension * chore: simplify the code * fix: check table is empty * ci: fix test * fix: add err info * refactor: extract the logic to infer_types * fix: add tests for different extensions * fix: ci clippy * fix: add more tests * fix: simplify the logic * fix: ci
1 parent 06fd26b commit 56f6437

File tree

2 files changed

+140
-0
lines changed

2 files changed

+140
-0
lines changed

datafusion/core/src/execution/context/mod.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,23 @@ impl SessionContext {
849849
let table_paths = table_paths.to_urls()?;
850850
let session_config = self.copied_config();
851851
let listing_options = options.to_listing_options(&session_config);
852+
853+
let option_extension = listing_options.file_extension.clone();
854+
855+
if table_paths.is_empty() {
856+
return exec_err!("No table paths were provided");
857+
}
858+
859+
// check if the file extension matches the expected extension
860+
for path in &table_paths {
861+
let file_name = path.prefix().filename().unwrap_or_default();
862+
if !path.as_str().ends_with(&option_extension) && file_name.contains('.') {
863+
return exec_err!(
864+
"File '{file_name}' does not match the expected extension '{option_extension}'"
865+
);
866+
}
867+
}
868+
852869
let resolved_schema = options
853870
.get_resolved_schema(&session_config, self.state(), table_paths[0].clone())
854871
.await?;

datafusion/core/src/execution/context/parquet.rs

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ impl SessionContext {
7474
mod tests {
7575
use async_trait::async_trait;
7676

77+
use crate::arrow::array::{Float32Array, Int32Array};
78+
use crate::arrow::datatypes::{DataType, Field, Schema};
79+
use crate::arrow::record_batch::RecordBatch;
80+
use crate::dataframe::DataFrameWriteOptions;
81+
use crate::parquet::basic::Compression;
7782
use crate::test_util::parquet_test_data;
7883

7984
use super::*;
@@ -132,6 +137,124 @@ mod tests {
132137
Ok(())
133138
}
134139

140+
#[tokio::test]
141+
async fn read_from_different_file_extension() -> Result<()> {
142+
let ctx = SessionContext::new();
143+
144+
// Make up a new dataframe.
145+
let write_df = ctx.read_batch(RecordBatch::try_new(
146+
Arc::new(Schema::new(vec![
147+
Field::new("purchase_id", DataType::Int32, false),
148+
Field::new("price", DataType::Float32, false),
149+
Field::new("quantity", DataType::Int32, false),
150+
])),
151+
vec![
152+
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])),
153+
Arc::new(Float32Array::from(vec![1.12, 3.40, 2.33, 9.10, 6.66])),
154+
Arc::new(Int32Array::from(vec![1, 3, 2, 4, 3])),
155+
],
156+
)?)?;
157+
158+
// Write the dataframe to a parquet file named 'output1.parquet'
159+
write_df
160+
.clone()
161+
.write_parquet(
162+
"output1.parquet",
163+
DataFrameWriteOptions::new().with_single_file_output(true),
164+
Some(
165+
WriterProperties::builder()
166+
.set_compression(Compression::SNAPPY)
167+
.build(),
168+
),
169+
)
170+
.await?;
171+
172+
// Write the dataframe to a parquet file named 'output2.parquet.snappy'
173+
write_df
174+
.clone()
175+
.write_parquet(
176+
"output2.parquet.snappy",
177+
DataFrameWriteOptions::new().with_single_file_output(true),
178+
Some(
179+
WriterProperties::builder()
180+
.set_compression(Compression::SNAPPY)
181+
.build(),
182+
),
183+
)
184+
.await?;
185+
186+
// Write the dataframe to a parquet file named 'output3.parquet.snappy.parquet'
187+
write_df
188+
.write_parquet(
189+
"output3.parquet.snappy.parquet",
190+
DataFrameWriteOptions::new().with_single_file_output(true),
191+
Some(
192+
WriterProperties::builder()
193+
.set_compression(Compression::SNAPPY)
194+
.build(),
195+
),
196+
)
197+
.await?;
198+
199+
// Read the dataframe from 'output1.parquet' with the default file extension.
200+
let read_df = ctx
201+
.read_parquet(
202+
"output1.parquet",
203+
ParquetReadOptions {
204+
..Default::default()
205+
},
206+
)
207+
.await?;
208+
209+
let results = read_df.collect().await?;
210+
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
211+
assert_eq!(total_rows, 5);
212+
213+
// Read the dataframe from 'output2.parquet.snappy' with the correct file extension.
214+
let read_df = ctx
215+
.read_parquet(
216+
"output2.parquet.snappy",
217+
ParquetReadOptions {
218+
file_extension: "snappy",
219+
..Default::default()
220+
},
221+
)
222+
.await?;
223+
let results = read_df.collect().await?;
224+
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
225+
assert_eq!(total_rows, 5);
226+
227+
// Read the dataframe from 'output3.parquet.snappy.parquet' with the wrong file extension.
228+
let read_df = ctx
229+
.read_parquet(
230+
"output2.parquet.snappy",
231+
ParquetReadOptions {
232+
..Default::default()
233+
},
234+
)
235+
.await;
236+
237+
assert_eq!(
238+
read_df.unwrap_err().strip_backtrace(),
239+
"Execution error: File 'output2.parquet.snappy' does not match the expected extension '.parquet'"
240+
);
241+
242+
// Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension.
243+
let read_df = ctx
244+
.read_parquet(
245+
"output3.parquet.snappy.parquet",
246+
ParquetReadOptions {
247+
..Default::default()
248+
},
249+
)
250+
.await?;
251+
252+
let results = read_df.collect().await?;
253+
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
254+
assert_eq!(total_rows, 5);
255+
Ok(())
256+
}
257+
135258
// Test for compilation error when calling read_* functions from an #[async_trait] function.
136259
// See https://github.com/apache/arrow-datafusion/issues/1154
137260
#[async_trait]

0 commit comments

Comments
 (0)