diff --git a/build.sbt b/build.sbt index 64180f7c..b8cd33b4 100644 --- a/build.sbt +++ b/build.sbt @@ -19,7 +19,7 @@ ThisBuild / organization := "app.softnetwork" name := "softclient4es" -ThisBuild / version := "0.7.0" +ThisBuild / version := "0.8.0" ThisBuild / scalaVersion := scala213 diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 436a9f88..a3bcaac2 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -2319,4 +2319,103 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("\\(double\\)(\\d)", "(double) $1") } + it should "handle string function as script field and condition" in { + val select: ElasticSearchRequest = + SQLQuery(string) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "def left = (def e1 = (def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null ? e0.trim() : null); e1 != null ? e1.length() : null); left == null ? false : left > 10" + | } + | } + | } + | ] + | } + | }, + | "script_fields": { + | "len": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null ? e0.length() : null)" + | } + | }, + | "lower": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null ? e0.lower() : null)" + | } + | }, + | "upper": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null ? e0.upper() : null)" + | } + | }, + | "substr": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); (arg0 == null) ? null : ((1 - 1) < 0 || (1 - 1 + 3) > arg0.length()) ? null : arg0.substring((1 - 1), (1 - 1 + 3)))" + | } + | }, + | "trim": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null ? e0.trim() : null)" + | } + | }, + | "concat": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); (arg0 == null) ? null : String.valueOf(arg0) + \"_test\" + String.valueOf(1))" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defa", "def a") + .replaceAll("defe", "def e") + .replaceAll("defl", "def l") + .replaceAll("def_", "def _") + .replaceAll("=_", " = _") + .replaceAll(",_", ", _") + .replaceAll(",\\(", ", (") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll(":\\(", " : (") + .replaceAll(":0", " : 0") + .replaceAll(",(\\d)", ", $1") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") + .replaceAll("; if", ";if") + .replaceAll("==", " == ") + .replaceAll("\\+", " + ") + .replaceAll("-", " - ") + .replaceAll("\\*", " * ") + .replaceAll("/", " / ") + .replaceAll(">", " > ") + .replaceAll("<", " < ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll("false:", "false : ") + } + } diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index f19f2457..24551fa0 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -2308,4 +2308,103 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("\\(double\\)(\\d)", "(double) $1") } + it should "handle string function as script field and condition" in { + val select: ElasticSearchRequest = + SQLQuery(string) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "def left = (def e1 = (def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null ? e0.trim() : null); e1 != null ? e1.length() : null); left == null ? false : left > 10" + | } + | } + | } + | ] + | } + | }, + | "script_fields": { + | "len": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null ? e0.length() : null)" + | } + | }, + | "lower": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null ? e0.lower() : null)" + | } + | }, + | "upper": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null ? e0.upper() : null)" + | } + | }, + | "substr": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); (arg0 == null) ? null : ((1 - 1) < 0 || (1 - 1 + 3) > arg0.length()) ? null : arg0.substring((1 - 1), (1 - 1 + 3)))" + | } + | }, + | "trim": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null ? e0.trim() : null)" + | } + | }, + | "concat": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); (arg0 == null) ? null : String.valueOf(arg0) + \"_test\" + String.valueOf(1))" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defa", "def a") + .replaceAll("defe", "def e") + .replaceAll("defl", "def l") + .replaceAll("def_", "def _") + .replaceAll("=_", " = _") + .replaceAll(",_", ", _") + .replaceAll(",\\(", ", (") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll(":\\(", " : (") + .replaceAll(":0", " : 0") + .replaceAll(",(\\d)", ", $1") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") + .replaceAll("; if", ";if") + .replaceAll("==", " == ") + .replaceAll("\\+", " + ") + .replaceAll("-", " - ") + .replaceAll("\\*", " * ") + .replaceAll("/", " / ") + .replaceAll(">", " > ") + .replaceAll("<", " < ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll("false:", "false : ") + } + } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 8bf6ab49..b101a66f 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -1025,3 +1025,98 @@ case class SQLAtan2(y: PainlessScript, x: PainlessScript) extends MathematicalFu override def args: List[PainlessScript] = List(y, x) override def nullable: Boolean = y.nullable || x.nullable } + +sealed trait StringFunction[Out <: SQLType] + extends SQLTransformFunction[SQLVarchar, Out] + with SQLFunctionWithIdentifier { + override def inputType: SQLVarchar = SQLTypes.Varchar + + override def outputType: Out + + def operator: SQLStringOperator + + override def fun: Option[PainlessScript] = Some(operator) + + override def identifier: SQLIdentifier = SQLIdentifier("", functions = this :: Nil) + + override def toSQL(base: String): String = s"$sql($base)" + + override def sql: String = + if (args.isEmpty) + s"${fun.map(_.sql).getOrElse("")}" + else + super.sql +} + +case class SQLStringFunction(operator: SQLStringOperator) extends StringFunction[SQLVarchar] { + override def outputType: SQLVarchar = SQLTypes.Varchar + override def args: List[PainlessScript] = List.empty + +} + +case class SQLSubstring(str: PainlessScript, start: Int, length: Option[Int]) + extends StringFunction[SQLVarchar] { + override def outputType: SQLVarchar = SQLTypes.Varchar + override def operator: SQLStringOperator = Substring + + override def args: List[PainlessScript] = + List(str, SQLIntValue(start)) ++ length.map(l => SQLIntValue(l)).toList + + override def nullable: Boolean = str.nullable + + override def toPainlessCall(callArgs: List[String]): String = { + callArgs match { + // SUBSTRING(expr, start, length) + case List(arg0, arg1, arg2) => + s"(($arg1 - 1) < 0 || ($arg1 - 1 + $arg2) > $arg0.length()) ? null : $arg0.substring(($arg1 - 1), ($arg1 - 1 + $arg2))" + + // SUBSTRING(expr, start) + case List(arg0, arg1) => + s"(($arg1 - 1) < 0 || ($arg1 - 1) >= $arg0.length()) ? null : $arg0.substring(($arg1 - 1))" + + case _ => throw new IllegalArgumentException("SUBSTRING requires 2 or 3 arguments") + } + } + + override def validate(): Either[String, Unit] = + if (start < 1) + Left("SUBSTRING start position must be greater than or equal to 1 (SQL is 1-based)") + else if (length.exists(_ < 0)) + Left("SUBSTRING length must be non-negative") + else Right(()) + + override def toSQL(base: String): String = sql + +} + +case class SQLConcat(values: List[PainlessScript]) extends StringFunction[SQLVarchar] { + override def outputType: SQLVarchar = SQLTypes.Varchar + override def operator: SQLStringOperator = Concat + + override def args: List[PainlessScript] = values + + override def nullable: Boolean = values.exists(_.nullable) + + override def toPainlessCall(callArgs: List[String]): String = { + if (callArgs.isEmpty) + throw new IllegalArgumentException("CONCAT requires at least one argument") + else + callArgs.zipWithIndex + .map { case (arg, idx) => + SQLTypeUtils.coerce(arg, values(idx).out, SQLTypes.Varchar, nullable = false) + } + .mkString(operator.painless) + } + + override def validate(): Either[String, Unit] = + if (values.isEmpty) Left("CONCAT requires at least one argument") + else Right(()) + + override def toSQL(base: String): String = sql +} + +case object SQLLength extends StringFunction[SQLBigInt] { + override def outputType: SQLBigInt = SQLTypes.BigInt + override def operator: SQLStringOperator = Length + override def args: List[PainlessScript] = List.empty +} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala index cb572ef9..9dbe21c2 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala @@ -99,6 +99,23 @@ case object IsNotNull extends SQLExpr("is not null") with SQLComparisonOperator case object Match extends SQLExpr("match") with SQLComparisonOperator case object Against extends SQLExpr("against") with SQLRegex +sealed trait SQLStringOperator extends SQLOperator { + override def painless: String = s".${sql.toLowerCase()}()" +} +case object Concat extends SQLExpr("concat") with SQLStringOperator { + override def painless: String = " + " +} +case object Lower extends SQLExpr("lower") with SQLStringOperator +case object Upper extends SQLExpr("upper") with SQLStringOperator +case object Trim extends SQLExpr("trim") with SQLStringOperator +//case object LTrim extends SQLExpr("ltrim") with SQLStringOperator +//case object RTrim extends SQLExpr("rtrim") with SQLStringOperator +case object Substring extends SQLExpr("substring") with SQLStringOperator { + override def painless: String = ".substring" +} +case object To extends SQLExpr("to") with SQLRegex +case object Length extends SQLExpr("length") with SQLStringOperator + sealed trait SQLLogicalOperator extends SQLExpressionOperator case object Not extends SQLExpr("not") with SQLLogicalOperator diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 201c18cc..21050a86 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -487,8 +487,48 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => mf => mf.identifier } + def concatFunction: PackratParser[StringFunction[SQLVarchar]] = + Concat.regex ~ start ~ rep1sep(valueExpr, separator) ~ end ^^ { case _ ~ _ ~ vs ~ _ => + SQLConcat(vs) + } + + def substringFunction: PackratParser[StringFunction[SQLVarchar]] = + Substring.regex ~ start ~ valueExpr ~ (From.regex | separator) ~ long ~ ((To.regex | separator) ~ long).? ~ end ^^ { + case _ ~ _ ~ v ~ _ ~ s ~ eOpt ~ _ => + SQLSubstring(v, s.value.toInt, eOpt.map { case _ ~ e => e.value.toInt }) + } + + def stringFunctionWithIdentifier: PackratParser[SQLIdentifier] = + (concatFunction | substringFunction) ^^ { sf => + sf.identifier + } + + def length: PackratParser[StringFunction[SQLBigInt]] = + Length.regex ^^ { _ => + SQLLength + } + + def lower: PackratParser[StringFunction[SQLVarchar]] = + Lower.regex ^^ { _ => + SQLStringFunction(Lower) + } + + def upper: PackratParser[StringFunction[SQLVarchar]] = + Upper.regex ^^ { _ => + SQLStringFunction(Upper) + } + + def trim: PackratParser[StringFunction[SQLVarchar]] = + Trim.regex ^^ { _ => + SQLStringFunction(Trim) + } + + def string_functions: Parser[ + StringFunction[_] + ] = /*concatFunction | substringFunction |*/ length | lower | upper | trim + def sql_functions: PackratParser[SQLFunction] = - aggregates | distance | date_diff | date_trunc | extractors | date_functions | datetime_functions | logical_functions + aggregates | distance | date_diff | date_trunc | extractors | date_functions | datetime_functions | logical_functions | string_functions //private val regexIdentifier = """[\*a-zA-Z_\-][a-zA-Z0-9_\-\.\[\]\*]*""" @@ -547,6 +587,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => "double", "pi", "boolean", + "distance", "time", "date", "datetime", @@ -595,7 +636,18 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => "acos", "tan", "atan", - "atan2" + "atan2", + "concat", + "substr", + "substring", + "to", + "length", + "lower", + "upper", + "trim" +// "ltrim", +// "rtrim", +// "replace", ) private val identifierRegexStr = @@ -681,7 +733,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => } def identifierWithTransformation: PackratParser[SQLIdentifier] = - mathematicalFunctionWithIdentifier | castFunctionWithIdentifier | conditionalFunctionWithIdentifier | dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier + mathematicalFunctionWithIdentifier | castFunctionWithIdentifier | conditionalFunctionWithIdentifier | dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier | stringFunctionWithIdentifier def identifierWithAggregation: PackratParser[SQLIdentifier] = aggregates ~ start ~ (identifierWithFunction | identifierWithIntervalFunction | identifier) ~ end ^^ { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala index fcf2ba0a..738ac852 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala @@ -125,6 +125,10 @@ object SQLTypeUtils { case (_, _) if from == to => return expr + // ---- Any -> VARCHAR ---- + case (_, SQLTypes.Varchar) => + s"String.valueOf($expr)" + // ---- PAR DEFAUT ---- case _ => return expr // fallback diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index b09f0f34..3137cfcc 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -158,6 +158,9 @@ object Queries { val mathematical: String = "select identifier, (abs(identifier) + 1.0) * 2, ceil(identifier), floor(identifier), sqrt(identifier), exp(identifier), log(identifier), log10(identifier), pow(identifier, 3), round(identifier), round(identifier, 2), sign(identifier), cos(identifier), acos(identifier), sin(identifier), asin(identifier), tan(identifier), atan(identifier), atan2(identifier, 3.0) from Table where sqrt(identifier) > 100.0" + + val string: String = + "select identifier, length(identifier2) as len, lower(identifier2) as lower, upper(identifier2) as upper, substring(identifier2, 1, 3) as substr, trim(identifier2) as trim, concat(identifier2, '_test', 1) as concat from Table where length(trim(identifier2)) > 10" } /** Created by smanciot on 15/02/17. @@ -624,4 +627,11 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { ) } + it should "parse string functions" in { + val result = SQLParser(string) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + string + ) + } + }