Skip to content

Commit 2bc413a

Browse files
committed
support decimal to arithmetic operation
1 parent a2551db commit 2bc413a

3 files changed

Lines changed: 472 additions & 6 deletions

File tree

ballista/rust/core/src/serde/logical_plan/to_proto.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ impl protobuf::IntervalUnit {
6060
match interval_unit {
6161
IntervalUnit::YearMonth => protobuf::IntervalUnit::YearMonth,
6262
IntervalUnit::DayTime => protobuf::IntervalUnit::DayTime,
63-
IntervalUnit::MonthDayNano => protobuf::IntervalUnit::MonthDayNano,
63+
IntervalUnit::MonthDayNano => protobuf::IntervalUnit::MonthDayNano,
6464
}
6565
}
6666

datafusion/src/physical_plan/coercion_rule/binary_rule.rs

Lines changed: 200 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ use crate::arrow::datatypes::DataType;
22
use crate::error::{DataFusionError, Result};
33
use crate::logical_plan::Operator;
44
use crate::physical_plan::expressions::coercion::{
5-
dictionary_coercion, eq_coercion, is_numeric, like_coercion, numerical_coercion,
6-
string_coercion, temporal_coercion,
5+
dictionary_coercion, eq_coercion, is_numeric, like_coercion, string_coercion,
6+
temporal_coercion,
77
};
8+
use crate::scalar::{MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128};
89

910
/// Coercion rules for all binary operators. Returns the output type
1011
/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
@@ -30,12 +31,11 @@ pub(crate) fn coerce_types(
3031
Operator::Like | Operator::NotLike => like_coercion(lhs_type, rhs_type),
3132
// for math expressions, the final value of the coercion is also the return type
3233
// because coercion favours higher information types
33-
// TODO: support decimal data type
3434
Operator::Plus
3535
| Operator::Minus
3636
| Operator::Modulo
3737
| Operator::Divide
38-
| Operator::Multiply => numerical_coercion(lhs_type, rhs_type),
38+
| Operator::Multiply => mathematics_numerical_coercion(&op, lhs_type, rhs_type),
3939
Operator::RegexMatch
4040
| Operator::RegexIMatch
4141
| Operator::RegexNotMatch
@@ -143,12 +143,141 @@ fn get_comparison_common_decimal_type(
143143
}
144144
}
145145

