From 7a65f9c245351d3fcd404cdadbef9c2b84f0a1db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Tue, 9 Sep 2025 14:31:11 +0200 Subject: [PATCH 01/18] add SQLLogicalFunction and corresponding criteria, implements PainlessScript for every expression, mixed SQLToken with SQLValidation, validate the sql search request after parsing --- .../elastic/sql/bridge/ElasticQuery.scala | 6 +- .../elastic/sql/bridge/package.scala | 14 ++ .../elastic/sql/SQLQuerySpec.scala | 105 ++++++++++++ .../app/softnetwork/elastic/sql/SQLFrom.scala | 8 + .../softnetwork/elastic/sql/SQLFunction.scala | 48 +++++- .../softnetwork/elastic/sql/SQLGroupBy.scala | 37 ++-- .../softnetwork/elastic/sql/SQLHaving.scala | 2 + .../softnetwork/elastic/sql/SQLOperator.scala | 9 +- .../softnetwork/elastic/sql/SQLParser.scala | 89 ++++++---- .../elastic/sql/SQLSearchRequest.scala | 40 +++++ .../softnetwork/elastic/sql/SQLSelect.scala | 9 + .../app/softnetwork/elastic/sql/SQLType.scala | 13 +- .../softnetwork/elastic/sql/SQLTypes.scala | 2 + .../elastic/sql/SQLValidator.scala | 15 +- .../softnetwork/elastic/sql/SQLWhere.scala | 158 ++++++++++++++---- .../app/softnetwork/elastic/sql/package.scala | 18 +- .../elastic/sql/SQLParserSpec.scala | 32 ++++ 17 files changed, 498 insertions(+), 107 deletions(-) diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala index c4943563..a70a7681 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala @@ -13,7 +13,9 @@ import app.softnetwork.elastic.sql.{ SQLExpression, SQLIn, SQLIsNotNull, - SQLIsNull + SQLIsNotNullCriteria, + SQLIsNull, + SQLIsNullCriteria } import com.sksamuel.elastic4s.ElasticApi._ import com.sksamuel.elastic4s.searches.queries.Query @@ -71,6 +73,8 @@ case class ElasticQuery(filter: ElasticFilter) { case geoDistance: ElasticGeoDistance => geoDistance case matchExpression: ElasticMatch => matchExpression case dateMath: SQLComparisonDateMath => dateMath + case isNull: SQLIsNullCriteria => isNull + case isNotNull: SQLIsNotNullCriteria => isNotNull case other => throw new IllegalArgumentException(s"Unsupported filter type: ${other.getClass.getName}") } diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index 160ee4e2..654272e1 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -335,6 +335,20 @@ package object bridge { existsQuery(identifier.name) } + implicit def isNullCriteriaToQuery( + isNull: SQLIsNullCriteria + ): Query = { + import isNull._ + not(existsQuery(identifier.name)) + } + + implicit def isNotNullCriteriaToQuery( + isNotNull: SQLIsNotNullCriteria + ): Query = { + import isNotNull._ + existsQuery(identifier.name) + } + implicit def inToQuery[R, T <: SQLValue[R]](in: SQLIn[R, T]): Query = { import in._ val _values: Seq[Any] = values.innerValues diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 97892ef4..ee31f217 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1475,4 +1475,109 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") } + it should "handle is_null function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(isnull) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "flag": { + | "script": { + | "lang": "painless", + | "source": "(doc['identifier'].value == null)" + | } + | } + | }, + | "_source": true + |}""".stripMargin.replaceAll("\\s+", "").replaceAll("==", " == ") + } + + it should "handle is_notnull function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(isnotnull) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "flag": { + | "script": { + | "lang": "painless", + | "source": "(doc['identifier2'].value != null)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "").replaceAll("!=", " != ") + } + + it should "handle is_null criteria as must_not exists" in { + val select: ElasticSearchRequest = + SQLQuery(isNullCriteria) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "bool": { + | "must_not": [ + | { + | "exists": { + | "field": "identifier" + | } + | } + | ] + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "") + } + + it should "handle is_notnull criteria as exists" in { + val select: ElasticSearchRequest = + SQLQuery(isNotNullCriteria) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier" + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "") + } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFrom.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFrom.scala index 15f36837..dc7e65a9 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFrom.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFrom.scala @@ -27,4 +27,12 @@ case class SQLFrom(tables: Seq[SQLTable]) extends Updateable { } def update(request: SQLSearchRequest): SQLFrom = this.copy(tables = tables.map(_.update(request))) + + override def validate(): Either[String, Unit] = { + if (tables.isEmpty) { + Left("At least one table is required in FROM clause") + } else { + Right(()) + } + } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 96ca2453..ceb56479 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -4,6 +4,7 @@ import scala.util.matching.Regex sealed trait SQLFunction extends SQLRegex { def toSQL(base: String): String = if (base.nonEmpty) s"$sql($base)" else sql + def system: Boolean = false } sealed trait SQLFunctionWithIdentifier extends SQLFunction { @@ -26,7 +27,7 @@ object SQLFunctionUtils { } -trait SQLFunctionChain extends SQLFunction with SQLValidation { +trait SQLFunctionChain extends SQLFunction { def functions: List[SQLFunction] override def validate(): Either[String, Unit] = @@ -43,6 +44,10 @@ trait SQLFunctionChain extends SQLFunction with SQLValidation { } lazy val aggregation: Boolean = aggregateFunction.isDefined + + override def in: SQLType = functions.lastOption.map(_.in).getOrElse(super.in) + + override def out: SQLType = functions.headOption.map(_.out).getOrElse(super.out) } sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType] @@ -50,6 +55,8 @@ sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType] with PainlessScript { def inputType: In def outputType: Out + override def in: SQLType = inputType + override def out: SQLType = outputType } sealed trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType] @@ -174,7 +181,7 @@ case class SQLAddInterval(interval: TimeInterval) override def script: String = s"${operator.script}${interval.script}" } -case class SQLSubstractInterval(interval: TimeInterval) +case class SQLSubtractInterval(interval: TimeInterval) extends SQLExpr(interval.sql) with SQLArithmeticFunction[SQLDateTime, SQLDateTime] with MathScript { @@ -191,18 +198,29 @@ sealed trait DateFunction extends DateTimeFunction sealed trait TimeFunction extends DateTimeFunction -sealed trait CurrentDateTimeFunction extends DateTimeFunction with PainlessScript with MathScript { +sealed trait SystemFunction extends SQLFunction { + override def system: Boolean = true +} + +sealed trait CurrentDateTimeFunction + extends DateTimeFunction + with PainlessScript + with MathScript + with SystemFunction { override def painless: String = "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" override def script: String = "now" + override def out: SQLType = SQLTypes.DateTime } sealed trait CurrentDateFunction extends CurrentDateTimeFunction with DateFunction { override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalDate()" + override def out: SQLType = SQLTypes.Date } sealed trait CurrentTimeFunction extends CurrentDateTimeFunction with TimeFunction { override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalTime()" + override def out: SQLType = SQLTypes.Time } case object CurrentDate extends SQLExpr("current_date") with CurrentDateFunction @@ -396,3 +414,27 @@ case class FormatDateTime(identifier: SQLIdentifier, format: String) override def toPainless(base: String): String = s"DateTimeFormatter.ofPattern('$format').format($base)" } + +sealed trait SQLLogicalFunction[In <: SQLType] + extends SQLTransformFunction[In, SQLBool] + with SQLFunctionWithIdentifier { + def operator: SQLLogicalOperator + override def outputType: SQLBool = SQLTypes.Boolean + override def toPainless(base: String): String = s"($base$painless)" +} + +case class SQLIsNullFunction(identifier: SQLIdentifier) + extends SQLExpr("isnull") + with SQLLogicalFunction[SQLAny] { + override def operator: SQLLogicalOperator = IsNull + override def inputType: SQLAny = SQLTypes.Any + override def painless: String = s" == null" +} + +case class SQLIsNotNullFunction(identifier: SQLIdentifier) + extends SQLExpr("isnotnull") + with SQLLogicalFunction[SQLAny] { + override def operator: SQLLogicalOperator = IsNotNull + override def inputType: SQLAny = SQLTypes.Any + override def painless: String = s" != null" +} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala index e30e7f73..3fa97be2 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala @@ -9,6 +9,14 @@ case class SQLGroupBy(buckets: Seq[SQLBucket]) extends Updateable { lazy val bucketNames: Map[String, SQLBucket] = buckets.map { b => b.identifier.identifierName -> b }.toMap + + override def validate(): Either[String, Unit] = { + if (buckets.isEmpty) { + Left("At least one bucket is required in GROUP BY clause") + } else { + Right(()) + } + } } case class SQLBucket( @@ -95,19 +103,13 @@ object BucketSelectorScript { extractBucketsPath(left) ++ extractBucketsPath(right) case relation: ElasticRelation => extractBucketsPath(relation.criteria) case _: SQLMatch => Map.empty //MATCH is not supported in bucket_selector - case b: BinaryExpression => - import b._ - if (left.aggregation && right.aggregation) - Map(left.aliasOrName -> left.aliasOrName, right.aliasOrName -> right.aliasOrName) - else if (left.aggregation) - Map(left.aliasOrName -> left.aliasOrName) - else if (right.aggregation) - Map(right.aliasOrName -> right.aliasOrName) - else - Map.empty case e: Expression if e.aggregation => import e._ - Map(identifier.aliasOrName -> identifier.aliasOrName) + maybeValue match { + case Some(v: SQLIdentifier) if v.aggregation => + Map(identifier.aliasOrName -> identifier.aliasOrName, v.aliasOrName -> v.aliasOrName) + case _ => Map(identifier.aliasOrName -> identifier.aliasOrName) + } case _ => Map.empty } @@ -155,18 +157,7 @@ object BucketSelectorScript { case _: SQLMatch => "1 == 1" //MATCH is not supported in bucket_selector case e: Expression if e.aggregation => - val param = - s"params.${e.identifier.aliasOrName}" - e.maybeValue match { - case Some(v) => toPainless(param, e.operator, v, e.maybeNot.nonEmpty) - case None => - e.operator match { - case IsNull => s"$param == null" - case IsNotNull => s"$param != null" - case _ => - throw new IllegalArgumentException(s"Operator ${e.operator} requires a value") - } - } + e.painless case _ => "1 == 1" //throw new IllegalArgumentException(s"Unsupported SQLCriteria type: $expr") } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLHaving.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLHaving.scala index a96351da..97ed5dc9 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLHaving.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLHaving.scala @@ -9,4 +9,6 @@ case class SQLHaving(criteria: Option[SQLCriteria]) extends Updateable { } def update(request: SQLSearchRequest): SQLHaving = this.copy(criteria = criteria.map(_.update(request))) + + override def validate(): Either[String, Unit] = criteria.map(_.validate()).getOrElse(Right(())) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala index 21342118..6ee913cb 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala @@ -1,6 +1,13 @@ package app.softnetwork.elastic.sql -trait SQLOperator extends SQLToken +trait SQLOperator extends SQLToken with PainlessScript { + override def painless: String = this match { + case And => "&&" + case Or => "||" + case Not => "!" + case _ => sql + } +} sealed trait ArithmeticOperator extends SQLOperator with MathScript { override def toString: String = s" $sql " diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index ad1b1bd4..02642d89 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -27,8 +27,12 @@ object SQLParser def request: PackratParser[SQLSearchRequest] = { phrase(select ~ from ~ where.? ~ groupBy.? ~ having.? ~ orderBy.? ~ limit.?) ^^ { case s ~ f ~ w ~ g ~ h ~ o ~ l => - SQLSearchRequest(s, f, w, g, h, o, l) - .update() + val request = SQLSearchRequest(s, f, w, g, h, o, l).update() + request.validate() match { + case Left(error) => throw SQLValidationError(error) + case _ => + } + request } } @@ -143,9 +147,9 @@ trait SQLParser extends RegexParsers with PackratParsers { SQLAddInterval(it) } - def substractInterval: PackratParser[SQLSubstractInterval] = + def substractInterval: PackratParser[SQLSubtractInterval] = substract ~ interval ^^ { case _ ~ it => - SQLSubstractInterval(it) + SQLSubtractInterval(it) } def intervalFunction: PackratParser[SQLArithmeticFunction[SQLDateTime, SQLDateTime]] = @@ -251,16 +255,12 @@ trait SQLParser extends RegexParsers with PackratParsers { def distance: PackratParser[SQLFunction] = Distance.regex ^^ (_ => Distance) def painless_identifier: PackratParser[SQLIdentifier] = - repsep( + rep1sep( date_trunc | extractors | date_functions | datetime_functions, start ) ~ start.? ~ (identifierWithSystemFunction | identifierWithArithmeticFunction | identifier).? ~ rep( end ) ^^ { case f ~ _ ~ i ~ _ => - SQLValidator.validateChain(f) match { - case Left(error) => throw SQLValidationError(error) - case _ => - } i match { case Some(id) => id.copy(functions = id.functions ++ f) case None => @@ -281,8 +281,21 @@ trait SQLParser extends RegexParsers with PackratParsers { SQLIdentifier("", functions = dd :: Nil) } + def is_null: PackratParser[SQLLogicalFunction[_]] = + "(?i)isnull".r ~ start ~ (painless_identifier | identifierWithArithmeticFunction | identifier) ~ end ^^ { + case _ ~ _ ~ i ~ _ => SQLIsNullFunction(i) + } + + def is_notnull: PackratParser[SQLLogicalFunction[_]] = + "(?i)isnotnull".r ~ start ~ (painless_identifier | identifierWithArithmeticFunction | identifier) ~ end ^^ { + case _ ~ _ ~ i ~ _ => SQLIsNotNullFunction(i) + } + + def logical_functions: PackratParser[SQLLogicalFunction[_]] = + is_null | is_notnull + def sql_functions: PackratParser[SQLFunction] = - aggregates | distance | date_diff | date_trunc | extractors | date_functions | datetime_functions + aggregates | distance | date_diff | date_trunc | extractors | date_functions | datetime_functions | logical_functions private val regexIdentifier = """[\*a-zA-Z_\-][a-zA-Z0-9_\-\.\[\]\*]*""" @@ -305,8 +318,13 @@ trait SQLParser extends RegexParsers with PackratParsers { t.identifier.copy(functions = t +: t.identifier.functions) } + private[this] def logicalFunctionWithIdentifier: PackratParser[SQLIdentifier] = + (is_null | is_notnull) ^^ { t => + t.identifier.copy(functions = t +: t.identifier.functions) + } + def identifierWithTransformation: PackratParser[SQLIdentifier] = - dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier + dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier | logicalFunctionWithIdentifier def identifierWithArithmeticFunction: PackratParser[SQLIdentifier] = (identifierWithFunction | identifier) ~ arithmeticFunction ^^ { case i ~ af => @@ -326,10 +344,6 @@ trait SQLParser extends RegexParsers with PackratParsers { ) ~ start.? ~ (identifierWithSystemFunction | identifierWithArithmeticFunction | identifier).? ~ rep1( end ) ^^ { case f ~ _ ~ i ~ _ => - SQLValidator.validateChain(f) match { - case Left(error) => throw SQLValidationError(error) - case _ => - } i match { case None => f.lastOption match { @@ -347,7 +361,7 @@ trait SQLParser extends RegexParsers with PackratParsers { def alias: PackratParser[SQLAlias] = Alias.regex.? ~ regexAlias.r ^^ { case _ ~ b => SQLAlias(b) } def field: PackratParser[Field] = - (identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | identifierWithTransformation | date_diff_identifier | identifier) ~ alias.? ^^ { + (identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier) ~ alias.? ^^ { case i ~ a => SQLField(i, a) } @@ -406,15 +420,17 @@ trait SQLWhereParser { private def diff: PackratParser[SQLComparisonOperator] = Diff.sql ^^ (_ => Diff) + private def any_identifier: PackratParser[SQLIdentifier] = + identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier + private def equality: PackratParser[SQLExpression] = - not.? ~ (identifierWithAggregation | identifierWithFunction | identifier) ~ (eq | ne | diff) ~ (boolean | literal | double | long) ^^ { + not.? ~ any_identifier ~ (eq | ne | diff) ~ (boolean | literal | double | long) ^^ { case n ~ i ~ o ~ v => SQLExpression(i, o, v, n) } def like: PackratParser[SQLExpression] = - (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Like.regex ~ literal ^^ { - case i ~ n ~ _ ~ v => - SQLExpression(i, Like, v, n) + any_identifier ~ not.? ~ Like.regex ~ literal ^^ { case i ~ n ~ _ ~ v => + SQLExpression(i, Like, v, n) } private def ge: PackratParser[SQLComparisonOperator] = Ge.sql ^^ (_ => Ge) @@ -426,14 +442,14 @@ trait SQLWhereParser { def lt: PackratParser[SQLComparisonOperator] = Lt.sql ^^ (_ => Lt) private def comparison: PackratParser[SQLExpression] = - not.? ~ (identifierWithAggregation | identifierWithFunction | identifier) ~ (ge | gt | le | lt) ~ (double | long | literal) ^^ { + not.? ~ any_identifier ~ (ge | gt | le | lt) ~ (double | long | literal) ^^ { case n ~ i ~ o ~ v => SQLExpression(i, o, v, n) } def in: PackratParser[SQLExpressionOperator] = In.regex ^^ (_ => In) private def inLiteral: PackratParser[SQLCriteria] = - identifier ~ not.? ~ in ~ start ~ rep1sep(literal, separator) ~ end ^^ { + any_identifier ~ not.? ~ in ~ start ~ rep1sep(literal, separator) ~ end ^^ { case i ~ n ~ _ ~ _ ~ v ~ _ => SQLIn( i, @@ -443,7 +459,7 @@ trait SQLWhereParser { } private def inDoubles: PackratParser[SQLCriteria] = - (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ in ~ start ~ rep1sep( + any_identifier ~ not.? ~ in ~ start ~ rep1sep( double, separator ) ~ end ^^ { case i ~ n ~ _ ~ _ ~ v ~ _ => @@ -455,7 +471,7 @@ trait SQLWhereParser { } private def inLongs: PackratParser[SQLCriteria] = - (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ in ~ start ~ rep1sep( + any_identifier ~ not.? ~ in ~ start ~ rep1sep( long, separator ) ~ end ^^ { case i ~ n ~ _ ~ _ ~ v ~ _ => @@ -467,17 +483,17 @@ trait SQLWhereParser { } def between: PackratParser[SQLCriteria] = - (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ literal ~ and ~ literal ^^ { + any_identifier ~ not.? ~ Between.regex ~ literal ~ and ~ literal ^^ { case i ~ n ~ _ ~ from ~ _ ~ to => SQLBetween(i, SQLLiteralFromTo(from, to), n) } def betweenLongs: PackratParser[SQLCriteria] = - (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ long ~ and ~ long ^^ { + any_identifier ~ not.? ~ Between.regex ~ long ~ and ~ long ^^ { case i ~ n ~ _ ~ from ~ _ ~ to => SQLBetween(i, SQLLongFromTo(from, to), n) } def betweenDoubles: PackratParser[SQLCriteria] = - (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ double ~ and ~ double ^^ { + any_identifier ~ not.? ~ Between.regex ~ double ~ and ~ double ^^ { case i ~ n ~ _ ~ from ~ _ ~ to => SQLBetween(i, SQLDoubleFromTo(from, to), n) } @@ -488,18 +504,16 @@ trait SQLWhereParser { def matchCriteria: PackratParser[SQLMatch] = Match.regex ~ start ~ rep1sep( - identifier, + any_identifier, separator ) ~ end ~ Against.regex ~ start ~ literal ~ end ^^ { case _ ~ _ ~ i ~ _ ~ _ ~ _ ~ l ~ _ => SQLMatch(i, l) } - private def dateTimeComparison: PackratParser[SQLComparisonDateMath] = { - // identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | identifier - not.? ~ (identifierWithAggregation | identifier) ~ (eq | ne | diff | ge | gt | le | lt) ~ (current_date | current_time | current_timestamp | now) ~ arithmeticOperator.? ~ interval.? ^^ { + private def dateTimeComparison: PackratParser[SQLComparisonDateMath] = + not.? ~ any_identifier ~ (eq | ne | diff | ge | gt | le | lt) ~ (current_date | current_time | current_timestamp | now) ~ arithmeticOperator.? ~ interval.? ^^ { case n ~ i ~ o ~ dt ~ ao ~ it => SQLComparisonDateMath(i, o, dt, ao, it, n) } - } def and: PackratParser[SQLPredicateOperator] = And.regex ^^ (_ => And) @@ -507,8 +521,13 @@ trait SQLWhereParser { def not: PackratParser[Not.type] = Not.regex ^^ (_ => Not) + def logical_criteria: PackratParser[SQLCriteria] = + (is_null | is_notnull) ^^ { case SQLLogicalFunctionAsCriteria(c) => + c + } + def criteria: PackratParser[SQLCriteria] = - (equality | like | dateTimeComparison | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | sql_distance | matchCriteria) ^^ ( + (equality | like | dateTimeComparison | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | sql_distance | matchCriteria | logical_criteria) ^^ ( c => c ) @@ -740,10 +759,6 @@ trait SQLOrderByParser { def fieldWithFunction: PackratParser[(String, List[SQLFunction])] = rep1sep(sql_functions, start) ~ start.? ~ fieldName ~ rep1(end) ^^ { case f ~ _ ~ n ~ _ => - SQLValidator.validateChain(f) match { - case Left(error) => throw SQLValidationError(error) - case _ => - } (n, f) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala index a578cc53..87ec07da 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala @@ -51,4 +51,44 @@ case class SQLSearchRequest( } lazy val buckets: Seq[SQLBucket] = groupBy.map(_.buckets).getOrElse(Seq.empty) + + override def validate(): Either[String, Unit] = { + for { + _ <- from.validate() + _ <- select.validate() + _ <- where.map(_.validate()).getOrElse(Right(())) + _ <- groupBy.map(_.validate()).getOrElse(Right(())) + _ <- having.map(_.validate()).getOrElse(Right(())) + _ <- orderBy.map(_.validate()).getOrElse(Right(())) + _ <- limit.map(_.validate()).getOrElse(Right(())) + /*_ <- { + // validate that having clauses are only applied when group by is present + if (having.isDefined && groupBy.isEmpty) { + Left("HAVING clauses can only be applied when GROUP BY is present") + } else { + Right(()) + } + }*/ + _ <- { + // validate that non-aggregated fields are not present when group by is present + if (groupBy.isDefined) { + val nonAggregatedFields = select.fields.filterNot(f => f.aggregation || f.isScriptField) + val invalidFields = nonAggregatedFields.filterNot(f => + buckets.exists(b => + b.name == f.fieldAlias.map(_.alias).getOrElse(f.sourceField.replace(".", "_")) + ) + ) + if (invalidFields.nonEmpty) { + Left( + s"Non-aggregated fields ${invalidFields.map(_.sql).mkString(", ")} cannot be selected when GROUP BY is present" + ) + } else { + Right(()) + } + } else { + Right(()) + } + } + } yield () + } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala index e2991f9d..5aa56a26 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala @@ -29,6 +29,8 @@ sealed trait Field extends Updateable with SQLFunctionChain with PainlessScript def painless: String = identifier.painless lazy val scriptName: String = fieldAlias.map(_.alias).getOrElse(sourceField) + + override def validate(): Either[String, Unit] = identifier.validate() } case class SQLField( @@ -58,4 +60,11 @@ case class SQLSelect( }.toMap def update(request: SQLSearchRequest): SQLSelect = this.copy(fields = fields.map(_.update(request)), except = except.map(_.update(request))) + + override def validate(): Either[String, Unit] = + if (fields.isEmpty) { + Left("At least one field is required in SELECT clause") + } else { + fields.map(_.validate()).find(_.isLeft).getOrElse(Right(())) + } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala index ff4cebf7..a5271a48 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala @@ -1,18 +1,25 @@ package app.softnetwork.elastic.sql +import SQLTypes._ sealed trait SQLType { def typeId: String } trait SQLAny extends SQLType trait SQLTemporal extends SQLType trait SQLDate extends SQLTemporal +trait SQLTime extends SQLTemporal trait SQLDateTime extends SQLTemporal trait SQLNumber extends SQLType trait SQLString extends SQLType +trait SQLBool extends SQLType object SQLTypeCompatibility { def matches(out: SQLType, in: SQLType): Boolean = out.typeId == in.typeId || - (out.typeId == "temporal" && Set("date", "datetime").contains(in.typeId)) || - (in.typeId == "temporal" && Set("date", "datetime").contains(out.typeId)) || - out.typeId == "any" || in.typeId == "any" + (out.typeId == Temporal.typeId && Set(Date.typeId, DateTime.typeId, Time.typeId).contains( + in.typeId + )) || + (in.typeId == Temporal.typeId && Set(Date.typeId, DateTime.typeId, Time.typeId).contains( + out.typeId + )) || + out.typeId == Any.typeId || in.typeId == Any.typeId } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala index 131c9a01..78f7bd49 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala @@ -4,7 +4,9 @@ object SQLTypes { case object Any extends SQLAny { val typeId = "any" } case object Temporal extends SQLTemporal { val typeId = "temporal" } case object Date extends SQLDate { val typeId = "date" } + case object Time extends SQLTime { val typeId = "time" } case object DateTime extends SQLDateTime { val typeId = "datetime" } case object Number extends SQLNumber { val typeId = "number" } case object String extends SQLString { val typeId = "string" } + case object Boolean extends SQLBool { val typeId = "boolean" } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala index 776bab53..1da0509e 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala @@ -7,16 +7,19 @@ object SQLValidator { val unaryFuncs = functions.collect { case f: SQLUnaryFunction[_, _] => f } unaryFuncs.sliding(2).foreach { case Seq(f1, f2) => - if (!SQLTypeCompatibility.matches(f2.outputType, f1.inputType)) { - return Left( - s"Type mismatch: output '${f2.outputType.typeId}' of `${f2.sql}` " + - s"is not compatible with input '${f1.inputType.typeId}' of `${f1.sql}`" - ) - } + validateTypesMatching(f2.outputType, f1.inputType) case _ => // ok } Right(()) } + + def validateTypesMatching(out: SQLType, in: SQLType): Either[String, Unit] = { + if (SQLTypeCompatibility.matches(out, in)) { + Right(()) + } else { + Left(s"Type mismatch: output '${out.typeId}' is not compatible with input '${in.typeId}'") + } + } } trait SQLValidation { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index b40a09d9..86e2e5c6 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -85,6 +85,11 @@ case class SQLPredicate( override def matchCriteria: Boolean = leftCriteria.matchCriteria || rightCriteria.matchCriteria + override def validate(): Either[String, Unit] = + for { + _ <- leftCriteria.validate() + _ <- rightCriteria.validate() + } yield () } sealed trait ElasticFilter @@ -95,6 +100,7 @@ sealed trait SQLCriteriaWithIdentifier extends SQLCriteria with SQLFunctionChain override def group: Boolean = false override lazy val limit: Option[SQLLimit] = identifier.limit override val functions: List[SQLFunction] = identifier.functions + override def validate(): Either[String, Unit] = identifier.validate() } case class ElasticBoolQuery( @@ -151,12 +157,77 @@ case class ElasticBoolQuery( } -sealed trait Expression extends SQLCriteriaWithIdentifier with ElasticFilter { +sealed trait Expression extends SQLCriteriaWithIdentifier with ElasticFilter with PainlessScript { def maybeValue: Option[SQLToken] def maybeNot: Option[Not.type] def notAsString: String = maybeNot.map(v => s"$v ").getOrElse("") def valueAsString: String = maybeValue.map(v => s" $v").getOrElse("") override def sql = s"$identifier $notAsString$operator$valueAsString" + + override lazy val aggregation: Boolean = maybeValue match { + case Some(v: SQLFunctionChain) => identifier.aggregation || v.aggregation + case _ => identifier.aggregation + } + + def painlessNot: String = operator match { + case _: SQLComparisonOperator => "" + case _ => maybeNot.map(_.painless).getOrElse("") + } + + def painlessOp: String = operator match { + case o: SQLComparisonOperator if maybeNot.isDefined => o.not.painless + case _ => operator.painless + } + + def painlessValue: String = maybeValue + .map { + case v: SQLValue[_] => v.painless + case v: SQLValues[_, _] => v.painless + case v: SQLIdentifier => v.painless + case v => v.sql + } + .getOrElse { + operator match { + case IsNull | IsNotNull => "null" + case _ => "" + } + } + + override def painless: String = s"$painlessNot${identifier.painless} $painlessOp $painlessValue" + + def maybeScript: Option[SQLScript] = if ( + identifier.functions.nonEmpty || maybeValue.exists { + case v: SQLFunctionChain => v.functions.nonEmpty + case _ => false + } + ) { + Some(SQLScript(painless)) + } else None + + def isScript: Boolean = maybeScript.isDefined + + override def out: SQLType = identifier.out + + override def validate(): Either[String, Unit] = { + for { + _ <- identifier.validate() + _ <- maybeValue match { + case Some(v) => + v.validate() match { + case Left(err) => Left(s"$err in expression: $this") + case Right(_) => + SQLValidator.validateTypesMatching(out, v.out) match { + case Left(_) => + Left( + s"Type mismatch: '${out.typeId}' is not compatible with '${v.out.typeId}' in expression: $this" + ) + case Right(_) => Right(()) + } + } + case _ => Right(()) + } + } yield () + } } case class SQLExpression( @@ -168,7 +239,12 @@ case class SQLExpression( override def maybeValue: Option[SQLToken] = Option(value) override def update(request: SQLSearchRequest): SQLCriteria = { - val updated = this.copy(identifier = identifier.update(request)) + val updated = + value match { + case id: SQLIdentifier => + this.copy(identifier = identifier.update(request), value = id.update(request)) + case _ => this.copy(identifier = identifier.update(request)) + } if (updated.nested) { ElasticNested(updated, limit) } else @@ -178,33 +254,6 @@ case class SQLExpression( override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this } -sealed trait BinaryExpression extends Expression { - def left: SQLIdentifier - def right: SQLIdentifier - override def identifier: SQLIdentifier = left - override def maybeValue: Option[SQLToken] = Some(right) - override lazy val aggregation: Boolean = left.aggregation || right.aggregation -} - -case class SQLBinaryExpression( - left: SQLIdentifier, - operator: SQLComparisonOperator, // Gt, Ge, Lt, Le, Eq, ... - right: SQLIdentifier, - maybeNot: Option[Not.type] = None -) extends BinaryExpression - with PainlessScript { - override def update(request: SQLSearchRequest): SQLCriteria = - this.copy(left = left.update(request), right = right.update(request)) - - override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this - - override def painless: String = { - val painlessOp = (if (maybeNot.isDefined) operator.not else operator).painless - s"${left.painless} $painlessOp ${right.painless}" - } - -} - case class SQLIsNull(identifier: SQLIdentifier) extends Expression { override val operator: SQLOperator = IsNull @@ -241,6 +290,49 @@ case class SQLIsNotNull(identifier: SQLIdentifier) extends Expression { override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this } +sealed trait SQLCriteriaWithLogicalFunction[In <: SQLType] extends Expression { + def logicalFunction: SQLLogicalFunction[In] + override def maybeValue: Option[SQLToken] = None + override def maybeNot: Option[Not.type] = None + override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this + override val functions: List[SQLFunction] = List(logicalFunction) + override def sql = s"${logicalFunction.sql}($identifier)" +} + +object SQLLogicalFunctionAsCriteria { + def unapply(f: SQLLogicalFunction[_]): Option[SQLCriteria] = f match { + case SQLIsNullFunction(id) => Some(SQLIsNullCriteria(id)) + case SQLIsNotNullFunction(id) => Some(SQLIsNotNullCriteria(id)) + case _ => None + } +} + +case class SQLIsNullCriteria(identifier: SQLIdentifier) + extends SQLCriteriaWithLogicalFunction[SQLAny] { + override val logicalFunction: SQLLogicalFunction[SQLAny] = SQLIsNullFunction(identifier) + override val operator: SQLOperator = IsNull + override def update(request: SQLSearchRequest): SQLCriteria = { + val updated = this.copy(identifier = identifier.update(request)) + if (updated.nested) { + ElasticNested(updated, limit) + } else + updated + } +} + +case class SQLIsNotNullCriteria(identifier: SQLIdentifier) + extends SQLCriteriaWithLogicalFunction[SQLAny] { + override val logicalFunction: SQLLogicalFunction[SQLAny] = SQLIsNotNullFunction(identifier) + override val operator: SQLOperator = IsNotNull + override def update(request: SQLSearchRequest): SQLCriteria = { + val updated = this.copy(identifier = identifier.update(request)) + if (updated.nested) { + ElasticNested(updated, limit) + } else + updated + } +} + case class SQLIn[R, +T <: SQLValue[R]]( identifier: SQLIdentifier, values: SQLValues[R, T], @@ -392,6 +484,9 @@ case class SQLComparisonDateMath( } } } + + override def maybeScript: Option[SQLScript] = Some(SQLScript(script)) + } case class ElasticMatch( @@ -473,4 +568,9 @@ case class SQLWhere(criteria: Option[SQLCriteria]) extends Updateable { def update(request: SQLSearchRequest): SQLWhere = this.copy(criteria = criteria.map(_.update(request))) + override def validate(): Either[String, Unit] = criteria match { + case Some(c) => c.validate() + case _ => Right(()) + } + } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index 35448221..f605aaa0 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -16,9 +16,11 @@ package object sql { case _ => "" } - trait SQLToken extends Serializable { + trait SQLToken extends Serializable with SQLValidation { def sql: String override def toString: String = sql + def in: SQLType = SQLTypes.Any + def out: SQLType = SQLTypes.Any } trait PainlessScript extends SQLToken { @@ -68,6 +70,7 @@ package object sql { case class SQLBoolean(override val value: Boolean) extends SQLValue[Boolean](value) { override def sql: String = value.toString + override def out: SQLType = SQLTypes.Boolean } case class SQLLiteral(override val value: String) extends SQLValue[String](value) { @@ -96,6 +99,7 @@ package object sql { case _ => super.choose(values, operator, separator) } } + override def out: SQLType = SQLTypes.String } sealed abstract class SQLNumeric[T: Numeric](override val value: T)(implicit @@ -129,6 +133,7 @@ package object sql { def ne: Seq[T] => Boolean = { _.forall { _ != value } } + override def out: SQLType = SQLTypes.Number } case class SQLLong(override val value: Long) extends SQLNumeric[Long](value) @@ -170,8 +175,10 @@ package object sql { } sealed abstract class SQLValues[+R: TypeTag, +T <: SQLValue[R]](val values: Seq[T]) - extends SQLToken { + extends SQLToken + with PainlessScript { override def sql = s"(${values.map(_.sql).mkString(",")})" + override def painless: String = s"[${values.map(_.painless).mkString(",")}]" lazy val innerValues: Seq[R] = values.map(_.value) } @@ -268,7 +275,10 @@ package object sql { lazy val aliasOrName: String = fieldAlias.getOrElse(name) override def painless: String = { - val base = if (name.nonEmpty) s"doc['$name'].value" else "" + val base = + if (aggregation && functions.size == 1) s"params.$aliasOrName" + else if (name.nonEmpty) s"doc['$name'].value" + else "" val orderedFunctions = SQLFunctionUtils.transformFunctions(this).reverse orderedFunctions.foldLeft(base) { case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr) @@ -337,5 +347,5 @@ package object sql { } } - case class SQLScript(script: String) extends SQLExpr(script) + case class SQLScript(script: String) extends SQLExpr(script) with MathScript } diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index e961a9fc..e44ea280 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -133,6 +133,10 @@ object Queries { val dateTimeSub = "select identifier, datetime_sub(lastUpdated, interval 10 day) as lastSeen from Table where identifier2 is not null" + val isnull = "select isnull(identifier) as flag from Table" + val isnotnull = "select identifier, isnotnull(identifier2) as flag from Table" + val isNullCriteria = "select * from Table where isnull(identifier)" + val isNotNullCriteria = "select * from Table where isnotnull(identifier)" } /** Created by smanciot on 15/02/17. @@ -506,4 +510,32 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { dateTimeSub ) } + + it should "parse isnull function" in { + val result = SQLParser(isnull) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + isnull + ) + } + + it should "parse isnotnull function" in { + val result = SQLParser(isnotnull) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + isnotnull + ) + } + + it should "parse isnull criteria" in { + val result = SQLParser(isNullCriteria) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + isNullCriteria + ) + } + + it should "parse isnotnull criteria" in { + val result = SQLParser(isNotNullCriteria) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + isNotNullCriteria + ) + } } From bfdf95df7ce170bf45c63a24272b9c7c26fae9d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Tue, 9 Sep 2025 14:51:21 +0200 Subject: [PATCH 02/18] override validation for SQLMultiSearchRequest --- .../app/softnetwork/elastic/sql/SQLMultiSearchRequest.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLMultiSearchRequest.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLMultiSearchRequest.scala index 9184eef0..ed13841d 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLMultiSearchRequest.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLMultiSearchRequest.scala @@ -5,4 +5,7 @@ case class SQLMultiSearchRequest(requests: Seq[SQLSearchRequest]) extends SQLTok def update(): SQLMultiSearchRequest = this.copy(requests = requests.map(_.update())) + override def validate(): Either[String, Unit] = { + requests.map(_.validate()).find(_.isLeft).getOrElse(Right(())) + } } From 4abc29d6f7d48831a0dc66163d9e627ea929ae4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Wed, 10 Sep 2025 09:45:37 +0200 Subject: [PATCH 03/18] apply sql type and retrieve math script for a chain of functions, remove SQLComparisonDateMath, update identifier, expressions and implicits accordingly, override painless for like, in and match --- .../elastic/sql/bridge/ElasticQuery.scala | 2 - .../elastic/sql/bridge/package.scala | 51 +++++---- .../elastic/sql/SQLQuerySpec.scala | 5 +- .../elastic/sql/bridge/ElasticQuery.scala | 2 - .../elastic/sql/bridge/package.scala | 51 +++++---- .../elastic/sql/SQLQuerySpec.scala | 5 +- .../softnetwork/elastic/sql/SQLFunction.scala | 102 ++++++++++++------ .../softnetwork/elastic/sql/SQLGroupBy.scala | 42 +++----- .../softnetwork/elastic/sql/SQLOperator.scala | 10 +- .../softnetwork/elastic/sql/SQLParser.scala | 27 ++--- .../softnetwork/elastic/sql/SQLWhere.scala | 88 ++++----------- .../app/softnetwork/elastic/sql/package.scala | 13 ++- 12 files changed, 189 insertions(+), 209 deletions(-) diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala index a70a7681..61fd88f1 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala @@ -9,7 +9,6 @@ import app.softnetwork.elastic.sql.{ ElasticNested, ElasticParent, SQLBetween, - SQLComparisonDateMath, SQLExpression, SQLIn, SQLIsNotNull, @@ -72,7 +71,6 @@ case class ElasticQuery(filter: ElasticFilter) { case between: SQLBetween[Double] => between case geoDistance: ElasticGeoDistance => geoDistance case matchExpression: ElasticMatch => matchExpression - case dateMath: SQLComparisonDateMath => dateMath case isNull: SQLIsNullCriteria => isNull case isNotNull: SQLIsNotNullCriteria => isNotNull case other => diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index 654272e1..52410f0d 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -143,8 +143,13 @@ package object bridge { implicit def expressionToQuery(expression: SQLExpression): Query = { import expression._ + if (aggregation) + return matchAllQuery() + if (identifier.functions.nonEmpty) { + return scriptQuery(Script(script = painless).lang("painless").scriptType("source")) + } value match { - case n: SQLNumeric[_] if !aggregation => + case n: SQLNumeric[_] => operator match { case Ge => maybeNot match { @@ -226,7 +231,7 @@ package object bridge { } case _ => matchAllQuery() } - case l: SQLLiteral if !aggregation => + case l: SQLLiteral => operator match { case Like => maybeNot match { @@ -279,7 +284,7 @@ package object bridge { } case _ => matchAllQuery() } - case b: SQLBoolean if !aggregation => + case b: SQLBoolean => operator match { case Eq => maybeNot match { @@ -297,27 +302,27 @@ package object bridge { } case _ => matchAllQuery() } - case _ => matchAllQuery() - } - } - - implicit def dateMathToQuery(dateMath: SQLComparisonDateMath): Query = { - import dateMath._ - if (aggregation) - return matchAllQuery() - dateTimeFunction match { - case _: CurrentTimeFunction => - scriptQuery(Script(script = script).lang("painless").scriptType("source")) - case _ => - val op = if (maybeNot.isDefined) operator.not else operator - op match { - case Gt => rangeQuery(identifier.name) gt script - case Ge => rangeQuery(identifier.name) gte script - case Lt => rangeQuery(identifier.name) lt script - case Le => rangeQuery(identifier.name) lte script - case Eq => rangeQuery(identifier.name) gte script lte script - case Ne | Diff => not(rangeQuery(identifier.name) gte script lte script) + case i: SQLIdentifier => + operator match { + case op: SQLComparisonOperator => + i.toScript match { + case Some(script) => + val o = if (maybeNot.isDefined) op.not else op + o match { + case Gt => rangeQuery(identifier.name) gt script + case Ge => rangeQuery(identifier.name) gte script + case Lt => rangeQuery(identifier.name) lt script + case Le => rangeQuery(identifier.name) lte script + case Eq => rangeQuery(identifier.name) gte script lte script + case Ne | Diff => not(rangeQuery(identifier.name) gte script lte script) + } + case _ => + scriptQuery(Script(script = painless).lang("painless").scriptType("source")) + } + case _ => + scriptQuery(Script(script = painless).lang("painless").scriptType("source")) } + case _ => matchAllQuery() } } diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index ee31f217..ccf23d6b 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -976,7 +976,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "script": { | "script": { | "lang": "painless", - | "source": "return doc['createdAt'].value.toLocalTime() < LocalTime.now();" + | "source": "doc['createdAt'].value.toLocalTime() < ZonedDateTime.now(ZoneId.of('Z')).toLocalTime()" | } | } | }, @@ -984,7 +984,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "script": { | "script": { | "lang": "painless", - | "source": "return doc['createdAt'].value.toLocalTime() >= LocalTime.now().minus(10, ChronoUnit.MINUTES);" + | "source": "doc['createdAt'].value.toLocalTime() >= ZonedDateTime.now(ZoneId.of('Z')).toLocalTime().minus(10, ChronoUnit.MINUTES)" | } | } | } @@ -1001,7 +1001,6 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("ChronoUnit", " ChronoUnit") .replaceAll(">=", " >= ") .replaceAll("<", " < ") - .replaceAll("return", "return ") } it should "handle having with date functions" in { diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala index 9de90dd8..77c7ab92 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala @@ -9,7 +9,6 @@ import app.softnetwork.elastic.sql.{ ElasticNested, ElasticParent, SQLBetween, - SQLComparisonDateMath, SQLExpression, SQLIn, SQLIsNotNull, @@ -70,7 +69,6 @@ case class ElasticQuery(filter: ElasticFilter) { case between: SQLBetween[Double] => between case geoDistance: ElasticGeoDistance => geoDistance case matchExpression: ElasticMatch => matchExpression - case dateMath: SQLComparisonDateMath => dateMath case other => throw new IllegalArgumentException(s"Unsupported filter type: ${other.getClass.getName}") } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index b2edb050..2377cfc7 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -144,8 +144,13 @@ package object bridge { implicit def expressionToQuery(expression: SQLExpression): Query = { import expression._ + if (aggregation) + return matchAllQuery() + if (identifier.functions.nonEmpty) { + return scriptQuery(Script(script = painless).lang("painless").scriptType("source")) + } value match { - case n: SQLNumeric[_] if !aggregation => + case n: SQLNumeric[_] => operator match { case Ge => maybeNot match { @@ -227,7 +232,7 @@ package object bridge { } case _ => matchAllQuery() } - case l: SQLLiteral if !aggregation => + case l: SQLLiteral => operator match { case Like => maybeNot match { @@ -280,7 +285,7 @@ package object bridge { } case _ => matchAllQuery() } - case b: SQLBoolean if !aggregation => + case b: SQLBoolean => operator match { case Eq => maybeNot match { @@ -298,27 +303,27 @@ package object bridge { } case _ => matchAllQuery() } - case _ => matchAllQuery() - } - } - - implicit def dateMathToQuery(dateMath: SQLComparisonDateMath): Query = { - import dateMath._ - if (aggregation) - return matchAllQuery() - dateTimeFunction match { - case _: CurrentTimeFunction => - scriptQuery(Script(script = script).lang("painless").scriptType("source")) - case _ => - val op = if (maybeNot.isDefined) operator.not else operator - op match { - case Gt => rangeQuery(identifier.name) gt script - case Ge => rangeQuery(identifier.name) gte script - case Lt => rangeQuery(identifier.name) lt script - case Le => rangeQuery(identifier.name) lte script - case Eq => rangeQuery(identifier.name) gte script lte script - case Ne | Diff => not(rangeQuery(identifier.name) gte script lte script) + case i: SQLIdentifier => + operator match { + case op: SQLComparisonOperator => + i.toScript match { + case Some(script) => + val o = if (maybeNot.isDefined) op.not else op + o match { + case Gt => rangeQuery(identifier.name) gt script + case Ge => rangeQuery(identifier.name) gte script + case Lt => rangeQuery(identifier.name) lt script + case Le => rangeQuery(identifier.name) lte script + case Eq => rangeQuery(identifier.name) gte script lte script + case Ne | Diff => not(rangeQuery(identifier.name) gte script lte script) + } + case _ => + scriptQuery(Script(script = painless).lang("painless").scriptType("source")) + } + case _ => + scriptQuery(Script(script = painless).lang("painless").scriptType("source")) } + case _ => matchAllQuery() } } diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index d7f25e09..0a4826fe 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -975,7 +975,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "script": { | "script": { | "lang": "painless", - | "source": "return doc['createdAt'].value.toLocalTime() < LocalTime.now();" + | "source": "doc['createdAt'].value.toLocalTime() < ZonedDateTime.now(ZoneId.of('Z')).toLocalTime()" | } | } | }, @@ -983,7 +983,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "script": { | "script": { | "lang": "painless", - | "source": "return doc['createdAt'].value.toLocalTime() >= LocalTime.now().minus(10, ChronoUnit.MINUTES);" + | "source": "doc['createdAt'].value.toLocalTime() >= ZonedDateTime.now(ZoneId.of('Z')).toLocalTime().minus(10, ChronoUnit.MINUTES)" | } | } | } @@ -1000,7 +1000,6 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("ChronoUnit", " ChronoUnit") .replaceAll(">=", " >= ") .replaceAll("<", " < ") - .replaceAll("return", "return ") } it should "handle having with date functions" in { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index ceb56479..1dccfe10 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -5,6 +5,7 @@ import scala.util.matching.Regex sealed trait SQLFunction extends SQLRegex { def toSQL(base: String): String = if (base.nonEmpty) s"$sql($base)" else sql def system: Boolean = false + def applyType(in: SQLType): SQLType = out } sealed trait SQLFunctionWithIdentifier extends SQLFunction { @@ -13,16 +14,16 @@ sealed trait SQLFunctionWithIdentifier extends SQLFunction { object SQLFunctionUtils { def aggregateAndTransformFunctions( - identifier: Identifier + chain: SQLFunctionChain ): (List[SQLFunction], List[SQLFunction]) = { - identifier.functions.partition { + chain.functions.partition { case _: AggregateFunction => true case _ => false } } - def transformFunctions(identifier: Identifier): List[SQLFunction] = { - aggregateAndTransformFunctions(identifier)._2 + def transformFunctions(chain: SQLFunctionChain): List[SQLFunction] = { + aggregateAndTransformFunctions(chain)._2 } } @@ -30,24 +31,52 @@ object SQLFunctionUtils { trait SQLFunctionChain extends SQLFunction { def functions: List[SQLFunction] - override def validate(): Either[String, Unit] = - SQLValidator.validateChain(functions) + override def validate(): Either[String, Unit] = { + if (aggregations.size > 1) { + Left("Only one aggregation function is allowed in a function chain") + } else if (aggregations.size == 1 && !functions.head.isInstanceOf[AggregateFunction]) { + Left("Aggregation function must be the first function in the chain") + } else { + SQLValidator.validateChain(functions) + } + } override def toSQL(base: String): String = functions.reverse.foldLeft(base)((expr, fun) => { fun.toSQL(expr) }) - lazy val aggregateFunction: Option[AggregateFunction] = functions.headOption match { - case Some(af: AggregateFunction) => Some(af) - case _ => None + def toScript: Option[String] = { + val orderedFunctions = SQLFunctionUtils.transformFunctions(this).reverse + orderedFunctions.foldLeft(Option("")) { + case (expr, f: MathScript) if expr.isDefined => Option(s"${expr.get}${f.script}") + case (_, _) => None // ignore non math scripts + } match { + case Some(s) if s.nonEmpty => + out match { + case SQLTypes.Date => Some(s"$s/d") + case _ => Some(s) + } + case _ => None + } + } + + private[this] lazy val aggregations = functions.collect { case af: AggregateFunction => + af } + lazy val aggregateFunction: Option[AggregateFunction] = aggregations.headOption + lazy val aggregation: Boolean = aggregateFunction.isDefined override def in: SQLType = functions.lastOption.map(_.in).getOrElse(super.in) - override def out: SQLType = functions.headOption.map(_.out).getOrElse(super.out) + override def out: SQLType = { + val baseType = super.out + functions.reverse.foldLeft(baseType) { (currentType, fun) => + fun.applyType(currentType) + } + } } sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType] @@ -57,6 +86,7 @@ sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType] def outputType: Out override def in: SQLType = inputType override def out: SQLType = outputType + override def applyType(in: SQLType): SQLType = outputType } sealed trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType] @@ -74,9 +104,17 @@ sealed trait SQLTransformFunction[In <: SQLType, Out <: SQLType] extends SQLUnar } sealed trait SQLArithmeticFunction[In <: SQLType, Out <: SQLType] - extends SQLTransformFunction[In, Out] { + extends SQLTransformFunction[In, Out] + with MathScript { def operator: ArithmeticOperator override def toSQL(base: String): String = s"$base$operator$sql" + override def applyType(in: SQLType): SQLType = in /*match { + case SQLTypes.Date => SQLTypes.Date // a Date remains a Date + case SQLTypes.Time => SQLTypes.Time // a Time remains a Time + case SQLTypes.DateTime => SQLTypes.DateTime // a DateTime remains a DateTime + case SQLTypes.Number => SQLTypes.Number // a Number remains a Number + case _ => outputType // fallback + }*/ } sealed trait ParametrizedFunction extends SQLFunction { @@ -172,8 +210,7 @@ object TimeInterval { case class SQLAddInterval(interval: TimeInterval) extends SQLExpr(interval.sql) - with SQLArithmeticFunction[SQLDateTime, SQLDateTime] - with MathScript { + with SQLArithmeticFunction[SQLDateTime, SQLDateTime] { override def operator: ArithmeticOperator = Add override def inputType: SQLDateTime = SQLTypes.DateTime override def outputType: SQLDateTime = SQLTypes.DateTime @@ -183,8 +220,7 @@ case class SQLAddInterval(interval: TimeInterval) case class SQLSubtractInterval(interval: TimeInterval) extends SQLExpr(interval.sql) - with SQLArithmeticFunction[SQLDateTime, SQLDateTime] - with MathScript { + with SQLArithmeticFunction[SQLDateTime, SQLDateTime] { override def operator: ArithmeticOperator = Subtract override def inputType: SQLDateTime = SQLTypes.DateTime override def outputType: SQLDateTime = SQLTypes.DateTime @@ -192,35 +228,37 @@ case class SQLSubtractInterval(interval: TimeInterval) override def script: String = s"${operator.script}${interval.script}" } -sealed trait DateTimeFunction extends SQLFunction +sealed trait DateTimeFunction extends SQLFunction { + def now: String = "ZonedDateTime.now(ZoneId.of('Z'))" + override def out: SQLType = SQLTypes.DateTime +} -sealed trait DateFunction extends DateTimeFunction +sealed trait DateFunction extends DateTimeFunction { + override def out: SQLType = SQLTypes.Date +} -sealed trait TimeFunction extends DateTimeFunction +sealed trait TimeFunction extends DateTimeFunction { + override def out: SQLType = SQLTypes.Time +} sealed trait SystemFunction extends SQLFunction { override def system: Boolean = true } -sealed trait CurrentDateTimeFunction - extends DateTimeFunction - with PainlessScript - with MathScript - with SystemFunction { - override def painless: String = - "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" +sealed trait CurrentFunction extends SystemFunction with PainlessScript + +sealed trait CurrentDateTimeFunction extends DateTimeFunction with CurrentFunction with MathScript { + override def painless: String = now override def script: String = "now" - override def out: SQLType = SQLTypes.DateTime } -sealed trait CurrentDateFunction extends CurrentDateTimeFunction with DateFunction { - override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalDate()" - override def out: SQLType = SQLTypes.Date +sealed trait CurrentDateFunction extends DateFunction with CurrentFunction with MathScript { + override def painless: String = s"$now.toLocalDate()" + override def script: String = "now" } -sealed trait CurrentTimeFunction extends CurrentDateTimeFunction with TimeFunction { - override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalTime()" - override def out: SQLType = SQLTypes.Time +sealed trait CurrentTimeFunction extends TimeFunction with CurrentFunction { + override def painless: String = s"$now.toLocalTime()" } case object CurrentDate extends SQLExpr("current_date") with CurrentDateFunction diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala index 3fa97be2..78370198 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala @@ -118,9 +118,8 @@ object BucketSelectorScript { val leftStr = toPainless(left) val rightStr = toPainless(right) val opStr = op match { - case And => "&&" - case Or => "||" - case _ => throw new IllegalArgumentException(s"Unsupported logical operator: $op") + case And | Or => op.painless + case _ => throw new IllegalArgumentException(s"Unsupported logical operator: $op") } val not = maybeNot.nonEmpty if (group || not) @@ -130,34 +129,21 @@ object BucketSelectorScript { case relation: ElasticRelation => toPainless(relation.criteria) - case SQLComparisonDateMath(identifier, op, dateFunc, arithOp, interval, maybeNot) - if identifier.aggregation => - val painlessOp = if (maybeNot.nonEmpty) op.not.painless else op.painless - val paramName = identifier.aliasOrName - // always use a correct "now" creation - val now = "ZonedDateTime.now(ZoneId.of('Z'))" - - // build the RHS as a Painless ZonedDateTime (apply +/- interval using TimeInterval.painless) - val rightBase = (arithOp, interval) match { - case (Some(Add), Some(i)) => s"$now.plus(${i.painless})" - case (Some(Subtract), Some(i)) => s"$now.minus(${i.painless})" - case _ => now - } - - val rightZdt = dateFunc match { - // truncate only after arithmetic for CurrentDate - case _: CurrentDateFunction => s"$rightBase.truncatedTo(ChronoUnit.DAYS)" - case _: CurrentTimeFunction => s"$rightBase.truncatedTo(ChronoUnit.SECONDS)" - case _ => rightBase - } - - // protect against null params and compare epoch millis - s"(params.$paramName != null) && (params.$paramName $painlessOp $rightZdt.toInstant().toEpochMilli())" - case _: SQLMatch => "1 == 1" //MATCH is not supported in bucket_selector case e: Expression if e.aggregation => - e.painless + val paramName = e.identifier.paramName + e.out match { + case SQLTypes.Date if e.operator.isInstanceOf[SQLComparisonOperator] => + // protect against null params and compare epoch millis + s"($paramName != null) && (${e.painless}.truncatedTo(ChronoUnit.DAYS).toInstant().toEpochMilli())" + case SQLTypes.Time if e.operator.isInstanceOf[SQLComparisonOperator] => + s"($paramName != null) && (${e.painless}.truncatedTo(ChronoUnit.SECONDS).toInstant().toEpochMilli())" + case SQLTypes.DateTime if e.operator.isInstanceOf[SQLComparisonOperator] => + s"($paramName != null) && (${e.painless}.toInstant().toEpochMilli())" + case _ => + e.painless + } case _ => "1 == 1" //throw new IllegalArgumentException(s"Unsupported SQLCriteria type: $expr") } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala index 6ee913cb..694a930f 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala @@ -2,10 +2,12 @@ package app.softnetwork.elastic.sql trait SQLOperator extends SQLToken with PainlessScript { override def painless: String = this match { - case And => "&&" - case Or => "||" - case Not => "!" - case _ => sql + case And => "&&" + case Or => "||" + case Not => "!" + case In => ".contains" + case Like | Match => ".matches" + case _ => sql } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 02642d89..a2e552c2 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -7,11 +7,6 @@ import TimeUnit._ /** Created by smanciot on 27/06/2018. * * SQL Parser for ElasticSearch - * - * TODO implements SQL : - * - JOIN, - * - GROUP BY, - * - HAVING, etc. */ object SQLParser extends SQLParser @@ -114,24 +109,23 @@ trait SQLParser extends RegexParsers with PackratParsers { TimeInterval(l.value.toInt, u) } - def current_date: PackratParser[CurrentDateTimeFunction] = + def current_date: PackratParser[CurrentFunction] = CurrentDate.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => if (s.isDefined && t.isDefined) CurentDateWithParens else CurrentDate } - def current_time: PackratParser[CurrentDateTimeFunction] = + def current_time: PackratParser[CurrentFunction] = CurrentTime.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => if (s.isDefined && t.isDefined) CurrentTimeWithParens else CurrentTime } - def current_timestamp: PackratParser[CurrentDateTimeFunction] = + def current_timestamp: PackratParser[CurrentFunction] = CurrentTimestamp.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => if (s.isDefined && t.isDefined) CurrentTimestampWithParens else CurrentTimestamp } - def now: PackratParser[CurrentDateTimeFunction] = Now.regex ~ start.? ~ end.? ^^ { - case _ ~ s ~ t => - if (s.isDefined && t.isDefined) NowWithParens else Now + def now: PackratParser[CurrentFunction] = Now.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => + if (s.isDefined && t.isDefined) NowWithParens else Now } def add: PackratParser[ArithmeticOperator] = Add.sql ^^ (_ => Add) @@ -424,7 +418,7 @@ trait SQLWhereParser { identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier private def equality: PackratParser[SQLExpression] = - not.? ~ any_identifier ~ (eq | ne | diff) ~ (boolean | literal | double | long) ^^ { + not.? ~ any_identifier ~ (eq | ne | diff) ~ (boolean | literal | double | long | any_identifier) ^^ { case n ~ i ~ o ~ v => SQLExpression(i, o, v, n) } @@ -442,7 +436,7 @@ trait SQLWhereParser { def lt: PackratParser[SQLComparisonOperator] = Lt.sql ^^ (_ => Lt) private def comparison: PackratParser[SQLExpression] = - not.? ~ any_identifier ~ (ge | gt | le | lt) ~ (double | long | literal) ^^ { + not.? ~ any_identifier ~ (ge | gt | le | lt) ~ (double | long | literal | any_identifier) ^^ { case n ~ i ~ o ~ v => SQLExpression(i, o, v, n) } @@ -510,11 +504,6 @@ trait SQLWhereParser { SQLMatch(i, l) } - private def dateTimeComparison: PackratParser[SQLComparisonDateMath] = - not.? ~ any_identifier ~ (eq | ne | diff | ge | gt | le | lt) ~ (current_date | current_time | current_timestamp | now) ~ arithmeticOperator.? ~ interval.? ^^ { - case n ~ i ~ o ~ dt ~ ao ~ it => SQLComparisonDateMath(i, o, dt, ao, it, n) - } - def and: PackratParser[SQLPredicateOperator] = And.regex ^^ (_ => And) def or: PackratParser[SQLPredicateOperator] = Or.regex ^^ (_ => Or) @@ -527,7 +516,7 @@ trait SQLWhereParser { } def criteria: PackratParser[SQLCriteria] = - (equality | like | dateTimeComparison | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | sql_distance | matchCriteria | logical_criteria) ^^ ( + (equality | like | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | sql_distance | matchCriteria | logical_criteria) ^^ ( c => c ) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index 86e2e5c6..c32568a1 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -193,20 +193,27 @@ sealed trait Expression extends SQLCriteriaWithIdentifier with ElasticFilter wit } } - override def painless: String = s"$painlessNot${identifier.painless} $painlessOp $painlessValue" - - def maybeScript: Option[SQLScript] = if ( - identifier.functions.nonEmpty || maybeValue.exists { - case v: SQLFunctionChain => v.functions.nonEmpty - case _ => false + private[this] lazy val base: String = + out match { + case SQLTypes.Time if identifier.paramName.nonEmpty => + s"${identifier.paramName}.toLocalTime()" + case _ => identifier.paramName } - ) { - Some(SQLScript(painless)) - } else None - - def isScript: Boolean = maybeScript.isDefined - override def out: SQLType = identifier.out + override def painless: String = + s"$painlessNot${identifier.toPainless(base)} $painlessOp $painlessValue" + + override def out: SQLType = + (identifier.out, maybeValue) match { + case (idType, Some(v)) => + v match { + case value: SQLValue[_] => value.out + case values: SQLValues[_, _] => values.out + case id: SQLIdentifier => id.out + case _ => idType + } + case (idType, None) => idType + } override def validate(): Either[String, Unit] = { for { @@ -356,6 +363,8 @@ case class SQLIn[R, +T <: SQLValue[R]]( override def maybeValue: Option[SQLToken] = Some(values) override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this + + override def painless: String = s"$painlessNot${identifier.painless} $painlessOp($painlessValue)" } case class SQLBetween[+T]( @@ -436,59 +445,6 @@ case class SQLMatch( override def group: Boolean = false } -case class SQLComparisonDateMath( - identifier: SQLIdentifier, - operator: SQLComparisonOperator, // Gt, Ge, Lt, Le, Eq, ... - dateTimeFunction: CurrentDateTimeFunction, // CurrentDate, Now, CurrentTimestamp, CurrentTime, ... - arithmeticOperator: Option[ArithmeticOperator] = - None, // Plus or Minus between dateTimeFunction and interval - interval: Option[TimeInterval] = None, // optional interval - maybeNot: Option[Not.type] = None -) extends Expression - with MathScript { - override def sql: String = { - s"$identifier ${operator.sql} $dateTimeFunction${asString(arithmeticOperator)}${asString(interval)}" - } - override def update(request: SQLSearchRequest): SQLCriteria = - this.copy(identifier = identifier.update(request)) - - override def maybeValue: Option[SQLToken] = Some(SQLScript(script)) - - override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this - - override def script: String = { - dateTimeFunction match { - case _: CurrentTimeFunction => - val painlessOp = (if (maybeNot.isDefined) operator.not else operator).painless - (arithmeticOperator, interval) match { - case (Some(Add), Some(i)) => // compare doc time with now + interval - s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now().plus(${i.value}, ${i.unit.painless});" - - case (Some(Subtract), Some(i)) => // compare doc time with now - s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now().minus(${i.value}, ${i.unit.painless});" - - case _ => - s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now();" - } - case _ => - val base = s"${dateTimeFunction.script}" - val dateMath = - (arithmeticOperator, interval) match { - case (Some(Add), Some(i)) => s"$base+${i.script}" - case (Some(Subtract), Some(i)) => s"$base-${i.script}" - case _ => base - } - dateTimeFunction match { - case _: CurrentDateFunction => s"$dateMath/d" - case _ => dateMath - } - } - } - - override def maybeScript: Option[SQLScript] = Some(SQLScript(script)) - -} - case class ElasticMatch( identifier: SQLIdentifier, value: SQLLiteral, @@ -507,6 +463,8 @@ case class ElasticMatch( override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this override def matchCriteria: Boolean = true + + override def painless: String = s"$painlessNot${identifier.painless} $painlessOp($painlessValue)" } sealed abstract class ElasticRelation(val criteria: SQLCriteria, val operator: ElasticOperator) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index f605aaa0..7c384744 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -274,11 +274,12 @@ package object sql { lazy val aliasOrName: String = fieldAlias.getOrElse(name) - override def painless: String = { - val base = - if (aggregation && functions.size == 1) s"params.$aliasOrName" - else if (name.nonEmpty) s"doc['$name'].value" - else "" + def paramName: String = + if (aggregation && functions.size == 1) s"params.$aliasOrName" + else if (name.nonEmpty) s"doc['$name'].value" + else "" + + def toPainless(base: String): String = { val orderedFunctions = SQLFunctionUtils.transformFunctions(this).reverse orderedFunctions.foldLeft(base) { case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr) @@ -287,6 +288,8 @@ package object sql { } } + override def painless: String = toPainless(paramName) + } case class SQLIdentifier( From 9d8d7ef24817576f0f295428be7ed35352ccce0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Wed, 10 Sep 2025 09:47:53 +0200 Subject: [PATCH 04/18] update version --- build.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index bae919ad..4b12cadf 100644 --- a/build.sbt +++ b/build.sbt @@ -19,7 +19,7 @@ ThisBuild / organization := "app.softnetwork" name := "softclient4es" -ThisBuild / version := "0.5.0" +ThisBuild / version := "0.6.0" ThisBuild / scalaVersion := scala213 From 32b78c189ad982b27647a20d1e4899d9401fb8a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Wed, 10 Sep 2025 11:12:26 +0200 Subject: [PATCH 05/18] add baseType to SQLToken, update out within SQLFunctionChain --- .../app/softnetwork/elastic/sql/SQLFunction.scala | 11 +++-------- .../scala/app/softnetwork/elastic/sql/package.scala | 5 +++-- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 1dccfe10..bae74093 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -72,11 +72,12 @@ trait SQLFunctionChain extends SQLFunction { override def in: SQLType = functions.lastOption.map(_.in).getOrElse(super.in) override def out: SQLType = { - val baseType = super.out + val baseType = functions.lastOption.map(_.in).getOrElse(super.baseType) functions.reverse.foldLeft(baseType) { (currentType, fun) => fun.applyType(currentType) } } + } sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType] @@ -108,13 +109,7 @@ sealed trait SQLArithmeticFunction[In <: SQLType, Out <: SQLType] with MathScript { def operator: ArithmeticOperator override def toSQL(base: String): String = s"$base$operator$sql" - override def applyType(in: SQLType): SQLType = in /*match { - case SQLTypes.Date => SQLTypes.Date // a Date remains a Date - case SQLTypes.Time => SQLTypes.Time // a Time remains a Time - case SQLTypes.DateTime => SQLTypes.DateTime // a DateTime remains a DateTime - case SQLTypes.Number => SQLTypes.Number // a Number remains a Number - case _ => outputType // fallback - }*/ + override def applyType(in: SQLType): SQLType = in } sealed trait ParametrizedFunction extends SQLFunction { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index 7c384744..8a2351b4 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -19,8 +19,9 @@ package object sql { trait SQLToken extends Serializable with SQLValidation { def sql: String override def toString: String = sql - def in: SQLType = SQLTypes.Any - def out: SQLType = SQLTypes.Any + def baseType: SQLType = SQLTypes.Any + def in: SQLType = baseType + def out: SQLType = baseType } trait PainlessScript extends SQLToken { From 27e6bfef6e3b98e469d6a1eede629d283f99d3cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Fri, 12 Sep 2025 11:41:57 +0200 Subject: [PATCH 06/18] add system and nullable properties for all tokens, add nullValue for painless scripts, add input expr to all sql functions set by applying applyTo, add sql interval function to update out sql type, update painless script computation, implements coalesce and nullif logical functions, add fixes for parser, update sql types and validations --- .../elastic/sql/SQLQuerySpec.scala | 331 +++++++++++++-- .../elastic/sql/bridge/ElasticQuery.scala | 6 +- .../elastic/sql/bridge/package.scala | 14 + .../elastic/sql/SQLQuerySpec.scala | 390 +++++++++++++++++- .../softnetwork/elastic/sql/SQLFunction.scala | 207 +++++++++- .../softnetwork/elastic/sql/SQLGroupBy.scala | 57 --- .../softnetwork/elastic/sql/SQLOperator.scala | 4 + .../softnetwork/elastic/sql/SQLParser.scala | 183 ++++++-- .../app/softnetwork/elastic/sql/SQLType.scala | 13 - .../elastic/sql/SQLTypeUtils.scala | 113 +++++ .../softnetwork/elastic/sql/SQLTypes.scala | 12 +- .../elastic/sql/SQLValidator.scala | 10 +- .../app/softnetwork/elastic/sql/package.scala | 51 ++- .../sql/SQLDateTimeFunctionSuite.scala | 10 +- .../elastic/sql/SQLParserSpec.scala | 56 ++- 15 files changed, 1228 insertions(+), 229 deletions(-) create mode 100644 sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index ccf23d6b..edaa0b80 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -882,14 +882,29 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "ct": { | "script": { | "lang": "painless", - | "source": "doc['createdAt'].value.minus(35, ChronoUnit.MINUTES)" + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.minus(35, ChronoUnit.MINUTES) : null)" | } | } | }, | "_source": { - | "includes": ["identifier"] + | "includes": [ + | "identifier" + | ] | } - |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("!=null", " != null") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "filter with date time and interval" in { @@ -1040,7 +1055,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": "lastSeen" | }, | "script": { - | "source": "(params.lastSeen != null) && (params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS).toInstant().toEpochMilli())" + | "source": "params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)" | } | } | } @@ -1202,7 +1217,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(doc['createdAt'].value, LocalDate::from)" + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(e0, LocalDate::from) : null)" | } | } | } @@ -1211,10 +1226,20 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") .replaceAll(",ChronoUnit", ", ChronoUnit") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") .replaceAll(">", " > ") .replaceAll(",LocalDate", ", LocalDate") } @@ -1258,7 +1283,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from).truncatedTo(ChronoUnit.MINUTES).get(ChronoUnit.YEARS)" + | "source": "(def e2 = (def e1 = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(e0, ZonedDateTime::from) : null); e1 != null ? e1.truncatedTo(ChronoUnit.MINUTES) : null); e2 != null ? e2.get(ChronoUnit.YEARS) : null)" | } | } | } @@ -1267,9 +1292,19 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") .replaceAll(">", " > ") .replaceAll(",ZonedDateTime", ", ZonedDateTime") } @@ -1288,7 +1323,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "diff": { | "script": { | "lang": "painless", - | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, doc['createdAt'].value)" + | "source": "(def s = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def e = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); s != null && e != null ? ChronoUnit.DAYS.between(s, e) : null)" | } | } | }, @@ -1299,7 +1334,21 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") - .replaceAll(",doc", ", doc") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") } it should "handle aggregation with date_diff function" in { @@ -1324,7 +1373,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "max": { | "script": { | "lang": "painless", - | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from))" + | "source": "(def s = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def e = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(e0, ZonedDateTime::from) : null); s != null && e != null ? ChronoUnit.DAYS.between(s, e) : null)" | } | } | } @@ -1333,8 +1382,21 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") - .replaceAll(",doc", ", doc") - .replaceAll("DateTimeFormatter", " DateTimeFormatter") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") .replaceAll("ZonedDateTime", " ZonedDateTime") } @@ -1360,7 +1422,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.plus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1369,7 +1431,24 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "handle date_sub function as script field" in { @@ -1394,7 +1473,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.minus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1403,7 +1482,24 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "handle datetime_add function as script field" in { @@ -1428,7 +1524,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.plus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1437,7 +1533,24 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "handle datetime_sub function as script field" in { @@ -1462,7 +1575,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.minus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1471,7 +1584,24 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "handle is_null function as script field" in { @@ -1488,12 +1618,28 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "flag": { | "script": { | "lang": "painless", - | "source": "(doc['identifier'].value == null)" + | "source": "(def e0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); e0 == null)" | } | } | }, | "_source": true - |}""".stripMargin.replaceAll("\\s+", "").replaceAll("==", " == ") + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") } it should "handle is_notnull function as script field" in { @@ -1503,23 +1649,39 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { println(query) query shouldBe """{ - | "query": { - | "match_all": {} - | }, - | "script_fields": { - | "flag": { - | "script": { - | "lang": "painless", - | "source": "(doc['identifier2'].value != null)" - | } - | } - | }, - | "_source": { - | "includes": [ - | "identifier" - | ] - | } - |}""".stripMargin.replaceAll("\\s+", "").replaceAll("!=", " != ") + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "flag": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") } it should "handle is_null criteria as must_not exists" in { @@ -1579,4 +1741,97 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin.replaceAll("\\s+", "") } + + it should "handle coalesce function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(coalesce) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def v0 = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.minus(35, ChronoUnit.MINUTES) : null);if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')); }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll(";}", "; }") + .replaceAll(";e", "; e") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle nullif function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(nullif) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def v0 = ({ def e1=(!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2=DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from).minus(2, ChronoUnit.DAYS); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defe", " def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";def", "; def") + .replaceAll(";return", "; return") + .replaceAll("returnv", " return v") + .replaceAll("returne", " return e") + .replaceAll(";}", "; }") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(",LocalDate", ", LocalDate") + .replaceAll("=DateTimeFormatter", " = DateTimeFormatter") + .replaceAll("ZonedDateTime", " ZonedDateTime") + } + } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala index 77c7ab92..3a532263 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala @@ -12,7 +12,9 @@ import app.softnetwork.elastic.sql.{ SQLExpression, SQLIn, SQLIsNotNull, - SQLIsNull + SQLIsNotNullCriteria, + SQLIsNull, + SQLIsNullCriteria } import com.sksamuel.elastic4s.ElasticApi._ import com.sksamuel.elastic4s.requests.searches.queries.Query @@ -69,6 +71,8 @@ case class ElasticQuery(filter: ElasticFilter) { case between: SQLBetween[Double] => between case geoDistance: ElasticGeoDistance => geoDistance case matchExpression: ElasticMatch => matchExpression + case isNull: SQLIsNullCriteria => isNull + case isNotNull: SQLIsNotNullCriteria => isNotNull case other => throw new IllegalArgumentException(s"Unsupported filter type: ${other.getClass.getName}") } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index 2377cfc7..4089249f 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -341,6 +341,20 @@ package object bridge { existsQuery(identifier.name) } + implicit def isNullCriteriaToQuery( + isNull: SQLIsNullCriteria + ): Query = { + import isNull._ + not(existsQuery(identifier.name)) + } + + implicit def isNotNullCriteriaToQuery( + isNotNull: SQLIsNotNullCriteria + ): Query = { + import isNotNull._ + existsQuery(identifier.name) + } + implicit def inToQuery[R, T <: SQLValue[R]](in: SQLIn[R, T]): Query = { import in._ val _values: Seq[Any] = values.innerValues diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 0a4826fe..b2ce0edf 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -881,14 +881,28 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "ct": { | "script": { | "lang": "painless", - | "source": "doc['createdAt'].value.minus(35, ChronoUnit.MINUTES)" + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.minus(35, ChronoUnit.MINUTES) : null)" | } | } | }, | "_source": { - | "includes": ["identifier"] + | "includes": [ + | "identifier" + | ] | } - |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin.replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("!=null", " != null") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "filter with date time and interval" in { @@ -1039,7 +1053,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": "lastSeen" | }, | "script": { - | "source": "(params.lastSeen != null) && (params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS).toInstant().toEpochMilli())" + | "source": "params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)" | } | } | } @@ -1199,7 +1213,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(doc['createdAt'].value, LocalDate::from)" + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(e0, LocalDate::from) : null)" | } | } | } @@ -1208,10 +1222,20 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") .replaceAll("ChronoUnit", " ChronoUnit") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") .replaceAll(">", " > ") .replaceAll(",LocalDate", ", LocalDate") } @@ -1255,7 +1279,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from).truncatedTo(ChronoUnit.MINUTES).get(ChronoUnit.YEARS)" + | "source": "(def e2 = (def e1 = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(e0, ZonedDateTime::from) : null); e1 != null ? e1.truncatedTo(ChronoUnit.MINUTES) : null); e2 != null ? e2.get(ChronoUnit.YEARS) : null)" | } | } | } @@ -1264,9 +1288,19 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") .replaceAll(">", " > ") .replaceAll(",ZonedDateTime", ", ZonedDateTime") } @@ -1285,7 +1319,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "diff": { | "script": { | "lang": "painless", - | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, doc['createdAt'].value)" + | "source": "(def s = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def e = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); s != null && e != null ? ChronoUnit.DAYS.between(s, e) : null)" | } | } | }, @@ -1296,7 +1330,21 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") - .replaceAll(",doc", ", doc") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") } it should "handle aggregation with date_diff function" in { @@ -1321,7 +1369,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "max": { | "script": { | "lang": "painless", - | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from))" + | "source": "(def s = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def e = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(e0, ZonedDateTime::from) : null); s != null && e != null ? ChronoUnit.DAYS.between(s, e) : null)" | } | } | } @@ -1330,8 +1378,21 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") - .replaceAll(",doc", ", doc") - .replaceAll("DateTimeFormatter", " DateTimeFormatter") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") .replaceAll("ZonedDateTime", " ZonedDateTime") } @@ -1357,7 +1418,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.plus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1366,7 +1427,23 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin.replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "handle date_sub function as script field" in { @@ -1391,7 +1468,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.minus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1400,7 +1477,23 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin.replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "handle datetime_add function as script field" in { @@ -1425,7 +1518,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.plus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1434,7 +1527,23 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin.replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "handle datetime_sub function as script field" in { @@ -1459,7 +1568,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.minus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1468,7 +1577,250 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin.replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle is_null function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(isnull) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "flag": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); e0 == null)" + | } + | } + | }, + | "_source": true + |}""".stripMargin.replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + } + + it should "handle is_notnull function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(isnotnull) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "flag": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + } + + it should "handle is_null criteria as must_not exists" in { + val select: ElasticSearchRequest = + SQLQuery(isNullCriteria) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "bool": { + | "must_not": [ + | { + | "exists": { + | "field": "identifier" + | } + | } + | ] + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "") + } + + it should "handle is_notnull criteria as exists" in { + val select: ElasticSearchRequest = + SQLQuery(isNotNullCriteria) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier" + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "") + } + + it should "handle coalesce function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(coalesce) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def v0 = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.minus(35, ChronoUnit.MINUTES) : null);if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')); }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll(";}", "; }") + .replaceAll(";e", "; e") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle nullif function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(nullif) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def v0 = ({ def e1=(!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2=DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from).minus(2, ChronoUnit.DAYS); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defe", " def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";def", "; def") + .replaceAll(";return", "; return") + .replaceAll("returnv", " return v") + .replaceAll("returne", " return e") + .replaceAll(";}", "; }") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(",LocalDate", ", LocalDate") + .replaceAll("=DateTimeFormatter", " = DateTimeFormatter") + .replaceAll("ZonedDateTime", " ZonedDateTime") } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index bae74093..4a82f061 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -1,17 +1,26 @@ package app.softnetwork.elastic.sql +import scala.util.Try import scala.util.matching.Regex sealed trait SQLFunction extends SQLRegex { def toSQL(base: String): String = if (base.nonEmpty) s"$sql($base)" else sql - def system: Boolean = false def applyType(in: SQLType): SQLType = out + var expr: SQLToken = _ + def applyTo(expr: SQLToken): Unit = { + this.expr = expr + } + override def nullable: Boolean = Try(expr.nullable).getOrElse(true) } sealed trait SQLFunctionWithIdentifier extends SQLFunction { def identifier: SQLIdentifier } +trait SQLFunctionWithValue[+T] extends SQLFunction { + def value: T +} + object SQLFunctionUtils { def aggregateAndTransformFunctions( chain: SQLFunctionChain @@ -61,6 +70,17 @@ trait SQLFunctionChain extends SQLFunction { } } + override def system: Boolean = functions.lastOption.exists(_.system) + + override def applyTo(expr: SQLToken): Unit = { + super.applyTo(expr) + val orderedFunctions = functions.reverse + orderedFunctions.foldLeft(expr) { (currentExpr, fun) => + fun.applyTo(currentExpr) + fun + } + } + private[this] lazy val aggregations = functions.collect { case af: AggregateFunction => af } @@ -98,10 +118,16 @@ sealed trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType] def left: PainlessScript def right: PainlessScript + override def nullable: Boolean = left.nullable || right.nullable } sealed trait SQLTransformFunction[In <: SQLType, Out <: SQLType] extends SQLUnaryFunction[In, Out] { - def toPainless(base: String): String = s"$base$painless" + def toPainless(base: String, idx: Int): String = { + if (nullable) + s"(def e$idx = $base; e$idx != null ? e$idx$painless : null)" + else + s"$base$painless" + } } sealed trait SQLArithmeticFunction[In <: SQLType, Out <: SQLType] @@ -184,6 +210,31 @@ sealed trait TimeInterval extends PainlessScript with MathScript { override def painless: String = s"$value, ${unit.painless}" override def script: String = TimeInterval.script(this) + + def applyType(in: SQLType): Either[String, SQLType] = { + import TimeUnit._ + in match { + case SQLTypes.Date => + unit match { + case Year | Month | Day => Right(SQLTypes.Date) + case Hour | Minute | Second => Right(SQLTypes.Timestamp) + case _ => Left(s"Invalid interval unit $unit for DATE") + } + case SQLTypes.Time => + unit match { + case Hour | Minute | Second => Right(SQLTypes.Time) + case _ => Left(s"Invalid interval unit $unit for TIME") + } + case SQLTypes.DateTime => + Right(SQLTypes.Timestamp) + case SQLTypes.Timestamp => + Right(SQLTypes.Timestamp) + case SQLTypes.Temporal => + Right(SQLTypes.Timestamp) + case _ => + Left(s"Intervals not supported for type $in") + } + } } import TimeUnit._ @@ -203,24 +254,32 @@ object TimeInterval { } } +sealed trait SQLIntervalFunction extends SQLArithmeticFunction[SQLDateTime, SQLDateTime] { + def interval: TimeInterval + override def inputType: SQLDateTime = SQLTypes.DateTime + override def outputType: SQLDateTime = SQLTypes.DateTime + override def script: String = s"${operator.script}${interval.script}" + + override def applyType(in: SQLType): SQLType = interval.applyType(in).getOrElse(out) + + override def validate(): Either[String, Unit] = interval.applyType(out) match { + case Left(err) => Left(err) + case Right(_) => Right(()) + } +} + case class SQLAddInterval(interval: TimeInterval) extends SQLExpr(interval.sql) - with SQLArithmeticFunction[SQLDateTime, SQLDateTime] { + with SQLIntervalFunction { override def operator: ArithmeticOperator = Add - override def inputType: SQLDateTime = SQLTypes.DateTime - override def outputType: SQLDateTime = SQLTypes.DateTime override def painless: String = s".plus(${interval.painless})" - override def script: String = s"${operator.script}${interval.script}" } case class SQLSubtractInterval(interval: TimeInterval) extends SQLExpr(interval.sql) - with SQLArithmeticFunction[SQLDateTime, SQLDateTime] { + with SQLIntervalFunction { override def operator: ArithmeticOperator = Subtract - override def inputType: SQLDateTime = SQLTypes.DateTime - override def outputType: SQLDateTime = SQLTypes.DateTime override def painless: String = s".minus(${interval.painless})" - override def script: String = s"${operator.script}${interval.script}" } sealed trait DateTimeFunction extends SQLFunction { @@ -333,7 +392,16 @@ case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit) override def toSQL(base: String): String = { s"$sql(${end.sql}, ${start.sql}, ${unit.sql})" } - override def painless: String = s"${unit.painless}.between(${start.painless}, ${end.painless})" + override def painless: String = { + if (start.nullable && end.nullable) + s"(def s = ${start.painless}; def e = ${end.painless}; s != null && e != null ? ${unit.painless}.between(s, e) : null)" + else if (start.nullable) + s"(def s = ${start.painless}; s != null ? ${unit.painless}.between(s, ${end.painless}) : null)" + else if (end.nullable) + s"(def e = ${end.painless}; e != null ? ${unit.painless}.between(${start.painless}, e) : null)" + else + s"${unit.painless}.between(${start.painless}, ${end.painless})" + } } case class DateAdd(identifier: SQLIdentifier, interval: TimeInterval) @@ -373,8 +441,11 @@ case class ParseDate(identifier: SQLIdentifier, format: String) s"$sql($base, '$format')" } override def painless: String = throw new NotImplementedError("Use toPainless instead") - override def toPainless(base: String): String = - s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDate::from)" + override def toPainless(base: String, idx: Int): String = + if (nullable) + s"(def e$idx = $base; e$idx != null ? DateTimeFormatter.ofPattern('$format').parse(e$idx, LocalDate::from) : null)" + else + s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDate::from)" } case class FormatDate(identifier: SQLIdentifier, format: String) @@ -388,8 +459,11 @@ case class FormatDate(identifier: SQLIdentifier, format: String) s"$sql($base, '$format')" } override def painless: String = throw new NotImplementedError("Use toPainless instead") - override def toPainless(base: String): String = - s"DateTimeFormatter.ofPattern('$format').format($base)" + override def toPainless(base: String, idx: Int): String = + if (nullable) + s"(def e$idx = $base; e$idx != null ? DateTimeFormatter.ofPattern('$format').format(e$idx) : null)" + else + s"DateTimeFormatter.ofPattern('$format').format($base)" } case class DateTimeAdd(identifier: SQLIdentifier, interval: TimeInterval) @@ -429,8 +503,11 @@ case class ParseDateTime(identifier: SQLIdentifier, format: String) s"$sql($base, '$format')" } override def painless: String = throw new NotImplementedError("Use toPainless instead") - override def toPainless(base: String): String = - s"DateTimeFormatter.ofPattern('$format').parse($base, ZonedDateTime::from)" + override def toPainless(base: String, idx: Int): String = + if (nullable) + s"(def e$idx = $base; e$idx != null ? DateTimeFormatter.ofPattern('$format').parse(e$idx, ZonedDateTime::from) : null)" + else + s"DateTimeFormatter.ofPattern('$format').parse($base, ZonedDateTime::from)" } case class FormatDateTime(identifier: SQLIdentifier, format: String) @@ -444,8 +521,11 @@ case class FormatDateTime(identifier: SQLIdentifier, format: String) s"$sql($base, '$format')" } override def painless: String = throw new NotImplementedError("Use toPainless instead") - override def toPainless(base: String): String = - s"DateTimeFormatter.ofPattern('$format').format($base)" + override def toPainless(base: String, idx: Int): String = + if (nullable) + s"(def e$idx = $base; e$idx != null ? DateTimeFormatter.ofPattern('$format').format(e$idx) : null)" + else + s"DateTimeFormatter.ofPattern('$format').format($base)" } sealed trait SQLLogicalFunction[In <: SQLType] @@ -453,7 +533,7 @@ sealed trait SQLLogicalFunction[In <: SQLType] with SQLFunctionWithIdentifier { def operator: SQLLogicalOperator override def outputType: SQLBool = SQLTypes.Boolean - override def toPainless(base: String): String = s"($base$painless)" + override def toPainless(base: String, idx: Int): String = s"($base$painless)" } case class SQLIsNullFunction(identifier: SQLIdentifier) @@ -462,6 +542,12 @@ case class SQLIsNullFunction(identifier: SQLIdentifier) override def operator: SQLLogicalOperator = IsNull override def inputType: SQLAny = SQLTypes.Any override def painless: String = s" == null" + override def toPainless(base: String, idx: Int): String = { + if (nullable) + s"(def e$idx = $base; e$idx$painless)" + else + s"$base$painless" + } } case class SQLIsNotNullFunction(identifier: SQLIdentifier) @@ -470,4 +556,85 @@ case class SQLIsNotNullFunction(identifier: SQLIdentifier) override def operator: SQLLogicalOperator = IsNotNull override def inputType: SQLAny = SQLTypes.Any override def painless: String = s" != null" + override def toPainless(base: String, idx: Int): String = { + if (nullable) + s"(def e$idx = $base; e$idx$painless)" + else + s"$base$painless" + } +} + +case class SQLCoalesce(values: List[PainlessScript]) extends SQLLogicalFunction[SQLAny] { + override def operator: SQLLogicalOperator = Coalesce + + override def identifier: SQLIdentifier = SQLIdentifier("") + + override def inputType: SQLAny = SQLTypes.Any + + override lazy val sql: String = s"$Coalesce(${values.map(_.sql).mkString(", ")})" + + // Reprend l’idée de SQLValues mais pour n’importe quel token + override def out: SQLType = SQLTypeUtils.leastCommonSuperType(values.map(_.out).distinct) + + override def applyType(in: SQLType): SQLType = out + + override def validate(): Either[String, Unit] = { + if (values.isEmpty) Left("COALESCE requires at least one argument") + else Right(()) + } + + override def toPainless(base: String, idx: Int): String = s"$base$painless" + + override def painless: String = { + require(values.nonEmpty, "COALESCE requires at least one argument") + + val checks = values + .take(values.length - 1) + .zipWithIndex + .map { case (v, index) => + var check = s"def v$index = ${SQLTypeUtils.coerce(v, out)};" + check += s"if (v$index != null) return v$index;" + check + } + .mkString(" ") + // final fallback + s"{ $checks return ${SQLTypeUtils.coerce(values.last, out)}; }" + } + + override def nullable: Boolean = values.forall(_.nullable) +} + +case class SQLNullIf(expr1: PainlessScript, expr2: PainlessScript) + extends SQLLogicalFunction[SQLAny] { + override def operator: SQLLogicalOperator = NullIf + + override def identifier: SQLIdentifier = SQLIdentifier("") + + override def inputType: SQLAny = SQLTypes.Any + + override def sql: String = s"$NullIf(${expr1.sql}, ${expr2.sql})" + + override def out: SQLType = expr1.out + + override def applyType(in: SQLType): SQLType = out + + override def painless: String = { + val e1 = expr1.painless + val e2 = expr2.painless + s"""{ def e1 = $e1; + |def e2 = $e2; + |return e1 == e2 ? null : e1; + |}""".stripMargin.replaceAll("\n", " ") + } +} + +case class SQLCast(value: PainlessScript, targetType: SQLType, as: Boolean = true) + extends SQLTransformFunction[SQLType, SQLType] { + override def inputType: SQLType = value.out + override def outputType: SQLType = targetType + + override def sql: String = + s"$Cast(${value.sql} ${if (as) s"$Alias " else ""}${targetType.typeId})" + + override def painless: String = SQLTypeUtils.coerce(value, targetType) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala index 78370198..39220f68 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala @@ -41,63 +41,6 @@ case class SQLBucket( object BucketSelectorScript { - private[this] def painlessIn(param: String, values: Seq[SQLValue[_]], not: Boolean): String = { - val ret = s"[${values.map { _.painless }.mkString(", ")}].contains($param)" - if (not) s"!$ret" else ret - } - - private[this] def painlessBetween( - param: String, - lower: SQLValue[_], - upper: SQLValue[_], - not: Boolean - ): String = { - val ret = s"($param >= ${lower.painless} && $param <= ${upper.painless})" - if (not) s"!$ret" else ret - } - - private[this] def toPainless( - param: String, - operator: SQLOperator, - value: SQLToken, - not: Boolean - ): String = { - operator match { - case o: SQLComparisonOperator => - val valueStr = - value match { - case v: SQLBoolean => v.painless - case v: SQLDouble => v.painless - case v: SQLLiteral => v.painless - case v: SQLLong => v.painless - case _ => - throw new IllegalArgumentException( - s"Unsupported value type in bucket_selector: $value" - ) - } - if (not) - s"$param ${o.not.painless} $valueStr" - else - s"$param ${o.painless} $valueStr" - case In => - value match { - case SQLDoubleValues(vals) => painlessIn(param, vals, not) - case SQLLiteralValues(vals) => painlessIn(param, vals, not) - case SQLLongValues(vals) => painlessIn(param, vals, not) - case _ => throw new IllegalArgumentException("IN requires a list") - } - case Between => - value match { - case SQLDoubleFromTo(lower, upper) => painlessBetween(param, lower, upper, not) - case SQLLiteralFromTo(lower, upper) => painlessBetween(param, lower, upper, not) - case SQLLongFromTo(lower, upper) => painlessBetween(param, lower, upper, not) - case _ => throw new IllegalArgumentException("BETWEEN requires two values") - } - case _ => - throw new IllegalArgumentException(s"Unsupported operator in bucket_selector: $operator") - } - } - def extractBucketsPath(criteria: SQLCriteria): Map[String, String] = criteria match { case SQLPredicate(left, _, right, _, _) => extractBucketsPath(left) ++ extractBucketsPath(right) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala index 694a930f..0de0b03d 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala @@ -57,6 +57,10 @@ case object IsNull extends SQLExpr("is null") with SQLLogicalOperator case object IsNotNull extends SQLExpr("is not null") with SQLLogicalOperator case object Not extends SQLExpr("not") with SQLLogicalOperator case object Match extends SQLExpr("match") with SQLLogicalOperator +case object Coalesce extends SQLExpr("coalesce") with SQLLogicalOperator +case object NullIf extends SQLExpr("nullif") with SQLLogicalOperator +case object Exists extends SQLExpr("exists") with SQLLogicalOperator +case object Cast extends SQLExpr("cast") with SQLLogicalOperator case object Against extends SQLExpr("against") with SQLRegex diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index a2e552c2..d7737b91 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -69,6 +69,10 @@ trait SQLParser extends RegexParsers with PackratParsers { def boolean: PackratParser[SQLBoolean] = """(true|false)""".r ^^ (bool => SQLBoolean(bool.toBoolean)) + /*def value_identifier: PackratParser[SQLIdentifier] = (literal | long | double | boolean) ^^ { v => + SQLIdentifier("", functions = v :: Nil) + }*/ + def start: PackratParser[SQLDelimiter] = "(" ^^ (_ => StartPredicate) def end: PackratParser[SQLDelimiter] = ")" ^^ (_ => EndPredicate) @@ -109,30 +113,33 @@ trait SQLParser extends RegexParsers with PackratParsers { TimeInterval(l.value.toInt, u) } + def parens: PackratParser[List[SQLDelimiter]] = + start ~ end ^^ { case s ~ e => s :: e :: Nil } + def current_date: PackratParser[CurrentFunction] = - CurrentDate.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => - if (s.isDefined && t.isDefined) CurentDateWithParens else CurrentDate + CurrentDate.regex ~ parens.? ^^ { case _ ~ p => + if (p.isDefined) CurentDateWithParens else CurrentDate } def current_time: PackratParser[CurrentFunction] = - CurrentTime.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => - if (s.isDefined && t.isDefined) CurrentTimeWithParens else CurrentTime + CurrentTime.regex ~ parens.? ^^ { case _ ~ p => + if (p.isDefined) CurrentTimeWithParens else CurrentTime } def current_timestamp: PackratParser[CurrentFunction] = - CurrentTimestamp.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => - if (s.isDefined && t.isDefined) CurrentTimestampWithParens else CurrentTimestamp + CurrentTimestamp.regex ~ parens.? ^^ { case _ ~ p => + if (p.isDefined) CurrentTimestampWithParens else CurrentTimestamp } - def now: PackratParser[CurrentFunction] = Now.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => - if (s.isDefined && t.isDefined) NowWithParens else Now + def now: PackratParser[CurrentFunction] = Now.regex ~ parens.? ^^ { case _ ~ p => + if (p.isDefined) NowWithParens else Now } def add: PackratParser[ArithmeticOperator] = Add.sql ^^ (_ => Add) - def substract: PackratParser[ArithmeticOperator] = Subtract.sql ^^ (_ => Subtract) + def subtract: PackratParser[ArithmeticOperator] = Subtract.sql ^^ (_ => Subtract) - def intervalOperator: PackratParser[ArithmeticOperator] = add | substract + def intervalOperator: PackratParser[ArithmeticOperator] = add | subtract def arithmeticOperator: PackratParser[ArithmeticOperator] = intervalOperator @@ -142,15 +149,13 @@ trait SQLParser extends RegexParsers with PackratParsers { } def substractInterval: PackratParser[SQLSubtractInterval] = - substract ~ interval ^^ { case _ ~ it => + subtract ~ interval ^^ { case _ ~ it => SQLSubtractInterval(it) } def intervalFunction: PackratParser[SQLArithmeticFunction[SQLDateTime, SQLDateTime]] = addInterval | substractInterval - def arithmeticFunction: PackratParser[SQLArithmeticFunction[_, _]] = intervalFunction - def identifierWithSystemFunction: PackratParser[SQLIdentifier] = (current_date | current_time | current_timestamp | now) ~ intervalFunction.? ^^ { case f1 ~ f2 => @@ -161,7 +166,7 @@ trait SQLParser extends RegexParsers with PackratParsers { } def date_trunc: PackratParser[SQLFunctionWithIdentifier] = - "(?i)date_trunc".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { + "(?i)date_trunc".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { case _ ~ _ ~ i ~ _ ~ u ~ _ => DateTrunc(i, u) } @@ -192,25 +197,30 @@ trait SQLParser extends RegexParsers with PackratParsers { extract | extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second def date_add: PackratParser[DateFunction with SQLFunctionWithIdentifier] = - "(?i)date_add".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + "(?i)date_add".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ ~ t ~ _ => DateAdd(i, t) } def date_sub: PackratParser[DateFunction with SQLFunctionWithIdentifier] = - "(?i)date_sub".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + "(?i)date_sub".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ ~ t ~ _ => DateSub(i, t) } def parse_date: PackratParser[DateFunction with SQLFunctionWithIdentifier] = - "(?i)parse_date".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { - case _ ~ _ ~ i ~ _ ~ f ~ _ => - ParseDate(i, f.value) + "(?i)parse_date".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | literal | identifier) ~ separator ~ literal ~ end ^^ { + case _ ~ _ ~ li ~ _ ~ f ~ _ => + li match { + case l: SQLLiteral => + ParseDate(SQLIdentifier("", functions = l :: Nil), f.value) + case i: SQLIdentifier => + ParseDate(i, f.value) + } } def format_date: PackratParser[DateFunction with SQLFunctionWithIdentifier] = - "(?i)format_date".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + "(?i)format_date".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { case _ ~ _ ~ i ~ _ ~ f ~ _ => FormatDate(i, f.value) } @@ -218,25 +228,30 @@ trait SQLParser extends RegexParsers with PackratParsers { def date_functions: PackratParser[DateFunction] = date_add | date_sub | parse_date | format_date def datetime_add: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = - "(?i)datetime_add".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + "(?i)datetime_add".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ ~ t ~ _ => DateTimeAdd(i, t) } def datetime_sub: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = - "(?i)datetime_sub".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + "(?i)datetime_sub".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ ~ t ~ _ => DateTimeSub(i, t) } def parse_datetime: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = - "(?i)parse_datetime".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { - case _ ~ _ ~ i ~ _ ~ f ~ _ => - ParseDateTime(i, f.value) + "(?i)parse_datetime".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | literal | identifier) ~ separator ~ literal ~ end ^^ { + case _ ~ _ ~ li ~ _ ~ f ~ _ => + li match { + case l: SQLLiteral => + ParseDateTime(SQLIdentifier("", functions = l :: Nil), f.value) + case i: SQLIdentifier => + ParseDateTime(i, f.value) + } } def format_datetime: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = - "(?i)format_datetime".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + "(?i)format_datetime".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { case _ ~ _ ~ i ~ _ ~ f ~ _ => FormatDateTime(i, f.value) } @@ -248,11 +263,11 @@ trait SQLParser extends RegexParsers with PackratParsers { def distance: PackratParser[SQLFunction] = Distance.regex ^^ (_ => Distance) - def painless_identifier: PackratParser[SQLIdentifier] = + def identifierWithTemporalFunction: PackratParser[SQLIdentifier] = rep1sep( date_trunc | extractors | date_functions | datetime_functions, start - ) ~ start.? ~ (identifierWithSystemFunction | identifierWithArithmeticFunction | identifier).? ~ rep( + ) ~ start.? ~ (identifierWithSystemFunction | identifier).? ~ rep( end ) ^^ { case f ~ _ ~ i ~ _ => i match { @@ -267,7 +282,7 @@ trait SQLParser extends RegexParsers with PackratParsers { } def date_diff: PackratParser[SQLBinaryFunction[_, _, _]] = - "(?i)date_diff".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { + "(?i)date_diff".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { case _ ~ _ ~ d1 ~ _ ~ d2 ~ _ ~ u ~ _ => DateDiff(d1, d2, u) } @@ -276,25 +291,100 @@ trait SQLParser extends RegexParsers with PackratParsers { } def is_null: PackratParser[SQLLogicalFunction[_]] = - "(?i)isnull".r ~ start ~ (painless_identifier | identifierWithArithmeticFunction | identifier) ~ end ^^ { + "(?i)isnull".r ~ start ~ (identifierWithTransformation | identifierWithArithmeticFunction | identifierWithTemporalFunction | identifier) ~ end ^^ { case _ ~ _ ~ i ~ _ => SQLIsNullFunction(i) } def is_notnull: PackratParser[SQLLogicalFunction[_]] = - "(?i)isnotnull".r ~ start ~ (painless_identifier | identifierWithArithmeticFunction | identifier) ~ end ^^ { + "(?i)isnotnull".r ~ start ~ (identifierWithTransformation | identifierWithArithmeticFunction | identifierWithTemporalFunction | identifier) ~ end ^^ { case _ ~ _ ~ i ~ _ => SQLIsNotNullFunction(i) } - def logical_functions: PackratParser[SQLLogicalFunction[_]] = - is_null | is_notnull + private[this] def valueExpr: PackratParser[PainlessScript] = + // les plus spécifiques en premier + identifierWithTransformation | // transformations appliquées à un identifier + date_diff_identifier | // date_diff(...) retournant un identifier-like + identifierWithSystemFunction | // CURRENT_DATE, NOW, etc. (+/- interval) + identifierWithArithmeticFunction | // foo - interval ... + identifierWithTemporalFunction | // chaîne de fonctions appliquées à un identifier + identifierWithFunction | // fonctions appliquées à un identifier + literal | // 'string' + long | + double | + boolean | + identifier + + def coalesce: PackratParser[SQLCoalesce] = + Coalesce.regex ~ start ~ rep1sep( + valueExpr, + separator + ) ~ end ^^ { case _ ~ _ ~ ids ~ _ => + SQLCoalesce(ids) + } + + def nullif: PackratParser[SQLNullIf] = + NullIf.regex ~ start ~ valueExpr ~ separator ~ valueExpr ~ end ^^ { + case _ ~ _ ~ id1 ~ _ ~ id2 ~ _ => SQLNullIf(id1, id2) + } + + def logical_functions: PackratParser[SQLTransformFunction[_, _]] = + is_null | is_notnull | coalesce | nullif def sql_functions: PackratParser[SQLFunction] = aggregates | distance | date_diff | date_trunc | extractors | date_functions | datetime_functions | logical_functions - private val regexIdentifier = """[\*a-zA-Z_\-][a-zA-Z0-9_\-\.\[\]\*]*""" + //private val regexIdentifier = """[\*a-zA-Z_\-][a-zA-Z0-9_\-\.\[\]\*]*""" + + private val reservedKeywords = Seq( + "select", + "from", + "where", + "group", + "having", + "order", + "limit", + "as", + "by", + "except", + "unnest", + "current_date", + "current_time", + "current_datetime", + "current_timestamp", + "now", + "coalesce", + "nullif", + "isnull", + "isnotnull", + "date_add", + "date_sub", + "parse_date", + "parse_datetime", + "format_date", + "format_datetime", + "date_trunc", + "extract", + "date_diff", + "datetime_add", + "datetime_sub", + "interval", + "year", + "month", + "day", + "hour", + "minute", + "second" + ) + + private val identifierRegexStr = + s"""(?i)(?!(?:${reservedKeywords.mkString( + "|" + )})\\b)[\\*a-zA-Z_][a-zA-Z0-9_.\\[\\]\\*]*""" + + private val identifierRegex = identifierRegexStr.r // scala.util.matching.Regex def identifier: PackratParser[SQLIdentifier] = - Distinct.regex.? ~ regexIdentifier.r ^^ { case d ~ i => + Distinct.regex.? ~ identifierRegex ^^ { case d ~ i => SQLIdentifier( i, None, @@ -303,22 +393,31 @@ trait SQLParser extends RegexParsers with PackratParsers { } private[this] def dateFunctionWithIdentifier: PackratParser[SQLIdentifier] = - (parse_date | format_date | date_add | date_sub) ^^ { t => - t.identifier.copy(functions = t +: t.identifier.functions) + (parse_date | format_date | date_add | date_sub) ~ arithmeticFunction.? ^^ { case t ~ af => + af match { + case Some(f) => t.identifier.copy(functions = f +: t +: t.identifier.functions) + case None => t.identifier.copy(functions = t +: t.identifier.functions) + } } private[this] def dateTimeFunctionWithIdentifier: PackratParser[SQLIdentifier] = - (date_trunc | parse_datetime | format_datetime | datetime_add | datetime_sub) ^^ { t => - t.identifier.copy(functions = t +: t.identifier.functions) + (date_trunc | parse_datetime | format_datetime | datetime_add | datetime_sub) ~ arithmeticFunction.? ^^ { + case t ~ af => + af match { + case Some(f) => t.identifier.copy(functions = f +: t +: t.identifier.functions) + case None => t.identifier.copy(functions = t +: t.identifier.functions) + } } private[this] def logicalFunctionWithIdentifier: PackratParser[SQLIdentifier] = - (is_null | is_notnull) ^^ { t => + (is_null | is_notnull | coalesce | nullif) ^^ { t => t.identifier.copy(functions = t +: t.identifier.functions) } def identifierWithTransformation: PackratParser[SQLIdentifier] = - dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier | logicalFunctionWithIdentifier + logicalFunctionWithIdentifier | dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier + + def arithmeticFunction: PackratParser[SQLArithmeticFunction[_, _]] = intervalFunction def identifierWithArithmeticFunction: PackratParser[SQLIdentifier] = (identifierWithFunction | identifier) ~ arithmeticFunction ^^ { case i ~ af => @@ -516,7 +615,7 @@ trait SQLWhereParser { } def criteria: PackratParser[SQLCriteria] = - (equality | like | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | sql_distance | matchCriteria | logical_criteria) ^^ ( + (equality | like | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | /*coalesce | nullif |*/ sql_distance | matchCriteria | logical_criteria) ^^ ( c => c ) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala index a5271a48..cc8bda0c 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala @@ -1,5 +1,4 @@ package app.softnetwork.elastic.sql -import SQLTypes._ sealed trait SQLType { def typeId: String } @@ -11,15 +10,3 @@ trait SQLDateTime extends SQLTemporal trait SQLNumber extends SQLType trait SQLString extends SQLType trait SQLBool extends SQLType - -object SQLTypeCompatibility { - def matches(out: SQLType, in: SQLType): Boolean = - out.typeId == in.typeId || - (out.typeId == Temporal.typeId && Set(Date.typeId, DateTime.typeId, Time.typeId).contains( - in.typeId - )) || - (in.typeId == Temporal.typeId && Set(Date.typeId, DateTime.typeId, Time.typeId).contains( - out.typeId - )) || - out.typeId == Any.typeId || in.typeId == Any.typeId -} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala new file mode 100644 index 00000000..9d177502 --- /dev/null +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala @@ -0,0 +1,113 @@ +package app.softnetwork.elastic.sql +import SQLTypes._ + +object SQLTypeUtils { + + def matches(out: SQLType, in: SQLType): Boolean = + out.typeId == in.typeId || + (out.typeId == Temporal.typeId && Set( + Date.typeId, + DateTime.typeId, + Time.typeId, + Timestamp.typeId + ).contains( + in.typeId + )) || + (in.typeId == Temporal.typeId && Set( + Date.typeId, + DateTime.typeId, + Time.typeId, + Timestamp.typeId + ).contains( + out.typeId + )) || + (out.typeId == Number.typeId && Set(Int.typeId, Long.typeId, Double.typeId, Float.typeId) + .contains( + in.typeId + )) || + (in.typeId == Number.typeId && Set(Int.typeId, Long.typeId, Double.typeId, Float.typeId) + .contains( + out.typeId + )) || + out.typeId == Any.typeId || in.typeId == Any.typeId || + out.typeId == Null.typeId || in.typeId == Null.typeId + + def leastCommonSuperType(types: List[SQLType]): SQLType = { + val distinct = types.distinct + if (distinct.size == 1) return distinct.head + + // 1. String + if (distinct.exists(matches(SQLTypes.String, _))) return SQLTypes.String + + // 2. Number + if (distinct.exists(matches(SQLTypes.Double, _))) return SQLTypes.Double + if (distinct.exists(matches(SQLTypes.Long, _))) return SQLTypes.Long + if (distinct.exists(matches(SQLTypes.Int, _))) return SQLTypes.Int + if (distinct.exists(matches(SQLTypes.Number, _))) return SQLTypes.Number + + // 3. Temporal + if (distinct.exists(matches(SQLTypes.Timestamp, _))) return SQLTypes.Timestamp + if (distinct.exists(matches(SQLTypes.DateTime, _))) return SQLTypes.DateTime + + // mixed case DATE + TIME → DATETIME + if (distinct.exists(matches(SQLTypes.Date, _)) && distinct.exists(matches(SQLTypes.Time, _))) + return SQLTypes.DateTime + + if (distinct.exists(matches(SQLTypes.Date, _))) return SQLTypes.Date + if (distinct.exists(matches(SQLTypes.Time, _))) return SQLTypes.Time + if (distinct.exists(matches(SQLTypes.Temporal, _))) return SQLTypes.Timestamp + + // 4. Null or Any + if (distinct.exists(matches(SQLTypes.Null, _))) return SQLTypes.Any + if (distinct.exists(matches(SQLTypes.Any, _))) return SQLTypes.Any + + // 5. Fallback + SQLTypes.Any + } + + def coerce(in: PainlessScript, to: SQLType): String = { + val expr = in.painless + val from = in.out + val ret = { + (from, to) match { + // ---- DATE & TIME ---- + case (SQLTypes.Date, SQLTypes.DateTime | SQLTypes.Timestamp) => + s"($expr).atStartOfDay(ZoneId.of('Z'))" + case (SQLTypes.DateTime | SQLTypes.Timestamp, SQLTypes.Date) => + s"($expr).toLocalDate()" + case (SQLTypes.DateTime | SQLTypes.Timestamp, SQLTypes.Time) => + s"($expr).toLocalTime()" + + // ---- NUMERIQUES ---- + case (SQLTypes.Int, SQLTypes.Long) => + s"((long) $expr)" + case (SQLTypes.Int, SQLTypes.Double) => + s"((double) $expr)" + case (SQLTypes.Long, SQLTypes.Double) => + s"((double) $expr)" + + // ---- NUMERIC <-> TEMPORAL ---- + case (SQLTypes.Long, SQLTypes.Timestamp) => + s"Instant.ofEpochMilli($expr).atZone(ZoneId.of('Z'))" + case (SQLTypes.Timestamp, SQLTypes.Long) => + s"$expr.toInstant().toEpochMilli()" + + // ---- BOOLEEN -> NUMERIC ---- + case (SQLTypes.Boolean, SQLTypes.Number) => + s"($expr ? 1 : 0)" + + // ---- IDENTITY ---- + case (_, _) if from == to => + return expr + + // ---- PAR DEFAUT ---- + case _ => + return expr // fallback + } + } + if (!in.nullable) + return ret + s"($expr != null ? $ret : null)" + } + +} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala index 78f7bd49..2dea0813 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala @@ -2,11 +2,17 @@ package app.softnetwork.elastic.sql object SQLTypes { case object Any extends SQLAny { val typeId = "any" } + case object Null extends SQLAny { val typeId = "null" } case object Temporal extends SQLTemporal { val typeId = "temporal" } - case object Date extends SQLDate { val typeId = "date" } - case object Time extends SQLTime { val typeId = "time" } - case object DateTime extends SQLDateTime { val typeId = "datetime" } + case object Date extends SQLTemporal with SQLDate { val typeId = "date" } + case object Time extends SQLTemporal with SQLTime { val typeId = "time" } + case object DateTime extends SQLTemporal with SQLDateTime { val typeId = "datetime" } + case object Timestamp extends SQLTemporal with SQLDateTime { val typeId = "timestamp" } case object Number extends SQLNumber { val typeId = "number" } + case object Int extends SQLNumber { val typeId = "integer" } + case object Long extends SQLNumber { val typeId = "long" } + case object Double extends SQLNumber { val typeId = "double" } + case object Float extends SQLNumber { val typeId = "float" } case object String extends SQLString { val typeId = "string" } case object Boolean extends SQLBool { val typeId = "boolean" } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala index 1da0509e..15994d96 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala @@ -4,6 +4,14 @@ object SQLValidator { def validateChain(functions: List[SQLFunction]): Either[String, Unit] = { // validate function chain type compatibility + functions match { + case Nil => return Right(()) + case _ => + } + functions.map(_.validate()).find(_.isLeft) match { + case Some(left) => return left + case None => + } val unaryFuncs = functions.collect { case f: SQLUnaryFunction[_, _] => f } unaryFuncs.sliding(2).foreach { case Seq(f1, f2) => @@ -14,7 +22,7 @@ object SQLValidator { } def validateTypesMatching(out: SQLType, in: SQLType): Either[String, Unit] = { - if (SQLTypeCompatibility.matches(out, in)) { + if (SQLTypeUtils.matches(out, in)) { Right(()) } else { Left(s"Type mismatch: output '${out.typeId}' is not compatible with input '${in.typeId}'") diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index 8a2351b4..fac88a1e 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -22,10 +22,13 @@ package object sql { def baseType: SQLType = SQLTypes.Any def in: SQLType = baseType def out: SQLType = baseType + def system: Boolean = false + def nullable: Boolean = !system } trait PainlessScript extends SQLToken { def painless: String + def nullValue: String = "null" } trait MathScript extends SQLToken { @@ -40,9 +43,15 @@ package object sql { case object Distinct extends SQLExpr("distinct") with SQLRegex + case object Empty extends SQLExpr("") with PainlessScript { + override def painless: String = "" + override def nullable: Boolean = false + } + abstract class SQLValue[+T](val value: T)(implicit ev$1: T => Ordered[T]) extends SQLToken - with PainlessScript { + with PainlessScript + with SQLFunctionWithValue[T] { def choose[R >: T]( values: Seq[R], operator: Option[SQLExpressionOperator], @@ -61,12 +70,14 @@ package object sql { case _ => values.headOption } } - def painless: String = value match { + override def painless: String = value match { case s: String => s""""$s"""" case b: Boolean => b.toString case n: Number => n.toString case _ => value.toString } + + override def nullable: Boolean = false } case class SQLBoolean(override val value: Boolean) extends SQLValue[Boolean](value) { @@ -75,7 +86,7 @@ package object sql { } case class SQLLiteral(override val value: String) extends SQLValue[String](value) { - override def sql: String = s""""$value"""" + override def sql: String = s"""'$value'""" import SQLImplicits._ private lazy val pattern: Pattern = value.pattern def like: Seq[String] => Boolean = { @@ -137,9 +148,13 @@ package object sql { override def out: SQLType = SQLTypes.Number } - case class SQLLong(override val value: Long) extends SQLNumeric[Long](value) + case class SQLLong(override val value: Long) extends SQLNumeric[Long](value) { + override def out: SQLType = SQLTypes.Long + } - case class SQLDouble(override val value: Double) extends SQLNumeric[Double](value) + case class SQLDouble(override val value: Double) extends SQLNumeric[Double](value) { + override def out: SQLType = SQLTypes.Double + } sealed abstract class SQLFromTo[+T](val from: SQLValue[T], val to: SQLValue[T]) extends SQLToken { override def sql = s"${from.sql} and ${to.sql}" @@ -181,6 +196,7 @@ package object sql { override def sql = s"(${values.map(_.sql).mkString(",")})" override def painless: String = s"[${values.map(_.painless).mkString(",")}]" lazy val innerValues: Seq[R] = values.map(_.value) + override def nullable: Boolean = values.exists(_.nullable) } case class SQLLiteralValues(override val values: Seq[SQLLiteral]) @@ -264,6 +280,8 @@ package object sql { def fieldAlias: Option[String] def bucket: Option[SQLBucket] + applyTo(this) + lazy val identifierName: String = functions.reverse.foldLeft(name)((expr, fun) => { fun.toSQL(expr) @@ -277,20 +295,31 @@ package object sql { def paramName: String = if (aggregation && functions.size == 1) s"params.$aliasOrName" - else if (name.nonEmpty) s"doc['$name'].value" + else if (name.nonEmpty) + s"doc['$name'].value" else "" def toPainless(base: String): String = { val orderedFunctions = SQLFunctionUtils.transformFunctions(this).reverse - orderedFunctions.foldLeft(base) { - case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr) - case (expr, f: PainlessScript) => s"$expr${f.painless}" - case (expr, f) => f.toSQL(expr) // fallback + var expr = base + orderedFunctions.zipWithIndex.foreach { case (f, idx) => + f match { + case f: SQLTransformFunction[_, _] => expr = f.toPainless(expr, idx) + case f: PainlessScript => expr = s"$expr${f.painless}" + case f => expr = f.toSQL(expr) // fallback + } } + expr } - override def painless: String = toPainless(paramName) + override def painless: String = toPainless( + if (nullable) + s"(!doc.containsKey('$name') || doc['$name'].empty ? $nullValue : doc['$name'].value)" + else + paramName + ) + override def nullable: Boolean = this.name.nonEmpty && (!aggregation || functions.size > 1) } case class SQLIdentifier( diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala index 2017093c..dc9c05fc 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala @@ -36,7 +36,7 @@ class SQLDateTimeFunctionSuite extends AnyFunSuite { require(transforms.nonEmpty, "No transforms provided") val initial: (String, SQLType) = - (transforms.head.toPainless(base), transforms.head.outputType.asInstanceOf[SQLType]) + (transforms.head.toPainless(base, 0), transforms.head.outputType.asInstanceOf[SQLType]) val (finalExpr, _) = transforms.tail.foldLeft(initial) { case ((expr, currentType), t: SQLUnaryFunction[_, _]) => @@ -45,7 +45,7 @@ class SQLDateTimeFunctionSuite extends AnyFunSuite { s"Type mismatch: expected ${currentType.getClass.getSimpleName}, got ${t.inputType.getClass.getSimpleName}" ) } - (t.toPainless(expr), t.outputType.asInstanceOf[SQLType]) + (t.toPainless(expr, 0), t.outputType.asInstanceOf[SQLType]) } finalExpr @@ -77,8 +77,8 @@ class SQLDateTimeFunctionSuite extends AnyFunSuite { val names = chain.map(_.sql).mkString(" -> ") test(s"Valid chain $idx: $names") { val chained = chainTransformsTyped(baseDate, chain) - val expected = chain.reverse.tail.foldLeft(chain.last.toPainless(baseDate)) { (expr, f) => - f.toPainless(expr) + val expected = chain.reverse.tail.foldLeft(chain.last.toPainless(baseDate, 0)) { (expr, f) => + f.toPainless(expr, 0) } // On ne teste que la génération de code Painless sans évaluer le résultat assert(chained.nonEmpty) @@ -88,7 +88,7 @@ class SQLDateTimeFunctionSuite extends AnyFunSuite { // Test simple pour chaque fonction individuelle transformFunctions.foreach { f => test(s"Single transformation ${f.sql}") { - val result = f.toPainless(baseDate) + val result = f.toPainless(baseDate, 0) assert(result.nonEmpty) } } diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index e44ea280..22e6da55 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -10,17 +10,17 @@ object Queries { val numericalGt = "select * from Table where identifier > 1" val numericalGe = "select * from Table where identifier >= 1" val numericalNe = "select * from Table where identifier <> 1" - val literalEq = """select * from Table where identifier = "un"""" - val literalLt = "select * from Table where createdAt < \"now-35M/M\"" - val literalLe = "select * from Table where createdAt <= \"now-35M/M\"" - val literalGt = "select * from Table where createdAt > \"now-35M/M\"" - val literalGe = "select * from Table where createdAt >= \"now-35M/M\"" - val literalNe = """select * from Table where identifier <> "un"""" + val literalEq = """select * from Table where identifier = 'un'""" + val literalLt = "select * from Table where createdAt < 'now-35M/M'" + val literalLe = "select * from Table where createdAt <= 'now-35M/M'" + val literalGt = "select * from Table where createdAt > 'now-35M/M'" + val literalGe = "select * from Table where createdAt >= 'now-35M/M'" + val literalNe = """select * from Table where identifier <> 'un'""" val boolEq = """select * from Table where identifier = true""" val boolNe = """select * from Table where identifier <> false""" - val literalLike = """select * from Table where identifier like "%un%"""" - val literalNotLike = """select * from Table where identifier not like "%un%"""" - val betweenExpression = """select * from Table where identifier between "1" and "2"""" + val literalLike = """select * from Table where identifier like '%un%'""" + val literalNotLike = """select * from Table where identifier not like '%un%'""" + val betweenExpression = """select * from Table where identifier between '1' and '2'""" val andPredicate = "select * from Table where identifier1 = 1 and identifier2 > 2" val orPredicate = "select * from Table where identifier1 = 1 or identifier2 > 2" val leftPredicate = @@ -40,28 +40,28 @@ object Queries { "select * from Table where identifier1 = 1 and parent(parent.identifier2 > 2 or parent.identifier3 = 3)" val parentCriteria = "select * from Table where identifier1 = 1 and parent(parent.identifier3 = 3)" - val inLiteralExpression = "select * from Table where identifier in (\"val1\",\"val2\",\"val3\")" + val inLiteralExpression = "select * from Table where identifier in ('val1','val2','val3')" val inNumericalExpressionWithIntValues = "select * from Table where identifier in (1,2,3)" val inNumericalExpressionWithDoubleValues = "select * from Table where identifier in (1.0,2.1,3.4)" val notInLiteralExpression = - "select * from Table where identifier not in (\"val1\",\"val2\",\"val3\")" + "select * from Table where identifier not in ('val1','val2','val3')" val notInNumericalExpressionWithIntValues = "select * from Table where identifier not in (1,2,3)" val notInNumericalExpressionWithDoubleValues = "select * from Table where identifier not in (1.0,2.1,3.4)" val nestedWithBetween = - "select * from Table where nested(ciblage.Archivage_CreationDate between \"now-3M/M\" and \"now\" and ciblage.statutComportement = 1)" - val count = "select count(t.id) as c1 from Table as t where t.nom = \"Nom\"" - val countDistinct = "select count(distinct t.id) as c2 from Table as t where t.nom = \"Nom\"" + "select * from Table where nested(ciblage.Archivage_CreationDate between 'now-3M/M' and 'now' and ciblage.statutComportement = 1)" + val count = "select count(t.id) as c1 from Table as t where t.nom = 'Nom'" + val countDistinct = "select count(distinct t.id) as c2 from Table as t where t.nom = 'Nom'" val countNested = - "select count(email.value) as email from crmgp where profile.postalCode in (\"75001\",\"75002\")" + "select count(email.value) as email from crmgp where profile.postalCode in ('75001','75002')" val isNull = "select * from Table where identifier is null" val isNotNull = "select * from Table where identifier is not null" val geoDistanceCriteria = - "select * from Table where distance(profile.location,(-70.0,40.0)) <= \"5km\"" + "select * from Table where distance(profile.location,(-70.0,40.0)) <= '5km'" val except = "select * except(col1,col2) from Table" val matchCriteria = - "select * from Table where match (identifier1,identifier2,identifier3) against (\"value\")" + "select * from Table where match (identifier1,identifier2,identifier3) against ('value')" val groupBy = "select identifier, count(identifier2) from Table where identifier2 is not null group by identifier" val orderBy = "select * from Table order by identifier desc" @@ -77,7 +77,7 @@ object Queries { """select count(CustomerID) as cnt, City, Country |from Customers |group by Country, City - |having Country <> "USA" and City <> "Berlin" and count(CustomerID) > 1 + |having Country <> 'USA' and City <> 'Berlin' and count(CustomerID) > 1 |order by count(CustomerID) desc, Country asc""".stripMargin.replaceAll("\n", " ") val dateTimeWithIntervalFields: String = "select current_timestamp() - interval 3 day as ct, current_date as cd, current_time as t, now as n from dual" @@ -93,7 +93,7 @@ object Queries { """select count(CustomerID) as cnt, City, Country, max(createdAt) as lastSeen |from Table |group by Country, City - |having Country <> "USA" and City != "Berlin" and count(CustomerID) > 1 and lastSeen > now - interval 7 day + |having Country <> 'USA' and City != 'Berlin' and count(CustomerID) > 1 and lastSeen > now - interval 7 day |order by Country asc""".stripMargin .replaceAll("\n", " ") val parseDate = @@ -137,6 +137,10 @@ object Queries { val isnotnull = "select identifier, isnotnull(identifier2) as flag from Table" val isNullCriteria = "select * from Table where isnull(identifier)" val isNotNullCriteria = "select * from Table where isnotnull(identifier)" + val coalesce: String = + "select coalesce(createdAt - interval 35 minute, current_date) as c, identifier from Table" + val nullif: String = + "select coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd') - interval 2 day), current_date) as c, identifier from Table" } /** Created by smanciot on 15/02/17. @@ -538,4 +542,18 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { isNotNullCriteria ) } + + it should "parse coalesce function" in { + val result = SQLParser(coalesce) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + coalesce + ) + } + + it should "parse nullif function" in { + val result = SQLParser(nullif) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + nullif + ) + } } From 14c1e9b3d60d2a6e5667ec62f92262c165ea6f11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Fri, 12 Sep 2025 12:15:09 +0200 Subject: [PATCH 07/18] add \- to identifier regex --- sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index d7737b91..1d00f37a 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -379,7 +379,7 @@ trait SQLParser extends RegexParsers with PackratParsers { private val identifierRegexStr = s"""(?i)(?!(?:${reservedKeywords.mkString( "|" - )})\\b)[\\*a-zA-Z_][a-zA-Z0-9_.\\[\\]\\*]*""" + )})\\b)[\\*a-zA-Z_\\-][a-zA-Z0-9_\\-.\\[\\]\\*]*""" private val identifierRegex = identifierRegexStr.r // scala.util.matching.Regex From ee0606b76d49a0817bb4cab4a9941469a687448c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Fri, 12 Sep 2025 12:57:44 +0200 Subject: [PATCH 08/18] rename SQLLogicalFunction to SQLConditionalFunction, update operators mixed in --- .../softnetwork/elastic/sql/SQLFunction.scala | 20 ++++----- .../softnetwork/elastic/sql/SQLOperator.scala | 43 +++++++++++-------- .../softnetwork/elastic/sql/SQLParser.scala | 6 +-- .../softnetwork/elastic/sql/SQLWhere.scala | 22 +++++----- 4 files changed, 49 insertions(+), 42 deletions(-) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 4a82f061..78b6aac7 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -528,18 +528,18 @@ case class FormatDateTime(identifier: SQLIdentifier, format: String) s"DateTimeFormatter.ofPattern('$format').format($base)" } -sealed trait SQLLogicalFunction[In <: SQLType] +sealed trait SQLConditionalFunction[In <: SQLType] extends SQLTransformFunction[In, SQLBool] with SQLFunctionWithIdentifier { - def operator: SQLLogicalOperator + def operator: SQLConditionalOperator override def outputType: SQLBool = SQLTypes.Boolean override def toPainless(base: String, idx: Int): String = s"($base$painless)" } case class SQLIsNullFunction(identifier: SQLIdentifier) extends SQLExpr("isnull") - with SQLLogicalFunction[SQLAny] { - override def operator: SQLLogicalOperator = IsNull + with SQLConditionalFunction[SQLAny] { + override def operator: SQLConditionalOperator = IsNull override def inputType: SQLAny = SQLTypes.Any override def painless: String = s" == null" override def toPainless(base: String, idx: Int): String = { @@ -552,8 +552,8 @@ case class SQLIsNullFunction(identifier: SQLIdentifier) case class SQLIsNotNullFunction(identifier: SQLIdentifier) extends SQLExpr("isnotnull") - with SQLLogicalFunction[SQLAny] { - override def operator: SQLLogicalOperator = IsNotNull + with SQLConditionalFunction[SQLAny] { + override def operator: SQLConditionalOperator = IsNotNull override def inputType: SQLAny = SQLTypes.Any override def painless: String = s" != null" override def toPainless(base: String, idx: Int): String = { @@ -564,8 +564,8 @@ case class SQLIsNotNullFunction(identifier: SQLIdentifier) } } -case class SQLCoalesce(values: List[PainlessScript]) extends SQLLogicalFunction[SQLAny] { - override def operator: SQLLogicalOperator = Coalesce +case class SQLCoalesce(values: List[PainlessScript]) extends SQLConditionalFunction[SQLAny] { + override def operator: SQLConditionalOperator = Coalesce override def identifier: SQLIdentifier = SQLIdentifier("") @@ -605,8 +605,8 @@ case class SQLCoalesce(values: List[PainlessScript]) extends SQLLogicalFunction[ } case class SQLNullIf(expr1: PainlessScript, expr2: PainlessScript) - extends SQLLogicalFunction[SQLAny] { - override def operator: SQLLogicalOperator = NullIf + extends SQLConditionalFunction[SQLAny] { + override def operator: SQLConditionalOperator = NullIf override def identifier: SQLIdentifier = SQLIdentifier("") diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala index 0de0b03d..023db2dc 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala @@ -1,12 +1,14 @@ package app.softnetwork.elastic.sql -trait SQLOperator extends SQLToken with PainlessScript { +trait SQLOperator extends SQLToken with PainlessScript with SQLRegex { override def painless: String = this match { case And => "&&" case Or => "||" case Not => "!" case In => ".contains" case Like | Match => ".matches" + case Eq => "==" + case Ne => "!=" case _ => sql } } @@ -24,12 +26,6 @@ case object Modulo extends SQLExpr("%") with ArithmeticOperator sealed trait SQLExpressionOperator extends SQLOperator sealed trait SQLComparisonOperator extends SQLExpressionOperator with PainlessScript { - override def painless: String = this match { - case Eq => "==" - case Ne => "!=" - case other => other.sql - } - def not: SQLComparisonOperator = this match { case Eq => Ne case Ne | Diff => Eq @@ -47,22 +43,31 @@ case object Ge extends SQLExpr(">=") with SQLComparisonOperator case object Gt extends SQLExpr(">") with SQLComparisonOperator case object Le extends SQLExpr("<=") with SQLComparisonOperator case object Lt extends SQLExpr("<") with SQLComparisonOperator +case object In extends SQLExpr("in") with SQLComparisonOperator +case object Like extends SQLExpr("like") with SQLComparisonOperator +case object Between extends SQLExpr("between") with SQLComparisonOperator +case object IsNull extends SQLExpr("is null") with SQLComparisonOperator with SQLConditionalOperator +case object IsNotNull + extends SQLExpr("is not null") + with SQLComparisonOperator + with SQLConditionalOperator +case object Match extends SQLExpr("match") with SQLComparisonOperator +case object Against extends SQLExpr("against") with SQLRegex -sealed trait SQLLogicalOperator extends SQLExpressionOperator with SQLRegex +sealed trait SQLLogicalOperator extends SQLExpressionOperator -case object In extends SQLExpr("in") with SQLLogicalOperator -case object Like extends SQLExpr("like") with SQLLogicalOperator -case object Between extends SQLExpr("between") with SQLLogicalOperator -case object IsNull extends SQLExpr("is null") with SQLLogicalOperator -case object IsNotNull extends SQLExpr("is not null") with SQLLogicalOperator case object Not extends SQLExpr("not") with SQLLogicalOperator -case object Match extends SQLExpr("match") with SQLLogicalOperator -case object Coalesce extends SQLExpr("coalesce") with SQLLogicalOperator -case object NullIf extends SQLExpr("nullif") with SQLLogicalOperator -case object Exists extends SQLExpr("exists") with SQLLogicalOperator -case object Cast extends SQLExpr("cast") with SQLLogicalOperator -case object Against extends SQLExpr("against") with SQLRegex +sealed trait SQLConditionalOperator extends SQLExpressionOperator +case object Coalesce extends SQLExpr("coalesce") with SQLConditionalOperator +case object NullIf extends SQLExpr("nullif") with SQLConditionalOperator +case object Exists extends SQLExpr("exists") with SQLConditionalOperator +case object Cast extends SQLExpr("cast") with SQLConditionalOperator +case object Case extends SQLExpr("case") with SQLConditionalOperator +case object When extends SQLExpr("when") with SQLConditionalOperator +case object Then extends SQLExpr("then") with SQLConditionalOperator +case object Else extends SQLExpr("else") with SQLConditionalOperator +case object End extends SQLExpr("end") with SQLConditionalOperator sealed trait SQLPredicateOperator extends SQLLogicalOperator diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 1d00f37a..de4a3ab5 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -290,12 +290,12 @@ trait SQLParser extends RegexParsers with PackratParsers { SQLIdentifier("", functions = dd :: Nil) } - def is_null: PackratParser[SQLLogicalFunction[_]] = + def is_null: PackratParser[SQLConditionalFunction[_]] = "(?i)isnull".r ~ start ~ (identifierWithTransformation | identifierWithArithmeticFunction | identifierWithTemporalFunction | identifier) ~ end ^^ { case _ ~ _ ~ i ~ _ => SQLIsNullFunction(i) } - def is_notnull: PackratParser[SQLLogicalFunction[_]] = + def is_notnull: PackratParser[SQLConditionalFunction[_]] = "(?i)isnotnull".r ~ start ~ (identifierWithTransformation | identifierWithArithmeticFunction | identifierWithTemporalFunction | identifier) ~ end ^^ { case _ ~ _ ~ i ~ _ => SQLIsNotNullFunction(i) } @@ -610,7 +610,7 @@ trait SQLWhereParser { def not: PackratParser[Not.type] = Not.regex ^^ (_ => Not) def logical_criteria: PackratParser[SQLCriteria] = - (is_null | is_notnull) ^^ { case SQLLogicalFunctionAsCriteria(c) => + (is_null | is_notnull) ^^ { case SQLConditionalFunctionAsCriteria(c) => c } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index c32568a1..bb544630 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -297,17 +297,17 @@ case class SQLIsNotNull(identifier: SQLIdentifier) extends Expression { override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this } -sealed trait SQLCriteriaWithLogicalFunction[In <: SQLType] extends Expression { - def logicalFunction: SQLLogicalFunction[In] +sealed trait SQLCriteriaWithConditionalFunction[In <: SQLType] extends Expression { + def conditionalFunction: SQLConditionalFunction[In] override def maybeValue: Option[SQLToken] = None override def maybeNot: Option[Not.type] = None override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this - override val functions: List[SQLFunction] = List(logicalFunction) - override def sql = s"${logicalFunction.sql}($identifier)" + override val functions: List[SQLFunction] = List(conditionalFunction) + override def sql = s"${conditionalFunction.sql}($identifier)" } -object SQLLogicalFunctionAsCriteria { - def unapply(f: SQLLogicalFunction[_]): Option[SQLCriteria] = f match { +object SQLConditionalFunctionAsCriteria { + def unapply(f: SQLConditionalFunction[_]): Option[SQLCriteria] = f match { case SQLIsNullFunction(id) => Some(SQLIsNullCriteria(id)) case SQLIsNotNullFunction(id) => Some(SQLIsNotNullCriteria(id)) case _ => None @@ -315,8 +315,8 @@ object SQLLogicalFunctionAsCriteria { } case class SQLIsNullCriteria(identifier: SQLIdentifier) - extends SQLCriteriaWithLogicalFunction[SQLAny] { - override val logicalFunction: SQLLogicalFunction[SQLAny] = SQLIsNullFunction(identifier) + extends SQLCriteriaWithConditionalFunction[SQLAny] { + override val conditionalFunction: SQLConditionalFunction[SQLAny] = SQLIsNullFunction(identifier) override val operator: SQLOperator = IsNull override def update(request: SQLSearchRequest): SQLCriteria = { val updated = this.copy(identifier = identifier.update(request)) @@ -328,8 +328,10 @@ case class SQLIsNullCriteria(identifier: SQLIdentifier) } case class SQLIsNotNullCriteria(identifier: SQLIdentifier) - extends SQLCriteriaWithLogicalFunction[SQLAny] { - override val logicalFunction: SQLLogicalFunction[SQLAny] = SQLIsNotNullFunction(identifier) + extends SQLCriteriaWithConditionalFunction[SQLAny] { + override val conditionalFunction: SQLConditionalFunction[SQLAny] = SQLIsNotNullFunction( + identifier + ) override val operator: SQLOperator = IsNotNull override def update(request: SQLSearchRequest): SQLCriteria = { val updated = this.copy(identifier = identifier.update(request)) From 04c35089929ca7ab25ead6902e088540396da828 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Fri, 12 Sep 2025 16:47:16 +0200 Subject: [PATCH 09/18] fix sql types applied for interval functions, finalize implementation for CAST --- .../elastic/sql/SQLQuerySpec.scala | 45 ++++++++++++++ .../elastic/sql/SQLQuerySpec.scala | 46 ++++++++++++++ .../softnetwork/elastic/sql/SQLFunction.scala | 31 ++++++++-- .../softnetwork/elastic/sql/SQLParser.scala | 62 +++++++++++++++++-- .../elastic/sql/SQLTypeUtils.scala | 35 ++++++----- .../elastic/sql/SQLParserSpec.scala | 9 +++ 6 files changed, 205 insertions(+), 23 deletions(-) diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index edaa0b80..3a153086 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1834,4 +1834,49 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("ZonedDateTime", " ZonedDateTime") } + it should "handle cast function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(cast) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def v0 = ({ def e1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2 = DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')).minus(2, ChronoUnit.HOURS); }.toInstant().toEpochMilli()" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defe", " def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("; if", ";if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(",LocalDate", ", LocalDate") + .replaceAll("=DateTimeFormatter", " = DateTimeFormatter") + } } diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index b2ce0edf..37f16e24 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1823,4 +1823,50 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("ZonedDateTime", " ZonedDateTime") } + + it should "handle cast function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(cast) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def v0 = ({ def e1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2 = DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')).minus(2, ChronoUnit.HOURS); }.toInstant().toEpochMilli()" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defe", " def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("; if", ";if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(",LocalDate", ", LocalDate") + .replaceAll("=DateTimeFormatter", " = DateTimeFormatter") + } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 78b6aac7..26635eab 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -254,18 +254,29 @@ object TimeInterval { } } -sealed trait SQLIntervalFunction extends SQLArithmeticFunction[SQLDateTime, SQLDateTime] { +sealed trait SQLIntervalFunction extends SQLArithmeticFunction[SQLTemporal, SQLTemporal] { def interval: TimeInterval - override def inputType: SQLDateTime = SQLTypes.DateTime - override def outputType: SQLDateTime = SQLTypes.DateTime + override def inputType: SQLTemporal = SQLTypes.Temporal + override def outputType: SQLTemporal = SQLTypes.Temporal override def script: String = s"${operator.script}${interval.script}" + private[this] var _out: SQLType = outputType + override def out: SQLType = _out - override def applyType(in: SQLType): SQLType = interval.applyType(in).getOrElse(out) + override def applyType(in: SQLType): SQLType = { + _out = interval.applyType(in).getOrElse(out) + _out + } override def validate(): Either[String, Unit] = interval.applyType(out) match { case Left(err) => Left(err) case Right(_) => Right(()) } + + override def toPainless(base: String, idx: Int): String = + if (nullable) + s"(def e$idx = $base; e$idx != null ? ${SQLTypeUtils.coerce(s"e$idx", expr.out, out, nullable = false)}$painless : null)" + else + s"${SQLTypeUtils.coerce(base, expr.out, out, nullable = expr.nullable)}$painless" } case class SQLAddInterval(interval: TimeInterval) @@ -636,5 +647,15 @@ case class SQLCast(value: PainlessScript, targetType: SQLType, as: Boolean = tru override def sql: String = s"$Cast(${value.sql} ${if (as) s"$Alias " else ""}${targetType.typeId})" - override def painless: String = SQLTypeUtils.coerce(value, targetType) + override def toSQL(base: String): String = sql + + override def painless: String = + SQLTypeUtils.coerce(value, targetType) + + override def toPainless(base: String, idx: Int): String = + SQLTypeUtils.coerce(base, value.out, targetType, value.nullable) + /*if (nullable) + s"(def e$idx = $base; e$idx != null ? ${SQLTypeUtils.coerce(s"e$idx", value.out, out, nullable = false)}$painless : null)" + else + s"${SQLTypeUtils.coerce(base, value.out, targetType, nullable = value.nullable)}$painless"*/ } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index de4a3ab5..e6e36281 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -153,7 +153,7 @@ trait SQLParser extends RegexParsers with PackratParsers { SQLSubtractInterval(it) } - def intervalFunction: PackratParser[SQLArithmeticFunction[SQLDateTime, SQLDateTime]] = + def intervalFunction: PackratParser[SQLArithmeticFunction[SQLTemporal, SQLTemporal]] = addInterval | substractInterval def identifierWithSystemFunction: PackratParser[SQLIdentifier] = @@ -373,7 +373,31 @@ trait SQLParser extends RegexParsers with PackratParsers { "day", "hour", "minute", - "second" + "second", + "quarter", + "string", + "int", + "integer", + "long", + "double", + "boolean", + "time", + "date", + "datetime", + "timestamp", + "and", + "or", + "not", + "like", + "in", + "between", + "distinct", + "cast", + "count", + "min", + "max", + "avg", + "sum" ) private val identifierRegexStr = @@ -392,6 +416,36 @@ trait SQLParser extends RegexParsers with PackratParsers { ) } + def string_type: PackratParser[SQLTypes.String.type] = "(?i)string".r ^^ (_ => SQLTypes.String) + + def date_type: PackratParser[SQLTypes.Date.type] = "(?i)date".r ^^ (_ => SQLTypes.Date) + + def time_type: PackratParser[SQLTypes.Time.type] = "(?i)time".r ^^ (_ => SQLTypes.Time) + + def datetime_type: PackratParser[SQLTypes.DateTime.type] = + "(?i)(datetime)".r ^^ (_ => SQLTypes.DateTime) + + def timestamp_type: PackratParser[SQLTypes.Timestamp.type] = + "(?i)(timestamp)".r ^^ (_ => SQLTypes.Timestamp) + + def boolean_type: PackratParser[SQLTypes.Boolean.type] = + "(?i)boolean".r ^^ (_ => SQLTypes.Boolean) + + def long_type: PackratParser[SQLTypes.Long.type] = "(?i)long".r ^^ (_ => SQLTypes.Long) + + def double_type: PackratParser[SQLTypes.Double.type] = "(?i)double".r ^^ (_ => SQLTypes.Double) + + def int_type: PackratParser[SQLTypes.Int.type] = "(?i)(int|integer)".r ^^ (_ => SQLTypes.Int) + + def sql_type: PackratParser[SQLType] = + string_type | datetime_type | timestamp_type | date_type | time_type | boolean_type | long_type | double_type | int_type + + private[this] def castFunctionWithIdentifier: PackratParser[SQLIdentifier] = + "(?i)cast".r ~ start ~ (identifierWithTransformation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier) ~ Alias.regex.? ~ sql_type ~ end ^^ { + case _ ~ _ ~ i ~ as ~ t ~ _ => + i.copy(functions = SQLCast(i, targetType = t, as = as.isDefined) +: i.functions) + } + private[this] def dateFunctionWithIdentifier: PackratParser[SQLIdentifier] = (parse_date | format_date | date_add | date_sub) ~ arithmeticFunction.? ^^ { case t ~ af => af match { @@ -409,13 +463,13 @@ trait SQLParser extends RegexParsers with PackratParsers { } } - private[this] def logicalFunctionWithIdentifier: PackratParser[SQLIdentifier] = + private[this] def conditionalFunctionWithIdentifier: PackratParser[SQLIdentifier] = (is_null | is_notnull | coalesce | nullif) ^^ { t => t.identifier.copy(functions = t +: t.identifier.functions) } def identifierWithTransformation: PackratParser[SQLIdentifier] = - logicalFunctionWithIdentifier | dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier + castFunctionWithIdentifier | conditionalFunctionWithIdentifier | dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier def arithmeticFunction: PackratParser[SQLArithmeticFunction[_, _]] = intervalFunction diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala index 9d177502..ae558147 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala @@ -29,6 +29,8 @@ object SQLTypeUtils { .contains( out.typeId )) || + (out.typeId == String.typeId && in.typeId == String.typeId) || + (out.typeId == Boolean.typeId && in.typeId == Boolean.typeId) || out.typeId == Any.typeId || in.typeId == Any.typeId || out.typeId == Null.typeId || in.typeId == Null.typeId @@ -37,29 +39,29 @@ object SQLTypeUtils { if (distinct.size == 1) return distinct.head // 1. String - if (distinct.exists(matches(SQLTypes.String, _))) return SQLTypes.String + if (distinct.contains(SQLTypes.String)) return SQLTypes.String // 2. Number - if (distinct.exists(matches(SQLTypes.Double, _))) return SQLTypes.Double - if (distinct.exists(matches(SQLTypes.Long, _))) return SQLTypes.Long - if (distinct.exists(matches(SQLTypes.Int, _))) return SQLTypes.Int - if (distinct.exists(matches(SQLTypes.Number, _))) return SQLTypes.Number + if (distinct.contains(SQLTypes.Double)) return SQLTypes.Double + if (distinct.contains(SQLTypes.Long)) return SQLTypes.Long + if (distinct.contains(SQLTypes.Int)) return SQLTypes.Int + if (distinct.contains(SQLTypes.Number)) return SQLTypes.Number // 3. Temporal - if (distinct.exists(matches(SQLTypes.Timestamp, _))) return SQLTypes.Timestamp - if (distinct.exists(matches(SQLTypes.DateTime, _))) return SQLTypes.DateTime + if (distinct.contains(SQLTypes.Timestamp)) return SQLTypes.Timestamp + if (distinct.contains(SQLTypes.DateTime)) return SQLTypes.DateTime // mixed case DATE + TIME → DATETIME - if (distinct.exists(matches(SQLTypes.Date, _)) && distinct.exists(matches(SQLTypes.Time, _))) + if (distinct.contains(SQLTypes.Date) && distinct.contains(SQLTypes.Time)) return SQLTypes.DateTime - if (distinct.exists(matches(SQLTypes.Date, _))) return SQLTypes.Date - if (distinct.exists(matches(SQLTypes.Time, _))) return SQLTypes.Time - if (distinct.exists(matches(SQLTypes.Temporal, _))) return SQLTypes.Timestamp + if (distinct.contains(SQLTypes.Date)) return SQLTypes.Date + if (distinct.contains(SQLTypes.Time)) return SQLTypes.Time + if (distinct.contains(SQLTypes.Temporal)) return SQLTypes.Timestamp // 4. Null or Any - if (distinct.exists(matches(SQLTypes.Null, _))) return SQLTypes.Any - if (distinct.exists(matches(SQLTypes.Any, _))) return SQLTypes.Any + if (distinct.contains(SQLTypes.Null)) return SQLTypes.Any + if (distinct.contains(SQLTypes.Any)) return SQLTypes.Any // 5. Fallback SQLTypes.Any @@ -68,6 +70,11 @@ object SQLTypeUtils { def coerce(in: PainlessScript, to: SQLType): String = { val expr = in.painless val from = in.out + val nullable = in.nullable + coerce(expr, from, to, nullable) + } + + def coerce(expr: String, from: SQLType, to: SQLType, nullable: Boolean): String = { val ret = { (from, to) match { // ---- DATE & TIME ---- @@ -105,7 +112,7 @@ object SQLTypeUtils { return expr // fallback } } - if (!in.nullable) + if (!nullable) return ret s"($expr != null ? $ret : null)" } diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 22e6da55..0341943c 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -141,6 +141,8 @@ object Queries { "select coalesce(createdAt - interval 35 minute, current_date) as c, identifier from Table" val nullif: String = "select coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd') - interval 2 day), current_date) as c, identifier from Table" + val cast: String = + "select cast(coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd')), current_date - interval 2 hour) long) as c, identifier from Table" } /** Created by smanciot on 15/02/17. @@ -556,4 +558,11 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { nullif ) } + + it should "parse cast function" in { + val result = SQLParser(cast) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + cast + ) + } } From d4f33e7dc7dc150b207db6823ebfe23f82cc87dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Sat, 13 Sep 2025 12:03:05 +0200 Subject: [PATCH 10/18] add generic AddInterval and SubtractInterval traits, implements painless for all criteria, update painless for generic expression --- .../elastic/sql/SQLQuerySpec.scala | 65 ++++++----- .../elastic/sql/SQLQuerySpec.scala | 76 +++++++------ .../softnetwork/elastic/sql/SQLFunction.scala | 40 ++++--- .../softnetwork/elastic/sql/SQLWhere.scala | 101 ++++++++++++------ 4 files changed, 169 insertions(+), 113 deletions(-) diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 3a153086..c8d0d477 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -984,38 +984,47 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { println(query) query shouldBe """{ - | "query": { - | "bool": { - | "filter": [ - | { - | "script": { - | "script": { - | "lang": "painless", - | "source": "doc['createdAt'].value.toLocalTime() < ZonedDateTime.now(ZoneId.of('Z')).toLocalTime()" - | } - | } - | }, - | { - | "script": { - | "script": { - | "lang": "painless", - | "source": "doc['createdAt'].value.toLocalTime() >= ZonedDateTime.now(ZoneId.of('Z')).toLocalTime().minus(10, ChronoUnit.MINUTES)" - | } - | } - | } - | ] - | } - | }, - | "_source": { - | "includes": [ - | "*" - | ] - | } - |}""".stripMargin + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "def left = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); left == null ? false : left < ZonedDateTime.now(ZoneId.of('Z')).toLocalTime()" + | } + | } + | }, + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "def left = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); left == null ? false : left >= ZonedDateTime.now(ZoneId.of('Z')).toLocalTime().minus(10, ChronoUnit.MINUTES)" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin .replaceAll("\\s", "") .replaceAll("ChronoUnit", " ChronoUnit") .replaceAll(">=", " >= ") .replaceAll("<", " < ") + .replaceAll("\\|\\|", " || ") + .replaceAll("null:", "null : ") + .replaceAll("false:", "false : ") + .replaceAll(":null", " : null ") + .replaceAll("\\?", " ? ") + .replaceAll("==", " == ") + .replaceAll("\\);", "); ") + .replaceAll("=\\(", " = (") + .replaceAll("defl", "def l") } it should "handle having with date functions" in { diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 37f16e24..6d0b9204 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -980,40 +980,48 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { SQLQuery(filterWithTimeAndInterval) val query = select.query println(query) - query shouldBe - """{ - | "query": { - | "bool": { - | "filter": [ - | { - | "script": { - | "script": { - | "lang": "painless", - | "source": "doc['createdAt'].value.toLocalTime() < ZonedDateTime.now(ZoneId.of('Z')).toLocalTime()" - | } - | } - | }, - | { - | "script": { - | "script": { - | "lang": "painless", - | "source": "doc['createdAt'].value.toLocalTime() >= ZonedDateTime.now(ZoneId.of('Z')).toLocalTime().minus(10, ChronoUnit.MINUTES)" - | } - | } - | } - | ] - | } - | }, - | "_source": { - | "includes": [ - | "*" - | ] - | } - |}""".stripMargin - .replaceAll("\\s", "") - .replaceAll("ChronoUnit", " ChronoUnit") - .replaceAll(">=", " >= ") - .replaceAll("<", " < ") + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "def left = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); left == null ? false : left < ZonedDateTime.now(ZoneId.of('Z')).toLocalTime()" + | } + | } + | }, + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "def left = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); left == null ? false : left >= ZonedDateTime.now(ZoneId.of('Z')).toLocalTime().minus(10, ChronoUnit.MINUTES)" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(">=", " >= ") + .replaceAll("<", " < ") + .replaceAll("\\|\\|", " || ") + .replaceAll("null:", "null : ") + .replaceAll("false:", "false : ") + .replaceAll(":null", " : null ") + .replaceAll("\\?", " ? ") + .replaceAll("==", " == ") + .replaceAll("\\);", "); ") + .replaceAll("=\\(", " = (") + .replaceAll("defl", "def l") } it should "handle having with date functions" in { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 26635eab..57c9ee29 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -211,7 +211,7 @@ sealed trait TimeInterval extends PainlessScript with MathScript { override def script: String = TimeInterval.script(this) - def applyType(in: SQLType): Either[String, SQLType] = { + def checkType(in: SQLType): Either[String, SQLType] = { import TimeUnit._ in match { case SQLTypes.Date => @@ -254,20 +254,18 @@ object TimeInterval { } } -sealed trait SQLIntervalFunction extends SQLArithmeticFunction[SQLTemporal, SQLTemporal] { +sealed trait SQLIntervalFunction[IO <: SQLTemporal] extends SQLArithmeticFunction[IO, IO] { def interval: TimeInterval - override def inputType: SQLTemporal = SQLTypes.Temporal - override def outputType: SQLTemporal = SQLTypes.Temporal override def script: String = s"${operator.script}${interval.script}" private[this] var _out: SQLType = outputType override def out: SQLType = _out override def applyType(in: SQLType): SQLType = { - _out = interval.applyType(in).getOrElse(out) + _out = interval.checkType(in).getOrElse(out) _out } - override def validate(): Either[String, Unit] = interval.applyType(out) match { + override def validate(): Either[String, Unit] = interval.checkType(out) match { case Left(err) => Left(err) case Right(_) => Right(()) } @@ -279,20 +277,30 @@ sealed trait SQLIntervalFunction extends SQLArithmeticFunction[SQLTemporal, SQLT s"${SQLTypeUtils.coerce(base, expr.out, out, nullable = expr.nullable)}$painless" } -case class SQLAddInterval(interval: TimeInterval) - extends SQLExpr(interval.sql) - with SQLIntervalFunction { +sealed trait AddInterval[IO <: SQLTemporal] extends SQLIntervalFunction[IO] { override def operator: ArithmeticOperator = Add override def painless: String = s".plus(${interval.painless})" } -case class SQLSubtractInterval(interval: TimeInterval) - extends SQLExpr(interval.sql) - with SQLIntervalFunction { +sealed trait SubtractInterval[IO <: SQLTemporal] extends SQLIntervalFunction[IO] { override def operator: ArithmeticOperator = Subtract override def painless: String = s".minus(${interval.painless})" } +case class SQLAddInterval(interval: TimeInterval) + extends SQLExpr(interval.sql) + with AddInterval[SQLTemporal] { + override def inputType: SQLTemporal = SQLTypes.Temporal + override def outputType: SQLTemporal = SQLTypes.Temporal +} + +case class SQLSubtractInterval(interval: TimeInterval) + extends SQLExpr(interval.sql) + with SubtractInterval[SQLTemporal] { + override def inputType: SQLTemporal = SQLTypes.Temporal + override def outputType: SQLTemporal = SQLTypes.Temporal +} + sealed trait DateTimeFunction extends SQLFunction { def now: String = "ZonedDateTime.now(ZoneId.of('Z'))" override def out: SQLType = SQLTypes.DateTime @@ -418,6 +426,7 @@ case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit) case class DateAdd(identifier: SQLIdentifier, interval: TimeInterval) extends SQLExpr("date_add") with DateFunction + with AddInterval[SQLDate] with SQLTransformFunction[SQLDate, SQLDate] with SQLFunctionWithIdentifier { override def inputType: SQLDate = SQLTypes.Date @@ -425,12 +434,12 @@ case class DateAdd(identifier: SQLIdentifier, interval: TimeInterval) override def toSQL(base: String): String = { s"$sql($base, ${interval.sql})" } - override def painless: String = s".plus(${interval.painless})" } case class DateSub(identifier: SQLIdentifier, interval: TimeInterval) extends SQLExpr("date_sub") with DateFunction + with SubtractInterval[SQLDate] with SQLTransformFunction[SQLDate, SQLDate] with SQLFunctionWithIdentifier { override def inputType: SQLDate = SQLTypes.Date @@ -438,7 +447,6 @@ case class DateSub(identifier: SQLIdentifier, interval: TimeInterval) override def toSQL(base: String): String = { s"$sql($base, ${interval.sql})" } - override def painless: String = s".minus(${interval.painless})" } case class ParseDate(identifier: SQLIdentifier, format: String) @@ -480,6 +488,7 @@ case class FormatDate(identifier: SQLIdentifier, format: String) case class DateTimeAdd(identifier: SQLIdentifier, interval: TimeInterval) extends SQLExpr("datetime_add") with DateTimeFunction + with AddInterval[SQLDateTime] with SQLTransformFunction[SQLDateTime, SQLDateTime] with SQLFunctionWithIdentifier { override def inputType: SQLDateTime = SQLTypes.DateTime @@ -487,12 +496,12 @@ case class DateTimeAdd(identifier: SQLIdentifier, interval: TimeInterval) override def toSQL(base: String): String = { s"$sql($base, ${interval.sql})" } - override def painless: String = s".plus(${interval.painless})" } case class DateTimeSub(identifier: SQLIdentifier, interval: TimeInterval) extends SQLExpr("datetime_sub") with DateTimeFunction + with SubtractInterval[SQLDateTime] with SQLTransformFunction[SQLDateTime, SQLDateTime] with SQLFunctionWithIdentifier { override def inputType: SQLDateTime = SQLTypes.DateTime @@ -500,7 +509,6 @@ case class DateTimeSub(identifier: SQLIdentifier, interval: TimeInterval) override def toSQL(base: String): String = { s"$sql($base, ${interval.sql})" } - override def painless: String = s".minus(${interval.painless})" } case class ParseDateTime(identifier: SQLIdentifier, format: String) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index bb544630..82a5fc36 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -1,8 +1,10 @@ package app.softnetwork.elastic.sql +import scala.annotation.tailrec + case object Where extends SQLExpr("where") with SQLRegex -sealed trait SQLCriteria extends Updateable { +sealed trait SQLCriteria extends Updateable with PainlessScript { def operator: SQLOperator def nested: Boolean = false @@ -28,6 +30,28 @@ sealed trait SQLCriteria extends Updateable { } } + private[this] def asGroup(script: String): String = if (group) s"($script)" else script + + override def out: SQLType = SQLTypes.Boolean + + override def painless: String = this match { + case SQLPredicate(left, op, right, maybeNot, group) => + val leftStr = left.painless + val rightStr = right.painless + val opStr = op match { + case And | Or => op.painless + case _ => throw new IllegalArgumentException(s"Unsupported logical operator: $op") + } + val not = maybeNot.nonEmpty + if (group || not) + s"${maybeNot.map(_.painless).getOrElse("")}($leftStr $opStr $rightStr)" + else + s"$leftStr $opStr $rightStr" + case relation: ElasticRelation => asGroup(relation.criteria.painless) + case m: SQLMatch => asGroup(m.criteria.painless) + case expr: Expression => asGroup(expr.painless) + case _ => throw new IllegalArgumentException(s"Unsupported criteria: $this") + } } case class SQLPredicate( @@ -94,15 +118,6 @@ case class SQLPredicate( sealed trait ElasticFilter -sealed trait SQLCriteriaWithIdentifier extends SQLCriteria with SQLFunctionChain { - def identifier: SQLIdentifier - override def nested: Boolean = identifier.nested - override def group: Boolean = false - override lazy val limit: Option[SQLLimit] = identifier.limit - override val functions: List[SQLFunction] = identifier.functions - override def validate(): Either[String, Unit] = identifier.validate() -} - case class ElasticBoolQuery( var innerFilters: Seq[ElasticFilter] = Nil, var mustFilters: Seq[ElasticFilter] = Nil, @@ -157,7 +172,16 @@ case class ElasticBoolQuery( } -sealed trait Expression extends SQLCriteriaWithIdentifier with ElasticFilter with PainlessScript { +sealed trait Expression + extends SQLCriteria + with SQLFunctionChain + with ElasticFilter + with PainlessScript { + def identifier: SQLIdentifier + override def nested: Boolean = identifier.nested + override def group: Boolean = false + override lazy val limit: Option[SQLLimit] = identifier.limit + override val functions: List[SQLFunction] = identifier.functions def maybeValue: Option[SQLToken] def maybeNot: Option[Not.type] def notAsString: String = maybeNot.map(v => s"$v ").getOrElse("") @@ -193,27 +217,30 @@ sealed trait Expression extends SQLCriteriaWithIdentifier with ElasticFilter wit } } - private[this] lazy val base: String = - out match { - case SQLTypes.Time if identifier.paramName.nonEmpty => - s"${identifier.paramName}.toLocalTime()" - case _ => identifier.paramName - } - - override def painless: String = - s"$painlessNot${identifier.toPainless(base)} $painlessOp $painlessValue" - - override def out: SQLType = - (identifier.out, maybeValue) match { - case (idType, Some(v)) => + private[this] lazy val left: String = { + val targetedType = maybeValue match { + case Some(v) => v match { case value: SQLValue[_] => value.out case values: SQLValues[_, _] => values.out - case id: SQLIdentifier => id.out - case _ => idType + case other => other.out } - case (idType, None) => idType + case None => identifier.out } + SQLTypeUtils.coerce(identifier, targetedType) + } + + private[this] lazy val check: String = + operator match { + case _: SQLComparisonOperator => s" $painlessOp $painlessValue" + case _ => s"$painlessOp($painlessValue)" + } + override def painless: String = { + if (identifier.nullable) { + return s"def left = $left; left == null ? false : ${painlessNot}left$check" + } + s"$painlessNot$left$check" + } override def validate(): Either[String, Unit] = { for { @@ -366,7 +393,7 @@ case class SQLIn[R, +T <: SQLValue[R]]( override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this - override def painless: String = s"$painlessNot${identifier.painless} $painlessOp($painlessValue)" + override def painless: String = s"$painlessNot${identifier.painless}$painlessOp($painlessValue)" } case class SQLBetween[+T]( @@ -425,15 +452,19 @@ case class SQLMatch( override lazy val nested: Boolean = identifiers.forall(_.nested) + @tailrec + private[this] def toCriteria(matches: List[ElasticMatch], curr: SQLCriteria): SQLCriteria = + matches match { + case Nil => curr + case single :: Nil => SQLPredicate(curr, Or, single, group = true) + case first :: rest => toCriteria(rest, SQLPredicate(curr, Or, first, group = true)) + } + lazy val criteria: SQLCriteria = { identifiers.map(id => ElasticMatch(id, value, None)) match { case Nil => throw new IllegalArgumentException("No identifiers for MATCH") case single :: Nil => single - case first :: second :: rest => - val initial: SQLCriteria = SQLPredicate(first, Or, second) - rest.foldLeft(initial) { (acc, next) => - SQLPredicate(acc, Or, next) - } + case first :: rest => toCriteria(rest, first) } } @@ -466,7 +497,7 @@ case class ElasticMatch( override def matchCriteria: Boolean = true - override def painless: String = s"$painlessNot${identifier.painless} $painlessOp($painlessValue)" + override def painless: String = s"$painlessNot${identifier.painless}$painlessOp($painlessValue)" } sealed abstract class ElasticRelation(val criteria: SQLCriteria, val operator: ElasticOperator) @@ -499,7 +530,7 @@ case class ElasticNested(override val criteria: SQLCriteria, override val limit: private[this] def name(criteria: SQLCriteria): Option[String] = criteria match { case SQLPredicate(left, _, right, _, _) => name(left).orElse(name(right)) - case c: SQLCriteriaWithIdentifier => + case c: Expression => c.identifier.innerHitsName.orElse(c.identifier.name.split('.').headOption) case n: ElasticNested => name(n.criteria) case _ => None From 2ca9f7c9c6ec2bbb783cbd0c3d0c309475290ba2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Sat, 13 Sep 2025 14:12:04 +0200 Subject: [PATCH 11/18] fix match criteria with group --- .../scala/app/softnetwork/elastic/sql/SQLWhere.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index 82a5fc36..676b3a61 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -456,17 +456,19 @@ case class SQLMatch( private[this] def toCriteria(matches: List[ElasticMatch], curr: SQLCriteria): SQLCriteria = matches match { case Nil => curr - case single :: Nil => SQLPredicate(curr, Or, single, group = true) - case first :: rest => toCriteria(rest, SQLPredicate(curr, Or, first, group = true)) + case single :: Nil => SQLPredicate(curr, Or, single) + case first :: rest => toCriteria(rest, SQLPredicate(curr, Or, first)) } - lazy val criteria: SQLCriteria = { - identifiers.map(id => ElasticMatch(id, value, None)) match { + lazy val criteria: SQLCriteria = + (identifiers.map(id => ElasticMatch(id, value, None)) match { case Nil => throw new IllegalArgumentException("No identifiers for MATCH") case single :: Nil => single case first :: rest => toCriteria(rest, first) + }) match { + case p: SQLPredicate => p.copy(group = true) + case other => other } - } override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = criteria match { case predicate: SQLPredicate => predicate.copy(group = true).asFilter(currentQuery) From 1dc428d22302b4b76b886998e3f5a648430b8693 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Sat, 13 Sep 2025 14:36:16 +0200 Subject: [PATCH 12/18] add null value, fix null expr within function --- .../scala/app/softnetwork/elastic/sql/SQLFunction.scala | 2 +- .../main/scala/app/softnetwork/elastic/sql/package.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 57c9ee29..3d7689e8 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -6,7 +6,7 @@ import scala.util.matching.Regex sealed trait SQLFunction extends SQLRegex { def toSQL(base: String): String = if (base.nonEmpty) s"$sql($base)" else sql def applyType(in: SQLType): SQLType = out - var expr: SQLToken = _ + var expr: SQLToken = SQLNull def applyTo(expr: SQLToken): Unit = { this.expr = expr } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index fac88a1e..ded40c0e 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -80,6 +80,13 @@ package object sql { override def nullable: Boolean = false } + case object SQLNull extends SQLValue[Null](null) { + override def sql: String = "null" + override def painless: String = "null" + override def nullable: Boolean = true + override def out: SQLType = SQLTypes.Null + } + case class SQLBoolean(override val value: Boolean) extends SQLValue[Boolean](value) { override def sql: String = value.toString override def out: SQLType = SQLTypes.Boolean From b6605e7c6ce1cec320760f5bbda4e65d2d9d4a29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Sun, 14 Sep 2025 08:53:59 +0200 Subject: [PATCH 13/18] updat sql types --- .../elastic/sql/bridge/package.scala | 6 +- .../elastic/sql/bridge/package.scala | 6 +- .../softnetwork/elastic/sql/SQLFunction.scala | 24 ++--- .../softnetwork/elastic/sql/SQLOperator.scala | 8 +- .../softnetwork/elastic/sql/SQLParser.scala | 35 ++++---- .../app/softnetwork/elastic/sql/SQLType.scala | 23 ++++- .../elastic/sql/SQLTypeUtils.scala | 39 +++++--- .../softnetwork/elastic/sql/SQLTypes.scala | 32 +++++-- .../softnetwork/elastic/sql/SQLWhere.scala | 10 +-- .../app/softnetwork/elastic/sql/package.scala | 90 ++++++++++++++----- ...ralSpec.scala => SQLStringValueSpec.scala} | 4 +- 11 files changed, 188 insertions(+), 89 deletions(-) rename sql/src/test/scala/app/softnetwork/elastic/sql/{SQLLiteralSpec.scala => SQLStringValueSpec.scala} (81%) diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index 52410f0d..36efad37 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -136,7 +136,7 @@ package object bridge { ) } - def applyNumericOp[A](n: SQLNumeric[_])( + def applyNumericOp[A](n: SQLNumericValue[_])( longOp: Long => A, doubleOp: Double => A ): A = n.toEither.fold(longOp, doubleOp) @@ -149,7 +149,7 @@ package object bridge { return scriptQuery(Script(script = painless).lang("painless").scriptType("source")) } value match { - case n: SQLNumeric[_] => + case n: SQLNumericValue[_] => operator match { case Ge => maybeNot match { @@ -231,7 +231,7 @@ package object bridge { } case _ => matchAllQuery() } - case l: SQLLiteral => + case l: SQLStringValue => operator match { case Like => maybeNot match { diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index 4089249f..84d5b845 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -137,7 +137,7 @@ package object bridge { ) } - def applyNumericOp[A](n: SQLNumeric[_])( + def applyNumericOp[A](n: SQLNumericValue[_])( longOp: Long => A, doubleOp: Double => A ): A = n.toEither.fold(longOp, doubleOp) @@ -150,7 +150,7 @@ package object bridge { return scriptQuery(Script(script = painless).lang("painless").scriptType("source")) } value match { - case n: SQLNumeric[_] => + case n: SQLNumericValue[_] => operator match { case Ge => maybeNot match { @@ -232,7 +232,7 @@ package object bridge { } case _ => matchAllQuery() } - case l: SQLLiteral => + case l: SQLStringValue => operator match { case Like => maybeNot match { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 3d7689e8..f3995369 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -368,10 +368,10 @@ case class DateTrunc(identifier: SQLIdentifier, unit: TimeUnit) case class Extract(unit: TimeUnit, override val sql: String = "extract") extends SQLExpr(sql) with DateTimeFunction - with SQLTransformFunction[SQLTemporal, SQLNumber] + with SQLTransformFunction[SQLTemporal, SQLNumeric] with ParametrizedFunction { override def inputType: SQLTemporal = SQLTypes.Temporal - override def outputType: SQLNumber = SQLTypes.Number + override def outputType: SQLNumeric = SQLTypes.Numeric override def params: Seq[String] = Seq(unit.sql) override def painless: String = s".get(${unit.painless})" } @@ -403,9 +403,9 @@ object SECOND extends Extract(Second, Second.sql) { case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit) extends SQLExpr("date_diff") with DateTimeFunction - with SQLBinaryFunction[SQLDateTime, SQLDateTime, SQLNumber] + with SQLBinaryFunction[SQLDateTime, SQLDateTime, SQLNumeric] with PainlessScript { - override def outputType: SQLNumber = SQLTypes.Number + override def outputType: SQLNumeric = SQLTypes.Numeric override def left: PainlessScript = end override def right: PainlessScript = start override def toSQL(base: String): String = { @@ -452,9 +452,9 @@ case class DateSub(identifier: SQLIdentifier, interval: TimeInterval) case class ParseDate(identifier: SQLIdentifier, format: String) extends SQLExpr("parse_date") with DateFunction - with SQLTransformFunction[SQLString, SQLDate] + with SQLTransformFunction[SQLVarchar, SQLDate] with SQLFunctionWithIdentifier { - override def inputType: SQLString = SQLTypes.String + override def inputType: SQLVarchar = SQLTypes.Varchar override def outputType: SQLDate = SQLTypes.Date override def toSQL(base: String): String = { s"$sql($base, '$format')" @@ -470,10 +470,10 @@ case class ParseDate(identifier: SQLIdentifier, format: String) case class FormatDate(identifier: SQLIdentifier, format: String) extends SQLExpr("format_date") with DateFunction - with SQLTransformFunction[SQLDate, SQLString] + with SQLTransformFunction[SQLDate, SQLVarchar] with SQLFunctionWithIdentifier { override def inputType: SQLDate = SQLTypes.Date - override def outputType: SQLString = SQLTypes.String + override def outputType: SQLVarchar = SQLTypes.Varchar override def toSQL(base: String): String = { s"$sql($base, '$format')" } @@ -514,9 +514,9 @@ case class DateTimeSub(identifier: SQLIdentifier, interval: TimeInterval) case class ParseDateTime(identifier: SQLIdentifier, format: String) extends SQLExpr("parse_datetime") with DateTimeFunction - with SQLTransformFunction[SQLString, SQLDateTime] + with SQLTransformFunction[SQLVarchar, SQLDateTime] with SQLFunctionWithIdentifier { - override def inputType: SQLString = SQLTypes.String + override def inputType: SQLVarchar = SQLTypes.Varchar override def outputType: SQLDateTime = SQLTypes.DateTime override def toSQL(base: String): String = { s"$sql($base, '$format')" @@ -532,10 +532,10 @@ case class ParseDateTime(identifier: SQLIdentifier, format: String) case class FormatDateTime(identifier: SQLIdentifier, format: String) extends SQLExpr("format_datetime") with DateTimeFunction - with SQLTransformFunction[SQLDateTime, SQLString] + with SQLTransformFunction[SQLDateTime, SQLVarchar] with SQLFunctionWithIdentifier { override def inputType: SQLDateTime = SQLTypes.DateTime - override def outputType: SQLString = SQLTypes.String + override def outputType: SQLVarchar = SQLTypes.Varchar override def toSQL(base: String): String = { s"$sql($base, '$format')" } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala index 023db2dc..1e9eb1fb 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala @@ -64,10 +64,10 @@ case object NullIf extends SQLExpr("nullif") with SQLConditionalOperator case object Exists extends SQLExpr("exists") with SQLConditionalOperator case object Cast extends SQLExpr("cast") with SQLConditionalOperator case object Case extends SQLExpr("case") with SQLConditionalOperator -case object When extends SQLExpr("when") with SQLConditionalOperator -case object Then extends SQLExpr("then") with SQLConditionalOperator -case object Else extends SQLExpr("else") with SQLConditionalOperator -case object End extends SQLExpr("end") with SQLConditionalOperator +case object When extends SQLExpr("when") with SQLRegex +case object Then extends SQLExpr("then") with SQLRegex +case object Else extends SQLExpr("else") with SQLRegex +case object End extends SQLExpr("end") with SQLRegex sealed trait SQLPredicateOperator extends SQLLogicalOperator diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index e6e36281..e32dd01f 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -59,12 +59,14 @@ case class SQLParserError(msg: String) extends SQLCompilationError trait SQLParser extends RegexParsers with PackratParsers { - def literal: PackratParser[SQLLiteral] = - """"[^"]*"|'[^']*'""".r ^^ (str => SQLLiteral(str.substring(1, str.length - 1))) + def literal: PackratParser[SQLStringValue] = + """"[^"]*"|'[^']*'""".r ^^ (str => SQLStringValue(str.substring(1, str.length - 1))) - def long: PackratParser[SQLLong] = """(-)?(0|[1-9]\d*)""".r ^^ (str => SQLLong(str.toLong)) + def long: PackratParser[SQLLongValue] = + """(-)?(0|[1-9]\d*)""".r ^^ (str => SQLLongValue(str.toLong)) - def double: PackratParser[SQLDouble] = """(-)?(\d+\.\d+)""".r ^^ (str => SQLDouble(str.toDouble)) + def double: PackratParser[SQLDoubleValue] = + """(-)?(\d+\.\d+)""".r ^^ (str => SQLDoubleValue(str.toDouble)) def boolean: PackratParser[SQLBoolean] = """(true|false)""".r ^^ (bool => SQLBoolean(bool.toBoolean)) @@ -171,29 +173,30 @@ trait SQLParser extends RegexParsers with PackratParsers { DateTrunc(i, u) } - def extract: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + def extract: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = "(?i)extract".r ~ start ~ time_unit ~ end ^^ { case _ ~ _ ~ u ~ _ => Extract(u) } - def extract_year: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + def extract_year: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = Year.regex ^^ (_ => YEAR) - def extract_month: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + def extract_month: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = Month.regex ^^ (_ => MONTH) - def extract_day: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = Day.regex ^^ (_ => DAY) + def extract_day: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = + Day.regex ^^ (_ => DAY) - def extract_hour: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + def extract_hour: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = Hour.regex ^^ (_ => HOUR) - def extract_minute: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + def extract_minute: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = Minute.regex ^^ (_ => MINUTE) - def extract_second: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + def extract_second: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = Second.regex ^^ (_ => SECOND) - def extractors: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + def extractors: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = extract | extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second def date_add: PackratParser[DateFunction with SQLFunctionWithIdentifier] = @@ -212,7 +215,7 @@ trait SQLParser extends RegexParsers with PackratParsers { "(?i)parse_date".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | literal | identifier) ~ separator ~ literal ~ end ^^ { case _ ~ _ ~ li ~ _ ~ f ~ _ => li match { - case l: SQLLiteral => + case l: SQLStringValue => ParseDate(SQLIdentifier("", functions = l :: Nil), f.value) case i: SQLIdentifier => ParseDate(i, f.value) @@ -416,7 +419,7 @@ trait SQLParser extends RegexParsers with PackratParsers { ) } - def string_type: PackratParser[SQLTypes.String.type] = "(?i)string".r ^^ (_ => SQLTypes.String) + def string_type: PackratParser[SQLTypes.Varchar.type] = "(?i)string".r ^^ (_ => SQLTypes.Varchar) def date_type: PackratParser[SQLTypes.Date.type] = "(?i)date".r ^^ (_ => SQLTypes.Date) @@ -431,7 +434,7 @@ trait SQLParser extends RegexParsers with PackratParsers { def boolean_type: PackratParser[SQLTypes.Boolean.type] = "(?i)boolean".r ^^ (_ => SQLTypes.Boolean) - def long_type: PackratParser[SQLTypes.Long.type] = "(?i)long".r ^^ (_ => SQLTypes.Long) + def long_type: PackratParser[SQLTypes.BigInt.type] = "(?i)long".r ^^ (_ => SQLTypes.BigInt) def double_type: PackratParser[SQLTypes.Double.type] = "(?i)double".r ^^ (_ => SQLTypes.Double) @@ -600,7 +603,7 @@ trait SQLWhereParser { case i ~ n ~ _ ~ _ ~ v ~ _ => SQLIn( i, - SQLLiteralValues(v), + SQLStringValues(v), n ) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala index cc8bda0c..b0aea9da 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala @@ -3,10 +3,29 @@ package app.softnetwork.elastic.sql sealed trait SQLType { def typeId: String } trait SQLAny extends SQLType + trait SQLTemporal extends SQLType + trait SQLDate extends SQLTemporal trait SQLTime extends SQLTemporal trait SQLDateTime extends SQLTemporal -trait SQLNumber extends SQLType -trait SQLString extends SQLType +trait SQLTimestamp extends SQLDateTime + +trait SQLNumeric extends SQLType + +trait SQLTinyInt extends SQLNumeric +trait SQLSmallInt extends SQLNumeric +trait SQLInt extends SQLNumeric +trait SQLBigInt extends SQLNumeric +trait SQLDouble extends SQLNumeric +trait SQLReal extends SQLNumeric + +trait SQLLiteral extends SQLType +trait SQLVarchar extends SQLLiteral +trait SQLChar extends SQLLiteral + trait SQLBool extends SQLType + +trait SQLArray extends SQLType { def elementType: SQLType } + +trait SQLStruct extends SQLType diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala index ae558147..7dc0911b 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala @@ -21,15 +21,29 @@ object SQLTypeUtils { ).contains( out.typeId )) || - (out.typeId == Number.typeId && Set(Int.typeId, Long.typeId, Double.typeId, Float.typeId) + (out.typeId == Numeric.typeId && Set( + TinyInt.typeId, + SmallInt.typeId, + Int.typeId, + BigInt.typeId, + Double.typeId, + Real.typeId + ) .contains( in.typeId )) || - (in.typeId == Number.typeId && Set(Int.typeId, Long.typeId, Double.typeId, Float.typeId) + (in.typeId == Numeric.typeId && Set( + TinyInt.typeId, + SmallInt.typeId, + Int.typeId, + BigInt.typeId, + Double.typeId, + Real.typeId + ) .contains( out.typeId )) || - (out.typeId == String.typeId && in.typeId == String.typeId) || + (out.typeId == Varchar.typeId && in.typeId == Varchar.typeId) || (out.typeId == Boolean.typeId && in.typeId == Boolean.typeId) || out.typeId == Any.typeId || in.typeId == Any.typeId || out.typeId == Null.typeId || in.typeId == Null.typeId @@ -39,13 +53,16 @@ object SQLTypeUtils { if (distinct.size == 1) return distinct.head // 1. String - if (distinct.contains(SQLTypes.String)) return SQLTypes.String + if (distinct.contains(SQLTypes.Varchar)) return SQLTypes.Varchar // 2. Number if (distinct.contains(SQLTypes.Double)) return SQLTypes.Double - if (distinct.contains(SQLTypes.Long)) return SQLTypes.Long + if (distinct.contains(SQLTypes.Real)) return SQLTypes.Real + if (distinct.contains(SQLTypes.BigInt)) return SQLTypes.BigInt if (distinct.contains(SQLTypes.Int)) return SQLTypes.Int - if (distinct.contains(SQLTypes.Number)) return SQLTypes.Number + if (distinct.contains(SQLTypes.SmallInt)) return SQLTypes.SmallInt + if (distinct.contains(SQLTypes.TinyInt)) return SQLTypes.TinyInt + if (distinct.contains(SQLTypes.Numeric)) return SQLTypes.Numeric // 3. Temporal if (distinct.contains(SQLTypes.Timestamp)) return SQLTypes.Timestamp @@ -86,21 +103,21 @@ object SQLTypeUtils { s"($expr).toLocalTime()" // ---- NUMERIQUES ---- - case (SQLTypes.Int, SQLTypes.Long) => + case (SQLTypes.Int, SQLTypes.BigInt) => s"((long) $expr)" case (SQLTypes.Int, SQLTypes.Double) => s"((double) $expr)" - case (SQLTypes.Long, SQLTypes.Double) => + case (SQLTypes.BigInt, SQLTypes.Double) => s"((double) $expr)" // ---- NUMERIC <-> TEMPORAL ---- - case (SQLTypes.Long, SQLTypes.Timestamp) => + case (SQLTypes.BigInt, SQLTypes.Timestamp) => s"Instant.ofEpochMilli($expr).atZone(ZoneId.of('Z'))" - case (SQLTypes.Timestamp, SQLTypes.Long) => + case (SQLTypes.Timestamp, SQLTypes.BigInt) => s"$expr.toInstant().toEpochMilli()" // ---- BOOLEEN -> NUMERIC ---- - case (SQLTypes.Boolean, SQLTypes.Number) => + case (SQLTypes.Boolean, SQLTypes.Numeric) => s"($expr ? 1 : 0)" // ---- IDENTITY ---- diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala index 2dea0813..df82351e 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala @@ -2,17 +2,35 @@ package app.softnetwork.elastic.sql object SQLTypes { case object Any extends SQLAny { val typeId = "any" } + case object Null extends SQLAny { val typeId = "null" } + case object Temporal extends SQLTemporal { val typeId = "temporal" } + case object Date extends SQLTemporal with SQLDate { val typeId = "date" } case object Time extends SQLTemporal with SQLTime { val typeId = "time" } case object DateTime extends SQLTemporal with SQLDateTime { val typeId = "datetime" } - case object Timestamp extends SQLTemporal with SQLDateTime { val typeId = "timestamp" } - case object Number extends SQLNumber { val typeId = "number" } - case object Int extends SQLNumber { val typeId = "integer" } - case object Long extends SQLNumber { val typeId = "long" } - case object Double extends SQLNumber { val typeId = "double" } - case object Float extends SQLNumber { val typeId = "float" } - case object String extends SQLString { val typeId = "string" } + case object Timestamp extends SQLTimestamp { val typeId = "timestamp" } + + case object Numeric extends SQLNumeric { val typeId = "numeric" } + + case object TinyInt extends SQLTinyInt { val typeId = "tinyint" } + case object SmallInt extends SQLSmallInt { val typeId = "smallint" } + case object Int extends SQLInt { val typeId = "int" } + case object BigInt extends SQLBigInt { val typeId = "bigint" } + case object Double extends SQLDouble { val typeId = "double" } + case object Real extends SQLReal { val typeId = "float" } + + case object Literal extends SQLLiteral { val typeId = "literal" } + + case object Char extends SQLChar { val typeId = "char" } + case object Varchar extends SQLVarchar { val typeId = "varchar" } + case object Boolean extends SQLBool { val typeId = "boolean" } + + case class Array(elementType: SQLType) extends SQLArray { + val typeId = s"array<${elementType.typeId}>" + } + + case object Struct extends SQLStruct { val typeId = "struct" } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index 676b3a61..11589d16 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -423,9 +423,9 @@ case class SQLBetween[+T]( case class ElasticGeoDistance( identifier: SQLIdentifier, - distance: SQLLiteral, - lat: SQLDouble, - lon: SQLDouble + distance: SQLStringValue, + lat: SQLDoubleValue, + lon: SQLDoubleValue ) extends Expression { override def sql = s"$Distance($identifier,($lat,$lon)) $operator $distance" override val functions: List[SQLFunction] = List(Distance) @@ -442,7 +442,7 @@ case class ElasticGeoDistance( case class SQLMatch( identifiers: Seq[SQLIdentifier], - value: SQLLiteral + value: SQLStringValue ) extends SQLCriteria { override def sql: String = s"$operator (${identifiers.mkString(",")}) $Against ($value)" @@ -482,7 +482,7 @@ case class SQLMatch( case class ElasticMatch( identifier: SQLIdentifier, - value: SQLLiteral, + value: SQLStringValue, options: Option[String] ) extends Expression { override def sql: String = diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index ded40c0e..89585f8d 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -43,11 +43,6 @@ package object sql { case object Distinct extends SQLExpr("distinct") with SQLRegex - case object Empty extends SQLExpr("") with PainlessScript { - override def painless: String = "" - override def nullable: Boolean = false - } - abstract class SQLValue[+T](val value: T)(implicit ev$1: T => Ordered[T]) extends SQLToken with PainlessScript @@ -92,7 +87,12 @@ package object sql { override def out: SQLType = SQLTypes.Boolean } - case class SQLLiteral(override val value: String) extends SQLValue[String](value) { + case class SQLCharValue(override val value: Char) extends SQLValue[Char](value) { + override def sql: String = s"""'$value'""" + override def out: SQLType = SQLTypes.Char + } + + case class SQLStringValue(override val value: String) extends SQLValue[String](value) { override def sql: String = s"""'$value'""" import SQLImplicits._ private lazy val pattern: Pattern = value.pattern @@ -118,10 +118,10 @@ package object sql { case _ => super.choose(values, operator, separator) } } - override def out: SQLType = SQLTypes.String + override def out: SQLType = SQLTypes.Varchar } - sealed abstract class SQLNumeric[T: Numeric](override val value: T)(implicit + sealed abstract class SQLNumericValue[T: Numeric](override val value: T)(implicit ev$1: T => Ordered[T] ) extends SQLValue[T](value) { override def sql: String = value.toString @@ -152,22 +152,38 @@ package object sql { def ne: Seq[T] => Boolean = { _.forall { _ != value } } - override def out: SQLType = SQLTypes.Number + override def out: SQLNumeric = SQLTypes.Numeric } - case class SQLLong(override val value: Long) extends SQLNumeric[Long](value) { - override def out: SQLType = SQLTypes.Long + case class SQLByteValue(override val value: Byte) extends SQLNumericValue[Byte](value) { + override def out: SQLNumeric = SQLTypes.TinyInt } - case class SQLDouble(override val value: Double) extends SQLNumeric[Double](value) { - override def out: SQLType = SQLTypes.Double + case class SQLShortValue(override val value: Short) extends SQLNumericValue[Short](value) { + override def out: SQLNumeric = SQLTypes.SmallInt + } + + case class SQLIntValue(override val value: Int) extends SQLNumericValue[Int](value) { + override def out: SQLNumeric = SQLTypes.Int + } + + case class SQLLongValue(override val value: Long) extends SQLNumericValue[Long](value) { + override def out: SQLNumeric = SQLTypes.BigInt + } + + case class SQLFloatValue(override val value: Float) extends SQLNumericValue[Float](value) { + override def out: SQLNumeric = SQLTypes.Real + } + + case class SQLDoubleValue(override val value: Double) extends SQLNumericValue[Double](value) { + override def out: SQLNumeric = SQLTypes.Double } sealed abstract class SQLFromTo[+T](val from: SQLValue[T], val to: SQLValue[T]) extends SQLToken { override def sql = s"${from.sql} and ${to.sql}" } - case class SQLLiteralFromTo(override val from: SQLLiteral, override val to: SQLLiteral) + case class SQLLiteralFromTo(override val from: SQLStringValue, override val to: SQLStringValue) extends SQLFromTo[String](from, to) { def between: Seq[String] => Boolean = { _.exists { s => s >= from.value && s <= to.value } @@ -177,7 +193,7 @@ package object sql { } } - case class SQLLongFromTo(override val from: SQLLong, override val to: SQLLong) + case class SQLLongFromTo(override val from: SQLLongValue, override val to: SQLLongValue) extends SQLFromTo[Long](from, to) { def between: Seq[Long] => Boolean = { _.exists { n => n >= from.value && n <= to.value } @@ -187,7 +203,7 @@ package object sql { } } - case class SQLDoubleFromTo(override val from: SQLDouble, override val to: SQLDouble) + case class SQLDoubleFromTo(override val from: SQLDoubleValue, override val to: SQLDoubleValue) extends SQLFromTo[Double](from, to) { def between: Seq[Double] => Boolean = { _.exists { n => n >= from.value && n <= to.value } @@ -204,9 +220,10 @@ package object sql { override def painless: String = s"[${values.map(_.painless).mkString(",")}]" lazy val innerValues: Seq[R] = values.map(_.value) override def nullable: Boolean = values.exists(_.nullable) + override def out: SQLArray = SQLTypes.Array(SQLTypes.Any) } - case class SQLLiteralValues(override val values: Seq[SQLLiteral]) + case class SQLStringValues(override val values: Seq[SQLStringValue]) extends SQLValues[String, SQLValue[String]](values) { def eq: Seq[String] => Boolean = { _.exists { s => innerValues.exists(_.contentEquals(s)) } @@ -214,22 +231,49 @@ package object sql { def ne: Seq[String] => Boolean = { _.forall { s => innerValues.forall(!_.contentEquals(s)) } } + override def out: SQLArray = SQLTypes.Array(SQLTypes.Varchar) } - class SQLNumericValues[R: TypeTag](override val values: Seq[SQLNumeric[R]]) - extends SQLValues[R, SQLNumeric[R]](values) { + class SQLNumericValues[R: TypeTag](override val values: Seq[SQLNumericValue[R]]) + extends SQLValues[R, SQLNumericValue[R]](values) { def eq: Seq[R] => Boolean = { _.exists { n => innerValues.contains(n) } } def ne: Seq[R] => Boolean = { _.forall { n => !innerValues.contains(n) } } + override def out: SQLArray = SQLTypes.Array(SQLTypes.Numeric) + } + + case class SQLByteValues(override val values: Seq[SQLByteValue]) + extends SQLNumericValues[Byte](values) { + override def out: SQLArray = SQLTypes.Array(SQLTypes.TinyInt) + } + + case class SQLShortValues(override val values: Seq[SQLShortValue]) + extends SQLNumericValues[Short](values) { + override def out: SQLArray = SQLTypes.Array(SQLTypes.SmallInt) + } + + case class SQLIntValues(override val values: Seq[SQLIntValue]) + extends SQLNumericValues[Int](values) { + override def out: SQLArray = SQLTypes.Array(SQLTypes.Int) + } + + case class SQLLongValues(override val values: Seq[SQLLongValue]) + extends SQLNumericValues[Long](values) { + override def out: SQLArray = SQLTypes.Array(SQLTypes.BigInt) } - case class SQLLongValues(override val values: Seq[SQLLong]) extends SQLNumericValues[Long](values) + case class SQLFloatValues(override val values: Seq[SQLFloatValue]) + extends SQLNumericValues[Float](values) { + override def out: SQLArray = SQLTypes.Array(SQLTypes.Real) + } - case class SQLDoubleValues(override val values: Seq[SQLDouble]) - extends SQLNumericValues[Double](values) + case class SQLDoubleValues(override val values: Seq[SQLDoubleValue]) + extends SQLNumericValues[Double](values) { + override def out: SQLArray = SQLTypes.Array(SQLTypes.Double) + } def choose[T]( values: Seq[T], @@ -386,6 +430,4 @@ package object sql { } } } - - case class SQLScript(script: String) extends SQLExpr(script) with MathScript } diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLLiteralSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLStringValueSpec.scala similarity index 81% rename from sql/src/test/scala/app/softnetwork/elastic/sql/SQLLiteralSpec.scala rename to sql/src/test/scala/app/softnetwork/elastic/sql/SQLStringValueSpec.scala index 16b10632..7f524305 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLLiteralSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLStringValueSpec.scala @@ -5,10 +5,10 @@ import org.scalatest.matchers.should.Matchers /** Created by smanciot on 17/02/17. */ -class SQLLiteralSpec extends AnyFlatSpec with Matchers { +class SQLStringValueSpec extends AnyFlatSpec with Matchers { "SQLLiteral" should "perform sql like" in { - val l = SQLLiteral("%dummy%") + val l = SQLStringValue("%dummy%") l.like(Seq("dummy")) should ===(true) l.like(Seq("aa dummy")) should ===(true) l.like(Seq("dummy bbb")) should ===(true) From ad71c925d072ebf7c9f576a9ce04e647633324aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Sun, 14 Sep 2025 09:32:14 +0200 Subject: [PATCH 14/18] to fix parser specifications --- sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala | 3 ++- .../test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index e32dd01f..4af4d9c6 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -382,6 +382,7 @@ trait SQLParser extends RegexParsers with PackratParsers { "int", "integer", "long", + "bigint", "double", "boolean", "time", @@ -434,7 +435,7 @@ trait SQLParser extends RegexParsers with PackratParsers { def boolean_type: PackratParser[SQLTypes.Boolean.type] = "(?i)boolean".r ^^ (_ => SQLTypes.Boolean) - def long_type: PackratParser[SQLTypes.BigInt.type] = "(?i)long".r ^^ (_ => SQLTypes.BigInt) + def long_type: PackratParser[SQLTypes.BigInt.type] = "(?i)long|bigint".r ^^ (_ => SQLTypes.BigInt) def double_type: PackratParser[SQLTypes.Double.type] = "(?i)double".r ^^ (_ => SQLTypes.Double) diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 0341943c..fefba19b 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -142,7 +142,7 @@ object Queries { val nullif: String = "select coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd') - interval 2 day), current_date) as c, identifier from Table" val cast: String = - "select cast(coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd')), current_date - interval 2 hour) long) as c, identifier from Table" + "select cast(coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd')), current_date - interval 2 hour) bigint) as c, identifier from Table" } /** Created by smanciot on 15/02/17. From fba02c9fa8054248c95f6e56313e767e158a63a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 15 Sep 2025 09:03:53 +0200 Subject: [PATCH 15/18] implements case when --- .../elastic/sql/SQLQuerySpec.scala | 102 ++++++++++++++ .../elastic/sql/SQLQuerySpec.scala | 102 ++++++++++++++ .../elastic/sql/SQLDelimiter.scala | 4 + .../softnetwork/elastic/sql/SQLFunction.scala | 124 ++++++++++++++++-- .../softnetwork/elastic/sql/SQLParser.scala | 58 ++++++-- .../softnetwork/elastic/sql/SQLWhere.scala | 32 +++-- .../app/softnetwork/elastic/sql/package.scala | 10 +- .../elastic/sql/SQLParserSpec.scala | 18 +++ 8 files changed, 416 insertions(+), 34 deletions(-) diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index c8d0d477..c393b79b 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1888,4 +1888,106 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll(",LocalDate", ", LocalDate") .replaceAll("=DateTimeFormatter", " = DateTimeFormatter") } + + it should "handle case function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(caseWhen) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ if (def left = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); left == null ? false : left > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)) return left; if (def left = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); left != null) return left; def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defd", " def d") + .replaceAll("defe", " def e") + .replaceAll("defl", " def l") + .replaceAll("if\\(", "if (") + .replaceAll("\\{if", "{ if") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("false:", "false : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll(";if", "; if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll(">", " > ") + .replaceAll("if \\(\\s*def", "if (def") + .replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle case with expression function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(caseWhenExpr) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS); def val0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); if (expr == val0) return val0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1; def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defd", " def d") + .replaceAll("defe", " def e") + .replaceAll("defl", " def l") + .replaceAll("if\\(", "if (") + .replaceAll("\\{if", "{ if") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("false:", "false : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll(";if", "; if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll(">", " > ") + .replaceAll("if \\(\\s*def", "if (def") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("=ZonedDateTime", " = ZonedDateTime") + } + } diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 6d0b9204..f4f249de 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1877,4 +1877,106 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll(",LocalDate", ", LocalDate") .replaceAll("=DateTimeFormatter", " = DateTimeFormatter") } + + it should "handle case function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(caseWhen) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ if (def left = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); left == null ? false : left > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)) return left; if (def left = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); left != null) return left; def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defd", " def d") + .replaceAll("defe", " def e") + .replaceAll("defl", " def l") + .replaceAll("if\\(", "if (") + .replaceAll("\\{if", "{ if") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("false:", "false : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll(";if", "; if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll(">", " > ") + .replaceAll("if \\(\\s*def", "if (def") + .replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle case with expression function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(caseWhenExpr) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS); def val0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); if (expr == val0) return val0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1; def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defd", " def d") + .replaceAll("defe", " def e") + .replaceAll("defl", " def l") + .replaceAll("if\\(", "if (") + .replaceAll("\\{if", "{ if") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("false:", "false : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll(";if", "; if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll(">", " > ") + .replaceAll("if \\(\\s*def", "if (def") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("=ZonedDateTime", " = ZonedDateTime") + } + } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLDelimiter.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLDelimiter.scala index cc6a7dc4..1b0b791b 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLDelimiter.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLDelimiter.scala @@ -4,7 +4,11 @@ sealed trait SQLDelimiter extends SQLToken sealed trait StartDelimiter extends SQLDelimiter case object StartPredicate extends SQLExpr("(") with StartDelimiter +case object StartCase extends SQLExpr("case") with StartDelimiter +case object WhenCase extends SQLExpr("when") with StartDelimiter sealed trait EndDelimiter extends SQLDelimiter case object EndPredicate extends SQLExpr(")") with EndDelimiter case object Separator extends SQLExpr(",") with EndDelimiter +case object EndCase extends SQLExpr("end") with EndDelimiter +case object ThenCase extends SQLExpr("then") with EndDelimiter diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index f3995369..8f4814df 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -1,16 +1,16 @@ package app.softnetwork.elastic.sql -import scala.util.Try import scala.util.matching.Regex sealed trait SQLFunction extends SQLRegex { def toSQL(base: String): String = if (base.nonEmpty) s"$sql($base)" else sql def applyType(in: SQLType): SQLType = out - var expr: SQLToken = SQLNull - def applyTo(expr: SQLToken): Unit = { - this.expr = expr + private[this] var _expr: SQLToken = SQLNull + def expr_=(e: SQLToken): Unit = { + _expr = e } - override def nullable: Boolean = Try(expr.nullable).getOrElse(true) + def expr: SQLToken = _expr + override def nullable: Boolean = expr.nullable } sealed trait SQLFunctionWithIdentifier extends SQLFunction { @@ -72,11 +72,10 @@ trait SQLFunctionChain extends SQLFunction { override def system: Boolean = functions.lastOption.exists(_.system) - override def applyTo(expr: SQLToken): Unit = { - super.applyTo(expr) - val orderedFunctions = functions.reverse - orderedFunctions.foldLeft(expr) { (currentExpr, fun) => - fun.applyTo(currentExpr) + def applyTo(expr: SQLToken): Unit = { + this.expr = expr + functions.reverse.foldLeft(expr) { (currentExpr, fun) => + fun.expr = currentExpr fun } } @@ -662,8 +661,105 @@ case class SQLCast(value: PainlessScript, targetType: SQLType, as: Boolean = tru override def toPainless(base: String, idx: Int): String = SQLTypeUtils.coerce(base, value.out, targetType, value.nullable) - /*if (nullable) - s"(def e$idx = $base; e$idx != null ? ${SQLTypeUtils.coerce(s"e$idx", value.out, out, nullable = false)}$painless : null)" - else - s"${SQLTypeUtils.coerce(base, value.out, targetType, nullable = value.nullable)}$painless"*/ +} + +case class SQLCaseWhen( + expression: Option[PainlessScript], + conditions: List[(PainlessScript, PainlessScript)], + default: Option[PainlessScript] +) extends SQLTransformFunction[SQLAny, SQLAny] { + override def inputType: SQLAny = SQLTypes.Any + override def outputType: SQLAny = SQLTypes.Any + + override def sql: String = { + val exprPart = expression.map(e => s"$Case ${e.sql}").getOrElse(Case.sql) + val whenThen = conditions + .map { case (cond, res) => s"$When ${cond.sql} $Then ${res.sql}" } + .mkString(" ") + val elsePart = default.map(d => s" $Else ${d.sql}").getOrElse("") + s"$exprPart $whenThen$elsePart $End" + } + + override def out: SQLType = + SQLTypeUtils.leastCommonSuperType( + conditions.map(_._2.out) ++ default.map(_.out).toList + ) + + override def applyType(in: SQLType): SQLType = out + + override def validate(): Either[String, Unit] = { + if (conditions.isEmpty) Left("CASE WHEN requires at least one condition") + else if ( + expression.isEmpty && conditions.exists { case (cond, _) => cond.out != SQLTypes.Boolean } + ) + Left("CASE WHEN conditions must be of type BOOLEAN") + else if ( + expression.isDefined && conditions.exists { case (cond, _) => + !SQLTypeUtils.matches(cond.out, expression.get.out) + } + ) + Left("CASE WHEN conditions must be of the same type as the expression") + else Right(()) + } + + override def painless: String = { + val base = + expression match { + case Some(expr) => + s"def expr = ${SQLTypeUtils.coerce(expr, expr.out)}; " + case _ => "" + } + val cases = conditions.zipWithIndex + .map { case ((cond, res), zindex) => + expression match { + case Some(expr) => + val c = SQLTypeUtils.coerce(cond, expr.out) + if (cond.sql == res.sql) { + s"def val$zindex = $c; if (expr == val$zindex) return val$zindex;" + } else { + val _res = { + res match { + case i: Identifier => + val name = i.name + cond match { + case e: Expression if e.identifier.name == name => + e.identifier.nullable = false + e + case i: Identifier if i.name == name => + i.nullable = false + i + case _ => res + } + case _ => res + } + } + val r = SQLTypeUtils.coerce(_res, out) + s"if (expr == $c) return $r;" + } + case None => + val c = SQLTypeUtils.coerce(cond, SQLTypes.Boolean) + val r = + cond match { + case e: Expression => + val name = e.identifier.name + res match { + case i: Identifier if i.name == name => "left" + case _ => SQLTypeUtils.coerce(res, out) + } + case _ => SQLTypeUtils.coerce(res, out) + } + s"if ($c) return $r;" + } + } + .mkString(" ") + val defaultCase = default + .map(d => s"def dval = ${SQLTypeUtils.coerce(d, out)}; return dval;") + .getOrElse("return null;") + s"{ $base$cases $defaultCase }" + } + + override def toPainless(base: String, idx: Int): String = s"$base$painless" + + override def nullable: Boolean = + conditions.exists { case (_, res) => res.nullable } || default.forall(_.nullable) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 4af4d9c6..94e977ef 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -57,7 +57,7 @@ trait SQLCompilationError case class SQLParserError(msg: String) extends SQLCompilationError -trait SQLParser extends RegexParsers with PackratParsers { +trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => def literal: PackratParser[SQLStringValue] = """"[^"]*"|'[^']*'""".r ^^ (str => SQLStringValue(str.substring(1, str.length - 1))) @@ -145,18 +145,18 @@ trait SQLParser extends RegexParsers with PackratParsers { def arithmeticOperator: PackratParser[ArithmeticOperator] = intervalOperator - def addInterval: PackratParser[SQLAddInterval] = + def add_interval: PackratParser[SQLAddInterval] = add ~ interval ^^ { case _ ~ it => SQLAddInterval(it) } - def substractInterval: PackratParser[SQLSubtractInterval] = + def substract_interval: PackratParser[SQLSubtractInterval] = subtract ~ interval ^^ { case _ ~ it => SQLSubtractInterval(it) } def intervalFunction: PackratParser[SQLArithmeticFunction[SQLTemporal, SQLTemporal]] = - addInterval | substractInterval + add_interval | substract_interval def identifierWithSystemFunction: PackratParser[SQLIdentifier] = (current_date | current_time | current_timestamp | now) ~ intervalFunction.? ^^ { @@ -293,6 +293,10 @@ trait SQLParser extends RegexParsers with PackratParsers { SQLIdentifier("", functions = dd :: Nil) } + def case_when_identifier: Parser[SQLIdentifier] = case_when ^^ { cw => + SQLIdentifier("", functions = cw :: Nil) + } + def is_null: PackratParser[SQLConditionalFunction[_]] = "(?i)isnull".r ~ start ~ (identifierWithTransformation | identifierWithArithmeticFunction | identifierWithTemporalFunction | identifier) ~ end ^^ { case _ ~ _ ~ i ~ _ => SQLIsNullFunction(i) @@ -303,7 +307,7 @@ trait SQLParser extends RegexParsers with PackratParsers { case _ ~ _ ~ i ~ _ => SQLIsNotNullFunction(i) } - private[this] def valueExpr: PackratParser[PainlessScript] = + def valueExpr: PackratParser[PainlessScript] = // les plus spécifiques en premier identifierWithTransformation | // transformations appliquées à un identifier date_diff_identifier | // date_diff(...) retournant un identifier-like @@ -330,8 +334,37 @@ trait SQLParser extends RegexParsers with PackratParsers { case _ ~ _ ~ id1 ~ _ ~ id2 ~ _ => SQLNullIf(id1, id2) } + def start_case: PackratParser[StartCase.type] = Case.regex ^^ (_ => StartCase) + + def when_case: PackratParser[WhenCase.type] = When.regex ^^ (_ => WhenCase) + + def then_case: PackratParser[ThenCase.type] = Then.regex ^^ (_ => ThenCase) + + def else_case: PackratParser[Else.type] = Else.regex ^^ (_ => Else) + + def end_case: PackratParser[EndCase.type] = End.regex ^^ (_ => EndCase) + + def case_condition: Parser[(PainlessScript, PainlessScript)] = + when_case ~ (whereCriteria | valueExpr) ~ then_case.? ~ valueExpr ^^ { case _ ~ c ~ _ ~ r => + c match { + case p: PainlessScript => p -> r + case rawTokens: List[SQLToken] => + processTokens(rawTokens) match { + case Some(criteria) => criteria -> r + case _ => SQLNull -> r + } + } + } + + def case_else: Parser[PainlessScript] = else_case ~ valueExpr ^^ { case _ ~ r => r } + + def case_when: PackratParser[SQLCaseWhen] = + start_case ~ valueExpr.? ~ rep1(case_condition) ~ case_else.? ~ end_case ^^ { + case _ ~ e ~ c ~ r ~ _ => SQLCaseWhen(e, c, r) + } + def logical_functions: PackratParser[SQLTransformFunction[_, _]] = - is_null | is_notnull | coalesce | nullif + is_null | is_notnull | coalesce | nullif | case_when def sql_functions: PackratParser[SQLFunction] = aggregates | distance | date_diff | date_trunc | extractors | date_functions | datetime_functions | logical_functions @@ -401,7 +434,12 @@ trait SQLParser extends RegexParsers with PackratParsers { "min", "max", "avg", - "sum" + "sum", + "case", + "when", + "then", + "else", + "end" ) private val identifierRegexStr = @@ -512,7 +550,7 @@ trait SQLParser extends RegexParsers with PackratParsers { def alias: PackratParser[SQLAlias] = Alias.regex.? ~ regexAlias.r ^^ { case _ ~ b => SQLAlias(b) } def field: PackratParser[Field] = - (identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier) ~ alias.? ^^ { + (identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | case_when_identifier | identifier) ~ alias.? ^^ { case i ~ a => SQLField(i, a) } @@ -714,7 +752,7 @@ trait SQLWhereParser { nestedCriteria | childCriteria | parentCriteria | criteria def whereCriteria: PackratParser[List[SQLToken]] = rep1( - allPredicate | allCriteria | start | or | and | end + allPredicate | allCriteria | start | or | and | end | then_case ) def where: PackratParser[SQLWhere] = @@ -804,6 +842,8 @@ trait SQLWhereParser { case _ => throw SQLValidationError("Invalid stack state for predicate creation") } + case ThenCase :: _ => + processTokensHelper(Nil, stack) // exit processing on THEN case (_: EndDelimiter) :: rest => processTokensHelper(rest, stack) // Ignore and move on case _ => processTokensHelper(Nil, stack) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index 11589d16..9e244acb 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -172,11 +172,7 @@ case class ElasticBoolQuery( } -sealed trait Expression - extends SQLCriteria - with SQLFunctionChain - with ElasticFilter - with PainlessScript { +sealed trait Expression extends SQLFunctionChain with ElasticFilter with SQLCriteria { // to fix output type as Boolean def identifier: SQLIdentifier override def nested: Boolean = identifier.nested override def group: Boolean = false @@ -210,14 +206,14 @@ sealed trait Expression case v: SQLIdentifier => v.painless case v => v.sql } - .getOrElse { + .getOrElse("") /*{ operator match { case IsNull | IsNotNull => "null" case _ => "" } - } + }*/ - private[this] lazy val left: String = { + protected lazy val left: String = { val targetedType = maybeValue match { case Some(v) => v match { @@ -230,11 +226,12 @@ sealed trait Expression SQLTypeUtils.coerce(identifier, targetedType) } - private[this] lazy val check: String = + protected lazy val check: String = operator match { case _: SQLComparisonOperator => s" $painlessOp $painlessValue" case _ => s"$painlessOp($painlessValue)" } + override def painless: String = { if (identifier.nullable) { return s"def left = $left; left == null ? false : ${painlessNot}left$check" @@ -250,7 +247,7 @@ sealed trait Expression v.validate() match { case Left(err) => Left(s"$err in expression: $this") case Right(_) => - SQLValidator.validateTypesMatching(out, v.out) match { + SQLValidator.validateTypesMatching(identifier.out, v.out) match { case Left(_) => Left( s"Type mismatch: '${out.typeId}' is not compatible with '${v.out.typeId}' in expression: $this" @@ -352,6 +349,13 @@ case class SQLIsNullCriteria(identifier: SQLIdentifier) } else updated } + override def painless: String = { + if (identifier.nullable) { + return s"def left = $left; left == null" + } + s"$painlessNot$left$check" + } + } case class SQLIsNotNullCriteria(identifier: SQLIdentifier) @@ -367,6 +371,14 @@ case class SQLIsNotNullCriteria(identifier: SQLIdentifier) } else updated } + + override def painless: String = { + if (identifier.nullable) { + return s"def left = $left; left != null" + } + s"$painlessNot$left$check" + } + } case class SQLIn[R, +T <: SQLValue[R]]( diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index 89585f8d..8ba264a1 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -370,7 +370,15 @@ package object sql { paramName ) - override def nullable: Boolean = this.name.nonEmpty && (!aggregation || functions.size > 1) + private[this] var _nullable = + this.name.nonEmpty && (!aggregation || functions.size > 1) + + def nullable_=(b: Boolean): Unit = { + _nullable = b + } + + override def nullable: Boolean = _nullable + } case class SQLIdentifier( diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index fefba19b..3e5ba1dc 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -143,6 +143,10 @@ object Queries { "select coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd') - interval 2 day), current_date) as c, identifier from Table" val cast: String = "select cast(coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd')), current_date - interval 2 hour) bigint) as c, identifier from Table" + val caseWhen: String = + "select case when lastUpdated > now - interval 7 day then lastUpdated when isnotnull(lastSeen) then lastSeen else createdAt end as c, identifier from Table" + val caseWhenExpr: String = + "select case now - interval 7 day when lastUpdated then lastUpdated when lastSeen then lastSeen else createdAt end as c, identifier from Table" } /** Created by smanciot on 15/02/17. @@ -565,4 +569,18 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { cast ) } + + it should "parse case when expression" in { + val result = SQLParser(caseWhen) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + caseWhen + ) + } + + it should "parse case when with expression" in { + val result = SQLParser(caseWhenExpr) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + caseWhenExpr + ) + } } From 9906da0f1a81ae3dc88e3a7b5fc08d9adc7817b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 15 Sep 2025 13:03:32 +0200 Subject: [PATCH 16/18] update painless for case when --- .../elastic/sql/SQLQuerySpec.scala | 5 +- .../elastic/sql/SQLQuerySpec.scala | 5 +- .../softnetwork/elastic/sql/SQLFunction.scala | 52 +++++++++---------- .../softnetwork/elastic/sql/SQLParser.scala | 8 +-- .../app/softnetwork/elastic/sql/package.scala | 7 ++- .../elastic/sql/SQLParserSpec.scala | 4 +- 6 files changed, 45 insertions(+), 36 deletions(-) diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index c393b79b..827cb07b 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1903,7 +1903,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "c": { | "script": { | "lang": "painless", - | "source": "{ if (def left = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); left == null ? false : left > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)) return left; if (def left = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); left != null) return left; def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | "source": "{ if (def left = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); left == null ? false : left > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)) return left; if (def left = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); left != null) return left.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" | } | } | }, @@ -1953,7 +1953,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "c": { | "script": { | "lang": "painless", - | "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS); def val0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); if (expr == val0) return val0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1; def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def val0 = e0 != null ? ((e0).atStartOfDay(ZoneId.of('Z')).minus(3, ChronoUnit.DAYS)).atStartOfDay(ZoneId.of('Z')) : null; if (expr == val0) return e0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" | } | } | }, @@ -1988,6 +1988,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("if \\(\\s*def", "if (def") .replaceAll("ChronoUnit", " ChronoUnit") .replaceAll("=ZonedDateTime", " = ZonedDateTime") + .replaceAll("=e", " = e") } } diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index f4f249de..bc17b6cc 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1892,7 +1892,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "c": { | "script": { | "lang": "painless", - | "source": "{ if (def left = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); left == null ? false : left > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)) return left; if (def left = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); left != null) return left; def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | "source": "{ if (def left = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); left == null ? false : left > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)) return left; if (def left = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); left != null) return left.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" | } | } | }, @@ -1942,7 +1942,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "c": { | "script": { | "lang": "painless", - | "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS); def val0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); if (expr == val0) return val0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1; def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def val0 = e0 != null ? ((e0).atStartOfDay(ZoneId.of('Z')).minus(3, ChronoUnit.DAYS)).atStartOfDay(ZoneId.of('Z')) : null; if (expr == val0) return e0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" | } | } | }, @@ -1977,6 +1977,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("if \\(\\s*def", "if (def") .replaceAll("ChronoUnit", " ChronoUnit") .replaceAll("=ZonedDateTime", " = ZonedDateTime") + .replaceAll("=e", " = e") } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 8f4814df..0f2af173 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -710,42 +710,42 @@ case class SQLCaseWhen( case _ => "" } val cases = conditions.zipWithIndex - .map { case ((cond, res), zindex) => + .map { case ((cond, res), idx) => + val name = + cond match { + case e: Expression => + e.identifier.name + case i: Identifier => + i.name + case _ => "" + } expression match { case Some(expr) => val c = SQLTypeUtils.coerce(cond, expr.out) if (cond.sql == res.sql) { - s"def val$zindex = $c; if (expr == val$zindex) return val$zindex;" + s"def val$idx = $c; if (expr == val$idx) return val$idx;" } else { - val _res = { - res match { - case i: Identifier => - val name = i.name - cond match { - case e: Expression if e.identifier.name == name => - e.identifier.nullable = false - e - case i: Identifier if i.name == name => - i.nullable = false - i - case _ => res - } - case _ => res - } + res match { + case i: Identifier if i.name == name && cond.isInstanceOf[Identifier] => + i.nullable = false + if (cond.asInstanceOf[Identifier].functions.isEmpty) + s"def val$idx = $c; if (expr == val$idx) return ${SQLTypeUtils.coerce(i.toPainless(s"val$idx"), i.out, out, nullable = false)};" + else { + cond.asInstanceOf[Identifier].nullable = false + s"def e$idx = ${i.checkNotNull}; def val$idx = e$idx != null ? ${SQLTypeUtils.coerce(cond.asInstanceOf[Identifier].toPainless(s"e$idx"), cond.out, out, nullable = false)} : null; if (expr == val$idx) return ${SQLTypeUtils + .coerce(i.toPainless(s"e$idx"), i.out, out, nullable = false)};" + } + case _ => + s"if (expr == $c) return ${SQLTypeUtils.coerce(res, out)};" } - val r = SQLTypeUtils.coerce(_res, out) - s"if (expr == $c) return $r;" } case None => val c = SQLTypeUtils.coerce(cond, SQLTypes.Boolean) val r = - cond match { - case e: Expression => - val name = e.identifier.name - res match { - case i: Identifier if i.name == name => "left" - case _ => SQLTypeUtils.coerce(res, out) - } + res match { + case i: Identifier if i.name == name && cond.isInstanceOf[Expression] => + i.nullable = false + SQLTypeUtils.coerce(i.toPainless("left"), i.out, out, nullable = false) case _ => SQLTypeUtils.coerce(res, out) } s"if ($c) return $r;" diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 94e977ef..123313de 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -483,9 +483,11 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => string_type | datetime_type | timestamp_type | date_type | time_type | boolean_type | long_type | double_type | int_type private[this] def castFunctionWithIdentifier: PackratParser[SQLIdentifier] = - "(?i)cast".r ~ start ~ (identifierWithTransformation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier) ~ Alias.regex.? ~ sql_type ~ end ^^ { - case _ ~ _ ~ i ~ as ~ t ~ _ => - i.copy(functions = SQLCast(i, targetType = t, as = as.isDefined) +: i.functions) + "(?i)cast".r ~ start ~ (identifierWithTransformation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier) ~ Alias.regex.? ~ sql_type ~ end ~ arithmeticFunction.? ^^ { + case _ ~ _ ~ i ~ as ~ t ~ _ ~ a => + i.copy(functions = + (SQLCast(i, targetType = t, as = as.isDefined) +: i.functions) ++ a.toList + ) } private[this] def dateFunctionWithIdentifier: PackratParser[SQLIdentifier] = diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index 8ba264a1..48db25a6 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -363,9 +363,14 @@ package object sql { expr } + def checkNotNull: String = + if (name.isEmpty) "" + else + s"(!doc.containsKey('$name') || doc['$name'].empty ? $nullValue : doc['$name'].value)" + override def painless: String = toPainless( if (nullable) - s"(!doc.containsKey('$name') || doc['$name'].empty ? $nullValue : doc['$name'].value)" + checkNotNull else paramName ) diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 3e5ba1dc..2f3ee749 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -144,9 +144,9 @@ object Queries { val cast: String = "select cast(coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd')), current_date - interval 2 hour) bigint) as c, identifier from Table" val caseWhen: String = - "select case when lastUpdated > now - interval 7 day then lastUpdated when isnotnull(lastSeen) then lastSeen else createdAt end as c, identifier from Table" + "select case when lastUpdated > now - interval 7 day then lastUpdated when isnotnull(lastSeen) then lastSeen + interval 2 day else createdAt end as c, identifier from Table" val caseWhenExpr: String = - "select case now - interval 7 day when lastUpdated then lastUpdated when lastSeen then lastSeen else createdAt end as c, identifier from Table" + "select case current_date - interval 7 day when cast(lastUpdated as date) - interval 3 day then lastUpdated when lastSeen then lastSeen + interval 2 day else createdAt end as c, identifier from Table" } /** Created by smanciot on 15/02/17. From 2100d827da64e7a28029887fe512a99f63650509 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 15 Sep 2025 14:51:40 +0200 Subject: [PATCH 17/18] get rid of ParametrizedFunction, fix extract function, fix cast with interval function --- .../elastic/sql/SQLQuerySpec.scala | 72 ++++++++++++++++++- .../elastic/sql/SQLQuerySpec.scala | 72 ++++++++++++++++++- .../softnetwork/elastic/sql/SQLFunction.scala | 32 +++------ .../softnetwork/elastic/sql/SQLParser.scala | 18 ++--- .../elastic/sql/SQLParserSpec.scala | 14 +++- 5 files changed, 175 insertions(+), 33 deletions(-) diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 827cb07b..c8a9b8e3 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1953,7 +1953,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "c": { | "script": { | "lang": "painless", - | "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def val0 = e0 != null ? ((e0).atStartOfDay(ZoneId.of('Z')).minus(3, ChronoUnit.DAYS)).atStartOfDay(ZoneId.of('Z')) : null; if (expr == val0) return e0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def val0 = e0 != null ? (e0.minus(3, ChronoUnit.DAYS)).atStartOfDay(ZoneId.of('Z')) : null; if (expr == val0) return e0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" | } | } | }, @@ -1991,4 +1991,74 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("=e", " = e") } + it should "handle extract function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(extract) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "day": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.DAYS) : null)" + | } + | }, + | "month": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.MONTHS) : null)" + | } + | }, + | "year": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.YEARS) : null)" + | } + | }, + | "hour": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.HOURS) : null)" + | } + | }, + | "minute": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.MINUTES) : null)" + | } + | }, + | "second": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.SECONDS) : null)" + | } + | } + | }, + | "_source": true + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll(";if", "; if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll(">", " > ") + .replaceAll("if \\(\\s*def", "if (def") + } + } diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index bc17b6cc..4037983e 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1942,7 +1942,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "c": { | "script": { | "lang": "painless", - | "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def val0 = e0 != null ? ((e0).atStartOfDay(ZoneId.of('Z')).minus(3, ChronoUnit.DAYS)).atStartOfDay(ZoneId.of('Z')) : null; if (expr == val0) return e0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def val0 = e0 != null ? (e0.minus(3, ChronoUnit.DAYS)).atStartOfDay(ZoneId.of('Z')) : null; if (expr == val0) return e0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" | } | } | }, @@ -1980,4 +1980,74 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("=e", " = e") } + it should "handle extract function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(extract) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "day": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.DAYS) : null)" + | } + | }, + | "month": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.MONTHS) : null)" + | } + | }, + | "year": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.YEARS) : null)" + | } + | }, + | "hour": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.HOURS) : null)" + | } + | }, + | "minute": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.MINUTES) : null)" + | } + | }, + | "second": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.SECONDS) : null)" + | } + | } + | }, + | "_source": true + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll(";if", "; if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll(">", " > ") + .replaceAll("if \\(\\s*def", "if (def") + } + } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 0f2af173..cbdd1552 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -137,18 +137,6 @@ sealed trait SQLArithmeticFunction[In <: SQLType, Out <: SQLType] override def applyType(in: SQLType): SQLType = in } -sealed trait ParametrizedFunction extends SQLFunction { - def params: Seq[String] - override def toSQL(base: String): String = { - params match { - case Nil => s"$sql($base)" - case _ => - val paramsStr = params.mkString(", ") - s"$sql($paramsStr)($base)" - } - } -} - sealed trait AggregateFunction extends SQLFunction case object Count extends SQLExpr("count") with AggregateFunction case object Min extends SQLExpr("min") with AggregateFunction @@ -367,36 +355,35 @@ case class DateTrunc(identifier: SQLIdentifier, unit: TimeUnit) case class Extract(unit: TimeUnit, override val sql: String = "extract") extends SQLExpr(sql) with DateTimeFunction - with SQLTransformFunction[SQLTemporal, SQLNumeric] - with ParametrizedFunction { + with SQLTransformFunction[SQLTemporal, SQLNumeric] { override def inputType: SQLTemporal = SQLTypes.Temporal override def outputType: SQLNumeric = SQLTypes.Numeric - override def params: Seq[String] = Seq(unit.sql) + override def toSQL(base: String): String = s"$sql(${unit.sql} from $base)" override def painless: String = s".get(${unit.painless})" } object YEAR extends Extract(Year, Year.sql) { - override def params: Seq[String] = Seq.empty + override def toSQL(base: String): String = s"$sql($base)" } object MONTH extends Extract(Month, Month.sql) { - override def params: Seq[String] = Seq.empty + override def toSQL(base: String): String = s"$sql($base)" } object DAY extends Extract(Day, Day.sql) { - override def params: Seq[String] = Seq.empty + override def toSQL(base: String): String = s"$sql($base)" } object HOUR extends Extract(Hour, Hour.sql) { - override def params: Seq[String] = Seq.empty + override def toSQL(base: String): String = s"$sql($base)" } object MINUTE extends Extract(Minute, Minute.sql) { - override def params: Seq[String] = Seq.empty + override def toSQL(base: String): String = s"$sql($base)" } object SECOND extends Extract(Second, Second.sql) { - override def params: Seq[String] = Seq.empty + override def toSQL(base: String): String = s"$sql($base)" } case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit) @@ -732,7 +719,8 @@ case class SQLCaseWhen( s"def val$idx = $c; if (expr == val$idx) return ${SQLTypeUtils.coerce(i.toPainless(s"val$idx"), i.out, out, nullable = false)};" else { cond.asInstanceOf[Identifier].nullable = false - s"def e$idx = ${i.checkNotNull}; def val$idx = e$idx != null ? ${SQLTypeUtils.coerce(cond.asInstanceOf[Identifier].toPainless(s"e$idx"), cond.out, out, nullable = false)} : null; if (expr == val$idx) return ${SQLTypeUtils + s"def e$idx = ${i.checkNotNull}; def val$idx = e$idx != null ? ${SQLTypeUtils + .coerce(cond.asInstanceOf[Identifier].toPainless(s"e$idx"), cond.out, out, nullable = false)} : null; if (expr == val$idx) return ${SQLTypeUtils .coerce(i.toPainless(s"e$idx"), i.out, out, nullable = false)};" } case _ => diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 123313de..30cab644 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -173,9 +173,10 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => DateTrunc(i, u) } - def extract: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = - "(?i)extract".r ~ start ~ time_unit ~ end ^^ { case _ ~ _ ~ u ~ _ => - Extract(u) + def extract_identifier: PackratParser[SQLIdentifier] = + "(?i)extract".r ~ start ~ time_unit ~ "(?i)from".r ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ end ^^ { + case _ ~ _ ~ u ~ _ ~ i ~ _ => + i.copy(functions = Extract(u) +: i.functions) } def extract_year: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = @@ -197,7 +198,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => Second.regex ^^ (_ => SECOND) def extractors: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = - extract | extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second + extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second def date_add: PackratParser[DateFunction with SQLFunctionWithIdentifier] = "(?i)date_add".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { @@ -311,6 +312,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => // les plus spécifiques en premier identifierWithTransformation | // transformations appliquées à un identifier date_diff_identifier | // date_diff(...) retournant un identifier-like + extract_identifier | identifierWithSystemFunction | // CURRENT_DATE, NOW, etc. (+/- interval) identifierWithArithmeticFunction | // foo - interval ... identifierWithTemporalFunction | // chaîne de fonctions appliquées à un identifier @@ -483,10 +485,10 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => string_type | datetime_type | timestamp_type | date_type | time_type | boolean_type | long_type | double_type | int_type private[this] def castFunctionWithIdentifier: PackratParser[SQLIdentifier] = - "(?i)cast".r ~ start ~ (identifierWithTransformation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier) ~ Alias.regex.? ~ sql_type ~ end ~ arithmeticFunction.? ^^ { + "(?i)cast".r ~ start ~ (identifierWithTransformation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | extract_identifier | identifier) ~ Alias.regex.? ~ sql_type ~ end ~ arithmeticFunction.? ^^ { case _ ~ _ ~ i ~ as ~ t ~ _ ~ a => i.copy(functions = - (SQLCast(i, targetType = t, as = as.isDefined) +: i.functions) ++ a.toList + a.toList ++ (SQLCast(i, targetType = t, as = as.isDefined) +: i.functions) ) } @@ -552,7 +554,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => def alias: PackratParser[SQLAlias] = Alias.regex.? ~ regexAlias.r ^^ { case _ ~ b => SQLAlias(b) } def field: PackratParser[Field] = - (identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | case_when_identifier | identifier) ~ alias.? ^^ { + (identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | extract_identifier | case_when_identifier | identifier) ~ alias.? ^^ { case i ~ a => SQLField(i, a) } @@ -612,7 +614,7 @@ trait SQLWhereParser { private def diff: PackratParser[SQLComparisonOperator] = Diff.sql ^^ (_ => Diff) private def any_identifier: PackratParser[SQLIdentifier] = - identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier + identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | extract_identifier | identifier private def equality: PackratParser[SQLExpression] = not.? ~ any_identifier ~ (eq | ne | diff) ~ (boolean | literal | double | long | any_identifier) ^^ { diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 2f3ee749..3e00e98b 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -147,6 +147,9 @@ object Queries { "select case when lastUpdated > now - interval 7 day then lastUpdated when isnotnull(lastSeen) then lastSeen + interval 2 day else createdAt end as c, identifier from Table" val caseWhenExpr: String = "select case current_date - interval 7 day when cast(lastUpdated as date) - interval 3 day then lastUpdated when lastSeen then lastSeen + interval 2 day else createdAt end as c, identifier from Table" + + val extract: String = + "select extract(day from createdAt) as day, extract(month from createdAt) as month, extract(year from createdAt) as year, extract(hour from createdAt) as hour, extract(minute from createdAt) as minute, extract(second from createdAt) as second from Table" } /** Created by smanciot on 15/02/17. @@ -579,8 +582,17 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { it should "parse case when with expression" in { val result = SQLParser(caseWhenExpr) + result.toOption + .flatMap(_.left.toOption.map(_.sql)) + .getOrElse("") + .equalsIgnoreCase(caseWhenExpr) shouldBe true + } + + it should "parse extract function" in { + val result = SQLParser(extract) result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( - caseWhenExpr + extract ) } + } From a96e93b21851e7a16e37a40f5207a63662bd97b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 15 Sep 2025 15:19:09 +0200 Subject: [PATCH 18/18] add all casts --- .../softnetwork/elastic/sql/SQLParser.scala | 25 ++++++++++++++++--- .../softnetwork/elastic/sql/SQLTypes.scala | 2 +- .../elastic/sql/SQLParserSpec.scala | 9 +++++++ 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 30cab644..ff6db34c 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -413,11 +413,18 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => "minute", "second", "quarter", + "char", "string", + "byte", + "tinyint", + "short", + "smallint", "int", "integer", "long", "bigint", + "real", + "float", "double", "boolean", "time", @@ -460,7 +467,11 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => ) } - def string_type: PackratParser[SQLTypes.Varchar.type] = "(?i)string".r ^^ (_ => SQLTypes.Varchar) + def char_type: PackratParser[SQLTypes.Char.type] = + "(?i)char".r ^^ (_ => SQLTypes.Char) + + def string_type: PackratParser[SQLTypes.Varchar.type] = + "(?i)varchar|string".r ^^ (_ => SQLTypes.Varchar) def date_type: PackratParser[SQLTypes.Date.type] = "(?i)date".r ^^ (_ => SQLTypes.Date) @@ -475,14 +486,22 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => def boolean_type: PackratParser[SQLTypes.Boolean.type] = "(?i)boolean".r ^^ (_ => SQLTypes.Boolean) + def byte_type: PackratParser[SQLTypes.TinyInt.type] = + "(?i)(byte|tinyint)".r ^^ (_ => SQLTypes.TinyInt) + + def short_type: PackratParser[SQLTypes.SmallInt.type] = + "(?i)(short|smallint)".r ^^ (_ => SQLTypes.SmallInt) + + def int_type: PackratParser[SQLTypes.Int.type] = "(?i)(int|integer)".r ^^ (_ => SQLTypes.Int) + def long_type: PackratParser[SQLTypes.BigInt.type] = "(?i)long|bigint".r ^^ (_ => SQLTypes.BigInt) def double_type: PackratParser[SQLTypes.Double.type] = "(?i)double".r ^^ (_ => SQLTypes.Double) - def int_type: PackratParser[SQLTypes.Int.type] = "(?i)(int|integer)".r ^^ (_ => SQLTypes.Int) + def float_type: PackratParser[SQLTypes.Real.type] = "(?i)float|real".r ^^ (_ => SQLTypes.Real) def sql_type: PackratParser[SQLType] = - string_type | datetime_type | timestamp_type | date_type | time_type | boolean_type | long_type | double_type | int_type + char_type | string_type | datetime_type | timestamp_type | date_type | time_type | boolean_type | long_type | double_type | float_type | int_type | short_type | byte_type private[this] def castFunctionWithIdentifier: PackratParser[SQLIdentifier] = "(?i)cast".r ~ start ~ (identifierWithTransformation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | extract_identifier | identifier) ~ Alias.regex.? ~ sql_type ~ end ~ arithmeticFunction.? ^^ { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala index df82351e..d067b650 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala @@ -19,7 +19,7 @@ object SQLTypes { case object Int extends SQLInt { val typeId = "int" } case object BigInt extends SQLBigInt { val typeId = "bigint" } case object Double extends SQLDouble { val typeId = "double" } - case object Real extends SQLReal { val typeId = "float" } + case object Real extends SQLReal { val typeId = "real" } case object Literal extends SQLLiteral { val typeId = "literal" } diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 3e00e98b..972ba528 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -143,6 +143,8 @@ object Queries { "select coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd') - interval 2 day), current_date) as c, identifier from Table" val cast: String = "select cast(coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd')), current_date - interval 2 hour) bigint) as c, identifier from Table" + val allCasts = + "select cast(identifier as int) as c1, cast(identifier as bigint) as c2, cast(identifier as double) as c3, cast(identifier as real) as c4, cast(identifier as boolean) as c5, cast(identifier as char) as c6, cast(identifier as varchar) as c7, cast(createdAt as date) as c8, cast(createdAt as time) as c9, cast(createdAt as datetime) as c10, cast(createdAt as timestamp) as c11, cast(identifier as smallint) as c12, cast(identifier as tinyint) as c13 from Table" val caseWhen: String = "select case when lastUpdated > now - interval 7 day then lastUpdated when isnotnull(lastSeen) then lastSeen + interval 2 day else createdAt end as c, identifier from Table" val caseWhenExpr: String = @@ -573,6 +575,13 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { ) } + it should "parse all casts function" in { + val result = SQLParser(allCasts) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + allCasts + ) + } + it should "parse case when expression" in { val result = SQLParser(caseWhen) result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===(