diff --git a/build.sbt b/build.sbt index bae919ad..4b12cadf 100644 --- a/build.sbt +++ b/build.sbt @@ -19,7 +19,7 @@ ThisBuild / organization := "app.softnetwork" name := "softclient4es" -ThisBuild / version := "0.5.0" +ThisBuild / version := "0.6.0" ThisBuild / scalaVersion := scala213 diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala index c4943563..61fd88f1 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala @@ -9,11 +9,12 @@ import app.softnetwork.elastic.sql.{ ElasticNested, ElasticParent, SQLBetween, - SQLComparisonDateMath, SQLExpression, SQLIn, SQLIsNotNull, - SQLIsNull + SQLIsNotNullCriteria, + SQLIsNull, + SQLIsNullCriteria } import com.sksamuel.elastic4s.ElasticApi._ import com.sksamuel.elastic4s.searches.queries.Query @@ -70,7 +71,8 @@ case class ElasticQuery(filter: ElasticFilter) { case between: SQLBetween[Double] => between case geoDistance: ElasticGeoDistance => geoDistance case matchExpression: ElasticMatch => matchExpression - case dateMath: SQLComparisonDateMath => dateMath + case isNull: SQLIsNullCriteria => isNull + case isNotNull: SQLIsNotNullCriteria => isNotNull case other => throw new IllegalArgumentException(s"Unsupported filter type: ${other.getClass.getName}") } diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index 160ee4e2..36efad37 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -136,15 +136,20 @@ package object bridge { ) } - def applyNumericOp[A](n: SQLNumeric[_])( + def applyNumericOp[A](n: SQLNumericValue[_])( longOp: Long => A, doubleOp: Double => A ): A = n.toEither.fold(longOp, doubleOp) implicit def expressionToQuery(expression: SQLExpression): Query = { import expression._ + if (aggregation) + return matchAllQuery() + if (identifier.functions.nonEmpty) { + return scriptQuery(Script(script = painless).lang("painless").scriptType("source")) + } value match { - case n: SQLNumeric[_] if !aggregation => + case n: SQLNumericValue[_] => operator match { case Ge => maybeNot match { @@ -226,7 +231,7 @@ package object bridge { } case _ => matchAllQuery() } - case l: SQLLiteral if !aggregation => + case l: SQLStringValue => operator match { case Like => maybeNot match { @@ -279,7 +284,7 @@ package object bridge { } case _ => matchAllQuery() } - case b: SQLBoolean if !aggregation => + case b: SQLBoolean => operator match { case Eq => maybeNot match { @@ -297,27 +302,27 @@ package object bridge { } case _ => matchAllQuery() } - case _ => matchAllQuery() - } - } - - implicit def dateMathToQuery(dateMath: SQLComparisonDateMath): Query = { - import dateMath._ - if (aggregation) - return matchAllQuery() - dateTimeFunction match { - case _: CurrentTimeFunction => - scriptQuery(Script(script = script).lang("painless").scriptType("source")) - case _ => - val op = if (maybeNot.isDefined) operator.not else operator - op match { - case Gt => rangeQuery(identifier.name) gt script - case Ge => rangeQuery(identifier.name) gte script - case Lt => rangeQuery(identifier.name) lt script - case Le => rangeQuery(identifier.name) lte script - case Eq => rangeQuery(identifier.name) gte script lte script - case Ne | Diff => not(rangeQuery(identifier.name) gte script lte script) + case i: SQLIdentifier => + operator match { + case op: SQLComparisonOperator => + i.toScript match { + case Some(script) => + val o = if (maybeNot.isDefined) op.not else op + o match { + case Gt => rangeQuery(identifier.name) gt script + case Ge => rangeQuery(identifier.name) gte script + case Lt => rangeQuery(identifier.name) lt script + case Le => rangeQuery(identifier.name) lte script + case Eq => rangeQuery(identifier.name) gte script lte script + case Ne | Diff => not(rangeQuery(identifier.name) gte script lte script) + } + case _ => + scriptQuery(Script(script = painless).lang("painless").scriptType("source")) + } + case _ => + scriptQuery(Script(script = painless).lang("painless").scriptType("source")) } + case _ => matchAllQuery() } } @@ -335,6 +340,20 @@ package object bridge { existsQuery(identifier.name) } + implicit def isNullCriteriaToQuery( + isNull: SQLIsNullCriteria + ): Query = { + import isNull._ + not(existsQuery(identifier.name)) + } + + implicit def isNotNullCriteriaToQuery( + isNotNull: SQLIsNotNullCriteria + ): Query = { + import isNotNull._ + existsQuery(identifier.name) + } + implicit def inToQuery[R, T <: SQLValue[R]](in: SQLIn[R, T]): Query = { import in._ val _values: Seq[Any] = values.innerValues diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 97892ef4..c8a9b8e3 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -882,14 +882,29 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "ct": { | "script": { | "lang": "painless", - | "source": "doc['createdAt'].value.minus(35, ChronoUnit.MINUTES)" + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.minus(35, ChronoUnit.MINUTES) : null)" | } | } | }, | "_source": { - | "includes": ["identifier"] + | "includes": [ + | "identifier" + | ] | } - |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("!=null", " != null") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "filter with date time and interval" in { @@ -969,39 +984,47 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { println(query) query shouldBe """{ - | "query": { - | "bool": { - | "filter": [ - | { - | "script": { - | "script": { - | "lang": "painless", - | "source": "return doc['createdAt'].value.toLocalTime() < LocalTime.now();" - | } - | } - | }, - | { - | "script": { - | "script": { - | "lang": "painless", - | "source": "return doc['createdAt'].value.toLocalTime() >= LocalTime.now().minus(10, ChronoUnit.MINUTES);" - | } - | } - | } - | ] - | } - | }, - | "_source": { - | "includes": [ - | "*" - | ] - | } - |}""".stripMargin + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "def left = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); left == null ? false : left < ZonedDateTime.now(ZoneId.of('Z')).toLocalTime()" + | } + | } + | }, + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "def left = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); left == null ? false : left >= ZonedDateTime.now(ZoneId.of('Z')).toLocalTime().minus(10, ChronoUnit.MINUTES)" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin .replaceAll("\\s", "") .replaceAll("ChronoUnit", " ChronoUnit") .replaceAll(">=", " >= ") .replaceAll("<", " < ") - .replaceAll("return", "return ") + .replaceAll("\\|\\|", " || ") + .replaceAll("null:", "null : ") + .replaceAll("false:", "false : ") + .replaceAll(":null", " : null ") + .replaceAll("\\?", " ? ") + .replaceAll("==", " == ") + .replaceAll("\\);", "); ") + .replaceAll("=\\(", " = (") + .replaceAll("defl", "def l") } it should "handle having with date functions" in { @@ -1041,7 +1064,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": "lastSeen" | }, | "script": { - | "source": "(params.lastSeen != null) && (params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS).toInstant().toEpochMilli())" + | "source": "params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)" | } | } | } @@ -1203,7 +1226,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(doc['createdAt'].value, LocalDate::from)" + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(e0, LocalDate::from) : null)" | } | } | } @@ -1212,10 +1235,20 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") .replaceAll(",ChronoUnit", ", ChronoUnit") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") .replaceAll(">", " > ") .replaceAll(",LocalDate", ", LocalDate") } @@ -1259,7 +1292,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from).truncatedTo(ChronoUnit.MINUTES).get(ChronoUnit.YEARS)" + | "source": "(def e2 = (def e1 = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(e0, ZonedDateTime::from) : null); e1 != null ? e1.truncatedTo(ChronoUnit.MINUTES) : null); e2 != null ? e2.get(ChronoUnit.YEARS) : null)" | } | } | } @@ -1268,9 +1301,19 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") .replaceAll(">", " > ") .replaceAll(",ZonedDateTime", ", ZonedDateTime") } @@ -1289,7 +1332,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "diff": { | "script": { | "lang": "painless", - | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, doc['createdAt'].value)" + | "source": "(def s = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def e = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); s != null && e != null ? ChronoUnit.DAYS.between(s, e) : null)" | } | } | }, @@ -1300,7 +1343,21 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") - .replaceAll(",doc", ", doc") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") } it should "handle aggregation with date_diff function" in { @@ -1325,7 +1382,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "max": { | "script": { | "lang": "painless", - | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from))" + | "source": "(def s = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def e = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(e0, ZonedDateTime::from) : null); s != null && e != null ? ChronoUnit.DAYS.between(s, e) : null)" | } | } | } @@ -1334,8 +1391,21 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") - .replaceAll(",doc", ", doc") - .replaceAll("DateTimeFormatter", " DateTimeFormatter") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") .replaceAll("ZonedDateTime", " ZonedDateTime") } @@ -1361,7 +1431,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.plus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1370,7 +1440,24 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "handle date_sub function as script field" in { @@ -1395,7 +1482,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.minus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1404,7 +1491,24 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "handle datetime_add function as script field" in { @@ -1429,7 +1533,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.plus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1438,7 +1542,24 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "handle datetime_sub function as script field" in { @@ -1463,7 +1584,88 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.minus(10, ChronoUnit.DAYS) : null)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle is_null function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(isnull) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "flag": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); e0 == null)" + | } + | } + | }, + | "_source": true + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + } + + it should "handle is_notnull function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(isnotnull) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "flag": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null)" | } | } | }, @@ -1472,7 +1674,391 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + } + + it should "handle is_null criteria as must_not exists" in { + val select: ElasticSearchRequest = + SQLQuery(isNullCriteria) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "bool": { + | "must_not": [ + | { + | "exists": { + | "field": "identifier" + | } + | } + | ] + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "") + } + + it should "handle is_notnull criteria as exists" in { + val select: ElasticSearchRequest = + SQLQuery(isNotNullCriteria) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier" + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "") + } + + it should "handle coalesce function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(coalesce) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def v0 = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.minus(35, ChronoUnit.MINUTES) : null);if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')); }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll(";}", "; }") + .replaceAll(";e", "; e") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle nullif function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(nullif) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def v0 = ({ def e1=(!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2=DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from).minus(2, ChronoUnit.DAYS); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defe", " def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";def", "; def") + .replaceAll(";return", "; return") + .replaceAll("returnv", " return v") + .replaceAll("returne", " return e") + .replaceAll(";}", "; }") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(",LocalDate", ", LocalDate") + .replaceAll("=DateTimeFormatter", " = DateTimeFormatter") + .replaceAll("ZonedDateTime", " ZonedDateTime") + } + + it should "handle cast function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(cast) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def v0 = ({ def e1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2 = DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')).minus(2, ChronoUnit.HOURS); }.toInstant().toEpochMilli()" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defe", " def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("; if", ";if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(",LocalDate", ", LocalDate") + .replaceAll("=DateTimeFormatter", " = DateTimeFormatter") + } + + it should "handle case function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(caseWhen) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ if (def left = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); left == null ? false : left > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)) return left; if (def left = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); left != null) return left.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defd", " def d") + .replaceAll("defe", " def e") + .replaceAll("defl", " def l") + .replaceAll("if\\(", "if (") + .replaceAll("\\{if", "{ if") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("false:", "false : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll(";if", "; if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll(">", " > ") + .replaceAll("if \\(\\s*def", "if (def") + .replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle case with expression function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(caseWhenExpr) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def val0 = e0 != null ? (e0.minus(3, ChronoUnit.DAYS)).atStartOfDay(ZoneId.of('Z')) : null; if (expr == val0) return e0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defd", " def d") + .replaceAll("defe", " def e") + .replaceAll("defl", " def l") + .replaceAll("if\\(", "if (") + .replaceAll("\\{if", "{ if") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("false:", "false : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll(";if", "; if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll(">", " > ") + .replaceAll("if \\(\\s*def", "if (def") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("=ZonedDateTime", " = ZonedDateTime") + .replaceAll("=e", " = e") + } + + it should "handle extract function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(extract) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "day": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.DAYS) : null)" + | } + | }, + | "month": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.MONTHS) : null)" + | } + | }, + | "year": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.YEARS) : null)" + | } + | }, + | "hour": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.HOURS) : null)" + | } + | }, + | "minute": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.MINUTES) : null)" + | } + | }, + | "second": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.SECONDS) : null)" + | } + | } + | }, + | "_source": true + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll(";if", "; if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll(">", " > ") + .replaceAll("if \\(\\s*def", "if (def") } } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala index 9de90dd8..3a532263 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala @@ -9,11 +9,12 @@ import app.softnetwork.elastic.sql.{ ElasticNested, ElasticParent, SQLBetween, - SQLComparisonDateMath, SQLExpression, SQLIn, SQLIsNotNull, - SQLIsNull + SQLIsNotNullCriteria, + SQLIsNull, + SQLIsNullCriteria } import com.sksamuel.elastic4s.ElasticApi._ import com.sksamuel.elastic4s.requests.searches.queries.Query @@ -70,7 +71,8 @@ case class ElasticQuery(filter: ElasticFilter) { case between: SQLBetween[Double] => between case geoDistance: ElasticGeoDistance => geoDistance case matchExpression: ElasticMatch => matchExpression - case dateMath: SQLComparisonDateMath => dateMath + case isNull: SQLIsNullCriteria => isNull + case isNotNull: SQLIsNotNullCriteria => isNotNull case other => throw new IllegalArgumentException(s"Unsupported filter type: ${other.getClass.getName}") } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index b2edb050..84d5b845 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -137,15 +137,20 @@ package object bridge { ) } - def applyNumericOp[A](n: SQLNumeric[_])( + def applyNumericOp[A](n: SQLNumericValue[_])( longOp: Long => A, doubleOp: Double => A ): A = n.toEither.fold(longOp, doubleOp) implicit def expressionToQuery(expression: SQLExpression): Query = { import expression._ + if (aggregation) + return matchAllQuery() + if (identifier.functions.nonEmpty) { + return scriptQuery(Script(script = painless).lang("painless").scriptType("source")) + } value match { - case n: SQLNumeric[_] if !aggregation => + case n: SQLNumericValue[_] => operator match { case Ge => maybeNot match { @@ -227,7 +232,7 @@ package object bridge { } case _ => matchAllQuery() } - case l: SQLLiteral if !aggregation => + case l: SQLStringValue => operator match { case Like => maybeNot match { @@ -280,7 +285,7 @@ package object bridge { } case _ => matchAllQuery() } - case b: SQLBoolean if !aggregation => + case b: SQLBoolean => operator match { case Eq => maybeNot match { @@ -298,27 +303,27 @@ package object bridge { } case _ => matchAllQuery() } - case _ => matchAllQuery() - } - } - - implicit def dateMathToQuery(dateMath: SQLComparisonDateMath): Query = { - import dateMath._ - if (aggregation) - return matchAllQuery() - dateTimeFunction match { - case _: CurrentTimeFunction => - scriptQuery(Script(script = script).lang("painless").scriptType("source")) - case _ => - val op = if (maybeNot.isDefined) operator.not else operator - op match { - case Gt => rangeQuery(identifier.name) gt script - case Ge => rangeQuery(identifier.name) gte script - case Lt => rangeQuery(identifier.name) lt script - case Le => rangeQuery(identifier.name) lte script - case Eq => rangeQuery(identifier.name) gte script lte script - case Ne | Diff => not(rangeQuery(identifier.name) gte script lte script) + case i: SQLIdentifier => + operator match { + case op: SQLComparisonOperator => + i.toScript match { + case Some(script) => + val o = if (maybeNot.isDefined) op.not else op + o match { + case Gt => rangeQuery(identifier.name) gt script + case Ge => rangeQuery(identifier.name) gte script + case Lt => rangeQuery(identifier.name) lt script + case Le => rangeQuery(identifier.name) lte script + case Eq => rangeQuery(identifier.name) gte script lte script + case Ne | Diff => not(rangeQuery(identifier.name) gte script lte script) + } + case _ => + scriptQuery(Script(script = painless).lang("painless").scriptType("source")) + } + case _ => + scriptQuery(Script(script = painless).lang("painless").scriptType("source")) } + case _ => matchAllQuery() } } @@ -336,6 +341,20 @@ package object bridge { existsQuery(identifier.name) } + implicit def isNullCriteriaToQuery( + isNull: SQLIsNullCriteria + ): Query = { + import isNull._ + not(existsQuery(identifier.name)) + } + + implicit def isNotNullCriteriaToQuery( + isNotNull: SQLIsNotNullCriteria + ): Query = { + import isNotNull._ + existsQuery(identifier.name) + } + implicit def inToQuery[R, T <: SQLValue[R]](in: SQLIn[R, T]): Query = { import in._ val _values: Seq[Any] = values.innerValues diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index d7f25e09..4037983e 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -881,14 +881,28 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "ct": { | "script": { | "lang": "painless", - | "source": "doc['createdAt'].value.minus(35, ChronoUnit.MINUTES)" + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.minus(35, ChronoUnit.MINUTES) : null)" | } | } | }, | "_source": { - | "includes": ["identifier"] + | "includes": [ + | "identifier" + | ] | } - |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin.replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("!=null", " != null") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "filter with date time and interval" in { @@ -966,41 +980,48 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { SQLQuery(filterWithTimeAndInterval) val query = select.query println(query) - query shouldBe - """{ - | "query": { - | "bool": { - | "filter": [ - | { - | "script": { - | "script": { - | "lang": "painless", - | "source": "return doc['createdAt'].value.toLocalTime() < LocalTime.now();" - | } - | } - | }, - | { - | "script": { - | "script": { - | "lang": "painless", - | "source": "return doc['createdAt'].value.toLocalTime() >= LocalTime.now().minus(10, ChronoUnit.MINUTES);" - | } - | } - | } - | ] - | } - | }, - | "_source": { - | "includes": [ - | "*" - | ] - | } - |}""".stripMargin - .replaceAll("\\s", "") - .replaceAll("ChronoUnit", " ChronoUnit") - .replaceAll(">=", " >= ") - .replaceAll("<", " < ") - .replaceAll("return", "return ") + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "def left = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); left == null ? false : left < ZonedDateTime.now(ZoneId.of('Z')).toLocalTime()" + | } + | } + | }, + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "def left = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); left == null ? false : left >= ZonedDateTime.now(ZoneId.of('Z')).toLocalTime().minus(10, ChronoUnit.MINUTES)" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(">=", " >= ") + .replaceAll("<", " < ") + .replaceAll("\\|\\|", " || ") + .replaceAll("null:", "null : ") + .replaceAll("false:", "false : ") + .replaceAll(":null", " : null ") + .replaceAll("\\?", " ? ") + .replaceAll("==", " == ") + .replaceAll("\\);", "); ") + .replaceAll("=\\(", " = (") + .replaceAll("defl", "def l") } it should "handle having with date functions" in { @@ -1040,7 +1061,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": "lastSeen" | }, | "script": { - | "source": "(params.lastSeen != null) && (params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS).toInstant().toEpochMilli())" + | "source": "params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)" | } | } | } @@ -1200,7 +1221,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(doc['createdAt'].value, LocalDate::from)" + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(e0, LocalDate::from) : null)" | } | } | } @@ -1209,10 +1230,20 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") .replaceAll("ChronoUnit", " ChronoUnit") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") .replaceAll(">", " > ") .replaceAll(",LocalDate", ", LocalDate") } @@ -1256,7 +1287,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from).truncatedTo(ChronoUnit.MINUTES).get(ChronoUnit.YEARS)" + | "source": "(def e2 = (def e1 = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(e0, ZonedDateTime::from) : null); e1 != null ? e1.truncatedTo(ChronoUnit.MINUTES) : null); e2 != null ? e2.get(ChronoUnit.YEARS) : null)" | } | } | } @@ -1265,9 +1296,19 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") .replaceAll(">", " > ") .replaceAll(",ZonedDateTime", ", ZonedDateTime") } @@ -1286,7 +1327,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "diff": { | "script": { | "lang": "painless", - | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, doc['createdAt'].value)" + | "source": "(def s = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def e = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); s != null && e != null ? ChronoUnit.DAYS.between(s, e) : null)" | } | } | }, @@ -1297,7 +1338,21 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") - .replaceAll(",doc", ", doc") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") } it should "handle aggregation with date_diff function" in { @@ -1322,7 +1377,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "max": { | "script": { | "lang": "painless", - | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from))" + | "source": "(def s = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def e = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(e0, ZonedDateTime::from) : null); s != null && e != null ? ChronoUnit.DAYS.between(s, e) : null)" | } | } | } @@ -1331,8 +1386,21 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") - .replaceAll(",doc", ", doc") - .replaceAll("DateTimeFormatter", " DateTimeFormatter") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") .replaceAll("ZonedDateTime", " ZonedDateTime") } @@ -1358,7 +1426,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.plus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1367,7 +1435,23 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin.replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "handle date_sub function as script field" in { @@ -1392,7 +1476,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.minus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1401,7 +1485,23 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin.replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "handle datetime_add function as script field" in { @@ -1426,7 +1526,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.plus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1435,7 +1535,23 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin.replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") } it should "handle datetime_sub function as script field" in { @@ -1460,7 +1576,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": { | "script": { | "lang": "painless", - | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | "source": "(def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); e0 != null ? e0.minus(10, ChronoUnit.DAYS) : null)" | } | } | }, @@ -1469,7 +1585,469 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + |}""".stripMargin.replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle is_null function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(isnull) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "flag": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); e0 == null)" + | } + | } + | }, + | "_source": true + |}""".stripMargin.replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + } + + it should "handle is_notnull function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(isnotnull) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "flag": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value); e0 != null)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defs", "def s") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + } + + it should "handle is_null criteria as must_not exists" in { + val select: ElasticSearchRequest = + SQLQuery(isNullCriteria) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "bool": { + | "must_not": [ + | { + | "exists": { + | "field": "identifier" + | } + | } + | ] + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "") + } + + it should "handle is_notnull criteria as exists" in { + val select: ElasticSearchRequest = + SQLQuery(isNotNullCriteria) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier" + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "") + } + + it should "handle coalesce function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(coalesce) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def v0 = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.minus(35, ChronoUnit.MINUTES) : null);if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')); }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll(";}", "; }") + .replaceAll(";e", "; e") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle nullif function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(nullif) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def v0 = ({ def e1=(!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2=DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from).minus(2, ChronoUnit.DAYS); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defe", " def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";def", "; def") + .replaceAll(";return", "; return") + .replaceAll("returnv", " return v") + .replaceAll("returne", " return e") + .replaceAll(";}", "; }") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(",LocalDate", ", LocalDate") + .replaceAll("=DateTimeFormatter", " = DateTimeFormatter") + .replaceAll("ZonedDateTime", " ZonedDateTime") + } + + + it should "handle cast function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(cast) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def v0 = ({ def e1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2 = DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')).minus(2, ChronoUnit.HOURS); }.toInstant().toEpochMilli()" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defe", " def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("; if", ";if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(",LocalDate", ", LocalDate") + .replaceAll("=DateTimeFormatter", " = DateTimeFormatter") + } + + it should "handle case function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(caseWhen) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ if (def left = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); left == null ? false : left > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS)) return left; if (def left = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); left != null) return left.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defd", " def d") + .replaceAll("defe", " def e") + .replaceAll("defl", " def l") + .replaceAll("if\\(", "if (") + .replaceAll("\\{if", "{ if") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("false:", "false : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll(";if", "; if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll(">", " > ") + .replaceAll("if \\(\\s*def", "if (def") + .replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle case with expression function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(caseWhenExpr) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "c": { + | "script": { + | "lang": "painless", + | "source": "{ def expr = ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().minus(7, ChronoUnit.DAYS); def e0 = (!doc.containsKey('lastUpdated') || doc['lastUpdated'].empty ? null : doc['lastUpdated'].value); def val0 = e0 != null ? (e0.minus(3, ChronoUnit.DAYS)).atStartOfDay(ZoneId.of('Z')) : null; if (expr == val0) return e0; def val1 = (!doc.containsKey('lastSeen') || doc['lastSeen'].empty ? null : doc['lastSeen'].value); if (expr == val1) return val1.plus(2, ChronoUnit.DAYS); def dval = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); return dval; }" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defd", " def d") + .replaceAll("defe", " def e") + .replaceAll("defl", " def l") + .replaceAll("if\\(", "if (") + .replaceAll("\\{if", "{ if") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("false:", "false : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll(";if", "; if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll(">", " > ") + .replaceAll("if \\(\\s*def", "if (def") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("=ZonedDateTime", " = ZonedDateTime") + .replaceAll("=e", " = e") + } + + it should "handle extract function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(extract) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "day": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.DAYS) : null)" + | } + | }, + | "month": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.MONTHS) : null)" + | } + | }, + | "year": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.YEARS) : null)" + | } + | }, + | "hour": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.HOURS) : null)" + | } + | }, + | "minute": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.MINUTES) : null)" + | } + | }, + | "second": { + | "script": { + | "lang": "painless", + | "source": "(def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? e0.get(ChronoUnit.SECONDS) : null)" + | } + | } + | }, + | "_source": true + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defe", "def e") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll(";if", "; if") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll(">", " > ") + .replaceAll("if \\(\\s*def", "if (def") } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLDelimiter.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLDelimiter.scala index cc6a7dc4..1b0b791b 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLDelimiter.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLDelimiter.scala @@ -4,7 +4,11 @@ sealed trait SQLDelimiter extends SQLToken sealed trait StartDelimiter extends SQLDelimiter case object StartPredicate extends SQLExpr("(") with StartDelimiter +case object StartCase extends SQLExpr("case") with StartDelimiter +case object WhenCase extends SQLExpr("when") with StartDelimiter sealed trait EndDelimiter extends SQLDelimiter case object EndPredicate extends SQLExpr(")") with EndDelimiter case object Separator extends SQLExpr(",") with EndDelimiter +case object EndCase extends SQLExpr("end") with EndDelimiter +case object ThenCase extends SQLExpr("then") with EndDelimiter diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFrom.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFrom.scala index 15f36837..dc7e65a9 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFrom.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFrom.scala @@ -27,4 +27,12 @@ case class SQLFrom(tables: Seq[SQLTable]) extends Updateable { } def update(request: SQLSearchRequest): SQLFrom = this.copy(tables = tables.map(_.update(request))) + + override def validate(): Either[String, Unit] = { + if (tables.isEmpty) { + Left("At least one table is required in FROM clause") + } else { + Right(()) + } + } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 96ca2453..cbdd1552 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -4,45 +4,99 @@ import scala.util.matching.Regex sealed trait SQLFunction extends SQLRegex { def toSQL(base: String): String = if (base.nonEmpty) s"$sql($base)" else sql + def applyType(in: SQLType): SQLType = out + private[this] var _expr: SQLToken = SQLNull + def expr_=(e: SQLToken): Unit = { + _expr = e + } + def expr: SQLToken = _expr + override def nullable: Boolean = expr.nullable } sealed trait SQLFunctionWithIdentifier extends SQLFunction { def identifier: SQLIdentifier } +trait SQLFunctionWithValue[+T] extends SQLFunction { + def value: T +} + object SQLFunctionUtils { def aggregateAndTransformFunctions( - identifier: Identifier + chain: SQLFunctionChain ): (List[SQLFunction], List[SQLFunction]) = { - identifier.functions.partition { + chain.functions.partition { case _: AggregateFunction => true case _ => false } } - def transformFunctions(identifier: Identifier): List[SQLFunction] = { - aggregateAndTransformFunctions(identifier)._2 + def transformFunctions(chain: SQLFunctionChain): List[SQLFunction] = { + aggregateAndTransformFunctions(chain)._2 } } -trait SQLFunctionChain extends SQLFunction with SQLValidation { +trait SQLFunctionChain extends SQLFunction { def functions: List[SQLFunction] - override def validate(): Either[String, Unit] = - SQLValidator.validateChain(functions) + override def validate(): Either[String, Unit] = { + if (aggregations.size > 1) { + Left("Only one aggregation function is allowed in a function chain") + } else if (aggregations.size == 1 && !functions.head.isInstanceOf[AggregateFunction]) { + Left("Aggregation function must be the first function in the chain") + } else { + SQLValidator.validateChain(functions) + } + } override def toSQL(base: String): String = functions.reverse.foldLeft(base)((expr, fun) => { fun.toSQL(expr) }) - lazy val aggregateFunction: Option[AggregateFunction] = functions.headOption match { - case Some(af: AggregateFunction) => Some(af) - case _ => None + def toScript: Option[String] = { + val orderedFunctions = SQLFunctionUtils.transformFunctions(this).reverse + orderedFunctions.foldLeft(Option("")) { + case (expr, f: MathScript) if expr.isDefined => Option(s"${expr.get}${f.script}") + case (_, _) => None // ignore non math scripts + } match { + case Some(s) if s.nonEmpty => + out match { + case SQLTypes.Date => Some(s"$s/d") + case _ => Some(s) + } + case _ => None + } + } + + override def system: Boolean = functions.lastOption.exists(_.system) + + def applyTo(expr: SQLToken): Unit = { + this.expr = expr + functions.reverse.foldLeft(expr) { (currentExpr, fun) => + fun.expr = currentExpr + fun + } + } + + private[this] lazy val aggregations = functions.collect { case af: AggregateFunction => + af } + lazy val aggregateFunction: Option[AggregateFunction] = aggregations.headOption + lazy val aggregation: Boolean = aggregateFunction.isDefined + + override def in: SQLType = functions.lastOption.map(_.in).getOrElse(super.in) + + override def out: SQLType = { + val baseType = functions.lastOption.map(_.in).getOrElse(super.baseType) + functions.reverse.foldLeft(baseType) { (currentType, fun) => + fun.applyType(currentType) + } + } + } sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType] @@ -50,6 +104,9 @@ sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType] with PainlessScript { def inputType: In def outputType: Out + override def in: SQLType = inputType + override def out: SQLType = outputType + override def applyType(in: SQLType): SQLType = outputType } sealed trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType] @@ -60,28 +117,24 @@ sealed trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType] def left: PainlessScript def right: PainlessScript + override def nullable: Boolean = left.nullable || right.nullable } sealed trait SQLTransformFunction[In <: SQLType, Out <: SQLType] extends SQLUnaryFunction[In, Out] { - def toPainless(base: String): String = s"$base$painless" + def toPainless(base: String, idx: Int): String = { + if (nullable) + s"(def e$idx = $base; e$idx != null ? e$idx$painless : null)" + else + s"$base$painless" + } } sealed trait SQLArithmeticFunction[In <: SQLType, Out <: SQLType] - extends SQLTransformFunction[In, Out] { + extends SQLTransformFunction[In, Out] + with MathScript { def operator: ArithmeticOperator override def toSQL(base: String): String = s"$base$operator$sql" -} - -sealed trait ParametrizedFunction extends SQLFunction { - def params: Seq[String] - override def toSQL(base: String): String = { - params match { - case Nil => s"$sql($base)" - case _ => - val paramsStr = params.mkString(", ") - s"$sql($paramsStr)($base)" - } - } + override def applyType(in: SQLType): SQLType = in } sealed trait AggregateFunction extends SQLFunction @@ -144,6 +197,31 @@ sealed trait TimeInterval extends PainlessScript with MathScript { override def painless: String = s"$value, ${unit.painless}" override def script: String = TimeInterval.script(this) + + def checkType(in: SQLType): Either[String, SQLType] = { + import TimeUnit._ + in match { + case SQLTypes.Date => + unit match { + case Year | Month | Day => Right(SQLTypes.Date) + case Hour | Minute | Second => Right(SQLTypes.Timestamp) + case _ => Left(s"Invalid interval unit $unit for DATE") + } + case SQLTypes.Time => + unit match { + case Hour | Minute | Second => Right(SQLTypes.Time) + case _ => Left(s"Invalid interval unit $unit for TIME") + } + case SQLTypes.DateTime => + Right(SQLTypes.Timestamp) + case SQLTypes.Timestamp => + Right(SQLTypes.Timestamp) + case SQLTypes.Temporal => + Right(SQLTypes.Timestamp) + case _ => + Left(s"Intervals not supported for type $in") + } + } } import TimeUnit._ @@ -163,46 +241,84 @@ object TimeInterval { } } -case class SQLAddInterval(interval: TimeInterval) - extends SQLExpr(interval.sql) - with SQLArithmeticFunction[SQLDateTime, SQLDateTime] - with MathScript { +sealed trait SQLIntervalFunction[IO <: SQLTemporal] extends SQLArithmeticFunction[IO, IO] { + def interval: TimeInterval + override def script: String = s"${operator.script}${interval.script}" + private[this] var _out: SQLType = outputType + override def out: SQLType = _out + + override def applyType(in: SQLType): SQLType = { + _out = interval.checkType(in).getOrElse(out) + _out + } + + override def validate(): Either[String, Unit] = interval.checkType(out) match { + case Left(err) => Left(err) + case Right(_) => Right(()) + } + + override def toPainless(base: String, idx: Int): String = + if (nullable) + s"(def e$idx = $base; e$idx != null ? ${SQLTypeUtils.coerce(s"e$idx", expr.out, out, nullable = false)}$painless : null)" + else + s"${SQLTypeUtils.coerce(base, expr.out, out, nullable = expr.nullable)}$painless" +} + +sealed trait AddInterval[IO <: SQLTemporal] extends SQLIntervalFunction[IO] { override def operator: ArithmeticOperator = Add - override def inputType: SQLDateTime = SQLTypes.DateTime - override def outputType: SQLDateTime = SQLTypes.DateTime override def painless: String = s".plus(${interval.painless})" - override def script: String = s"${operator.script}${interval.script}" } -case class SQLSubstractInterval(interval: TimeInterval) - extends SQLExpr(interval.sql) - with SQLArithmeticFunction[SQLDateTime, SQLDateTime] - with MathScript { +sealed trait SubtractInterval[IO <: SQLTemporal] extends SQLIntervalFunction[IO] { override def operator: ArithmeticOperator = Subtract - override def inputType: SQLDateTime = SQLTypes.DateTime - override def outputType: SQLDateTime = SQLTypes.DateTime override def painless: String = s".minus(${interval.painless})" - override def script: String = s"${operator.script}${interval.script}" } -sealed trait DateTimeFunction extends SQLFunction +case class SQLAddInterval(interval: TimeInterval) + extends SQLExpr(interval.sql) + with AddInterval[SQLTemporal] { + override def inputType: SQLTemporal = SQLTypes.Temporal + override def outputType: SQLTemporal = SQLTypes.Temporal +} + +case class SQLSubtractInterval(interval: TimeInterval) + extends SQLExpr(interval.sql) + with SubtractInterval[SQLTemporal] { + override def inputType: SQLTemporal = SQLTypes.Temporal + override def outputType: SQLTemporal = SQLTypes.Temporal +} + +sealed trait DateTimeFunction extends SQLFunction { + def now: String = "ZonedDateTime.now(ZoneId.of('Z'))" + override def out: SQLType = SQLTypes.DateTime +} + +sealed trait DateFunction extends DateTimeFunction { + override def out: SQLType = SQLTypes.Date +} -sealed trait DateFunction extends DateTimeFunction +sealed trait TimeFunction extends DateTimeFunction { + override def out: SQLType = SQLTypes.Time +} -sealed trait TimeFunction extends DateTimeFunction +sealed trait SystemFunction extends SQLFunction { + override def system: Boolean = true +} -sealed trait CurrentDateTimeFunction extends DateTimeFunction with PainlessScript with MathScript { - override def painless: String = - "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" +sealed trait CurrentFunction extends SystemFunction with PainlessScript + +sealed trait CurrentDateTimeFunction extends DateTimeFunction with CurrentFunction with MathScript { + override def painless: String = now override def script: String = "now" } -sealed trait CurrentDateFunction extends CurrentDateTimeFunction with DateFunction { - override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalDate()" +sealed trait CurrentDateFunction extends DateFunction with CurrentFunction with MathScript { + override def painless: String = s"$now.toLocalDate()" + override def script: String = "now" } -sealed trait CurrentTimeFunction extends CurrentDateTimeFunction with TimeFunction { - override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalTime()" +sealed trait CurrentTimeFunction extends TimeFunction with CurrentFunction { + override def painless: String = s"$now.toLocalTime()" } case object CurrentDate extends SQLExpr("current_date") with CurrentDateFunction @@ -239,55 +355,64 @@ case class DateTrunc(identifier: SQLIdentifier, unit: TimeUnit) case class Extract(unit: TimeUnit, override val sql: String = "extract") extends SQLExpr(sql) with DateTimeFunction - with SQLTransformFunction[SQLTemporal, SQLNumber] - with ParametrizedFunction { + with SQLTransformFunction[SQLTemporal, SQLNumeric] { override def inputType: SQLTemporal = SQLTypes.Temporal - override def outputType: SQLNumber = SQLTypes.Number - override def params: Seq[String] = Seq(unit.sql) + override def outputType: SQLNumeric = SQLTypes.Numeric + override def toSQL(base: String): String = s"$sql(${unit.sql} from $base)" override def painless: String = s".get(${unit.painless})" } object YEAR extends Extract(Year, Year.sql) { - override def params: Seq[String] = Seq.empty + override def toSQL(base: String): String = s"$sql($base)" } object MONTH extends Extract(Month, Month.sql) { - override def params: Seq[String] = Seq.empty + override def toSQL(base: String): String = s"$sql($base)" } object DAY extends Extract(Day, Day.sql) { - override def params: Seq[String] = Seq.empty + override def toSQL(base: String): String = s"$sql($base)" } object HOUR extends Extract(Hour, Hour.sql) { - override def params: Seq[String] = Seq.empty + override def toSQL(base: String): String = s"$sql($base)" } object MINUTE extends Extract(Minute, Minute.sql) { - override def params: Seq[String] = Seq.empty + override def toSQL(base: String): String = s"$sql($base)" } object SECOND extends Extract(Second, Second.sql) { - override def params: Seq[String] = Seq.empty + override def toSQL(base: String): String = s"$sql($base)" } case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit) extends SQLExpr("date_diff") with DateTimeFunction - with SQLBinaryFunction[SQLDateTime, SQLDateTime, SQLNumber] + with SQLBinaryFunction[SQLDateTime, SQLDateTime, SQLNumeric] with PainlessScript { - override def outputType: SQLNumber = SQLTypes.Number + override def outputType: SQLNumeric = SQLTypes.Numeric override def left: PainlessScript = end override def right: PainlessScript = start override def toSQL(base: String): String = { s"$sql(${end.sql}, ${start.sql}, ${unit.sql})" } - override def painless: String = s"${unit.painless}.between(${start.painless}, ${end.painless})" + override def painless: String = { + if (start.nullable && end.nullable) + s"(def s = ${start.painless}; def e = ${end.painless}; s != null && e != null ? ${unit.painless}.between(s, e) : null)" + else if (start.nullable) + s"(def s = ${start.painless}; s != null ? ${unit.painless}.between(s, ${end.painless}) : null)" + else if (end.nullable) + s"(def e = ${end.painless}; e != null ? ${unit.painless}.between(${start.painless}, e) : null)" + else + s"${unit.painless}.between(${start.painless}, ${end.painless})" + } } case class DateAdd(identifier: SQLIdentifier, interval: TimeInterval) extends SQLExpr("date_add") with DateFunction + with AddInterval[SQLDate] with SQLTransformFunction[SQLDate, SQLDate] with SQLFunctionWithIdentifier { override def inputType: SQLDate = SQLTypes.Date @@ -295,12 +420,12 @@ case class DateAdd(identifier: SQLIdentifier, interval: TimeInterval) override def toSQL(base: String): String = { s"$sql($base, ${interval.sql})" } - override def painless: String = s".plus(${interval.painless})" } case class DateSub(identifier: SQLIdentifier, interval: TimeInterval) extends SQLExpr("date_sub") with DateFunction + with SubtractInterval[SQLDate] with SQLTransformFunction[SQLDate, SQLDate] with SQLFunctionWithIdentifier { override def inputType: SQLDate = SQLTypes.Date @@ -308,42 +433,48 @@ case class DateSub(identifier: SQLIdentifier, interval: TimeInterval) override def toSQL(base: String): String = { s"$sql($base, ${interval.sql})" } - override def painless: String = s".minus(${interval.painless})" } case class ParseDate(identifier: SQLIdentifier, format: String) extends SQLExpr("parse_date") with DateFunction - with SQLTransformFunction[SQLString, SQLDate] + with SQLTransformFunction[SQLVarchar, SQLDate] with SQLFunctionWithIdentifier { - override def inputType: SQLString = SQLTypes.String + override def inputType: SQLVarchar = SQLTypes.Varchar override def outputType: SQLDate = SQLTypes.Date override def toSQL(base: String): String = { s"$sql($base, '$format')" } override def painless: String = throw new NotImplementedError("Use toPainless instead") - override def toPainless(base: String): String = - s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDate::from)" + override def toPainless(base: String, idx: Int): String = + if (nullable) + s"(def e$idx = $base; e$idx != null ? DateTimeFormatter.ofPattern('$format').parse(e$idx, LocalDate::from) : null)" + else + s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDate::from)" } case class FormatDate(identifier: SQLIdentifier, format: String) extends SQLExpr("format_date") with DateFunction - with SQLTransformFunction[SQLDate, SQLString] + with SQLTransformFunction[SQLDate, SQLVarchar] with SQLFunctionWithIdentifier { override def inputType: SQLDate = SQLTypes.Date - override def outputType: SQLString = SQLTypes.String + override def outputType: SQLVarchar = SQLTypes.Varchar override def toSQL(base: String): String = { s"$sql($base, '$format')" } override def painless: String = throw new NotImplementedError("Use toPainless instead") - override def toPainless(base: String): String = - s"DateTimeFormatter.ofPattern('$format').format($base)" + override def toPainless(base: String, idx: Int): String = + if (nullable) + s"(def e$idx = $base; e$idx != null ? DateTimeFormatter.ofPattern('$format').format(e$idx) : null)" + else + s"DateTimeFormatter.ofPattern('$format').format($base)" } case class DateTimeAdd(identifier: SQLIdentifier, interval: TimeInterval) extends SQLExpr("datetime_add") with DateTimeFunction + with AddInterval[SQLDateTime] with SQLTransformFunction[SQLDateTime, SQLDateTime] with SQLFunctionWithIdentifier { override def inputType: SQLDateTime = SQLTypes.DateTime @@ -351,12 +482,12 @@ case class DateTimeAdd(identifier: SQLIdentifier, interval: TimeInterval) override def toSQL(base: String): String = { s"$sql($base, ${interval.sql})" } - override def painless: String = s".plus(${interval.painless})" } case class DateTimeSub(identifier: SQLIdentifier, interval: TimeInterval) extends SQLExpr("datetime_sub") with DateTimeFunction + with SubtractInterval[SQLDateTime] with SQLTransformFunction[SQLDateTime, SQLDateTime] with SQLFunctionWithIdentifier { override def inputType: SQLDateTime = SQLTypes.DateTime @@ -364,35 +495,259 @@ case class DateTimeSub(identifier: SQLIdentifier, interval: TimeInterval) override def toSQL(base: String): String = { s"$sql($base, ${interval.sql})" } - override def painless: String = s".minus(${interval.painless})" } case class ParseDateTime(identifier: SQLIdentifier, format: String) extends SQLExpr("parse_datetime") with DateTimeFunction - with SQLTransformFunction[SQLString, SQLDateTime] + with SQLTransformFunction[SQLVarchar, SQLDateTime] with SQLFunctionWithIdentifier { - override def inputType: SQLString = SQLTypes.String + override def inputType: SQLVarchar = SQLTypes.Varchar override def outputType: SQLDateTime = SQLTypes.DateTime override def toSQL(base: String): String = { s"$sql($base, '$format')" } override def painless: String = throw new NotImplementedError("Use toPainless instead") - override def toPainless(base: String): String = - s"DateTimeFormatter.ofPattern('$format').parse($base, ZonedDateTime::from)" + override def toPainless(base: String, idx: Int): String = + if (nullable) + s"(def e$idx = $base; e$idx != null ? DateTimeFormatter.ofPattern('$format').parse(e$idx, ZonedDateTime::from) : null)" + else + s"DateTimeFormatter.ofPattern('$format').parse($base, ZonedDateTime::from)" } case class FormatDateTime(identifier: SQLIdentifier, format: String) extends SQLExpr("format_datetime") with DateTimeFunction - with SQLTransformFunction[SQLDateTime, SQLString] + with SQLTransformFunction[SQLDateTime, SQLVarchar] with SQLFunctionWithIdentifier { override def inputType: SQLDateTime = SQLTypes.DateTime - override def outputType: SQLString = SQLTypes.String + override def outputType: SQLVarchar = SQLTypes.Varchar override def toSQL(base: String): String = { s"$sql($base, '$format')" } override def painless: String = throw new NotImplementedError("Use toPainless instead") - override def toPainless(base: String): String = - s"DateTimeFormatter.ofPattern('$format').format($base)" + override def toPainless(base: String, idx: Int): String = + if (nullable) + s"(def e$idx = $base; e$idx != null ? DateTimeFormatter.ofPattern('$format').format(e$idx) : null)" + else + s"DateTimeFormatter.ofPattern('$format').format($base)" +} + +sealed trait SQLConditionalFunction[In <: SQLType] + extends SQLTransformFunction[In, SQLBool] + with SQLFunctionWithIdentifier { + def operator: SQLConditionalOperator + override def outputType: SQLBool = SQLTypes.Boolean + override def toPainless(base: String, idx: Int): String = s"($base$painless)" +} + +case class SQLIsNullFunction(identifier: SQLIdentifier) + extends SQLExpr("isnull") + with SQLConditionalFunction[SQLAny] { + override def operator: SQLConditionalOperator = IsNull + override def inputType: SQLAny = SQLTypes.Any + override def painless: String = s" == null" + override def toPainless(base: String, idx: Int): String = { + if (nullable) + s"(def e$idx = $base; e$idx$painless)" + else + s"$base$painless" + } +} + +case class SQLIsNotNullFunction(identifier: SQLIdentifier) + extends SQLExpr("isnotnull") + with SQLConditionalFunction[SQLAny] { + override def operator: SQLConditionalOperator = IsNotNull + override def inputType: SQLAny = SQLTypes.Any + override def painless: String = s" != null" + override def toPainless(base: String, idx: Int): String = { + if (nullable) + s"(def e$idx = $base; e$idx$painless)" + else + s"$base$painless" + } +} + +case class SQLCoalesce(values: List[PainlessScript]) extends SQLConditionalFunction[SQLAny] { + override def operator: SQLConditionalOperator = Coalesce + + override def identifier: SQLIdentifier = SQLIdentifier("") + + override def inputType: SQLAny = SQLTypes.Any + + override lazy val sql: String = s"$Coalesce(${values.map(_.sql).mkString(", ")})" + + // Reprend l’idée de SQLValues mais pour n’importe quel token + override def out: SQLType = SQLTypeUtils.leastCommonSuperType(values.map(_.out).distinct) + + override def applyType(in: SQLType): SQLType = out + + override def validate(): Either[String, Unit] = { + if (values.isEmpty) Left("COALESCE requires at least one argument") + else Right(()) + } + + override def toPainless(base: String, idx: Int): String = s"$base$painless" + + override def painless: String = { + require(values.nonEmpty, "COALESCE requires at least one argument") + + val checks = values + .take(values.length - 1) + .zipWithIndex + .map { case (v, index) => + var check = s"def v$index = ${SQLTypeUtils.coerce(v, out)};" + check += s"if (v$index != null) return v$index;" + check + } + .mkString(" ") + // final fallback + s"{ $checks return ${SQLTypeUtils.coerce(values.last, out)}; }" + } + + override def nullable: Boolean = values.forall(_.nullable) +} + +case class SQLNullIf(expr1: PainlessScript, expr2: PainlessScript) + extends SQLConditionalFunction[SQLAny] { + override def operator: SQLConditionalOperator = NullIf + + override def identifier: SQLIdentifier = SQLIdentifier("") + + override def inputType: SQLAny = SQLTypes.Any + + override def sql: String = s"$NullIf(${expr1.sql}, ${expr2.sql})" + + override def out: SQLType = expr1.out + + override def applyType(in: SQLType): SQLType = out + + override def painless: String = { + val e1 = expr1.painless + val e2 = expr2.painless + s"""{ def e1 = $e1; + |def e2 = $e2; + |return e1 == e2 ? null : e1; + |}""".stripMargin.replaceAll("\n", " ") + } +} + +case class SQLCast(value: PainlessScript, targetType: SQLType, as: Boolean = true) + extends SQLTransformFunction[SQLType, SQLType] { + override def inputType: SQLType = value.out + override def outputType: SQLType = targetType + + override def sql: String = + s"$Cast(${value.sql} ${if (as) s"$Alias " else ""}${targetType.typeId})" + + override def toSQL(base: String): String = sql + + override def painless: String = + SQLTypeUtils.coerce(value, targetType) + + override def toPainless(base: String, idx: Int): String = + SQLTypeUtils.coerce(base, value.out, targetType, value.nullable) +} + +case class SQLCaseWhen( + expression: Option[PainlessScript], + conditions: List[(PainlessScript, PainlessScript)], + default: Option[PainlessScript] +) extends SQLTransformFunction[SQLAny, SQLAny] { + override def inputType: SQLAny = SQLTypes.Any + override def outputType: SQLAny = SQLTypes.Any + + override def sql: String = { + val exprPart = expression.map(e => s"$Case ${e.sql}").getOrElse(Case.sql) + val whenThen = conditions + .map { case (cond, res) => s"$When ${cond.sql} $Then ${res.sql}" } + .mkString(" ") + val elsePart = default.map(d => s" $Else ${d.sql}").getOrElse("") + s"$exprPart $whenThen$elsePart $End" + } + + override def out: SQLType = + SQLTypeUtils.leastCommonSuperType( + conditions.map(_._2.out) ++ default.map(_.out).toList + ) + + override def applyType(in: SQLType): SQLType = out + + override def validate(): Either[String, Unit] = { + if (conditions.isEmpty) Left("CASE WHEN requires at least one condition") + else if ( + expression.isEmpty && conditions.exists { case (cond, _) => cond.out != SQLTypes.Boolean } + ) + Left("CASE WHEN conditions must be of type BOOLEAN") + else if ( + expression.isDefined && conditions.exists { case (cond, _) => + !SQLTypeUtils.matches(cond.out, expression.get.out) + } + ) + Left("CASE WHEN conditions must be of the same type as the expression") + else Right(()) + } + + override def painless: String = { + val base = + expression match { + case Some(expr) => + s"def expr = ${SQLTypeUtils.coerce(expr, expr.out)}; " + case _ => "" + } + val cases = conditions.zipWithIndex + .map { case ((cond, res), idx) => + val name = + cond match { + case e: Expression => + e.identifier.name + case i: Identifier => + i.name + case _ => "" + } + expression match { + case Some(expr) => + val c = SQLTypeUtils.coerce(cond, expr.out) + if (cond.sql == res.sql) { + s"def val$idx = $c; if (expr == val$idx) return val$idx;" + } else { + res match { + case i: Identifier if i.name == name && cond.isInstanceOf[Identifier] => + i.nullable = false + if (cond.asInstanceOf[Identifier].functions.isEmpty) + s"def val$idx = $c; if (expr == val$idx) return ${SQLTypeUtils.coerce(i.toPainless(s"val$idx"), i.out, out, nullable = false)};" + else { + cond.asInstanceOf[Identifier].nullable = false + s"def e$idx = ${i.checkNotNull}; def val$idx = e$idx != null ? ${SQLTypeUtils + .coerce(cond.asInstanceOf[Identifier].toPainless(s"e$idx"), cond.out, out, nullable = false)} : null; if (expr == val$idx) return ${SQLTypeUtils + .coerce(i.toPainless(s"e$idx"), i.out, out, nullable = false)};" + } + case _ => + s"if (expr == $c) return ${SQLTypeUtils.coerce(res, out)};" + } + } + case None => + val c = SQLTypeUtils.coerce(cond, SQLTypes.Boolean) + val r = + res match { + case i: Identifier if i.name == name && cond.isInstanceOf[Expression] => + i.nullable = false + SQLTypeUtils.coerce(i.toPainless("left"), i.out, out, nullable = false) + case _ => SQLTypeUtils.coerce(res, out) + } + s"if ($c) return $r;" + } + } + .mkString(" ") + val defaultCase = default + .map(d => s"def dval = ${SQLTypeUtils.coerce(d, out)}; return dval;") + .getOrElse("return null;") + s"{ $base$cases $defaultCase }" + } + + override def toPainless(base: String, idx: Int): String = s"$base$painless" + + override def nullable: Boolean = + conditions.exists { case (_, res) => res.nullable } || default.forall(_.nullable) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala index e30e7f73..39220f68 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala @@ -9,6 +9,14 @@ case class SQLGroupBy(buckets: Seq[SQLBucket]) extends Updateable { lazy val bucketNames: Map[String, SQLBucket] = buckets.map { b => b.identifier.identifierName -> b }.toMap + + override def validate(): Either[String, Unit] = { + if (buckets.isEmpty) { + Left("At least one bucket is required in GROUP BY clause") + } else { + Right(()) + } + } } case class SQLBucket( @@ -33,81 +41,18 @@ case class SQLBucket( object BucketSelectorScript { - private[this] def painlessIn(param: String, values: Seq[SQLValue[_]], not: Boolean): String = { - val ret = s"[${values.map { _.painless }.mkString(", ")}].contains($param)" - if (not) s"!$ret" else ret - } - - private[this] def painlessBetween( - param: String, - lower: SQLValue[_], - upper: SQLValue[_], - not: Boolean - ): String = { - val ret = s"($param >= ${lower.painless} && $param <= ${upper.painless})" - if (not) s"!$ret" else ret - } - - private[this] def toPainless( - param: String, - operator: SQLOperator, - value: SQLToken, - not: Boolean - ): String = { - operator match { - case o: SQLComparisonOperator => - val valueStr = - value match { - case v: SQLBoolean => v.painless - case v: SQLDouble => v.painless - case v: SQLLiteral => v.painless - case v: SQLLong => v.painless - case _ => - throw new IllegalArgumentException( - s"Unsupported value type in bucket_selector: $value" - ) - } - if (not) - s"$param ${o.not.painless} $valueStr" - else - s"$param ${o.painless} $valueStr" - case In => - value match { - case SQLDoubleValues(vals) => painlessIn(param, vals, not) - case SQLLiteralValues(vals) => painlessIn(param, vals, not) - case SQLLongValues(vals) => painlessIn(param, vals, not) - case _ => throw new IllegalArgumentException("IN requires a list") - } - case Between => - value match { - case SQLDoubleFromTo(lower, upper) => painlessBetween(param, lower, upper, not) - case SQLLiteralFromTo(lower, upper) => painlessBetween(param, lower, upper, not) - case SQLLongFromTo(lower, upper) => painlessBetween(param, lower, upper, not) - case _ => throw new IllegalArgumentException("BETWEEN requires two values") - } - case _ => - throw new IllegalArgumentException(s"Unsupported operator in bucket_selector: $operator") - } - } - def extractBucketsPath(criteria: SQLCriteria): Map[String, String] = criteria match { case SQLPredicate(left, _, right, _, _) => extractBucketsPath(left) ++ extractBucketsPath(right) case relation: ElasticRelation => extractBucketsPath(relation.criteria) case _: SQLMatch => Map.empty //MATCH is not supported in bucket_selector - case b: BinaryExpression => - import b._ - if (left.aggregation && right.aggregation) - Map(left.aliasOrName -> left.aliasOrName, right.aliasOrName -> right.aliasOrName) - else if (left.aggregation) - Map(left.aliasOrName -> left.aliasOrName) - else if (right.aggregation) - Map(right.aliasOrName -> right.aliasOrName) - else - Map.empty case e: Expression if e.aggregation => import e._ - Map(identifier.aliasOrName -> identifier.aliasOrName) + maybeValue match { + case Some(v: SQLIdentifier) if v.aggregation => + Map(identifier.aliasOrName -> identifier.aliasOrName, v.aliasOrName -> v.aliasOrName) + case _ => Map(identifier.aliasOrName -> identifier.aliasOrName) + } case _ => Map.empty } @@ -116,9 +61,8 @@ object BucketSelectorScript { val leftStr = toPainless(left) val rightStr = toPainless(right) val opStr = op match { - case And => "&&" - case Or => "||" - case _ => throw new IllegalArgumentException(s"Unsupported logical operator: $op") + case And | Or => op.painless + case _ => throw new IllegalArgumentException(s"Unsupported logical operator: $op") } val not = maybeNot.nonEmpty if (group || not) @@ -128,44 +72,20 @@ object BucketSelectorScript { case relation: ElasticRelation => toPainless(relation.criteria) - case SQLComparisonDateMath(identifier, op, dateFunc, arithOp, interval, maybeNot) - if identifier.aggregation => - val painlessOp = if (maybeNot.nonEmpty) op.not.painless else op.painless - val paramName = identifier.aliasOrName - // always use a correct "now" creation - val now = "ZonedDateTime.now(ZoneId.of('Z'))" - - // build the RHS as a Painless ZonedDateTime (apply +/- interval using TimeInterval.painless) - val rightBase = (arithOp, interval) match { - case (Some(Add), Some(i)) => s"$now.plus(${i.painless})" - case (Some(Subtract), Some(i)) => s"$now.minus(${i.painless})" - case _ => now - } - - val rightZdt = dateFunc match { - // truncate only after arithmetic for CurrentDate - case _: CurrentDateFunction => s"$rightBase.truncatedTo(ChronoUnit.DAYS)" - case _: CurrentTimeFunction => s"$rightBase.truncatedTo(ChronoUnit.SECONDS)" - case _ => rightBase - } - - // protect against null params and compare epoch millis - s"(params.$paramName != null) && (params.$paramName $painlessOp $rightZdt.toInstant().toEpochMilli())" - case _: SQLMatch => "1 == 1" //MATCH is not supported in bucket_selector case e: Expression if e.aggregation => - val param = - s"params.${e.identifier.aliasOrName}" - e.maybeValue match { - case Some(v) => toPainless(param, e.operator, v, e.maybeNot.nonEmpty) - case None => - e.operator match { - case IsNull => s"$param == null" - case IsNotNull => s"$param != null" - case _ => - throw new IllegalArgumentException(s"Operator ${e.operator} requires a value") - } + val paramName = e.identifier.paramName + e.out match { + case SQLTypes.Date if e.operator.isInstanceOf[SQLComparisonOperator] => + // protect against null params and compare epoch millis + s"($paramName != null) && (${e.painless}.truncatedTo(ChronoUnit.DAYS).toInstant().toEpochMilli())" + case SQLTypes.Time if e.operator.isInstanceOf[SQLComparisonOperator] => + s"($paramName != null) && (${e.painless}.truncatedTo(ChronoUnit.SECONDS).toInstant().toEpochMilli())" + case SQLTypes.DateTime if e.operator.isInstanceOf[SQLComparisonOperator] => + s"($paramName != null) && (${e.painless}.toInstant().toEpochMilli())" + case _ => + e.painless } case _ => "1 == 1" //throw new IllegalArgumentException(s"Unsupported SQLCriteria type: $expr") diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLHaving.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLHaving.scala index a96351da..97ed5dc9 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLHaving.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLHaving.scala @@ -9,4 +9,6 @@ case class SQLHaving(criteria: Option[SQLCriteria]) extends Updateable { } def update(request: SQLSearchRequest): SQLHaving = this.copy(criteria = criteria.map(_.update(request))) + + override def validate(): Either[String, Unit] = criteria.map(_.validate()).getOrElse(Right(())) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLMultiSearchRequest.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLMultiSearchRequest.scala index 9184eef0..ed13841d 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLMultiSearchRequest.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLMultiSearchRequest.scala @@ -5,4 +5,7 @@ case class SQLMultiSearchRequest(requests: Seq[SQLSearchRequest]) extends SQLTok def update(): SQLMultiSearchRequest = this.copy(requests = requests.map(_.update())) + override def validate(): Either[String, Unit] = { + requests.map(_.validate()).find(_.isLeft).getOrElse(Right(())) + } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala index 21342118..1e9eb1fb 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala @@ -1,6 +1,17 @@ package app.softnetwork.elastic.sql -trait SQLOperator extends SQLToken +trait SQLOperator extends SQLToken with PainlessScript with SQLRegex { + override def painless: String = this match { + case And => "&&" + case Or => "||" + case Not => "!" + case In => ".contains" + case Like | Match => ".matches" + case Eq => "==" + case Ne => "!=" + case _ => sql + } +} sealed trait ArithmeticOperator extends SQLOperator with MathScript { override def toString: String = s" $sql " @@ -15,12 +26,6 @@ case object Modulo extends SQLExpr("%") with ArithmeticOperator sealed trait SQLExpressionOperator extends SQLOperator sealed trait SQLComparisonOperator extends SQLExpressionOperator with PainlessScript { - override def painless: String = this match { - case Eq => "==" - case Ne => "!=" - case other => other.sql - } - def not: SQLComparisonOperator = this match { case Eq => Ne case Ne | Diff => Eq @@ -38,18 +43,31 @@ case object Ge extends SQLExpr(">=") with SQLComparisonOperator case object Gt extends SQLExpr(">") with SQLComparisonOperator case object Le extends SQLExpr("<=") with SQLComparisonOperator case object Lt extends SQLExpr("<") with SQLComparisonOperator +case object In extends SQLExpr("in") with SQLComparisonOperator +case object Like extends SQLExpr("like") with SQLComparisonOperator +case object Between extends SQLExpr("between") with SQLComparisonOperator +case object IsNull extends SQLExpr("is null") with SQLComparisonOperator with SQLConditionalOperator +case object IsNotNull + extends SQLExpr("is not null") + with SQLComparisonOperator + with SQLConditionalOperator +case object Match extends SQLExpr("match") with SQLComparisonOperator +case object Against extends SQLExpr("against") with SQLRegex -sealed trait SQLLogicalOperator extends SQLExpressionOperator with SQLRegex +sealed trait SQLLogicalOperator extends SQLExpressionOperator -case object In extends SQLExpr("in") with SQLLogicalOperator -case object Like extends SQLExpr("like") with SQLLogicalOperator -case object Between extends SQLExpr("between") with SQLLogicalOperator -case object IsNull extends SQLExpr("is null") with SQLLogicalOperator -case object IsNotNull extends SQLExpr("is not null") with SQLLogicalOperator case object Not extends SQLExpr("not") with SQLLogicalOperator -case object Match extends SQLExpr("match") with SQLLogicalOperator -case object Against extends SQLExpr("against") with SQLRegex +sealed trait SQLConditionalOperator extends SQLExpressionOperator +case object Coalesce extends SQLExpr("coalesce") with SQLConditionalOperator +case object NullIf extends SQLExpr("nullif") with SQLConditionalOperator +case object Exists extends SQLExpr("exists") with SQLConditionalOperator +case object Cast extends SQLExpr("cast") with SQLConditionalOperator +case object Case extends SQLExpr("case") with SQLConditionalOperator +case object When extends SQLExpr("when") with SQLRegex +case object Then extends SQLExpr("then") with SQLRegex +case object Else extends SQLExpr("else") with SQLRegex +case object End extends SQLExpr("end") with SQLRegex sealed trait SQLPredicateOperator extends SQLLogicalOperator diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index ad1b1bd4..ff6db34c 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -7,11 +7,6 @@ import TimeUnit._ /** Created by smanciot on 27/06/2018. * * SQL Parser for ElasticSearch - * - * TODO implements SQL : - * - JOIN, - * - GROUP BY, - * - HAVING, etc. */ object SQLParser extends SQLParser @@ -27,8 +22,12 @@ object SQLParser def request: PackratParser[SQLSearchRequest] = { phrase(select ~ from ~ where.? ~ groupBy.? ~ having.? ~ orderBy.? ~ limit.?) ^^ { case s ~ f ~ w ~ g ~ h ~ o ~ l => - SQLSearchRequest(s, f, w, g, h, o, l) - .update() + val request = SQLSearchRequest(s, f, w, g, h, o, l).update() + request.validate() match { + case Left(error) => throw SQLValidationError(error) + case _ => + } + request } } @@ -58,18 +57,24 @@ trait SQLCompilationError case class SQLParserError(msg: String) extends SQLCompilationError -trait SQLParser extends RegexParsers with PackratParsers { +trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => - def literal: PackratParser[SQLLiteral] = - """"[^"]*"|'[^']*'""".r ^^ (str => SQLLiteral(str.substring(1, str.length - 1))) + def literal: PackratParser[SQLStringValue] = + """"[^"]*"|'[^']*'""".r ^^ (str => SQLStringValue(str.substring(1, str.length - 1))) - def long: PackratParser[SQLLong] = """(-)?(0|[1-9]\d*)""".r ^^ (str => SQLLong(str.toLong)) + def long: PackratParser[SQLLongValue] = + """(-)?(0|[1-9]\d*)""".r ^^ (str => SQLLongValue(str.toLong)) - def double: PackratParser[SQLDouble] = """(-)?(\d+\.\d+)""".r ^^ (str => SQLDouble(str.toDouble)) + def double: PackratParser[SQLDoubleValue] = + """(-)?(\d+\.\d+)""".r ^^ (str => SQLDoubleValue(str.toDouble)) def boolean: PackratParser[SQLBoolean] = """(true|false)""".r ^^ (bool => SQLBoolean(bool.toBoolean)) + /*def value_identifier: PackratParser[SQLIdentifier] = (literal | long | double | boolean) ^^ { v => + SQLIdentifier("", functions = v :: Nil) + }*/ + def start: PackratParser[SQLDelimiter] = "(" ^^ (_ => StartPredicate) def end: PackratParser[SQLDelimiter] = ")" ^^ (_ => EndPredicate) @@ -110,48 +115,48 @@ trait SQLParser extends RegexParsers with PackratParsers { TimeInterval(l.value.toInt, u) } - def current_date: PackratParser[CurrentDateTimeFunction] = - CurrentDate.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => - if (s.isDefined && t.isDefined) CurentDateWithParens else CurrentDate + def parens: PackratParser[List[SQLDelimiter]] = + start ~ end ^^ { case s ~ e => s :: e :: Nil } + + def current_date: PackratParser[CurrentFunction] = + CurrentDate.regex ~ parens.? ^^ { case _ ~ p => + if (p.isDefined) CurentDateWithParens else CurrentDate } - def current_time: PackratParser[CurrentDateTimeFunction] = - CurrentTime.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => - if (s.isDefined && t.isDefined) CurrentTimeWithParens else CurrentTime + def current_time: PackratParser[CurrentFunction] = + CurrentTime.regex ~ parens.? ^^ { case _ ~ p => + if (p.isDefined) CurrentTimeWithParens else CurrentTime } - def current_timestamp: PackratParser[CurrentDateTimeFunction] = - CurrentTimestamp.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => - if (s.isDefined && t.isDefined) CurrentTimestampWithParens else CurrentTimestamp + def current_timestamp: PackratParser[CurrentFunction] = + CurrentTimestamp.regex ~ parens.? ^^ { case _ ~ p => + if (p.isDefined) CurrentTimestampWithParens else CurrentTimestamp } - def now: PackratParser[CurrentDateTimeFunction] = Now.regex ~ start.? ~ end.? ^^ { - case _ ~ s ~ t => - if (s.isDefined && t.isDefined) NowWithParens else Now + def now: PackratParser[CurrentFunction] = Now.regex ~ parens.? ^^ { case _ ~ p => + if (p.isDefined) NowWithParens else Now } def add: PackratParser[ArithmeticOperator] = Add.sql ^^ (_ => Add) - def substract: PackratParser[ArithmeticOperator] = Subtract.sql ^^ (_ => Subtract) + def subtract: PackratParser[ArithmeticOperator] = Subtract.sql ^^ (_ => Subtract) - def intervalOperator: PackratParser[ArithmeticOperator] = add | substract + def intervalOperator: PackratParser[ArithmeticOperator] = add | subtract def arithmeticOperator: PackratParser[ArithmeticOperator] = intervalOperator - def addInterval: PackratParser[SQLAddInterval] = + def add_interval: PackratParser[SQLAddInterval] = add ~ interval ^^ { case _ ~ it => SQLAddInterval(it) } - def substractInterval: PackratParser[SQLSubstractInterval] = - substract ~ interval ^^ { case _ ~ it => - SQLSubstractInterval(it) + def substract_interval: PackratParser[SQLSubtractInterval] = + subtract ~ interval ^^ { case _ ~ it => + SQLSubtractInterval(it) } - def intervalFunction: PackratParser[SQLArithmeticFunction[SQLDateTime, SQLDateTime]] = - addInterval | substractInterval - - def arithmeticFunction: PackratParser[SQLArithmeticFunction[_, _]] = intervalFunction + def intervalFunction: PackratParser[SQLArithmeticFunction[SQLTemporal, SQLTemporal]] = + add_interval | substract_interval def identifierWithSystemFunction: PackratParser[SQLIdentifier] = (current_date | current_time | current_timestamp | now) ~ intervalFunction.? ^^ { @@ -163,56 +168,63 @@ trait SQLParser extends RegexParsers with PackratParsers { } def date_trunc: PackratParser[SQLFunctionWithIdentifier] = - "(?i)date_trunc".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { + "(?i)date_trunc".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { case _ ~ _ ~ i ~ _ ~ u ~ _ => DateTrunc(i, u) } - def extract: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = - "(?i)extract".r ~ start ~ time_unit ~ end ^^ { case _ ~ _ ~ u ~ _ => - Extract(u) + def extract_identifier: PackratParser[SQLIdentifier] = + "(?i)extract".r ~ start ~ time_unit ~ "(?i)from".r ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ end ^^ { + case _ ~ _ ~ u ~ _ ~ i ~ _ => + i.copy(functions = Extract(u) +: i.functions) } - def extract_year: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + def extract_year: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = Year.regex ^^ (_ => YEAR) - def extract_month: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + def extract_month: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = Month.regex ^^ (_ => MONTH) - def extract_day: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = Day.regex ^^ (_ => DAY) + def extract_day: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = + Day.regex ^^ (_ => DAY) - def extract_hour: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + def extract_hour: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = Hour.regex ^^ (_ => HOUR) - def extract_minute: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + def extract_minute: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = Minute.regex ^^ (_ => MINUTE) - def extract_second: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + def extract_second: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = Second.regex ^^ (_ => SECOND) - def extractors: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = - extract | extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second + def extractors: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = + extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second def date_add: PackratParser[DateFunction with SQLFunctionWithIdentifier] = - "(?i)date_add".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + "(?i)date_add".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ ~ t ~ _ => DateAdd(i, t) } def date_sub: PackratParser[DateFunction with SQLFunctionWithIdentifier] = - "(?i)date_sub".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + "(?i)date_sub".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ ~ t ~ _ => DateSub(i, t) } def parse_date: PackratParser[DateFunction with SQLFunctionWithIdentifier] = - "(?i)parse_date".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { - case _ ~ _ ~ i ~ _ ~ f ~ _ => - ParseDate(i, f.value) + "(?i)parse_date".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | literal | identifier) ~ separator ~ literal ~ end ^^ { + case _ ~ _ ~ li ~ _ ~ f ~ _ => + li match { + case l: SQLStringValue => + ParseDate(SQLIdentifier("", functions = l :: Nil), f.value) + case i: SQLIdentifier => + ParseDate(i, f.value) + } } def format_date: PackratParser[DateFunction with SQLFunctionWithIdentifier] = - "(?i)format_date".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + "(?i)format_date".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { case _ ~ _ ~ i ~ _ ~ f ~ _ => FormatDate(i, f.value) } @@ -220,25 +232,30 @@ trait SQLParser extends RegexParsers with PackratParsers { def date_functions: PackratParser[DateFunction] = date_add | date_sub | parse_date | format_date def datetime_add: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = - "(?i)datetime_add".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + "(?i)datetime_add".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ ~ t ~ _ => DateTimeAdd(i, t) } def datetime_sub: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = - "(?i)datetime_sub".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + "(?i)datetime_sub".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ ~ t ~ _ => DateTimeSub(i, t) } def parse_datetime: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = - "(?i)parse_datetime".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { - case _ ~ _ ~ i ~ _ ~ f ~ _ => - ParseDateTime(i, f.value) + "(?i)parse_datetime".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | literal | identifier) ~ separator ~ literal ~ end ^^ { + case _ ~ _ ~ li ~ _ ~ f ~ _ => + li match { + case l: SQLLiteral => + ParseDateTime(SQLIdentifier("", functions = l :: Nil), f.value) + case i: SQLIdentifier => + ParseDateTime(i, f.value) + } } def format_datetime: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = - "(?i)format_datetime".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + "(?i)format_datetime".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { case _ ~ _ ~ i ~ _ ~ f ~ _ => FormatDateTime(i, f.value) } @@ -250,17 +267,13 @@ trait SQLParser extends RegexParsers with PackratParsers { def distance: PackratParser[SQLFunction] = Distance.regex ^^ (_ => Distance) - def painless_identifier: PackratParser[SQLIdentifier] = - repsep( + def identifierWithTemporalFunction: PackratParser[SQLIdentifier] = + rep1sep( date_trunc | extractors | date_functions | datetime_functions, start - ) ~ start.? ~ (identifierWithSystemFunction | identifierWithArithmeticFunction | identifier).? ~ rep( + ) ~ start.? ~ (identifierWithSystemFunction | identifier).? ~ rep( end ) ^^ { case f ~ _ ~ i ~ _ => - SQLValidator.validateChain(f) match { - case Left(error) => throw SQLValidationError(error) - case _ => - } i match { case Some(id) => id.copy(functions = id.functions ++ f) case None => @@ -273,7 +286,7 @@ trait SQLParser extends RegexParsers with PackratParsers { } def date_diff: PackratParser[SQLBinaryFunction[_, _, _]] = - "(?i)date_diff".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { + "(?i)date_diff".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { case _ ~ _ ~ d1 ~ _ ~ d2 ~ _ ~ u ~ _ => DateDiff(d1, d2, u) } @@ -281,13 +294,172 @@ trait SQLParser extends RegexParsers with PackratParsers { SQLIdentifier("", functions = dd :: Nil) } + def case_when_identifier: Parser[SQLIdentifier] = case_when ^^ { cw => + SQLIdentifier("", functions = cw :: Nil) + } + + def is_null: PackratParser[SQLConditionalFunction[_]] = + "(?i)isnull".r ~ start ~ (identifierWithTransformation | identifierWithArithmeticFunction | identifierWithTemporalFunction | identifier) ~ end ^^ { + case _ ~ _ ~ i ~ _ => SQLIsNullFunction(i) + } + + def is_notnull: PackratParser[SQLConditionalFunction[_]] = + "(?i)isnotnull".r ~ start ~ (identifierWithTransformation | identifierWithArithmeticFunction | identifierWithTemporalFunction | identifier) ~ end ^^ { + case _ ~ _ ~ i ~ _ => SQLIsNotNullFunction(i) + } + + def valueExpr: PackratParser[PainlessScript] = + // les plus spécifiques en premier + identifierWithTransformation | // transformations appliquées à un identifier + date_diff_identifier | // date_diff(...) retournant un identifier-like + extract_identifier | + identifierWithSystemFunction | // CURRENT_DATE, NOW, etc. (+/- interval) + identifierWithArithmeticFunction | // foo - interval ... + identifierWithTemporalFunction | // chaîne de fonctions appliquées à un identifier + identifierWithFunction | // fonctions appliquées à un identifier + literal | // 'string' + long | + double | + boolean | + identifier + + def coalesce: PackratParser[SQLCoalesce] = + Coalesce.regex ~ start ~ rep1sep( + valueExpr, + separator + ) ~ end ^^ { case _ ~ _ ~ ids ~ _ => + SQLCoalesce(ids) + } + + def nullif: PackratParser[SQLNullIf] = + NullIf.regex ~ start ~ valueExpr ~ separator ~ valueExpr ~ end ^^ { + case _ ~ _ ~ id1 ~ _ ~ id2 ~ _ => SQLNullIf(id1, id2) + } + + def start_case: PackratParser[StartCase.type] = Case.regex ^^ (_ => StartCase) + + def when_case: PackratParser[WhenCase.type] = When.regex ^^ (_ => WhenCase) + + def then_case: PackratParser[ThenCase.type] = Then.regex ^^ (_ => ThenCase) + + def else_case: PackratParser[Else.type] = Else.regex ^^ (_ => Else) + + def end_case: PackratParser[EndCase.type] = End.regex ^^ (_ => EndCase) + + def case_condition: Parser[(PainlessScript, PainlessScript)] = + when_case ~ (whereCriteria | valueExpr) ~ then_case.? ~ valueExpr ^^ { case _ ~ c ~ _ ~ r => + c match { + case p: PainlessScript => p -> r + case rawTokens: List[SQLToken] => + processTokens(rawTokens) match { + case Some(criteria) => criteria -> r + case _ => SQLNull -> r + } + } + } + + def case_else: Parser[PainlessScript] = else_case ~ valueExpr ^^ { case _ ~ r => r } + + def case_when: PackratParser[SQLCaseWhen] = + start_case ~ valueExpr.? ~ rep1(case_condition) ~ case_else.? ~ end_case ^^ { + case _ ~ e ~ c ~ r ~ _ => SQLCaseWhen(e, c, r) + } + + def logical_functions: PackratParser[SQLTransformFunction[_, _]] = + is_null | is_notnull | coalesce | nullif | case_when + def sql_functions: PackratParser[SQLFunction] = - aggregates | distance | date_diff | date_trunc | extractors | date_functions | datetime_functions + aggregates | distance | date_diff | date_trunc | extractors | date_functions | datetime_functions | logical_functions + + //private val regexIdentifier = """[\*a-zA-Z_\-][a-zA-Z0-9_\-\.\[\]\*]*""" + + private val reservedKeywords = Seq( + "select", + "from", + "where", + "group", + "having", + "order", + "limit", + "as", + "by", + "except", + "unnest", + "current_date", + "current_time", + "current_datetime", + "current_timestamp", + "now", + "coalesce", + "nullif", + "isnull", + "isnotnull", + "date_add", + "date_sub", + "parse_date", + "parse_datetime", + "format_date", + "format_datetime", + "date_trunc", + "extract", + "date_diff", + "datetime_add", + "datetime_sub", + "interval", + "year", + "month", + "day", + "hour", + "minute", + "second", + "quarter", + "char", + "string", + "byte", + "tinyint", + "short", + "smallint", + "int", + "integer", + "long", + "bigint", + "real", + "float", + "double", + "boolean", + "time", + "date", + "datetime", + "timestamp", + "and", + "or", + "not", + "like", + "in", + "between", + "distinct", + "cast", + "count", + "min", + "max", + "avg", + "sum", + "case", + "when", + "then", + "else", + "end" + ) - private val regexIdentifier = """[\*a-zA-Z_\-][a-zA-Z0-9_\-\.\[\]\*]*""" + private val identifierRegexStr = + s"""(?i)(?!(?:${reservedKeywords.mkString( + "|" + )})\\b)[\\*a-zA-Z_\\-][a-zA-Z0-9_\\-.\\[\\]\\*]*""" + + private val identifierRegex = identifierRegexStr.r // scala.util.matching.Regex def identifier: PackratParser[SQLIdentifier] = - Distinct.regex.? ~ regexIdentifier.r ^^ { case d ~ i => + Distinct.regex.? ~ identifierRegex ^^ { case d ~ i => SQLIdentifier( i, None, @@ -295,18 +467,76 @@ trait SQLParser extends RegexParsers with PackratParsers { ) } + def char_type: PackratParser[SQLTypes.Char.type] = + "(?i)char".r ^^ (_ => SQLTypes.Char) + + def string_type: PackratParser[SQLTypes.Varchar.type] = + "(?i)varchar|string".r ^^ (_ => SQLTypes.Varchar) + + def date_type: PackratParser[SQLTypes.Date.type] = "(?i)date".r ^^ (_ => SQLTypes.Date) + + def time_type: PackratParser[SQLTypes.Time.type] = "(?i)time".r ^^ (_ => SQLTypes.Time) + + def datetime_type: PackratParser[SQLTypes.DateTime.type] = + "(?i)(datetime)".r ^^ (_ => SQLTypes.DateTime) + + def timestamp_type: PackratParser[SQLTypes.Timestamp.type] = + "(?i)(timestamp)".r ^^ (_ => SQLTypes.Timestamp) + + def boolean_type: PackratParser[SQLTypes.Boolean.type] = + "(?i)boolean".r ^^ (_ => SQLTypes.Boolean) + + def byte_type: PackratParser[SQLTypes.TinyInt.type] = + "(?i)(byte|tinyint)".r ^^ (_ => SQLTypes.TinyInt) + + def short_type: PackratParser[SQLTypes.SmallInt.type] = + "(?i)(short|smallint)".r ^^ (_ => SQLTypes.SmallInt) + + def int_type: PackratParser[SQLTypes.Int.type] = "(?i)(int|integer)".r ^^ (_ => SQLTypes.Int) + + def long_type: PackratParser[SQLTypes.BigInt.type] = "(?i)long|bigint".r ^^ (_ => SQLTypes.BigInt) + + def double_type: PackratParser[SQLTypes.Double.type] = "(?i)double".r ^^ (_ => SQLTypes.Double) + + def float_type: PackratParser[SQLTypes.Real.type] = "(?i)float|real".r ^^ (_ => SQLTypes.Real) + + def sql_type: PackratParser[SQLType] = + char_type | string_type | datetime_type | timestamp_type | date_type | time_type | boolean_type | long_type | double_type | float_type | int_type | short_type | byte_type + + private[this] def castFunctionWithIdentifier: PackratParser[SQLIdentifier] = + "(?i)cast".r ~ start ~ (identifierWithTransformation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | extract_identifier | identifier) ~ Alias.regex.? ~ sql_type ~ end ~ arithmeticFunction.? ^^ { + case _ ~ _ ~ i ~ as ~ t ~ _ ~ a => + i.copy(functions = + a.toList ++ (SQLCast(i, targetType = t, as = as.isDefined) +: i.functions) + ) + } + private[this] def dateFunctionWithIdentifier: PackratParser[SQLIdentifier] = - (parse_date | format_date | date_add | date_sub) ^^ { t => - t.identifier.copy(functions = t +: t.identifier.functions) + (parse_date | format_date | date_add | date_sub) ~ arithmeticFunction.? ^^ { case t ~ af => + af match { + case Some(f) => t.identifier.copy(functions = f +: t +: t.identifier.functions) + case None => t.identifier.copy(functions = t +: t.identifier.functions) + } } private[this] def dateTimeFunctionWithIdentifier: PackratParser[SQLIdentifier] = - (date_trunc | parse_datetime | format_datetime | datetime_add | datetime_sub) ^^ { t => + (date_trunc | parse_datetime | format_datetime | datetime_add | datetime_sub) ~ arithmeticFunction.? ^^ { + case t ~ af => + af match { + case Some(f) => t.identifier.copy(functions = f +: t +: t.identifier.functions) + case None => t.identifier.copy(functions = t +: t.identifier.functions) + } + } + + private[this] def conditionalFunctionWithIdentifier: PackratParser[SQLIdentifier] = + (is_null | is_notnull | coalesce | nullif) ^^ { t => t.identifier.copy(functions = t +: t.identifier.functions) } def identifierWithTransformation: PackratParser[SQLIdentifier] = - dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier + castFunctionWithIdentifier | conditionalFunctionWithIdentifier | dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier + + def arithmeticFunction: PackratParser[SQLArithmeticFunction[_, _]] = intervalFunction def identifierWithArithmeticFunction: PackratParser[SQLIdentifier] = (identifierWithFunction | identifier) ~ arithmeticFunction ^^ { case i ~ af => @@ -326,10 +556,6 @@ trait SQLParser extends RegexParsers with PackratParsers { ) ~ start.? ~ (identifierWithSystemFunction | identifierWithArithmeticFunction | identifier).? ~ rep1( end ) ^^ { case f ~ _ ~ i ~ _ => - SQLValidator.validateChain(f) match { - case Left(error) => throw SQLValidationError(error) - case _ => - } i match { case None => f.lastOption match { @@ -347,7 +573,7 @@ trait SQLParser extends RegexParsers with PackratParsers { def alias: PackratParser[SQLAlias] = Alias.regex.? ~ regexAlias.r ^^ { case _ ~ b => SQLAlias(b) } def field: PackratParser[Field] = - (identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | identifierWithTransformation | date_diff_identifier | identifier) ~ alias.? ^^ { + (identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | extract_identifier | case_when_identifier | identifier) ~ alias.? ^^ { case i ~ a => SQLField(i, a) } @@ -406,15 +632,17 @@ trait SQLWhereParser { private def diff: PackratParser[SQLComparisonOperator] = Diff.sql ^^ (_ => Diff) + private def any_identifier: PackratParser[SQLIdentifier] = + identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | extract_identifier | identifier + private def equality: PackratParser[SQLExpression] = - not.? ~ (identifierWithAggregation | identifierWithFunction | identifier) ~ (eq | ne | diff) ~ (boolean | literal | double | long) ^^ { + not.? ~ any_identifier ~ (eq | ne | diff) ~ (boolean | literal | double | long | any_identifier) ^^ { case n ~ i ~ o ~ v => SQLExpression(i, o, v, n) } def like: PackratParser[SQLExpression] = - (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Like.regex ~ literal ^^ { - case i ~ n ~ _ ~ v => - SQLExpression(i, Like, v, n) + any_identifier ~ not.? ~ Like.regex ~ literal ^^ { case i ~ n ~ _ ~ v => + SQLExpression(i, Like, v, n) } private def ge: PackratParser[SQLComparisonOperator] = Ge.sql ^^ (_ => Ge) @@ -426,24 +654,24 @@ trait SQLWhereParser { def lt: PackratParser[SQLComparisonOperator] = Lt.sql ^^ (_ => Lt) private def comparison: PackratParser[SQLExpression] = - not.? ~ (identifierWithAggregation | identifierWithFunction | identifier) ~ (ge | gt | le | lt) ~ (double | long | literal) ^^ { + not.? ~ any_identifier ~ (ge | gt | le | lt) ~ (double | long | literal | any_identifier) ^^ { case n ~ i ~ o ~ v => SQLExpression(i, o, v, n) } def in: PackratParser[SQLExpressionOperator] = In.regex ^^ (_ => In) private def inLiteral: PackratParser[SQLCriteria] = - identifier ~ not.? ~ in ~ start ~ rep1sep(literal, separator) ~ end ^^ { + any_identifier ~ not.? ~ in ~ start ~ rep1sep(literal, separator) ~ end ^^ { case i ~ n ~ _ ~ _ ~ v ~ _ => SQLIn( i, - SQLLiteralValues(v), + SQLStringValues(v), n ) } private def inDoubles: PackratParser[SQLCriteria] = - (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ in ~ start ~ rep1sep( + any_identifier ~ not.? ~ in ~ start ~ rep1sep( double, separator ) ~ end ^^ { case i ~ n ~ _ ~ _ ~ v ~ _ => @@ -455,7 +683,7 @@ trait SQLWhereParser { } private def inLongs: PackratParser[SQLCriteria] = - (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ in ~ start ~ rep1sep( + any_identifier ~ not.? ~ in ~ start ~ rep1sep( long, separator ) ~ end ^^ { case i ~ n ~ _ ~ _ ~ v ~ _ => @@ -467,17 +695,17 @@ trait SQLWhereParser { } def between: PackratParser[SQLCriteria] = - (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ literal ~ and ~ literal ^^ { + any_identifier ~ not.? ~ Between.regex ~ literal ~ and ~ literal ^^ { case i ~ n ~ _ ~ from ~ _ ~ to => SQLBetween(i, SQLLiteralFromTo(from, to), n) } def betweenLongs: PackratParser[SQLCriteria] = - (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ long ~ and ~ long ^^ { + any_identifier ~ not.? ~ Between.regex ~ long ~ and ~ long ^^ { case i ~ n ~ _ ~ from ~ _ ~ to => SQLBetween(i, SQLLongFromTo(from, to), n) } def betweenDoubles: PackratParser[SQLCriteria] = - (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ double ~ and ~ double ^^ { + any_identifier ~ not.? ~ Between.regex ~ double ~ and ~ double ^^ { case i ~ n ~ _ ~ from ~ _ ~ to => SQLBetween(i, SQLDoubleFromTo(from, to), n) } @@ -488,27 +716,25 @@ trait SQLWhereParser { def matchCriteria: PackratParser[SQLMatch] = Match.regex ~ start ~ rep1sep( - identifier, + any_identifier, separator ) ~ end ~ Against.regex ~ start ~ literal ~ end ^^ { case _ ~ _ ~ i ~ _ ~ _ ~ _ ~ l ~ _ => SQLMatch(i, l) } - private def dateTimeComparison: PackratParser[SQLComparisonDateMath] = { - // identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | identifier - not.? ~ (identifierWithAggregation | identifier) ~ (eq | ne | diff | ge | gt | le | lt) ~ (current_date | current_time | current_timestamp | now) ~ arithmeticOperator.? ~ interval.? ^^ { - case n ~ i ~ o ~ dt ~ ao ~ it => SQLComparisonDateMath(i, o, dt, ao, it, n) - } - } - def and: PackratParser[SQLPredicateOperator] = And.regex ^^ (_ => And) def or: PackratParser[SQLPredicateOperator] = Or.regex ^^ (_ => Or) def not: PackratParser[Not.type] = Not.regex ^^ (_ => Not) + def logical_criteria: PackratParser[SQLCriteria] = + (is_null | is_notnull) ^^ { case SQLConditionalFunctionAsCriteria(c) => + c + } + def criteria: PackratParser[SQLCriteria] = - (equality | like | dateTimeComparison | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | sql_distance | matchCriteria) ^^ ( + (equality | like | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | /*coalesce | nullif |*/ sql_distance | matchCriteria | logical_criteria) ^^ ( c => c ) @@ -549,7 +775,7 @@ trait SQLWhereParser { nestedCriteria | childCriteria | parentCriteria | criteria def whereCriteria: PackratParser[List[SQLToken]] = rep1( - allPredicate | allCriteria | start | or | and | end + allPredicate | allCriteria | start | or | and | end | then_case ) def where: PackratParser[SQLWhere] = @@ -639,6 +865,8 @@ trait SQLWhereParser { case _ => throw SQLValidationError("Invalid stack state for predicate creation") } + case ThenCase :: _ => + processTokensHelper(Nil, stack) // exit processing on THEN case (_: EndDelimiter) :: rest => processTokensHelper(rest, stack) // Ignore and move on case _ => processTokensHelper(Nil, stack) @@ -740,10 +968,6 @@ trait SQLOrderByParser { def fieldWithFunction: PackratParser[(String, List[SQLFunction])] = rep1sep(sql_functions, start) ~ start.? ~ fieldName ~ rep1(end) ^^ { case f ~ _ ~ n ~ _ => - SQLValidator.validateChain(f) match { - case Left(error) => throw SQLValidationError(error) - case _ => - } (n, f) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala index a578cc53..87ec07da 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala @@ -51,4 +51,44 @@ case class SQLSearchRequest( } lazy val buckets: Seq[SQLBucket] = groupBy.map(_.buckets).getOrElse(Seq.empty) + + override def validate(): Either[String, Unit] = { + for { + _ <- from.validate() + _ <- select.validate() + _ <- where.map(_.validate()).getOrElse(Right(())) + _ <- groupBy.map(_.validate()).getOrElse(Right(())) + _ <- having.map(_.validate()).getOrElse(Right(())) + _ <- orderBy.map(_.validate()).getOrElse(Right(())) + _ <- limit.map(_.validate()).getOrElse(Right(())) + /*_ <- { + // validate that having clauses are only applied when group by is present + if (having.isDefined && groupBy.isEmpty) { + Left("HAVING clauses can only be applied when GROUP BY is present") + } else { + Right(()) + } + }*/ + _ <- { + // validate that non-aggregated fields are not present when group by is present + if (groupBy.isDefined) { + val nonAggregatedFields = select.fields.filterNot(f => f.aggregation || f.isScriptField) + val invalidFields = nonAggregatedFields.filterNot(f => + buckets.exists(b => + b.name == f.fieldAlias.map(_.alias).getOrElse(f.sourceField.replace(".", "_")) + ) + ) + if (invalidFields.nonEmpty) { + Left( + s"Non-aggregated fields ${invalidFields.map(_.sql).mkString(", ")} cannot be selected when GROUP BY is present" + ) + } else { + Right(()) + } + } else { + Right(()) + } + } + } yield () + } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala index e2991f9d..5aa56a26 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala @@ -29,6 +29,8 @@ sealed trait Field extends Updateable with SQLFunctionChain with PainlessScript def painless: String = identifier.painless lazy val scriptName: String = fieldAlias.map(_.alias).getOrElse(sourceField) + + override def validate(): Either[String, Unit] = identifier.validate() } case class SQLField( @@ -58,4 +60,11 @@ case class SQLSelect( }.toMap def update(request: SQLSearchRequest): SQLSelect = this.copy(fields = fields.map(_.update(request)), except = except.map(_.update(request))) + + override def validate(): Either[String, Unit] = + if (fields.isEmpty) { + Left("At least one field is required in SELECT clause") + } else { + fields.map(_.validate()).find(_.isLeft).getOrElse(Right(())) + } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala index ff4cebf7..b0aea9da 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala @@ -3,16 +3,29 @@ package app.softnetwork.elastic.sql sealed trait SQLType { def typeId: String } trait SQLAny extends SQLType + trait SQLTemporal extends SQLType + trait SQLDate extends SQLTemporal +trait SQLTime extends SQLTemporal trait SQLDateTime extends SQLTemporal -trait SQLNumber extends SQLType -trait SQLString extends SQLType - -object SQLTypeCompatibility { - def matches(out: SQLType, in: SQLType): Boolean = - out.typeId == in.typeId || - (out.typeId == "temporal" && Set("date", "datetime").contains(in.typeId)) || - (in.typeId == "temporal" && Set("date", "datetime").contains(out.typeId)) || - out.typeId == "any" || in.typeId == "any" -} +trait SQLTimestamp extends SQLDateTime + +trait SQLNumeric extends SQLType + +trait SQLTinyInt extends SQLNumeric +trait SQLSmallInt extends SQLNumeric +trait SQLInt extends SQLNumeric +trait SQLBigInt extends SQLNumeric +trait SQLDouble extends SQLNumeric +trait SQLReal extends SQLNumeric + +trait SQLLiteral extends SQLType +trait SQLVarchar extends SQLLiteral +trait SQLChar extends SQLLiteral + +trait SQLBool extends SQLType + +trait SQLArray extends SQLType { def elementType: SQLType } + +trait SQLStruct extends SQLType diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala new file mode 100644 index 00000000..7dc0911b --- /dev/null +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala @@ -0,0 +1,137 @@ +package app.softnetwork.elastic.sql +import SQLTypes._ + +object SQLTypeUtils { + + def matches(out: SQLType, in: SQLType): Boolean = + out.typeId == in.typeId || + (out.typeId == Temporal.typeId && Set( + Date.typeId, + DateTime.typeId, + Time.typeId, + Timestamp.typeId + ).contains( + in.typeId + )) || + (in.typeId == Temporal.typeId && Set( + Date.typeId, + DateTime.typeId, + Time.typeId, + Timestamp.typeId + ).contains( + out.typeId + )) || + (out.typeId == Numeric.typeId && Set( + TinyInt.typeId, + SmallInt.typeId, + Int.typeId, + BigInt.typeId, + Double.typeId, + Real.typeId + ) + .contains( + in.typeId + )) || + (in.typeId == Numeric.typeId && Set( + TinyInt.typeId, + SmallInt.typeId, + Int.typeId, + BigInt.typeId, + Double.typeId, + Real.typeId + ) + .contains( + out.typeId + )) || + (out.typeId == Varchar.typeId && in.typeId == Varchar.typeId) || + (out.typeId == Boolean.typeId && in.typeId == Boolean.typeId) || + out.typeId == Any.typeId || in.typeId == Any.typeId || + out.typeId == Null.typeId || in.typeId == Null.typeId + + def leastCommonSuperType(types: List[SQLType]): SQLType = { + val distinct = types.distinct + if (distinct.size == 1) return distinct.head + + // 1. String + if (distinct.contains(SQLTypes.Varchar)) return SQLTypes.Varchar + + // 2. Number + if (distinct.contains(SQLTypes.Double)) return SQLTypes.Double + if (distinct.contains(SQLTypes.Real)) return SQLTypes.Real + if (distinct.contains(SQLTypes.BigInt)) return SQLTypes.BigInt + if (distinct.contains(SQLTypes.Int)) return SQLTypes.Int + if (distinct.contains(SQLTypes.SmallInt)) return SQLTypes.SmallInt + if (distinct.contains(SQLTypes.TinyInt)) return SQLTypes.TinyInt + if (distinct.contains(SQLTypes.Numeric)) return SQLTypes.Numeric + + // 3. Temporal + if (distinct.contains(SQLTypes.Timestamp)) return SQLTypes.Timestamp + if (distinct.contains(SQLTypes.DateTime)) return SQLTypes.DateTime + + // mixed case DATE + TIME → DATETIME + if (distinct.contains(SQLTypes.Date) && distinct.contains(SQLTypes.Time)) + return SQLTypes.DateTime + + if (distinct.contains(SQLTypes.Date)) return SQLTypes.Date + if (distinct.contains(SQLTypes.Time)) return SQLTypes.Time + if (distinct.contains(SQLTypes.Temporal)) return SQLTypes.Timestamp + + // 4. Null or Any + if (distinct.contains(SQLTypes.Null)) return SQLTypes.Any + if (distinct.contains(SQLTypes.Any)) return SQLTypes.Any + + // 5. Fallback + SQLTypes.Any + } + + def coerce(in: PainlessScript, to: SQLType): String = { + val expr = in.painless + val from = in.out + val nullable = in.nullable + coerce(expr, from, to, nullable) + } + + def coerce(expr: String, from: SQLType, to: SQLType, nullable: Boolean): String = { + val ret = { + (from, to) match { + // ---- DATE & TIME ---- + case (SQLTypes.Date, SQLTypes.DateTime | SQLTypes.Timestamp) => + s"($expr).atStartOfDay(ZoneId.of('Z'))" + case (SQLTypes.DateTime | SQLTypes.Timestamp, SQLTypes.Date) => + s"($expr).toLocalDate()" + case (SQLTypes.DateTime | SQLTypes.Timestamp, SQLTypes.Time) => + s"($expr).toLocalTime()" + + // ---- NUMERIQUES ---- + case (SQLTypes.Int, SQLTypes.BigInt) => + s"((long) $expr)" + case (SQLTypes.Int, SQLTypes.Double) => + s"((double) $expr)" + case (SQLTypes.BigInt, SQLTypes.Double) => + s"((double) $expr)" + + // ---- NUMERIC <-> TEMPORAL ---- + case (SQLTypes.BigInt, SQLTypes.Timestamp) => + s"Instant.ofEpochMilli($expr).atZone(ZoneId.of('Z'))" + case (SQLTypes.Timestamp, SQLTypes.BigInt) => + s"$expr.toInstant().toEpochMilli()" + + // ---- BOOLEEN -> NUMERIC ---- + case (SQLTypes.Boolean, SQLTypes.Numeric) => + s"($expr ? 1 : 0)" + + // ---- IDENTITY ---- + case (_, _) if from == to => + return expr + + // ---- PAR DEFAUT ---- + case _ => + return expr // fallback + } + } + if (!nullable) + return ret + s"($expr != null ? $ret : null)" + } + +} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala index 131c9a01..d067b650 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala @@ -2,9 +2,35 @@ package app.softnetwork.elastic.sql object SQLTypes { case object Any extends SQLAny { val typeId = "any" } + + case object Null extends SQLAny { val typeId = "null" } + case object Temporal extends SQLTemporal { val typeId = "temporal" } - case object Date extends SQLDate { val typeId = "date" } - case object DateTime extends SQLDateTime { val typeId = "datetime" } - case object Number extends SQLNumber { val typeId = "number" } - case object String extends SQLString { val typeId = "string" } + + case object Date extends SQLTemporal with SQLDate { val typeId = "date" } + case object Time extends SQLTemporal with SQLTime { val typeId = "time" } + case object DateTime extends SQLTemporal with SQLDateTime { val typeId = "datetime" } + case object Timestamp extends SQLTimestamp { val typeId = "timestamp" } + + case object Numeric extends SQLNumeric { val typeId = "numeric" } + + case object TinyInt extends SQLTinyInt { val typeId = "tinyint" } + case object SmallInt extends SQLSmallInt { val typeId = "smallint" } + case object Int extends SQLInt { val typeId = "int" } + case object BigInt extends SQLBigInt { val typeId = "bigint" } + case object Double extends SQLDouble { val typeId = "double" } + case object Real extends SQLReal { val typeId = "real" } + + case object Literal extends SQLLiteral { val typeId = "literal" } + + case object Char extends SQLChar { val typeId = "char" } + case object Varchar extends SQLVarchar { val typeId = "varchar" } + + case object Boolean extends SQLBool { val typeId = "boolean" } + + case class Array(elementType: SQLType) extends SQLArray { + val typeId = s"array<${elementType.typeId}>" + } + + case object Struct extends SQLStruct { val typeId = "struct" } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala index 776bab53..15994d96 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala @@ -4,19 +4,30 @@ object SQLValidator { def validateChain(functions: List[SQLFunction]): Either[String, Unit] = { // validate function chain type compatibility + functions match { + case Nil => return Right(()) + case _ => + } + functions.map(_.validate()).find(_.isLeft) match { + case Some(left) => return left + case None => + } val unaryFuncs = functions.collect { case f: SQLUnaryFunction[_, _] => f } unaryFuncs.sliding(2).foreach { case Seq(f1, f2) => - if (!SQLTypeCompatibility.matches(f2.outputType, f1.inputType)) { - return Left( - s"Type mismatch: output '${f2.outputType.typeId}' of `${f2.sql}` " + - s"is not compatible with input '${f1.inputType.typeId}' of `${f1.sql}`" - ) - } + validateTypesMatching(f2.outputType, f1.inputType) case _ => // ok } Right(()) } + + def validateTypesMatching(out: SQLType, in: SQLType): Either[String, Unit] = { + if (SQLTypeUtils.matches(out, in)) { + Right(()) + } else { + Left(s"Type mismatch: output '${out.typeId}' is not compatible with input '${in.typeId}'") + } + } } trait SQLValidation { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index b40a09d9..9e244acb 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -1,8 +1,10 @@ package app.softnetwork.elastic.sql +import scala.annotation.tailrec + case object Where extends SQLExpr("where") with SQLRegex -sealed trait SQLCriteria extends Updateable { +sealed trait SQLCriteria extends Updateable with PainlessScript { def operator: SQLOperator def nested: Boolean = false @@ -28,6 +30,28 @@ sealed trait SQLCriteria extends Updateable { } } + private[this] def asGroup(script: String): String = if (group) s"($script)" else script + + override def out: SQLType = SQLTypes.Boolean + + override def painless: String = this match { + case SQLPredicate(left, op, right, maybeNot, group) => + val leftStr = left.painless + val rightStr = right.painless + val opStr = op match { + case And | Or => op.painless + case _ => throw new IllegalArgumentException(s"Unsupported logical operator: $op") + } + val not = maybeNot.nonEmpty + if (group || not) + s"${maybeNot.map(_.painless).getOrElse("")}($leftStr $opStr $rightStr)" + else + s"$leftStr $opStr $rightStr" + case relation: ElasticRelation => asGroup(relation.criteria.painless) + case m: SQLMatch => asGroup(m.criteria.painless) + case expr: Expression => asGroup(expr.painless) + case _ => throw new IllegalArgumentException(s"Unsupported criteria: $this") + } } case class SQLPredicate( @@ -85,18 +109,15 @@ case class SQLPredicate( override def matchCriteria: Boolean = leftCriteria.matchCriteria || rightCriteria.matchCriteria + override def validate(): Either[String, Unit] = + for { + _ <- leftCriteria.validate() + _ <- rightCriteria.validate() + } yield () } sealed trait ElasticFilter -sealed trait SQLCriteriaWithIdentifier extends SQLCriteria with SQLFunctionChain { - def identifier: SQLIdentifier - override def nested: Boolean = identifier.nested - override def group: Boolean = false - override lazy val limit: Option[SQLLimit] = identifier.limit - override val functions: List[SQLFunction] = identifier.functions -} - case class ElasticBoolQuery( var innerFilters: Seq[ElasticFilter] = Nil, var mustFilters: Seq[ElasticFilter] = Nil, @@ -151,12 +172,93 @@ case class ElasticBoolQuery( } -sealed trait Expression extends SQLCriteriaWithIdentifier with ElasticFilter { +sealed trait Expression extends SQLFunctionChain with ElasticFilter with SQLCriteria { // to fix output type as Boolean + def identifier: SQLIdentifier + override def nested: Boolean = identifier.nested + override def group: Boolean = false + override lazy val limit: Option[SQLLimit] = identifier.limit + override val functions: List[SQLFunction] = identifier.functions def maybeValue: Option[SQLToken] def maybeNot: Option[Not.type] def notAsString: String = maybeNot.map(v => s"$v ").getOrElse("") def valueAsString: String = maybeValue.map(v => s" $v").getOrElse("") override def sql = s"$identifier $notAsString$operator$valueAsString" + + override lazy val aggregation: Boolean = maybeValue match { + case Some(v: SQLFunctionChain) => identifier.aggregation || v.aggregation + case _ => identifier.aggregation + } + + def painlessNot: String = operator match { + case _: SQLComparisonOperator => "" + case _ => maybeNot.map(_.painless).getOrElse("") + } + + def painlessOp: String = operator match { + case o: SQLComparisonOperator if maybeNot.isDefined => o.not.painless + case _ => operator.painless + } + + def painlessValue: String = maybeValue + .map { + case v: SQLValue[_] => v.painless + case v: SQLValues[_, _] => v.painless + case v: SQLIdentifier => v.painless + case v => v.sql + } + .getOrElse("") /*{ + operator match { + case IsNull | IsNotNull => "null" + case _ => "" + } + }*/ + + protected lazy val left: String = { + val targetedType = maybeValue match { + case Some(v) => + v match { + case value: SQLValue[_] => value.out + case values: SQLValues[_, _] => values.out + case other => other.out + } + case None => identifier.out + } + SQLTypeUtils.coerce(identifier, targetedType) + } + + protected lazy val check: String = + operator match { + case _: SQLComparisonOperator => s" $painlessOp $painlessValue" + case _ => s"$painlessOp($painlessValue)" + } + + override def painless: String = { + if (identifier.nullable) { + return s"def left = $left; left == null ? false : ${painlessNot}left$check" + } + s"$painlessNot$left$check" + } + + override def validate(): Either[String, Unit] = { + for { + _ <- identifier.validate() + _ <- maybeValue match { + case Some(v) => + v.validate() match { + case Left(err) => Left(s"$err in expression: $this") + case Right(_) => + SQLValidator.validateTypesMatching(identifier.out, v.out) match { + case Left(_) => + Left( + s"Type mismatch: '${out.typeId}' is not compatible with '${v.out.typeId}' in expression: $this" + ) + case Right(_) => Right(()) + } + } + case _ => Right(()) + } + } yield () + } } case class SQLExpression( @@ -168,7 +270,12 @@ case class SQLExpression( override def maybeValue: Option[SQLToken] = Option(value) override def update(request: SQLSearchRequest): SQLCriteria = { - val updated = this.copy(identifier = identifier.update(request)) + val updated = + value match { + case id: SQLIdentifier => + this.copy(identifier = identifier.update(request), value = id.update(request)) + case _ => this.copy(identifier = identifier.update(request)) + } if (updated.nested) { ElasticNested(updated, limit) } else @@ -178,33 +285,6 @@ case class SQLExpression( override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this } -sealed trait BinaryExpression extends Expression { - def left: SQLIdentifier - def right: SQLIdentifier - override def identifier: SQLIdentifier = left - override def maybeValue: Option[SQLToken] = Some(right) - override lazy val aggregation: Boolean = left.aggregation || right.aggregation -} - -case class SQLBinaryExpression( - left: SQLIdentifier, - operator: SQLComparisonOperator, // Gt, Ge, Lt, Le, Eq, ... - right: SQLIdentifier, - maybeNot: Option[Not.type] = None -) extends BinaryExpression - with PainlessScript { - override def update(request: SQLSearchRequest): SQLCriteria = - this.copy(left = left.update(request), right = right.update(request)) - - override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this - - override def painless: String = { - val painlessOp = (if (maybeNot.isDefined) operator.not else operator).painless - s"${left.painless} $painlessOp ${right.painless}" - } - -} - case class SQLIsNull(identifier: SQLIdentifier) extends Expression { override val operator: SQLOperator = IsNull @@ -241,6 +321,66 @@ case class SQLIsNotNull(identifier: SQLIdentifier) extends Expression { override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this } +sealed trait SQLCriteriaWithConditionalFunction[In <: SQLType] extends Expression { + def conditionalFunction: SQLConditionalFunction[In] + override def maybeValue: Option[SQLToken] = None + override def maybeNot: Option[Not.type] = None + override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this + override val functions: List[SQLFunction] = List(conditionalFunction) + override def sql = s"${conditionalFunction.sql}($identifier)" +} + +object SQLConditionalFunctionAsCriteria { + def unapply(f: SQLConditionalFunction[_]): Option[SQLCriteria] = f match { + case SQLIsNullFunction(id) => Some(SQLIsNullCriteria(id)) + case SQLIsNotNullFunction(id) => Some(SQLIsNotNullCriteria(id)) + case _ => None + } +} + +case class SQLIsNullCriteria(identifier: SQLIdentifier) + extends SQLCriteriaWithConditionalFunction[SQLAny] { + override val conditionalFunction: SQLConditionalFunction[SQLAny] = SQLIsNullFunction(identifier) + override val operator: SQLOperator = IsNull + override def update(request: SQLSearchRequest): SQLCriteria = { + val updated = this.copy(identifier = identifier.update(request)) + if (updated.nested) { + ElasticNested(updated, limit) + } else + updated + } + override def painless: String = { + if (identifier.nullable) { + return s"def left = $left; left == null" + } + s"$painlessNot$left$check" + } + +} + +case class SQLIsNotNullCriteria(identifier: SQLIdentifier) + extends SQLCriteriaWithConditionalFunction[SQLAny] { + override val conditionalFunction: SQLConditionalFunction[SQLAny] = SQLIsNotNullFunction( + identifier + ) + override val operator: SQLOperator = IsNotNull + override def update(request: SQLSearchRequest): SQLCriteria = { + val updated = this.copy(identifier = identifier.update(request)) + if (updated.nested) { + ElasticNested(updated, limit) + } else + updated + } + + override def painless: String = { + if (identifier.nullable) { + return s"def left = $left; left != null" + } + s"$painlessNot$left$check" + } + +} + case class SQLIn[R, +T <: SQLValue[R]]( identifier: SQLIdentifier, values: SQLValues[R, T], @@ -264,6 +404,8 @@ case class SQLIn[R, +T <: SQLValue[R]]( override def maybeValue: Option[SQLToken] = Some(values) override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this + + override def painless: String = s"$painlessNot${identifier.painless}$painlessOp($painlessValue)" } case class SQLBetween[+T]( @@ -293,9 +435,9 @@ case class SQLBetween[+T]( case class ElasticGeoDistance( identifier: SQLIdentifier, - distance: SQLLiteral, - lat: SQLDouble, - lon: SQLDouble + distance: SQLStringValue, + lat: SQLDoubleValue, + lon: SQLDoubleValue ) extends Expression { override def sql = s"$Distance($identifier,($lat,$lon)) $operator $distance" override val functions: List[SQLFunction] = List(Distance) @@ -312,7 +454,7 @@ case class ElasticGeoDistance( case class SQLMatch( identifiers: Seq[SQLIdentifier], - value: SQLLiteral + value: SQLStringValue ) extends SQLCriteria { override def sql: String = s"$operator (${identifiers.mkString(",")}) $Against ($value)" @@ -322,17 +464,23 @@ case class SQLMatch( override lazy val nested: Boolean = identifiers.forall(_.nested) - lazy val criteria: SQLCriteria = { - identifiers.map(id => ElasticMatch(id, value, None)) match { + @tailrec + private[this] def toCriteria(matches: List[ElasticMatch], curr: SQLCriteria): SQLCriteria = + matches match { + case Nil => curr + case single :: Nil => SQLPredicate(curr, Or, single) + case first :: rest => toCriteria(rest, SQLPredicate(curr, Or, first)) + } + + lazy val criteria: SQLCriteria = + (identifiers.map(id => ElasticMatch(id, value, None)) match { case Nil => throw new IllegalArgumentException("No identifiers for MATCH") case single :: Nil => single - case first :: second :: rest => - val initial: SQLCriteria = SQLPredicate(first, Or, second) - rest.foldLeft(initial) { (acc, next) => - SQLPredicate(acc, Or, next) - } + case first :: rest => toCriteria(rest, first) + }) match { + case p: SQLPredicate => p.copy(group = true) + case other => other } - } override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = criteria match { case predicate: SQLPredicate => predicate.copy(group = true).asFilter(currentQuery) @@ -344,59 +492,9 @@ case class SQLMatch( override def group: Boolean = false } -case class SQLComparisonDateMath( - identifier: SQLIdentifier, - operator: SQLComparisonOperator, // Gt, Ge, Lt, Le, Eq, ... - dateTimeFunction: CurrentDateTimeFunction, // CurrentDate, Now, CurrentTimestamp, CurrentTime, ... - arithmeticOperator: Option[ArithmeticOperator] = - None, // Plus or Minus between dateTimeFunction and interval - interval: Option[TimeInterval] = None, // optional interval - maybeNot: Option[Not.type] = None -) extends Expression - with MathScript { - override def sql: String = { - s"$identifier ${operator.sql} $dateTimeFunction${asString(arithmeticOperator)}${asString(interval)}" - } - override def update(request: SQLSearchRequest): SQLCriteria = - this.copy(identifier = identifier.update(request)) - - override def maybeValue: Option[SQLToken] = Some(SQLScript(script)) - - override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this - - override def script: String = { - dateTimeFunction match { - case _: CurrentTimeFunction => - val painlessOp = (if (maybeNot.isDefined) operator.not else operator).painless - (arithmeticOperator, interval) match { - case (Some(Add), Some(i)) => // compare doc time with now + interval - s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now().plus(${i.value}, ${i.unit.painless});" - - case (Some(Subtract), Some(i)) => // compare doc time with now - s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now().minus(${i.value}, ${i.unit.painless});" - - case _ => - s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now();" - } - case _ => - val base = s"${dateTimeFunction.script}" - val dateMath = - (arithmeticOperator, interval) match { - case (Some(Add), Some(i)) => s"$base+${i.script}" - case (Some(Subtract), Some(i)) => s"$base-${i.script}" - case _ => base - } - dateTimeFunction match { - case _: CurrentDateFunction => s"$dateMath/d" - case _ => dateMath - } - } - } -} - case class ElasticMatch( identifier: SQLIdentifier, - value: SQLLiteral, + value: SQLStringValue, options: Option[String] ) extends Expression { override def sql: String = @@ -412,6 +510,8 @@ case class ElasticMatch( override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this override def matchCriteria: Boolean = true + + override def painless: String = s"$painlessNot${identifier.painless}$painlessOp($painlessValue)" } sealed abstract class ElasticRelation(val criteria: SQLCriteria, val operator: ElasticOperator) @@ -444,7 +544,7 @@ case class ElasticNested(override val criteria: SQLCriteria, override val limit: private[this] def name(criteria: SQLCriteria): Option[String] = criteria match { case SQLPredicate(left, _, right, _, _) => name(left).orElse(name(right)) - case c: SQLCriteriaWithIdentifier => + case c: Expression => c.identifier.innerHitsName.orElse(c.identifier.name.split('.').headOption) case n: ElasticNested => name(n.criteria) case _ => None @@ -473,4 +573,9 @@ case class SQLWhere(criteria: Option[SQLCriteria]) extends Updateable { def update(request: SQLSearchRequest): SQLWhere = this.copy(criteria = criteria.map(_.update(request))) + override def validate(): Either[String, Unit] = criteria match { + case Some(c) => c.validate() + case _ => Right(()) + } + } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index 35448221..48db25a6 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -16,13 +16,19 @@ package object sql { case _ => "" } - trait SQLToken extends Serializable { + trait SQLToken extends Serializable with SQLValidation { def sql: String override def toString: String = sql + def baseType: SQLType = SQLTypes.Any + def in: SQLType = baseType + def out: SQLType = baseType + def system: Boolean = false + def nullable: Boolean = !system } trait PainlessScript extends SQLToken { def painless: String + def nullValue: String = "null" } trait MathScript extends SQLToken { @@ -39,7 +45,8 @@ package object sql { abstract class SQLValue[+T](val value: T)(implicit ev$1: T => Ordered[T]) extends SQLToken - with PainlessScript { + with PainlessScript + with SQLFunctionWithValue[T] { def choose[R >: T]( values: Seq[R], operator: Option[SQLExpressionOperator], @@ -58,20 +65,35 @@ package object sql { case _ => values.headOption } } - def painless: String = value match { + override def painless: String = value match { case s: String => s""""$s"""" case b: Boolean => b.toString case n: Number => n.toString case _ => value.toString } + + override def nullable: Boolean = false + } + + case object SQLNull extends SQLValue[Null](null) { + override def sql: String = "null" + override def painless: String = "null" + override def nullable: Boolean = true + override def out: SQLType = SQLTypes.Null } case class SQLBoolean(override val value: Boolean) extends SQLValue[Boolean](value) { override def sql: String = value.toString + override def out: SQLType = SQLTypes.Boolean + } + + case class SQLCharValue(override val value: Char) extends SQLValue[Char](value) { + override def sql: String = s"""'$value'""" + override def out: SQLType = SQLTypes.Char } - case class SQLLiteral(override val value: String) extends SQLValue[String](value) { - override def sql: String = s""""$value"""" + case class SQLStringValue(override val value: String) extends SQLValue[String](value) { + override def sql: String = s"""'$value'""" import SQLImplicits._ private lazy val pattern: Pattern = value.pattern def like: Seq[String] => Boolean = { @@ -96,9 +118,10 @@ package object sql { case _ => super.choose(values, operator, separator) } } + override def out: SQLType = SQLTypes.Varchar } - sealed abstract class SQLNumeric[T: Numeric](override val value: T)(implicit + sealed abstract class SQLNumericValue[T: Numeric](override val value: T)(implicit ev$1: T => Ordered[T] ) extends SQLValue[T](value) { override def sql: String = value.toString @@ -129,17 +152,38 @@ package object sql { def ne: Seq[T] => Boolean = { _.forall { _ != value } } + override def out: SQLNumeric = SQLTypes.Numeric + } + + case class SQLByteValue(override val value: Byte) extends SQLNumericValue[Byte](value) { + override def out: SQLNumeric = SQLTypes.TinyInt + } + + case class SQLShortValue(override val value: Short) extends SQLNumericValue[Short](value) { + override def out: SQLNumeric = SQLTypes.SmallInt } - case class SQLLong(override val value: Long) extends SQLNumeric[Long](value) + case class SQLIntValue(override val value: Int) extends SQLNumericValue[Int](value) { + override def out: SQLNumeric = SQLTypes.Int + } + + case class SQLLongValue(override val value: Long) extends SQLNumericValue[Long](value) { + override def out: SQLNumeric = SQLTypes.BigInt + } + + case class SQLFloatValue(override val value: Float) extends SQLNumericValue[Float](value) { + override def out: SQLNumeric = SQLTypes.Real + } - case class SQLDouble(override val value: Double) extends SQLNumeric[Double](value) + case class SQLDoubleValue(override val value: Double) extends SQLNumericValue[Double](value) { + override def out: SQLNumeric = SQLTypes.Double + } sealed abstract class SQLFromTo[+T](val from: SQLValue[T], val to: SQLValue[T]) extends SQLToken { override def sql = s"${from.sql} and ${to.sql}" } - case class SQLLiteralFromTo(override val from: SQLLiteral, override val to: SQLLiteral) + case class SQLLiteralFromTo(override val from: SQLStringValue, override val to: SQLStringValue) extends SQLFromTo[String](from, to) { def between: Seq[String] => Boolean = { _.exists { s => s >= from.value && s <= to.value } @@ -149,7 +193,7 @@ package object sql { } } - case class SQLLongFromTo(override val from: SQLLong, override val to: SQLLong) + case class SQLLongFromTo(override val from: SQLLongValue, override val to: SQLLongValue) extends SQLFromTo[Long](from, to) { def between: Seq[Long] => Boolean = { _.exists { n => n >= from.value && n <= to.value } @@ -159,7 +203,7 @@ package object sql { } } - case class SQLDoubleFromTo(override val from: SQLDouble, override val to: SQLDouble) + case class SQLDoubleFromTo(override val from: SQLDoubleValue, override val to: SQLDoubleValue) extends SQLFromTo[Double](from, to) { def between: Seq[Double] => Boolean = { _.exists { n => n >= from.value && n <= to.value } @@ -170,12 +214,16 @@ package object sql { } sealed abstract class SQLValues[+R: TypeTag, +T <: SQLValue[R]](val values: Seq[T]) - extends SQLToken { + extends SQLToken + with PainlessScript { override def sql = s"(${values.map(_.sql).mkString(",")})" + override def painless: String = s"[${values.map(_.painless).mkString(",")}]" lazy val innerValues: Seq[R] = values.map(_.value) + override def nullable: Boolean = values.exists(_.nullable) + override def out: SQLArray = SQLTypes.Array(SQLTypes.Any) } - case class SQLLiteralValues(override val values: Seq[SQLLiteral]) + case class SQLStringValues(override val values: Seq[SQLStringValue]) extends SQLValues[String, SQLValue[String]](values) { def eq: Seq[String] => Boolean = { _.exists { s => innerValues.exists(_.contentEquals(s)) } @@ -183,22 +231,49 @@ package object sql { def ne: Seq[String] => Boolean = { _.forall { s => innerValues.forall(!_.contentEquals(s)) } } + override def out: SQLArray = SQLTypes.Array(SQLTypes.Varchar) } - class SQLNumericValues[R: TypeTag](override val values: Seq[SQLNumeric[R]]) - extends SQLValues[R, SQLNumeric[R]](values) { + class SQLNumericValues[R: TypeTag](override val values: Seq[SQLNumericValue[R]]) + extends SQLValues[R, SQLNumericValue[R]](values) { def eq: Seq[R] => Boolean = { _.exists { n => innerValues.contains(n) } } def ne: Seq[R] => Boolean = { _.forall { n => !innerValues.contains(n) } } + override def out: SQLArray = SQLTypes.Array(SQLTypes.Numeric) + } + + case class SQLByteValues(override val values: Seq[SQLByteValue]) + extends SQLNumericValues[Byte](values) { + override def out: SQLArray = SQLTypes.Array(SQLTypes.TinyInt) } - case class SQLLongValues(override val values: Seq[SQLLong]) extends SQLNumericValues[Long](values) + case class SQLShortValues(override val values: Seq[SQLShortValue]) + extends SQLNumericValues[Short](values) { + override def out: SQLArray = SQLTypes.Array(SQLTypes.SmallInt) + } + + case class SQLIntValues(override val values: Seq[SQLIntValue]) + extends SQLNumericValues[Int](values) { + override def out: SQLArray = SQLTypes.Array(SQLTypes.Int) + } + + case class SQLLongValues(override val values: Seq[SQLLongValue]) + extends SQLNumericValues[Long](values) { + override def out: SQLArray = SQLTypes.Array(SQLTypes.BigInt) + } - case class SQLDoubleValues(override val values: Seq[SQLDouble]) - extends SQLNumericValues[Double](values) + case class SQLFloatValues(override val values: Seq[SQLFloatValue]) + extends SQLNumericValues[Float](values) { + override def out: SQLArray = SQLTypes.Array(SQLTypes.Real) + } + + case class SQLDoubleValues(override val values: Seq[SQLDoubleValue]) + extends SQLNumericValues[Double](values) { + override def out: SQLArray = SQLTypes.Array(SQLTypes.Double) + } def choose[T]( values: Seq[T], @@ -256,6 +331,8 @@ package object sql { def fieldAlias: Option[String] def bucket: Option[SQLBucket] + applyTo(this) + lazy val identifierName: String = functions.reverse.foldLeft(name)((expr, fun) => { fun.toSQL(expr) @@ -267,16 +344,46 @@ package object sql { lazy val aliasOrName: String = fieldAlias.getOrElse(name) - override def painless: String = { - val base = if (name.nonEmpty) s"doc['$name'].value" else "" + def paramName: String = + if (aggregation && functions.size == 1) s"params.$aliasOrName" + else if (name.nonEmpty) + s"doc['$name'].value" + else "" + + def toPainless(base: String): String = { val orderedFunctions = SQLFunctionUtils.transformFunctions(this).reverse - orderedFunctions.foldLeft(base) { - case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr) - case (expr, f: PainlessScript) => s"$expr${f.painless}" - case (expr, f) => f.toSQL(expr) // fallback + var expr = base + orderedFunctions.zipWithIndex.foreach { case (f, idx) => + f match { + case f: SQLTransformFunction[_, _] => expr = f.toPainless(expr, idx) + case f: PainlessScript => expr = s"$expr${f.painless}" + case f => expr = f.toSQL(expr) // fallback + } } + expr + } + + def checkNotNull: String = + if (name.isEmpty) "" + else + s"(!doc.containsKey('$name') || doc['$name'].empty ? $nullValue : doc['$name'].value)" + + override def painless: String = toPainless( + if (nullable) + checkNotNull + else + paramName + ) + + private[this] var _nullable = + this.name.nonEmpty && (!aggregation || functions.size > 1) + + def nullable_=(b: Boolean): Unit = { + _nullable = b } + override def nullable: Boolean = _nullable + } case class SQLIdentifier( @@ -336,6 +443,4 @@ package object sql { } } } - - case class SQLScript(script: String) extends SQLExpr(script) } diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala index 2017093c..dc9c05fc 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala @@ -36,7 +36,7 @@ class SQLDateTimeFunctionSuite extends AnyFunSuite { require(transforms.nonEmpty, "No transforms provided") val initial: (String, SQLType) = - (transforms.head.toPainless(base), transforms.head.outputType.asInstanceOf[SQLType]) + (transforms.head.toPainless(base, 0), transforms.head.outputType.asInstanceOf[SQLType]) val (finalExpr, _) = transforms.tail.foldLeft(initial) { case ((expr, currentType), t: SQLUnaryFunction[_, _]) => @@ -45,7 +45,7 @@ class SQLDateTimeFunctionSuite extends AnyFunSuite { s"Type mismatch: expected ${currentType.getClass.getSimpleName}, got ${t.inputType.getClass.getSimpleName}" ) } - (t.toPainless(expr), t.outputType.asInstanceOf[SQLType]) + (t.toPainless(expr, 0), t.outputType.asInstanceOf[SQLType]) } finalExpr @@ -77,8 +77,8 @@ class SQLDateTimeFunctionSuite extends AnyFunSuite { val names = chain.map(_.sql).mkString(" -> ") test(s"Valid chain $idx: $names") { val chained = chainTransformsTyped(baseDate, chain) - val expected = chain.reverse.tail.foldLeft(chain.last.toPainless(baseDate)) { (expr, f) => - f.toPainless(expr) + val expected = chain.reverse.tail.foldLeft(chain.last.toPainless(baseDate, 0)) { (expr, f) => + f.toPainless(expr, 0) } // On ne teste que la génération de code Painless sans évaluer le résultat assert(chained.nonEmpty) @@ -88,7 +88,7 @@ class SQLDateTimeFunctionSuite extends AnyFunSuite { // Test simple pour chaque fonction individuelle transformFunctions.foreach { f => test(s"Single transformation ${f.sql}") { - val result = f.toPainless(baseDate) + val result = f.toPainless(baseDate, 0) assert(result.nonEmpty) } } diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index e961a9fc..972ba528 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -10,17 +10,17 @@ object Queries { val numericalGt = "select * from Table where identifier > 1" val numericalGe = "select * from Table where identifier >= 1" val numericalNe = "select * from Table where identifier <> 1" - val literalEq = """select * from Table where identifier = "un"""" - val literalLt = "select * from Table where createdAt < \"now-35M/M\"" - val literalLe = "select * from Table where createdAt <= \"now-35M/M\"" - val literalGt = "select * from Table where createdAt > \"now-35M/M\"" - val literalGe = "select * from Table where createdAt >= \"now-35M/M\"" - val literalNe = """select * from Table where identifier <> "un"""" + val literalEq = """select * from Table where identifier = 'un'""" + val literalLt = "select * from Table where createdAt < 'now-35M/M'" + val literalLe = "select * from Table where createdAt <= 'now-35M/M'" + val literalGt = "select * from Table where createdAt > 'now-35M/M'" + val literalGe = "select * from Table where createdAt >= 'now-35M/M'" + val literalNe = """select * from Table where identifier <> 'un'""" val boolEq = """select * from Table where identifier = true""" val boolNe = """select * from Table where identifier <> false""" - val literalLike = """select * from Table where identifier like "%un%"""" - val literalNotLike = """select * from Table where identifier not like "%un%"""" - val betweenExpression = """select * from Table where identifier between "1" and "2"""" + val literalLike = """select * from Table where identifier like '%un%'""" + val literalNotLike = """select * from Table where identifier not like '%un%'""" + val betweenExpression = """select * from Table where identifier between '1' and '2'""" val andPredicate = "select * from Table where identifier1 = 1 and identifier2 > 2" val orPredicate = "select * from Table where identifier1 = 1 or identifier2 > 2" val leftPredicate = @@ -40,28 +40,28 @@ object Queries { "select * from Table where identifier1 = 1 and parent(parent.identifier2 > 2 or parent.identifier3 = 3)" val parentCriteria = "select * from Table where identifier1 = 1 and parent(parent.identifier3 = 3)" - val inLiteralExpression = "select * from Table where identifier in (\"val1\",\"val2\",\"val3\")" + val inLiteralExpression = "select * from Table where identifier in ('val1','val2','val3')" val inNumericalExpressionWithIntValues = "select * from Table where identifier in (1,2,3)" val inNumericalExpressionWithDoubleValues = "select * from Table where identifier in (1.0,2.1,3.4)" val notInLiteralExpression = - "select * from Table where identifier not in (\"val1\",\"val2\",\"val3\")" + "select * from Table where identifier not in ('val1','val2','val3')" val notInNumericalExpressionWithIntValues = "select * from Table where identifier not in (1,2,3)" val notInNumericalExpressionWithDoubleValues = "select * from Table where identifier not in (1.0,2.1,3.4)" val nestedWithBetween = - "select * from Table where nested(ciblage.Archivage_CreationDate between \"now-3M/M\" and \"now\" and ciblage.statutComportement = 1)" - val count = "select count(t.id) as c1 from Table as t where t.nom = \"Nom\"" - val countDistinct = "select count(distinct t.id) as c2 from Table as t where t.nom = \"Nom\"" + "select * from Table where nested(ciblage.Archivage_CreationDate between 'now-3M/M' and 'now' and ciblage.statutComportement = 1)" + val count = "select count(t.id) as c1 from Table as t where t.nom = 'Nom'" + val countDistinct = "select count(distinct t.id) as c2 from Table as t where t.nom = 'Nom'" val countNested = - "select count(email.value) as email from crmgp where profile.postalCode in (\"75001\",\"75002\")" + "select count(email.value) as email from crmgp where profile.postalCode in ('75001','75002')" val isNull = "select * from Table where identifier is null" val isNotNull = "select * from Table where identifier is not null" val geoDistanceCriteria = - "select * from Table where distance(profile.location,(-70.0,40.0)) <= \"5km\"" + "select * from Table where distance(profile.location,(-70.0,40.0)) <= '5km'" val except = "select * except(col1,col2) from Table" val matchCriteria = - "select * from Table where match (identifier1,identifier2,identifier3) against (\"value\")" + "select * from Table where match (identifier1,identifier2,identifier3) against ('value')" val groupBy = "select identifier, count(identifier2) from Table where identifier2 is not null group by identifier" val orderBy = "select * from Table order by identifier desc" @@ -77,7 +77,7 @@ object Queries { """select count(CustomerID) as cnt, City, Country |from Customers |group by Country, City - |having Country <> "USA" and City <> "Berlin" and count(CustomerID) > 1 + |having Country <> 'USA' and City <> 'Berlin' and count(CustomerID) > 1 |order by count(CustomerID) desc, Country asc""".stripMargin.replaceAll("\n", " ") val dateTimeWithIntervalFields: String = "select current_timestamp() - interval 3 day as ct, current_date as cd, current_time as t, now as n from dual" @@ -93,7 +93,7 @@ object Queries { """select count(CustomerID) as cnt, City, Country, max(createdAt) as lastSeen |from Table |group by Country, City - |having Country <> "USA" and City != "Berlin" and count(CustomerID) > 1 and lastSeen > now - interval 7 day + |having Country <> 'USA' and City != 'Berlin' and count(CustomerID) > 1 and lastSeen > now - interval 7 day |order by Country asc""".stripMargin .replaceAll("\n", " ") val parseDate = @@ -133,6 +133,25 @@ object Queries { val dateTimeSub = "select identifier, datetime_sub(lastUpdated, interval 10 day) as lastSeen from Table where identifier2 is not null" + val isnull = "select isnull(identifier) as flag from Table" + val isnotnull = "select identifier, isnotnull(identifier2) as flag from Table" + val isNullCriteria = "select * from Table where isnull(identifier)" + val isNotNullCriteria = "select * from Table where isnotnull(identifier)" + val coalesce: String = + "select coalesce(createdAt - interval 35 minute, current_date) as c, identifier from Table" + val nullif: String = + "select coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd') - interval 2 day), current_date) as c, identifier from Table" + val cast: String = + "select cast(coalesce(nullif(createdAt, parse_date('2025-09-11', 'yyyy-MM-dd')), current_date - interval 2 hour) bigint) as c, identifier from Table" + val allCasts = + "select cast(identifier as int) as c1, cast(identifier as bigint) as c2, cast(identifier as double) as c3, cast(identifier as real) as c4, cast(identifier as boolean) as c5, cast(identifier as char) as c6, cast(identifier as varchar) as c7, cast(createdAt as date) as c8, cast(createdAt as time) as c9, cast(createdAt as datetime) as c10, cast(createdAt as timestamp) as c11, cast(identifier as smallint) as c12, cast(identifier as tinyint) as c13 from Table" + val caseWhen: String = + "select case when lastUpdated > now - interval 7 day then lastUpdated when isnotnull(lastSeen) then lastSeen + interval 2 day else createdAt end as c, identifier from Table" + val caseWhenExpr: String = + "select case current_date - interval 7 day when cast(lastUpdated as date) - interval 3 day then lastUpdated when lastSeen then lastSeen + interval 2 day else createdAt end as c, identifier from Table" + + val extract: String = + "select extract(day from createdAt) as day, extract(month from createdAt) as month, extract(year from createdAt) as year, extract(hour from createdAt) as hour, extract(minute from createdAt) as minute, extract(second from createdAt) as second from Table" } /** Created by smanciot on 15/02/17. @@ -506,4 +525,83 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { dateTimeSub ) } + + it should "parse isnull function" in { + val result = SQLParser(isnull) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + isnull + ) + } + + it should "parse isnotnull function" in { + val result = SQLParser(isnotnull) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + isnotnull + ) + } + + it should "parse isnull criteria" in { + val result = SQLParser(isNullCriteria) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + isNullCriteria + ) + } + + it should "parse isnotnull criteria" in { + val result = SQLParser(isNotNullCriteria) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + isNotNullCriteria + ) + } + + it should "parse coalesce function" in { + val result = SQLParser(coalesce) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + coalesce + ) + } + + it should "parse nullif function" in { + val result = SQLParser(nullif) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + nullif + ) + } + + it should "parse cast function" in { + val result = SQLParser(cast) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + cast + ) + } + + it should "parse all casts function" in { + val result = SQLParser(allCasts) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + allCasts + ) + } + + it should "parse case when expression" in { + val result = SQLParser(caseWhen) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + caseWhen + ) + } + + it should "parse case when with expression" in { + val result = SQLParser(caseWhenExpr) + result.toOption + .flatMap(_.left.toOption.map(_.sql)) + .getOrElse("") + .equalsIgnoreCase(caseWhenExpr) shouldBe true + } + + it should "parse extract function" in { + val result = SQLParser(extract) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + extract + ) + } + } diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLLiteralSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLStringValueSpec.scala similarity index 81% rename from sql/src/test/scala/app/softnetwork/elastic/sql/SQLLiteralSpec.scala rename to sql/src/test/scala/app/softnetwork/elastic/sql/SQLStringValueSpec.scala index 16b10632..7f524305 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLLiteralSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLStringValueSpec.scala @@ -5,10 +5,10 @@ import org.scalatest.matchers.should.Matchers /** Created by smanciot on 17/02/17. */ -class SQLLiteralSpec extends AnyFlatSpec with Matchers { +class SQLStringValueSpec extends AnyFlatSpec with Matchers { "SQLLiteral" should "perform sql like" in { - val l = SQLLiteral("%dummy%") + val l = SQLStringValue("%dummy%") l.like(Seq("dummy")) should ===(true) l.like(Seq("aa dummy")) should ===(true) l.like(Seq("dummy bbb")) should ===(true)