diff --git a/build.sbt b/build.sbt index 6f0014dc..bae919ad 100644 --- a/build.sbt +++ b/build.sbt @@ -19,7 +19,7 @@ ThisBuild / organization := "app.softnetwork" name := "softclient4es" -ThisBuild / version := "0.4.0" +ThisBuild / version := "0.5.0" ThisBuild / scalaVersion := scala213 diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index bb6fbcb2..86aeec7c 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -7,11 +7,12 @@ import app.softnetwork.elastic.sql.{ BucketSelectorScript, Count, ElasticBoolQuery, + Field, Max, Min, SQLBucket, SQLCriteria, - SQLField, + SQLFunctionUtils, SortOrder, Sum } @@ -57,7 +58,7 @@ case class ElasticAggregation( object ElasticAggregation { def apply( - sqlAgg: SQLField, + sqlAgg: Field, having: Option[SQLCriteria], bucketsDirection: Map[String, SortOrder] ): ElasticAggregation = { @@ -88,6 +89,23 @@ object ElasticAggregation { var aggPath = Seq[String]() + val (aggFuncs, transformFuncs) = SQLFunctionUtils.aggregateAndTransformFunctions(identifier) + + require(aggFuncs.size == 1, s"Multiple aggregate functions not supported: $aggFuncs") + + def aggWithFieldOrScript( + buildField: (String, String) => Aggregation, + buildScript: (String, Script) => Aggregation + ): Aggregation = { + if (transformFuncs.nonEmpty) { + val scriptSrc = identifier.painless + val script = Script(scriptSrc).lang("painless") + buildScript(aggName, script) + } else { + buildField(aggName, sourceField) + } + } + val _agg = aggType match { case Count => @@ -96,10 +114,10 @@ object ElasticAggregation { else { valueCountAgg(aggName, sourceField) } - case Min => minAgg(aggName, sourceField) - case Max => maxAgg(aggName, sourceField) - case Avg => avgAgg(aggName, sourceField) - case Sum => sumAgg(aggName, sourceField) + case Min => aggWithFieldOrScript(minAgg, (name, s) => minAgg(name, sourceField).script(s)) + case Max => aggWithFieldOrScript(maxAgg, (name, s) => maxAgg(name, sourceField).script(s)) + case Avg => aggWithFieldOrScript(avgAgg, (name, s) => avgAgg(name, sourceField).script(s)) + case Sum => aggWithFieldOrScript(sumAgg, (name, s) => sumAgg(name, sourceField).script(s)) } val filteredAggName = "filtered_agg" diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala index ec1d8bd7..c4943563 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala @@ -9,12 +9,13 @@ import app.softnetwork.elastic.sql.{ ElasticNested, ElasticParent, SQLBetween, + SQLComparisonDateMath, SQLExpression, SQLIn, SQLIsNotNull, SQLIsNull } -import com.sksamuel.elastic4s.ElasticApi.{bool, _} +import com.sksamuel.elastic4s.ElasticApi._ import com.sksamuel.elastic4s.searches.queries.Query case class ElasticQuery(filter: ElasticFilter) { @@ -69,6 +70,7 @@ case class ElasticQuery(filter: ElasticFilter) { case between: SQLBetween[Double] => between case geoDistance: ElasticGeoDistance => geoDistance case matchExpression: ElasticMatch => matchExpression + case dateMath: SQLComparisonDateMath => dateMath case other => throw new IllegalArgumentException(s"Unsupported filter type: ${other.getClass.getName}") } diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala index 35950d9c..3c451a43 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala @@ -1,11 +1,11 @@ package app.softnetwork.elastic.sql.bridge -import app.softnetwork.elastic.sql.{SQLBucket, SQLCriteria, SQLExcept, SQLField} +import app.softnetwork.elastic.sql.{Field, SQLBucket, SQLCriteria, SQLExcept} import com.sksamuel.elastic4s.searches.SearchRequest import com.sksamuel.elastic4s.http.search.SearchBodyBuilderFn case class ElasticSearchRequest( - fields: Seq[SQLField], + fields: Seq[Field], except: Option[SQLExcept], sources: Seq[String], criteria: Option[SQLCriteria], diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index b1167f7c..160ee4e2 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -4,6 +4,7 @@ import com.sksamuel.elastic4s.ElasticApi import com.sksamuel.elastic4s.ElasticApi._ import com.sksamuel.elastic4s.http.ElasticDsl.BuildableTermsNoOp import com.sksamuel.elastic4s.http.search.SearchBodyBuilderFn +import com.sksamuel.elastic4s.script.Script import com.sksamuel.elastic4s.searches.aggs.Aggregation import com.sksamuel.elastic4s.searches.queries.Query import com.sksamuel.elastic4s.searches.{MultiSearchRequest, SearchRequest} @@ -95,6 +96,17 @@ package object bridge { } } + _search = scriptFields match { + case Nil => _search + case _ => + _search scriptfields scriptFields.map { field => + scriptField( + field.scriptName, + Script(script = field.painless).lang("painless").scriptType("source") + ) + } + } + _search = orderBy match { case Some(o) if aggregates.isEmpty && buckets.isEmpty => _search sortBy o.sorts.map(sort => @@ -134,7 +146,7 @@ package object bridge { value match { case n: SQLNumeric[_] if !aggregation => operator match { - case _: Ge.type => + case Ge => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -147,7 +159,7 @@ package object bridge { d => rangeQuery(identifier.name) gte d ) } - case _: Gt.type => + case Gt => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -160,7 +172,7 @@ package object bridge { d => rangeQuery(identifier.name) gt d ) } - case _: Le.type => + case Le => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -173,7 +185,7 @@ package object bridge { d => rangeQuery(identifier.name) lte d ) } - case _: Lt.type => + case Lt => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -186,7 +198,7 @@ package object bridge { d => rangeQuery(identifier.name) lt d ) } - case _: Eq.type => + case Eq => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -199,7 +211,7 @@ package object bridge { d => termQuery(identifier.name, d) ) } - case _: Ne.type => + case Ne | Diff => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -216,49 +228,49 @@ package object bridge { } case l: SQLLiteral if !aggregation => operator match { - case _: Like.type => + case Like => maybeNot match { case Some(_) => not(regexQuery(identifier.name, toRegex(l.value))) case _ => regexQuery(identifier.name, toRegex(l.value)) } - case _: Ge.type => + case Ge => maybeNot match { case Some(_) => rangeQuery(identifier.name) lt l.value case _ => rangeQuery(identifier.name) gte l.value } - case _: Gt.type => + case Gt => maybeNot match { case Some(_) => rangeQuery(identifier.name) lte l.value case _ => rangeQuery(identifier.name) gt l.value } - case _: Le.type => + case Le => maybeNot match { case Some(_) => rangeQuery(identifier.name) gt l.value case _ => rangeQuery(identifier.name) lte l.value } - case _: Lt.type => + case Lt => maybeNot match { case Some(_) => rangeQuery(identifier.name) gte l.value case _ => rangeQuery(identifier.name) lt l.value } - case _: Eq.type => + case Eq => maybeNot match { case Some(_) => not(termQuery(identifier.name, l.value)) case _ => termQuery(identifier.name, l.value) } - case _: Ne.type => + case Ne | Diff => maybeNot match { case Some(_) => termQuery(identifier.name, l.value) @@ -269,14 +281,14 @@ package object bridge { } case b: SQLBoolean if !aggregation => operator match { - case _: Eq.type => + case Eq => maybeNot match { case Some(_) => not(termQuery(identifier.name, b.value)) case _ => termQuery(identifier.name, b.value) } - case _: Ne.type => + case Ne | Diff => maybeNot match { case Some(_) => termQuery(identifier.name, b.value) @@ -289,6 +301,26 @@ package object bridge { } } + implicit def dateMathToQuery(dateMath: SQLComparisonDateMath): Query = { + import dateMath._ + if (aggregation) + return matchAllQuery() + dateTimeFunction match { + case _: CurrentTimeFunction => + scriptQuery(Script(script = script).lang("painless").scriptType("source")) + case _ => + val op = if (maybeNot.isDefined) operator.not else operator + op match { + case Gt => rangeQuery(identifier.name) gt script + case Ge => rangeQuery(identifier.name) gte script + case Lt => rangeQuery(identifier.name) lt script + case Le => rangeQuery(identifier.name) lte script + case Eq => rangeQuery(identifier.name) gte script lte script + case Ne | Diff => not(rangeQuery(identifier.name) gte script lte script) + } + } + } + implicit def isNullToQuery( isNull: SQLIsNull ): Query = { diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 58f76b96..97892ef4 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -6,7 +6,7 @@ import com.google.gson.{JsonArray, JsonObject, JsonParser, JsonPrimitive} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -import scala.collection.JavaConverters._ +import scala.jdk.CollectionConverters._ /** Created by smanciot on 13/04/17. */ @@ -538,8 +538,8 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "must_not": [ | { | "term": { - | "country": { - | "value": "usa" + | "Country": { + | "value": "USA" | } | } | } @@ -551,8 +551,8 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "must_not": [ | { | "term": { - | "city": { - | "value": "berlin" + | "City": { + | "value": "Berlin" | } | } | } @@ -566,17 +566,17 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } | }, | "aggs": { - | "country": { + | "Country": { | "terms": { - | "field": "country.keyword", + | "field": "Country.keyword", | "order": { - | "country": "asc" + | "Country": "asc" | } | }, | "aggs": { - | "city": { + | "City": { | "terms": { - | "field": "city.keyword", + | "field": "City.keyword", | "order": { | "cnt": "desc" | } @@ -584,7 +584,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "aggs": { | "cnt": { | "value_count": { - | "field": "customerid" + | "field": "CustomerID" | } | }, | "having_filter": { @@ -868,4 +868,611 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { } + it should "add script fields" in { + val select: ElasticSearchRequest = + SQLQuery(fieldsWithInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "ct": { + | "script": { + | "lang": "painless", + | "source": "doc['createdAt'].value.minus(35, ChronoUnit.MINUTES)" + | } + | } + | }, + | "_source": { + | "includes": ["identifier"] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "filter with date time and interval" in { + val select: ElasticSearchRequest = + SQLQuery(filterWithDateTimeAndInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "range": { + | "createdAt": { + | "lt": "now" + | } + | } + | }, + | { + | "range": { + | "createdAt": { + | "gte": "now-10d" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "filter with date and interval" in { + val select: ElasticSearchRequest = + SQLQuery(filterWithDateAndInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "range": { + | "createdAt": { + | "lt": "now/d" + | } + | } + | }, + | { + | "range": { + | "createdAt": { + | "gte": "now-10d/d" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "filter with time and interval" in { + val select: ElasticSearchRequest = + SQLQuery(filterWithTimeAndInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "return doc['createdAt'].value.toLocalTime() < LocalTime.now();" + | } + | } + | }, + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "return doc['createdAt'].value.toLocalTime() >= LocalTime.now().minus(10, ChronoUnit.MINUTES);" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(">=", " >= ") + .replaceAll("<", " < ") + .replaceAll("return", "return ") + } + + it should "handle having with date functions" in { + val select: ElasticSearchRequest = + SQLQuery("""SELECT userId, MAX(createdAt) as lastSeen + |FROM table + |GROUP BY userId + |HAVING MAX(createdAt) > now - interval 7 day""".stripMargin) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "filtered_agg": { + | "filter": { + | "match_all": {} + | }, + | "aggs": { + | "userId": { + | "terms": { + | "field": "userId.keyword" + | }, + | "aggs": { + | "lastSeen": { + | "max": { + | "field": "createdAt" + | } + | }, + | "having_filter": { + | "bucket_selector": { + | "buckets_path": { + | "lastSeen": "lastSeen" + | }, + | "script": { + | "source": "(params.lastSeen != null) && (params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS).toInstant().toEpochMilli())" + | } + | } + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + } + + it should "handle group by with having and date time functions" in { + val select: ElasticSearchRequest = + SQLQuery(groupByWithHavingAndDateTimeFunctions) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "filtered_agg": { + | "filter": { + | "bool": { + | "filter": [ + | { + | "bool": { + | "must_not": [ + | { + | "term": { + | "Country": { + | "value": "USA" + | } + | } + | } + | ] + | } + | }, + | { + | "bool": { + | "must_not": [ + | { + | "term": { + | "City": { + | "value": "Berlin" + | } + | } + | } + | ] + | } + | }, + | { + | "match_all": {} + | }, + | { + | "range": { + | "lastSeen": { + | "gt": "now-7d" + | } + | } + | } + | ] + | } + | }, + | "aggs": { + | "Country": { + | "terms": { + | "field": "Country.keyword", + | "order": { + | "Country": "asc" + | } + | }, + | "aggs": { + | "City": { + | "terms": { + | "field": "City.keyword" + | }, + | "aggs": { + | "cnt": { + | "value_count": { + | "field": "CustomerID" + | } + | }, + | "lastSeen": { + | "max": { + | "field": "createdAt" + | } + | }, + | "having_filter": { + | "bucket_selector": { + | "buckets_path": { + | "cnt": "cnt" + | }, + | "script": { + | "source": "1 == 1 && 1 == 1 && params.cnt > 1 && 1 == 1" + | } + | } + | } + | } + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + } + + it should "handle parse_date function" in { + val select: ElasticSearchRequest = + SQLQuery(parseDate) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "identifier": { + | "terms": { + | "field": "identifier.keyword", + | "order": { + | "ct": "desc" + | } + | }, + | "aggs": { + | "ct": { + | "value_count": { + | "field": "identifier2" + | } + | }, + | "lastSeen": { + | "max": { + | "field": "createdAt", + | "script": { + | "lang": "painless", + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(doc['createdAt'].value, LocalDate::from)" + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll(",ChronoUnit", ", ChronoUnit") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + .replaceAll(",LocalDate", ", LocalDate") + } + + it should "handle parse_datetime function" in { + val select: ElasticSearchRequest = + SQLQuery(parseDateTime) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "identifier": { + | "terms": { + | "field": "identifier.keyword", + | "order": { + | "ct": "desc" + | } + | }, + | "aggs": { + | "ct": { + | "value_count": { + | "field": "identifier2" + | } + | }, + | "lastSeen": { + | "max": { + | "field": "createdAt", + | "script": { + | "lang": "painless", + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from).truncatedTo(ChronoUnit.MINUTES).get(ChronoUnit.YEARS)" + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + .replaceAll(",ZonedDateTime", ", ZonedDateTime") + } + + it should "handle date_diff function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateDiff) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "diff": { + | "script": { + | "lang": "painless", + | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, doc['createdAt'].value)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll(",doc", ", doc") + } + + it should "handle aggregation with date_diff function" in { + val select: ElasticSearchRequest = + SQLQuery(aggregationWithDateDiff) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "identifier": { + | "terms": { + | "field": "identifier.keyword" + | }, + | "aggs": { + | "max_diff": { + | "max": { + | "script": { + | "lang": "painless", + | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from))" + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll(",doc", ", doc") + .replaceAll("DateTimeFormatter", " DateTimeFormatter") + .replaceAll("ZonedDateTime", " ZonedDateTime") + } + + it should "handle date_add function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateAdd) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle date_sub function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateSub) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle datetime_add function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateTimeAdd) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle datetime_sub function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateTimeSub) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + } + } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index cf365736..1bedbaa4 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -7,11 +7,12 @@ import app.softnetwork.elastic.sql.{ BucketSelectorScript, Count, ElasticBoolQuery, + Field, Max, Min, SQLBucket, SQLCriteria, - SQLField, + SQLFunctionUtils, SortOrder, Sum } @@ -56,7 +57,7 @@ case class ElasticAggregation( object ElasticAggregation { def apply( - sqlAgg: SQLField, + sqlAgg: Field, having: Option[SQLCriteria], bucketsDirection: Map[String, SortOrder] ): ElasticAggregation = { @@ -87,6 +88,23 @@ object ElasticAggregation { var aggPath = Seq[String]() + val (aggFuncs, transformFuncs) = SQLFunctionUtils.aggregateAndTransformFunctions(identifier) + + require(aggFuncs.size == 1, s"Multiple aggregate functions not supported: $aggFuncs") + + def aggWithFieldOrScript( + buildField: (String, String) => Aggregation, + buildScript: (String, Script) => Aggregation + ): Aggregation = { + if (transformFuncs.nonEmpty) { + val scriptSrc = identifier.painless + val script = Script(scriptSrc).lang("painless") + buildScript(aggName, script) + } else { + buildField(aggName, sourceField) + } + } + val _agg = aggType match { case Count => @@ -95,10 +113,10 @@ object ElasticAggregation { else { valueCountAgg(aggName, sourceField) } - case Min => minAgg(aggName, sourceField) - case Max => maxAgg(aggName, sourceField) - case Avg => avgAgg(aggName, sourceField) - case Sum => sumAgg(aggName, sourceField) + case Min => aggWithFieldOrScript(minAgg, (name, s) => minAgg(name, sourceField).script(s)) + case Max => aggWithFieldOrScript(maxAgg, (name, s) => maxAgg(name, sourceField).script(s)) + case Avg => aggWithFieldOrScript(avgAgg, (name, s) => avgAgg(name, sourceField).script(s)) + case Sum => aggWithFieldOrScript(sumAgg, (name, s) => sumAgg(name, sourceField).script(s)) } val filteredAggName = "filtered_agg" diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala index 746d4af8..9de90dd8 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala @@ -9,12 +9,13 @@ import app.softnetwork.elastic.sql.{ ElasticNested, ElasticParent, SQLBetween, + SQLComparisonDateMath, SQLExpression, SQLIn, SQLIsNotNull, SQLIsNull } -import com.sksamuel.elastic4s.ElasticApi.{bool, _} +import com.sksamuel.elastic4s.ElasticApi._ import com.sksamuel.elastic4s.requests.searches.queries.Query case class ElasticQuery(filter: ElasticFilter) { @@ -69,6 +70,7 @@ case class ElasticQuery(filter: ElasticFilter) { case between: SQLBetween[Double] => between case geoDistance: ElasticGeoDistance => geoDistance case matchExpression: ElasticMatch => matchExpression + case dateMath: SQLComparisonDateMath => dateMath case other => throw new IllegalArgumentException(s"Unsupported filter type: ${other.getClass.getName}") } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala index bfd1bc7f..adcf87e1 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala @@ -1,10 +1,10 @@ package app.softnetwork.elastic.sql.bridge -import app.softnetwork.elastic.sql.{SQLBucket, SQLCriteria, SQLExcept, SQLField} +import app.softnetwork.elastic.sql.{SQLBucket, SQLCriteria, SQLExcept, Field} import com.sksamuel.elastic4s.requests.searches.{SearchBodyBuilderFn, SearchRequest} case class ElasticSearchRequest( - fields: Seq[SQLField], + fields: Seq[Field], except: Option[SQLExcept], sources: Seq[String], criteria: Option[SQLCriteria], diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index e051841d..b2edb050 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -2,6 +2,7 @@ package app.softnetwork.elastic.sql import com.sksamuel.elastic4s.ElasticApi import com.sksamuel.elastic4s.ElasticApi._ +import com.sksamuel.elastic4s.requests.script.Script import com.sksamuel.elastic4s.requests.searches.aggs.Aggregation import com.sksamuel.elastic4s.requests.searches.queries.Query import com.sksamuel.elastic4s.requests.searches.sort.FieldSort @@ -96,6 +97,17 @@ package object bridge { } } + _search = scriptFields match { + case Nil => _search + case _ => + _search scriptfields scriptFields.map { field => + scriptField( + field.scriptName, + Script(script = field.painless).lang("painless").scriptType("source") + ) + } + } + _search = orderBy match { case Some(o) if aggregates.isEmpty && buckets.isEmpty => _search sortBy o.sorts.map(sort => @@ -135,7 +147,7 @@ package object bridge { value match { case n: SQLNumeric[_] if !aggregation => operator match { - case _: Ge.type => + case Ge => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -148,7 +160,7 @@ package object bridge { d => rangeQuery(identifier.name) gte d ) } - case _: Gt.type => + case Gt => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -161,7 +173,7 @@ package object bridge { d => rangeQuery(identifier.name) gt d ) } - case _: Le.type => + case Le => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -174,7 +186,7 @@ package object bridge { d => rangeQuery(identifier.name) lte d ) } - case _: Lt.type => + case Lt => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -187,7 +199,7 @@ package object bridge { d => rangeQuery(identifier.name) lt d ) } - case _: Eq.type => + case Eq => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -200,7 +212,7 @@ package object bridge { d => termQuery(identifier.name, d) ) } - case _: Ne.type => + case Ne | Diff => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -217,49 +229,49 @@ package object bridge { } case l: SQLLiteral if !aggregation => operator match { - case _: Like.type => + case Like => maybeNot match { case Some(_) => not(regexQuery(identifier.name, toRegex(l.value))) case _ => regexQuery(identifier.name, toRegex(l.value)) } - case _: Ge.type => + case Ge => maybeNot match { case Some(_) => rangeQuery(identifier.name) lt l.value case _ => rangeQuery(identifier.name) gte l.value } - case _: Gt.type => + case Gt => maybeNot match { case Some(_) => rangeQuery(identifier.name) lte l.value case _ => rangeQuery(identifier.name) gt l.value } - case _: Le.type => + case Le => maybeNot match { case Some(_) => rangeQuery(identifier.name) gt l.value case _ => rangeQuery(identifier.name) lte l.value } - case _: Lt.type => + case Lt => maybeNot match { case Some(_) => rangeQuery(identifier.name) gte l.value case _ => rangeQuery(identifier.name) lt l.value } - case _: Eq.type => + case Eq => maybeNot match { case Some(_) => not(termQuery(identifier.name, l.value)) case _ => termQuery(identifier.name, l.value) } - case _: Ne.type => + case Ne | Diff => maybeNot match { case Some(_) => termQuery(identifier.name, l.value) @@ -270,14 +282,14 @@ package object bridge { } case b: SQLBoolean if !aggregation => operator match { - case _: Eq.type => + case Eq => maybeNot match { case Some(_) => not(termQuery(identifier.name, b.value)) case _ => termQuery(identifier.name, b.value) } - case _: Ne.type => + case Ne | Diff => maybeNot match { case Some(_) => termQuery(identifier.name, b.value) @@ -290,6 +302,26 @@ package object bridge { } } + implicit def dateMathToQuery(dateMath: SQLComparisonDateMath): Query = { + import dateMath._ + if (aggregation) + return matchAllQuery() + dateTimeFunction match { + case _: CurrentTimeFunction => + scriptQuery(Script(script = script).lang("painless").scriptType("source")) + case _ => + val op = if (maybeNot.isDefined) operator.not else operator + op match { + case Gt => rangeQuery(identifier.name) gt script + case Ge => rangeQuery(identifier.name) gte script + case Lt => rangeQuery(identifier.name) lt script + case Le => rangeQuery(identifier.name) lte script + case Eq => rangeQuery(identifier.name) gte script lte script + case Ne | Diff => not(rangeQuery(identifier.name) gte script lte script) + } + } + } + implicit def isNullToQuery( isNull: SQLIsNull ): Query = { diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index ce96ddf5..d7f25e09 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -6,7 +6,7 @@ import com.google.gson.{JsonArray, JsonObject, JsonParser, JsonPrimitive} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -import scala.collection.JavaConverters._ +import scala.jdk.CollectionConverters._ /** Created by smanciot on 13/04/17. */ @@ -538,8 +538,8 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "must_not": [ | { | "term": { - | "country": { - | "value": "usa" + | "Country": { + | "value": "USA" | } | } | } @@ -551,8 +551,8 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "must_not": [ | { | "term": { - | "city": { - | "value": "berlin" + | "City": { + | "value": "Berlin" | } | } | } @@ -566,17 +566,17 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } | }, | "aggs": { - | "country": { + | "Country": { | "terms": { - | "field": "country.keyword", + | "field": "Country.keyword", | "order": { - | "country": "asc" + | "Country": "asc" | } | }, | "aggs": { - | "city": { + | "City": { | "terms": { - | "field": "city.keyword", + | "field": "City.keyword", | "order": { | "cnt": "desc" | } @@ -584,7 +584,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "aggs": { | "cnt": { | "value_count": { - | "field": "customerid" + | "field": "CustomerID" | } | }, | "having_filter": { @@ -867,4 +867,609 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { } + it should "add script fields" in { + val select: ElasticSearchRequest = + SQLQuery(fieldsWithInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "ct": { + | "script": { + | "lang": "painless", + | "source": "doc['createdAt'].value.minus(35, ChronoUnit.MINUTES)" + | } + | } + | }, + | "_source": { + | "includes": ["identifier"] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "filter with date time and interval" in { + val select: ElasticSearchRequest = + SQLQuery(filterWithDateTimeAndInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "range": { + | "createdAt": { + | "lt": "now" + | } + | } + | }, + | { + | "range": { + | "createdAt": { + | "gte": "now-10d" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "filter with date and interval" in { + val select: ElasticSearchRequest = + SQLQuery(filterWithDateAndInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "range": { + | "createdAt": { + | "lt": "now/d" + | } + | } + | }, + | { + | "range": { + | "createdAt": { + | "gte": "now-10d/d" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "filter with time and interval" in { + val select: ElasticSearchRequest = + SQLQuery(filterWithTimeAndInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "return doc['createdAt'].value.toLocalTime() < LocalTime.now();" + | } + | } + | }, + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "return doc['createdAt'].value.toLocalTime() >= LocalTime.now().minus(10, ChronoUnit.MINUTES);" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(">=", " >= ") + .replaceAll("<", " < ") + .replaceAll("return", "return ") + } + + it should "handle having with date functions" in { + val select: ElasticSearchRequest = + SQLQuery("""SELECT userId, MAX(createdAt) as lastSeen + |FROM table + |GROUP BY userId + |HAVING MAX(createdAt) > now - interval 7 day""".stripMargin) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "filtered_agg": { + | "filter": { + | "match_all": {} + | }, + | "aggs": { + | "userId": { + | "terms": { + | "field": "userId.keyword" + | }, + | "aggs": { + | "lastSeen": { + | "max": { + | "field": "createdAt" + | } + | }, + | "having_filter": { + | "bucket_selector": { + | "buckets_path": { + | "lastSeen": "lastSeen" + | }, + | "script": { + | "source": "(params.lastSeen != null) && (params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS).toInstant().toEpochMilli())" + | } + | } + | } + | } + | } + | } + | } + | } + |}""".stripMargin.replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + } + + it should "handle group by with having and date time functions" in { + val select: ElasticSearchRequest = + SQLQuery(groupByWithHavingAndDateTimeFunctions) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "filtered_agg": { + | "filter": { + | "bool": { + | "filter": [ + | { + | "bool": { + | "must_not": [ + | { + | "term": { + | "Country": { + | "value": "USA" + | } + | } + | } + | ] + | } + | }, + | { + | "bool": { + | "must_not": [ + | { + | "term": { + | "City": { + | "value": "Berlin" + | } + | } + | } + | ] + | } + | }, + | { + | "match_all": {} + | }, + | { + | "range": { + | "lastSeen": { + | "gt": "now-7d" + | } + | } + | } + | ] + | } + | }, + | "aggs": { + | "Country": { + | "terms": { + | "field": "Country.keyword", + | "order": { + | "Country": "asc" + | } + | }, + | "aggs": { + | "City": { + | "terms": { + | "field": "City.keyword" + | }, + | "aggs": { + | "cnt": { + | "value_count": { + | "field": "CustomerID" + | } + | }, + | "lastSeen": { + | "max": { + | "field": "createdAt" + | } + | }, + | "having_filter": { + | "bucket_selector": { + | "buckets_path": { + | "cnt": "cnt" + | }, + | "script": { + | "source": "1 == 1 && 1 == 1 && params.cnt > 1 && 1 == 1" + | } + | } + | } + | } + | } + | } + | } + | } + | } + | } + |}""".stripMargin.replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + } + + it should "handle parse_date function" in { + val select: ElasticSearchRequest = + SQLQuery(parseDate) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "identifier": { + | "terms": { + | "field": "identifier.keyword", + | "order": { + | "ct": "desc" + | } + | }, + | "aggs": { + | "ct": { + | "value_count": { + | "field": "identifier2" + | } + | }, + | "lastSeen": { + | "max": { + | "field": "createdAt", + | "script": { + | "lang": "painless", + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(doc['createdAt'].value, LocalDate::from)" + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + .replaceAll(",LocalDate", ", LocalDate") + } + + it should "handle parse_datetime function" in { + val select: ElasticSearchRequest = + SQLQuery(parseDateTime) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "identifier": { + | "terms": { + | "field": "identifier.keyword", + | "order": { + | "ct": "desc" + | } + | }, + | "aggs": { + | "ct": { + | "value_count": { + | "field": "identifier2" + | } + | }, + | "lastSeen": { + | "max": { + | "field": "createdAt", + | "script": { + | "lang": "painless", + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from).truncatedTo(ChronoUnit.MINUTES).get(ChronoUnit.YEARS)" + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + .replaceAll(",ZonedDateTime", ", ZonedDateTime") + } + + it should "handle date_diff function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateDiff) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "diff": { + | "script": { + | "lang": "painless", + | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, doc['createdAt'].value)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll(",doc", ", doc") + } + + it should "handle aggregation with date_diff function" in { + val select: ElasticSearchRequest = + SQLQuery(aggregationWithDateDiff) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "identifier": { + | "terms": { + | "field": "identifier.keyword" + | }, + | "aggs": { + | "max_diff": { + | "max": { + | "script": { + | "lang": "painless", + | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from))" + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll(",doc", ", doc") + .replaceAll("DateTimeFormatter", " DateTimeFormatter") + .replaceAll("ZonedDateTime", " ZonedDateTime") + } + + it should "handle date_add function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateAdd) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle date_sub function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateSub) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle datetime_add function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateTimeAdd) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle datetime_sub function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateTimeSub) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + } + } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 31b1fe81..96ca2453 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -1,6 +1,88 @@ package app.softnetwork.elastic.sql -sealed trait SQLFunction extends SQLRegex +import scala.util.matching.Regex + +sealed trait SQLFunction extends SQLRegex { + def toSQL(base: String): String = if (base.nonEmpty) s"$sql($base)" else sql +} + +sealed trait SQLFunctionWithIdentifier extends SQLFunction { + def identifier: SQLIdentifier +} + +object SQLFunctionUtils { + def aggregateAndTransformFunctions( + identifier: Identifier + ): (List[SQLFunction], List[SQLFunction]) = { + identifier.functions.partition { + case _: AggregateFunction => true + case _ => false + } + } + + def transformFunctions(identifier: Identifier): List[SQLFunction] = { + aggregateAndTransformFunctions(identifier)._2 + } + +} + +trait SQLFunctionChain extends SQLFunction with SQLValidation { + def functions: List[SQLFunction] + + override def validate(): Either[String, Unit] = + SQLValidator.validateChain(functions) + + override def toSQL(base: String): String = + functions.reverse.foldLeft(base)((expr, fun) => { + fun.toSQL(expr) + }) + + lazy val aggregateFunction: Option[AggregateFunction] = functions.headOption match { + case Some(af: AggregateFunction) => Some(af) + case _ => None + } + + lazy val aggregation: Boolean = aggregateFunction.isDefined +} + +sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType] + extends SQLFunction + with PainlessScript { + def inputType: In + def outputType: Out +} + +sealed trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType] + extends SQLUnaryFunction[SQLAny, Out] { self: SQLFunction => + + override def inputType: SQLAny = SQLTypes.Any + + def left: PainlessScript + def right: PainlessScript + +} + +sealed trait SQLTransformFunction[In <: SQLType, Out <: SQLType] extends SQLUnaryFunction[In, Out] { + def toPainless(base: String): String = s"$base$painless" +} + +sealed trait SQLArithmeticFunction[In <: SQLType, Out <: SQLType] + extends SQLTransformFunction[In, Out] { + def operator: ArithmeticOperator + override def toSQL(base: String): String = s"$base$operator$sql" +} + +sealed trait ParametrizedFunction extends SQLFunction { + def params: Seq[String] + override def toSQL(base: String): String = { + params match { + case Nil => s"$sql($base)" + case _ => + val paramsStr = params.mkString(", ") + s"$sql($paramsStr)($base)" + } + } +} sealed trait AggregateFunction extends SQLFunction case object Count extends SQLExpr("count") with AggregateFunction @@ -10,3 +92,307 @@ case object Avg extends SQLExpr("avg") with AggregateFunction case object Sum extends SQLExpr("sum") with AggregateFunction case object Distance extends SQLExpr("distance") with SQLFunction with SQLOperator + +sealed trait TimeUnit extends PainlessScript with MathScript { + lazy val regex: Regex = s"\\b(?i)$sql(s)?\\b".r + + override def painless: String = s"ChronoUnit.${sql.toUpperCase()}S" +} + +sealed trait CalendarUnit extends TimeUnit +sealed trait FixedUnit extends TimeUnit + +object TimeUnit { + case object Year extends SQLExpr("year") with CalendarUnit { + override def script: String = "y" + } + case object Month extends SQLExpr("month") with CalendarUnit { + override def script: String = "M" + } + case object Quarter extends SQLExpr("quarter") with CalendarUnit { + override def script: String = throw new IllegalArgumentException( + "Quarter must be converted to months (value * 3) before creating date-math" + ) + } + case object Week extends SQLExpr("week") with CalendarUnit { + override def script: String = "w" + } + + case object Day extends SQLExpr("day") with CalendarUnit with FixedUnit { + override def script: String = "d" + } + + case object Hour extends SQLExpr("hour") with FixedUnit { + override def script: String = "H" + } + case object Minute extends SQLExpr("minute") with FixedUnit { + override def script: String = "m" + } + case object Second extends SQLExpr("second") with FixedUnit { + override def script: String = "s" + } + +} + +case object Interval extends SQLExpr("interval") with SQLFunction with SQLRegex + +sealed trait TimeInterval extends PainlessScript with MathScript { + def value: Int + def unit: TimeUnit + override def sql: String = s"$Interval $value ${unit.sql}" + + override def painless: String = s"$value, ${unit.painless}" + + override def script: String = TimeInterval.script(this) +} + +import TimeUnit._ + +case class CalendarInterval(value: Int, unit: CalendarUnit) extends TimeInterval +case class FixedInterval(value: Int, unit: FixedUnit) extends TimeInterval + +object TimeInterval { + def apply(value: Int, unit: TimeUnit): TimeInterval = unit match { + case cu: CalendarUnit => CalendarInterval(value, cu) + case fu: FixedUnit => FixedInterval(value, fu) + } + def script(interval: TimeInterval): String = interval match { + case CalendarInterval(v, Quarter) => s"${v * 3}M" + case CalendarInterval(v, u) => s"$v${u.script}" + case FixedInterval(v, u) => s"$v${u.script}" + } +} + +case class SQLAddInterval(interval: TimeInterval) + extends SQLExpr(interval.sql) + with SQLArithmeticFunction[SQLDateTime, SQLDateTime] + with MathScript { + override def operator: ArithmeticOperator = Add + override def inputType: SQLDateTime = SQLTypes.DateTime + override def outputType: SQLDateTime = SQLTypes.DateTime + override def painless: String = s".plus(${interval.painless})" + override def script: String = s"${operator.script}${interval.script}" +} + +case class SQLSubstractInterval(interval: TimeInterval) + extends SQLExpr(interval.sql) + with SQLArithmeticFunction[SQLDateTime, SQLDateTime] + with MathScript { + override def operator: ArithmeticOperator = Subtract + override def inputType: SQLDateTime = SQLTypes.DateTime + override def outputType: SQLDateTime = SQLTypes.DateTime + override def painless: String = s".minus(${interval.painless})" + override def script: String = s"${operator.script}${interval.script}" +} + +sealed trait DateTimeFunction extends SQLFunction + +sealed trait DateFunction extends DateTimeFunction + +sealed trait TimeFunction extends DateTimeFunction + +sealed trait CurrentDateTimeFunction extends DateTimeFunction with PainlessScript with MathScript { + override def painless: String = + "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" + override def script: String = "now" +} + +sealed trait CurrentDateFunction extends CurrentDateTimeFunction with DateFunction { + override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalDate()" +} + +sealed trait CurrentTimeFunction extends CurrentDateTimeFunction with TimeFunction { + override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalTime()" +} + +case object CurrentDate extends SQLExpr("current_date") with CurrentDateFunction + +case object CurentDateWithParens extends SQLExpr("current_date()") with CurrentDateFunction + +case object CurrentTime extends SQLExpr("current_time") with CurrentTimeFunction + +case object CurrentTimeWithParens extends SQLExpr("current_time()") with CurrentTimeFunction + +case object CurrentTimestamp extends SQLExpr("current_timestamp") with CurrentDateTimeFunction + +case object CurrentTimestampWithParens + extends SQLExpr("current_timestamp()") + with CurrentDateTimeFunction + +case object Now extends SQLExpr("now") with CurrentDateTimeFunction + +case object NowWithParens extends SQLExpr("now()") with CurrentDateTimeFunction + +case class DateTrunc(identifier: SQLIdentifier, unit: TimeUnit) + extends SQLExpr("date_trunc") + with DateTimeFunction + with SQLTransformFunction[SQLTemporal, SQLTemporal] + with SQLFunctionWithIdentifier { + override def inputType: SQLTemporal = SQLTypes.Temporal // par défaut + override def outputType: SQLTemporal = SQLTypes.Temporal // idem + override def toSQL(base: String): String = { + s"$sql($base, ${unit.sql})" + } + override def painless: String = s".truncatedTo(${unit.painless})" +} + +case class Extract(unit: TimeUnit, override val sql: String = "extract") + extends SQLExpr(sql) + with DateTimeFunction + with SQLTransformFunction[SQLTemporal, SQLNumber] + with ParametrizedFunction { + override def inputType: SQLTemporal = SQLTypes.Temporal + override def outputType: SQLNumber = SQLTypes.Number + override def params: Seq[String] = Seq(unit.sql) + override def painless: String = s".get(${unit.painless})" +} + +object YEAR extends Extract(Year, Year.sql) { + override def params: Seq[String] = Seq.empty +} + +object MONTH extends Extract(Month, Month.sql) { + override def params: Seq[String] = Seq.empty +} + +object DAY extends Extract(Day, Day.sql) { + override def params: Seq[String] = Seq.empty +} + +object HOUR extends Extract(Hour, Hour.sql) { + override def params: Seq[String] = Seq.empty +} + +object MINUTE extends Extract(Minute, Minute.sql) { + override def params: Seq[String] = Seq.empty +} + +object SECOND extends Extract(Second, Second.sql) { + override def params: Seq[String] = Seq.empty +} + +case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit) + extends SQLExpr("date_diff") + with DateTimeFunction + with SQLBinaryFunction[SQLDateTime, SQLDateTime, SQLNumber] + with PainlessScript { + override def outputType: SQLNumber = SQLTypes.Number + override def left: PainlessScript = end + override def right: PainlessScript = start + override def toSQL(base: String): String = { + s"$sql(${end.sql}, ${start.sql}, ${unit.sql})" + } + override def painless: String = s"${unit.painless}.between(${start.painless}, ${end.painless})" +} + +case class DateAdd(identifier: SQLIdentifier, interval: TimeInterval) + extends SQLExpr("date_add") + with DateFunction + with SQLTransformFunction[SQLDate, SQLDate] + with SQLFunctionWithIdentifier { + override def inputType: SQLDate = SQLTypes.Date + override def outputType: SQLDate = SQLTypes.Date + override def toSQL(base: String): String = { + s"$sql($base, ${interval.sql})" + } + override def painless: String = s".plus(${interval.painless})" +} + +case class DateSub(identifier: SQLIdentifier, interval: TimeInterval) + extends SQLExpr("date_sub") + with DateFunction + with SQLTransformFunction[SQLDate, SQLDate] + with SQLFunctionWithIdentifier { + override def inputType: SQLDate = SQLTypes.Date + override def outputType: SQLDate = SQLTypes.Date + override def toSQL(base: String): String = { + s"$sql($base, ${interval.sql})" + } + override def painless: String = s".minus(${interval.painless})" +} + +case class ParseDate(identifier: SQLIdentifier, format: String) + extends SQLExpr("parse_date") + with DateFunction + with SQLTransformFunction[SQLString, SQLDate] + with SQLFunctionWithIdentifier { + override def inputType: SQLString = SQLTypes.String + override def outputType: SQLDate = SQLTypes.Date + override def toSQL(base: String): String = { + s"$sql($base, '$format')" + } + override def painless: String = throw new NotImplementedError("Use toPainless instead") + override def toPainless(base: String): String = + s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDate::from)" +} + +case class FormatDate(identifier: SQLIdentifier, format: String) + extends SQLExpr("format_date") + with DateFunction + with SQLTransformFunction[SQLDate, SQLString] + with SQLFunctionWithIdentifier { + override def inputType: SQLDate = SQLTypes.Date + override def outputType: SQLString = SQLTypes.String + override def toSQL(base: String): String = { + s"$sql($base, '$format')" + } + override def painless: String = throw new NotImplementedError("Use toPainless instead") + override def toPainless(base: String): String = + s"DateTimeFormatter.ofPattern('$format').format($base)" +} + +case class DateTimeAdd(identifier: SQLIdentifier, interval: TimeInterval) + extends SQLExpr("datetime_add") + with DateTimeFunction + with SQLTransformFunction[SQLDateTime, SQLDateTime] + with SQLFunctionWithIdentifier { + override def inputType: SQLDateTime = SQLTypes.DateTime + override def outputType: SQLDateTime = SQLTypes.DateTime + override def toSQL(base: String): String = { + s"$sql($base, ${interval.sql})" + } + override def painless: String = s".plus(${interval.painless})" +} + +case class DateTimeSub(identifier: SQLIdentifier, interval: TimeInterval) + extends SQLExpr("datetime_sub") + with DateTimeFunction + with SQLTransformFunction[SQLDateTime, SQLDateTime] + with SQLFunctionWithIdentifier { + override def inputType: SQLDateTime = SQLTypes.DateTime + override def outputType: SQLDateTime = SQLTypes.DateTime + override def toSQL(base: String): String = { + s"$sql($base, ${interval.sql})" + } + override def painless: String = s".minus(${interval.painless})" +} + +case class ParseDateTime(identifier: SQLIdentifier, format: String) + extends SQLExpr("parse_datetime") + with DateTimeFunction + with SQLTransformFunction[SQLString, SQLDateTime] + with SQLFunctionWithIdentifier { + override def inputType: SQLString = SQLTypes.String + override def outputType: SQLDateTime = SQLTypes.DateTime + override def toSQL(base: String): String = { + s"$sql($base, '$format')" + } + override def painless: String = throw new NotImplementedError("Use toPainless instead") + override def toPainless(base: String): String = + s"DateTimeFormatter.ofPattern('$format').parse($base, ZonedDateTime::from)" +} + +case class FormatDateTime(identifier: SQLIdentifier, format: String) + extends SQLExpr("format_datetime") + with DateTimeFunction + with SQLTransformFunction[SQLDateTime, SQLString] + with SQLFunctionWithIdentifier { + override def inputType: SQLDateTime = SQLTypes.DateTime + override def outputType: SQLString = SQLTypes.String + override def toSQL(base: String): String = { + s"$sql($base, '$format')" + } + override def painless: String = throw new NotImplementedError("Use toPainless instead") + override def toPainless(base: String): String = + s"DateTimeFormatter.ofPattern('$format').format($base)" +} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala index 4903c3ac..e30e7f73 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala @@ -3,7 +3,7 @@ package app.softnetwork.elastic.sql case object GroupBy extends SQLExpr("group by") with SQLRegex case class SQLGroupBy(buckets: Seq[SQLBucket]) extends Updateable { - override def sql: String = s" $GroupBy ${buckets.mkString(",")}" + override def sql: String = s" $GroupBy ${buckets.mkString(", ")}" def update(request: SQLSearchRequest): SQLGroupBy = this.copy(buckets = buckets.map(_.update(request))) lazy val bucketNames: Map[String, SQLBucket] = buckets.map { b => @@ -34,7 +34,7 @@ case class SQLBucket( object BucketSelectorScript { private[this] def painlessIn(param: String, values: Seq[SQLValue[_]], not: Boolean): String = { - val ret = s"[${values.map { _.painlessValue }.mkString(", ")}].contains($param)" + val ret = s"[${values.map { _.painless }.mkString(", ")}].contains($param)" if (not) s"!$ret" else ret } @@ -44,7 +44,7 @@ object BucketSelectorScript { upper: SQLValue[_], not: Boolean ): String = { - val ret = s"($param >= ${lower.painlessValue} && $param <= ${upper.painlessValue})" + val ret = s"($param >= ${lower.painless} && $param <= ${upper.painless})" if (not) s"!$ret" else ret } @@ -55,44 +55,22 @@ object BucketSelectorScript { not: Boolean ): String = { operator match { - case _: SQLComparisonOperator => + case o: SQLComparisonOperator => val valueStr = value match { - case v: SQLBoolean => v.painlessValue - case v: SQLDouble => v.painlessValue - case v: SQLLiteral => v.painlessValue - case v: SQLLong => v.painlessValue + case v: SQLBoolean => v.painless + case v: SQLDouble => v.painless + case v: SQLLiteral => v.painless + case v: SQLLong => v.painless case _ => throw new IllegalArgumentException( s"Unsupported value type in bucket_selector: $value" ) } - if (not) { - operator match { - case Eq => s"$param != $valueStr" - case Ne => s"$param == $valueStr" - case Gt => s"$param <= $valueStr" - case Ge => s"$param < $valueStr" - case Lt => s"$param >= $valueStr" - case Le => s"$param > $valueStr" - case _ => - throw new IllegalArgumentException( - s"Unsupported comparison operator in bucket_selector: $operator" - ) - } - } else - operator match { - case Eq => s"$param == $valueStr" - case Ne => s"$param != $valueStr" - case Gt => s"$param > $valueStr" - case Ge => s"$param >= $valueStr" - case Lt => s"$param < $valueStr" - case Le => s"$param <= $valueStr" - case _ => - throw new IllegalArgumentException( - s"Unsupported comparison operator in bucket_selector: $operator" - ) - } + if (not) + s"$param ${o.not.painless} $valueStr" + else + s"$param ${o.painless} $valueStr" case In => value match { case SQLDoubleValues(vals) => painlessIn(param, vals, not) @@ -117,17 +95,20 @@ object BucketSelectorScript { extractBucketsPath(left) ++ extractBucketsPath(right) case relation: ElasticRelation => extractBucketsPath(relation.criteria) case _: SQLMatch => Map.empty //MATCH is not supported in bucket_selector - case e: Expression => + case b: BinaryExpression => + import b._ + if (left.aggregation && right.aggregation) + Map(left.aliasOrName -> left.aliasOrName, right.aliasOrName -> right.aliasOrName) + else if (left.aggregation) + Map(left.aliasOrName -> left.aliasOrName) + else if (right.aggregation) + Map(right.aliasOrName -> right.aliasOrName) + else + Map.empty + case e: Expression if e.aggregation => import e._ - val name = identifier.fieldAlias.getOrElse(identifier.name) - if (e.aggregation) { - Map(name -> name) - } /*else if (e.identifier.bucket.isDefined) { - Map(name -> "_key") - }*/ - else { - Map.empty // for performance, we only allow aggregation here - } + Map(identifier.aliasOrName -> identifier.aliasOrName) + case _ => Map.empty } def toPainless(expr: SQLCriteria): String = expr match { @@ -147,26 +128,46 @@ object BucketSelectorScript { case relation: ElasticRelation => toPainless(relation.criteria) + case SQLComparisonDateMath(identifier, op, dateFunc, arithOp, interval, maybeNot) + if identifier.aggregation => + val painlessOp = if (maybeNot.nonEmpty) op.not.painless else op.painless + val paramName = identifier.aliasOrName + // always use a correct "now" creation + val now = "ZonedDateTime.now(ZoneId.of('Z'))" + + // build the RHS as a Painless ZonedDateTime (apply +/- interval using TimeInterval.painless) + val rightBase = (arithOp, interval) match { + case (Some(Add), Some(i)) => s"$now.plus(${i.painless})" + case (Some(Subtract), Some(i)) => s"$now.minus(${i.painless})" + case _ => now + } + + val rightZdt = dateFunc match { + // truncate only after arithmetic for CurrentDate + case _: CurrentDateFunction => s"$rightBase.truncatedTo(ChronoUnit.DAYS)" + case _: CurrentTimeFunction => s"$rightBase.truncatedTo(ChronoUnit.SECONDS)" + case _ => rightBase + } + + // protect against null params and compare epoch millis + s"(params.$paramName != null) && (params.$paramName $painlessOp $rightZdt.toInstant().toEpochMilli())" + case _: SQLMatch => "1 == 1" //MATCH is not supported in bucket_selector - case e: Expression => - if (e.aggregation /*|| e.identifier.bucket.isDefined*/ ) { // for performance, we only allow aggregation here - val param = - s"params.${e.identifier.fieldAlias.getOrElse(e.identifier.name)}" - e.maybeValue match { - case Some(v) => toPainless(param, e.operator, v, e.maybeNot.nonEmpty) - case None => - e.operator match { - case IsNull => s"$param == null" - case IsNotNull => s"$param != null" - case _ => - throw new IllegalArgumentException(s"Operator ${e.operator} requires a value") - } - } - } else { - "1 == 1" + case e: Expression if e.aggregation => + val param = + s"params.${e.identifier.aliasOrName}" + e.maybeValue match { + case Some(v) => toPainless(param, e.operator, v, e.maybeNot.nonEmpty) + case None => + e.operator match { + case IsNull => s"$param == null" + case IsNotNull => s"$param != null" + case _ => + throw new IllegalArgumentException(s"Operator ${e.operator} requires a value") + } } - case _ => throw new IllegalArgumentException(s"Unsupported SQLCriteria type: $expr") + case _ => "1 == 1" //throw new IllegalArgumentException(s"Unsupported SQLCriteria type: $expr") } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala index 5df94ea8..21342118 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala @@ -2,12 +2,38 @@ package app.softnetwork.elastic.sql trait SQLOperator extends SQLToken +sealed trait ArithmeticOperator extends SQLOperator with MathScript { + override def toString: String = s" $sql " + override def script: String = sql +} +case object Add extends SQLExpr("+") with ArithmeticOperator +case object Subtract extends SQLExpr("-") with ArithmeticOperator +case object Multiply extends SQLExpr("*") with ArithmeticOperator +case object Divide extends SQLExpr("/") with ArithmeticOperator +case object Modulo extends SQLExpr("%") with ArithmeticOperator + sealed trait SQLExpressionOperator extends SQLOperator -sealed trait SQLComparisonOperator extends SQLExpressionOperator +sealed trait SQLComparisonOperator extends SQLExpressionOperator with PainlessScript { + override def painless: String = this match { + case Eq => "==" + case Ne => "!=" + case other => other.sql + } + + def not: SQLComparisonOperator = this match { + case Eq => Ne + case Ne | Diff => Eq + case Ge => Lt + case Gt => Le + case Le => Gt + case Lt => Ge + } +} case object Eq extends SQLExpr("=") with SQLComparisonOperator case object Ne extends SQLExpr("<>") with SQLComparisonOperator +case object Diff extends SQLExpr("!=") with SQLComparisonOperator case object Ge extends SQLExpr(">=") with SQLComparisonOperator case object Gt extends SQLExpr(">") with SQLComparisonOperator case object Le extends SQLExpr("<=") with SQLComparisonOperator diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala index 74f3110a..4a04f005 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala @@ -11,17 +11,13 @@ case object Asc extends SQLExpr("asc") with SortOrder case class SQLFieldSort( field: String, order: Option[SortOrder], - function: Option[SQLFunction] = None -) extends SQLTokenWithFunction { - private[this] lazy val fieldWithFunction: String = function match { - case Some(f) => s"$f($field)" - case _ => field - } + functions: List[SQLFunction] = List.empty +) extends SQLFunctionChain { lazy val direction: SortOrder = order.getOrElse(Asc) - lazy val name: String = fieldWithFunction + lazy val name: String = toSQL(field) override def sql: String = s"$name $direction" } case class SQLOrderBy(sorts: Seq[SQLFieldSort]) extends SQLToken { - override def sql: String = s" $OrderBy ${sorts.mkString(",")}" + override def sql: String = s" $OrderBy ${sorts.mkString(", ")}" } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 00922443..ad1b1bd4 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -2,6 +2,7 @@ package app.softnetwork.elastic.sql import scala.util.parsing.combinator.{PackratParsers, RegexParsers} import scala.util.parsing.input.CharSequenceReader +import TimeUnit._ /** Created by smanciot on 27/06/2018. * @@ -85,19 +86,206 @@ trait SQLParser extends RegexParsers with PackratParsers { def sum: PackratParser[AggregateFunction] = Sum.regex ^^ (_ => Sum) - def aggregateFunction: PackratParser[AggregateFunction] = count | min | max | avg | sum + def year: PackratParser[TimeUnit] = Year.regex ^^ (_ => Year) - def distanceFunction: PackratParser[SQLFunction] = Distance.regex ^^ (_ => Distance) + def month: PackratParser[TimeUnit] = Month.regex ^^ (_ => Month) - def sqlFunction: PackratParser[SQLFunction] = aggregateFunction | distanceFunction + def quarter: PackratParser[TimeUnit] = Quarter.regex ^^ (_ => Quarter) - private val regexIdentifier = """[\*a-zA-Z_\-][a-zA-Z0-9_\-\.\[\]\*]*""" + def week: PackratParser[TimeUnit] = Week.regex ^^ (_ => Week) - def identifierWithFunction: PackratParser[SQLIdentifier] = - sqlFunction ~ start ~ identifier ~ end ^^ { case f ~ _ ~ i ~ _ => - i.copy(function = Some(f)) + def day: PackratParser[TimeUnit] = Day.regex ^^ (_ => Day) + + def hour: PackratParser[TimeUnit] = Hour.regex ^^ (_ => Hour) + + def minute: PackratParser[TimeUnit] = Minute.regex ^^ (_ => Minute) + + def second: PackratParser[TimeUnit] = Second.regex ^^ (_ => Second) + + def time_unit: PackratParser[TimeUnit] = + year | month | quarter | week | day | hour | minute | second + + def interval: PackratParser[TimeInterval] = + Interval.regex ~ long ~ time_unit ^^ { case _ ~ l ~ u => + TimeInterval(l.value.toInt, u) + } + + def current_date: PackratParser[CurrentDateTimeFunction] = + CurrentDate.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => + if (s.isDefined && t.isDefined) CurentDateWithParens else CurrentDate + } + + def current_time: PackratParser[CurrentDateTimeFunction] = + CurrentTime.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => + if (s.isDefined && t.isDefined) CurrentTimeWithParens else CurrentTime + } + + def current_timestamp: PackratParser[CurrentDateTimeFunction] = + CurrentTimestamp.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => + if (s.isDefined && t.isDefined) CurrentTimestampWithParens else CurrentTimestamp + } + + def now: PackratParser[CurrentDateTimeFunction] = Now.regex ~ start.? ~ end.? ^^ { + case _ ~ s ~ t => + if (s.isDefined && t.isDefined) NowWithParens else Now + } + + def add: PackratParser[ArithmeticOperator] = Add.sql ^^ (_ => Add) + + def substract: PackratParser[ArithmeticOperator] = Subtract.sql ^^ (_ => Subtract) + + def intervalOperator: PackratParser[ArithmeticOperator] = add | substract + + def arithmeticOperator: PackratParser[ArithmeticOperator] = intervalOperator + + def addInterval: PackratParser[SQLAddInterval] = + add ~ interval ^^ { case _ ~ it => + SQLAddInterval(it) + } + + def substractInterval: PackratParser[SQLSubstractInterval] = + substract ~ interval ^^ { case _ ~ it => + SQLSubstractInterval(it) + } + + def intervalFunction: PackratParser[SQLArithmeticFunction[SQLDateTime, SQLDateTime]] = + addInterval | substractInterval + + def arithmeticFunction: PackratParser[SQLArithmeticFunction[_, _]] = intervalFunction + + def identifierWithSystemFunction: PackratParser[SQLIdentifier] = + (current_date | current_time | current_timestamp | now) ~ intervalFunction.? ^^ { + case f1 ~ f2 => + f2 match { + case Some(f) => SQLIdentifier("", functions = List(f, f1)) + case None => SQLIdentifier("", functions = List(f1)) + } + } + + def date_trunc: PackratParser[SQLFunctionWithIdentifier] = + "(?i)date_trunc".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ u ~ _ => + DateTrunc(i, u) + } + + def extract: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + "(?i)extract".r ~ start ~ time_unit ~ end ^^ { case _ ~ _ ~ u ~ _ => + Extract(u) + } + + def extract_year: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + Year.regex ^^ (_ => YEAR) + + def extract_month: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + Month.regex ^^ (_ => MONTH) + + def extract_day: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = Day.regex ^^ (_ => DAY) + + def extract_hour: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + Hour.regex ^^ (_ => HOUR) + + def extract_minute: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + Minute.regex ^^ (_ => MINUTE) + + def extract_second: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + Second.regex ^^ (_ => SECOND) + + def extractors: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = + extract | extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second + + def date_add: PackratParser[DateFunction with SQLFunctionWithIdentifier] = + "(?i)date_add".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ t ~ _ => + DateAdd(i, t) + } + + def date_sub: PackratParser[DateFunction with SQLFunctionWithIdentifier] = + "(?i)date_sub".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ t ~ _ => + DateSub(i, t) + } + + def parse_date: PackratParser[DateFunction with SQLFunctionWithIdentifier] = + "(?i)parse_date".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ f ~ _ => + ParseDate(i, f.value) + } + + def format_date: PackratParser[DateFunction with SQLFunctionWithIdentifier] = + "(?i)format_date".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ f ~ _ => + FormatDate(i, f.value) + } + + def date_functions: PackratParser[DateFunction] = date_add | date_sub | parse_date | format_date + + def datetime_add: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = + "(?i)datetime_add".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ t ~ _ => + DateTimeAdd(i, t) + } + + def datetime_sub: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = + "(?i)datetime_sub".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ t ~ _ => + DateTimeSub(i, t) + } + + def parse_datetime: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = + "(?i)parse_datetime".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ f ~ _ => + ParseDateTime(i, f.value) + } + + def format_datetime: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = + "(?i)format_datetime".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ f ~ _ => + FormatDateTime(i, f.value) + } + + def datetime_functions: PackratParser[DateTimeFunction] = + datetime_add | datetime_sub | parse_datetime | format_datetime + + def aggregates: PackratParser[AggregateFunction] = count | min | max | avg | sum + + def distance: PackratParser[SQLFunction] = Distance.regex ^^ (_ => Distance) + + def painless_identifier: PackratParser[SQLIdentifier] = + repsep( + date_trunc | extractors | date_functions | datetime_functions, + start + ) ~ start.? ~ (identifierWithSystemFunction | identifierWithArithmeticFunction | identifier).? ~ rep( + end + ) ^^ { case f ~ _ ~ i ~ _ => + SQLValidator.validateChain(f) match { + case Left(error) => throw SQLValidationError(error) + case _ => + } + i match { + case Some(id) => id.copy(functions = id.functions ++ f) + case None => + f.lastOption match { + case Some(fi: SQLFunctionWithIdentifier) => + fi.identifier.copy(functions = f ++ fi.identifier.functions) + case _ => SQLIdentifier("", functions = f) + } + } } + def date_diff: PackratParser[SQLBinaryFunction[_, _, _]] = + "(?i)date_diff".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { + case _ ~ _ ~ d1 ~ _ ~ d2 ~ _ ~ u ~ _ => DateDiff(d1, d2, u) + } + + def date_diff_identifier: PackratParser[SQLIdentifier] = date_diff ^^ { dd => + SQLIdentifier("", functions = dd :: Nil) + } + + def sql_functions: PackratParser[SQLFunction] = + aggregates | distance | date_diff | date_trunc | extractors | date_functions | datetime_functions + + private val regexIdentifier = """[\*a-zA-Z_\-][a-zA-Z0-9_\-\.\[\]\*]*""" + def identifier: PackratParser[SQLIdentifier] = Distinct.regex.? ~ regexIdentifier.r ^^ { case d ~ i => SQLIdentifier( @@ -107,15 +295,62 @@ trait SQLParser extends RegexParsers with PackratParsers { ) } + private[this] def dateFunctionWithIdentifier: PackratParser[SQLIdentifier] = + (parse_date | format_date | date_add | date_sub) ^^ { t => + t.identifier.copy(functions = t +: t.identifier.functions) + } + + private[this] def dateTimeFunctionWithIdentifier: PackratParser[SQLIdentifier] = + (date_trunc | parse_datetime | format_datetime | datetime_add | datetime_sub) ^^ { t => + t.identifier.copy(functions = t +: t.identifier.functions) + } + + def identifierWithTransformation: PackratParser[SQLIdentifier] = + dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier + + def identifierWithArithmeticFunction: PackratParser[SQLIdentifier] = + (identifierWithFunction | identifier) ~ arithmeticFunction ^^ { case i ~ af => + i.copy(functions = af +: i.functions) + } + + def identifierWithAggregation: PackratParser[SQLIdentifier] = + aggregates ~ start ~ (identifierWithFunction | identifierWithArithmeticFunction | identifier) ~ end ^^ { + case a ~ _ ~ i ~ _ => + i.copy(functions = a +: i.functions) + } + + def identifierWithFunction: PackratParser[SQLIdentifier] = + rep1sep( + sql_functions, + start + ) ~ start.? ~ (identifierWithSystemFunction | identifierWithArithmeticFunction | identifier).? ~ rep1( + end + ) ^^ { case f ~ _ ~ i ~ _ => + SQLValidator.validateChain(f) match { + case Left(error) => throw SQLValidationError(error) + case _ => + } + i match { + case None => + f.lastOption match { + case Some(fi: SQLFunctionWithIdentifier) => + fi.identifier.copy(functions = f ++ fi.identifier.functions) + case _ => SQLIdentifier("", functions = f) + } + case Some(id) => id.copy(functions = id.functions ++ f) + } + } + private val regexAlias = """\b(?!(?i)as\b)\b(?!(?i)except\b)\b(?!(?i)where\b)\b(?!(?i)filter\b)\b(?!(?i)from\b)\b(?!(?i)group\b)\b(?!(?i)having\b)\b(?!(?i)order\b)\b(?!(?i)limit\b)[a-zA-Z0-9_]*""" def alias: PackratParser[SQLAlias] = Alias.regex.? ~ regexAlias.r ^^ { case _ ~ b => SQLAlias(b) } - def field: PackratParser[SQLField] = (identifierWithFunction | identifier) ~ alias.? ^^ { - case i ~ a => - SQLField(i, a) - } + def field: PackratParser[Field] = + (identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | identifierWithTransformation | date_diff_identifier | identifier) ~ alias.? ^^ { + case i ~ a => + SQLField(i, a) + } } @@ -128,7 +363,10 @@ trait SQLSelectParser { } def select: PackratParser[SQLSelect] = - Select.regex ~ rep1sep(field, separator) ~ except.? ^^ { case _ ~ fields ~ e => + Select.regex ~ rep1sep( + field, + separator + ) ~ except.? ^^ { case _ ~ fields ~ e => SQLSelect(fields, e) } @@ -162,30 +400,33 @@ trait SQLWhereParser { SQLIsNotNull(i) } - private def eq: PackratParser[SQLExpressionOperator] = Eq.sql ^^ (_ => Eq) + private def eq: PackratParser[SQLComparisonOperator] = Eq.sql ^^ (_ => Eq) - private def ne: PackratParser[SQLExpressionOperator] = Ne.sql ^^ (_ => Ne) + private def ne: PackratParser[SQLComparisonOperator] = Ne.sql ^^ (_ => Ne) + + private def diff: PackratParser[SQLComparisonOperator] = Diff.sql ^^ (_ => Diff) private def equality: PackratParser[SQLExpression] = - not.? ~ (identifierWithFunction | identifier) ~ (eq | ne) ~ (boolean | literal | double | long) ^^ { + not.? ~ (identifierWithAggregation | identifierWithFunction | identifier) ~ (eq | ne | diff) ~ (boolean | literal | double | long) ^^ { case n ~ i ~ o ~ v => SQLExpression(i, o, v, n) } def like: PackratParser[SQLExpression] = - (identifierWithFunction | identifier) ~ not.? ~ Like.regex ~ literal ^^ { case i ~ n ~ _ ~ v => - SQLExpression(i, Like, v, n) + (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Like.regex ~ literal ^^ { + case i ~ n ~ _ ~ v => + SQLExpression(i, Like, v, n) } - private def ge: PackratParser[SQLExpressionOperator] = Ge.sql ^^ (_ => Ge) + private def ge: PackratParser[SQLComparisonOperator] = Ge.sql ^^ (_ => Ge) - def gt: PackratParser[SQLExpressionOperator] = Gt.sql ^^ (_ => Gt) + def gt: PackratParser[SQLComparisonOperator] = Gt.sql ^^ (_ => Gt) - private def le: PackratParser[SQLExpressionOperator] = Le.sql ^^ (_ => Le) + private def le: PackratParser[SQLComparisonOperator] = Le.sql ^^ (_ => Le) - def lt: PackratParser[SQLExpressionOperator] = Lt.sql ^^ (_ => Lt) + def lt: PackratParser[SQLComparisonOperator] = Lt.sql ^^ (_ => Lt) private def comparison: PackratParser[SQLExpression] = - not.? ~ (identifierWithFunction | identifier) ~ (ge | gt | le | lt) ~ (double | long | literal) ^^ { + not.? ~ (identifierWithAggregation | identifierWithFunction | identifier) ~ (ge | gt | le | lt) ~ (double | long | literal) ^^ { case n ~ i ~ o ~ v => SQLExpression(i, o, v, n) } @@ -202,7 +443,7 @@ trait SQLWhereParser { } private def inDoubles: PackratParser[SQLCriteria] = - (identifierWithFunction | identifier) ~ not.? ~ in ~ start ~ rep1sep( + (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ in ~ start ~ rep1sep( double, separator ) ~ end ^^ { case i ~ n ~ _ ~ _ ~ v ~ _ => @@ -214,7 +455,7 @@ trait SQLWhereParser { } private def inLongs: PackratParser[SQLCriteria] = - (identifierWithFunction | identifier) ~ not.? ~ in ~ start ~ rep1sep( + (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ in ~ start ~ rep1sep( long, separator ) ~ end ^^ { case i ~ n ~ _ ~ _ ~ v ~ _ => @@ -226,22 +467,22 @@ trait SQLWhereParser { } def between: PackratParser[SQLCriteria] = - (identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ literal ~ and ~ literal ^^ { + (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ literal ~ and ~ literal ^^ { case i ~ n ~ _ ~ from ~ _ ~ to => SQLBetween(i, SQLLiteralFromTo(from, to), n) } def betweenLongs: PackratParser[SQLCriteria] = - (identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ long ~ and ~ long ^^ { + (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ long ~ and ~ long ^^ { case i ~ n ~ _ ~ from ~ _ ~ to => SQLBetween(i, SQLLongFromTo(from, to), n) } def betweenDoubles: PackratParser[SQLCriteria] = - (identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ double ~ and ~ double ^^ { + (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ double ~ and ~ double ^^ { case i ~ n ~ _ ~ from ~ _ ~ to => SQLBetween(i, SQLDoubleFromTo(from, to), n) } - def distance: PackratParser[SQLCriteria] = - distanceFunction ~ start ~ identifier ~ separator ~ start ~ double ~ separator ~ double ~ end ~ end ~ le ~ literal ^^ { + def sql_distance: PackratParser[SQLCriteria] = + distance ~ start ~ identifier ~ separator ~ start ~ double ~ separator ~ double ~ end ~ end ~ le ~ literal ^^ { case _ ~ _ ~ i ~ _ ~ _ ~ lat ~ _ ~ lon ~ _ ~ _ ~ _ ~ d => ElasticGeoDistance(i, d, lat, lon) } @@ -253,6 +494,13 @@ trait SQLWhereParser { SQLMatch(i, l) } + private def dateTimeComparison: PackratParser[SQLComparisonDateMath] = { + // identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | identifier + not.? ~ (identifierWithAggregation | identifier) ~ (eq | ne | diff | ge | gt | le | lt) ~ (current_date | current_time | current_timestamp | now) ~ arithmeticOperator.? ~ interval.? ^^ { + case n ~ i ~ o ~ dt ~ ao ~ it => SQLComparisonDateMath(i, o, dt, ao, it, n) + } + } + def and: PackratParser[SQLPredicateOperator] = And.regex ^^ (_ => And) def or: PackratParser[SQLPredicateOperator] = Or.regex ^^ (_ => Or) @@ -260,7 +508,7 @@ trait SQLWhereParser { def not: PackratParser[Not.type] = Not.regex ^^ (_ => Not) def criteria: PackratParser[SQLCriteria] = - (equality | like | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | distance | matchCriteria) ^^ ( + (equality | like | dateTimeComparison | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | sql_distance | matchCriteria) ^^ ( c => c ) @@ -389,7 +637,7 @@ trait SQLWhereParser { case _ :: Nil => processTokensHelper(rest, op :: stack) case _ => - throw new IllegalStateException("Invalid stack state for predicate creation") + throw SQLValidationError("Invalid stack state for predicate creation") } case (_: EndDelimiter) :: rest => processTokensHelper(rest, stack) // Ignore and move on @@ -420,7 +668,7 @@ trait SQLWhereParser { */ private def processSubTokens(tokens: List[SQLToken]): SQLCriteria = { processTokensHelper(tokens, Nil).getOrElse( - throw new IllegalStateException("Empty sub-expression") + throw SQLValidationError("Empty sub-expression") ) } @@ -443,7 +691,7 @@ trait SQLWhereParser { subTokens: List[SQLToken] = Nil ): (List[SQLToken], List[SQLToken]) = { tokens match { - case Nil => throw new IllegalStateException("Unbalanced parentheses") + case Nil => throw SQLValidationError("Unbalanced parentheses") case (start: StartDelimiter) :: rest => extractSubTokens(rest, openCount + 1, start :: subTokens) case (end: EndDelimiter) :: rest => @@ -458,14 +706,10 @@ trait SQLWhereParser { trait SQLGroupByParser { self: SQLParser with SQLWhereParser => - private def having: PackratParser[SQLHaving] = Having.regex ~> whereCriteria ^^ { rawTokens => - SQLHaving( - processTokens(rawTokens) - ) + def bucket: PackratParser[SQLBucket] = identifier ^^ { i => + SQLBucket(i) } - def bucket: PackratParser[SQLBucket] = identifier ^^ (i => SQLBucket(i)) - def groupBy: PackratParser[SQLGroupBy] = GroupBy.regex ~ rep1sep(bucket, separator) ^^ { case _ ~ buckets => SQLGroupBy(buckets) @@ -494,16 +738,20 @@ trait SQLOrderByParser { private def fieldName: PackratParser[String] = """\b(?!(?i)limit\b)[a-zA-Z_][a-zA-Z0-9_]*""".r ^^ (f => f) - def fieldWithFunction: PackratParser[(String, SQLFunction)] = - sqlFunction ~ start ~ fieldName ~ end ^^ { case f ~ _ ~ n ~ _ => + def fieldWithFunction: PackratParser[(String, List[SQLFunction])] = + rep1sep(sql_functions, start) ~ start.? ~ fieldName ~ rep1(end) ^^ { case f ~ _ ~ n ~ _ => + SQLValidator.validateChain(f) match { + case Left(error) => throw SQLValidationError(error) + case _ => + } (n, f) } def sort: PackratParser[SQLFieldSort] = (fieldWithFunction | fieldName) ~ (asc | desc).? ^^ { case f ~ o => f match { - case i: (String, SQLFunction) => SQLFieldSort(i._1, o, Some(i._2)) - case s: String => SQLFieldSort(s, o, None) + case i: (String, List[SQLFunction]) => SQLFieldSort(i._1, o, i._2) + case s: String => SQLFieldSort(s, o, List.empty) } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala index fb415157..a578cc53 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala @@ -30,14 +30,19 @@ case class SQLSearchRequest( ) } + lazy val scriptFields: Seq[Field] = select.fields.filter(_.isScriptField) + lazy val fields: Seq[String] = { if (aggregates.isEmpty && buckets.isEmpty) - select.fields.map(_.sourceField).filterNot(f => excludes.contains(f)) + select.fields + .filterNot(_.isScriptField) + .map(_.sourceField) + .filterNot(f => excludes.contains(f)) else Seq.empty } - lazy val aggregates: Seq[SQLField] = select.fields.filter(_.aggregation) + lazy val aggregates: Seq[Field] = select.fields.filter(_.aggregation) lazy val excludes: Seq[String] = select.except.map(_.fields.map(_.sourceField)).getOrElse(Nil) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala index 18e0da9d..e2991f9d 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala @@ -2,41 +2,57 @@ package app.softnetwork.elastic.sql case object Select extends SQLExpr("select") with SQLRegex -case class SQLField( - identifier: SQLIdentifier, - fieldAlias: Option[SQLAlias] = None -) extends Updateable - with SQLTokenWithFunction { +sealed trait Field extends Updateable with SQLFunctionChain with PainlessScript { + def identifier: Identifier + def fieldAlias: Option[SQLAlias] + def isScriptField: Boolean = functions.nonEmpty && !aggregation && identifier.bucket.isEmpty override def sql: String = s"$identifier${asString(fieldAlias)}" - def update(request: SQLSearchRequest): SQLField = - this.copy(identifier = identifier.update(request)) lazy val sourceField: String = if (identifier.nested) { identifier.tableAlias .orElse(fieldAlias.map(_.alias)) .map(a => s"$a.") - .getOrElse("") + identifier.name.split("\\.").tail.mkString(".") + .getOrElse("") + identifier.name + .replace("(", "") + .replace(")", "") + .split("\\.") + .tail + .mkString(".") } else { - identifier.name + identifier.name.replace("(", "").replace(")", "") } - override def function: Option[SQLFunction] = identifier.function + override def functions: List[SQLFunction] = identifier.functions + + def update(request: SQLSearchRequest): Field + + def painless: String = identifier.painless + + lazy val scriptName: String = fieldAlias.map(_.alias).getOrElse(sourceField) +} + +case class SQLField( + identifier: SQLIdentifier, + fieldAlias: Option[SQLAlias] = None +) extends Field { + def update(request: SQLSearchRequest): SQLField = + this.copy(identifier = identifier.update(request)) } case object Except extends SQLExpr("except") with SQLRegex -case class SQLExcept(fields: Seq[SQLField]) extends Updateable { +case class SQLExcept(fields: Seq[Field]) extends Updateable { override def sql: String = s" $Except(${fields.mkString(",")})" def update(request: SQLSearchRequest): SQLExcept = this.copy(fields = fields.map(_.update(request))) } case class SQLSelect( - fields: Seq[SQLField] = Seq(SQLField(identifier = SQLIdentifier("*"))), + fields: Seq[Field] = Seq(SQLField(identifier = SQLIdentifier("*"))), except: Option[SQLExcept] = None ) extends Updateable { override def sql: String = - s"$Select ${fields.mkString(",")}${except.getOrElse("")}" + s"$Select ${fields.mkString(", ")}${except.getOrElse("")}" lazy val fieldAliases: Map[String, String] = fields.flatMap { field => field.fieldAlias.map(a => field.identifier.identifierName -> a.alias) }.toMap diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala new file mode 100644 index 00000000..ff4cebf7 --- /dev/null +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala @@ -0,0 +1,18 @@ +package app.softnetwork.elastic.sql + +sealed trait SQLType { def typeId: String } + +trait SQLAny extends SQLType +trait SQLTemporal extends SQLType +trait SQLDate extends SQLTemporal +trait SQLDateTime extends SQLTemporal +trait SQLNumber extends SQLType +trait SQLString extends SQLType + +object SQLTypeCompatibility { + def matches(out: SQLType, in: SQLType): Boolean = + out.typeId == in.typeId || + (out.typeId == "temporal" && Set("date", "datetime").contains(in.typeId)) || + (in.typeId == "temporal" && Set("date", "datetime").contains(out.typeId)) || + out.typeId == "any" || in.typeId == "any" +} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala new file mode 100644 index 00000000..131c9a01 --- /dev/null +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala @@ -0,0 +1,10 @@ +package app.softnetwork.elastic.sql + +object SQLTypes { + case object Any extends SQLAny { val typeId = "any" } + case object Temporal extends SQLTemporal { val typeId = "temporal" } + case object Date extends SQLDate { val typeId = "date" } + case object DateTime extends SQLDateTime { val typeId = "datetime" } + case object Number extends SQLNumber { val typeId = "number" } + case object String extends SQLString { val typeId = "string" } +} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala new file mode 100644 index 00000000..776bab53 --- /dev/null +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala @@ -0,0 +1,26 @@ +package app.softnetwork.elastic.sql + +object SQLValidator { + + def validateChain(functions: List[SQLFunction]): Either[String, Unit] = { + // validate function chain type compatibility + val unaryFuncs = functions.collect { case f: SQLUnaryFunction[_, _] => f } + unaryFuncs.sliding(2).foreach { + case Seq(f1, f2) => + if (!SQLTypeCompatibility.matches(f2.outputType, f1.inputType)) { + return Left( + s"Type mismatch: output '${f2.outputType.typeId}' of `${f2.sql}` " + + s"is not compatible with input '${f1.inputType.typeId}' of `${f1.sql}`" + ) + } + case _ => // ok + } + Right(()) + } +} + +trait SQLValidation { + def validate(): Either[String, Unit] = Right(()) +} + +case class SQLValidationError(message: String) extends Exception(message) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index 076ca868..b40a09d9 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -89,12 +89,12 @@ case class SQLPredicate( sealed trait ElasticFilter -sealed trait SQLCriteriaWithIdentifier extends SQLCriteria with SQLTokenWithFunction { +sealed trait SQLCriteriaWithIdentifier extends SQLCriteria with SQLFunctionChain { def identifier: SQLIdentifier override def nested: Boolean = identifier.nested override def group: Boolean = false override lazy val limit: Option[SQLLimit] = identifier.limit - override val function: Option[SQLFunction] = identifier.function + override val functions: List[SQLFunction] = identifier.functions } case class ElasticBoolQuery( @@ -151,7 +151,7 @@ case class ElasticBoolQuery( } -trait Expression extends SQLCriteriaWithIdentifier with ElasticFilter { +sealed trait Expression extends SQLCriteriaWithIdentifier with ElasticFilter { def maybeValue: Option[SQLToken] def maybeNot: Option[Not.type] def notAsString: String = maybeNot.map(v => s"$v ").getOrElse("") @@ -178,6 +178,33 @@ case class SQLExpression( override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this } +sealed trait BinaryExpression extends Expression { + def left: SQLIdentifier + def right: SQLIdentifier + override def identifier: SQLIdentifier = left + override def maybeValue: Option[SQLToken] = Some(right) + override lazy val aggregation: Boolean = left.aggregation || right.aggregation +} + +case class SQLBinaryExpression( + left: SQLIdentifier, + operator: SQLComparisonOperator, // Gt, Ge, Lt, Le, Eq, ... + right: SQLIdentifier, + maybeNot: Option[Not.type] = None +) extends BinaryExpression + with PainlessScript { + override def update(request: SQLSearchRequest): SQLCriteria = + this.copy(left = left.update(request), right = right.update(request)) + + override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this + + override def painless: String = { + val painlessOp = (if (maybeNot.isDefined) operator.not else operator).painless + s"${left.painless} $painlessOp ${right.painless}" + } + +} + case class SQLIsNull(identifier: SQLIdentifier) extends Expression { override val operator: SQLOperator = IsNull @@ -219,7 +246,7 @@ case class SQLIn[R, +T <: SQLValue[R]]( values: SQLValues[R, T], maybeNot: Option[Not.type] = None ) extends Expression { this: SQLIn[R, T] => - private[this] lazy val id = function match { + private[this] lazy val id = functions.headOption match { case Some(f) => s"$f($identifier)" case _ => s"$identifier" } @@ -244,7 +271,7 @@ case class SQLBetween[+T]( fromTo: SQLFromTo[T], maybeNot: Option[Not.type] ) extends Expression { - private[this] lazy val id = function match { + private[this] lazy val id = functions.headOption match { case Some(f) => s"$f($identifier)" case _ => s"$identifier" } @@ -271,7 +298,7 @@ case class ElasticGeoDistance( lon: SQLDouble ) extends Expression { override def sql = s"$Distance($identifier,($lat,$lon)) $operator $distance" - override val function: Option[SQLFunction] = Some(Distance) + override val functions: List[SQLFunction] = List(Distance) override def operator: SQLOperator = Le override def update(request: SQLSearchRequest): ElasticGeoDistance = this.copy(identifier = identifier.update(request)) @@ -317,6 +344,56 @@ case class SQLMatch( override def group: Boolean = false } +case class SQLComparisonDateMath( + identifier: SQLIdentifier, + operator: SQLComparisonOperator, // Gt, Ge, Lt, Le, Eq, ... + dateTimeFunction: CurrentDateTimeFunction, // CurrentDate, Now, CurrentTimestamp, CurrentTime, ... + arithmeticOperator: Option[ArithmeticOperator] = + None, // Plus or Minus between dateTimeFunction and interval + interval: Option[TimeInterval] = None, // optional interval + maybeNot: Option[Not.type] = None +) extends Expression + with MathScript { + override def sql: String = { + s"$identifier ${operator.sql} $dateTimeFunction${asString(arithmeticOperator)}${asString(interval)}" + } + override def update(request: SQLSearchRequest): SQLCriteria = + this.copy(identifier = identifier.update(request)) + + override def maybeValue: Option[SQLToken] = Some(SQLScript(script)) + + override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this + + override def script: String = { + dateTimeFunction match { + case _: CurrentTimeFunction => + val painlessOp = (if (maybeNot.isDefined) operator.not else operator).painless + (arithmeticOperator, interval) match { + case (Some(Add), Some(i)) => // compare doc time with now + interval + s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now().plus(${i.value}, ${i.unit.painless});" + + case (Some(Subtract), Some(i)) => // compare doc time with now + s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now().minus(${i.value}, ${i.unit.painless});" + + case _ => + s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now();" + } + case _ => + val base = s"${dateTimeFunction.script}" + val dateMath = + (arithmeticOperator, interval) match { + case (Some(Add), Some(i)) => s"$base+${i.script}" + case (Some(Subtract), Some(i)) => s"$base-${i.script}" + case _ => base + } + dateTimeFunction match { + case _: CurrentDateFunction => s"$dateMath/d" + case _ => dateMath + } + } + } +} + case class ElasticMatch( identifier: SQLIdentifier, value: SQLLiteral, diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index e7316baa..35448221 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -12,7 +12,7 @@ package object sql { import scala.language.implicitConversions implicit def asString(token: Option[_ <: SQLToken]): String = token match { - case Some(t) => t.sql + case Some(t) => t.toString case _ => "" } @@ -21,15 +21,12 @@ package object sql { override def toString: String = sql } - trait SQLTokenWithFunction extends SQLToken { - def function: Option[SQLFunction] - - lazy val aggregateFunction: Option[AggregateFunction] = function match { - case Some(af: AggregateFunction) => Some(af) - case _ => None - } + trait PainlessScript extends SQLToken { + def painless: String + } - lazy val aggregation: Boolean = aggregateFunction.isDefined + trait MathScript extends SQLToken { + def script: String } trait Updateable extends SQLToken { @@ -40,7 +37,9 @@ package object sql { case object Distinct extends SQLExpr("distinct") with SQLRegex - abstract class SQLValue[+T](val value: T)(implicit ev$1: T => Ordered[T]) extends SQLToken { + abstract class SQLValue[+T](val value: T)(implicit ev$1: T => Ordered[T]) + extends SQLToken + with PainlessScript { def choose[R >: T]( values: Seq[R], operator: Option[SQLExpressionOperator], @@ -50,16 +49,16 @@ package object sql { None else operator match { - case Some(_: Eq.type) => values.find(_ == value) - case Some(_: Ne.type) => values.find(_ != value) - case Some(_: Ge.type) => values.filter(_ >= value).sorted.reverse.headOption - case Some(_: Gt.type) => values.filter(_ > value).sorted.reverse.headOption - case Some(_: Le.type) => values.filter(_ <= value).sorted.headOption - case Some(_: Lt.type) => values.filter(_ < value).sorted.headOption - case _ => values.headOption + case Some(Eq) => values.find(_ == value) + case Some(Ne | Diff) => values.find(_ != value) + case Some(Ge) => values.filter(_ >= value).sorted.reverse.headOption + case Some(Gt) => values.filter(_ > value).sorted.reverse.headOption + case Some(Le) => values.filter(_ <= value).sorted.headOption + case Some(Lt) => values.filter(_ < value).sorted.headOption + case _ => values.headOption } } - def painlessValue: String = value match { + def painless: String = value match { case s: String => s""""$s"""" case b: Boolean => b.toString case n: Number => n.toString @@ -90,11 +89,11 @@ package object sql { separator: String = "|" )(implicit ev: R => Ordered[R]): Option[R] = { operator match { - case Some(_: Eq.type) => values.find(v => v.toString contentEquals value) - case Some(_: Ne.type) => values.find(v => !(v.toString contentEquals value)) - case Some(_: Like.type) => values.find(v => pattern.matcher(v.toString).matches()) - case None => Some(values.mkString(separator)) - case _ => super.choose(values, operator, separator) + case Some(Eq) => values.find(v => v.toString contentEquals value) + case Some(Ne | Diff) => values.find(v => !(v.toString contentEquals value)) + case Some(Like) => values.find(v => pattern.matcher(v.toString).matches()) + case None => Some(values.mkString(separator)) + case _ => super.choose(values, operator, separator) } } } @@ -211,10 +210,10 @@ package object sql { value.choose[T](values, Some(operator)) case _ => function match { - case Some(_: Min.type) => Some(values.min) - case Some(_: Max.type) => Some(values.max) - // FIXME case Some(_: SQLSum.type) => Some(values.sum) - // FIXME case Some(_: SQLAvg.type) => Some(values.sum / values.length ) + case Some(Min) => Some(values.min) + case Some(Max) => Some(values.max) + // FIXME case Some(SQLSum) => Some(values.sum) + // FIXME case Some(SQLAvg) => Some(values.sum / values.length ) case _ => values.headOption } } @@ -248,13 +247,45 @@ package object sql { def update(request: SQLSearchRequest): SQLSource } + trait Identifier extends SQLToken with SQLSource with SQLFunctionChain with PainlessScript { + def name: String + + def tableAlias: Option[String] + def distinct: Boolean + def nested: Boolean + def fieldAlias: Option[String] + def bucket: Option[SQLBucket] + + lazy val identifierName: String = + functions.reverse.foldLeft(name)((expr, fun) => { + fun.toSQL(expr) + }) + + lazy val nestedType: Option[String] = if (nested) Some(name.split('.').head) else None + + lazy val innerHitsName: Option[String] = if (nested) tableAlias else None + + lazy val aliasOrName: String = fieldAlias.getOrElse(name) + + override def painless: String = { + val base = if (name.nonEmpty) s"doc['$name'].value" else "" + val orderedFunctions = SQLFunctionUtils.transformFunctions(this).reverse + orderedFunctions.foldLeft(base) { + case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr) + case (expr, f: PainlessScript) => s"$expr${f.painless}" + case (expr, f) => f.toSQL(expr) // fallback + } + } + + } + case class SQLIdentifier( name: String, tableAlias: Option[String] = None, distinct: Boolean = false, nested: Boolean = false, limit: Option[SQLLimit] = None, - function: Option[SQLFunction] = None, + functions: List[SQLFunction] = List.empty, fieldAlias: Option[String] = None, bucket: Option[SQLBucket] = None ) extends SQLExpr({ @@ -270,26 +301,11 @@ package object sql { parts.mkString(".").trim } } - function match { - case Some(f) => s"$f($sql)" - case _ => sql - } + functions.reverse.foldLeft(sql)((expr, fun) => { + fun.toSQL(expr) + }) }) - with SQLSource - with SQLTokenWithFunction { - - lazy val aggregationName: Option[String] = - if (aggregation) fieldAlias.orElse(Option(name)) else None - - lazy val identifierName: String = - (function match { - case Some(f) => s"${f.sql}($name)" - case _ => name - }).toLowerCase - - lazy val nestedType: Option[String] = if (nested) Some(name.split('.').head) else None - - lazy val innerHitsName: Option[String] = if (nested) tableAlias else None + with Identifier { def update(request: SQLSearchRequest): SQLIdentifier = { val parts: Seq[String] = name.split("\\.").toSeq @@ -320,4 +336,6 @@ package object sql { } } } + + case class SQLScript(script: String) extends SQLExpr(script) } diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala new file mode 100644 index 00000000..2017093c --- /dev/null +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala @@ -0,0 +1,95 @@ +package app.softnetwork.elastic.sql + +import org.scalatest.funsuite.AnyFunSuite +import TimeUnit._ + +class SQLDateTimeFunctionSuite extends AnyFunSuite { + + // Base d'exemple + val baseDate = "doc['createdAt'].value" + + // Liste de toutes les fonctions transformables avec leurs types + val transformFunctions: Seq[SQLTransformFunction[_, _]] = Seq( + ParseDate(SQLIdentifier(""), "yyyy-MM-dd"), + ParseDateTime(SQLIdentifier(""), "yyyy-MM-dd HH:mm:ss"), + DateAdd(SQLIdentifier(""), TimeInterval(1, Day)), + DateSub(SQLIdentifier(""), TimeInterval(2, Month)), + DateTimeAdd(SQLIdentifier(""), TimeInterval(3, Hour)), + DateTimeSub(SQLIdentifier(""), TimeInterval(30, Minute)), + DateTrunc(SQLIdentifier(""), Day), + Extract(Day), + FormatDate(SQLIdentifier(""), "yyyy-MM-dd"), + FormatDateTime(SQLIdentifier(""), "yyyy-MM-dd HH:mm:ss"), + YEAR, + MONTH, + DAY, + HOUR, + MINUTE, + SECOND + ) + + // Fonction pour chaîner une séquence de transformations en vérifiant les types + def chainTransformsTyped( + base: String, + transforms: Seq[SQLTransformFunction[_, _]] + ): String = { + require(transforms.nonEmpty, "No transforms provided") + + val initial: (String, SQLType) = + (transforms.head.toPainless(base), transforms.head.outputType.asInstanceOf[SQLType]) + + val (finalExpr, _) = transforms.tail.foldLeft(initial) { + case ((expr, currentType), t: SQLUnaryFunction[_, _]) => + if (!currentType.getClass.isAssignableFrom(t.inputType.getClass)) { + throw new IllegalArgumentException( + s"Type mismatch: expected ${currentType.getClass.getSimpleName}, got ${t.inputType.getClass.getSimpleName}" + ) + } + (t.toPainless(expr), t.outputType.asInstanceOf[SQLType]) + } + + finalExpr + } + + // Générer dynamiquement tous les chaînages valides jusqu'à N fonctions + def generateChains( + functions: Seq[SQLTransformFunction[_, _]], + maxLength: Int + ): Seq[Seq[SQLTransformFunction[_, _]]] = { + if (maxLength <= 1) functions.map(Seq(_)) + else { + val shorter = generateChains(functions, maxLength - 1) + for { + chain <- shorter + f <- functions + if f.inputType.getClass.isAssignableFrom(chain.last.outputType.getClass) + } yield chain :+ f + } + } + + // Tester tous les chaînages pour N=2 et N=3 + val chains2: Seq[Seq[SQLTransformFunction[_, _]]] = + generateChains(transformFunctions, 2) + val chains3: Seq[Seq[SQLTransformFunction[_, _]]] = + generateChains(transformFunctions, 3) + + (chains2 ++ chains3).zipWithIndex.foreach { case (chain, idx) => + val names = chain.map(_.sql).mkString(" -> ") + test(s"Valid chain $idx: $names") { + val chained = chainTransformsTyped(baseDate, chain) + val expected = chain.reverse.tail.foldLeft(chain.last.toPainless(baseDate)) { (expr, f) => + f.toPainless(expr) + } + // On ne teste que la génération de code Painless sans évaluer le résultat + assert(chained.nonEmpty) + } + } + + // Test simple pour chaque fonction individuelle + transformFunctions.foreach { f => + test(s"Single transformation ${f.sql}") { + val result = f.toPainless(baseDate) + assert(result.nonEmpty) + } + } +} diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 93758fba..e961a9fc 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -4,7 +4,7 @@ import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers object Queries { - val numericalEq = "select t.col1,t.col2 from Table as t where t.identifier = 1.0" + val numericalEq = "select t.col1, t.col2 from Table as t where t.identifier = 1.0" val numericalLt = "select * from Table where identifier < 1" val numericalLe = "select * from Table where identifier <= 1" val numericalGt = "select * from Table where identifier > 1" @@ -63,22 +63,76 @@ object Queries { val matchCriteria = "select * from Table where match (identifier1,identifier2,identifier3) against (\"value\")" val groupBy = - "select identifier,count(identifier) from Table where identifier is not null group by identifier" + "select identifier, count(identifier2) from Table where identifier2 is not null group by identifier" val orderBy = "select * from Table order by identifier desc" val limit = "select * from Table limit 10" val groupByWithOrderByAndLimit: String = - """select identifier,count(identifier) + """select identifier, count(identifier2) |from Table |where identifier is not null |group by identifier - |order by identifier desc + |order by identifier2 desc |limit 10""".stripMargin.replaceAll("\n", " ") val groupByWithHaving: String = - """SELECT COUNT(CustomerID) as cnt,City,Country - |FROM Customers - |GROUP BY Country,City - |HAVING Country <> "USA" AND City <> "Berlin" AND COUNT(CustomerID) > 1 - |ORDER BY COUNT(CustomerID) DESC,Country asc""".stripMargin.replaceAll("\n", " ").toLowerCase + """select count(CustomerID) as cnt, City, Country + |from Customers + |group by Country, City + |having Country <> "USA" and City <> "Berlin" and count(CustomerID) > 1 + |order by count(CustomerID) desc, Country asc""".stripMargin.replaceAll("\n", " ") + val dateTimeWithIntervalFields: String = + "select current_timestamp() - interval 3 day as ct, current_date as cd, current_time as t, now as n from dual" + val fieldsWithInterval: String = + "select createdAt - interval 35 minute as ct, identifier from Table" + val filterWithDateTimeAndInterval: String = + "select * from Table where createdAt < current_timestamp() and createdAt >= current_timestamp() - interval 10 day" + val filterWithDateAndInterval: String = + "select * from Table where createdAt < current_date and createdAt >= current_date() - interval 10 day" + val filterWithTimeAndInterval: String = + "select * from Table where createdAt < current_time and createdAt >= current_time() - interval 10 minute" + val groupByWithHavingAndDateTimeFunctions: String = + """select count(CustomerID) as cnt, City, Country, max(createdAt) as lastSeen + |from Table + |group by Country, City + |having Country <> "USA" and City != "Berlin" and count(CustomerID) > 1 and lastSeen > now - interval 7 day + |order by Country asc""".stripMargin + .replaceAll("\n", " ") + val parseDate = + "select identifier, count(identifier2) as ct, max(parse_date(createdAt, 'yyyy-MM-dd')) as lastSeen from Table where identifier2 is not null group by identifier order by count(identifier2) desc" + val parseDateTime: String = + """select identifier, count(identifier2) as ct, + |max( + |year( + |date_trunc( + |parse_datetime( + |createdAt, + |'yyyy-MM-ddTHH:mm:ssZ' + |), minute))) as lastSeen + |from Table + |where identifier2 is not null + |group by identifier + |order by count(identifier2) desc""".stripMargin + .replaceAll("\n", " ") + .replaceAll("\\( ", "(") + .replaceAll(" \\)", ")") + + val dateDiff = "select date_diff(createdAt, updatedAt, day) as diff, identifier from Table" + + val aggregationWithDateDiff = + "select max(date_diff(parse_datetime(createdAt, 'yyyy-MM-ddTHH:mm:ssZ'), updatedAt, day)) as max_diff from Table group by identifier" + + val formatDate = + "select identifier, format_date(date_trunc(lastUpdated, month), 'yyyy-MM-dd') as lastSeen from Table where identifier2 is not null" + val formatDateTime = + "select identifier, format_datetime(date_trunc(lastUpdated, month), 'yyyy-MM-ddThh:mm:ssZ') as lastSeen from Table where identifier2 is not null" + val dateAdd = + "select identifier, date_add(lastUpdated, interval 10 day) as lastSeen from Table where identifier2 is not null" + val dateSub = + "select identifier, date_sub(lastUpdated, interval 10 day) as lastSeen from Table where identifier2 is not null" + val dateTimeAdd = + "select identifier, datetime_add(lastUpdated, interval 10 day) as lastSeen from Table where identifier2 is not null" + val dateTimeSub = + "select identifier, datetime_sub(lastUpdated, interval 10 day) as lastSeen from Table where identifier2 is not null" + } /** Created by smanciot on 15/02/17. @@ -343,4 +397,113 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===(groupByWithHaving) } + it should "parse date time fields" in { + val result = SQLParser(dateTimeWithIntervalFields) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + dateTimeWithIntervalFields + ) + } + + it should "parse fields with interval" in { + val result = SQLParser(fieldsWithInterval) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===(fieldsWithInterval) + } + + it should "parse filter with date time and interval" in { + val result = SQLParser(filterWithDateTimeAndInterval) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + filterWithDateTimeAndInterval + ) + } + + it should "parse filter with date and interval" in { + val result = SQLParser(filterWithDateAndInterval) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + filterWithDateAndInterval + ) + } + + it should "parse filter with time and interval" in { + val result = SQLParser(filterWithTimeAndInterval) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + filterWithTimeAndInterval + ) + } + + it should "parse group by with having and date time functions" in { + val result = SQLParser(groupByWithHavingAndDateTimeFunctions) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + groupByWithHavingAndDateTimeFunctions + ) + } + + it should "parse parse_date function" in { + val result = SQLParser(parseDate) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + parseDate + ) + } + + it should "parse parse_date_time function" in { + val result = SQLParser(parseDateTime) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + parseDateTime + ) + } + + it should "parse date_diff function" in { + val result = SQLParser(dateDiff) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + dateDiff + ) + } + + it should "parse date_diff function with aggregation" in { + val result = SQLParser(aggregationWithDateDiff) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + aggregationWithDateDiff + ) + } + + it should "parse format_date function" in { + val result = SQLParser(formatDate) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + formatDate + ) + } + + it should "parse format_datetime function" in { + val result = SQLParser(formatDateTime) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + formatDateTime + ) + } + + it should "parse date_add function" in { + val result = SQLParser(dateAdd) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + dateAdd + ) + } + + it should "parse date_sub function" in { + val result = SQLParser(dateSub) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + dateSub + ) + } + + it should "parse datetime_add function" in { + val result = SQLParser(dateTimeAdd) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + dateTimeAdd + ) + } + + it should "parse datetime_sub function" in { + val result = SQLParser(dateTimeSub) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + dateTimeSub + ) + } }