@@ -2,9 +2,10 @@ use crate::arrow::datatypes::DataType;
22use crate :: error:: { DataFusionError , Result } ;
33use crate :: logical_plan:: Operator ;
44use crate :: physical_plan:: expressions:: coercion:: {
5- dictionary_coercion, eq_coercion, is_numeric, like_coercion, numerical_coercion ,
6- string_coercion , temporal_coercion,
5+ dictionary_coercion, eq_coercion, is_numeric, like_coercion, string_coercion ,
6+ temporal_coercion,
77} ;
8+ use crate :: scalar:: { MAX_PRECISION_FOR_DECIMAL128 , MAX_SCALE_FOR_DECIMAL128 } ;
89
910/// Coercion rules for all binary operators. Returns the output type
1011/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
@@ -30,12 +31,11 @@ pub(crate) fn coerce_types(
3031 Operator :: Like | Operator :: NotLike => like_coercion ( lhs_type, rhs_type) ,
3132 // for math expressions, the final value of the coercion is also the return type
3233 // because coercion favours higher information types
33- // TODO: support decimal data type
3434 Operator :: Plus
3535 | Operator :: Minus
3636 | Operator :: Modulo
3737 | Operator :: Divide
38- | Operator :: Multiply => numerical_coercion ( lhs_type, rhs_type) ,
38+ | Operator :: Multiply => mathematics_numerical_coercion ( & op , lhs_type, rhs_type) ,
3939 Operator :: RegexMatch
4040 | Operator :: RegexIMatch
4141 | Operator :: RegexNotMatch
@@ -143,12 +143,141 @@ fn get_comparison_common_decimal_type(
143143 }
144144}
145145
146+ // Convert the numeric data type to the decimal data type.
147+ // Now, we just support the signed integer type and floating-point type.
148+ fn convert_numeric_type_to_decimal ( numeric_type : & DataType ) -> Option < DataType > {
149+ match numeric_type {
150+ DataType :: Int8 => Some ( DataType :: Decimal ( 3 , 0 ) ) ,
151+ DataType :: Int16 => Some ( DataType :: Decimal ( 5 , 0 ) ) ,
152+ DataType :: Int32 => Some ( DataType :: Decimal ( 10 , 0 ) ) ,
153+ DataType :: Int64 => Some ( DataType :: Decimal ( 20 , 0 ) ) ,
154+ // TODO if we convert the floating-point data to the decimal type, it maybe overflow.
155+ DataType :: Float32 => Some ( DataType :: Decimal ( 14 , 7 ) ) ,
156+ DataType :: Float64 => Some ( DataType :: Decimal ( 30 , 15 ) ) ,
157+ _ => None ,
158+ }
159+ }
160+
161+ fn mathematics_numerical_coercion (
162+ mathematics_op : & Operator ,
163+ lhs_type : & DataType ,
164+ rhs_type : & DataType ,
165+ ) -> Option < DataType > {
166+ use arrow:: datatypes:: DataType :: * ;
167+
168+ // error on any non-numeric type
169+ if !is_numeric ( lhs_type) || !is_numeric ( rhs_type) {
170+ return None ;
171+ } ;
172+
173+ // same type => all good
174+ if lhs_type == rhs_type {
175+ return Some ( lhs_type. clone ( ) ) ;
176+ }
177+
178+ // these are ordered from most informative to least informative so
179+ // that the coercion removes the least amount of information
180+ match ( lhs_type, rhs_type) {
181+ ( Decimal ( _, _) , Decimal ( _, _) ) => {
182+ coercion_decimal_mathematics_type ( mathematics_op, lhs_type, rhs_type)
183+ }
184+ ( Decimal ( _, _) , _) => {
185+ let converted_decimal_type = convert_numeric_type_to_decimal ( rhs_type) ;
186+ if converted_decimal_type. is_none ( ) {
187+ None
188+ } else {
189+ coercion_decimal_mathematics_type (
190+ mathematics_op,
191+ lhs_type,
192+ & converted_decimal_type. unwrap ( ) ,
193+ )
194+ }
195+ }
196+ ( _, Decimal ( _, _) ) => {
197+ let converted_decimal_type = convert_numeric_type_to_decimal ( lhs_type) ;
198+ if converted_decimal_type. is_none ( ) {
199+ None
200+ } else {
201+ coercion_decimal_mathematics_type (
202+ mathematics_op,
203+ & converted_decimal_type. unwrap ( ) ,
204+ rhs_type,
205+ )
206+ }
207+ }
208+ ( Float64 , _) | ( _, Float64 ) => Some ( Float64 ) ,
209+ ( _, Float32 ) | ( Float32 , _) => Some ( Float32 ) ,
210+ ( Int64 , _) | ( _, Int64 ) => Some ( Int64 ) ,
211+ ( Int32 , _) | ( _, Int32 ) => Some ( Int32 ) ,
212+ ( Int16 , _) | ( _, Int16 ) => Some ( Int16 ) ,
213+ ( Int8 , _) | ( _, Int8 ) => Some ( Int8 ) ,
214+ ( UInt64 , _) | ( _, UInt64 ) => Some ( UInt64 ) ,
215+ ( UInt32 , _) | ( _, UInt32 ) => Some ( UInt32 ) ,
216+ ( UInt16 , _) | ( _, UInt16 ) => Some ( UInt16 ) ,
217+ ( UInt8 , _) | ( _, UInt8 ) => Some ( UInt8 ) ,
218+ _ => None ,
219+ }
220+ }
221+
222+ fn create_decimal_type ( precision : usize , scale : usize ) -> DataType {
223+ DataType :: Decimal (
224+ MAX_PRECISION_FOR_DECIMAL128 . min ( precision) ,
225+ MAX_SCALE_FOR_DECIMAL128 . min ( scale) ,
226+ )
227+ }
228+
229+ fn coercion_decimal_mathematics_type (
230+ mathematics_op : & Operator ,
231+ left_decimal_type : & DataType ,
232+ right_decimal_type : & DataType ,
233+ ) -> Option < DataType > {
234+ use arrow:: datatypes:: DataType :: * ;
235+ match ( left_decimal_type, right_decimal_type) {
236+ ( Decimal ( p1, s1) , Decimal ( p2, s2) ) => {
237+ match mathematics_op {
238+ Operator :: Plus | Operator :: Minus => {
239+ // max(s1, s2)
240+ let result_scale = * s1. max ( s2) ;
241+ // max(s1, s2) + max(p1-s1, p2-s2) + 1
242+ let result_precision = result_scale + ( * p1 - * s1) . max ( * p2 - * s2) + 1 ;
243+ Some ( create_decimal_type ( result_precision, result_scale) )
244+ }
245+ Operator :: Multiply => {
246+ // s1 + s2
247+ let result_scale = * s1 + * s2;
248+ // p1 + p2 + 1
249+ let result_precision = * p1 + * p2 + 1 ;
250+ Some ( create_decimal_type ( result_precision, result_scale) )
251+ }
252+ Operator :: Divide => {
253+ // max(6, s1 + p2 + 1)
254+ let result_scale = 6 . max ( * s1 + * p2 + 1 ) ;
255+ // p1 - s1 + s2 + max(6, s1 + p2 + 1)
256+ let result_precision = result_scale + * p1 - * s1 + * s2;
257+ Some ( create_decimal_type ( result_precision, result_scale) )
258+ }
259+ Operator :: Modulo => {
260+ // max(s1, s2)
261+ let result_scale = * s1. max ( s2) ;
262+ // min(p1-s1, p2-s2) + max(s1, s2)
263+ let result_precision = result_scale + ( * p1 - * s1) . min ( * p2 - * s2) ;
264+ Some ( create_decimal_type ( result_precision, result_scale) )
265+ }
266+ _ => unreachable ! ( ) ,
267+ }
268+ }
269+ _ => unreachable ! ( ) ,
270+ }
271+ }
272+
146273#[ cfg( test) ]
147274mod tests {
148275 use crate :: arrow:: datatypes:: DataType ;
149276 use crate :: error:: { DataFusionError , Result } ;
150277 use crate :: logical_plan:: Operator ;
151- use crate :: physical_plan:: coercion_rule:: binary_rule:: coerce_types;
278+ use crate :: physical_plan:: coercion_rule:: binary_rule:: {
279+ coerce_types, coercion_decimal_mathematics_type, convert_numeric_type_to_decimal,
280+ } ;
152281
153282 #[ test]
154283
@@ -207,4 +336,70 @@ mod tests {
207336 assert ! ( result_type. is_err( ) ) ;
208337 Ok ( ( ) )
209338 }
339+
340+ #[ test]
341+ fn test_decimal_mathematics_op_type ( ) {
342+ assert_eq ! (
343+ convert_numeric_type_to_decimal( & DataType :: Int8 ) . unwrap( ) ,
344+ DataType :: Decimal ( 3 , 0 )
345+ ) ;
346+ assert_eq ! (
347+ convert_numeric_type_to_decimal( & DataType :: Int16 ) . unwrap( ) ,
348+ DataType :: Decimal ( 5 , 0 )
349+ ) ;
350+ assert_eq ! (
351+ convert_numeric_type_to_decimal( & DataType :: Int32 ) . unwrap( ) ,
352+ DataType :: Decimal ( 10 , 0 )
353+ ) ;
354+ assert_eq ! (
355+ convert_numeric_type_to_decimal( & DataType :: Int64 ) . unwrap( ) ,
356+ DataType :: Decimal ( 20 , 0 )
357+ ) ;
358+ assert_eq ! (
359+ convert_numeric_type_to_decimal( & DataType :: Float32 ) . unwrap( ) ,
360+ DataType :: Decimal ( 14 , 7 )
361+ ) ;
362+ assert_eq ! (
363+ convert_numeric_type_to_decimal( & DataType :: Float64 ) . unwrap( ) ,
364+ DataType :: Decimal ( 30 , 15 )
365+ ) ;
366+
367+ let op = Operator :: Plus ;
368+ let left_decimal_type = DataType :: Decimal ( 10 , 3 ) ;
369+ let right_decimal_type = DataType :: Decimal ( 20 , 4 ) ;
370+ let result = coercion_decimal_mathematics_type (
371+ & op,
372+ & left_decimal_type,
373+ & right_decimal_type,
374+ ) ;
375+ assert_eq ! ( DataType :: Decimal ( 21 , 4 ) , result. unwrap( ) ) ;
376+ let op = Operator :: Minus ;
377+ let result = coercion_decimal_mathematics_type (
378+ & op,
379+ & left_decimal_type,
380+ & right_decimal_type,
381+ ) ;
382+ assert_eq ! ( DataType :: Decimal ( 21 , 4 ) , result. unwrap( ) ) ;
383+ let op = Operator :: Multiply ;
384+ let result = coercion_decimal_mathematics_type (
385+ & op,
386+ & left_decimal_type,
387+ & right_decimal_type,
388+ ) ;
389+ assert_eq ! ( DataType :: Decimal ( 31 , 7 ) , result. unwrap( ) ) ;
390+ let op = Operator :: Divide ;
391+ let result = coercion_decimal_mathematics_type (
392+ & op,
393+ & left_decimal_type,
394+ & right_decimal_type,
395+ ) ;
396+ assert_eq ! ( DataType :: Decimal ( 35 , 24 ) , result. unwrap( ) ) ;
397+ let op = Operator :: Modulo ;
398+ let result = coercion_decimal_mathematics_type (
399+ & op,
400+ & left_decimal_type,
401+ & right_decimal_type,
402+ ) ;
403+ assert_eq ! ( DataType :: Decimal ( 11 , 4 ) , result. unwrap( ) ) ;
404+ }
210405}
0 commit comments