@@ -74,6 +74,11 @@ impl SessionContext {
7474mod tests {
7575 use async_trait:: async_trait;
7676
77+ use crate :: arrow:: array:: { Float32Array , Int32Array } ;
78+ use crate :: arrow:: datatypes:: { DataType , Field , Schema } ;
79+ use crate :: arrow:: record_batch:: RecordBatch ;
80+ use crate :: dataframe:: DataFrameWriteOptions ;
81+ use crate :: parquet:: basic:: Compression ;
7782 use crate :: test_util:: parquet_test_data;
7883
7984 use super :: * ;
@@ -132,6 +137,124 @@ mod tests {
132137 Ok ( ( ) )
133138 }
134139
140+ #[ tokio:: test]
141+ async fn read_from_different_file_extension ( ) -> Result < ( ) > {
142+ let ctx = SessionContext :: new ( ) ;
143+
144+ // Make up a new dataframe.
145+ let write_df = ctx. read_batch ( RecordBatch :: try_new (
146+ Arc :: new ( Schema :: new ( vec ! [
147+ Field :: new( "purchase_id" , DataType :: Int32 , false ) ,
148+ Field :: new( "price" , DataType :: Float32 , false ) ,
149+ Field :: new( "quantity" , DataType :: Int32 , false ) ,
150+ ] ) ) ,
151+ vec ! [
152+ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 , 4 , 5 ] ) ) ,
153+ Arc :: new( Float32Array :: from( vec![ 1.12 , 3.40 , 2.33 , 9.10 , 6.66 ] ) ) ,
154+ Arc :: new( Int32Array :: from( vec![ 1 , 3 , 2 , 4 , 3 ] ) ) ,
155+ ] ,
156+ ) ?) ?;
157+
158+ // Write the dataframe to a parquet file named 'output1.parquet'
159+ write_df
160+ . clone ( )
161+ . write_parquet (
162+ "output1.parquet" ,
163+ DataFrameWriteOptions :: new ( ) . with_single_file_output ( true ) ,
164+ Some (
165+ WriterProperties :: builder ( )
166+ . set_compression ( Compression :: SNAPPY )
167+ . build ( ) ,
168+ ) ,
169+ )
170+ . await ?;
171+
172+ // Write the dataframe to a parquet file named 'output2.parquet.snappy'
173+ write_df
174+ . clone ( )
175+ . write_parquet (
176+ "output2.parquet.snappy" ,
177+ DataFrameWriteOptions :: new ( ) . with_single_file_output ( true ) ,
178+ Some (
179+ WriterProperties :: builder ( )
180+ . set_compression ( Compression :: SNAPPY )
181+ . build ( ) ,
182+ ) ,
183+ )
184+ . await ?;
185+
186+ // Write the dataframe to a parquet file named 'output3.parquet.snappy.parquet'
187+ write_df
188+ . write_parquet (
189+ "output3.parquet.snappy.parquet" ,
190+ DataFrameWriteOptions :: new ( ) . with_single_file_output ( true ) ,
191+ Some (
192+ WriterProperties :: builder ( )
193+ . set_compression ( Compression :: SNAPPY )
194+ . build ( ) ,
195+ ) ,
196+ )
197+ . await ?;
198+
199+ // Read the dataframe from 'output1.parquet' with the default file extension.
200+ let read_df = ctx
201+ . read_parquet (
202+ "output1.parquet" ,
203+ ParquetReadOptions {
204+ ..Default :: default ( )
205+ } ,
206+ )
207+ . await ?;
208+
209+ let results = read_df. collect ( ) . await ?;
210+ let total_rows: usize = results. iter ( ) . map ( |rb| rb. num_rows ( ) ) . sum ( ) ;
211+ assert_eq ! ( total_rows, 5 ) ;
212+
213+ // Read the dataframe from 'output2.parquet.snappy' with the correct file extension.
214+ let read_df = ctx
215+ . read_parquet (
216+ "output2.parquet.snappy" ,
217+ ParquetReadOptions {
218+ file_extension : "snappy" ,
219+ ..Default :: default ( )
220+ } ,
221+ )
222+ . await ?;
223+ let results = read_df. collect ( ) . await ?;
224+ let total_rows: usize = results. iter ( ) . map ( |rb| rb. num_rows ( ) ) . sum ( ) ;
225+ assert_eq ! ( total_rows, 5 ) ;
226+
227+ // Read the dataframe from 'output3.parquet.snappy.parquet' with the wrong file extension.
228+ let read_df = ctx
229+ . read_parquet (
230+ "output2.parquet.snappy" ,
231+ ParquetReadOptions {
232+ ..Default :: default ( )
233+ } ,
234+ )
235+ . await ;
236+
237+ assert_eq ! (
238+ read_df. unwrap_err( ) . strip_backtrace( ) ,
239+ "Execution error: File 'output2.parquet.snappy' does not match the expected extension '.parquet'"
240+ ) ;
241+
242+ // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension.
243+ let read_df = ctx
244+ . read_parquet (
245+ "output3.parquet.snappy.parquet" ,
246+ ParquetReadOptions {
247+ ..Default :: default ( )
248+ } ,
249+ )
250+ . await ?;
251+
252+ let results = read_df. collect ( ) . await ?;
253+ let total_rows: usize = results. iter ( ) . map ( |rb| rb. num_rows ( ) ) . sum ( ) ;
254+ assert_eq ! ( total_rows, 5 ) ;
255+ Ok ( ( ) )
256+ }
257+
135258 // Test for compilation error when calling read_* functions from an #[async_trait] function.
136259 // See https://github.com/apache/arrow-datafusion/issues/1154
137260 #[ async_trait]
0 commit comments