146+
// Convert the numeric data type to the decimal data type.
147+
// Now, we just support the signed integer type and floating-point type.
148+
fn convert_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
149+
match numeric_type {
150+
DataType::Int8 => Some(DataType::Decimal(3, 0)),
151+
DataType::Int16 => Some(DataType::Decimal(5, 0)),
152+
DataType::Int32 => Some(DataType::Decimal(10, 0)),
153+
DataType::Int64 => Some(DataType::Decimal(20, 0)),
154+
// TODO if we convert the floating-point data to the decimal type, it maybe overflow.
155+
DataType::Float32 => Some(DataType::Decimal(14, 7)),
156+
DataType::Float64 => Some(DataType::Decimal(30, 15)),
157+
_ => None,
158+
}
159+
}
160+
161+
fn mathematics_numerical_coercion(
162+
mathematics_op: &Operator,
163+
lhs_type: &DataType,
164+
rhs_type: &DataType,
165+
) -> Option<DataType> {
166+
use arrow::datatypes::DataType::*;
167+
168+
// error on any non-numeric type
169+
if !is_numeric(lhs_type) || !is_numeric(rhs_type) {
170+
return None;
171+
};
172+
173+
// same type => all good
174+
if lhs_type == rhs_type {
175+
return Some(lhs_type.clone());
176+
}
177+
178+
// these are ordered from most informative to least informative so
179+
// that the coercion removes the least amount of information
180+
match (lhs_type, rhs_type) {
181+
(Decimal(_, _), Decimal(_, _)) => {
182+
coercion_decimal_mathematics_type(mathematics_op, lhs_type, rhs_type)
183+
}
184+
(Decimal(_, _), _) => {
185+
let converted_decimal_type = convert_numeric_type_to_decimal(rhs_type);
186+
if converted_decimal_type.is_none() {
187+
None
188+
} else {
189+
coercion_decimal_mathematics_type(
190+
mathematics_op,
191+
lhs_type,
192+
&converted_decimal_type.unwrap(),
193+
)
194+
}
195+
}
196+
(_, Decimal(_, _)) => {
197+
let converted_decimal_type = convert_numeric_type_to_decimal(lhs_type);
198+
if converted_decimal_type.is_none() {
199+
None
200+
} else {
201+
coercion_decimal_mathematics_type(
202+
mathematics_op,
203+
&converted_decimal_type.unwrap(),
204+
rhs_type,
205+
)
206+
}
207+
}
208+
(Float64, _) | (_, Float64) => Some(Float64),
209+
(_, Float32) | (Float32, _) => Some(Float32),
210+
(Int64, _) | (_, Int64) => Some(Int64),
211+
(Int32, _) | (_, Int32) => Some(Int32),
212+
(Int16, _) | (_, Int16) => Some(Int16),
213+
(Int8, _) | (_, Int8) => Some(Int8),
214+
(UInt64, _) | (_, UInt64) => Some(UInt64),
215+
(UInt32, _) | (_, UInt32) => Some(UInt32),
216+
(UInt16, _) | (_, UInt16) => Some(UInt16),
217+
(UInt8, _) | (_, UInt8) => Some(UInt8),
218+
_ => None,
219+
}
220+
}
221+
222+
fn create_decimal_type(precision: usize, scale: usize) -> DataType {
223+
DataType::Decimal(
224+
MAX_PRECISION_FOR_DECIMAL128.min(precision),
225+
MAX_SCALE_FOR_DECIMAL128.min(scale),
226+
)
227+
}
228+
229+
fn coercion_decimal_mathematics_type(
230+
mathematics_op: &Operator,
231+
left_decimal_type: &DataType,
232+
right_decimal_type: &DataType,
233+
) -> Option<DataType> {
234+
use arrow::datatypes::DataType::*;
235+
match (left_decimal_type, right_decimal_type) {
236+
(Decimal(p1, s1), Decimal(p2, s2)) => {
237+
match mathematics_op {
238+
Operator::Plus | Operator::Minus => {
239+
// max(s1, s2)
240+
let result_scale = *s1.max(s2);
241+
// max(s1, s2) + max(p1-s1, p2-s2) + 1
242+
let result_precision = result_scale + (*p1 - *s1).max(*p2 - *s2) + 1;
243+
Some(create_decimal_type(result_precision, result_scale))
244+
}
245+
Operator::Multiply => {
246+
// s1 + s2
247+
let result_scale = *s1 + *s2;
248+
// p1 + p2 + 1
249+
let result_precision = *p1 + *p2 + 1;
250+
Some(create_decimal_type(result_precision, result_scale))
251+
}
252+
Operator::Divide => {
253+
// max(6, s1 + p2 + 1)
254+
let result_scale = 6.max(*s1 + *p2 + 1);
255+
// p1 - s1 + s2 + max(6, s1 + p2 + 1)
256+
let result_precision = result_scale + *p1 - *s1 + *s2;
257+
Some(create_decimal_type(result_precision, result_scale))
258+
}
259+
Operator::Modulo => {
260+
// max(s1, s2)
261+
let result_scale = *s1.max(s2);
262+
// min(p1-s1, p2-s2) + max(s1, s2)
263+
let result_precision = result_scale + (*p1 - *s1).min(*p2 - *s2);
264+
Some(create_decimal_type(result_precision, result_scale))
265+
}
266+
_ => unreachable!(),
267+
}
268+
}
269+
_ => unreachable!(),
270+
}
271+
}
272+
146273
#[cfg(test)]
147274
mod tests {
148275
use crate::arrow::datatypes::DataType;
149276
use crate::error::{DataFusionError, Result};
150277
use crate::logical_plan::Operator;
151-
use crate::physical_plan::coercion_rule::binary_rule::coerce_types;
278+
use crate::physical_plan::coercion_rule::binary_rule::{
279+
coerce_types, coercion_decimal_mathematics_type, convert_numeric_type_to_decimal,
280+
};
152281

153282
#[test]
154283

@@ -207,4 +336,70 @@ mod tests {
207336
assert!(result_type.is_err());
208337
Ok(())
209338
}
339+
340+
#[test]
341+
fn test_decimal_mathematics_op_type() {
342+
assert_eq!(
343+
convert_numeric_type_to_decimal(&DataType::Int8).unwrap(),
344+
DataType::Decimal(3, 0)
345+
);
346+
assert_eq!(
347+
convert_numeric_type_to_decimal(&DataType::Int16).unwrap(),
348+
DataType::Decimal(5, 0)
349+
);
350+
assert_eq!(
351+
convert_numeric_type_to_decimal(&DataType::Int32).unwrap(),
352+
DataType::Decimal(10, 0)
353+
);
354+
assert_eq!(
355+
convert_numeric_type_to_decimal(&DataType::Int64).unwrap(),
356+
DataType::Decimal(20, 0)
357+
);
358+
assert_eq!(
359+
convert_numeric_type_to_decimal(&DataType::Float32).unwrap(),
360+
DataType::Decimal(14, 7)
361+
);
362+
assert_eq!(
363+
convert_numeric_type_to_decimal(&DataType::Float64).unwrap(),
364+
DataType::Decimal(30, 15)
365+
);
366+
367+
let op = Operator::Plus;
368+
let left_decimal_type = DataType::Decimal(10, 3);
369+
let right_decimal_type = DataType::Decimal(20, 4);
370+
let result = coercion_decimal_mathematics_type(
371+
&op,
372+
&left_decimal_type,
373+
&right_decimal_type,
374+
);
375+
assert_eq!(DataType::Decimal(21, 4), result.unwrap());
376+
let op = Operator::Minus;
377+
let result = coercion_decimal_mathematics_type(
378+
&op,
379+
&left_decimal_type,
380+
&right_decimal_type,
381+
);
382+
assert_eq!(DataType::Decimal(21, 4), result.unwrap());
383+
let op = Operator::Multiply;
384+
let result = coercion_decimal_mathematics_type(
385+
&op,
386+
&left_decimal_type,
387+
&right_decimal_type,
388+
);
389+
assert_eq!(DataType::Decimal(31, 7), result.unwrap());
390+
let op = Operator::Divide;
391+
let result = coercion_decimal_mathematics_type(
392+
&op,
393+
&left_decimal_type,
394+
&right_decimal_type,
395+
);
396+
assert_eq!(DataType::Decimal(35, 24), result.unwrap());
397+
let op = Operator::Modulo;
398+
let result = coercion_decimal_mathematics_type(
399+
&op,
400+
&left_decimal_type,
401+
&right_decimal_type,
402+
);
403+
assert_eq!(DataType::Decimal(11, 4), result.unwrap());
404+
}
210405
}

0 commit comments

Comments
 (0)