From d8e42b82f692c19cd86d4061cd73dd516cccf355 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Thu, 4 Sep 2025 13:08:53 +0200 Subject: [PATCH 01/22] add Field, ScriptField and PainlessScript traits, add date time functions, update parser for date time script fields --- .../sql/bridge/ElasticAggregation.scala | 4 +- .../sql/bridge/ElasticSearchRequest.scala | 4 +- .../sql/bridge/ElasticAggregation.scala | 4 +- .../sql/bridge/ElasticSearchRequest.scala | 4 +- .../softnetwork/elastic/sql/SQLFunction.scala | 95 +++++++++++++++++++ .../softnetwork/elastic/sql/SQLGroupBy.scala | 12 +-- .../softnetwork/elastic/sql/SQLOperator.scala | 9 ++ .../softnetwork/elastic/sql/SQLParser.scala | 74 ++++++++++++++- .../elastic/sql/SQLSearchRequest.scala | 2 +- .../softnetwork/elastic/sql/SQLSelect.scala | 57 +++++++++-- .../app/softnetwork/elastic/sql/package.scala | 12 ++- .../elastic/sql/SQLParserSpec.scala | 27 +++++- 12 files changed, 269 insertions(+), 35 deletions(-) diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index bb6fbcb2..ff17e7e8 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -7,11 +7,11 @@ import app.softnetwork.elastic.sql.{ BucketSelectorScript, Count, ElasticBoolQuery, + Field, Max, Min, SQLBucket, SQLCriteria, - SQLField, SortOrder, Sum } @@ -57,7 +57,7 @@ case class ElasticAggregation( object ElasticAggregation { def apply( - sqlAgg: SQLField, + sqlAgg: Field, having: Option[SQLCriteria], bucketsDirection: Map[String, SortOrder] ): ElasticAggregation = { diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala index 35950d9c..3c451a43 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala @@ -1,11 +1,11 @@ package app.softnetwork.elastic.sql.bridge -import app.softnetwork.elastic.sql.{SQLBucket, SQLCriteria, SQLExcept, SQLField} +import app.softnetwork.elastic.sql.{Field, SQLBucket, SQLCriteria, SQLExcept} import com.sksamuel.elastic4s.searches.SearchRequest import com.sksamuel.elastic4s.http.search.SearchBodyBuilderFn case class ElasticSearchRequest( - fields: Seq[SQLField], + fields: Seq[Field], except: Option[SQLExcept], sources: Seq[String], criteria: Option[SQLCriteria], diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index cf365736..4d0894f4 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -7,11 +7,11 @@ import app.softnetwork.elastic.sql.{ BucketSelectorScript, Count, ElasticBoolQuery, + Field, Max, Min, SQLBucket, SQLCriteria, - SQLField, SortOrder, Sum } @@ -56,7 +56,7 @@ case class ElasticAggregation( object ElasticAggregation { def apply( - sqlAgg: SQLField, + sqlAgg: Field, having: Option[SQLCriteria], bucketsDirection: Map[String, SortOrder] ): ElasticAggregation = { diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala index bfd1bc7f..adcf87e1 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala @@ -1,10 +1,10 @@ package app.softnetwork.elastic.sql.bridge -import app.softnetwork.elastic.sql.{SQLBucket, SQLCriteria, SQLExcept, SQLField} +import app.softnetwork.elastic.sql.{SQLBucket, SQLCriteria, SQLExcept, Field} import com.sksamuel.elastic4s.requests.searches.{SearchBodyBuilderFn, SearchRequest} case class ElasticSearchRequest( - fields: Seq[SQLField], + fields: Seq[Field], except: Option[SQLExcept], sources: Seq[String], criteria: Option[SQLCriteria], diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 31b1fe81..1712e01f 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -1,5 +1,7 @@ package app.softnetwork.elastic.sql +import scala.util.matching.Regex + sealed trait SQLFunction extends SQLRegex sealed trait AggregateFunction extends SQLFunction @@ -10,3 +12,96 @@ case object Avg extends SQLExpr("avg") with AggregateFunction case object Sum extends SQLExpr("sum") with AggregateFunction case object Distance extends SQLExpr("distance") with SQLFunction with SQLOperator + +sealed trait TimeUnit extends PainlessScript { + lazy val regex: Regex = s"\\b(?i)${sql}s?\\b".r + + override def painless: String = s"ChronoUnit.${sql.toUpperCase()}" +} + +sealed trait CalendarUnit extends TimeUnit +sealed trait FixedUnit extends TimeUnit + +case object Year extends SQLExpr("year") with CalendarUnit +case object Month extends SQLExpr("month") with CalendarUnit +case object Quarter extends SQLExpr("quarter") with CalendarUnit +case object Week extends SQLExpr("week") with CalendarUnit + +case object Day extends SQLExpr("day") with CalendarUnit with FixedUnit + +case object Hour extends SQLExpr("hour") with FixedUnit +case object Minute extends SQLExpr("minute") with FixedUnit +case object Second extends SQLExpr("second") with FixedUnit + +case object Interval extends SQLExpr("interval") with SQLFunction with SQLRegex + +sealed trait TimeInterval extends PainlessScript { + def value: Int + def unit: TimeUnit + override def sql: String = s"$Interval $value ${unit.sql}" + + override def painless: String = s"$value, ${unit.painless}" +} + +case class CalendarInterval(value: Int, unit: CalendarUnit) extends TimeInterval +case class FixedInterval(value: Int, unit: FixedUnit) extends TimeInterval + +object TimeInterval { + def apply(value: Int, unit: TimeUnit): TimeInterval = unit match { + case cu: CalendarUnit => CalendarInterval(value, cu) + case fu: FixedUnit => FixedInterval(value, fu) + } +} + +sealed trait DateTimeFunction extends SQLFunction + +case object CurrentDate extends SQLExpr("current_date") with DateTimeFunction with PainlessScript { + override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalDate()" +} +case object CurentDateWithParens + extends SQLExpr("current_date()") + with DateTimeFunction + with PainlessScript { + override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalDate()" +} +case object CurrentTime extends SQLExpr("current_time") with DateTimeFunction with PainlessScript { + override def painless: String = + "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalTime()" +} +case object CurrentTimeWithParens + extends SQLExpr("current_time()") + with DateTimeFunction + with PainlessScript { + override def painless: String = + "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalTime()" +} +case object CurrentTimestamp + extends SQLExpr("current_timestamp") + with DateTimeFunction + with PainlessScript { + override def painless: String = + "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" +} +case object CurrentTimestampWithParens + extends SQLExpr("current_timestamp()") + with DateTimeFunction + with PainlessScript { + override def painless: String = + "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" +} +case object Now extends SQLExpr("now") with DateTimeFunction with PainlessScript { + override def painless: String = + "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" +} +case object NowWithParens extends SQLExpr("now()") with DateTimeFunction with PainlessScript { + override def painless: String = + "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" +} + +case class DateAdd(interval: TimeInterval) extends SQLExpr("date_add") with DateTimeFunction +case class DateDiff(interval: TimeInterval) extends SQLExpr("date_diff") with DateTimeFunction +case class DateSub(interval: TimeInterval) extends SQLExpr("date_sub") with DateTimeFunction +case class DateTrunc(unit: TimeUnit) extends SQLExpr("date_trunc") with DateTimeFunction +case class Extract(unit: TimeUnit) extends SQLExpr("extract") with DateTimeFunction +case class FormatDate(format: String) extends SQLExpr("format_date") with DateTimeFunction +case class ParseDate(format: String) extends SQLExpr("parse_date") with DateTimeFunction diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala index 4903c3ac..ee6bc3e6 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala @@ -34,7 +34,7 @@ case class SQLBucket( object BucketSelectorScript { private[this] def painlessIn(param: String, values: Seq[SQLValue[_]], not: Boolean): String = { - val ret = s"[${values.map { _.painlessValue }.mkString(", ")}].contains($param)" + val ret = s"[${values.map { _.painless }.mkString(", ")}].contains($param)" if (not) s"!$ret" else ret } @@ -44,7 +44,7 @@ object BucketSelectorScript { upper: SQLValue[_], not: Boolean ): String = { - val ret = s"($param >= ${lower.painlessValue} && $param <= ${upper.painlessValue})" + val ret = s"($param >= ${lower.painless} && $param <= ${upper.painless})" if (not) s"!$ret" else ret } @@ -58,10 +58,10 @@ object BucketSelectorScript { case _: SQLComparisonOperator => val valueStr = value match { - case v: SQLBoolean => v.painlessValue - case v: SQLDouble => v.painlessValue - case v: SQLLiteral => v.painlessValue - case v: SQLLong => v.painlessValue + case v: SQLBoolean => v.painless + case v: SQLDouble => v.painless + case v: SQLLiteral => v.painless + case v: SQLLong => v.painless case _ => throw new IllegalArgumentException( s"Unsupported value type in bucket_selector: $value" diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala index 5df94ea8..e91cd82b 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala @@ -2,6 +2,15 @@ package app.softnetwork.elastic.sql trait SQLOperator extends SQLToken +sealed trait ArithmeticOperator extends SQLOperator { + override def toString: String = s" $sql " +} +case object Plus extends SQLExpr("+") with ArithmeticOperator +case object Minus extends SQLExpr("-") with ArithmeticOperator +case object Multiply extends SQLExpr("*") with ArithmeticOperator +case object Divide extends SQLExpr("/") with ArithmeticOperator +case object Modulo extends SQLExpr("%") with ArithmeticOperator + sealed trait SQLExpressionOperator extends SQLOperator sealed trait SQLComparisonOperator extends SQLExpressionOperator diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 00922443..c260c87e 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -85,6 +85,63 @@ trait SQLParser extends RegexParsers with PackratParsers { def sum: PackratParser[AggregateFunction] = Sum.regex ^^ (_ => Sum) + def year: PackratParser[TimeUnit] = Year.regex ^^ (_ => Year) + + def month: PackratParser[TimeUnit] = Month.regex ^^ (_ => Year) + + def quarter: PackratParser[TimeUnit] = Quarter.regex ^^ (_ => Quarter) + + def week: PackratParser[TimeUnit] = Week.regex ^^ (_ => Week) + + def day: PackratParser[TimeUnit] = Day.regex ^^ (_ => Day) + + def hour: PackratParser[TimeUnit] = Hour.regex ^^ (_ => Hour) + + def minute: PackratParser[TimeUnit] = Minute.regex ^^ (_ => Minute) + + def second: PackratParser[TimeUnit] = Second.regex ^^ (_ => Second) + + def interval: PackratParser[TimeInterval] = + Interval.regex ~ long ~ (year | month | quarter | week | day | hour | minute | second) ^^ { + case _ ~ l ~ u => + TimeInterval(l.value.toInt, u) + } + + def current_date: PackratParser[DateTimeFunction] = CurrentDate.regex ~ start.? ~ end.? ^^ { + case _ ~ s ~ t => + if (s.isDefined && t.isDefined) CurentDateWithParens else CurrentDate + } + + def current_time: PackratParser[DateTimeFunction] = CurrentTime.regex ~ start.? ~ end.? ^^ { + case _ ~ s ~ t => + if (s.isDefined && t.isDefined) CurrentTimeWithParens else CurrentTime + } + + def current_timestamp: PackratParser[DateTimeFunction] = + CurrentTimestamp.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => + if (s.isDefined && t.isDefined) CurrentTimestampWithParens else CurrentTimestamp + } + + def now: PackratParser[DateTimeFunction] = Now.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => + if (s.isDefined && t.isDefined) NowWithParens else Now + } + + def plus: PackratParser[ArithmeticOperator] = Plus.sql ^^ (_ => Plus) + + def minus: PackratParser[ArithmeticOperator] = Minus.sql ^^ (_ => Minus) + + def arithmeticOperator: PackratParser[ArithmeticOperator] = plus | minus + + def dateTimeWithInterval: PackratParser[SQLDateTimeField] = + (current_date | current_time | current_timestamp | now) ~ arithmeticOperator.? ~ interval.? ^^ { + case f ~ o ~ i => + SQLDateTimeField( + SQLIdentifier(f.sql), + o, + i + ) + } + def aggregateFunction: PackratParser[AggregateFunction] = count | min | max | avg | sum def distanceFunction: PackratParser[SQLFunction] = Distance.regex ^^ (_ => Distance) @@ -107,16 +164,29 @@ trait SQLParser extends RegexParsers with PackratParsers { ) } + def identifierWithInterval: PackratParser[SQLDateTimeField] = + identifier ~ arithmeticOperator ~ interval ^^ { case f ~ o ~ i => + SQLDateTimeField( + f, + Some(o), + Some(i) + ) + } + private val regexAlias = """\b(?!(?i)as\b)\b(?!(?i)except\b)\b(?!(?i)where\b)\b(?!(?i)filter\b)\b(?!(?i)from\b)\b(?!(?i)group\b)\b(?!(?i)having\b)\b(?!(?i)order\b)\b(?!(?i)limit\b)[a-zA-Z0-9_]*""" def alias: PackratParser[SQLAlias] = Alias.regex.? ~ regexAlias.r ^^ { case _ ~ b => SQLAlias(b) } - def field: PackratParser[SQLField] = (identifierWithFunction | identifier) ~ alias.? ^^ { + def field: PackratParser[Field] = (identifierWithFunction | identifier) ~ alias.? ^^ { case i ~ a => SQLField(i, a) } + def scriptField: PackratParser[ScriptField] = + (dateTimeWithInterval | identifierWithInterval) ~ alias.? ^^ { case d ~ a => + d.copy(fieldAlias = a) + } } trait SQLSelectParser { @@ -128,7 +198,7 @@ trait SQLSelectParser { } def select: PackratParser[SQLSelect] = - Select.regex ~ rep1sep(field, separator) ~ except.? ^^ { case _ ~ fields ~ e => + Select.regex ~ rep1sep(scriptField | field, separator) ~ except.? ^^ { case _ ~ fields ~ e => SQLSelect(fields, e) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala index fb415157..3ccc834e 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala @@ -37,7 +37,7 @@ case class SQLSearchRequest( Seq.empty } - lazy val aggregates: Seq[SQLField] = select.fields.filter(_.aggregation) + lazy val aggregates: Seq[Field] = select.fields.filter(_.aggregation) lazy val excludes: Seq[String] = select.except.map(_.fields.map(_.sourceField)).getOrElse(Nil) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala index 18e0da9d..ae5b41dc 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala @@ -2,14 +2,11 @@ package app.softnetwork.elastic.sql case object Select extends SQLExpr("select") with SQLRegex -case class SQLField( - identifier: SQLIdentifier, - fieldAlias: Option[SQLAlias] = None -) extends Updateable - with SQLTokenWithFunction { +sealed trait Field extends Updateable with SQLTokenWithFunction { + def identifier: SQLIdentifier + def fieldAlias: Option[SQLAlias] + def isScriptField: Boolean = false override def sql: String = s"$identifier${asString(fieldAlias)}" - def update(request: SQLSearchRequest): SQLField = - this.copy(identifier = identifier.update(request)) lazy val sourceField: String = if (identifier.nested) { identifier.tableAlias @@ -21,22 +18,62 @@ case class SQLField( } override def function: Option[SQLFunction] = identifier.function + + def update(request: SQLSearchRequest): Field +} + +case class SQLField( + identifier: SQLIdentifier, + fieldAlias: Option[SQLAlias] = None +) extends Field { + def update(request: SQLSearchRequest): SQLField = + this.copy(identifier = identifier.update(request)) +} + +sealed trait ScriptField extends Field with PainlessScript { + override def isScriptField: Boolean = true + + def update(request: SQLSearchRequest): ScriptField +} + +case class SQLDateTimeField( + identifier: SQLIdentifier, + operator: Option[ArithmeticOperator] = None, + interval: Option[TimeInterval], + fieldAlias: Option[SQLAlias] = None +) extends ScriptField { + override def sql: String = + s"$identifier${asString(operator)}${asString(interval)}${asString(fieldAlias)}" + def update(request: SQLSearchRequest): SQLDateTimeField = + this.copy(identifier = identifier.update(request)) + override def painless: String = { + val base = identifier.function match { + case f @ Some(CurrentDate | CurrentTime | CurrentTimestamp | Now) => + f.asInstanceOf[PainlessScript].painless + case _ => s"doc['$sourceField'].value" + } + (operator, interval) match { + case (Some(Minus), Some(i)) => s"$base.minus(${i.painless})" + case (Some(Plus), Some(i)) => s"$base.plus(${i.painless})" + case _ => base + } + } } case object Except extends SQLExpr("except") with SQLRegex -case class SQLExcept(fields: Seq[SQLField]) extends Updateable { +case class SQLExcept(fields: Seq[Field]) extends Updateable { override def sql: String = s" $Except(${fields.mkString(",")})" def update(request: SQLSearchRequest): SQLExcept = this.copy(fields = fields.map(_.update(request))) } case class SQLSelect( - fields: Seq[SQLField] = Seq(SQLField(identifier = SQLIdentifier("*"))), + fields: Seq[Field] = Seq(SQLField(identifier = SQLIdentifier("*"))), except: Option[SQLExcept] = None ) extends Updateable { override def sql: String = - s"$Select ${fields.mkString(",")}${except.getOrElse("")}" + s"$Select ${fields.mkString(", ")}${except.getOrElse("")}" lazy val fieldAliases: Map[String, String] = fields.flatMap { field => field.fieldAlias.map(a => field.identifier.identifierName -> a.alias) }.toMap diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index e7316baa..063f70c1 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -12,7 +12,7 @@ package object sql { import scala.language.implicitConversions implicit def asString(token: Option[_ <: SQLToken]): String = token match { - case Some(t) => t.sql + case Some(t) => t.toString case _ => "" } @@ -21,6 +21,10 @@ package object sql { override def toString: String = sql } + trait PainlessScript extends SQLToken { + def painless: String + } + trait SQLTokenWithFunction extends SQLToken { def function: Option[SQLFunction] @@ -40,7 +44,9 @@ package object sql { case object Distinct extends SQLExpr("distinct") with SQLRegex - abstract class SQLValue[+T](val value: T)(implicit ev$1: T => Ordered[T]) extends SQLToken { + abstract class SQLValue[+T](val value: T)(implicit ev$1: T => Ordered[T]) + extends SQLToken + with PainlessScript { def choose[R >: T]( values: Seq[R], operator: Option[SQLExpressionOperator], @@ -59,7 +65,7 @@ package object sql { case _ => values.headOption } } - def painlessValue: String = value match { + def painless: String = value match { case s: String => s""""$s"""" case b: Boolean => b.toString case n: Number => n.toString diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 93758fba..91d7e38c 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -4,7 +4,7 @@ import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers object Queries { - val numericalEq = "select t.col1,t.col2 from Table as t where t.identifier = 1.0" + val numericalEq = "select t.col1, t.col2 from Table as t where t.identifier = 1.0" val numericalLt = "select * from Table where identifier < 1" val numericalLe = "select * from Table where identifier <= 1" val numericalGt = "select * from Table where identifier > 1" @@ -63,22 +63,28 @@ object Queries { val matchCriteria = "select * from Table where match (identifier1,identifier2,identifier3) against (\"value\")" val groupBy = - "select identifier,count(identifier) from Table where identifier is not null group by identifier" + "select identifier, count(identifier2) from Table where identifier2 is not null group by identifier" val orderBy = "select * from Table order by identifier desc" val limit = "select * from Table limit 10" val groupByWithOrderByAndLimit: String = - """select identifier,count(identifier) + """select identifier, count(identifier2) |from Table |where identifier is not null |group by identifier - |order by identifier desc + |order by identifier2 desc |limit 10""".stripMargin.replaceAll("\n", " ") val groupByWithHaving: String = - """SELECT COUNT(CustomerID) as cnt,City,Country + """SELECT COUNT(CustomerID) as cnt, City, Country |FROM Customers |GROUP BY Country,City |HAVING Country <> "USA" AND City <> "Berlin" AND COUNT(CustomerID) > 1 |ORDER BY COUNT(CustomerID) DESC,Country asc""".stripMargin.replaceAll("\n", " ").toLowerCase + val dateTimeWithIntervalFields: String = + "select current_timestamp() - interval 3 day as ct, current_date as cd, current_time as t, now as n from dual" + val fieldsWithInterval: String = + "select createdAt - interval 35 minute as ct, identifier from Table" + + //TODO "select * from Table where createdAt <= current_timestamp() and createdAt >= current_timestamp() - interval 35 minute" } /** Created by smanciot on 15/02/17. @@ -343,4 +349,15 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===(groupByWithHaving) } + it should "parse date time fields" in { + val result = SQLParser(dateTimeWithIntervalFields) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + dateTimeWithIntervalFields + ) + } + + it should "parse fields with interval" in { + val result = SQLParser(fieldsWithInterval) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===(fieldsWithInterval) + } } From f1c3944fada2d9afe53d6c8d68ffd78a4a88e238 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Thu, 4 Sep 2025 16:11:03 +0200 Subject: [PATCH 02/22] add script fields to elasticsearch request --- .../elastic/sql/bridge/package.scala | 12 +++++++++ .../elastic/sql/SQLQuerySpec.scala | 26 ++++++++++++++++++- .../elastic/sql/bridge/package.scala | 12 +++++++++ .../elastic/sql/SQLQuerySpec.scala | 26 ++++++++++++++++++- .../elastic/sql/SQLSearchRequest.scala | 10 ++++++- .../softnetwork/elastic/sql/SQLSelect.scala | 11 ++++++-- 6 files changed, 92 insertions(+), 5 deletions(-) diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index b1167f7c..ff88cac6 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -4,6 +4,7 @@ import com.sksamuel.elastic4s.ElasticApi import com.sksamuel.elastic4s.ElasticApi._ import com.sksamuel.elastic4s.http.ElasticDsl.BuildableTermsNoOp import com.sksamuel.elastic4s.http.search.SearchBodyBuilderFn +import com.sksamuel.elastic4s.script.Script import com.sksamuel.elastic4s.searches.aggs.Aggregation import com.sksamuel.elastic4s.searches.queries.Query import com.sksamuel.elastic4s.searches.{MultiSearchRequest, SearchRequest} @@ -95,6 +96,17 @@ package object bridge { } } + _search = scriptFields match { + case Nil => _search + case _ => + _search scriptfields scriptFields.map { field => + scriptField( + field.name, + Script(script = field.painless).lang("painless").scriptType("source") + ) + } + } + _search = orderBy match { case Some(o) if aggregates.isEmpty && buckets.isEmpty => _search sortBy o.sorts.map(sort => diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 58f76b96..b5c44631 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -6,7 +6,7 @@ import com.google.gson.{JsonArray, JsonObject, JsonParser, JsonPrimitive} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -import scala.collection.JavaConverters._ +import scala.jdk.CollectionConverters._ /** Created by smanciot on 13/04/17. */ @@ -868,4 +868,28 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { } + it should "add script fields" in { + val select: ElasticSearchRequest = + SQLQuery(fieldsWithInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "ct": { + | "script": { + | "lang": "painless", + | "source": "doc['createdAt'].value.minus(35, ChronoUnit.MINUTE)" + | } + | } + | }, + | "_source": { + | "includes": ["identifier"] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + + } } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index e051841d..ac9c468c 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -2,6 +2,7 @@ package app.softnetwork.elastic.sql import com.sksamuel.elastic4s.ElasticApi import com.sksamuel.elastic4s.ElasticApi._ +import com.sksamuel.elastic4s.requests.script.Script import com.sksamuel.elastic4s.requests.searches.aggs.Aggregation import com.sksamuel.elastic4s.requests.searches.queries.Query import com.sksamuel.elastic4s.requests.searches.sort.FieldSort @@ -96,6 +97,17 @@ package object bridge { } } + _search = scriptFields match { + case Nil => _search + case _ => + _search scriptfields scriptFields.map { field => + scriptField( + field.name, + Script(script = field.painless).lang("painless").scriptType("source") + ) + } + } + _search = orderBy match { case Some(o) if aggregates.isEmpty && buckets.isEmpty => _search sortBy o.sorts.map(sort => diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index ce96ddf5..6e5163de 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -6,7 +6,7 @@ import com.google.gson.{JsonArray, JsonObject, JsonParser, JsonPrimitive} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -import scala.collection.JavaConverters._ +import scala.jdk.CollectionConverters._ /** Created by smanciot on 13/04/17. */ @@ -867,4 +867,28 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { } + it should "add script fields" in { + val select: ElasticSearchRequest = + SQLQuery(fieldsWithInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "ct": { + | "script": { + | "lang": "painless", + | "source": "doc['createdAt'].value.minus(35, ChronoUnit.MINUTE)" + | } + | } + | }, + | "_source": { + | "includes": ["identifier"] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + + } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala index 3ccc834e..40c87efb 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala @@ -30,9 +30,17 @@ case class SQLSearchRequest( ) } + lazy val scriptFields: Seq[ScriptField] = select.fields.flatMap { + case s: ScriptField => Some(s) + case _ => None + } + lazy val fields: Seq[String] = { if (aggregates.isEmpty && buckets.isEmpty) - select.fields.map(_.sourceField).filterNot(f => excludes.contains(f)) + select.fields + .filterNot(_.isScriptField) + .map(_.sourceField) + .filterNot(f => excludes.contains(f)) else Seq.empty } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala index ae5b41dc..1f788fc2 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala @@ -12,9 +12,14 @@ sealed trait Field extends Updateable with SQLTokenWithFunction { identifier.tableAlias .orElse(fieldAlias.map(_.alias)) .map(a => s"$a.") - .getOrElse("") + identifier.name.split("\\.").tail.mkString(".") + .getOrElse("") + identifier.name + .replace("(", "") + .replace(")", "") + .split("\\.") + .tail + .mkString(".") } else { - identifier.name + identifier.name.replace("(", "").replace(")", "") } override def function: Option[SQLFunction] = identifier.function @@ -34,6 +39,8 @@ sealed trait ScriptField extends Field with PainlessScript { override def isScriptField: Boolean = true def update(request: SQLSearchRequest): ScriptField + + lazy val name: String = fieldAlias.map(_.alias).getOrElse(sourceField) } case class SQLDateTimeField( From 8c3f8ad123e1f2571e5ce10bd529ba4054c9aa1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Thu, 4 Sep 2025 22:24:38 +0200 Subject: [PATCH 03/22] add support to handle date time comparisons criteria --- .../elastic/sql/bridge/ElasticQuery.scala | 4 +- .../elastic/sql/bridge/package.scala | 18 +++ .../elastic/sql/SQLQuerySpec.scala | 111 ++++++++++++++++++ .../elastic/sql/bridge/ElasticQuery.scala | 4 +- .../elastic/sql/bridge/package.scala | 18 +++ .../elastic/sql/SQLQuerySpec.scala | 111 ++++++++++++++++++ .../softnetwork/elastic/sql/SQLFunction.scala | 106 +++++++++-------- .../softnetwork/elastic/sql/SQLGroupBy.scala | 32 +---- .../softnetwork/elastic/sql/SQLOperator.scala | 17 ++- .../softnetwork/elastic/sql/SQLParser.scala | 38 +++--- .../softnetwork/elastic/sql/SQLWhere.scala | 50 ++++++++ .../app/softnetwork/elastic/sql/package.scala | 4 + .../elastic/sql/SQLParserSpec.scala | 29 ++++- 13 files changed, 448 insertions(+), 94 deletions(-) diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala index ec1d8bd7..c4943563 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala @@ -9,12 +9,13 @@ import app.softnetwork.elastic.sql.{ ElasticNested, ElasticParent, SQLBetween, + SQLComparisonDateMath, SQLExpression, SQLIn, SQLIsNotNull, SQLIsNull } -import com.sksamuel.elastic4s.ElasticApi.{bool, _} +import com.sksamuel.elastic4s.ElasticApi._ import com.sksamuel.elastic4s.searches.queries.Query case class ElasticQuery(filter: ElasticFilter) { @@ -69,6 +70,7 @@ case class ElasticQuery(filter: ElasticFilter) { case between: SQLBetween[Double] => between case geoDistance: ElasticGeoDistance => geoDistance case matchExpression: ElasticMatch => matchExpression + case dateMath: SQLComparisonDateMath => dateMath case other => throw new IllegalArgumentException(s"Unsupported filter type: ${other.getClass.getName}") } diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index ff88cac6..650c19e7 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -301,6 +301,24 @@ package object bridge { } } + implicit def dateMathToQuery(dateMath: SQLComparisonDateMath): Query = { + import dateMath._ + dateTimeFunction match { + case _: CurrentTimeFunction => + scriptQuery(Script(script = script).lang("painless").scriptType("source")) + case _ => + val op = if (maybeNot.isDefined) operator.not else operator + op match { + case Gt => rangeQuery(identifier.name) gt script + case Ge => rangeQuery(identifier.name) gte script + case Lt => rangeQuery(identifier.name) lt script + case Le => rangeQuery(identifier.name) lte script + case Eq => rangeQuery(identifier.name) gte script lte script + case Ne => rangeQuery(identifier.name) lt script gt script + } + } + } + implicit def isNullToQuery( isNull: SQLIsNull ): Query = { diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index b5c44631..881a62d6 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -890,6 +890,117 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "includes": ["identifier"] | } |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "filter with date time and interval" in { + val select: ElasticSearchRequest = + SQLQuery(filterWithDateTimeAndInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "range": { + | "createdAt": { + | "lt": "now" + | } + | } + | }, + | { + | "range": { + | "createdAt": { + | "gte": "now-10d" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + it should "filter with date and interval" in { + val select: ElasticSearchRequest = + SQLQuery(filterWithDateAndInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "range": { + | "createdAt": { + | "lt": "now/d" + | } + | } + | }, + | { + | "range": { + | "createdAt": { + | "gte": "now-10d/d" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "filter with time and interval" in { + val select: ElasticSearchRequest = + SQLQuery(filterWithTimeAndInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "return doc['createdAt'].value.toLocalTime() < LocalTime.now();" + | } + | } + | }, + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "return doc['createdAt'].value.toLocalTime() >= LocalTime.now().minus(10, ChronoUnit.MINUTE);" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(">=", " >= ") + .replaceAll("<", " < ") + .replaceAll("return", "return ") } } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala index 746d4af8..9de90dd8 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticQuery.scala @@ -9,12 +9,13 @@ import app.softnetwork.elastic.sql.{ ElasticNested, ElasticParent, SQLBetween, + SQLComparisonDateMath, SQLExpression, SQLIn, SQLIsNotNull, SQLIsNull } -import com.sksamuel.elastic4s.ElasticApi.{bool, _} +import com.sksamuel.elastic4s.ElasticApi._ import com.sksamuel.elastic4s.requests.searches.queries.Query case class ElasticQuery(filter: ElasticFilter) { @@ -69,6 +70,7 @@ case class ElasticQuery(filter: ElasticFilter) { case between: SQLBetween[Double] => between case geoDistance: ElasticGeoDistance => geoDistance case matchExpression: ElasticMatch => matchExpression + case dateMath: SQLComparisonDateMath => dateMath case other => throw new IllegalArgumentException(s"Unsupported filter type: ${other.getClass.getName}") } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index ac9c468c..12587c94 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -302,6 +302,24 @@ package object bridge { } } + implicit def dateMathToQuery(dateMath: SQLComparisonDateMath): Query = { + import dateMath._ + dateTimeFunction match { + case _: CurrentTimeFunction => + scriptQuery(Script(script = script).lang("painless").scriptType("source")) + case _ => + val op = if (maybeNot.isDefined) operator.not else operator + op match { + case Gt => rangeQuery(identifier.name) gt script + case Ge => rangeQuery(identifier.name) gte script + case Lt => rangeQuery(identifier.name) lt script + case Le => rangeQuery(identifier.name) lte script + case Eq => rangeQuery(identifier.name) gte script lte script + case Ne => rangeQuery(identifier.name) lt script gt script + } + } + } + implicit def isNullToQuery( isNull: SQLIsNull ): Query = { diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 6e5163de..c6cdfb3c 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -889,6 +889,117 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "includes": ["identifier"] | } |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "filter with date time and interval" in { + val select: ElasticSearchRequest = + SQLQuery(filterWithDateTimeAndInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "range": { + | "createdAt": { + | "lt": "now" + | } + | } + | }, + | { + | "range": { + | "createdAt": { + | "gte": "now-10d" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "filter with date and interval" in { + val select: ElasticSearchRequest = + SQLQuery(filterWithDateAndInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "range": { + | "createdAt": { + | "lt": "now/d" + | } + | } + | }, + | { + | "range": { + | "createdAt": { + | "gte": "now-10d/d" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + it should "filter with time and interval" in { + val select: ElasticSearchRequest = + SQLQuery(filterWithTimeAndInterval) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "return doc['createdAt'].value.toLocalTime() < LocalTime.now();" + | } + | } + | }, + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "return doc['createdAt'].value.toLocalTime() >= LocalTime.now().minus(10, ChronoUnit.MINUTE);" + | } + | } + | } + | ] + | } + | }, + | "_source": { + | "includes": [ + | "*" + | ] + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(">=", " >= ") + .replaceAll("<", " < ") + .replaceAll("return", "return ") } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 1712e01f..3d3d9936 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -13,7 +13,7 @@ case object Sum extends SQLExpr("sum") with AggregateFunction case object Distance extends SQLExpr("distance") with SQLFunction with SQLOperator -sealed trait TimeUnit extends PainlessScript { +sealed trait TimeUnit extends PainlessScript with MathScript { lazy val regex: Regex = s"\\b(?i)${sql}s?\\b".r override def painless: String = s"ChronoUnit.${sql.toUpperCase()}" @@ -22,25 +22,45 @@ sealed trait TimeUnit extends PainlessScript { sealed trait CalendarUnit extends TimeUnit sealed trait FixedUnit extends TimeUnit -case object Year extends SQLExpr("year") with CalendarUnit -case object Month extends SQLExpr("month") with CalendarUnit -case object Quarter extends SQLExpr("quarter") with CalendarUnit -case object Week extends SQLExpr("week") with CalendarUnit +case object Year extends SQLExpr("year") with CalendarUnit { + override def script: String = "y" +} +case object Month extends SQLExpr("month") with CalendarUnit { + override def script: String = "M" +} +case object Quarter extends SQLExpr("quarter") with CalendarUnit { + override def script: String = throw new IllegalArgumentException( + "Quarter must be converted to months (value * 3) before creating date-math" + ) +} +case object Week extends SQLExpr("week") with CalendarUnit { + override def script: String = "w" +} -case object Day extends SQLExpr("day") with CalendarUnit with FixedUnit +case object Day extends SQLExpr("day") with CalendarUnit with FixedUnit { + override def script: String = "d" +} -case object Hour extends SQLExpr("hour") with FixedUnit -case object Minute extends SQLExpr("minute") with FixedUnit -case object Second extends SQLExpr("second") with FixedUnit +case object Hour extends SQLExpr("hour") with FixedUnit { + override def script: String = "H" +} +case object Minute extends SQLExpr("minute") with FixedUnit { + override def script: String = "m" +} +case object Second extends SQLExpr("second") with FixedUnit { + override def script: String = "s" +} case object Interval extends SQLExpr("interval") with SQLFunction with SQLRegex -sealed trait TimeInterval extends PainlessScript { +sealed trait TimeInterval extends PainlessScript with MathScript { def value: Int def unit: TimeUnit override def sql: String = s"$Interval $value ${unit.sql}" override def painless: String = s"$value, ${unit.painless}" + + override def script: String = TimeInterval.script(this) } case class CalendarInterval(value: Int, unit: CalendarUnit) extends TimeInterval @@ -51,52 +71,46 @@ object TimeInterval { case cu: CalendarUnit => CalendarInterval(value, cu) case fu: FixedUnit => FixedInterval(value, fu) } + def script(interval: TimeInterval): String = interval match { + case CalendarInterval(v, Quarter) => s"${v * 3}M" + case CalendarInterval(v, u) => s"$v${u.script}" + case FixedInterval(v, u) => s"$v${u.script}" + } } sealed trait DateTimeFunction extends SQLFunction -case object CurrentDate extends SQLExpr("current_date") with DateTimeFunction with PainlessScript { - override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalDate()" -} -case object CurentDateWithParens - extends SQLExpr("current_date()") - with DateTimeFunction - with PainlessScript { - override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalDate()" -} -case object CurrentTime extends SQLExpr("current_time") with DateTimeFunction with PainlessScript { +sealed trait CurrentDateTimeFunction extends DateTimeFunction with PainlessScript with MathScript { override def painless: String = - "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalTime()" + "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" + override def script: String = "now" } -case object CurrentTimeWithParens - extends SQLExpr("current_time()") - with DateTimeFunction - with PainlessScript { - override def painless: String = - "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalTime()" + +sealed trait CurrentDateFunction extends CurrentDateTimeFunction { + override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalDate()" } -case object CurrentTimestamp - extends SQLExpr("current_timestamp") - with DateTimeFunction - with PainlessScript { - override def painless: String = - "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" + +sealed trait CurrentTimeFunction extends CurrentDateTimeFunction { + override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalTime()" } + +case object CurrentDate extends SQLExpr("current_date") with CurrentDateFunction + +case object CurentDateWithParens extends SQLExpr("current_date()") with CurrentDateFunction + +case object CurrentTime extends SQLExpr("current_time") with CurrentTimeFunction + +case object CurrentTimeWithParens extends SQLExpr("current_time()") with CurrentTimeFunction + +case object CurrentTimestamp extends SQLExpr("current_timestamp") with CurrentDateTimeFunction + case object CurrentTimestampWithParens extends SQLExpr("current_timestamp()") - with DateTimeFunction - with PainlessScript { - override def painless: String = - "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" -} -case object Now extends SQLExpr("now") with DateTimeFunction with PainlessScript { - override def painless: String = - "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" -} -case object NowWithParens extends SQLExpr("now()") with DateTimeFunction with PainlessScript { - override def painless: String = - "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" -} + with CurrentDateTimeFunction + +case object Now extends SQLExpr("now") with CurrentDateTimeFunction + +case object NowWithParens extends SQLExpr("now()") with CurrentDateTimeFunction case class DateAdd(interval: TimeInterval) extends SQLExpr("date_add") with DateTimeFunction case class DateDiff(interval: TimeInterval) extends SQLExpr("date_diff") with DateTimeFunction diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala index ee6bc3e6..73e977eb 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala @@ -55,7 +55,7 @@ object BucketSelectorScript { not: Boolean ): String = { operator match { - case _: SQLComparisonOperator => + case o: SQLComparisonOperator => val valueStr = value match { case v: SQLBoolean => v.painless @@ -67,32 +67,10 @@ object BucketSelectorScript { s"Unsupported value type in bucket_selector: $value" ) } - if (not) { - operator match { - case Eq => s"$param != $valueStr" - case Ne => s"$param == $valueStr" - case Gt => s"$param <= $valueStr" - case Ge => s"$param < $valueStr" - case Lt => s"$param >= $valueStr" - case Le => s"$param > $valueStr" - case _ => - throw new IllegalArgumentException( - s"Unsupported comparison operator in bucket_selector: $operator" - ) - } - } else - operator match { - case Eq => s"$param == $valueStr" - case Ne => s"$param != $valueStr" - case Gt => s"$param > $valueStr" - case Ge => s"$param >= $valueStr" - case Lt => s"$param < $valueStr" - case Le => s"$param <= $valueStr" - case _ => - throw new IllegalArgumentException( - s"Unsupported comparison operator in bucket_selector: $operator" - ) - } + if (not) + s"$param ${o.not.painless} $valueStr" + else + s"$param ${o.painless} $valueStr" case In => value match { case SQLDoubleValues(vals) => painlessIn(param, vals, not) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala index e91cd82b..2bc52639 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala @@ -13,7 +13,22 @@ case object Modulo extends SQLExpr("%") with ArithmeticOperator sealed trait SQLExpressionOperator extends SQLOperator -sealed trait SQLComparisonOperator extends SQLExpressionOperator +sealed trait SQLComparisonOperator extends SQLExpressionOperator with PainlessScript { + override def painless: String = this match { + case Eq => "==" + case Ne => "!=" + case other => other.sql + } + + def not: SQLComparisonOperator = this match { + case Eq => Ne + case Ne => Eq + case Ge => Lt + case Gt => Le + case Le => Gt + case Lt => Ge + } +} case object Eq extends SQLExpr("=") with SQLComparisonOperator case object Ne extends SQLExpr("<>") with SQLComparisonOperator diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index c260c87e..6ff28331 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -107,23 +107,24 @@ trait SQLParser extends RegexParsers with PackratParsers { TimeInterval(l.value.toInt, u) } - def current_date: PackratParser[DateTimeFunction] = CurrentDate.regex ~ start.? ~ end.? ^^ { - case _ ~ s ~ t => + def current_date: PackratParser[CurrentDateTimeFunction] = + CurrentDate.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => if (s.isDefined && t.isDefined) CurentDateWithParens else CurrentDate - } + } - def current_time: PackratParser[DateTimeFunction] = CurrentTime.regex ~ start.? ~ end.? ^^ { - case _ ~ s ~ t => + def current_time: PackratParser[CurrentDateTimeFunction] = + CurrentTime.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => if (s.isDefined && t.isDefined) CurrentTimeWithParens else CurrentTime - } + } - def current_timestamp: PackratParser[DateTimeFunction] = + def current_timestamp: PackratParser[CurrentDateTimeFunction] = CurrentTimestamp.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => if (s.isDefined && t.isDefined) CurrentTimestampWithParens else CurrentTimestamp } - def now: PackratParser[DateTimeFunction] = Now.regex ~ start.? ~ end.? ^^ { case _ ~ s ~ t => - if (s.isDefined && t.isDefined) NowWithParens else Now + def now: PackratParser[CurrentDateTimeFunction] = Now.regex ~ start.? ~ end.? ^^ { + case _ ~ s ~ t => + if (s.isDefined && t.isDefined) NowWithParens else Now } def plus: PackratParser[ArithmeticOperator] = Plus.sql ^^ (_ => Plus) @@ -232,9 +233,9 @@ trait SQLWhereParser { SQLIsNotNull(i) } - private def eq: PackratParser[SQLExpressionOperator] = Eq.sql ^^ (_ => Eq) + private def eq: PackratParser[SQLComparisonOperator] = Eq.sql ^^ (_ => Eq) - private def ne: PackratParser[SQLExpressionOperator] = Ne.sql ^^ (_ => Ne) + private def ne: PackratParser[SQLComparisonOperator] = Ne.sql ^^ (_ => Ne) private def equality: PackratParser[SQLExpression] = not.? ~ (identifierWithFunction | identifier) ~ (eq | ne) ~ (boolean | literal | double | long) ^^ { @@ -246,13 +247,13 @@ trait SQLWhereParser { SQLExpression(i, Like, v, n) } - private def ge: PackratParser[SQLExpressionOperator] = Ge.sql ^^ (_ => Ge) + private def ge: PackratParser[SQLComparisonOperator] = Ge.sql ^^ (_ => Ge) - def gt: PackratParser[SQLExpressionOperator] = Gt.sql ^^ (_ => Gt) + def gt: PackratParser[SQLComparisonOperator] = Gt.sql ^^ (_ => Gt) - private def le: PackratParser[SQLExpressionOperator] = Le.sql ^^ (_ => Le) + private def le: PackratParser[SQLComparisonOperator] = Le.sql ^^ (_ => Le) - def lt: PackratParser[SQLExpressionOperator] = Lt.sql ^^ (_ => Lt) + def lt: PackratParser[SQLComparisonOperator] = Lt.sql ^^ (_ => Lt) private def comparison: PackratParser[SQLExpression] = not.? ~ (identifierWithFunction | identifier) ~ (ge | gt | le | lt) ~ (double | long | literal) ^^ { @@ -323,6 +324,11 @@ trait SQLWhereParser { SQLMatch(i, l) } + private def dateTimeComparison: PackratParser[SQLComparisonDateMath] = + not.? ~ (identifierWithFunction | identifier) ~ (eq | ne | ge | gt | le | lt) ~ (current_date | current_time | current_timestamp | now) ~ arithmeticOperator.? ~ interval.? ^^ { + case n ~ i ~ o ~ dt ~ ao ~ it => SQLComparisonDateMath(i, o, dt, ao, it, n) + } + def and: PackratParser[SQLPredicateOperator] = And.regex ^^ (_ => And) def or: PackratParser[SQLPredicateOperator] = Or.regex ^^ (_ => Or) @@ -330,7 +336,7 @@ trait SQLWhereParser { def not: PackratParser[Not.type] = Not.regex ^^ (_ => Not) def criteria: PackratParser[SQLCriteria] = - (equality | like | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | distance | matchCriteria) ^^ ( + (equality | like | dateTimeComparison | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | distance | matchCriteria) ^^ ( c => c ) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index 076ca868..4f3992a8 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -317,6 +317,56 @@ case class SQLMatch( override def group: Boolean = false } +case class SQLComparisonDateMath( + identifier: SQLIdentifier, + operator: SQLComparisonOperator, // Gt, Ge, Lt, Le, Eq, ... + dateTimeFunction: CurrentDateTimeFunction, // CurrentDate, Now, CurrentTimestamp, CurrentTime, ... + arithmeticOperator: Option[ArithmeticOperator] = + None, // Plus or Minus between dateTimeFunction and interval + interval: Option[TimeInterval] = None, // optional interval + maybeNot: Option[Not.type] = None +) extends Expression + with MathScript { + override def sql: String = { + s"$identifier ${operator.sql} $dateTimeFunction${asString(arithmeticOperator)}${asString(interval)}" + } + override def update(request: SQLSearchRequest): SQLCriteria = + this.copy(identifier = identifier.update(request)) + + override def maybeValue: Option[SQLToken] = Some(dateTimeFunction) + + override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this + + override def script: String = { + dateTimeFunction match { + case _: CurrentTimeFunction => + val painlessOp = (if (maybeNot.isDefined) operator.not else operator).painless + (arithmeticOperator, interval) match { + case (Some(Plus), Some(i)) => // compare doc time with now + interval + s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now().plus(${i.value}, ${i.unit.painless});" + + case (Some(Minus), Some(i)) => // compare doc time with now + s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now().minus(${i.value}, ${i.unit.painless});" + + case _ => + s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now();" + } + case _ => + val base = s"${dateTimeFunction.script}" + val dateMath = + (arithmeticOperator, interval) match { + case (Some(Plus), Some(i)) => s"$base+${i.script}" + case (Some(Minus), Some(i)) => s"$base-${i.script}" + case _ => base + } + dateTimeFunction match { + case _: CurrentDateFunction => s"$dateMath/d" + case _ => dateMath + } + } + } +} + case class ElasticMatch( identifier: SQLIdentifier, value: SQLLiteral, diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index 063f70c1..47ba30fe 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -25,6 +25,10 @@ package object sql { def painless: String } + trait MathScript extends SQLToken { + def script: String + } + trait SQLTokenWithFunction extends SQLToken { def function: Option[SQLFunction] diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 91d7e38c..d52d3e8d 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -83,8 +83,12 @@ object Queries { "select current_timestamp() - interval 3 day as ct, current_date as cd, current_time as t, now as n from dual" val fieldsWithInterval: String = "select createdAt - interval 35 minute as ct, identifier from Table" - - //TODO "select * from Table where createdAt <= current_timestamp() and createdAt >= current_timestamp() - interval 35 minute" + val filterWithDateTimeAndInterval: String = + "select * from Table where createdAt < current_timestamp() and createdAt >= current_timestamp() - interval 10 day" + val filterWithDateAndInterval: String = + "select * from Table where createdAt < current_date and createdAt >= current_date() - interval 10 day" + val filterWithTimeAndInterval: String = + "select * from Table where createdAt < current_time and createdAt >= current_time() - interval 10 minute" } /** Created by smanciot on 15/02/17. @@ -360,4 +364,25 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { val result = SQLParser(fieldsWithInterval) result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===(fieldsWithInterval) } + + it should "parse filter with date time and interval" in { + val result = SQLParser(filterWithDateTimeAndInterval) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + filterWithDateTimeAndInterval + ) + } + + it should "parse filter with date and interval" in { + val result = SQLParser(filterWithDateAndInterval) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + filterWithDateAndInterval + ) + } + + it should "parse filter with time and interval" in { + val result = SQLParser(filterWithTimeAndInterval) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + filterWithTimeAndInterval + ) + } } From 1ae37929c8a54019e59cb4102c1073eeab4e6e0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Thu, 4 Sep 2025 22:40:36 +0200 Subject: [PATCH 04/22] fix parser for Month --- sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 6ff28331..77930cf6 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -87,7 +87,7 @@ trait SQLParser extends RegexParsers with PackratParsers { def year: PackratParser[TimeUnit] = Year.regex ^^ (_ => Year) - def month: PackratParser[TimeUnit] = Month.regex ^^ (_ => Year) + def month: PackratParser[TimeUnit] = Month.regex ^^ (_ => Month) def quarter: PackratParser[TimeUnit] = Quarter.regex ^^ (_ => Quarter) From a546276a30796c8c8ae63a73eb50c6dc818461da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Thu, 4 Sep 2025 22:50:09 +0200 Subject: [PATCH 05/22] fix ne comparison for date math range query --- .../main/scala/app/softnetwork/elastic/sql/bridge/package.scala | 2 +- .../main/scala/app/softnetwork/elastic/sql/bridge/package.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index 650c19e7..0f386732 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -314,7 +314,7 @@ package object bridge { case Lt => rangeQuery(identifier.name) lt script case Le => rangeQuery(identifier.name) lte script case Eq => rangeQuery(identifier.name) gte script lte script - case Ne => rangeQuery(identifier.name) lt script gt script + case Ne => not(rangeQuery(identifier.name) gte script lte script) } } } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index 12587c94..34776e36 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -315,7 +315,7 @@ package object bridge { case Lt => rangeQuery(identifier.name) lt script case Le => rangeQuery(identifier.name) lte script case Eq => rangeQuery(identifier.name) gte script lte script - case Ne => rangeQuery(identifier.name) lt script gt script + case Ne => not(rangeQuery(identifier.name) gte script lte script) } } } From bfc1426320fce50b27e30e0e3d3b7865b248c217 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Fri, 5 Sep 2025 01:19:52 +0200 Subject: [PATCH 06/22] add aliasOrName for all identifiers, handle date functions used with aggregations within having --- .../elastic/sql/bridge/package.scala | 2 + .../elastic/sql/SQLQuerySpec.scala | 157 ++++++++++++++++++ .../elastic/sql/bridge/package.scala | 2 + .../elastic/sql/SQLQuerySpec.scala | 155 +++++++++++++++++ .../softnetwork/elastic/sql/SQLGroupBy.scala | 69 +++++--- .../app/softnetwork/elastic/sql/package.scala | 2 + .../elastic/sql/SQLParserSpec.scala | 13 ++ 7 files changed, 373 insertions(+), 27 deletions(-) diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index 0f386732..0cc2adca 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -303,6 +303,8 @@ package object bridge { implicit def dateMathToQuery(dateMath: SQLComparisonDateMath): Query = { import dateMath._ + if (aggregation) + return matchAllQuery() dateTimeFunction match { case _: CurrentTimeFunction => scriptQuery(Script(script = script).lang("painless").scriptType("source")) diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 881a62d6..3f96b33b 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1003,4 +1003,161 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("<", " < ") .replaceAll("return", "return ") } + + it should "handle having with date functions" in { + val select: ElasticSearchRequest = + SQLQuery("""SELECT userId, MAX(createdAt) as lastSeen + |FROM table + |GROUP BY userId + |HAVING MAX(createdAt) > now - interval 7 day""".stripMargin) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "filtered_agg": { + | "filter": { + | "match_all": {} + | }, + | "aggs": { + | "userId": { + | "terms": { + | "field": "userId.keyword" + | }, + | "aggs": { + | "lastSeen": { + | "max": { + | "field": "createdAt" + | } + | }, + | "having_filter": { + | "bucket_selector": { + | "buckets_path": { + | "lastSeen": "lastSeen" + | }, + | "script": { + | "source": "(params.lastSeen != null) && (params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAY).toInstant().toEpochMilli())" + | } + | } + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + } + + it should "handle group by with having and date time functions" in { + val select: ElasticSearchRequest = + SQLQuery(groupByWithHavingAndDateTimeFunctions) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "filtered_agg": { + | "filter": { + | "bool": { + | "filter": [ + | { + | "bool": { + | "must_not": [ + | { + | "term": { + | "Country": { + | "value": "USA" + | } + | } + | } + | ] + | } + | }, + | { + | "bool": { + | "must_not": [ + | { + | "term": { + | "City": { + | "value": "Berlin" + | } + | } + | } + | ] + | } + | }, + | { + | "match_all": {} + | }, + | { + | "range": { + | "lastSeen": { + | "gt": "now-7d" + | } + | } + | } + | ] + | } + | }, + | "aggs": { + | "Country": { + | "terms": { + | "field": "Country.keyword" + | }, + | "aggs": { + | "City": { + | "terms": { + | "field": "City.keyword" + | }, + | "aggs": { + | "cnt": { + | "value_count": { + | "field": "CustomerID" + | } + | }, + | "lastSeen": { + | "max": { + | "field": "createdAt" + | } + | }, + | "having_filter": { + | "bucket_selector": { + | "buckets_path": { + | "cnt": "cnt" + | }, + | "script": { + | "source": "1 == 1 && 1 == 1 && params.cnt > 1 && 1 == 1" + | } + | } + | } + | } + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + } } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index 34776e36..ea9d4695 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -304,6 +304,8 @@ package object bridge { implicit def dateMathToQuery(dateMath: SQLComparisonDateMath): Query = { import dateMath._ + if (aggregation) + return matchAllQuery() dateTimeFunction match { case _: CurrentTimeFunction => scriptQuery(Script(script = script).lang("painless").scriptType("source")) diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index c6cdfb3c..12ae803f 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1002,4 +1002,159 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("<", " < ") .replaceAll("return", "return ") } + + it should "handle having with date functions" in { + val select: ElasticSearchRequest = + SQLQuery("""SELECT userId, MAX(createdAt) as lastSeen + |FROM table + |GROUP BY userId + |HAVING MAX(createdAt) > now - interval 7 day""".stripMargin) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "filtered_agg": { + | "filter": { + | "match_all": {} + | }, + | "aggs": { + | "userId": { + | "terms": { + | "field": "userId.keyword" + | }, + | "aggs": { + | "lastSeen": { + | "max": { + | "field": "createdAt" + | } + | }, + | "having_filter": { + | "bucket_selector": { + | "buckets_path": { + | "lastSeen": "lastSeen" + | }, + | "script": { + | "source": "(params.lastSeen != null) && (params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAY).toInstant().toEpochMilli())" + | } + | } + | } + | } + | } + | } + | } + | } + |}""".stripMargin.replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + } + + it should "handle group by with having and date time functions" in { + val select: ElasticSearchRequest = + SQLQuery(groupByWithHavingAndDateTimeFunctions) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "filtered_agg": { + | "filter": { + | "bool": { + | "filter": [ + | { + | "bool": { + | "must_not": [ + | { + | "term": { + | "Country": { + | "value": "USA" + | } + | } + | } + | ] + | } + | }, + | { + | "bool": { + | "must_not": [ + | { + | "term": { + | "City": { + | "value": "Berlin" + | } + | } + | } + | ] + | } + | }, + | { + | "match_all": {} + | }, + | { + | "range": { + | "lastSeen": { + | "gt": "now-7d" + | } + | } + | } + | ] + | } + | }, + | "aggs": { + | "Country": { + | "terms": { + | "field": "Country.keyword" + | }, + | "aggs": { + | "City": { + | "terms": { + | "field": "City.keyword" + | }, + | "aggs": { + | "cnt": { + | "value_count": { + | "field": "CustomerID" + | } + | }, + | "lastSeen": { + | "max": { + | "field": "createdAt" + | } + | }, + | "having_filter": { + | "bucket_selector": { + | "buckets_path": { + | "cnt": "cnt" + | }, + | "script": { + | "source": "1 == 1 && 1 == 1 && params.cnt > 1 && 1 == 1" + | } + | } + | } + | } + | } + | } + | } + | } + | } + | } + |}""".stripMargin.replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala index 73e977eb..bac04eb1 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala @@ -95,17 +95,12 @@ object BucketSelectorScript { extractBucketsPath(left) ++ extractBucketsPath(right) case relation: ElasticRelation => extractBucketsPath(relation.criteria) case _: SQLMatch => Map.empty //MATCH is not supported in bucket_selector - case e: Expression => + case SQLComparisonDateMath(identifier, _, _, _, _, _) if identifier.aggregation => + Map(identifier.aliasOrName -> identifier.aliasOrName) + case e: Expression if e.aggregation => import e._ - val name = identifier.fieldAlias.getOrElse(identifier.name) - if (e.aggregation) { - Map(name -> name) - } /*else if (e.identifier.bucket.isDefined) { - Map(name -> "_key") - }*/ - else { - Map.empty // for performance, we only allow aggregation here - } + Map(identifier.aliasOrName -> identifier.aliasOrName) + case _ => Map.empty } def toPainless(expr: SQLCriteria): String = expr match { @@ -125,26 +120,46 @@ object BucketSelectorScript { case relation: ElasticRelation => toPainless(relation.criteria) + case SQLComparisonDateMath(identifier, op, dateFunc, arithOp, interval, maybeNot) + if identifier.aggregation => + val painlessOp = if (maybeNot.nonEmpty) op.not.painless else op.painless + val paramName = identifier.aliasOrName + // always use a correct "now" creation + val now = "ZonedDateTime.now(ZoneId.of('Z'))" + + // build the RHS as a Painless ZonedDateTime (apply +/- interval using TimeInterval.painless) + val rightBase = (arithOp, interval) match { + case (Some(Plus), Some(i)) => s"$now.plus(${i.painless})" + case (Some(Minus), Some(i)) => s"$now.minus(${i.painless})" + case _ => now + } + + val rightZdt = dateFunc match { + // truncate only after arithmetic for CurrentDate + case _: CurrentDateFunction => s"$rightBase.truncatedTo(ChronoUnit.DAYS)" + case _: CurrentTimeFunction => s"$rightBase.truncatedTo(ChronoUnit.SECONDS)" + case _ => rightBase + } + + // protect against null params and compare epoch millis + s"(params.$paramName != null) && (params.$paramName $painlessOp $rightZdt.toInstant().toEpochMilli())" + case _: SQLMatch => "1 == 1" //MATCH is not supported in bucket_selector - case e: Expression => - if (e.aggregation /*|| e.identifier.bucket.isDefined*/ ) { // for performance, we only allow aggregation here - val param = - s"params.${e.identifier.fieldAlias.getOrElse(e.identifier.name)}" - e.maybeValue match { - case Some(v) => toPainless(param, e.operator, v, e.maybeNot.nonEmpty) - case None => - e.operator match { - case IsNull => s"$param == null" - case IsNotNull => s"$param != null" - case _ => - throw new IllegalArgumentException(s"Operator ${e.operator} requires a value") - } - } - } else { - "1 == 1" + case e: Expression if e.aggregation => + val param = + s"params.${e.identifier.aliasOrName}" + e.maybeValue match { + case Some(v) => toPainless(param, e.operator, v, e.maybeNot.nonEmpty) + case None => + e.operator match { + case IsNull => s"$param == null" + case IsNotNull => s"$param != null" + case _ => + throw new IllegalArgumentException(s"Operator ${e.operator} requires a value") + } } - case _ => throw new IllegalArgumentException(s"Unsupported SQLCriteria type: $expr") + case _ => "1 == 1" //throw new IllegalArgumentException(s"Unsupported SQLCriteria type: $expr") } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index 47ba30fe..ebf18129 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -301,6 +301,8 @@ package object sql { lazy val innerHitsName: Option[String] = if (nested) tableAlias else None + lazy val aliasOrName: String = fieldAlias.getOrElse(name) + def update(request: SQLSearchRequest): SQLIdentifier = { val parts: Seq[String] = name.split("\\.").toSeq if (request.tableAliases.values.toSeq.contains(parts.head)) { diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index d52d3e8d..b265aa1f 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -89,6 +89,12 @@ object Queries { "select * from Table where createdAt < current_date and createdAt >= current_date() - interval 10 day" val filterWithTimeAndInterval: String = "select * from Table where createdAt < current_time and createdAt >= current_time() - interval 10 minute" + val groupByWithHavingAndDateTimeFunctions: String = + """select count(CustomerID) as cnt, City, Country, max(createdAt) as lastSeen + |from Table + |group by Country,City + |having Country <> "USA" and City <> "Berlin" and count(CustomerID) > 1 and lastSeen > now - interval 7 day + |""".stripMargin.replaceAll("\n", " ") } /** Created by smanciot on 15/02/17. @@ -385,4 +391,11 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { filterWithTimeAndInterval ) } + + it should "parse group by with having and date time functions" in { + val result = SQLParser(groupByWithHavingAndDateTimeFunctions) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + groupByWithHavingAndDateTimeFunctions + ) + } } From 34434ca0cf34b1418af27a10c899d9687fc3d0cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Fri, 5 Sep 2025 01:31:36 +0200 Subject: [PATCH 07/22] fix parser specifications --- .../test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index b265aa1f..3c1fdf69 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -93,8 +93,7 @@ object Queries { """select count(CustomerID) as cnt, City, Country, max(createdAt) as lastSeen |from Table |group by Country,City - |having Country <> "USA" and City <> "Berlin" and count(CustomerID) > 1 and lastSeen > now - interval 7 day - |""".stripMargin.replaceAll("\n", " ") + |having Country <> "USA" and City <> "Berlin" and count(CustomerID) > 1 and lastSeen > now - interval 7 day""".stripMargin.replaceAll("\n", " ") } /** Created by smanciot on 15/02/17. From c80405a784171b38937256d6d9fe55a6c026ccb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Fri, 5 Sep 2025 01:47:06 +0200 Subject: [PATCH 08/22] fix parser specifications lint bug --- .../test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 3c1fdf69..8969371a 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -93,7 +93,8 @@ object Queries { """select count(CustomerID) as cnt, City, Country, max(createdAt) as lastSeen |from Table |group by Country,City - |having Country <> "USA" and City <> "Berlin" and count(CustomerID) > 1 and lastSeen > now - interval 7 day""".stripMargin.replaceAll("\n", " ") + |having Country <> "USA" and City <> "Berlin" and count(CustomerID) > 1 and lastSeen > now - interval 7 day""".stripMargin + .replaceAll("\n", " ") } /** Created by smanciot on 15/02/17. From 473cfcc21c9afbfd13ab8ff18c226b5570ac080e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Fri, 5 Sep 2025 07:46:01 +0200 Subject: [PATCH 09/22] fix order by with buckets and aggregations --- .../elastic/sql/SQLQuerySpec.scala | 25 +++++++++++-------- .../elastic/sql/SQLQuerySpec.scala | 25 +++++++++++-------- .../softnetwork/elastic/sql/SQLGroupBy.scala | 2 +- .../softnetwork/elastic/sql/SQLOrderBy.scala | 2 +- .../softnetwork/elastic/sql/SQLParser.scala | 8 ++---- .../app/softnetwork/elastic/sql/package.scala | 7 ++---- .../elastic/sql/SQLParserSpec.scala | 15 +++++------ 7 files changed, 42 insertions(+), 42 deletions(-) diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 3f96b33b..cf335677 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -538,8 +538,8 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "must_not": [ | { | "term": { - | "country": { - | "value": "usa" + | "Country": { + | "value": "USA" | } | } | } @@ -551,8 +551,8 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "must_not": [ | { | "term": { - | "city": { - | "value": "berlin" + | "City": { + | "value": "Berlin" | } | } | } @@ -566,17 +566,17 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } | }, | "aggs": { - | "country": { + | "Country": { | "terms": { - | "field": "country.keyword", + | "field": "Country.keyword", | "order": { - | "country": "asc" + | "Country": "asc" | } | }, | "aggs": { - | "city": { + | "City": { | "terms": { - | "field": "city.keyword", + | "field": "City.keyword", | "order": { | "cnt": "desc" | } @@ -584,7 +584,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "aggs": { | "cnt": { | "value_count": { - | "field": "customerid" + | "field": "CustomerID" | } | }, | "having_filter": { @@ -1117,7 +1117,10 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "aggs": { | "Country": { | "terms": { - | "field": "Country.keyword" + | "field": "Country.keyword", + | "order": { + | "Country": "asc" + | } | }, | "aggs": { | "City": { diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 12ae803f..8218f8b2 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -538,8 +538,8 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "must_not": [ | { | "term": { - | "country": { - | "value": "usa" + | "Country": { + | "value": "USA" | } | } | } @@ -551,8 +551,8 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "must_not": [ | { | "term": { - | "city": { - | "value": "berlin" + | "City": { + | "value": "Berlin" | } | } | } @@ -566,17 +566,17 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } | }, | "aggs": { - | "country": { + | "Country": { | "terms": { - | "field": "country.keyword", + | "field": "Country.keyword", | "order": { - | "country": "asc" + | "Country": "asc" | } | }, | "aggs": { - | "city": { + | "City": { | "terms": { - | "field": "city.keyword", + | "field": "City.keyword", | "order": { | "cnt": "desc" | } @@ -584,7 +584,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "aggs": { | "cnt": { | "value_count": { - | "field": "customerid" + | "field": "CustomerID" | } | }, | "having_filter": { @@ -1115,7 +1115,10 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "aggs": { | "Country": { | "terms": { - | "field": "Country.keyword" + | "field": "Country.keyword", + | "order": { + | "Country": "asc" + | } | }, | "aggs": { | "City": { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala index bac04eb1..6a7b4ec2 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala @@ -3,7 +3,7 @@ package app.softnetwork.elastic.sql case object GroupBy extends SQLExpr("group by") with SQLRegex case class SQLGroupBy(buckets: Seq[SQLBucket]) extends Updateable { - override def sql: String = s" $GroupBy ${buckets.mkString(",")}" + override def sql: String = s" $GroupBy ${buckets.mkString(", ")}" def update(request: SQLSearchRequest): SQLGroupBy = this.copy(buckets = buckets.map(_.update(request))) lazy val bucketNames: Map[String, SQLBucket] = buckets.map { b => diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala index 74f3110a..656c8e2c 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala @@ -23,5 +23,5 @@ case class SQLFieldSort( } case class SQLOrderBy(sorts: Seq[SQLFieldSort]) extends SQLToken { - override def sql: String = s" $OrderBy ${sorts.mkString(",")}" + override def sql: String = s" $OrderBy ${sorts.mkString(", ")}" } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 77930cf6..2b395b17 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -534,14 +534,10 @@ trait SQLWhereParser { trait SQLGroupByParser { self: SQLParser with SQLWhereParser => - private def having: PackratParser[SQLHaving] = Having.regex ~> whereCriteria ^^ { rawTokens => - SQLHaving( - processTokens(rawTokens) - ) + def bucket: PackratParser[SQLBucket] = identifier ^^ { i => + SQLBucket(i) } - def bucket: PackratParser[SQLBucket] = identifier ^^ (i => SQLBucket(i)) - def groupBy: PackratParser[SQLGroupBy] = GroupBy.regex ~ rep1sep(bucket, separator) ^^ { case _ ~ buckets => SQLGroupBy(buckets) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index ebf18129..603b3a3a 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -288,14 +288,11 @@ package object sql { with SQLSource with SQLTokenWithFunction { - lazy val aggregationName: Option[String] = - if (aggregation) fieldAlias.orElse(Option(name)) else None - lazy val identifierName: String = - (function match { + function match { case Some(f) => s"${f.sql}($name)" case _ => name - }).toLowerCase + } lazy val nestedType: Option[String] = if (nested) Some(name.split('.').head) else None diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 8969371a..c5311cfd 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -74,11 +74,11 @@ object Queries { |order by identifier2 desc |limit 10""".stripMargin.replaceAll("\n", " ") val groupByWithHaving: String = - """SELECT COUNT(CustomerID) as cnt, City, Country - |FROM Customers - |GROUP BY Country,City - |HAVING Country <> "USA" AND City <> "Berlin" AND COUNT(CustomerID) > 1 - |ORDER BY COUNT(CustomerID) DESC,Country asc""".stripMargin.replaceAll("\n", " ").toLowerCase + """select count(CustomerID) as cnt, City, Country + |from Customers + |group by Country, City + |having Country <> "USA" and City <> "Berlin" and count(CustomerID) > 1 + |order by count(CustomerID) desc, Country asc""".stripMargin.replaceAll("\n", " ") val dateTimeWithIntervalFields: String = "select current_timestamp() - interval 3 day as ct, current_date as cd, current_time as t, now as n from dual" val fieldsWithInterval: String = @@ -92,8 +92,9 @@ object Queries { val groupByWithHavingAndDateTimeFunctions: String = """select count(CustomerID) as cnt, City, Country, max(createdAt) as lastSeen |from Table - |group by Country,City - |having Country <> "USA" and City <> "Berlin" and count(CustomerID) > 1 and lastSeen > now - interval 7 day""".stripMargin + |group by Country, City + |having Country <> "USA" and City <> "Berlin" and count(CustomerID) > 1 and lastSeen > now - interval 7 day + |order by Country asc""".stripMargin .replaceAll("\n", " ") } From c931b70afe328ef3ef5f110a6e90021b7a82d0ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Fri, 5 Sep 2025 13:19:07 +0200 Subject: [PATCH 10/22] add functions composition mechanism, add trait SQLTransformFunction and ParametrizedFunction, implements parse_date and parse_datetime --- .../sql/bridge/ElasticAggregation.scala | 33 +++++- .../elastic/sql/SQLQuerySpec.scala | 112 ++++++++++++++++++ .../sql/bridge/ElasticAggregation.scala | 33 +++++- .../elastic/sql/SQLQuerySpec.scala | 112 ++++++++++++++++++ .../softnetwork/elastic/sql/SQLFunction.scala | 37 +++++- .../softnetwork/elastic/sql/SQLOrderBy.scala | 11 +- .../softnetwork/elastic/sql/SQLParser.scala | 36 ++++-- .../softnetwork/elastic/sql/SQLSelect.scala | 4 +- .../softnetwork/elastic/sql/SQLWhere.scala | 10 +- .../app/softnetwork/elastic/sql/package.scala | 24 ++-- .../elastic/sql/SQLParserSpec.scala | 19 +++ 11 files changed, 385 insertions(+), 46 deletions(-) diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index ff17e7e8..da15ce69 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -12,6 +12,7 @@ import app.softnetwork.elastic.sql.{ Min, SQLBucket, SQLCriteria, + SQLTransformFunction, SortOrder, Sum } @@ -88,6 +89,30 @@ object ElasticAggregation { var aggPath = Seq[String]() + val (aggFuncs, transformFuncs) = identifier.functions.partition { + case _: AggregateFunction => true + case _ => false + } + + require(aggFuncs.size == 1, s"Multiple aggregate functions not supported: $aggFuncs") + + def aggWithFieldOrScript( + buildField: (String, String) => Aggregation, + buildScript: (String, Script) => Aggregation + ): Aggregation = { + if (transformFuncs.nonEmpty) { + val base = s"doc['$sourceField'].value" + val scriptSrc = transformFuncs.foldLeft(base) { + case (expr, f: SQLTransformFunction) => f.toPainless(expr) + case (expr, f) => f.toSQL(expr) // fallback + } + val script = Script(scriptSrc).lang("painless") + buildScript(aggName, script) + } else { + buildField(aggName, sourceField) + } + } + val _agg = aggType match { case Count => @@ -96,10 +121,10 @@ object ElasticAggregation { else { valueCountAgg(aggName, sourceField) } - case Min => minAgg(aggName, sourceField) - case Max => maxAgg(aggName, sourceField) - case Avg => avgAgg(aggName, sourceField) - case Sum => sumAgg(aggName, sourceField) + case Min => aggWithFieldOrScript(minAgg, (name, s) => minAgg(name, sourceField).script(s)) + case Max => aggWithFieldOrScript(maxAgg, (name, s) => maxAgg(name, sourceField).script(s)) + case Avg => aggWithFieldOrScript(avgAgg, (name, s) => avgAgg(name, sourceField).script(s)) + case Sum => aggWithFieldOrScript(sumAgg, (name, s) => sumAgg(name, sourceField).script(s)) } val filteredAggName = "filtered_agg" diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index cf335677..d7088219 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1163,4 +1163,116 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("&&", " && ") .replaceAll(">", " > ") } + + it should "handle parse_date function" in { + val select: ElasticSearchRequest = + SQLQuery(parseDate) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "identifier": { + | "terms": { + | "field": "identifier.keyword", + | "order": { + | "ct": "desc" + | } + | }, + | "aggs": { + | "ct": { + | "value_count": { + | "field": "identifier2" + | } + | }, + | "lastSeen": { + | "max": { + | "field": "createdAt", + | "script": { + | "lang": "painless", + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(doc['createdAt'].value, LocalDate::from)" + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + .replaceAll(",LocalDate", ", LocalDate") + } + + it should "handle parse_datetime function" in { + val select: ElasticSearchRequest = + SQLQuery(parseDateTime) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "identifier": { + | "terms": { + | "field": "identifier.keyword", + | "order": { + | "ct": "desc" + | } + | }, + | "aggs": { + | "ct": { + | "value_count": { + | "field": "identifier2" + | } + | }, + | "lastSeen": { + | "max": { + | "field": "createdAt", + | "script": { + | "lang": "painless", + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, LocalDateTime::from)" + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + .replaceAll(",LocalDate", ", LocalDate") + } } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index 4d0894f4..2a6adda2 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -12,6 +12,7 @@ import app.softnetwork.elastic.sql.{ Min, SQLBucket, SQLCriteria, + SQLTransformFunction, SortOrder, Sum } @@ -87,6 +88,30 @@ object ElasticAggregation { var aggPath = Seq[String]() + val (aggFuncs, transformFuncs) = identifier.functions.partition { + case _: AggregateFunction => true + case _ => false + } + + require(aggFuncs.size == 1, s"Multiple aggregate functions not supported: $aggFuncs") + + def aggWithFieldOrScript( + buildField: (String, String) => Aggregation, + buildScript: (String, Script) => Aggregation + ): Aggregation = { + if (transformFuncs.nonEmpty) { + val base = s"doc['$sourceField'].value" + val scriptSrc = transformFuncs.foldLeft(base) { + case (expr, f: SQLTransformFunction) => f.toPainless(expr) + case (expr, f) => f.toSQL(expr) // fallback + } + val script = Script(scriptSrc).lang("painless") + buildScript(aggName, script) + } else { + buildField(aggName, sourceField) + } + } + val _agg = aggType match { case Count => @@ -95,10 +120,10 @@ object ElasticAggregation { else { valueCountAgg(aggName, sourceField) } - case Min => minAgg(aggName, sourceField) - case Max => maxAgg(aggName, sourceField) - case Avg => avgAgg(aggName, sourceField) - case Sum => sumAgg(aggName, sourceField) + case Min => aggWithFieldOrScript(minAgg, (name, s) => minAgg(name, sourceField).script(s)) + case Max => aggWithFieldOrScript(maxAgg, (name, s) => maxAgg(name, sourceField).script(s)) + case Avg => aggWithFieldOrScript(avgAgg, (name, s) => avgAgg(name, sourceField).script(s)) + case Sum => aggWithFieldOrScript(sumAgg, (name, s) => sumAgg(name, sourceField).script(s)) } val filteredAggName = "filtered_agg" diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 8218f8b2..4ca76076 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1160,4 +1160,116 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("&&", " && ") .replaceAll(">", " > ") } + + it should "handle parse_date function" in { + val select: ElasticSearchRequest = + SQLQuery(parseDate) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "identifier": { + | "terms": { + | "field": "identifier.keyword", + | "order": { + | "ct": "desc" + | } + | }, + | "aggs": { + | "ct": { + | "value_count": { + | "field": "identifier2" + | } + | }, + | "lastSeen": { + | "max": { + | "field": "createdAt", + | "script": { + | "lang": "painless", + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(doc['createdAt'].value, LocalDate::from)" + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + .replaceAll(",LocalDate", ", LocalDate") + } + + it should "handle parse_datetime function" in { + val select: ElasticSearchRequest = + SQLQuery(parseDateTime) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "identifier": { + | "terms": { + | "field": "identifier.keyword", + | "order": { + | "ct": "desc" + | } + | }, + | "aggs": { + | "ct": { + | "value_count": { + | "field": "identifier2" + | } + | }, + | "lastSeen": { + | "max": { + | "field": "createdAt", + | "script": { + | "lang": "painless", + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, LocalDateTime::from)" + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll("==", " == ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll(">", " > ") + .replaceAll(",LocalDate", ", LocalDate") + } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 3d3d9936..dc5a0f22 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -2,7 +2,21 @@ package app.softnetwork.elastic.sql import scala.util.matching.Regex -sealed trait SQLFunction extends SQLRegex +sealed trait SQLFunction extends SQLRegex { + def toSQL(base: String): String = s"$sql($base)" +} + +sealed trait SQLTransformFunction extends SQLFunction { + def toPainless(base: String): String +} + +sealed trait ParametrizedFunction extends SQLFunction { + def params: Seq[String] + override def toSQL(base: String): String = { + val paramsStr = params.map(p => s"'$p'").mkString(", ") + s"$sql($paramsStr)($base)" + } +} sealed trait AggregateFunction extends SQLFunction case object Count extends SQLExpr("count") with AggregateFunction @@ -118,4 +132,23 @@ case class DateSub(interval: TimeInterval) extends SQLExpr("date_sub") with Date case class DateTrunc(unit: TimeUnit) extends SQLExpr("date_trunc") with DateTimeFunction case class Extract(unit: TimeUnit) extends SQLExpr("extract") with DateTimeFunction case class FormatDate(format: String) extends SQLExpr("format_date") with DateTimeFunction -case class ParseDate(format: String) extends SQLExpr("parse_date") with DateTimeFunction + +case class ParseDate(format: String) + extends SQLExpr("parse_date") + with DateTimeFunction + with SQLTransformFunction + with ParametrizedFunction { + override def params: Seq[String] = Seq(format) + override def toPainless(base: String): String = + s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDate::from)" +} + +case class ParseDateTime(format: String) + extends SQLExpr("parse_datetime") + with DateTimeFunction + with SQLTransformFunction + with ParametrizedFunction { + override def params: Seq[String] = Seq(format) + override def toPainless(base: String): String = + s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDateTime::from)" +} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala index 656c8e2c..54acfb3d 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala @@ -11,12 +11,13 @@ case object Asc extends SQLExpr("asc") with SortOrder case class SQLFieldSort( field: String, order: Option[SortOrder], - function: Option[SQLFunction] = None + functions: List[SQLFunction] = List.empty ) extends SQLTokenWithFunction { - private[this] lazy val fieldWithFunction: String = function match { - case Some(f) => s"$f($field)" - case _ => field - } + private[this] lazy val fieldWithFunction: String = + functions.foldLeft(field)((expr, fun) => { + fun.toSQL(expr) + }) + lazy val direction: SortOrder = order.getOrElse(Asc) lazy val name: String = fieldWithFunction override def sql: String = s"$name $direction" diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 2b395b17..631a2ce8 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -143,17 +143,29 @@ trait SQLParser extends RegexParsers with PackratParsers { ) } - def aggregateFunction: PackratParser[AggregateFunction] = count | min | max | avg | sum + def parse_date: PackratParser[DateTimeFunction] = + "(?i)parse_date".r ~ start ~ literal ~ end ^^ { case _ ~ _ ~ f ~ _ => + ParseDate(f.value) + } + + def parse_datetime: PackratParser[DateTimeFunction] = + "(?i)parse_datetime".r ~ start ~ literal ~ end ^^ { case _ ~ _ ~ f ~ _ => + ParseDateTime(f.value) + } + + def date_functions: PackratParser[DateTimeFunction] = parse_date | parse_datetime + + def aggregates: PackratParser[AggregateFunction] = count | min | max | avg | sum - def distanceFunction: PackratParser[SQLFunction] = Distance.regex ^^ (_ => Distance) + def distance: PackratParser[SQLFunction] = Distance.regex ^^ (_ => Distance) - def sqlFunction: PackratParser[SQLFunction] = aggregateFunction | distanceFunction + def sql_functions: PackratParser[SQLFunction] = aggregates | distance | date_functions private val regexIdentifier = """[\*a-zA-Z_\-][a-zA-Z0-9_\-\.\[\]\*]*""" def identifierWithFunction: PackratParser[SQLIdentifier] = - sqlFunction ~ start ~ identifier ~ end ^^ { case f ~ _ ~ i ~ _ => - i.copy(function = Some(f)) + rep1sep(sql_functions, start) ~ start.? ~ identifier ~ rep1(end) ^^ { case f ~ _ ~ i ~ _ => + i.copy(functions = f) } def identifier: PackratParser[SQLIdentifier] = @@ -311,8 +323,8 @@ trait SQLWhereParser { case i ~ n ~ _ ~ from ~ _ ~ to => SQLBetween(i, SQLDoubleFromTo(from, to), n) } - def distance: PackratParser[SQLCriteria] = - distanceFunction ~ start ~ identifier ~ separator ~ start ~ double ~ separator ~ double ~ end ~ end ~ le ~ literal ^^ { + def sql_distance: PackratParser[SQLCriteria] = + distance ~ start ~ identifier ~ separator ~ start ~ double ~ separator ~ double ~ end ~ end ~ le ~ literal ^^ { case _ ~ _ ~ i ~ _ ~ _ ~ lat ~ _ ~ lon ~ _ ~ _ ~ _ ~ d => ElasticGeoDistance(i, d, lat, lon) } @@ -336,7 +348,7 @@ trait SQLWhereParser { def not: PackratParser[Not.type] = Not.regex ^^ (_ => Not) def criteria: PackratParser[SQLCriteria] = - (equality | like | dateTimeComparison | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | distance | matchCriteria) ^^ ( + (equality | like | dateTimeComparison | comparison | inLiteral | inLongs | inDoubles | between | betweenLongs | betweenDoubles | isNotNull | isNull | sql_distance | matchCriteria) ^^ ( c => c ) @@ -566,16 +578,16 @@ trait SQLOrderByParser { private def fieldName: PackratParser[String] = """\b(?!(?i)limit\b)[a-zA-Z_][a-zA-Z0-9_]*""".r ^^ (f => f) - def fieldWithFunction: PackratParser[(String, SQLFunction)] = - sqlFunction ~ start ~ fieldName ~ end ^^ { case f ~ _ ~ n ~ _ => + def fieldWithFunction: PackratParser[(String, List[SQLFunction])] = + rep1sep(sql_functions, start) ~ start.? ~ fieldName ~ rep1(end) ^^ { case f ~ _ ~ n ~ _ => (n, f) } def sort: PackratParser[SQLFieldSort] = (fieldWithFunction | fieldName) ~ (asc | desc).? ^^ { case f ~ o => f match { - case i: (String, SQLFunction) => SQLFieldSort(i._1, o, Some(i._2)) - case s: String => SQLFieldSort(s, o, None) + case i: (String, List[SQLFunction]) => SQLFieldSort(i._1, o, i._2) + case s: String => SQLFieldSort(s, o, List.empty) } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala index 1f788fc2..4cc6700a 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala @@ -22,7 +22,7 @@ sealed trait Field extends Updateable with SQLTokenWithFunction { identifier.name.replace("(", "").replace(")", "") } - override def function: Option[SQLFunction] = identifier.function + override def functions: List[SQLFunction] = identifier.functions def update(request: SQLSearchRequest): Field } @@ -54,7 +54,7 @@ case class SQLDateTimeField( def update(request: SQLSearchRequest): SQLDateTimeField = this.copy(identifier = identifier.update(request)) override def painless: String = { - val base = identifier.function match { + val base = identifier.functions.headOption match { // FIXME case f @ Some(CurrentDate | CurrentTime | CurrentTimestamp | Now) => f.asInstanceOf[PainlessScript].painless case _ => s"doc['$sourceField'].value" diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index 4f3992a8..68b36cd0 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -94,7 +94,7 @@ sealed trait SQLCriteriaWithIdentifier extends SQLCriteria with SQLTokenWithFunc override def nested: Boolean = identifier.nested override def group: Boolean = false override lazy val limit: Option[SQLLimit] = identifier.limit - override val function: Option[SQLFunction] = identifier.function + override val functions: List[SQLFunction] = identifier.functions } case class ElasticBoolQuery( @@ -219,7 +219,7 @@ case class SQLIn[R, +T <: SQLValue[R]]( values: SQLValues[R, T], maybeNot: Option[Not.type] = None ) extends Expression { this: SQLIn[R, T] => - private[this] lazy val id = function match { + private[this] lazy val id = functions.headOption match { case Some(f) => s"$f($identifier)" case _ => s"$identifier" } @@ -244,7 +244,7 @@ case class SQLBetween[+T]( fromTo: SQLFromTo[T], maybeNot: Option[Not.type] ) extends Expression { - private[this] lazy val id = function match { + private[this] lazy val id = functions.headOption match { case Some(f) => s"$f($identifier)" case _ => s"$identifier" } @@ -271,7 +271,7 @@ case class ElasticGeoDistance( lon: SQLDouble ) extends Expression { override def sql = s"$Distance($identifier,($lat,$lon)) $operator $distance" - override val function: Option[SQLFunction] = Some(Distance) + override val functions: List[SQLFunction] = List(Distance) override def operator: SQLOperator = Le override def update(request: SQLSearchRequest): ElasticGeoDistance = this.copy(identifier = identifier.update(request)) @@ -333,7 +333,7 @@ case class SQLComparisonDateMath( override def update(request: SQLSearchRequest): SQLCriteria = this.copy(identifier = identifier.update(request)) - override def maybeValue: Option[SQLToken] = Some(dateTimeFunction) + override def maybeValue: Option[SQLToken] = Some(SQLScript(script)) override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index 603b3a3a..3e878ec1 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -30,11 +30,13 @@ package object sql { } trait SQLTokenWithFunction extends SQLToken { - def function: Option[SQLFunction] + def functions: List[SQLFunction] - lazy val aggregateFunction: Option[AggregateFunction] = function match { + lazy val aggregateFunction: Option[AggregateFunction] = functions.headOption match { case Some(af: AggregateFunction) => Some(af) - case _ => None + case other => + Console.println(this) + None } lazy val aggregation: Boolean = aggregateFunction.isDefined @@ -264,7 +266,7 @@ package object sql { distinct: Boolean = false, nested: Boolean = false, limit: Option[SQLLimit] = None, - function: Option[SQLFunction] = None, + functions: List[SQLFunction] = List.empty, fieldAlias: Option[String] = None, bucket: Option[SQLBucket] = None ) extends SQLExpr({ @@ -280,19 +282,17 @@ package object sql { parts.mkString(".").trim } } - function match { - case Some(f) => s"$f($sql)" - case _ => sql - } + functions.reverse.foldLeft(sql)((expr, fun) => { + fun.toSQL(expr) + }) }) with SQLSource with SQLTokenWithFunction { lazy val identifierName: String = - function match { - case Some(f) => s"${f.sql}($name)" - case _ => name - } + functions.reverse.foldLeft(name)((expr, fun) => { + fun.toSQL(expr) + }) lazy val nestedType: Option[String] = if (nested) Some(name.split('.').head) else None diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index c5311cfd..8032dd4b 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -96,6 +96,11 @@ object Queries { |having Country <> "USA" and City <> "Berlin" and count(CustomerID) > 1 and lastSeen > now - interval 7 day |order by Country asc""".stripMargin .replaceAll("\n", " ") + val parseDate = + "select identifier, count(identifier2) as ct, max(parse_date('yyyy-MM-dd')(createdAt)) as lastSeen from Table where identifier2 is not null group by identifier order by count(identifier2) desc" + val parseDateTime = + "select identifier, count(identifier2) as ct, max(parse_datetime('yyyy-MM-ddTHH:mm:ssZ')(createdAt)) as lastSeen from Table where identifier2 is not null group by identifier order by count(identifier2) desc" + } /** Created by smanciot on 15/02/17. @@ -399,4 +404,18 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { groupByWithHavingAndDateTimeFunctions ) } + + it should "parse parse_date function" in { + val result = SQLParser(parseDate) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + parseDate + ) + } + + it should "parse parse_date_time function" in { + val result = SQLParser(parseDateTime) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + parseDateTime + ) + } } From 334c32656cbb9af3bf1af4cdaec406c2b333c827 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Fri, 5 Sep 2025 19:08:06 +0200 Subject: [PATCH 11/22] add SQLType and SQLTypedFunction[IN, OUT], add validations for functions using sql type OUT -> IN, add YEAR, MONTH, DAY, HOUR, MINUTE, SECOND extractors --- .../sql/bridge/ElasticAggregation.scala | 7 +- .../elastic/sql/SQLQuerySpec.scala | 12 +- .../sql/bridge/ElasticAggregation.scala | 7 +- .../elastic/sql/SQLQuerySpec.scala | 8 +- .../softnetwork/elastic/sql/SQLFunction.scala | 227 ++++++++++++++---- .../softnetwork/elastic/sql/SQLParser.scala | 87 ++++++- .../app/softnetwork/elastic/sql/SQLType.scala | 9 + .../softnetwork/elastic/sql/SQLTypes.scala | 9 + .../elastic/sql/SQLValidator.scala | 18 ++ .../app/softnetwork/elastic/sql/package.scala | 5 + .../sql/SQLDateTimeFunctionSuite.scala | 95 ++++++++ .../elastic/sql/SQLParserSpec.scala | 2 +- 12 files changed, 420 insertions(+), 66 deletions(-) create mode 100644 sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala create mode 100644 sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala create mode 100644 sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala create mode 100644 sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index da15ce69..8f6021fe 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -102,9 +102,10 @@ object ElasticAggregation { ): Aggregation = { if (transformFuncs.nonEmpty) { val base = s"doc['$sourceField'].value" - val scriptSrc = transformFuncs.foldLeft(base) { - case (expr, f: SQLTransformFunction) => f.toPainless(expr) - case (expr, f) => f.toSQL(expr) // fallback + val orderedTransforms = transformFuncs.reverse + val scriptSrc = orderedTransforms.foldLeft(base) { + case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr) + case (expr, f) => f.toSQL(expr) // fallback } val script = Script(scriptSrc).lang("painless") buildScript(aggName, script) diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index d7088219..c1c3b710 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -882,7 +882,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "ct": { | "script": { | "lang": "painless", - | "source": "doc['createdAt'].value.minus(35, ChronoUnit.MINUTE)" + | "source": "doc['createdAt'].value.minus(35, ChronoUnit.MINUTES)" | } | } | }, @@ -984,7 +984,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "script": { | "script": { | "lang": "painless", - | "source": "return doc['createdAt'].value.toLocalTime() >= LocalTime.now().minus(10, ChronoUnit.MINUTE);" + | "source": "return doc['createdAt'].value.toLocalTime() >= LocalTime.now().minus(10, ChronoUnit.MINUTES);" | } | } | } @@ -1041,7 +1041,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": "lastSeen" | }, | "script": { - | "source": "(params.lastSeen != null) && (params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAY).toInstant().toEpochMilli())" + | "source": "(params.lastSeen != null) && (params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS).toInstant().toEpochMilli())" | } | } | } @@ -1212,7 +1212,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") - .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(",ChronoUnit", ", ChronoUnit") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") @@ -1259,7 +1259,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, LocalDateTime::from)" + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, LocalDateTime::from).truncatedTo(ChronoUnit.MINUTES)" | } | } | } @@ -1268,7 +1268,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") - .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(",ChronoUnit", ", ChronoUnit") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index 2a6adda2..a360a24f 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -101,9 +101,10 @@ object ElasticAggregation { ): Aggregation = { if (transformFuncs.nonEmpty) { val base = s"doc['$sourceField'].value" - val scriptSrc = transformFuncs.foldLeft(base) { - case (expr, f: SQLTransformFunction) => f.toPainless(expr) - case (expr, f) => f.toSQL(expr) // fallback + val orderedTransforms = transformFuncs.reverse + val scriptSrc = orderedTransforms.foldLeft(base) { + case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr) + case (expr, f) => f.toSQL(expr) // fallback } val script = Script(scriptSrc).lang("painless") buildScript(aggName, script) diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 4ca76076..f614673a 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -881,7 +881,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "ct": { | "script": { | "lang": "painless", - | "source": "doc['createdAt'].value.minus(35, ChronoUnit.MINUTE)" + | "source": "doc['createdAt'].value.minus(35, ChronoUnit.MINUTES)" | } | } | }, @@ -983,7 +983,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "script": { | "script": { | "lang": "painless", - | "source": "return doc['createdAt'].value.toLocalTime() >= LocalTime.now().minus(10, ChronoUnit.MINUTE);" + | "source": "return doc['createdAt'].value.toLocalTime() >= LocalTime.now().minus(10, ChronoUnit.MINUTES);" | } | } | } @@ -1040,7 +1040,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "lastSeen": "lastSeen" | }, | "script": { - | "source": "(params.lastSeen != null) && (params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAY).toInstant().toEpochMilli())" + | "source": "(params.lastSeen != null) && (params.lastSeen > ZonedDateTime.now(ZoneId.of('Z')).minus(7, ChronoUnit.DAYS).toInstant().toEpochMilli())" | } | } | } @@ -1256,7 +1256,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, LocalDateTime::from)" + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, LocalDateTime::from).truncatedTo(ChronoUnit.MINUTES)" | } | } | } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index dc5a0f22..2ff9cf1b 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -6,15 +6,34 @@ sealed trait SQLFunction extends SQLRegex { def toSQL(base: String): String = s"$sql($base)" } -sealed trait SQLTransformFunction extends SQLFunction { +trait SQLTypedFunction[In <: SQLType, Out <: SQLType] extends SQLFunction { + def inputType: In + def outputType: Out + def in(other: SQLTypedFunction[_, _]): Boolean = + inputType.typeId == other.outputType.asInstanceOf[SQLType].typeId || + (inputType.typeId == "temporal" && Set("date", "datetime").contains( + other.outputType.asInstanceOf[SQLType].typeId + )) + def out(other: SQLTypedFunction[_, _]): Boolean = + outputType.typeId == other.inputType.asInstanceOf[SQLType].typeId || + (outputType.typeId == "temporal" && Set("date", "datetime").contains( + other.inputType.asInstanceOf[SQLType].typeId + )) +} + +sealed trait SQLTransformFunction[In <: SQLType, Out <: SQLType] extends SQLTypedFunction[In, Out] { def toPainless(base: String): String } sealed trait ParametrizedFunction extends SQLFunction { def params: Seq[String] override def toSQL(base: String): String = { - val paramsStr = params.map(p => s"'$p'").mkString(", ") - s"$sql($paramsStr)($base)" + params match { + case Nil => s"$sql($base)" + case _ => + val paramsStr = params.mkString(", ") + s"$sql($paramsStr)($base)" + } } } @@ -28,41 +47,44 @@ case object Sum extends SQLExpr("sum") with AggregateFunction case object Distance extends SQLExpr("distance") with SQLFunction with SQLOperator sealed trait TimeUnit extends PainlessScript with MathScript { - lazy val regex: Regex = s"\\b(?i)${sql}s?\\b".r + lazy val regex: Regex = s"\\b(?i)$sql(s)?\\b".r - override def painless: String = s"ChronoUnit.${sql.toUpperCase()}" + override def painless: String = s"ChronoUnit.${sql.toUpperCase()}S" } sealed trait CalendarUnit extends TimeUnit sealed trait FixedUnit extends TimeUnit -case object Year extends SQLExpr("year") with CalendarUnit { - override def script: String = "y" -} -case object Month extends SQLExpr("month") with CalendarUnit { - override def script: String = "M" -} -case object Quarter extends SQLExpr("quarter") with CalendarUnit { - override def script: String = throw new IllegalArgumentException( - "Quarter must be converted to months (value * 3) before creating date-math" - ) -} -case object Week extends SQLExpr("week") with CalendarUnit { - override def script: String = "w" -} +object TimeUnit { + case object Year extends SQLExpr("year") with CalendarUnit { + override def script: String = "y" + } + case object Month extends SQLExpr("month") with CalendarUnit { + override def script: String = "M" + } + case object Quarter extends SQLExpr("quarter") with CalendarUnit { + override def script: String = throw new IllegalArgumentException( + "Quarter must be converted to months (value * 3) before creating date-math" + ) + } + case object Week extends SQLExpr("week") with CalendarUnit { + override def script: String = "w" + } -case object Day extends SQLExpr("day") with CalendarUnit with FixedUnit { - override def script: String = "d" -} + case object Day extends SQLExpr("day") with CalendarUnit with FixedUnit { + override def script: String = "d" + } + + case object Hour extends SQLExpr("hour") with FixedUnit { + override def script: String = "H" + } + case object Minute extends SQLExpr("minute") with FixedUnit { + override def script: String = "m" + } + case object Second extends SQLExpr("second") with FixedUnit { + override def script: String = "s" + } -case object Hour extends SQLExpr("hour") with FixedUnit { - override def script: String = "H" -} -case object Minute extends SQLExpr("minute") with FixedUnit { - override def script: String = "m" -} -case object Second extends SQLExpr("second") with FixedUnit { - override def script: String = "s" } case object Interval extends SQLExpr("interval") with SQLFunction with SQLRegex @@ -77,6 +99,8 @@ sealed trait TimeInterval extends PainlessScript with MathScript { override def script: String = TimeInterval.script(this) } +import TimeUnit._ + case class CalendarInterval(value: Int, unit: CalendarUnit) extends TimeInterval case class FixedInterval(value: Int, unit: FixedUnit) extends TimeInterval @@ -94,17 +118,21 @@ object TimeInterval { sealed trait DateTimeFunction extends SQLFunction +sealed trait DateFunction extends DateTimeFunction + +sealed trait TimeFunction extends DateTimeFunction + sealed trait CurrentDateTimeFunction extends DateTimeFunction with PainlessScript with MathScript { override def painless: String = "ZonedDateTime.of(LocalDateTime.now(), ZoneId.of('Z')).toLocalDateTime()" override def script: String = "now" } -sealed trait CurrentDateFunction extends CurrentDateTimeFunction { +sealed trait CurrentDateFunction extends CurrentDateTimeFunction with DateFunction { override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalDate()" } -sealed trait CurrentTimeFunction extends CurrentDateTimeFunction { +sealed trait CurrentTimeFunction extends CurrentDateTimeFunction with TimeFunction { override def painless: String = "ZonedDateTime.of(LocalDate.now(), ZoneId.of('Z')).toLocalTime()" } @@ -126,29 +154,142 @@ case object Now extends SQLExpr("now") with CurrentDateTimeFunction case object NowWithParens extends SQLExpr("now()") with CurrentDateTimeFunction -case class DateAdd(interval: TimeInterval) extends SQLExpr("date_add") with DateTimeFunction -case class DateDiff(interval: TimeInterval) extends SQLExpr("date_diff") with DateTimeFunction -case class DateSub(interval: TimeInterval) extends SQLExpr("date_sub") with DateTimeFunction -case class DateTrunc(unit: TimeUnit) extends SQLExpr("date_trunc") with DateTimeFunction -case class Extract(unit: TimeUnit) extends SQLExpr("extract") with DateTimeFunction -case class FormatDate(format: String) extends SQLExpr("format_date") with DateTimeFunction +// case class DateDiff(interval: TimeInterval) extends SQLExpr("date_diff") with DateTimeFunction + +case class DateTrunc(unit: TimeUnit) + extends SQLExpr("date_trunc") + with DateTimeFunction + with SQLTransformFunction[SQLTemporal, SQLTemporal] + with ParametrizedFunction { + override def inputType: SQLTemporal = SQLTypes.Temporal // par défaut + override def outputType: SQLTemporal = SQLTypes.Temporal // idem + override def params: Seq[String] = Seq(unit.sql) + override def toPainless(base: String): String = s"$base.truncatedTo(${unit.painless})" +} + +case class Extract(unit: TimeUnit, override val sql: String = "extract") + extends SQLExpr(sql) + with DateTimeFunction + with SQLTransformFunction[SQLTemporal, SQLNumber] + with ParametrizedFunction { + override def inputType: SQLTemporal = SQLTypes.Temporal + override def outputType: SQLNumber = SQLTypes.Number + override def params: Seq[String] = Seq(unit.sql) + override def toPainless(base: String): String = s"$base.get(${unit.painless})" +} + +object YEAR extends Extract(Year, Year.sql) { + override def params: Seq[String] = Seq.empty +} + +object MONTH extends Extract(Month, Month.sql) { + override def params: Seq[String] = Seq.empty +} + +object DAY extends Extract(Day, Day.sql) { + override def params: Seq[String] = Seq.empty +} + +object HOUR extends Extract(Hour, Hour.sql) { + override def params: Seq[String] = Seq.empty +} + +object MINUTE extends Extract(Minute, Minute.sql) { + override def params: Seq[String] = Seq.empty +} + +object SECOND extends Extract(Second, Second.sql) { + override def params: Seq[String] = Seq.empty +} + +case class DateAdd(interval: TimeInterval) + extends SQLExpr("date_add") + with DateFunction + with SQLTransformFunction[SQLDate, SQLDate] + with ParametrizedFunction { + override def inputType: SQLDate = SQLTypes.Date + override def outputType: SQLDate = SQLTypes.Date + override def params: Seq[String] = Seq(interval.sql) + override def toPainless(base: String): String = s"$base.plus(${interval.painless})" +} + +case class DateSub(interval: TimeInterval) + extends SQLExpr("date_sub") + with DateFunction + with SQLTransformFunction[SQLDate, SQLDate] + with ParametrizedFunction { + override def inputType: SQLDate = SQLTypes.Date + override def outputType: SQLDate = SQLTypes.Date + override def params: Seq[String] = Seq(interval.sql) + override def toPainless(base: String): String = s"$base.minus(${interval.painless})" +} case class ParseDate(format: String) extends SQLExpr("parse_date") - with DateTimeFunction - with SQLTransformFunction + with DateFunction + with SQLTransformFunction[SQLString, SQLDate] with ParametrizedFunction { - override def params: Seq[String] = Seq(format) + override def inputType: SQLString = SQLTypes.String + override def outputType: SQLDate = SQLTypes.Date + override def params: Seq[String] = Seq(s"'$format'") override def toPainless(base: String): String = s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDate::from)" } +case class FormatDate(format: String) + extends SQLExpr("format_date") + with DateFunction + with SQLTransformFunction[SQLDate, SQLString] + with ParametrizedFunction { + override def inputType: SQLDate = SQLTypes.Date + override def outputType: SQLString = SQLTypes.String + override def params: Seq[String] = Seq(s"'$format'") + override def toPainless(base: String): String = + s"DateTimeFormatter.ofPattern('$format').format($base)" +} + +case class DateTimeAdd(interval: TimeInterval) + extends SQLExpr("datetime_add") + with DateTimeFunction + with SQLTransformFunction[SQLDateTime, SQLDateTime] + with ParametrizedFunction { + override def inputType: SQLDateTime = SQLTypes.DateTime + override def outputType: SQLDateTime = SQLTypes.DateTime + override def params: Seq[String] = Seq(interval.sql) + override def toPainless(base: String): String = s"$base.plus(${interval.painless})" +} + +case class DateTimeSub(interval: TimeInterval) + extends SQLExpr("datetime_sub") + with DateTimeFunction + with SQLTransformFunction[SQLDateTime, SQLDateTime] + with ParametrizedFunction { + override def inputType: SQLDateTime = SQLTypes.DateTime + override def outputType: SQLDateTime = SQLTypes.DateTime + override def params: Seq[String] = Seq(interval.sql) + override def toPainless(base: String): String = s"$base.minus(${interval.painless})" +} + case class ParseDateTime(format: String) extends SQLExpr("parse_datetime") with DateTimeFunction - with SQLTransformFunction + with SQLTransformFunction[SQLString, SQLDateTime] with ParametrizedFunction { - override def params: Seq[String] = Seq(format) + override def inputType: SQLString = SQLTypes.String + override def outputType: SQLDateTime = SQLTypes.DateTime + override def params: Seq[String] = Seq(s"'$format'") override def toPainless(base: String): String = s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDateTime::from)" } + +case class FormatDateTime(format: String) + extends SQLExpr("format_datetime") + with DateTimeFunction + with SQLTransformFunction[SQLDateTime, SQLString] + with ParametrizedFunction { + override def inputType: SQLDateTime = SQLTypes.DateTime + override def outputType: SQLString = SQLTypes.String + override def params: Seq[String] = Seq(s"'$format'") + override def toPainless(base: String): String = + s"DateTimeFormatter.ofPattern('$format').format($base)" +} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 631a2ce8..714a0c72 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -2,6 +2,7 @@ package app.softnetwork.elastic.sql import scala.util.parsing.combinator.{PackratParsers, RegexParsers} import scala.util.parsing.input.CharSequenceReader +import TimeUnit._ /** Created by smanciot on 27/06/2018. * @@ -101,10 +102,12 @@ trait SQLParser extends RegexParsers with PackratParsers { def second: PackratParser[TimeUnit] = Second.regex ^^ (_ => Second) + def time_unit: PackratParser[TimeUnit] = + year | month | quarter | week | day | hour | minute | second + def interval: PackratParser[TimeInterval] = - Interval.regex ~ long ~ (year | month | quarter | week | day | hour | minute | second) ^^ { - case _ ~ l ~ u => - TimeInterval(l.value.toInt, u) + Interval.regex ~ long ~ time_unit ^^ { case _ ~ l ~ u => + TimeInterval(l.value.toInt, u) } def current_date: PackratParser[CurrentDateTimeFunction] = @@ -143,28 +146,96 @@ trait SQLParser extends RegexParsers with PackratParsers { ) } - def parse_date: PackratParser[DateTimeFunction] = + def date_trunc: PackratParser[SQLTypedFunction[SQLTemporal, SQLTemporal]] = + "(?i)date_trunc".r ~ start ~ time_unit ~ end ^^ { case _ ~ _ ~ u ~ _ => + DateTrunc(u) + } + + def extract: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + "(?i)extract".r ~ start ~ time_unit ~ end ^^ { case _ ~ _ ~ u ~ _ => + Extract(u) + } + + def extract_year: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + Year.regex ^^ (_ => YEAR) + + def extract_month: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + Month.regex ^^ (_ => MONTH) + + def extract_day: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = Day.regex ^^ (_ => DAY) + + def extract_hour: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + Hour.regex ^^ (_ => HOUR) + + def extract_minute: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + Minute.regex ^^ (_ => MINUTE) + + def extract_second: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + Second.regex ^^ (_ => SECOND) + + def extractors: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + extract | extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second + + def date_add: PackratParser[DateFunction] = + "(?i)date_add".r ~ start ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ => + DateAdd(i) + } + + def date_sub: PackratParser[DateFunction] = + "(?i)date_sub".r ~ start ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ => + DateSub(i) + } + + def parse_date: PackratParser[DateFunction] = "(?i)parse_date".r ~ start ~ literal ~ end ^^ { case _ ~ _ ~ f ~ _ => ParseDate(f.value) } + def format_date: PackratParser[DateFunction] = + "(?i)format_date".r ~ start ~ literal ~ end ^^ { case _ ~ _ ~ f ~ _ => + FormatDate(f.value) + } + + def date_functions: PackratParser[DateFunction] = date_add | date_sub | parse_date | format_date + + def datetime_add: PackratParser[DateTimeFunction] = + "(?i)datetime_add".r ~ start ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ => + DateTimeAdd(i) + } + + def datetime_sub: PackratParser[DateTimeFunction] = + "(?i)datetime_sub".r ~ start ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ => + DateTimeSub(i) + } + def parse_datetime: PackratParser[DateTimeFunction] = "(?i)parse_datetime".r ~ start ~ literal ~ end ^^ { case _ ~ _ ~ f ~ _ => ParseDateTime(f.value) } - def date_functions: PackratParser[DateTimeFunction] = parse_date | parse_datetime + def format_datetime: PackratParser[DateTimeFunction] = + "(?i)format_datetime".r ~ start ~ literal ~ end ^^ { case _ ~ _ ~ f ~ _ => + FormatDateTime(f.value) + } + + def datetime_functions: PackratParser[DateTimeFunction] = + datetime_add | datetime_sub | parse_datetime | format_datetime def aggregates: PackratParser[AggregateFunction] = count | min | max | avg | sum def distance: PackratParser[SQLFunction] = Distance.regex ^^ (_ => Distance) - def sql_functions: PackratParser[SQLFunction] = aggregates | distance | date_functions + def sql_functions: PackratParser[SQLFunction] = + aggregates | distance | date_trunc | extractors | date_functions | datetime_functions private val regexIdentifier = """[\*a-zA-Z_\-][a-zA-Z0-9_\-\.\[\]\*]*""" def identifierWithFunction: PackratParser[SQLIdentifier] = rep1sep(sql_functions, start) ~ start.? ~ identifier ~ rep1(end) ^^ { case f ~ _ ~ i ~ _ => + SQLValidator.validateChain(f) match { + case Left(error) => throw new IllegalArgumentException(error) + case _ => + } i.copy(functions = f) } @@ -580,6 +651,10 @@ trait SQLOrderByParser { def fieldWithFunction: PackratParser[(String, List[SQLFunction])] = rep1sep(sql_functions, start) ~ start.? ~ fieldName ~ rep1(end) ^^ { case f ~ _ ~ n ~ _ => + SQLValidator.validateChain(f) match { + case Left(error) => throw new IllegalArgumentException(error) + case _ => + } (n, f) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala new file mode 100644 index 00000000..11988dfc --- /dev/null +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala @@ -0,0 +1,9 @@ +package app.softnetwork.elastic.sql + +sealed trait SQLType { def typeId: String } + +trait SQLTemporal extends SQLType +trait SQLDate extends SQLTemporal +trait SQLDateTime extends SQLTemporal +trait SQLNumber extends SQLType +trait SQLString extends SQLType diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala new file mode 100644 index 00000000..85873e9b --- /dev/null +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala @@ -0,0 +1,9 @@ +package app.softnetwork.elastic.sql + +object SQLTypes { + case object Temporal extends SQLTemporal { val typeId = "temporal" } + case object Date extends SQLDate { val typeId = "date" } + case object DateTime extends SQLDateTime { val typeId = "datetime" } + case object Number extends SQLNumber { val typeId = "number" } + case object String extends SQLString { val typeId = "string" } +} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala new file mode 100644 index 00000000..d0612077 --- /dev/null +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala @@ -0,0 +1,18 @@ +package app.softnetwork.elastic.sql + +object SQLValidator { + + def validateChain(functions: List[SQLFunction]): Either[String, Unit] = { + functions + .collect { case f: SQLTypedFunction[_, _] => f } + .sliding(2) + .foreach { + case Seq(f1, f2) => + if (!f1.in(f2)) { + return Left(s"Type mismatch: ${f2.outputType} -> ${f1.inputType}") + } + case _ => // ok + } + Right(()) + } +} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index 3e878ec1..d2b4e3a2 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -40,6 +40,11 @@ package object sql { } lazy val aggregation: Boolean = aggregateFunction.isDefined + + def validate(): Either[String, Unit] = { + SQLValidator.validateChain(functions) + } + } trait Updateable extends SQLToken { diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala new file mode 100644 index 00000000..8385201d --- /dev/null +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala @@ -0,0 +1,95 @@ +package app.softnetwork.elastic.sql + +import org.scalatest.funsuite.AnyFunSuite +import TimeUnit._ + +class SQLDateTimeFunctionSuite extends AnyFunSuite { + + // Base d'exemple + val baseDate = "doc['createdAt'].value" + + // Liste de toutes les fonctions transformables avec leurs types + val transformFunctions: Seq[SQLTransformFunction[_, _]] = Seq( + ParseDate("yyyy-MM-dd"), + ParseDateTime("yyyy-MM-dd HH:mm:ss"), + DateAdd(TimeInterval(1, Day)), + DateSub(TimeInterval(2, Month)), + DateTimeAdd(TimeInterval(3, Hour)), + DateTimeSub(TimeInterval(30, Minute)), + DateTrunc(Day), + Extract(Day), + FormatDate("yyyy-MM-dd"), + FormatDateTime("yyyy-MM-dd HH:mm:ss"), + YEAR, + MONTH, + DAY, + HOUR, + MINUTE, + SECOND + ) + + // Fonction pour chaîner une séquence de transformations en vérifiant les types + def chainTransformsTyped( + base: String, + transforms: Seq[SQLTransformFunction[_, _]] + ): String = { + require(transforms.nonEmpty, "No transforms provided") + + val initial: (String, SQLType) = + (transforms.head.toPainless(base), transforms.head.outputType.asInstanceOf[SQLType]) + + val (finalExpr, _) = transforms.tail.foldLeft(initial) { + case ((expr, currentType), t: SQLTypedFunction[_, _]) => + if (!currentType.getClass.isAssignableFrom(t.inputType.getClass)) { + throw new IllegalArgumentException( + s"Type mismatch: expected ${currentType.getClass.getSimpleName}, got ${t.inputType.getClass.getSimpleName}" + ) + } + (t.toPainless(expr), t.outputType.asInstanceOf[SQLType]) + } + + finalExpr + } + + // Générer dynamiquement tous les chaînages valides jusqu'à N fonctions + def generateChains( + functions: Seq[SQLTransformFunction[_, _]], + maxLength: Int + ): Seq[Seq[SQLTransformFunction[_, _]]] = { + if (maxLength <= 1) functions.map(Seq(_)) + else { + val shorter = generateChains(functions, maxLength - 1) + for { + chain <- shorter + f <- functions + if f.inputType.getClass.isAssignableFrom(chain.last.outputType.getClass) + } yield chain :+ f + } + } + + // Tester tous les chaînages pour N=2 et N=3 + val chains2: Seq[Seq[SQLTransformFunction[_, _]]] = + generateChains(transformFunctions, 2) + val chains3: Seq[Seq[SQLTransformFunction[_, _]]] = + generateChains(transformFunctions, 3) + + (chains2 ++ chains3).zipWithIndex.foreach { case (chain, idx) => + val names = chain.map(_.sql).mkString(" -> ") + test(s"Valid chain $idx: $names") { + val chained = chainTransformsTyped(baseDate, chain) + val expected = chain.reverse.tail.foldLeft(chain.last.toPainless(baseDate)) { (expr, f) => + f.toPainless(expr) + } + // On ne teste que la génération de code Painless sans évaluer le résultat + assert(chained.nonEmpty) + } + } + + // Test simple pour chaque fonction individuelle + transformFunctions.foreach { f => + test(s"Single transformation ${f.sql}") { + val result = f.toPainless(baseDate) + assert(result.nonEmpty) + } + } +} diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 8032dd4b..425f3415 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -99,7 +99,7 @@ object Queries { val parseDate = "select identifier, count(identifier2) as ct, max(parse_date('yyyy-MM-dd')(createdAt)) as lastSeen from Table where identifier2 is not null group by identifier order by count(identifier2) desc" val parseDateTime = - "select identifier, count(identifier2) as ct, max(parse_datetime('yyyy-MM-ddTHH:mm:ssZ')(createdAt)) as lastSeen from Table where identifier2 is not null group by identifier order by count(identifier2) desc" + "select identifier, count(identifier2) as ct, max(date_trunc(minute)(parse_datetime('yyyy-MM-ddTHH:mm:ssZ')(createdAt))) as lastSeen from Table where identifier2 is not null group by identifier order by count(identifier2) desc" } From dbff05d0f1c27bbf68ee364047cecf81eba49e65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Fri, 5 Sep 2025 19:15:08 +0200 Subject: [PATCH 12/22] add SQLScript --- sql/src/main/scala/app/softnetwork/elastic/sql/package.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index d2b4e3a2..c8e52d42 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -334,4 +334,6 @@ package object sql { } } } + + case class SQLScript(script: String) extends SQLExpr(script) } From 04a1fc448121d723889605d2341c880c51e41230 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Fri, 5 Sep 2025 19:31:49 +0200 Subject: [PATCH 13/22] to fix query specifications --- .../test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala | 1 - .../test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala | 1 - 2 files changed, 2 deletions(-) diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index c1c3b710..8658f195 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1268,7 +1268,6 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") - .replaceAll(",ChronoUnit", ", ChronoUnit") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index f614673a..9d0dc857 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1265,7 +1265,6 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | } |}""".stripMargin .replaceAll("\\s", "") - .replaceAll("ChronoUnit", " ChronoUnit") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") From 522faa0f5ad0ab5e174a999d857e6794cebf0ddc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Fri, 5 Sep 2025 20:33:11 +0200 Subject: [PATCH 14/22] test extract function --- .../softnetwork/elastic/sql/SQLQuerySpec.scala | 2 +- .../softnetwork/elastic/sql/SQLQuerySpec.scala | 2 +- .../softnetwork/elastic/sql/SQLParserSpec.scala | 17 +++++++++++++++-- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 8658f195..16c46dab 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1259,7 +1259,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, LocalDateTime::from).truncatedTo(ChronoUnit.MINUTES)" + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, LocalDateTime::from).truncatedTo(ChronoUnit.MINUTES).get(ChronoUnit.YEARS)" | } | } | } diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 9d0dc857..53dc8045 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1256,7 +1256,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, LocalDateTime::from).truncatedTo(ChronoUnit.MINUTES)" + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, LocalDateTime::from).truncatedTo(ChronoUnit.MINUTES).get(ChronoUnit.YEARS)" | } | } | } diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 425f3415..71991fa7 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -98,8 +98,21 @@ object Queries { .replaceAll("\n", " ") val parseDate = "select identifier, count(identifier2) as ct, max(parse_date('yyyy-MM-dd')(createdAt)) as lastSeen from Table where identifier2 is not null group by identifier order by count(identifier2) desc" - val parseDateTime = - "select identifier, count(identifier2) as ct, max(date_trunc(minute)(parse_datetime('yyyy-MM-ddTHH:mm:ssZ')(createdAt))) as lastSeen from Table where identifier2 is not null group by identifier order by count(identifier2) desc" + val parseDateTime: String = + """select identifier, count(identifier2) as ct, + |max( + |year( + |date_trunc(minute)( + |parse_datetime('yyyy-MM-ddTHH:mm:ssZ')( + |createdAt + |)))) as lastSeen + |from Table + |where identifier2 is not null + |group by identifier + |order by count(identifier2) desc""".stripMargin + .replaceAll("\n", " ") + .replaceAll("\\( ", "(") + .replaceAll(" \\)", ")") } From 2bbc83e01a7eeb84e500368035be1039aee7d5c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Sat, 6 Sep 2025 06:58:44 +0200 Subject: [PATCH 15/22] add diff (!=) operator --- .../elastic/sql/bridge/package.scala | 42 +++++++++---------- .../elastic/sql/bridge/package.scala | 32 +++++++------- .../softnetwork/elastic/sql/SQLFunction.scala | 6 +-- .../softnetwork/elastic/sql/SQLOperator.scala | 13 +++--- .../softnetwork/elastic/sql/SQLParser.scala | 6 ++- .../elastic/sql/SQLValidator.scala | 2 +- .../app/softnetwork/elastic/sql/package.scala | 32 +++++++------- .../elastic/sql/SQLParserSpec.scala | 2 +- 8 files changed, 68 insertions(+), 67 deletions(-) diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index 0cc2adca..3f990f01 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -146,7 +146,7 @@ package object bridge { value match { case n: SQLNumeric[_] if !aggregation => operator match { - case _: Ge.type => + case Ge => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -159,7 +159,7 @@ package object bridge { d => rangeQuery(identifier.name) gte d ) } - case _: Gt.type => + case Gt => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -172,7 +172,7 @@ package object bridge { d => rangeQuery(identifier.name) gt d ) } - case _: Le.type => + case Le => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -185,7 +185,7 @@ package object bridge { d => rangeQuery(identifier.name) lte d ) } - case _: Lt.type => + case Lt => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -198,7 +198,7 @@ package object bridge { d => rangeQuery(identifier.name) lt d ) } - case _: Eq.type => + case Eq => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -211,7 +211,7 @@ package object bridge { d => termQuery(identifier.name, d) ) } - case _: Ne.type => + case Ne | Diff => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -228,49 +228,49 @@ package object bridge { } case l: SQLLiteral if !aggregation => operator match { - case _: Like.type => + case Like => maybeNot match { case Some(_) => not(regexQuery(identifier.name, toRegex(l.value))) case _ => regexQuery(identifier.name, toRegex(l.value)) } - case _: Ge.type => + case Ge => maybeNot match { case Some(_) => rangeQuery(identifier.name) lt l.value case _ => rangeQuery(identifier.name) gte l.value } - case _: Gt.type => + case Gt => maybeNot match { case Some(_) => rangeQuery(identifier.name) lte l.value case _ => rangeQuery(identifier.name) gt l.value } - case _: Le.type => + case Le => maybeNot match { case Some(_) => rangeQuery(identifier.name) gt l.value case _ => rangeQuery(identifier.name) lte l.value } - case _: Lt.type => + case Lt => maybeNot match { case Some(_) => rangeQuery(identifier.name) gte l.value case _ => rangeQuery(identifier.name) lt l.value } - case _: Eq.type => + case Eq => maybeNot match { case Some(_) => not(termQuery(identifier.name, l.value)) case _ => termQuery(identifier.name, l.value) } - case _: Ne.type => + case Ne | Diff => maybeNot match { case Some(_) => termQuery(identifier.name, l.value) @@ -281,14 +281,14 @@ package object bridge { } case b: SQLBoolean if !aggregation => operator match { - case _: Eq.type => + case Eq => maybeNot match { case Some(_) => not(termQuery(identifier.name, b.value)) case _ => termQuery(identifier.name, b.value) } - case _: Ne.type => + case Ne | Diff => maybeNot match { case Some(_) => termQuery(identifier.name, b.value) @@ -311,12 +311,12 @@ package object bridge { case _ => val op = if (maybeNot.isDefined) operator.not else operator op match { - case Gt => rangeQuery(identifier.name) gt script - case Ge => rangeQuery(identifier.name) gte script - case Lt => rangeQuery(identifier.name) lt script - case Le => rangeQuery(identifier.name) lte script - case Eq => rangeQuery(identifier.name) gte script lte script - case Ne => not(rangeQuery(identifier.name) gte script lte script) + case Gt => rangeQuery(identifier.name) gt script + case Ge => rangeQuery(identifier.name) gte script + case Lt => rangeQuery(identifier.name) lt script + case Le => rangeQuery(identifier.name) lte script + case Eq => rangeQuery(identifier.name) gte script lte script + case Ne | Diff => not(rangeQuery(identifier.name) gte script lte script) } } } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index ea9d4695..bd6f12b0 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -147,7 +147,7 @@ package object bridge { value match { case n: SQLNumeric[_] if !aggregation => operator match { - case _: Ge.type => + case Ge => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -160,7 +160,7 @@ package object bridge { d => rangeQuery(identifier.name) gte d ) } - case _: Gt.type => + case Gt => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -173,7 +173,7 @@ package object bridge { d => rangeQuery(identifier.name) gt d ) } - case _: Le.type => + case Le => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -186,7 +186,7 @@ package object bridge { d => rangeQuery(identifier.name) lte d ) } - case _: Lt.type => + case Lt => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -199,7 +199,7 @@ package object bridge { d => rangeQuery(identifier.name) lt d ) } - case _: Eq.type => + case Eq => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -212,7 +212,7 @@ package object bridge { d => termQuery(identifier.name, d) ) } - case _: Ne.type => + case Ne | Diff => maybeNot match { case Some(_) => applyNumericOp(n)( @@ -229,49 +229,49 @@ package object bridge { } case l: SQLLiteral if !aggregation => operator match { - case _: Like.type => + case Like => maybeNot match { case Some(_) => not(regexQuery(identifier.name, toRegex(l.value))) case _ => regexQuery(identifier.name, toRegex(l.value)) } - case _: Ge.type => + case Ge => maybeNot match { case Some(_) => rangeQuery(identifier.name) lt l.value case _ => rangeQuery(identifier.name) gte l.value } - case _: Gt.type => + case Gt => maybeNot match { case Some(_) => rangeQuery(identifier.name) lte l.value case _ => rangeQuery(identifier.name) gt l.value } - case _: Le.type => + case Le => maybeNot match { case Some(_) => rangeQuery(identifier.name) gt l.value case _ => rangeQuery(identifier.name) lte l.value } - case _: Lt.type => + case Lt => maybeNot match { case Some(_) => rangeQuery(identifier.name) gte l.value case _ => rangeQuery(identifier.name) lt l.value } - case _: Eq.type => + case Eq => maybeNot match { case Some(_) => not(termQuery(identifier.name, l.value)) case _ => termQuery(identifier.name, l.value) } - case _: Ne.type => + case Ne | Diff => maybeNot match { case Some(_) => termQuery(identifier.name, l.value) @@ -282,14 +282,14 @@ package object bridge { } case b: SQLBoolean if !aggregation => operator match { - case _: Eq.type => + case Eq => maybeNot match { case Some(_) => not(termQuery(identifier.name, b.value)) case _ => termQuery(identifier.name, b.value) } - case _: Ne.type => + case Ne | Diff => maybeNot match { case Some(_) => termQuery(identifier.name, b.value) @@ -317,7 +317,7 @@ package object bridge { case Lt => rangeQuery(identifier.name) lt script case Le => rangeQuery(identifier.name) lte script case Eq => rangeQuery(identifier.name) gte script lte script - case Ne => not(rangeQuery(identifier.name) gte script lte script) + case Ne | Diff => not(rangeQuery(identifier.name) gte script lte script) } } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 2ff9cf1b..6ae1599b 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -9,12 +9,12 @@ sealed trait SQLFunction extends SQLRegex { trait SQLTypedFunction[In <: SQLType, Out <: SQLType] extends SQLFunction { def inputType: In def outputType: Out - def in(other: SQLTypedFunction[_, _]): Boolean = + def from(other: SQLTypedFunction[_, _]): Boolean = inputType.typeId == other.outputType.asInstanceOf[SQLType].typeId || (inputType.typeId == "temporal" && Set("date", "datetime").contains( other.outputType.asInstanceOf[SQLType].typeId )) - def out(other: SQLTypedFunction[_, _]): Boolean = + def to(other: SQLTypedFunction[_, _]): Boolean = outputType.typeId == other.inputType.asInstanceOf[SQLType].typeId || (outputType.typeId == "temporal" && Set("date", "datetime").contains( other.inputType.asInstanceOf[SQLType].typeId @@ -154,8 +154,6 @@ case object Now extends SQLExpr("now") with CurrentDateTimeFunction case object NowWithParens extends SQLExpr("now()") with CurrentDateTimeFunction -// case class DateDiff(interval: TimeInterval) extends SQLExpr("date_diff") with DateTimeFunction - case class DateTrunc(unit: TimeUnit) extends SQLExpr("date_trunc") with DateTimeFunction diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala index 2bc52639..c3a073b0 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala @@ -21,17 +21,18 @@ sealed trait SQLComparisonOperator extends SQLExpressionOperator with PainlessSc } def not: SQLComparisonOperator = this match { - case Eq => Ne - case Ne => Eq - case Ge => Lt - case Gt => Le - case Le => Gt - case Lt => Ge + case Eq => Ne + case Ne | Diff => Eq + case Ge => Lt + case Gt => Le + case Le => Gt + case Lt => Ge } } case object Eq extends SQLExpr("=") with SQLComparisonOperator case object Ne extends SQLExpr("<>") with SQLComparisonOperator +case object Diff extends SQLExpr("!=") with SQLComparisonOperator case object Ge extends SQLExpr(">=") with SQLComparisonOperator case object Gt extends SQLExpr(">") with SQLComparisonOperator case object Le extends SQLExpr("<=") with SQLComparisonOperator diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 714a0c72..48143d4e 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -320,8 +320,10 @@ trait SQLWhereParser { private def ne: PackratParser[SQLComparisonOperator] = Ne.sql ^^ (_ => Ne) + private def diff: PackratParser[SQLComparisonOperator] = Diff.sql ^^ (_ => Diff) + private def equality: PackratParser[SQLExpression] = - not.? ~ (identifierWithFunction | identifier) ~ (eq | ne) ~ (boolean | literal | double | long) ^^ { + not.? ~ (identifierWithFunction | identifier) ~ (eq | ne | diff) ~ (boolean | literal | double | long) ^^ { case n ~ i ~ o ~ v => SQLExpression(i, o, v, n) } @@ -408,7 +410,7 @@ trait SQLWhereParser { } private def dateTimeComparison: PackratParser[SQLComparisonDateMath] = - not.? ~ (identifierWithFunction | identifier) ~ (eq | ne | ge | gt | le | lt) ~ (current_date | current_time | current_timestamp | now) ~ arithmeticOperator.? ~ interval.? ^^ { + not.? ~ (identifierWithFunction | identifier) ~ (eq | ne | diff | ge | gt | le | lt) ~ (current_date | current_time | current_timestamp | now) ~ arithmeticOperator.? ~ interval.? ^^ { case n ~ i ~ o ~ dt ~ ao ~ it => SQLComparisonDateMath(i, o, dt, ao, it, n) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala index d0612077..9a129ddf 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala @@ -8,7 +8,7 @@ object SQLValidator { .sliding(2) .foreach { case Seq(f1, f2) => - if (!f1.in(f2)) { + if (!f1.from(f2)) { return Left(s"Type mismatch: ${f2.outputType} -> ${f1.inputType}") } case _ => // ok diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index c8e52d42..5364df34 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -67,13 +67,13 @@ package object sql { None else operator match { - case Some(_: Eq.type) => values.find(_ == value) - case Some(_: Ne.type) => values.find(_ != value) - case Some(_: Ge.type) => values.filter(_ >= value).sorted.reverse.headOption - case Some(_: Gt.type) => values.filter(_ > value).sorted.reverse.headOption - case Some(_: Le.type) => values.filter(_ <= value).sorted.headOption - case Some(_: Lt.type) => values.filter(_ < value).sorted.headOption - case _ => values.headOption + case Some(Eq) => values.find(_ == value) + case Some(Ne | Diff) => values.find(_ != value) + case Some(Ge) => values.filter(_ >= value).sorted.reverse.headOption + case Some(Gt) => values.filter(_ > value).sorted.reverse.headOption + case Some(Le) => values.filter(_ <= value).sorted.headOption + case Some(Lt) => values.filter(_ < value).sorted.headOption + case _ => values.headOption } } def painless: String = value match { @@ -107,11 +107,11 @@ package object sql { separator: String = "|" )(implicit ev: R => Ordered[R]): Option[R] = { operator match { - case Some(_: Eq.type) => values.find(v => v.toString contentEquals value) - case Some(_: Ne.type) => values.find(v => !(v.toString contentEquals value)) - case Some(_: Like.type) => values.find(v => pattern.matcher(v.toString).matches()) - case None => Some(values.mkString(separator)) - case _ => super.choose(values, operator, separator) + case Some(Eq) => values.find(v => v.toString contentEquals value) + case Some(Ne | Diff) => values.find(v => !(v.toString contentEquals value)) + case Some(Like) => values.find(v => pattern.matcher(v.toString).matches()) + case None => Some(values.mkString(separator)) + case _ => super.choose(values, operator, separator) } } } @@ -228,10 +228,10 @@ package object sql { value.choose[T](values, Some(operator)) case _ => function match { - case Some(_: Min.type) => Some(values.min) - case Some(_: Max.type) => Some(values.max) - // FIXME case Some(_: SQLSum.type) => Some(values.sum) - // FIXME case Some(_: SQLAvg.type) => Some(values.sum / values.length ) + case Some(Min) => Some(values.min) + case Some(Max) => Some(values.max) + // FIXME case Some(SQLSum) => Some(values.sum) + // FIXME case Some(SQLAvg) => Some(values.sum / values.length ) case _ => values.headOption } } diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 71991fa7..af3421cf 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -93,7 +93,7 @@ object Queries { """select count(CustomerID) as cnt, City, Country, max(createdAt) as lastSeen |from Table |group by Country, City - |having Country <> "USA" and City <> "Berlin" and count(CustomerID) > 1 and lastSeen > now - interval 7 day + |having Country <> "USA" and City != "Berlin" and count(CustomerID) > 1 and lastSeen > now - interval 7 day |order by Country asc""".stripMargin .replaceAll("\n", " ") val parseDate = From 61370014f79637bd50d2ef694fee32eee48419ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Sat, 6 Sep 2025 12:17:13 +0200 Subject: [PATCH 16/22] add SQLFunctionChain, SQLBinaryFunction, SQLFunctionField, implements date_diff --- .../sql/bridge/ElasticAggregation.scala | 9 +- .../sql/bridge/ElasticAggregation.scala | 9 +- .../softnetwork/elastic/sql/SQLFunction.scala | 95 +++++++++++++++---- .../softnetwork/elastic/sql/SQLOrderBy.scala | 9 +- .../softnetwork/elastic/sql/SQLParser.scala | 69 ++++++++++---- .../softnetwork/elastic/sql/SQLSelect.scala | 18 +++- .../app/softnetwork/elastic/sql/SQLType.scala | 9 ++ .../softnetwork/elastic/sql/SQLTypes.scala | 1 + .../elastic/sql/SQLValidator.scala | 28 ++++-- .../softnetwork/elastic/sql/SQLWhere.scala | 2 +- .../app/softnetwork/elastic/sql/package.scala | 23 +---- .../sql/SQLDateTimeFunctionSuite.scala | 2 +- .../elastic/sql/SQLParserSpec.scala | 18 ++++ 13 files changed, 204 insertions(+), 88 deletions(-) diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index 8f6021fe..4e03dcfa 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -12,7 +12,7 @@ import app.softnetwork.elastic.sql.{ Min, SQLBucket, SQLCriteria, - SQLTransformFunction, + SQLFunctionUtils, SortOrder, Sum } @@ -101,12 +101,7 @@ object ElasticAggregation { buildScript: (String, Script) => Aggregation ): Aggregation = { if (transformFuncs.nonEmpty) { - val base = s"doc['$sourceField'].value" - val orderedTransforms = transformFuncs.reverse - val scriptSrc = orderedTransforms.foldLeft(base) { - case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr) - case (expr, f) => f.toSQL(expr) // fallback - } + val scriptSrc = SQLFunctionUtils.buildPainless(Option(identifier), transformFuncs) val script = Script(scriptSrc).lang("painless") buildScript(aggName, script) } else { diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index a360a24f..0d44ddd2 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -12,7 +12,7 @@ import app.softnetwork.elastic.sql.{ Min, SQLBucket, SQLCriteria, - SQLTransformFunction, + SQLFunctionUtils, SortOrder, Sum } @@ -100,12 +100,7 @@ object ElasticAggregation { buildScript: (String, Script) => Aggregation ): Aggregation = { if (transformFuncs.nonEmpty) { - val base = s"doc['$sourceField'].value" - val orderedTransforms = transformFuncs.reverse - val scriptSrc = orderedTransforms.foldLeft(base) { - case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr) - case (expr, f) => f.toSQL(expr) // fallback - } + val scriptSrc = SQLFunctionUtils.buildPainless(Option(identifier), transformFuncs) val script = Script(scriptSrc).lang("painless") buildScript(aggName, script) } else { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 6ae1599b..c4a179ed 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -6,23 +6,62 @@ sealed trait SQLFunction extends SQLRegex { def toSQL(base: String): String = s"$sql($base)" } -trait SQLTypedFunction[In <: SQLType, Out <: SQLType] extends SQLFunction { +object SQLFunctionUtils { + def buildPainless(functions: List[SQLFunction]): String = + buildPainless(None, functions) + + def buildPainless( + painless: Option[PainlessScript] = None, + functions: List[SQLFunction] + ): String = { + val base = painless.map(_.painless).getOrElse("") + val orderedFunctions = functions.reverse + orderedFunctions.foldLeft(base) { + case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr) + case (_, f: PainlessScript) => f.painless + case (expr, f) => f.toSQL(expr) // fallback + } + } +} + +trait SQLFunctionChain extends SQLFunction with SQLValidation { + def functions: List[SQLFunction] + + override def validate(): Either[String, Unit] = + SQLValidator.validateChain(functions) + + override def toSQL(base: String): String = + functions.reverse.foldLeft(base)((expr, fun) => { + fun.toSQL(expr) + }) + + lazy val aggregateFunction: Option[AggregateFunction] = functions.headOption match { + case Some(af: AggregateFunction) => Some(af) + case _ => None + } + + lazy val aggregation: Boolean = aggregateFunction.isDefined +} + +sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType] + extends SQLFunction + with PainlessScript { def inputType: In def outputType: Out - def from(other: SQLTypedFunction[_, _]): Boolean = - inputType.typeId == other.outputType.asInstanceOf[SQLType].typeId || - (inputType.typeId == "temporal" && Set("date", "datetime").contains( - other.outputType.asInstanceOf[SQLType].typeId - )) - def to(other: SQLTypedFunction[_, _]): Boolean = - outputType.typeId == other.inputType.asInstanceOf[SQLType].typeId || - (outputType.typeId == "temporal" && Set("date", "datetime").contains( - other.inputType.asInstanceOf[SQLType].typeId - )) } -sealed trait SQLTransformFunction[In <: SQLType, Out <: SQLType] extends SQLTypedFunction[In, Out] { - def toPainless(base: String): String +trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType] + extends SQLUnaryFunction[SQLAny, Out] { self: SQLFunction => + + override def inputType: SQLAny = SQLTypes.Any + + def left: PainlessScript + def right: PainlessScript + +} + +sealed trait SQLTransformFunction[In <: SQLType, Out <: SQLType] extends SQLUnaryFunction[In, Out] { + def toPainless(base: String): String = s"$base$painless" } sealed trait ParametrizedFunction extends SQLFunction { @@ -162,7 +201,7 @@ case class DateTrunc(unit: TimeUnit) override def inputType: SQLTemporal = SQLTypes.Temporal // par défaut override def outputType: SQLTemporal = SQLTypes.Temporal // idem override def params: Seq[String] = Seq(unit.sql) - override def toPainless(base: String): String = s"$base.truncatedTo(${unit.painless})" + override def painless: String = s".truncatedTo(${unit.painless})" } case class Extract(unit: TimeUnit, override val sql: String = "extract") @@ -173,7 +212,7 @@ case class Extract(unit: TimeUnit, override val sql: String = "extract") override def inputType: SQLTemporal = SQLTypes.Temporal override def outputType: SQLNumber = SQLTypes.Number override def params: Seq[String] = Seq(unit.sql) - override def toPainless(base: String): String = s"$base.get(${unit.painless})" + override def painless: String = s".get(${unit.painless})" } object YEAR extends Extract(Year, Year.sql) { @@ -200,6 +239,20 @@ object SECOND extends Extract(Second, Second.sql) { override def params: Seq[String] = Seq.empty } +case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit) + extends SQLExpr("date_diff") + with DateTimeFunction + with SQLBinaryFunction[SQLDateTime, SQLDateTime, SQLNumber] + with PainlessScript { + override def outputType: SQLNumber = SQLTypes.Number + override def left: PainlessScript = end + override def right: PainlessScript = start + override def toSQL(base: String): String = { + s"$sql(${end.sql}, ${start.sql}, ${unit.sql})" + } + override def painless: String = s"${unit.painless}.between(${start.painless}, ${end.painless})" +} + case class DateAdd(interval: TimeInterval) extends SQLExpr("date_add") with DateFunction @@ -208,7 +261,7 @@ case class DateAdd(interval: TimeInterval) override def inputType: SQLDate = SQLTypes.Date override def outputType: SQLDate = SQLTypes.Date override def params: Seq[String] = Seq(interval.sql) - override def toPainless(base: String): String = s"$base.plus(${interval.painless})" + override def painless: String = s".plus(${interval.painless})" } case class DateSub(interval: TimeInterval) @@ -219,7 +272,7 @@ case class DateSub(interval: TimeInterval) override def inputType: SQLDate = SQLTypes.Date override def outputType: SQLDate = SQLTypes.Date override def params: Seq[String] = Seq(interval.sql) - override def toPainless(base: String): String = s"$base.minus(${interval.painless})" + override def painless: String = s".minus(${interval.painless})" } case class ParseDate(format: String) @@ -230,6 +283,7 @@ case class ParseDate(format: String) override def inputType: SQLString = SQLTypes.String override def outputType: SQLDate = SQLTypes.Date override def params: Seq[String] = Seq(s"'$format'") + override def painless: String = throw new NotImplementedError("Use toPainless instead") override def toPainless(base: String): String = s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDate::from)" } @@ -242,6 +296,7 @@ case class FormatDate(format: String) override def inputType: SQLDate = SQLTypes.Date override def outputType: SQLString = SQLTypes.String override def params: Seq[String] = Seq(s"'$format'") + override def painless: String = throw new NotImplementedError("Use toPainless instead") override def toPainless(base: String): String = s"DateTimeFormatter.ofPattern('$format').format($base)" } @@ -254,7 +309,7 @@ case class DateTimeAdd(interval: TimeInterval) override def inputType: SQLDateTime = SQLTypes.DateTime override def outputType: SQLDateTime = SQLTypes.DateTime override def params: Seq[String] = Seq(interval.sql) - override def toPainless(base: String): String = s"$base.plus(${interval.painless})" + override def painless: String = s".plus(${interval.painless})" } case class DateTimeSub(interval: TimeInterval) @@ -265,7 +320,7 @@ case class DateTimeSub(interval: TimeInterval) override def inputType: SQLDateTime = SQLTypes.DateTime override def outputType: SQLDateTime = SQLTypes.DateTime override def params: Seq[String] = Seq(interval.sql) - override def toPainless(base: String): String = s"$base.minus(${interval.painless})" + override def painless: String = s".minus(${interval.painless})" } case class ParseDateTime(format: String) @@ -276,6 +331,7 @@ case class ParseDateTime(format: String) override def inputType: SQLString = SQLTypes.String override def outputType: SQLDateTime = SQLTypes.DateTime override def params: Seq[String] = Seq(s"'$format'") + override def painless: String = throw new NotImplementedError("Use toPainless instead") override def toPainless(base: String): String = s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDateTime::from)" } @@ -288,6 +344,7 @@ case class FormatDateTime(format: String) override def inputType: SQLDateTime = SQLTypes.DateTime override def outputType: SQLString = SQLTypes.String override def params: Seq[String] = Seq(s"'$format'") + override def painless: String = throw new NotImplementedError("Use toPainless instead") override def toPainless(base: String): String = s"DateTimeFormatter.ofPattern('$format').format($base)" } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala index 54acfb3d..4a04f005 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOrderBy.scala @@ -12,14 +12,9 @@ case class SQLFieldSort( field: String, order: Option[SortOrder], functions: List[SQLFunction] = List.empty -) extends SQLTokenWithFunction { - private[this] lazy val fieldWithFunction: String = - functions.foldLeft(field)((expr, fun) => { - fun.toSQL(expr) - }) - +) extends SQLFunctionChain { lazy val direction: SortOrder = order.getOrElse(Asc) - lazy val name: String = fieldWithFunction + lazy val name: String = toSQL(field) override def sql: String = s"$name $direction" } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 48143d4e..8db81295 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -146,34 +146,34 @@ trait SQLParser extends RegexParsers with PackratParsers { ) } - def date_trunc: PackratParser[SQLTypedFunction[SQLTemporal, SQLTemporal]] = + def date_trunc: PackratParser[SQLUnaryFunction[SQLTemporal, SQLTemporal]] = "(?i)date_trunc".r ~ start ~ time_unit ~ end ^^ { case _ ~ _ ~ u ~ _ => DateTrunc(u) } - def extract: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + def extract: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = "(?i)extract".r ~ start ~ time_unit ~ end ^^ { case _ ~ _ ~ u ~ _ => Extract(u) } - def extract_year: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + def extract_year: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = Year.regex ^^ (_ => YEAR) - def extract_month: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + def extract_month: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = Month.regex ^^ (_ => MONTH) - def extract_day: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = Day.regex ^^ (_ => DAY) + def extract_day: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = Day.regex ^^ (_ => DAY) - def extract_hour: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + def extract_hour: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = Hour.regex ^^ (_ => HOUR) - def extract_minute: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + def extract_minute: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = Minute.regex ^^ (_ => MINUTE) - def extract_second: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + def extract_second: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = Second.regex ^^ (_ => SECOND) - def extractors: PackratParser[SQLTypedFunction[SQLTemporal, SQLNumber]] = + def extractors: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = extract | extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second def date_add: PackratParser[DateFunction] = @@ -225,15 +225,35 @@ trait SQLParser extends RegexParsers with PackratParsers { def distance: PackratParser[SQLFunction] = Distance.regex ^^ (_ => Distance) + def date_painless: PackratParser[PainlessScript] = + repsep( + date_trunc | extractors | date_functions | datetime_functions, + start + ) ~ start.? ~ identifier.? ~ rep(end) ^^ { case f ~ _ ~ i ~ _ => + SQLValidator.validateChain(f) match { + case Left(error) => throw SQLValidationError(error) + case _ => + } + i match { + case Some(id) => id.copy(functions = f) + case None => SQLIdentifier("", functions = f) + } + } + + def date_diff: PackratParser[DateDiff] = + "(?i)date_diff".r ~ start ~ (date_painless | identifier) ~ separator ~ (date_painless | identifier) ~ separator ~ time_unit ~ end ^^ { + case _ ~ _ ~ d1 ~ _ ~ d2 ~ _ ~ u ~ _ => DateDiff(d1, d2, u) + } + def sql_functions: PackratParser[SQLFunction] = - aggregates | distance | date_trunc | extractors | date_functions | datetime_functions + aggregates | distance | date_diff | date_trunc | extractors | date_functions | datetime_functions private val regexIdentifier = """[\*a-zA-Z_\-][a-zA-Z0-9_\-\.\[\]\*]*""" def identifierWithFunction: PackratParser[SQLIdentifier] = rep1sep(sql_functions, start) ~ start.? ~ identifier ~ rep1(end) ^^ { case f ~ _ ~ i ~ _ => SQLValidator.validateChain(f) match { - case Left(error) => throw new IllegalArgumentException(error) + case Left(error) => throw SQLValidationError(error) case _ => } i.copy(functions = f) @@ -271,6 +291,20 @@ trait SQLParser extends RegexParsers with PackratParsers { (dateTimeWithInterval | identifierWithInterval) ~ alias.? ^^ { case d ~ a => d.copy(fieldAlias = a) } + + def date_diff_field: PackratParser[SQLFunctionField] = date_diff ~ alias.? ^^ { case d ~ a => + SQLFunctionField(d :: Nil, a) + } + + def functionField: PackratParser[SQLFunctionField] = + rep1sep(sql_functions, start) ~ start.? ~ rep1(end) ~ alias.? ^^ { case f ~ _ ~ _ ~ a => + SQLValidator.validateChain(f) match { + case Left(error) => throw SQLValidationError(error) + case _ => + } + SQLFunctionField(f, a) + } + } trait SQLSelectParser { @@ -282,7 +316,10 @@ trait SQLSelectParser { } def select: PackratParser[SQLSelect] = - Select.regex ~ rep1sep(scriptField | field, separator) ~ except.? ^^ { case _ ~ fields ~ e => + Select.regex ~ rep1sep( + date_diff_field | functionField | scriptField | field, + separator + ) ~ except.? ^^ { case _ ~ fields ~ e => SQLSelect(fields, e) } @@ -550,7 +587,7 @@ trait SQLWhereParser { case _ :: Nil => processTokensHelper(rest, op :: stack) case _ => - throw new IllegalStateException("Invalid stack state for predicate creation") + throw SQLValidationError("Invalid stack state for predicate creation") } case (_: EndDelimiter) :: rest => processTokensHelper(rest, stack) // Ignore and move on @@ -581,7 +618,7 @@ trait SQLWhereParser { */ private def processSubTokens(tokens: List[SQLToken]): SQLCriteria = { processTokensHelper(tokens, Nil).getOrElse( - throw new IllegalStateException("Empty sub-expression") + throw SQLValidationError("Empty sub-expression") ) } @@ -604,7 +641,7 @@ trait SQLWhereParser { subTokens: List[SQLToken] = Nil ): (List[SQLToken], List[SQLToken]) = { tokens match { - case Nil => throw new IllegalStateException("Unbalanced parentheses") + case Nil => throw SQLValidationError("Unbalanced parentheses") case (start: StartDelimiter) :: rest => extractSubTokens(rest, openCount + 1, start :: subTokens) case (end: EndDelimiter) :: rest => @@ -654,7 +691,7 @@ trait SQLOrderByParser { def fieldWithFunction: PackratParser[(String, List[SQLFunction])] = rep1sep(sql_functions, start) ~ start.? ~ fieldName ~ rep1(end) ^^ { case f ~ _ ~ n ~ _ => SQLValidator.validateChain(f) match { - case Left(error) => throw new IllegalArgumentException(error) + case Left(error) => throw SQLValidationError(error) case _ => } (n, f) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala index 4cc6700a..cb6294d5 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala @@ -2,7 +2,7 @@ package app.softnetwork.elastic.sql case object Select extends SQLExpr("select") with SQLRegex -sealed trait Field extends Updateable with SQLTokenWithFunction { +sealed trait Field extends Updateable with SQLFunctionChain { def identifier: SQLIdentifier def fieldAlias: Option[SQLAlias] def isScriptField: Boolean = false @@ -43,6 +43,22 @@ sealed trait ScriptField extends Field with PainlessScript { lazy val name: String = fieldAlias.map(_.alias).getOrElse(sourceField) } +case class SQLFunctionField( + override val functions: List[SQLFunction], + fieldAlias: Option[SQLAlias] = None +) extends ScriptField + with SQLFunctionChain { + + override def update(request: SQLSearchRequest): SQLFunctionField = + this // TODO update SQLAlias if needed + + override def identifier: SQLIdentifier = SQLIdentifier("", functions = functions) + + override def painless: String = SQLFunctionUtils.buildPainless(functions) + + override lazy val sourceField: String = toSQL("").replace("(", "").replace(")", "") +} + case class SQLDateTimeField( identifier: SQLIdentifier, operator: Option[ArithmeticOperator] = None, diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala index 11988dfc..ff4cebf7 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLType.scala @@ -2,8 +2,17 @@ package app.softnetwork.elastic.sql sealed trait SQLType { def typeId: String } +trait SQLAny extends SQLType trait SQLTemporal extends SQLType trait SQLDate extends SQLTemporal trait SQLDateTime extends SQLTemporal trait SQLNumber extends SQLType trait SQLString extends SQLType + +object SQLTypeCompatibility { + def matches(out: SQLType, in: SQLType): Boolean = + out.typeId == in.typeId || + (out.typeId == "temporal" && Set("date", "datetime").contains(in.typeId)) || + (in.typeId == "temporal" && Set("date", "datetime").contains(out.typeId)) || + out.typeId == "any" || in.typeId == "any" +} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala index 85873e9b..131c9a01 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypes.scala @@ -1,6 +1,7 @@ package app.softnetwork.elastic.sql object SQLTypes { + case object Any extends SQLAny { val typeId = "any" } case object Temporal extends SQLTemporal { val typeId = "temporal" } case object Date extends SQLDate { val typeId = "date" } case object DateTime extends SQLDateTime { val typeId = "datetime" } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala index 9a129ddf..776bab53 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala @@ -3,16 +3,24 @@ package app.softnetwork.elastic.sql object SQLValidator { def validateChain(functions: List[SQLFunction]): Either[String, Unit] = { - functions - .collect { case f: SQLTypedFunction[_, _] => f } - .sliding(2) - .foreach { - case Seq(f1, f2) => - if (!f1.from(f2)) { - return Left(s"Type mismatch: ${f2.outputType} -> ${f1.inputType}") - } - case _ => // ok - } + // validate function chain type compatibility + val unaryFuncs = functions.collect { case f: SQLUnaryFunction[_, _] => f } + unaryFuncs.sliding(2).foreach { + case Seq(f1, f2) => + if (!SQLTypeCompatibility.matches(f2.outputType, f1.inputType)) { + return Left( + s"Type mismatch: output '${f2.outputType.typeId}' of `${f2.sql}` " + + s"is not compatible with input '${f1.inputType.typeId}' of `${f1.sql}`" + ) + } + case _ => // ok + } Right(()) } } + +trait SQLValidation { + def validate(): Either[String, Unit] = Right(()) +} + +case class SQLValidationError(message: String) extends Exception(message) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index 68b36cd0..162587e6 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -89,7 +89,7 @@ case class SQLPredicate( sealed trait ElasticFilter -sealed trait SQLCriteriaWithIdentifier extends SQLCriteria with SQLTokenWithFunction { +sealed trait SQLCriteriaWithIdentifier extends SQLCriteria with SQLFunctionChain { def identifier: SQLIdentifier override def nested: Boolean = identifier.nested override def group: Boolean = false diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index 5364df34..0781513b 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -29,24 +29,6 @@ package object sql { def script: String } - trait SQLTokenWithFunction extends SQLToken { - def functions: List[SQLFunction] - - lazy val aggregateFunction: Option[AggregateFunction] = functions.headOption match { - case Some(af: AggregateFunction) => Some(af) - case other => - Console.println(this) - None - } - - lazy val aggregation: Boolean = aggregateFunction.isDefined - - def validate(): Either[String, Unit] = { - SQLValidator.validateChain(functions) - } - - } - trait Updateable extends SQLToken { def update(request: SQLSearchRequest): Updateable } @@ -292,7 +274,8 @@ package object sql { }) }) with SQLSource - with SQLTokenWithFunction { + with SQLFunctionChain + with PainlessScript { lazy val identifierName: String = functions.reverse.foldLeft(name)((expr, fun) => { @@ -305,6 +288,8 @@ package object sql { lazy val aliasOrName: String = fieldAlias.getOrElse(name) + override def painless: String = s"doc['$name'].value" + def update(request: SQLSearchRequest): SQLIdentifier = { val parts: Seq[String] = name.split("\\.").toSeq if (request.tableAliases.values.toSeq.contains(parts.head)) { diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala index 8385201d..87411cb3 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala @@ -39,7 +39,7 @@ class SQLDateTimeFunctionSuite extends AnyFunSuite { (transforms.head.toPainless(base), transforms.head.outputType.asInstanceOf[SQLType]) val (finalExpr, _) = transforms.tail.foldLeft(initial) { - case ((expr, currentType), t: SQLTypedFunction[_, _]) => + case ((expr, currentType), t: SQLUnaryFunction[_, _]) => if (!currentType.getClass.isAssignableFrom(t.inputType.getClass)) { throw new IllegalArgumentException( s"Type mismatch: expected ${currentType.getClass.getSimpleName}, got ${t.inputType.getClass.getSimpleName}" diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index af3421cf..748b7189 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -114,6 +114,10 @@ object Queries { .replaceAll("\\( ", "(") .replaceAll(" \\)", ")") + val dateDiff = "select date_diff(createdAt, updatedAt, day) as diff from Table" + + val dateDiffWithAggregation = + "select max(date_diff(parse_datetime('yyyy-MM-ddTHH:mm:ssZ')(createdAt), updatedAt, day)) as max_diff from Table" } /** Created by smanciot on 15/02/17. @@ -431,4 +435,18 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { parseDateTime ) } + + it should "parse date_diff function" in { + val result = SQLParser(dateDiff) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + dateDiff + ) + } + + it should "parse date_diff function with aggregation" in { + val result = SQLParser(dateDiffWithAggregation) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + dateDiffWithAggregation + ) + } } From 6c85817baa9133cc87ebf0b4a603c7871ff89b37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 8 Sep 2025 11:34:55 +0200 Subject: [PATCH 17/22] add trait Identifier, add / subtract interval as an arithmetic function, update parser --- build.sbt | 2 +- .../sql/bridge/ElasticAggregation.scala | 7 +- .../elastic/sql/bridge/package.scala | 2 +- .../sql/bridge/ElasticAggregation.scala | 13 +- .../elastic/sql/bridge/package.scala | 2 +- .../softnetwork/elastic/sql/SQLFunction.scala | 55 +++++-- .../softnetwork/elastic/sql/SQLGroupBy.scala | 6 +- .../softnetwork/elastic/sql/SQLOperator.scala | 7 +- .../softnetwork/elastic/sql/SQLParser.scala | 146 ++++++++++-------- .../elastic/sql/SQLSearchRequest.scala | 5 +- .../softnetwork/elastic/sql/SQLSelect.scala | 59 +------ .../softnetwork/elastic/sql/SQLWhere.scala | 12 +- .../app/softnetwork/elastic/sql/package.scala | 41 +++-- 13 files changed, 185 insertions(+), 172 deletions(-) diff --git a/build.sbt b/build.sbt index 6f0014dc..bae919ad 100644 --- a/build.sbt +++ b/build.sbt @@ -19,7 +19,7 @@ ThisBuild / organization := "app.softnetwork" name := "softclient4es" -ThisBuild / version := "0.4.0" +ThisBuild / version := "0.5.0" ThisBuild / scalaVersion := scala213 diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index 4e03dcfa..f2b46f12 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -89,10 +89,7 @@ object ElasticAggregation { var aggPath = Seq[String]() - val (aggFuncs, transformFuncs) = identifier.functions.partition { - case _: AggregateFunction => true - case _ => false - } + val (aggFuncs, transformFuncs) = SQLFunctionUtils.aggregateAndTransformFunctions(identifier) require(aggFuncs.size == 1, s"Multiple aggregate functions not supported: $aggFuncs") @@ -101,7 +98,7 @@ object ElasticAggregation { buildScript: (String, Script) => Aggregation ): Aggregation = { if (transformFuncs.nonEmpty) { - val scriptSrc = SQLFunctionUtils.buildPainless(Option(identifier), transformFuncs) + val scriptSrc = SQLFunctionUtils.buildPainless(identifier) val script = Script(scriptSrc).lang("painless") buildScript(aggName, script) } else { diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index 3f990f01..160ee4e2 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -101,7 +101,7 @@ package object bridge { case _ => _search scriptfields scriptFields.map { field => scriptField( - field.name, + field.scriptName, Script(script = field.painless).lang("painless").scriptType("source") ) } diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index 0d44ddd2..e7e31289 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -88,19 +88,16 @@ object ElasticAggregation { var aggPath = Seq[String]() - val (aggFuncs, transformFuncs) = identifier.functions.partition { - case _: AggregateFunction => true - case _ => false - } + val (aggFuncs, transformFuncs) = SQLFunctionUtils.aggregateAndTransformFunctions(identifier) require(aggFuncs.size == 1, s"Multiple aggregate functions not supported: $aggFuncs") def aggWithFieldOrScript( - buildField: (String, String) => Aggregation, - buildScript: (String, Script) => Aggregation - ): Aggregation = { + buildField: (String, String) => Aggregation, + buildScript: (String, Script) => Aggregation + ): Aggregation = { if (transformFuncs.nonEmpty) { - val scriptSrc = SQLFunctionUtils.buildPainless(Option(identifier), transformFuncs) + val scriptSrc = SQLFunctionUtils.buildPainless(identifier) val script = Script(scriptSrc).lang("painless") buildScript(aggName, script) } else { diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala index bd6f12b0..b2edb050 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala @@ -102,7 +102,7 @@ package object bridge { case _ => _search scriptfields scriptFields.map { field => scriptField( - field.name, + field.scriptName, Script(script = field.painless).lang("painless").scriptType("source") ) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index c4a179ed..50c49f49 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -3,22 +3,31 @@ package app.softnetwork.elastic.sql import scala.util.matching.Regex sealed trait SQLFunction extends SQLRegex { - def toSQL(base: String): String = s"$sql($base)" + def toSQL(base: String): String = if (base.nonEmpty) s"$sql($base)" else sql } object SQLFunctionUtils { - def buildPainless(functions: List[SQLFunction]): String = - buildPainless(None, functions) + def aggregateAndTransformFunctions( + identifier: Identifier + ): (List[SQLFunction], List[SQLFunction]) = { + identifier.functions.partition { + case _: AggregateFunction => true + case _ => false + } + } + + def transformFunctions(identifier: Identifier): List[SQLFunction] = { + aggregateAndTransformFunctions(identifier)._2 + } def buildPainless( - painless: Option[PainlessScript] = None, - functions: List[SQLFunction] + identifier: Identifier ): String = { - val base = painless.map(_.painless).getOrElse("") - val orderedFunctions = functions.reverse + val base = identifier.painless + val orderedFunctions = transformFunctions(identifier).reverse orderedFunctions.foldLeft(base) { case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr) - case (_, f: PainlessScript) => f.painless + case (expr, f: PainlessScript) => s"$expr${f.painless}" case (expr, f) => f.toSQL(expr) // fallback } } @@ -50,7 +59,7 @@ sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType] def outputType: Out } -trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType] +sealed trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType] extends SQLUnaryFunction[SQLAny, Out] { self: SQLFunction => override def inputType: SQLAny = SQLTypes.Any @@ -64,6 +73,12 @@ sealed trait SQLTransformFunction[In <: SQLType, Out <: SQLType] extends SQLUnar def toPainless(base: String): String = s"$base$painless" } +sealed trait SQLArithmeticFunction[In <: SQLType, Out <: SQLType] + extends SQLTransformFunction[In, Out] { + def operator: ArithmeticOperator + override def toSQL(base: String): String = s"$base$operator$sql" +} + sealed trait ParametrizedFunction extends SQLFunction { def params: Seq[String] override def toSQL(base: String): String = { @@ -155,6 +170,28 @@ object TimeInterval { } } +case class SQLAddInterval(interval: TimeInterval) + extends SQLExpr(interval.sql) + with SQLArithmeticFunction[SQLDateTime, SQLDateTime] + with MathScript { + override def operator: ArithmeticOperator = Add + override def inputType: SQLDateTime = SQLTypes.DateTime + override def outputType: SQLDateTime = SQLTypes.DateTime + override def painless: String = s".plus(${interval.painless})" + override def script: String = s"${operator.script}${interval.script}" +} + +case class SQLSubstractInterval(interval: TimeInterval) + extends SQLExpr(interval.sql) + with SQLArithmeticFunction[SQLDateTime, SQLDateTime] + with MathScript { + override def operator: ArithmeticOperator = Subtract + override def inputType: SQLDateTime = SQLTypes.DateTime + override def outputType: SQLDateTime = SQLTypes.DateTime + override def painless: String = s".minus(${interval.painless})" + override def script: String = s"${operator.script}${interval.script}" +} + sealed trait DateTimeFunction extends SQLFunction sealed trait DateFunction extends DateTimeFunction diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala index 6a7b4ec2..75956b86 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala @@ -129,9 +129,9 @@ object BucketSelectorScript { // build the RHS as a Painless ZonedDateTime (apply +/- interval using TimeInterval.painless) val rightBase = (arithOp, interval) match { - case (Some(Plus), Some(i)) => s"$now.plus(${i.painless})" - case (Some(Minus), Some(i)) => s"$now.minus(${i.painless})" - case _ => now + case (Some(Add), Some(i)) => s"$now.plus(${i.painless})" + case (Some(Subtract), Some(i)) => s"$now.minus(${i.painless})" + case _ => now } val rightZdt = dateFunc match { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala index c3a073b0..21342118 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala @@ -2,11 +2,12 @@ package app.softnetwork.elastic.sql trait SQLOperator extends SQLToken -sealed trait ArithmeticOperator extends SQLOperator { +sealed trait ArithmeticOperator extends SQLOperator with MathScript { override def toString: String = s" $sql " + override def script: String = sql } -case object Plus extends SQLExpr("+") with ArithmeticOperator -case object Minus extends SQLExpr("-") with ArithmeticOperator +case object Add extends SQLExpr("+") with ArithmeticOperator +case object Subtract extends SQLExpr("-") with ArithmeticOperator case object Multiply extends SQLExpr("*") with ArithmeticOperator case object Divide extends SQLExpr("/") with ArithmeticOperator case object Modulo extends SQLExpr("%") with ArithmeticOperator diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 8db81295..2b1c0ac9 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -130,20 +130,36 @@ trait SQLParser extends RegexParsers with PackratParsers { if (s.isDefined && t.isDefined) NowWithParens else Now } - def plus: PackratParser[ArithmeticOperator] = Plus.sql ^^ (_ => Plus) + def add: PackratParser[ArithmeticOperator] = Add.sql ^^ (_ => Add) - def minus: PackratParser[ArithmeticOperator] = Minus.sql ^^ (_ => Minus) + def substract: PackratParser[ArithmeticOperator] = Subtract.sql ^^ (_ => Subtract) - def arithmeticOperator: PackratParser[ArithmeticOperator] = plus | minus + def intervalOperator: PackratParser[ArithmeticOperator] = add | substract - def dateTimeWithInterval: PackratParser[SQLDateTimeField] = - (current_date | current_time | current_timestamp | now) ~ arithmeticOperator.? ~ interval.? ^^ { - case f ~ o ~ i => - SQLDateTimeField( - SQLIdentifier(f.sql), - o, - i - ) + def arithmeticOperator: PackratParser[ArithmeticOperator] = intervalOperator + + def addInterval: PackratParser[SQLAddInterval] = + add ~ interval ^^ { case _ ~ it => + SQLAddInterval(it) + } + + def substractInterval: PackratParser[SQLSubstractInterval] = + substract ~ interval ^^ { case _ ~ it => + SQLSubstractInterval(it) + } + + def intervalFunction: PackratParser[SQLArithmeticFunction[SQLDateTime, SQLDateTime]] = + addInterval | substractInterval + + def arithmeticFunction: PackratParser[SQLArithmeticFunction[_, _]] = intervalFunction + + def identifierWithSystemFunction: PackratParser[SQLIdentifier] = + (current_date | current_time | current_timestamp | now) ~ intervalFunction.? ^^ { + case f1 ~ f2 => + f2 match { + case Some(f) => SQLIdentifier("", functions = List(f, f1)) + case None => SQLIdentifier("", functions = List(f1)) + } } def date_trunc: PackratParser[SQLUnaryFunction[SQLTemporal, SQLTemporal]] = @@ -225,40 +241,37 @@ trait SQLParser extends RegexParsers with PackratParsers { def distance: PackratParser[SQLFunction] = Distance.regex ^^ (_ => Distance) - def date_painless: PackratParser[PainlessScript] = + def painless_identifier: PackratParser[Identifier] = repsep( date_trunc | extractors | date_functions | datetime_functions, start - ) ~ start.? ~ identifier.? ~ rep(end) ^^ { case f ~ _ ~ i ~ _ => + ) ~ start.? ~ (identifierWithSystemFunction | identifierWithArithmeticFunction | identifier).? ~ rep( + end + ) ^^ { case f ~ _ ~ i ~ _ => SQLValidator.validateChain(f) match { case Left(error) => throw SQLValidationError(error) case _ => } i match { - case Some(id) => id.copy(functions = f) + case Some(id) => id.copy(functions = id.functions ++ f) case None => SQLIdentifier("", functions = f) } } - def date_diff: PackratParser[DateDiff] = - "(?i)date_diff".r ~ start ~ (date_painless | identifier) ~ separator ~ (date_painless | identifier) ~ separator ~ time_unit ~ end ^^ { + def date_diff: PackratParser[SQLBinaryFunction[_, _, _]] = + "(?i)date_diff".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { case _ ~ _ ~ d1 ~ _ ~ d2 ~ _ ~ u ~ _ => DateDiff(d1, d2, u) } + def date_diff_identifier: PackratParser[SQLIdentifier] = date_diff ^^ { dd => + SQLIdentifier("", functions = dd :: Nil) + } + def sql_functions: PackratParser[SQLFunction] = aggregates | distance | date_diff | date_trunc | extractors | date_functions | datetime_functions private val regexIdentifier = """[\*a-zA-Z_\-][a-zA-Z0-9_\-\.\[\]\*]*""" - def identifierWithFunction: PackratParser[SQLIdentifier] = - rep1sep(sql_functions, start) ~ start.? ~ identifier ~ rep1(end) ^^ { case f ~ _ ~ i ~ _ => - SQLValidator.validateChain(f) match { - case Left(error) => throw SQLValidationError(error) - case _ => - } - i.copy(functions = f) - } - def identifier: PackratParser[SQLIdentifier] = Distinct.regex.? ~ regexIdentifier.r ^^ { case d ~ i => SQLIdentifier( @@ -268,41 +281,43 @@ trait SQLParser extends RegexParsers with PackratParsers { ) } - def identifierWithInterval: PackratParser[SQLDateTimeField] = - identifier ~ arithmeticOperator ~ interval ^^ { case f ~ o ~ i => - SQLDateTimeField( - f, - Some(o), - Some(i) - ) + def identifierWithArithmeticFunction: PackratParser[SQLIdentifier] = + (identifierWithFunction | identifier) ~ arithmeticFunction ^^ { case i ~ af => + i.copy(functions = af +: i.functions) } - private val regexAlias = - """\b(?!(?i)as\b)\b(?!(?i)except\b)\b(?!(?i)where\b)\b(?!(?i)filter\b)\b(?!(?i)from\b)\b(?!(?i)group\b)\b(?!(?i)having\b)\b(?!(?i)order\b)\b(?!(?i)limit\b)[a-zA-Z0-9_]*""" - - def alias: PackratParser[SQLAlias] = Alias.regex.? ~ regexAlias.r ^^ { case _ ~ b => SQLAlias(b) } - - def field: PackratParser[Field] = (identifierWithFunction | identifier) ~ alias.? ^^ { - case i ~ a => - SQLField(i, a) - } - - def scriptField: PackratParser[ScriptField] = - (dateTimeWithInterval | identifierWithInterval) ~ alias.? ^^ { case d ~ a => - d.copy(fieldAlias = a) + def identifierWithAggregation: PackratParser[SQLIdentifier] = + aggregates ~ start ~ (identifierWithFunction | identifierWithArithmeticFunction | identifier) ~ end ^^ { + case a ~ _ ~ i ~ _ => + i.copy(functions = a +: i.functions) } - def date_diff_field: PackratParser[SQLFunctionField] = date_diff ~ alias.? ^^ { case d ~ a => - SQLFunctionField(d :: Nil, a) - } - - def functionField: PackratParser[SQLFunctionField] = - rep1sep(sql_functions, start) ~ start.? ~ rep1(end) ~ alias.? ^^ { case f ~ _ ~ _ ~ a => + def identifierWithFunction: PackratParser[SQLIdentifier] = + rep1sep( + sql_functions, + start + ) ~ start.? ~ (identifierWithSystemFunction | identifierWithArithmeticFunction | identifier).? ~ rep1( + end + ) ^^ { case f ~ _ ~ i ~ _ => SQLValidator.validateChain(f) match { case Left(error) => throw SQLValidationError(error) case _ => } - SQLFunctionField(f, a) + i match { + case None => SQLIdentifier("", functions = f) + case Some(id) => id.copy(functions = id.functions ++ f) + } + } + + private val regexAlias = + """\b(?!(?i)as\b)\b(?!(?i)except\b)\b(?!(?i)where\b)\b(?!(?i)filter\b)\b(?!(?i)from\b)\b(?!(?i)group\b)\b(?!(?i)having\b)\b(?!(?i)order\b)\b(?!(?i)limit\b)[a-zA-Z0-9_]*""" + + def alias: PackratParser[SQLAlias] = Alias.regex.? ~ regexAlias.r ^^ { case _ ~ b => SQLAlias(b) } + + def field: PackratParser[Field] = + (identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier) ~ alias.? ^^ { + case i ~ a => + SQLField(i, a) } } @@ -317,7 +332,7 @@ trait SQLSelectParser { def select: PackratParser[SQLSelect] = Select.regex ~ rep1sep( - date_diff_field | functionField | scriptField | field, + field, separator ) ~ except.? ^^ { case _ ~ fields ~ e => SQLSelect(fields, e) @@ -360,13 +375,14 @@ trait SQLWhereParser { private def diff: PackratParser[SQLComparisonOperator] = Diff.sql ^^ (_ => Diff) private def equality: PackratParser[SQLExpression] = - not.? ~ (identifierWithFunction | identifier) ~ (eq | ne | diff) ~ (boolean | literal | double | long) ^^ { + not.? ~ (identifierWithAggregation | identifierWithFunction | identifier) ~ (eq | ne | diff) ~ (boolean | literal | double | long) ^^ { case n ~ i ~ o ~ v => SQLExpression(i, o, v, n) } def like: PackratParser[SQLExpression] = - (identifierWithFunction | identifier) ~ not.? ~ Like.regex ~ literal ^^ { case i ~ n ~ _ ~ v => - SQLExpression(i, Like, v, n) + (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Like.regex ~ literal ^^ { + case i ~ n ~ _ ~ v => + SQLExpression(i, Like, v, n) } private def ge: PackratParser[SQLComparisonOperator] = Ge.sql ^^ (_ => Ge) @@ -378,7 +394,7 @@ trait SQLWhereParser { def lt: PackratParser[SQLComparisonOperator] = Lt.sql ^^ (_ => Lt) private def comparison: PackratParser[SQLExpression] = - not.? ~ (identifierWithFunction | identifier) ~ (ge | gt | le | lt) ~ (double | long | literal) ^^ { + not.? ~ (identifierWithAggregation | identifierWithFunction | identifier) ~ (ge | gt | le | lt) ~ (double | long | literal) ^^ { case n ~ i ~ o ~ v => SQLExpression(i, o, v, n) } @@ -395,7 +411,7 @@ trait SQLWhereParser { } private def inDoubles: PackratParser[SQLCriteria] = - (identifierWithFunction | identifier) ~ not.? ~ in ~ start ~ rep1sep( + (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ in ~ start ~ rep1sep( double, separator ) ~ end ^^ { case i ~ n ~ _ ~ _ ~ v ~ _ => @@ -407,7 +423,7 @@ trait SQLWhereParser { } private def inLongs: PackratParser[SQLCriteria] = - (identifierWithFunction | identifier) ~ not.? ~ in ~ start ~ rep1sep( + (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ in ~ start ~ rep1sep( long, separator ) ~ end ^^ { case i ~ n ~ _ ~ _ ~ v ~ _ => @@ -419,17 +435,17 @@ trait SQLWhereParser { } def between: PackratParser[SQLCriteria] = - (identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ literal ~ and ~ literal ^^ { + (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ literal ~ and ~ literal ^^ { case i ~ n ~ _ ~ from ~ _ ~ to => SQLBetween(i, SQLLiteralFromTo(from, to), n) } def betweenLongs: PackratParser[SQLCriteria] = - (identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ long ~ and ~ long ^^ { + (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ long ~ and ~ long ^^ { case i ~ n ~ _ ~ from ~ _ ~ to => SQLBetween(i, SQLLongFromTo(from, to), n) } def betweenDoubles: PackratParser[SQLCriteria] = - (identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ double ~ and ~ double ^^ { + (identifierWithAggregation | identifierWithFunction | identifier) ~ not.? ~ Between.regex ~ double ~ and ~ double ^^ { case i ~ n ~ _ ~ from ~ _ ~ to => SQLBetween(i, SQLDoubleFromTo(from, to), n) } @@ -446,10 +462,12 @@ trait SQLWhereParser { SQLMatch(i, l) } - private def dateTimeComparison: PackratParser[SQLComparisonDateMath] = - not.? ~ (identifierWithFunction | identifier) ~ (eq | ne | diff | ge | gt | le | lt) ~ (current_date | current_time | current_timestamp | now) ~ arithmeticOperator.? ~ interval.? ^^ { + private def dateTimeComparison: PackratParser[SQLComparisonDateMath] = { + // identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | identifier + not.? ~ (identifierWithAggregation | identifier) ~ (eq | ne | diff | ge | gt | le | lt) ~ (current_date | current_time | current_timestamp | now) ~ arithmeticOperator.? ~ interval.? ^^ { case n ~ i ~ o ~ dt ~ ao ~ it => SQLComparisonDateMath(i, o, dt, ao, it, n) } + } def and: PackratParser[SQLPredicateOperator] = And.regex ^^ (_ => And) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala index 40c87efb..a578cc53 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSearchRequest.scala @@ -30,10 +30,7 @@ case class SQLSearchRequest( ) } - lazy val scriptFields: Seq[ScriptField] = select.fields.flatMap { - case s: ScriptField => Some(s) - case _ => None - } + lazy val scriptFields: Seq[Field] = select.fields.filter(_.isScriptField) lazy val fields: Seq[String] = { if (aggregates.isEmpty && buckets.isEmpty) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala index cb6294d5..b8f4a2be 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala @@ -2,10 +2,11 @@ package app.softnetwork.elastic.sql case object Select extends SQLExpr("select") with SQLRegex -sealed trait Field extends Updateable with SQLFunctionChain { - def identifier: SQLIdentifier +sealed trait Field extends Updateable with SQLFunctionChain with PainlessScript { + def identifier: Identifier def fieldAlias: Option[SQLAlias] - def isScriptField: Boolean = false + def isScriptField: Boolean = + identifier.name.isEmpty || (functions.nonEmpty && !aggregation && identifier.bucket.isEmpty) override def sql: String = s"$identifier${asString(fieldAlias)}" lazy val sourceField: String = if (identifier.nested) { @@ -25,6 +26,10 @@ sealed trait Field extends Updateable with SQLFunctionChain { override def functions: List[SQLFunction] = identifier.functions def update(request: SQLSearchRequest): Field + + def painless: String = SQLFunctionUtils.buildPainless(identifier) + + lazy val scriptName: String = fieldAlias.map(_.alias).getOrElse(sourceField) } case class SQLField( @@ -35,54 +40,6 @@ case class SQLField( this.copy(identifier = identifier.update(request)) } -sealed trait ScriptField extends Field with PainlessScript { - override def isScriptField: Boolean = true - - def update(request: SQLSearchRequest): ScriptField - - lazy val name: String = fieldAlias.map(_.alias).getOrElse(sourceField) -} - -case class SQLFunctionField( - override val functions: List[SQLFunction], - fieldAlias: Option[SQLAlias] = None -) extends ScriptField - with SQLFunctionChain { - - override def update(request: SQLSearchRequest): SQLFunctionField = - this // TODO update SQLAlias if needed - - override def identifier: SQLIdentifier = SQLIdentifier("", functions = functions) - - override def painless: String = SQLFunctionUtils.buildPainless(functions) - - override lazy val sourceField: String = toSQL("").replace("(", "").replace(")", "") -} - -case class SQLDateTimeField( - identifier: SQLIdentifier, - operator: Option[ArithmeticOperator] = None, - interval: Option[TimeInterval], - fieldAlias: Option[SQLAlias] = None -) extends ScriptField { - override def sql: String = - s"$identifier${asString(operator)}${asString(interval)}${asString(fieldAlias)}" - def update(request: SQLSearchRequest): SQLDateTimeField = - this.copy(identifier = identifier.update(request)) - override def painless: String = { - val base = identifier.functions.headOption match { // FIXME - case f @ Some(CurrentDate | CurrentTime | CurrentTimestamp | Now) => - f.asInstanceOf[PainlessScript].painless - case _ => s"doc['$sourceField'].value" - } - (operator, interval) match { - case (Some(Minus), Some(i)) => s"$base.minus(${i.painless})" - case (Some(Plus), Some(i)) => s"$base.plus(${i.painless})" - case _ => base - } - } -} - case object Except extends SQLExpr("except") with SQLRegex case class SQLExcept(fields: Seq[Field]) extends Updateable { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index 162587e6..4d0c1218 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -151,7 +151,7 @@ case class ElasticBoolQuery( } -trait Expression extends SQLCriteriaWithIdentifier with ElasticFilter { +sealed trait Expression extends SQLCriteriaWithIdentifier with ElasticFilter { def maybeValue: Option[SQLToken] def maybeNot: Option[Not.type] def notAsString: String = maybeNot.map(v => s"$v ").getOrElse("") @@ -342,10 +342,10 @@ case class SQLComparisonDateMath( case _: CurrentTimeFunction => val painlessOp = (if (maybeNot.isDefined) operator.not else operator).painless (arithmeticOperator, interval) match { - case (Some(Plus), Some(i)) => // compare doc time with now + interval + case (Some(Add), Some(i)) => // compare doc time with now + interval s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now().plus(${i.value}, ${i.unit.painless});" - case (Some(Minus), Some(i)) => // compare doc time with now + case (Some(Subtract), Some(i)) => // compare doc time with now s"return doc['${identifier.name}'].value.toLocalTime() $painlessOp LocalTime.now().minus(${i.value}, ${i.unit.painless});" case _ => @@ -355,9 +355,9 @@ case class SQLComparisonDateMath( val base = s"${dateTimeFunction.script}" val dateMath = (arithmeticOperator, interval) match { - case (Some(Plus), Some(i)) => s"$base+${i.script}" - case (Some(Minus), Some(i)) => s"$base-${i.script}" - case _ => base + case (Some(Add), Some(i)) => s"$base+${i.script}" + case (Some(Subtract), Some(i)) => s"$base-${i.script}" + case _ => base } dateTimeFunction match { case _: CurrentDateFunction => s"$dateMath/d" diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index 0781513b..fbc97271 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -247,6 +247,30 @@ package object sql { def update(request: SQLSearchRequest): SQLSource } + trait Identifier extends SQLToken with SQLSource with SQLFunctionChain with PainlessScript { + def name: String + + def tableAlias: Option[String] + def distinct: Boolean + def nested: Boolean + def fieldAlias: Option[String] + def bucket: Option[SQLBucket] + + lazy val identifierName: String = + functions.reverse.foldLeft(name)((expr, fun) => { + fun.toSQL(expr) + }) + + lazy val nestedType: Option[String] = if (nested) Some(name.split('.').head) else None + + lazy val innerHitsName: Option[String] = if (nested) tableAlias else None + + lazy val aliasOrName: String = fieldAlias.getOrElse(name) + + override def painless: String = if (name.nonEmpty) s"doc['$name'].value" else "" + + } + case class SQLIdentifier( name: String, tableAlias: Option[String] = None, @@ -273,22 +297,7 @@ package object sql { fun.toSQL(expr) }) }) - with SQLSource - with SQLFunctionChain - with PainlessScript { - - lazy val identifierName: String = - functions.reverse.foldLeft(name)((expr, fun) => { - fun.toSQL(expr) - }) - - lazy val nestedType: Option[String] = if (nested) Some(name.split('.').head) else None - - lazy val innerHitsName: Option[String] = if (nested) tableAlias else None - - lazy val aliasOrName: String = fieldAlias.getOrElse(name) - - override def painless: String = s"doc['$name'].value" + with Identifier { def update(request: SQLSearchRequest): SQLIdentifier = { val parts: Seq[String] = name.split("\\.").toSeq From 7c30aa6ecf87d6348351d693a0141abf7d757ec8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 8 Sep 2025 12:03:53 +0200 Subject: [PATCH 18/22] add specifications for query with date_diff --- .../elastic/sql/SQLQuerySpec.scala | 63 +++++++++++++++++++ .../elastic/sql/SQLQuerySpec.scala | 63 +++++++++++++++++++ .../softnetwork/elastic/sql/SQLSelect.scala | 3 +- .../elastic/sql/SQLParserSpec.scala | 10 +-- 4 files changed, 132 insertions(+), 7 deletions(-) diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 16c46dab..7944b1c6 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1274,4 +1274,67 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll(">", " > ") .replaceAll(",LocalDate", ", LocalDate") } + + it should "handle date_diff function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateDiff) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "diff": { + | "script": { + | "lang": "painless", + | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, doc['createdAt'].value)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll(",doc", ", doc") + } + + it should "handle aggregation with date_diff function" in { + val select: ElasticSearchRequest = + SQLQuery(aggregationWithDateDiff) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "identifier": { + | "terms": { + | "field": "identifier.keyword" + | }, + | "aggs": { + | "max_diff": { + | "max": { + | "script": { + | "lang": "painless", + | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, doc['createdAt'].value)" + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll(",doc", ", doc") + } + } diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 53dc8045..70a7a805 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1271,4 +1271,67 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll(">", " > ") .replaceAll(",LocalDate", ", LocalDate") } + + it should "handle date_diff function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateDiff) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "script_fields": { + | "diff": { + | "script": { + | "lang": "painless", + | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, doc['createdAt'].value)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll(",doc", ", doc") + } + + it should "handle aggregation with date_diff function" in { + val select: ElasticSearchRequest = + SQLQuery(aggregationWithDateDiff) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "match_all": {} + | }, + | "size": 0, + | "_source": true, + | "aggs": { + | "identifier": { + | "terms": { + | "field": "identifier.keyword" + | }, + | "aggs": { + | "max_diff": { + | "max": { + | "script": { + | "lang": "painless", + | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, doc['createdAt'].value)" + | } + | } + | } + | } + | } + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll(",doc", ", doc") + } + } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala index b8f4a2be..51c30da2 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala @@ -5,8 +5,7 @@ case object Select extends SQLExpr("select") with SQLRegex sealed trait Field extends Updateable with SQLFunctionChain with PainlessScript { def identifier: Identifier def fieldAlias: Option[SQLAlias] - def isScriptField: Boolean = - identifier.name.isEmpty || (functions.nonEmpty && !aggregation && identifier.bucket.isEmpty) + def isScriptField: Boolean = functions.nonEmpty && !aggregation && identifier.bucket.isEmpty override def sql: String = s"$identifier${asString(fieldAlias)}" lazy val sourceField: String = if (identifier.nested) { diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 748b7189..1abac616 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -114,10 +114,10 @@ object Queries { .replaceAll("\\( ", "(") .replaceAll(" \\)", ")") - val dateDiff = "select date_diff(createdAt, updatedAt, day) as diff from Table" + val dateDiff = "select date_diff(createdAt, updatedAt, day) as diff, identifier from Table" - val dateDiffWithAggregation = - "select max(date_diff(parse_datetime('yyyy-MM-ddTHH:mm:ssZ')(createdAt), updatedAt, day)) as max_diff from Table" + val aggregationWithDateDiff = + "select max(date_diff(parse_datetime('yyyy-MM-ddTHH:mm:ssZ')(createdAt), updatedAt, day)) as max_diff from Table group by identifier" } /** Created by smanciot on 15/02/17. @@ -444,9 +444,9 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { } it should "parse date_diff function with aggregation" in { - val result = SQLParser(dateDiffWithAggregation) + val result = SQLParser(aggregationWithDateDiff) result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( - dateDiffWithAggregation + aggregationWithDateDiff ) } } From d08ce7736fd3fdac61f8bc8d10fa2f658610924f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 8 Sep 2025 12:33:07 +0200 Subject: [PATCH 19/22] to fix painless script with date_diff --- .../app/softnetwork/elastic/sql/SQLQuerySpec.scala | 8 +++++--- .../app/softnetwork/elastic/sql/SQLQuerySpec.scala | 8 +++++--- .../app/softnetwork/elastic/sql/SQLFunction.scala | 12 ++++++++++-- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 7944b1c6..10fabb36 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1203,7 +1203,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(doc['createdAt'].value, LocalDate::from)" + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(doc['createdAt'].value, ZonedDateTime::from)" | } | } | } @@ -1217,7 +1217,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("!=", " != ") .replaceAll("&&", " && ") .replaceAll(">", " > ") - .replaceAll(",LocalDate", ", LocalDate") + .replaceAll(",ZonedDateTime", ", ZonedDateTime") } it should "handle parse_datetime function" in { @@ -1325,7 +1325,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "max": { | "script": { | "lang": "painless", - | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, doc['createdAt'].value)" + | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from))" | } | } | } @@ -1335,6 +1335,8 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { |}""".stripMargin .replaceAll("\\s", "") .replaceAll(",doc", ", doc") + .replaceAll("DateTimeFormatter", " DateTimeFormatter") + .replaceAll("ZonedDateTime", " ZonedDateTime") } } diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 70a7a805..5570bf20 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1256,7 +1256,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, LocalDateTime::from).truncatedTo(ChronoUnit.MINUTES).get(ChronoUnit.YEARS)" + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from).truncatedTo(ChronoUnit.MINUTES).get(ChronoUnit.YEARS)" | } | } | } @@ -1269,7 +1269,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("!=", " != ") .replaceAll("&&", " && ") .replaceAll(">", " > ") - .replaceAll(",LocalDate", ", LocalDate") + .replaceAll(",ZonedDateTime", ", ZonedDateTime") } it should "handle date_diff function as script field" in { @@ -1322,7 +1322,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "max": { | "script": { | "lang": "painless", - | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, doc['createdAt'].value)" + | "source": "ChronoUnit.DAYS.between(doc['updatedAt'].value, DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from))" | } | } | } @@ -1332,6 +1332,8 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { |}""".stripMargin .replaceAll("\\s", "") .replaceAll(",doc", ", doc") + .replaceAll("DateTimeFormatter", " DateTimeFormatter") + .replaceAll("ZonedDateTime", " ZonedDateTime") } } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 50c49f49..22388585 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -287,7 +287,15 @@ case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit) override def toSQL(base: String): String = { s"$sql(${end.sql}, ${start.sql}, ${unit.sql})" } - override def painless: String = s"${unit.painless}.between(${start.painless}, ${end.painless})" + lazy val startPainless: String = start match { + case i: Identifier => SQLFunctionUtils.buildPainless(i) + case _ => start.painless + } + lazy val endPainless: String = end match { + case i: Identifier => SQLFunctionUtils.buildPainless(i) + case _ => end.painless + } + override def painless: String = s"${unit.painless}.between($startPainless, $endPainless)" } case class DateAdd(interval: TimeInterval) @@ -370,7 +378,7 @@ case class ParseDateTime(format: String) override def params: Seq[String] = Seq(s"'$format'") override def painless: String = throw new NotImplementedError("Use toPainless instead") override def toPainless(base: String): String = - s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDateTime::from)" + s"DateTimeFormatter.ofPattern('$format').parse($base, ZonedDateTime::from)" } case class FormatDateTime(format: String) From 75c88f2691a54e87fddc161338ddddebccd641fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 8 Sep 2025 12:51:32 +0200 Subject: [PATCH 20/22] update painless within identifier --- .../sql/bridge/ElasticAggregation.scala | 2 +- .../elastic/sql/SQLQuerySpec.scala | 8 +++---- .../sql/bridge/ElasticAggregation.scala | 2 +- .../softnetwork/elastic/sql/SQLFunction.scala | 21 +------------------ .../softnetwork/elastic/sql/SQLSelect.scala | 2 +- .../app/softnetwork/elastic/sql/package.scala | 10 ++++++++- 6 files changed, 17 insertions(+), 28 deletions(-) diff --git a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index f2b46f12..86aeec7c 100644 --- a/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/es6/sql-bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -98,7 +98,7 @@ object ElasticAggregation { buildScript: (String, Script) => Aggregation ): Aggregation = { if (transformFuncs.nonEmpty) { - val scriptSrc = SQLFunctionUtils.buildPainless(identifier) + val scriptSrc = identifier.painless val script = Script(scriptSrc).lang("painless") buildScript(aggName, script) } else { diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 10fabb36..e1a1d08b 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1203,7 +1203,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(doc['createdAt'].value, ZonedDateTime::from)" + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(doc['createdAt'].value, LocalDate::from)" | } | } | } @@ -1217,7 +1217,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("!=", " != ") .replaceAll("&&", " && ") .replaceAll(">", " > ") - .replaceAll(",ZonedDateTime", ", ZonedDateTime") + .replaceAll(",LocalDate", ", LocalDate") } it should "handle parse_datetime function" in { @@ -1259,7 +1259,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "field": "createdAt", | "script": { | "lang": "painless", - | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, LocalDateTime::from).truncatedTo(ChronoUnit.MINUTES).get(ChronoUnit.YEARS)" + | "source": "DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(doc['createdAt'].value, ZonedDateTime::from).truncatedTo(ChronoUnit.MINUTES).get(ChronoUnit.YEARS)" | } | } | } @@ -1272,7 +1272,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("!=", " != ") .replaceAll("&&", " && ") .replaceAll(">", " > ") - .replaceAll(",LocalDate", ", LocalDate") + .replaceAll(",ZonedDateTime", ", ZonedDateTime") } it should "handle date_diff function as script field" in { diff --git a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala index e7e31289..1bedbaa4 100644 --- a/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala +++ b/sql/bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala @@ -97,7 +97,7 @@ object ElasticAggregation { buildScript: (String, Script) => Aggregation ): Aggregation = { if (transformFuncs.nonEmpty) { - val scriptSrc = SQLFunctionUtils.buildPainless(identifier) + val scriptSrc = identifier.painless val script = Script(scriptSrc).lang("painless") buildScript(aggName, script) } else { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 22388585..6ea886ec 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -20,17 +20,6 @@ object SQLFunctionUtils { aggregateAndTransformFunctions(identifier)._2 } - def buildPainless( - identifier: Identifier - ): String = { - val base = identifier.painless - val orderedFunctions = transformFunctions(identifier).reverse - orderedFunctions.foldLeft(base) { - case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr) - case (expr, f: PainlessScript) => s"$expr${f.painless}" - case (expr, f) => f.toSQL(expr) // fallback - } - } } trait SQLFunctionChain extends SQLFunction with SQLValidation { @@ -287,15 +276,7 @@ case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit) override def toSQL(base: String): String = { s"$sql(${end.sql}, ${start.sql}, ${unit.sql})" } - lazy val startPainless: String = start match { - case i: Identifier => SQLFunctionUtils.buildPainless(i) - case _ => start.painless - } - lazy val endPainless: String = end match { - case i: Identifier => SQLFunctionUtils.buildPainless(i) - case _ => end.painless - } - override def painless: String = s"${unit.painless}.between($startPainless, $endPainless)" + override def painless: String = s"${unit.painless}.between(${start.painless}, ${end.painless})" } case class DateAdd(interval: TimeInterval) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala index 51c30da2..e2991f9d 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala @@ -26,7 +26,7 @@ sealed trait Field extends Updateable with SQLFunctionChain with PainlessScript def update(request: SQLSearchRequest): Field - def painless: String = SQLFunctionUtils.buildPainless(identifier) + def painless: String = identifier.painless lazy val scriptName: String = fieldAlias.map(_.alias).getOrElse(sourceField) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index fbc97271..35448221 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -267,7 +267,15 @@ package object sql { lazy val aliasOrName: String = fieldAlias.getOrElse(name) - override def painless: String = if (name.nonEmpty) s"doc['$name'].value" else "" + override def painless: String = { + val base = if (name.nonEmpty) s"doc['$name'].value" else "" + val orderedFunctions = SQLFunctionUtils.transformFunctions(this).reverse + orderedFunctions.foldLeft(base) { + case (expr, f: SQLTransformFunction[_, _]) => f.toPainless(expr) + case (expr, f: PainlessScript) => s"$expr${f.painless}" + case (expr, f) => f.toSQL(expr) // fallback + } + } } From 36f800471044404f0dc7877133c074a0d904d6b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 8 Sep 2025 15:56:47 +0200 Subject: [PATCH 21/22] add SQLFunctionWithIdentifier, update date_add, date_sub, date_trunc, parse_date, format_date, datetime_add, datetime_sub, parse_datetime, format_datetime parameters --- .../elastic/sql/SQLQuerySpec.scala | 136 ++++++++++++++++++ .../elastic/sql/SQLQuerySpec.scala | 136 ++++++++++++++++++ .../softnetwork/elastic/sql/SQLFunction.scala | 76 ++++++---- .../softnetwork/elastic/sql/SQLParser.scala | 94 ++++++++---- .../sql/SQLDateTimeFunctionSuite.scala | 18 +-- .../elastic/sql/SQLParserSpec.scala | 69 ++++++++- 6 files changed, 456 insertions(+), 73 deletions(-) diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index e1a1d08b..97892ef4 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1339,4 +1339,140 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("ZonedDateTime", " ZonedDateTime") } + it should "handle date_add function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateAdd) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle date_sub function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateSub) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle datetime_add function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateTimeAdd) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle datetime_sub function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateTimeSub) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + } + } diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 5570bf20..d7f25e09 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1336,4 +1336,140 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("ZonedDateTime", " ZonedDateTime") } + it should "handle date_add function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateAdd) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle date_sub function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateSub) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle datetime_add function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateTimeAdd) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.plus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + } + + it should "handle datetime_sub function as script field" in { + val select: ElasticSearchRequest = + SQLQuery(dateTimeSub) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "exists": { + | "field": "identifier2" + | } + | } + | ] + | } + | }, + | "script_fields": { + | "lastSeen": { + | "script": { + | "lang": "painless", + | "source": "doc['lastUpdated'].value.minus(10, ChronoUnit.DAYS)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin.replaceAll("\\s+", "").replaceAll("ChronoUnit", " ChronoUnit") + } + } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index 6ea886ec..96ca2453 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -6,6 +6,10 @@ sealed trait SQLFunction extends SQLRegex { def toSQL(base: String): String = if (base.nonEmpty) s"$sql($base)" else sql } +sealed trait SQLFunctionWithIdentifier extends SQLFunction { + def identifier: SQLIdentifier +} + object SQLFunctionUtils { def aggregateAndTransformFunctions( identifier: Identifier @@ -219,14 +223,16 @@ case object Now extends SQLExpr("now") with CurrentDateTimeFunction case object NowWithParens extends SQLExpr("now()") with CurrentDateTimeFunction -case class DateTrunc(unit: TimeUnit) +case class DateTrunc(identifier: SQLIdentifier, unit: TimeUnit) extends SQLExpr("date_trunc") with DateTimeFunction with SQLTransformFunction[SQLTemporal, SQLTemporal] - with ParametrizedFunction { + with SQLFunctionWithIdentifier { override def inputType: SQLTemporal = SQLTypes.Temporal // par défaut override def outputType: SQLTemporal = SQLTypes.Temporal // idem - override def params: Seq[String] = Seq(unit.sql) + override def toSQL(base: String): String = { + s"$sql($base, ${unit.sql})" + } override def painless: String = s".truncatedTo(${unit.painless})" } @@ -279,97 +285,113 @@ case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit) override def painless: String = s"${unit.painless}.between(${start.painless}, ${end.painless})" } -case class DateAdd(interval: TimeInterval) +case class DateAdd(identifier: SQLIdentifier, interval: TimeInterval) extends SQLExpr("date_add") with DateFunction with SQLTransformFunction[SQLDate, SQLDate] - with ParametrizedFunction { + with SQLFunctionWithIdentifier { override def inputType: SQLDate = SQLTypes.Date override def outputType: SQLDate = SQLTypes.Date - override def params: Seq[String] = Seq(interval.sql) + override def toSQL(base: String): String = { + s"$sql($base, ${interval.sql})" + } override def painless: String = s".plus(${interval.painless})" } -case class DateSub(interval: TimeInterval) +case class DateSub(identifier: SQLIdentifier, interval: TimeInterval) extends SQLExpr("date_sub") with DateFunction with SQLTransformFunction[SQLDate, SQLDate] - with ParametrizedFunction { + with SQLFunctionWithIdentifier { override def inputType: SQLDate = SQLTypes.Date override def outputType: SQLDate = SQLTypes.Date - override def params: Seq[String] = Seq(interval.sql) + override def toSQL(base: String): String = { + s"$sql($base, ${interval.sql})" + } override def painless: String = s".minus(${interval.painless})" } -case class ParseDate(format: String) +case class ParseDate(identifier: SQLIdentifier, format: String) extends SQLExpr("parse_date") with DateFunction with SQLTransformFunction[SQLString, SQLDate] - with ParametrizedFunction { + with SQLFunctionWithIdentifier { override def inputType: SQLString = SQLTypes.String override def outputType: SQLDate = SQLTypes.Date - override def params: Seq[String] = Seq(s"'$format'") + override def toSQL(base: String): String = { + s"$sql($base, '$format')" + } override def painless: String = throw new NotImplementedError("Use toPainless instead") override def toPainless(base: String): String = s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDate::from)" } -case class FormatDate(format: String) +case class FormatDate(identifier: SQLIdentifier, format: String) extends SQLExpr("format_date") with DateFunction with SQLTransformFunction[SQLDate, SQLString] - with ParametrizedFunction { + with SQLFunctionWithIdentifier { override def inputType: SQLDate = SQLTypes.Date override def outputType: SQLString = SQLTypes.String - override def params: Seq[String] = Seq(s"'$format'") + override def toSQL(base: String): String = { + s"$sql($base, '$format')" + } override def painless: String = throw new NotImplementedError("Use toPainless instead") override def toPainless(base: String): String = s"DateTimeFormatter.ofPattern('$format').format($base)" } -case class DateTimeAdd(interval: TimeInterval) +case class DateTimeAdd(identifier: SQLIdentifier, interval: TimeInterval) extends SQLExpr("datetime_add") with DateTimeFunction with SQLTransformFunction[SQLDateTime, SQLDateTime] - with ParametrizedFunction { + with SQLFunctionWithIdentifier { override def inputType: SQLDateTime = SQLTypes.DateTime override def outputType: SQLDateTime = SQLTypes.DateTime - override def params: Seq[String] = Seq(interval.sql) + override def toSQL(base: String): String = { + s"$sql($base, ${interval.sql})" + } override def painless: String = s".plus(${interval.painless})" } -case class DateTimeSub(interval: TimeInterval) +case class DateTimeSub(identifier: SQLIdentifier, interval: TimeInterval) extends SQLExpr("datetime_sub") with DateTimeFunction with SQLTransformFunction[SQLDateTime, SQLDateTime] - with ParametrizedFunction { + with SQLFunctionWithIdentifier { override def inputType: SQLDateTime = SQLTypes.DateTime override def outputType: SQLDateTime = SQLTypes.DateTime - override def params: Seq[String] = Seq(interval.sql) + override def toSQL(base: String): String = { + s"$sql($base, ${interval.sql})" + } override def painless: String = s".minus(${interval.painless})" } -case class ParseDateTime(format: String) +case class ParseDateTime(identifier: SQLIdentifier, format: String) extends SQLExpr("parse_datetime") with DateTimeFunction with SQLTransformFunction[SQLString, SQLDateTime] - with ParametrizedFunction { + with SQLFunctionWithIdentifier { override def inputType: SQLString = SQLTypes.String override def outputType: SQLDateTime = SQLTypes.DateTime - override def params: Seq[String] = Seq(s"'$format'") + override def toSQL(base: String): String = { + s"$sql($base, '$format')" + } override def painless: String = throw new NotImplementedError("Use toPainless instead") override def toPainless(base: String): String = s"DateTimeFormatter.ofPattern('$format').parse($base, ZonedDateTime::from)" } -case class FormatDateTime(format: String) +case class FormatDateTime(identifier: SQLIdentifier, format: String) extends SQLExpr("format_datetime") with DateTimeFunction with SQLTransformFunction[SQLDateTime, SQLString] - with ParametrizedFunction { + with SQLFunctionWithIdentifier { override def inputType: SQLDateTime = SQLTypes.DateTime override def outputType: SQLString = SQLTypes.String - override def params: Seq[String] = Seq(s"'$format'") + override def toSQL(base: String): String = { + s"$sql($base, '$format')" + } override def painless: String = throw new NotImplementedError("Use toPainless instead") override def toPainless(base: String): String = s"DateTimeFormatter.ofPattern('$format').format($base)" diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index 2b1c0ac9..ad1b1bd4 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -162,9 +162,10 @@ trait SQLParser extends RegexParsers with PackratParsers { } } - def date_trunc: PackratParser[SQLUnaryFunction[SQLTemporal, SQLTemporal]] = - "(?i)date_trunc".r ~ start ~ time_unit ~ end ^^ { case _ ~ _ ~ u ~ _ => - DateTrunc(u) + def date_trunc: PackratParser[SQLFunctionWithIdentifier] = + "(?i)date_trunc".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ u ~ _ => + DateTrunc(i, u) } def extract: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = @@ -192,46 +193,54 @@ trait SQLParser extends RegexParsers with PackratParsers { def extractors: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumber]] = extract | extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second - def date_add: PackratParser[DateFunction] = - "(?i)date_add".r ~ start ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ => - DateAdd(i) + def date_add: PackratParser[DateFunction with SQLFunctionWithIdentifier] = + "(?i)date_add".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ t ~ _ => + DateAdd(i, t) } - def date_sub: PackratParser[DateFunction] = - "(?i)date_sub".r ~ start ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ => - DateSub(i) + def date_sub: PackratParser[DateFunction with SQLFunctionWithIdentifier] = + "(?i)date_sub".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ t ~ _ => + DateSub(i, t) } - def parse_date: PackratParser[DateFunction] = - "(?i)parse_date".r ~ start ~ literal ~ end ^^ { case _ ~ _ ~ f ~ _ => - ParseDate(f.value) + def parse_date: PackratParser[DateFunction with SQLFunctionWithIdentifier] = + "(?i)parse_date".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ f ~ _ => + ParseDate(i, f.value) } - def format_date: PackratParser[DateFunction] = - "(?i)format_date".r ~ start ~ literal ~ end ^^ { case _ ~ _ ~ f ~ _ => - FormatDate(f.value) + def format_date: PackratParser[DateFunction with SQLFunctionWithIdentifier] = + "(?i)format_date".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ f ~ _ => + FormatDate(i, f.value) } def date_functions: PackratParser[DateFunction] = date_add | date_sub | parse_date | format_date - def datetime_add: PackratParser[DateTimeFunction] = - "(?i)datetime_add".r ~ start ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ => - DateTimeAdd(i) + def datetime_add: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = + "(?i)datetime_add".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ t ~ _ => + DateTimeAdd(i, t) } - def datetime_sub: PackratParser[DateTimeFunction] = - "(?i)datetime_sub".r ~ start ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ => - DateTimeSub(i) + def datetime_sub: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = + "(?i)datetime_sub".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ t ~ _ => + DateTimeSub(i, t) } - def parse_datetime: PackratParser[DateTimeFunction] = - "(?i)parse_datetime".r ~ start ~ literal ~ end ^^ { case _ ~ _ ~ f ~ _ => - ParseDateTime(f.value) + def parse_datetime: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = + "(?i)parse_datetime".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ f ~ _ => + ParseDateTime(i, f.value) } - def format_datetime: PackratParser[DateTimeFunction] = - "(?i)format_datetime".r ~ start ~ literal ~ end ^^ { case _ ~ _ ~ f ~ _ => - FormatDateTime(f.value) + def format_datetime: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = + "(?i)format_datetime".r ~ start ~ (painless_identifier | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + case _ ~ _ ~ i ~ _ ~ f ~ _ => + FormatDateTime(i, f.value) } def datetime_functions: PackratParser[DateTimeFunction] = @@ -241,7 +250,7 @@ trait SQLParser extends RegexParsers with PackratParsers { def distance: PackratParser[SQLFunction] = Distance.regex ^^ (_ => Distance) - def painless_identifier: PackratParser[Identifier] = + def painless_identifier: PackratParser[SQLIdentifier] = repsep( date_trunc | extractors | date_functions | datetime_functions, start @@ -254,7 +263,12 @@ trait SQLParser extends RegexParsers with PackratParsers { } i match { case Some(id) => id.copy(functions = id.functions ++ f) - case None => SQLIdentifier("", functions = f) + case None => + f.lastOption match { + case Some(fi: SQLFunctionWithIdentifier) => + fi.identifier.copy(functions = f ++ fi.identifier.functions) + case _ => SQLIdentifier("", functions = f) + } } } @@ -281,6 +295,19 @@ trait SQLParser extends RegexParsers with PackratParsers { ) } + private[this] def dateFunctionWithIdentifier: PackratParser[SQLIdentifier] = + (parse_date | format_date | date_add | date_sub) ^^ { t => + t.identifier.copy(functions = t +: t.identifier.functions) + } + + private[this] def dateTimeFunctionWithIdentifier: PackratParser[SQLIdentifier] = + (date_trunc | parse_datetime | format_datetime | datetime_add | datetime_sub) ^^ { t => + t.identifier.copy(functions = t +: t.identifier.functions) + } + + def identifierWithTransformation: PackratParser[SQLIdentifier] = + dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier + def identifierWithArithmeticFunction: PackratParser[SQLIdentifier] = (identifierWithFunction | identifier) ~ arithmeticFunction ^^ { case i ~ af => i.copy(functions = af +: i.functions) @@ -304,7 +331,12 @@ trait SQLParser extends RegexParsers with PackratParsers { case _ => } i match { - case None => SQLIdentifier("", functions = f) + case None => + f.lastOption match { + case Some(fi: SQLFunctionWithIdentifier) => + fi.identifier.copy(functions = f ++ fi.identifier.functions) + case _ => SQLIdentifier("", functions = f) + } case Some(id) => id.copy(functions = id.functions ++ f) } } @@ -315,7 +347,7 @@ trait SQLParser extends RegexParsers with PackratParsers { def alias: PackratParser[SQLAlias] = Alias.regex.? ~ regexAlias.r ^^ { case _ ~ b => SQLAlias(b) } def field: PackratParser[Field] = - (identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | identifier) ~ alias.? ^^ { + (identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | identifierWithTransformation | date_diff_identifier | identifier) ~ alias.? ^^ { case i ~ a => SQLField(i, a) } diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala index 87411cb3..2017093c 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala @@ -10,16 +10,16 @@ class SQLDateTimeFunctionSuite extends AnyFunSuite { // Liste de toutes les fonctions transformables avec leurs types val transformFunctions: Seq[SQLTransformFunction[_, _]] = Seq( - ParseDate("yyyy-MM-dd"), - ParseDateTime("yyyy-MM-dd HH:mm:ss"), - DateAdd(TimeInterval(1, Day)), - DateSub(TimeInterval(2, Month)), - DateTimeAdd(TimeInterval(3, Hour)), - DateTimeSub(TimeInterval(30, Minute)), - DateTrunc(Day), + ParseDate(SQLIdentifier(""), "yyyy-MM-dd"), + ParseDateTime(SQLIdentifier(""), "yyyy-MM-dd HH:mm:ss"), + DateAdd(SQLIdentifier(""), TimeInterval(1, Day)), + DateSub(SQLIdentifier(""), TimeInterval(2, Month)), + DateTimeAdd(SQLIdentifier(""), TimeInterval(3, Hour)), + DateTimeSub(SQLIdentifier(""), TimeInterval(30, Minute)), + DateTrunc(SQLIdentifier(""), Day), Extract(Day), - FormatDate("yyyy-MM-dd"), - FormatDateTime("yyyy-MM-dd HH:mm:ss"), + FormatDate(SQLIdentifier(""), "yyyy-MM-dd"), + FormatDateTime(SQLIdentifier(""), "yyyy-MM-dd HH:mm:ss"), YEAR, MONTH, DAY, diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 1abac616..e961a9fc 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -97,15 +97,16 @@ object Queries { |order by Country asc""".stripMargin .replaceAll("\n", " ") val parseDate = - "select identifier, count(identifier2) as ct, max(parse_date('yyyy-MM-dd')(createdAt)) as lastSeen from Table where identifier2 is not null group by identifier order by count(identifier2) desc" + "select identifier, count(identifier2) as ct, max(parse_date(createdAt, 'yyyy-MM-dd')) as lastSeen from Table where identifier2 is not null group by identifier order by count(identifier2) desc" val parseDateTime: String = """select identifier, count(identifier2) as ct, |max( |year( - |date_trunc(minute)( - |parse_datetime('yyyy-MM-ddTHH:mm:ssZ')( - |createdAt - |)))) as lastSeen + |date_trunc( + |parse_datetime( + |createdAt, + |'yyyy-MM-ddTHH:mm:ssZ' + |), minute))) as lastSeen |from Table |where identifier2 is not null |group by identifier @@ -117,7 +118,21 @@ object Queries { val dateDiff = "select date_diff(createdAt, updatedAt, day) as diff, identifier from Table" val aggregationWithDateDiff = - "select max(date_diff(parse_datetime('yyyy-MM-ddTHH:mm:ssZ')(createdAt), updatedAt, day)) as max_diff from Table group by identifier" + "select max(date_diff(parse_datetime(createdAt, 'yyyy-MM-ddTHH:mm:ssZ'), updatedAt, day)) as max_diff from Table group by identifier" + + val formatDate = + "select identifier, format_date(date_trunc(lastUpdated, month), 'yyyy-MM-dd') as lastSeen from Table where identifier2 is not null" + val formatDateTime = + "select identifier, format_datetime(date_trunc(lastUpdated, month), 'yyyy-MM-ddThh:mm:ssZ') as lastSeen from Table where identifier2 is not null" + val dateAdd = + "select identifier, date_add(lastUpdated, interval 10 day) as lastSeen from Table where identifier2 is not null" + val dateSub = + "select identifier, date_sub(lastUpdated, interval 10 day) as lastSeen from Table where identifier2 is not null" + val dateTimeAdd = + "select identifier, datetime_add(lastUpdated, interval 10 day) as lastSeen from Table where identifier2 is not null" + val dateTimeSub = + "select identifier, datetime_sub(lastUpdated, interval 10 day) as lastSeen from Table where identifier2 is not null" + } /** Created by smanciot on 15/02/17. @@ -449,4 +464,46 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { aggregationWithDateDiff ) } + + it should "parse format_date function" in { + val result = SQLParser(formatDate) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + formatDate + ) + } + + it should "parse format_datetime function" in { + val result = SQLParser(formatDateTime) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + formatDateTime + ) + } + + it should "parse date_add function" in { + val result = SQLParser(dateAdd) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + dateAdd + ) + } + + it should "parse date_sub function" in { + val result = SQLParser(dateSub) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + dateSub + ) + } + + it should "parse datetime_add function" in { + val result = SQLParser(dateTimeAdd) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + dateTimeAdd + ) + } + + it should "parse datetime_sub function" in { + val result = SQLParser(dateTimeSub) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + dateTimeSub + ) + } } From 9d60fcf3fe27f6c3de9e8b8e646be88b9b9774ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Manciot?= Date: Mon, 8 Sep 2025 16:31:30 +0200 Subject: [PATCH 22/22] introduce BinaryExpression --- .../softnetwork/elastic/sql/SQLGroupBy.scala | 16 +++++++--- .../softnetwork/elastic/sql/SQLWhere.scala | 31 +++++++++++++++++-- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala index 75956b86..e30e7f73 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLGroupBy.scala @@ -95,8 +95,16 @@ object BucketSelectorScript { extractBucketsPath(left) ++ extractBucketsPath(right) case relation: ElasticRelation => extractBucketsPath(relation.criteria) case _: SQLMatch => Map.empty //MATCH is not supported in bucket_selector - case SQLComparisonDateMath(identifier, _, _, _, _, _) if identifier.aggregation => - Map(identifier.aliasOrName -> identifier.aliasOrName) + case b: BinaryExpression => + import b._ + if (left.aggregation && right.aggregation) + Map(left.aliasOrName -> left.aliasOrName, right.aliasOrName -> right.aliasOrName) + else if (left.aggregation) + Map(left.aliasOrName -> left.aliasOrName) + else if (right.aggregation) + Map(right.aliasOrName -> right.aliasOrName) + else + Map.empty case e: Expression if e.aggregation => import e._ Map(identifier.aliasOrName -> identifier.aliasOrName) @@ -129,9 +137,9 @@ object BucketSelectorScript { // build the RHS as a Painless ZonedDateTime (apply +/- interval using TimeInterval.painless) val rightBase = (arithOp, interval) match { - case (Some(Add), Some(i)) => s"$now.plus(${i.painless})" + case (Some(Add), Some(i)) => s"$now.plus(${i.painless})" case (Some(Subtract), Some(i)) => s"$now.minus(${i.painless})" - case _ => now + case _ => now } val rightZdt = dateFunc match { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index 4d0c1218..b40a09d9 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -178,6 +178,33 @@ case class SQLExpression( override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this } +sealed trait BinaryExpression extends Expression { + def left: SQLIdentifier + def right: SQLIdentifier + override def identifier: SQLIdentifier = left + override def maybeValue: Option[SQLToken] = Some(right) + override lazy val aggregation: Boolean = left.aggregation || right.aggregation +} + +case class SQLBinaryExpression( + left: SQLIdentifier, + operator: SQLComparisonOperator, // Gt, Ge, Lt, Le, Eq, ... + right: SQLIdentifier, + maybeNot: Option[Not.type] = None +) extends BinaryExpression + with PainlessScript { + override def update(request: SQLSearchRequest): SQLCriteria = + this.copy(left = left.update(request), right = right.update(request)) + + override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this + + override def painless: String = { + val painlessOp = (if (maybeNot.isDefined) operator.not else operator).painless + s"${left.painless} $painlessOp ${right.painless}" + } + +} + case class SQLIsNull(identifier: SQLIdentifier) extends Expression { override val operator: SQLOperator = IsNull @@ -355,9 +382,9 @@ case class SQLComparisonDateMath( val base = s"${dateTimeFunction.script}" val dateMath = (arithmeticOperator, interval) match { - case (Some(Add), Some(i)) => s"$base+${i.script}" + case (Some(Add), Some(i)) => s"$base+${i.script}" case (Some(Subtract), Some(i)) => s"$base-${i.script}" - case _ => base + case _ => base } dateTimeFunction match { case _: CurrentDateFunction => s"$dateMath/d"