From d5af58d4e98f229cd08749a2a2f85a4930822fb3 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Sat, 18 May 2024 07:27:48 -0700 Subject: [PATCH 01/13] feat: add hex scalar function --- .../datafusion/expressions/scalar_funcs.rs | 7 + .../expressions/scalar_funcs/hex.rs | 191 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 7 + .../apache/comet/CometExpressionSuite.scala | 40 ++++ 4 files changed, 245 insertions(+) create mode 100644 core/src/execution/datafusion/expressions/scalar_funcs/hex.rs diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 8c5e1f3916..9736bb26d3 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -55,6 +55,9 @@ use unicode_segmentation::UnicodeSegmentation; mod unhex; use unhex::spark_unhex; +mod hex; +use hex::spark_hex; + macro_rules! make_comet_scalar_udf { ($name:expr, $func:ident, $data_type:ident) => {{ let scalar_func = CometScalarFunction::new( @@ -108,6 +111,10 @@ pub fn create_comet_physical_fun( "make_decimal" => { make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type) } + "hex" => { + let func = Arc::new(spark_hex); + make_comet_scalar_udf!("hex", func, without data_type) + } "unhex" => { let func = Arc::new(spark_unhex); make_comet_scalar_udf!("unhex", func, without data_type) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs new file mode 100644 index 0000000000..9e746fad75 --- /dev/null +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::as_string_array; +use arrow_array::StringArray; +use arrow_schema::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ + cast::{as_binary_array, as_int64_array}, + exec_err, DataFusionError, ScalarValue, +}; +use std::fmt::Write; + +fn hex_bytes(bytes: &[u8]) -> Vec { + let length = bytes.len(); + let mut value = vec![0; length * 2]; + let mut i = 0; + while i < length { + value[i * 2] = (bytes[i] & 0xF0) >> 4; + value[i * 2 + 1] = bytes[i] & 0x0F; + i += 1; + } + value +} + +fn hex_int64(num: i64) -> String { + if num >= 0 { + format!("{:X}", num) + } else { + format!("{:016X}", num as u64) + } +} + +fn hex_string(s: &str) -> Vec { + hex_bytes(s.as_bytes()) +} + +fn hex_bytes_to_string(bytes: &[u8]) -> Result { + let mut hex_string = String::with_capacity(bytes.len() * 2); + for byte in bytes { + write!(&mut hex_string, "{:01X}", byte)?; + } + Ok(hex_string) +} + +pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return Err(DataFusionError::Internal( + "hex expects exactly one argument".to_string(), + )); + } + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Int64 => { + let array = as_int64_array(array)?; + + let hexed: Vec> = array.iter().map(|v| v.map(hex_int64)).collect(); + + let string_array = StringArray::from(hexed); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + DataType::Utf8 => { + let array = as_string_array(array); + + let hexed: Vec> = array + .iter() + .map(|v| v.map(|v| hex_bytes_to_string(&hex_string(v))).transpose()) + .collect::>()?; + + let string_array = StringArray::from(hexed); + + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + DataType::Binary => { + let array = as_binary_array(array)?; + + let hexed: Vec> = array + .iter() + .map(|v| v.map(|v| hex_bytes_to_string(&hex_bytes(v))).transpose()) + .collect::>()?; + + let string_array = StringArray::from(hexed); + + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + _ => exec_err!("hex expects a string, binary or integer argument"), + }, + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Int64(Some(v)) => { + let hex_string = hex_int64(*v); + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(hex_string)))) + } + ScalarValue::Binary(Some(v)) | ScalarValue::LargeBinary(Some(v)) => { + let hex_bytes = hex_bytes(v); + let hex_string = hex_bytes_to_string(&hex_bytes)?; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(hex_string)))) + } + ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => { + let hex_bytes = hex_string(v); + let hex_string = hex_bytes_to_string(&hex_bytes)?; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(hex_string)))) + } + ScalarValue::Int64(None) | ScalarValue::Utf8(None) | ScalarValue::Binary(None) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + _ => exec_err!("hex expects a string, binary or integer argument"), + }, + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::as_string_array; + use arrow_array::{Int64Array, StringArray}; + use datafusion::logical_expr::ColumnarValue; + + #[test] + fn test_hex_bytes() { + let bytes = [0x01, 0x02, 0x03, 0x04]; + let hexed = super::hex_bytes(&bytes); + assert_eq!(hexed, vec![0, 1, 0, 2, 0, 3, 0, 4]); + } + + #[test] + fn test_hex_bytes_to_string() -> Result<(), std::fmt::Error> { + let bytes = [0x01, 0x02, 0x03, 0x04]; + let hexed = super::hex_bytes_to_string(&bytes)?; + assert_eq!(hexed, "1234".to_string()); + + let large_bytes = [0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0]; + let hexed = super::hex_bytes_to_string(&large_bytes)?; + assert_eq!(hexed, "123456789ABCDEF0".to_string()); + + Ok(()) + } + + #[test] + fn test_hex_int64() { + let num = 1234; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "4D2".to_string()); + + let num = -1; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string()); + } + + #[test] + fn test_spark_hex_int64() { + let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]); + let columnar_value = ColumnarValue::Array(Arc::new(int_array)); + + let result = super::spark_hex(&[columnar_value]).unwrap(); + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let string_array = as_string_array(&result); + let expected_array = StringArray::from(vec![ + Some("1".to_string()), + Some("2".to_string()), + None, + Some("3".to_string()), + ]); + + assert_eq!(string_array, &expected_array); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index cf7c86a9fd..0ca99a0c5a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1483,6 +1483,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr) optExprWithInfo(optExpr, expr, left, right) + case Hex(child) => + val childExpr = exprToProtoInternal(child, inputs) + val optExpr = + scalarExprToProtoWithReturnType("hex", StringType, childExpr) + + optExprWithInfo(optExpr, expr, child) + case e: Unhex if !isSpark32 => val unHex = unhexSerde(e) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index f3fd50e9e9..058d99552e 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1038,6 +1038,46 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + test("hex") { + val str_table = "string_hex_table" + withTable(str_table) { + sql(s"create table $str_table(col string) using parquet") + + sql(s"""INSERT INTO $str_table VALUES + |('Spark SQL'), + |('string'), + |(''), + |('###'), + |('G123'), + |(NULL), + |('hello'), + |('A1B'), + |('0A1B')""".stripMargin) + + checkSparkAnswerAndOperator(s"SELECT hex(col) FROM $str_table") + checkSparkAnswerAndOperator(s"SELECT hex(CAST(col AS BINARY)) FROM $str_table") + } + + val int_table = "int_hex_table" + withTable(int_table) { + sql(s"create table $int_table(col int) using parquet") + + sql(s"""INSERT INTO $int_table VALUES + |(-1), + |(0), + |(1), + |(2), + |(3), + |(4), + |(NULL), + |(5), + |(6), + |(7), + |(8)""".stripMargin) + + checkSparkAnswerAndOperator(s"SELECT hex(col) FROM $int_table") + } + } test("unhex") { // When running against Spark 3.2, we include a bug fix for https://issues.apache.org/jira/browse/SPARK-40924 that // was added in Spark 3.3, so although Comet's behavior is more correct when running against Spark 3.2, it is not From 2bfcf25a6705ec47d857670c65dbba979ca10e23 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 21 May 2024 08:53:39 -0700 Subject: [PATCH 02/13] test: change hex test to use makeParquetFileAllTypes, support more types --- .../expressions/scalar_funcs/hex.rs | 134 +++++++++++++++++- .../apache/comet/CometExpressionSuite.scala | 48 ++----- 2 files changed, 141 insertions(+), 41 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index 9e746fad75..8a5df6a980 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -18,11 +18,11 @@ use std::sync::Arc; use arrow::array::as_string_array; -use arrow_array::StringArray; +use arrow_array::{Int16Array, Int8Array, StringArray}; use arrow_schema::DataType; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{ - cast::{as_binary_array, as_int64_array}, + cast::{as_binary_array, as_fixed_size_binary_array, as_int32_array, as_int64_array}, exec_err, DataFusionError, ScalarValue, }; use std::fmt::Write; @@ -47,6 +47,30 @@ fn hex_int64(num: i64) -> String { } } +fn hex_int32(num: i32) -> String { + if num >= 0 { + format!("{:X}", num) + } else { + format!("{:08X}", num as u32) + } +} + +fn hex_int16(num: i16) -> String { + if num >= 0 { + format!("{:X}", num) + } else { + format!("{:04X}", num as u16) + } +} + +fn hex_int8(num: i8) -> String { + if num >= 0 { + format!("{:X}", num) + } else { + format!("{:02X}", num as u8) + } +} + fn hex_string(s: &str) -> Vec { hex_bytes(s.as_bytes()) } @@ -54,7 +78,7 @@ fn hex_string(s: &str) -> Vec { fn hex_bytes_to_string(bytes: &[u8]) -> Result { let mut hex_string = String::with_capacity(bytes.len() * 2); for byte in bytes { - write!(&mut hex_string, "{:01X}", byte)?; + write!(&mut hex_string, "{:X}", byte)?; } Ok(hex_string) } @@ -76,6 +100,46 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { + let array = as_int32_array(array)?; + + let hexed: Vec> = array.iter().map(|v| v.map(hex_int32)).collect(); + + let string_array = StringArray::from(hexed); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + DataType::Int16 => { + let array = array.as_any().downcast_ref::().unwrap(); + + let hexed: Vec> = array.iter().map(|v| v.map(hex_int16)).collect(); + + let string_array = StringArray::from(hexed); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + DataType::Int8 => { + let array = array.as_any().downcast_ref::().unwrap(); + + let hexed: Vec> = array.iter().map(|v| v.map(hex_int8)).collect(); + + let string_array = StringArray::from(hexed); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + DataType::UInt64 => { + let array = as_int64_array(array)?; + + let hexed: Vec> = array.iter().map(|v| v.map(hex_int64)).collect(); + + let string_array = StringArray::from(hexed); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + DataType::UInt8 => { + let array = array.as_any().downcast_ref::().unwrap(); + + let hexed: Vec> = array.iter().map(|v| v.map(hex_int8)).collect(); + + let string_array = StringArray::from(hexed); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } DataType::Utf8 => { let array = as_string_array(array); @@ -100,7 +164,22 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result exec_err!("hex expects a string, binary or integer argument"), + DataType::FixedSizeBinary(_) => { + let array = as_fixed_size_binary_array(array)?; + + let hexed: Vec> = array + .iter() + .map(|v| v.map(|v| hex_bytes_to_string(&hex_bytes(v))).transpose()) + .collect::>()?; + + let string_array = StringArray::from(hexed); + + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + _ => exec_err!( + "hex expects a string, binary or integer argument, got {:?}", + array.data_type() + ), }, ColumnarValue::Scalar(scalar) => match scalar { ScalarValue::Int64(Some(v)) => { @@ -108,7 +187,9 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { + ScalarValue::Binary(Some(v)) + | ScalarValue::LargeBinary(Some(v)) + | ScalarValue::FixedSizeBinary(_, Some(v)) => { let hex_bytes = hex_bytes(v); let hex_string = hex_bytes_to_string(&hex_bytes)?; @@ -120,10 +201,16 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { + ScalarValue::Int64(None) + | ScalarValue::Utf8(None) + | ScalarValue::Binary(None) + | ScalarValue::FixedSizeBinary(_, None) => { Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) } - _ => exec_err!("hex expects a string, binary or integer argument"), + _ => exec_err!( + "hex expects a string, binary or integer argument, got {:?}", + scalar.data_type() + ), }, } } @@ -156,6 +243,39 @@ mod test { Ok(()) } + #[test] + fn test_hex_i8() { + let num = 123; + let hexed = super::hex_int8(num); + assert_eq!(hexed, "7B".to_string()); + + let num = -1; + let hexed = super::hex_int8(num); + assert_eq!(hexed, "FF".to_string()); + } + + #[test] + fn test_hex_i16() { + let num = 1234; + let hexed = super::hex_int16(num); + assert_eq!(hexed, "4D2".to_string()); + + let num = -1; + let hexed = super::hex_int16(num); + assert_eq!(hexed, "FFFF".to_string()); + } + + #[test] + fn test_hex_i32() { + let num = 1234; + let hexed = super::hex_int32(num); + assert_eq!(hexed, "4D2".to_string()); + + let num = -1; + let hexed = super::hex_int32(num); + assert_eq!(hexed, "FFFFFFFF".to_string()); + } + #[test] fn test_hex_int64() { let num = 1234; diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 058d99552e..49b1657dec 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1039,43 +1039,23 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } test("hex") { - val str_table = "string_hex_table" - withTable(str_table) { - sql(s"create table $str_table(col string) using parquet") - - sql(s"""INSERT INTO $str_table VALUES - |('Spark SQL'), - |('string'), - |(''), - |('###'), - |('G123'), - |(NULL), - |('hello'), - |('A1B'), - |('0A1B')""".stripMargin) - - checkSparkAnswerAndOperator(s"SELECT hex(col) FROM $str_table") - checkSparkAnswerAndOperator(s"SELECT hex(CAST(col AS BINARY)) FROM $str_table") - } + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "hex.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) - val int_table = "int_hex_table" - withTable(int_table) { - sql(s"create table $int_table(col int) using parquet") + withParquetTable(path.toString, "tbl") { + // ints + checkSparkAnswerAndOperator("SELECT hex(_2), hex(_3), hex(_4), hex(_5) FROM tbl") - sql(s"""INSERT INTO $int_table VALUES - |(-1), - |(0), - |(1), - |(2), - |(3), - |(4), - |(NULL), - |(5), - |(6), - |(7), - |(8)""".stripMargin) + // uints, uint8 and uint16 not working yet + // checkSparkAnswerAndOperator("SELECT hex(_9), hex(_10), hex(_11), hex(_12) FROM tbl") + checkSparkAnswerAndOperator("SELECT hex(_11), hex(_12) FROM tbl") - checkSparkAnswerAndOperator(s"SELECT hex(col) FROM $int_table") + // strings, binary + checkSparkAnswerAndOperator("SELECT hex(_8), hex(_14) FROM tbl") + } + } } } test("unhex") { From 1caa6fdbbd8f6ecf9ad68b84a161ca8e18ac748d Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 21 May 2024 12:43:22 -0700 Subject: [PATCH 03/13] test: add more columns to spark test --- .../expressions/scalar_funcs/hex.rs | 34 ++++++++++++------- .../apache/comet/CometExpressionSuite.scala | 12 ++----- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index 8a5df6a980..0e1880b160 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -27,18 +27,6 @@ use datafusion_common::{ }; use std::fmt::Write; -fn hex_bytes(bytes: &[u8]) -> Vec { - let length = bytes.len(); - let mut value = vec![0; length * 2]; - let mut i = 0; - while i < length { - value[i * 2] = (bytes[i] & 0xF0) >> 4; - value[i * 2 + 1] = bytes[i] & 0x0F; - i += 1; - } - value -} - fn hex_int64(num: i64) -> String { if num >= 0 { format!("{:X}", num) @@ -71,6 +59,18 @@ fn hex_int8(num: i8) -> String { } } +fn hex_bytes(bytes: &[u8]) -> Vec { + let length = bytes.len(); + let mut value = vec![0; length * 2]; + let mut i = 0; + while i < length { + value[i * 2] = (bytes[i] & 0xF0) >> 4; + value[i * 2 + 1] = bytes[i] & 0x0F; + i += 1; + } + value +} + fn hex_string(s: &str) -> Vec { hex_bytes(s.as_bytes()) } @@ -230,6 +230,16 @@ mod test { assert_eq!(hexed, vec![0, 1, 0, 2, 0, 3, 0, 4]); } + #[test] + fn test_hex_string() { + let s = "1234"; + let hexed = super::hex_string(s); + assert_eq!(hexed, vec![0x31, 0x32, 0x33, 0x34]); + + let hexed_string = super::hex_bytes_to_string(&hexed).unwrap(); + assert_eq!(hexed_string, "31323334".to_string()); + } + #[test] fn test_hex_bytes_to_string() -> Result<(), std::fmt::Error> { let bytes = [0x01, 0x02, 0x03, 0x04]; diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 49b1657dec..2b4a7a2edb 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1045,15 +1045,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) withParquetTable(path.toString, "tbl") { - // ints - checkSparkAnswerAndOperator("SELECT hex(_2), hex(_3), hex(_4), hex(_5) FROM tbl") - - // uints, uint8 and uint16 not working yet - // checkSparkAnswerAndOperator("SELECT hex(_9), hex(_10), hex(_11), hex(_12) FROM tbl") - checkSparkAnswerAndOperator("SELECT hex(_11), hex(_12) FROM tbl") - - // strings, binary - checkSparkAnswerAndOperator("SELECT hex(_8), hex(_14) FROM tbl") + // _9 and _10 (uint8 and uint16), and _13 (Dictionary(Int32, Utf8)) are not supported + checkSparkAnswerAndOperator( + "SELECT hex(_1), hex(_2), hex(_3), hex(_4), hex(_5), hex(_6), hex(_7), hex(_8), hex(_11), hex(_12), hex(_14), hex(_15), hex(_16), hex(_17), hex(_18), hex(_19), hex(_20) FROM tbl") } } } From 385def84b723aee8e5a9ea97e2e3561cf8fc1f12 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 21 May 2024 13:55:24 -0700 Subject: [PATCH 04/13] refactor: remove extra rust code --- .../expressions/scalar_funcs/hex.rs | 113 +----------------- 1 file changed, 3 insertions(+), 110 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index 0e1880b160..a26ee1ab56 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -18,11 +18,11 @@ use std::sync::Arc; use arrow::array::as_string_array; -use arrow_array::{Int16Array, Int8Array, StringArray}; +use arrow_array::StringArray; use arrow_schema::DataType; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{ - cast::{as_binary_array, as_fixed_size_binary_array, as_int32_array, as_int64_array}, + cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, exec_err, DataFusionError, ScalarValue, }; use std::fmt::Write; @@ -35,30 +35,6 @@ fn hex_int64(num: i64) -> String { } } -fn hex_int32(num: i32) -> String { - if num >= 0 { - format!("{:X}", num) - } else { - format!("{:08X}", num as u32) - } -} - -fn hex_int16(num: i16) -> String { - if num >= 0 { - format!("{:X}", num) - } else { - format!("{:04X}", num as u16) - } -} - -fn hex_int8(num: i8) -> String { - if num >= 0 { - format!("{:X}", num) - } else { - format!("{:02X}", num as u8) - } -} - fn hex_bytes(bytes: &[u8]) -> Vec { let length = bytes.len(); let mut value = vec![0; length * 2]; @@ -100,47 +76,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { - let array = as_int32_array(array)?; - - let hexed: Vec> = array.iter().map(|v| v.map(hex_int32)).collect(); - - let string_array = StringArray::from(hexed); - Ok(ColumnarValue::Array(Arc::new(string_array))) - } - DataType::Int16 => { - let array = array.as_any().downcast_ref::().unwrap(); - - let hexed: Vec> = array.iter().map(|v| v.map(hex_int16)).collect(); - - let string_array = StringArray::from(hexed); - Ok(ColumnarValue::Array(Arc::new(string_array))) - } - DataType::Int8 => { - let array = array.as_any().downcast_ref::().unwrap(); - - let hexed: Vec> = array.iter().map(|v| v.map(hex_int8)).collect(); - - let string_array = StringArray::from(hexed); - Ok(ColumnarValue::Array(Arc::new(string_array))) - } - DataType::UInt64 => { - let array = as_int64_array(array)?; - - let hexed: Vec> = array.iter().map(|v| v.map(hex_int64)).collect(); - - let string_array = StringArray::from(hexed); - Ok(ColumnarValue::Array(Arc::new(string_array))) - } - DataType::UInt8 => { - let array = array.as_any().downcast_ref::().unwrap(); - - let hexed: Vec> = array.iter().map(|v| v.map(hex_int8)).collect(); - - let string_array = StringArray::from(hexed); - Ok(ColumnarValue::Array(Arc::new(string_array))) - } - DataType::Utf8 => { + DataType::Utf8 | DataType::LargeUtf8 => { let array = as_string_array(array); let hexed: Vec> = array @@ -230,16 +166,6 @@ mod test { assert_eq!(hexed, vec![0, 1, 0, 2, 0, 3, 0, 4]); } - #[test] - fn test_hex_string() { - let s = "1234"; - let hexed = super::hex_string(s); - assert_eq!(hexed, vec![0x31, 0x32, 0x33, 0x34]); - - let hexed_string = super::hex_bytes_to_string(&hexed).unwrap(); - assert_eq!(hexed_string, "31323334".to_string()); - } - #[test] fn test_hex_bytes_to_string() -> Result<(), std::fmt::Error> { let bytes = [0x01, 0x02, 0x03, 0x04]; @@ -253,39 +179,6 @@ mod test { Ok(()) } - #[test] - fn test_hex_i8() { - let num = 123; - let hexed = super::hex_int8(num); - assert_eq!(hexed, "7B".to_string()); - - let num = -1; - let hexed = super::hex_int8(num); - assert_eq!(hexed, "FF".to_string()); - } - - #[test] - fn test_hex_i16() { - let num = 1234; - let hexed = super::hex_int16(num); - assert_eq!(hexed, "4D2".to_string()); - - let num = -1; - let hexed = super::hex_int16(num); - assert_eq!(hexed, "FFFF".to_string()); - } - - #[test] - fn test_hex_i32() { - let num = 1234; - let hexed = super::hex_int32(num); - assert_eq!(hexed, "4D2".to_string()); - - let num = -1; - let hexed = super::hex_int32(num); - assert_eq!(hexed, "FFFFFFFF".to_string()); - } - #[test] fn test_hex_int64() { let num = 1234; From 8033c23a4cc5e1bcb60f7bfa5d5e0753100d62f3 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Tue, 21 May 2024 17:57:41 -0700 Subject: [PATCH 05/13] feat: support dictionary --- .../expressions/scalar_funcs/hex.rs | 165 +++++++++++++++++- .../apache/comet/CometExpressionSuite.scala | 4 +- 2 files changed, 163 insertions(+), 6 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index a26ee1ab56..a745067e19 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -17,7 +17,10 @@ use std::sync::Arc; -use arrow::array::as_string_array; +use arrow::{ + array::{as_dictionary_array, as_string_array}, + datatypes::Int32Type, +}; use arrow_array::StringArray; use arrow_schema::DataType; use datafusion::logical_expr::ColumnarValue; @@ -112,8 +115,69 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { + let dict = as_dictionary_array::(&array); + + let hexed_values = as_int64_array(dict.values())?; + let values = hexed_values + .iter() + .map(|v| v.map(hex_int64)) + .collect::>(); + + let keys = dict.keys().clone(); + let mut new_keys = Vec::with_capacity(values.len()); + + for key in keys.iter() { + let key = key.map(|k| values[k as usize].clone()).unwrap_or(None); + new_keys.push(key); + } + + let string_array_values = StringArray::from(new_keys); + Ok(ColumnarValue::Array(Arc::new(string_array_values))) + } + DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Utf8) => { + let dict = as_dictionary_array::(&array); + + let hexed_values = as_string_array(dict.values()); + let values: Vec> = hexed_values + .iter() + .map(|v| v.map(|v| hex_bytes_to_string(&hex_string(v))).transpose()) + .collect::>()?; + + let keys = dict.keys().clone(); + + let mut new_keys = Vec::with_capacity(values.len()); + + for key in keys.iter() { + let key = key.map(|k| values[k as usize].clone()).unwrap_or(None); + new_keys.push(key); + } + + let string_array_values = StringArray::from(new_keys); + Ok(ColumnarValue::Array(Arc::new(string_array_values))) + } + DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Binary) => { + let dict = as_dictionary_array::(&array); + + let hexed_values = as_binary_array(dict.values())?; + let values: Vec> = hexed_values + .iter() + .map(|v| v.map(|v| hex_bytes_to_string(&hex_bytes(v))).transpose()) + .collect::>()?; + + let keys = dict.keys().clone(); + let mut new_keys = Vec::with_capacity(values.len()); + + for key in keys.iter() { + let key = key.map(|k| values[k as usize].clone()).unwrap_or(None); + new_keys.push(key); + } + + let string_array_values = StringArray::from(new_keys); + Ok(ColumnarValue::Array(Arc::new(string_array_values))) + } _ => exec_err!( - "hex expects a string, binary or integer argument, got {:?}", + "hex got an unexpected argument type: {:?}", array.data_type() ), }, @@ -144,7 +208,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result exec_err!( - "hex expects a string, binary or integer argument, got {:?}", + "hex got an unexpected argument type: {:?}", scalar.data_type() ), }, @@ -155,10 +219,103 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result::new(); + input_builder.append_value("hi"); + input_builder.append_value("bye"); + input_builder.append_null(); + input_builder.append_value("rust"); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("6869"); + string_builder.append_value("627965"); + string_builder.append_null(); + string_builder.append_value("72757374"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_int64() { + let mut input_builder = PrimitiveDictionaryBuilder::::new(); + input_builder.append_value(1); + input_builder.append_value(2); + input_builder.append_null(); + input_builder.append_value(3); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("1"); + string_builder.append_value("2"); + string_builder.append_null(); + string_builder.append_value("3"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_binary() { + let mut input_builder = BinaryDictionaryBuilder::::new(); + input_builder.append_value("1"); + input_builder.append_value("1"); + input_builder.append_null(); + input_builder.append_value("3"); + let input = input_builder.finish(); + + let mut expected_builder = StringBuilder::new(); + expected_builder.append_value("31"); + expected_builder.append_value("31"); + expected_builder.append_null(); + expected_builder.append_value("33"); + let expected = expected_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + #[test] fn test_hex_bytes() { let bytes = [0x01, 0x02, 0x03, 0x04]; diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 2b4a7a2edb..5e4de75bce 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1045,9 +1045,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) withParquetTable(path.toString, "tbl") { - // _9 and _10 (uint8 and uint16), and _13 (Dictionary(Int32, Utf8)) are not supported + // _9 and _10 (uint8 and uint16) not supported checkSparkAnswerAndOperator( - "SELECT hex(_1), hex(_2), hex(_3), hex(_4), hex(_5), hex(_6), hex(_7), hex(_8), hex(_11), hex(_12), hex(_14), hex(_15), hex(_16), hex(_17), hex(_18), hex(_19), hex(_20) FROM tbl") + "SELECT hex(_1), hex(_2), hex(_3), hex(_4), hex(_5), hex(_6), hex(_7), hex(_8), hex(_11), hex(_12), hex(_13), hex(_14), hex(_15), hex(_16), hex(_17), hex(_18), hex(_19), hex(_20) FROM tbl") } } } From 8c049a9d11963d762745750bd208f99f877646be Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Fri, 24 May 2024 10:00:33 -0700 Subject: [PATCH 06/13] fix: simplify hex_int64 --- .../execution/datafusion/expressions/scalar_funcs/hex.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index a745067e19..7896959455 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -31,11 +31,7 @@ use datafusion_common::{ use std::fmt::Write; fn hex_int64(num: i64) -> String { - if num >= 0 { - format!("{:X}", num) - } else { - format!("{:016X}", num as u64) - } + format!("{:X}", num) } fn hex_bytes(bytes: &[u8]) -> Vec { From 1539d6a454034d3f7a83cde2230197c328805514 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Fri, 24 May 2024 10:06:43 -0700 Subject: [PATCH 07/13] refactor: combine functions for hex byte/string --- .../expressions/scalar_funcs/hex.rs | 56 +++++-------------- 1 file changed, 13 insertions(+), 43 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index 7896959455..fae628de61 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -34,30 +34,22 @@ fn hex_int64(num: i64) -> String { format!("{:X}", num) } -fn hex_bytes(bytes: &[u8]) -> Vec { +fn hex_bytes(bytes: &[u8]) -> Result { let length = bytes.len(); - let mut value = vec![0; length * 2]; + let mut hex_string = String::with_capacity(bytes.len() * 2); let mut i = 0; while i < length { - value[i * 2] = (bytes[i] & 0xF0) >> 4; - value[i * 2 + 1] = bytes[i] & 0x0F; + write!(&mut hex_string, "{:X}", (bytes[i] & 0xF0) >> 4)?; + write!(&mut hex_string, "{:X}", bytes[i] & 0x0F)?; i += 1; } - value + Ok(hex_string) } -fn hex_string(s: &str) -> Vec { +fn hex_string(s: &str) -> Result { hex_bytes(s.as_bytes()) } -fn hex_bytes_to_string(bytes: &[u8]) -> Result { - let mut hex_string = String::with_capacity(bytes.len() * 2); - for byte in bytes { - write!(&mut hex_string, "{:X}", byte)?; - } - Ok(hex_string) -} - pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { if args.len() != 1 { return Err(DataFusionError::Internal( @@ -80,7 +72,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result> = array .iter() - .map(|v| v.map(|v| hex_bytes_to_string(&hex_string(v))).transpose()) + .map(|v| v.map(|v| hex_string(v)).transpose()) .collect::>()?; let string_array = StringArray::from(hexed); @@ -92,7 +84,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result> = array .iter() - .map(|v| v.map(|v| hex_bytes_to_string(&hex_bytes(v))).transpose()) + .map(|v| v.map(|v| hex_bytes(v)).transpose()) .collect::>()?; let string_array = StringArray::from(hexed); @@ -104,7 +96,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result> = array .iter() - .map(|v| v.map(|v| hex_bytes_to_string(&hex_bytes(v))).transpose()) + .map(|v| v.map(|v| hex_bytes(v)).transpose()) .collect::>()?; let string_array = StringArray::from(hexed); @@ -137,7 +129,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result> = hexed_values .iter() - .map(|v| v.map(|v| hex_bytes_to_string(&hex_string(v))).transpose()) + .map(|v| v.map(|v| hex_string(v)).transpose()) .collect::>()?; let keys = dict.keys().clone(); @@ -158,7 +150,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result> = hexed_values .iter() - .map(|v| v.map(|v| hex_bytes_to_string(&hex_bytes(v))).transpose()) + .map(|v| v.map(|v| hex_bytes(v)).transpose()) .collect::>()?; let keys = dict.keys().clone(); @@ -186,14 +178,12 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { - let hex_bytes = hex_bytes(v); - let hex_string = hex_bytes_to_string(&hex_bytes)?; + let hex_string = hex_bytes(v)?; Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(hex_string)))) } ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => { - let hex_bytes = hex_string(v); - let hex_string = hex_bytes_to_string(&hex_bytes)?; + let hex_string = hex_string(v)?; Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(hex_string)))) } @@ -312,26 +302,6 @@ mod test { assert_eq!(result, &expected); } - #[test] - fn test_hex_bytes() { - let bytes = [0x01, 0x02, 0x03, 0x04]; - let hexed = super::hex_bytes(&bytes); - assert_eq!(hexed, vec![0, 1, 0, 2, 0, 3, 0, 4]); - } - - #[test] - fn test_hex_bytes_to_string() -> Result<(), std::fmt::Error> { - let bytes = [0x01, 0x02, 0x03, 0x04]; - let hexed = super::hex_bytes_to_string(&bytes)?; - assert_eq!(hexed, "1234".to_string()); - - let large_bytes = [0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0]; - let hexed = super::hex_bytes_to_string(&large_bytes)?; - assert_eq!(hexed, "123456789ABCDEF0".to_string()); - - Ok(()) - } - #[test] fn test_hex_int64() { let num = 1234; From ecd2876573dec6adb39684abb402f5e9edc0c68b Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Fri, 24 May 2024 10:13:06 -0700 Subject: [PATCH 08/13] refactor: update vec collection --- .../expressions/scalar_funcs/hex.rs | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index fae628de61..ea465da2b2 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -62,46 +62,40 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { let array = as_int64_array(array)?; - let hexed: Vec> = array.iter().map(|v| v.map(hex_int64)).collect(); + let hexed_array: StringArray = + array.iter().map(|v| v.map(|v| hex_int64(v))).collect(); - let string_array = StringArray::from(hexed); - Ok(ColumnarValue::Array(Arc::new(string_array))) + Ok(ColumnarValue::Array(Arc::new(hexed_array))) } DataType::Utf8 | DataType::LargeUtf8 => { let array = as_string_array(array); - let hexed: Vec> = array + let hexed: StringArray = array .iter() .map(|v| v.map(|v| hex_string(v)).transpose()) .collect::>()?; - let string_array = StringArray::from(hexed); - - Ok(ColumnarValue::Array(Arc::new(string_array))) + Ok(ColumnarValue::Array(Arc::new(hexed))) } DataType::Binary => { let array = as_binary_array(array)?; - let hexed: Vec> = array + let hexed: StringArray = array .iter() .map(|v| v.map(|v| hex_bytes(v)).transpose()) .collect::>()?; - let string_array = StringArray::from(hexed); - - Ok(ColumnarValue::Array(Arc::new(string_array))) + Ok(ColumnarValue::Array(Arc::new(hexed))) } DataType::FixedSizeBinary(_) => { let array = as_fixed_size_binary_array(array)?; - let hexed: Vec> = array + let hexed: StringArray = array .iter() .map(|v| v.map(|v| hex_bytes(v)).transpose()) .collect::>()?; - let string_array = StringArray::from(hexed); - - Ok(ColumnarValue::Array(Arc::new(string_array))) + Ok(ColumnarValue::Array(Arc::new(hexed))) } DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Int64) => { let dict = as_dictionary_array::(&array); From ecf57a802b59ed142086eeabf744cfd67be48d7d Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Fri, 24 May 2024 10:25:14 -0700 Subject: [PATCH 09/13] refactor: refactor hex to support byte ref --- .../expressions/scalar_funcs/hex.rs | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index ea465da2b2..7f7c2fa13a 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -34,22 +34,16 @@ fn hex_int64(num: i64) -> String { format!("{:X}", num) } -fn hex_bytes(bytes: &[u8]) -> Result { +fn hex_bytes>(bytes: T) -> Result { + let bytes = bytes.as_ref(); let length = bytes.len(); - let mut hex_string = String::with_capacity(bytes.len() * 2); - let mut i = 0; - while i < length { - write!(&mut hex_string, "{:X}", (bytes[i] & 0xF0) >> 4)?; - write!(&mut hex_string, "{:X}", bytes[i] & 0x0F)?; - i += 1; + let mut hex_string = String::with_capacity(length * 2); + for &byte in bytes { + write!(&mut hex_string, "{:02X}", byte)?; } Ok(hex_string) } -fn hex_string(s: &str) -> Result { - hex_bytes(s.as_bytes()) -} - pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { if args.len() != 1 { return Err(DataFusionError::Internal( @@ -72,7 +66,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result>()?; Ok(ColumnarValue::Array(Arc::new(hexed))) @@ -123,7 +117,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result> = hexed_values .iter() - .map(|v| v.map(|v| hex_string(v)).transpose()) + .map(|v| v.map(|v| hex_bytes(v)).transpose()) .collect::>()?; let keys = dict.keys().clone(); @@ -177,7 +171,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { - let hex_string = hex_string(v)?; + let hex_string = hex_bytes(v)?; Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(hex_string)))) } From 5e178974abe56283b5f13c6a683df0a705dd8116 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Fri, 24 May 2024 10:29:01 -0700 Subject: [PATCH 10/13] style: fix clippy --- .../datafusion/expressions/scalar_funcs/hex.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index 7f7c2fa13a..8c3085f936 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -56,8 +56,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { let array = as_int64_array(array)?; - let hexed_array: StringArray = - array.iter().map(|v| v.map(|v| hex_int64(v))).collect(); + let hexed_array: StringArray = array.iter().map(|v| v.map(hex_int64)).collect(); Ok(ColumnarValue::Array(Arc::new(hexed_array))) } @@ -66,7 +65,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result>()?; Ok(ColumnarValue::Array(Arc::new(hexed))) @@ -76,7 +75,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result>()?; Ok(ColumnarValue::Array(Arc::new(hexed))) @@ -86,7 +85,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result>()?; Ok(ColumnarValue::Array(Arc::new(hexed))) @@ -117,7 +116,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result> = hexed_values .iter() - .map(|v| v.map(|v| hex_bytes(v)).transpose()) + .map(|v| v.map(hex_bytes).transpose()) .collect::>()?; let keys = dict.keys().clone(); @@ -138,7 +137,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result> = hexed_values .iter() - .map(|v| v.map(|v| hex_bytes(v)).transpose()) + .map(|v| v.map(hex_bytes).transpose()) .collect::>()?; let keys = dict.keys().clone(); From 88bdcde1babcd75cb52e18164e1c945bbbbbf857 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Fri, 24 May 2024 15:51:56 -0700 Subject: [PATCH 11/13] refactor: remove scalar handling --- .../expressions/scalar_funcs/hex.rs | 32 ++----------------- 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index 8c3085f936..bb37930684 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -26,7 +26,7 @@ use arrow_schema::DataType; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{ cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, - exec_err, DataFusionError, ScalarValue, + exec_err, DataFusionError, }; use std::fmt::Write; @@ -156,35 +156,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result match scalar { - ScalarValue::Int64(Some(v)) => { - let hex_string = hex_int64(*v); - - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(hex_string)))) - } - ScalarValue::Binary(Some(v)) - | ScalarValue::LargeBinary(Some(v)) - | ScalarValue::FixedSizeBinary(_, Some(v)) => { - let hex_string = hex_bytes(v)?; - - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(hex_string)))) - } - ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => { - let hex_string = hex_bytes(v)?; - - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(hex_string)))) - } - ScalarValue::Int64(None) - | ScalarValue::Utf8(None) - | ScalarValue::Binary(None) - | ScalarValue::FixedSizeBinary(_, None) => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) - } - _ => exec_err!( - "hex got an unexpected argument type: {:?}", - scalar.data_type() - ), - }, + _ => exec_err!("native hex does not support scalar values at this time"), } } From e7062c61a42ec254ec5db203845ec66e21a56bf9 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Wed, 29 May 2024 16:31:51 -0700 Subject: [PATCH 12/13] style: new lines in expression test file --- .../src/test/scala/org/apache/comet/CometExpressionSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 5e4de75bce..a2c7f41974 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1038,6 +1038,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + test("hex") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => @@ -1052,6 +1053,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + test("unhex") { // When running against Spark 3.2, we include a bug fix for https://issues.apache.org/jira/browse/SPARK-40924 that // was added in Spark 3.3, so although Comet's behavior is more correct when running against Spark 3.2, it is not From 129f00ab1379f16ac6f3a68bb1dc73dc1fcebeb8 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Wed, 29 May 2024 16:35:53 -0700 Subject: [PATCH 13/13] fix: handle large strings --- .../datafusion/expressions/scalar_funcs/hex.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs index bb37930684..ea572574a1 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use arrow::{ - array::{as_dictionary_array, as_string_array}, + array::{as_dictionary_array, as_largestring_array, as_string_array}, datatypes::Int32Type, }; use arrow_array::StringArray; @@ -60,7 +60,7 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { + DataType::Utf8 => { let array = as_string_array(array); let hexed: StringArray = array @@ -70,6 +70,16 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { + let array = as_largestring_array(array); + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } DataType::Binary => { let array = as_binary_array(array)?;