diff --git a/build.sbt b/build.sbt index 4b12cadf..64180f7c 100644 --- a/build.sbt +++ b/build.sbt @@ -19,7 +19,7 @@ ThisBuild / organization := "app.softnetwork" name := "softclient4es" -ThisBuild / version := "0.6.0" +ThisBuild / version := "0.7.0" ThisBuild / scalaVersion := scala213 diff --git a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index c8a9b8e3..436a9f88 100644 --- a/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/es6/sql-bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1332,7 +1332,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "diff": { | "script": { | "lang": "painless", - | "source": "(def s = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def e = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); s != null && e != null ? ChronoUnit.DAYS.between(s, e) : null)" + | "source": "(def arg0 = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def arg1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); (arg0 == null || arg1 == null) ? null : ChronoUnit.DAYS.between(arg0, arg1))" | } | } | }, @@ -1345,14 +1345,14 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("\\s", "") .replaceAll("defv", "def v") .replaceAll("defe", "def e") - .replaceAll("defs", "def s") + .replaceAll("defa", "def a") .replaceAll("if\\(", "if (") .replaceAll("=\\(", " = (") .replaceAll("\\?", " ? ") .replaceAll(":null", " : null") .replaceAll("null:", "null : ") .replaceAll("return", " return ") - .replaceAll("between\\(s,", "between(s, ") + .replaceAll(",a", ", a") .replaceAll(";", "; ") .replaceAll("==", " == ") .replaceAll("!=", " != ") @@ -1382,7 +1382,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "max": { | "script": { | "lang": "painless", - | "source": "(def s = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def e = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(e0, ZonedDateTime::from) : null); s != null && e != null ? ChronoUnit.DAYS.between(s, e) : null)" + | "source": "(def arg0 = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def arg1 = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(e0, ZonedDateTime::from) : null); (arg0 == null || arg1 == null) ? null : ChronoUnit.DAYS.between(arg0, arg1))" | } | } | } @@ -1393,14 +1393,14 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("\\s", "") .replaceAll("defv", "def v") .replaceAll("defe", "def e") - .replaceAll("defs", "def s") + .replaceAll("defa", "def a") .replaceAll("if\\(", "if (") .replaceAll("=\\(", " = (") .replaceAll("\\?", " ? ") .replaceAll(":null", " : null") .replaceAll("null:", "null : ") .replaceAll("return", " return ") - .replaceAll("between\\(s,", "between(s, ") + .replaceAll(",a", ", a") .replaceAll(";", "; ") .replaceAll("==", " == ") .replaceAll("!=", " != ") @@ -1808,7 +1808,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "c": { | "script": { | "lang": "painless", - | "source": "{ def v0 = ({ def e1=(!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2=DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from).minus(2, ChronoUnit.DAYS); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); }" + | "source": "{ def v0 = ((def arg0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); (arg0 == null) ? null : arg0 == DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from).minus(2, ChronoUnit.DAYS) ? null : arg0));if (v0 != null) return v0; return ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); }" | } | } | }, @@ -1820,7 +1820,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { |}""".stripMargin .replaceAll("\\s+", "") .replaceAll("defv", " def v") - .replaceAll("defe", " def e") + .replaceAll("defa", "def a") .replaceAll("if\\(", "if (") .replaceAll("=\\(", " = (") .replaceAll("\\?", " ? ") @@ -1832,6 +1832,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("returnv", " return v") .replaceAll("returne", " return e") .replaceAll(";}", "; }") + .replaceAll(";\\(", "; (") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") @@ -1857,7 +1858,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "c": { | "script": { | "lang": "painless", - | "source": "{ def v0 = ({ def e1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2 = DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')).minus(2, ChronoUnit.HOURS); }.toInstant().toEpochMilli()" + | "source": "{ def v0 = ((def arg0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); (arg0 == null) ? null : arg0 == DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from) ? null : arg0));if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')).minus(2, ChronoUnit.HOURS); }.toInstant().toEpochMilli()" | } | } | }, @@ -1869,7 +1870,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { |}""".stripMargin .replaceAll("\\s+", "") .replaceAll("defv", " def v") - .replaceAll("defe", " def e") + .replaceAll("defa", "def a") .replaceAll("if\\(", "if (") .replaceAll("=\\(", " = (") .replaceAll("\\?", " ? ") @@ -2061,4 +2062,261 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("if \\(\\s*def", "if (def") } + it should "handle arithmetic function as script field and condition" in { + val select: ElasticSearchRequest = + SQLQuery(arithmetic.replace("as group1", "")) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 * (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().get(ChronoUnit.YEARS) - 10)) > 10000" + | } + | } + | } + | ] + | } + | }, + | "script_fields": { + | "add": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 + 1)" + | } + | }, + | "sub": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 - 1)" + | } + | }, + | "mul": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 * 2)" + | } + | }, + | "div": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 / 2)" + | } + | }, + | "mod": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 % 2)" + | } + | }, + | "identifier_mul_identifier2_minus_10": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((def lv1 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); def rv1 = ((!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value)); ( lv1 == null || rv1 == null ) ? null : (lv1 * rv1))); ( lv0 == null ) ? null : (lv0 - 10)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defl", "def l") + .replaceAll("defr", "def r") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") + .replaceAll(">", " > ") + .replaceAll("\\*", " * ") + .replaceAll("/", " / ") + .replaceAll("%", " % ") + .replaceAll("\\+", " + ") + .replaceAll("-", " - ") + .replaceAll("==", " == ") + .replaceAll("\\|\\|", " || ") + } + + it should "handle mathematic function as script field and condition" in { + val select: ElasticSearchRequest = + SQLQuery(mathematical) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.sqrt(arg0)) > 100.0" + | } + | } + | } + | ] + | } + | }, + | "script_fields": { + | "abs_identifier_plus_1_0_mul_2": { + | "script": { + | "lang": "painless", + | "source": "((def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.abs(arg0)) + 1.0) * ((double) 2)" + | } + | }, + | "ceil_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.ceil(arg0))" + | } + | }, + | "floor_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.floor(arg0))" + | } + | }, + | "sqrt_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.sqrt(arg0))" + | } + | }, + | "exp_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.exp(arg0))" + | } + | }, + | "log_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.log(arg0))" + | } + | }, + | "log10_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.log10(arg0))" + | } + | }, + | "pow_identifier_3": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.pow(arg0, 3))" + | } + | }, + | "round_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : (def p = Math.pow(10, 0); Math.round((arg0 * p) / p)))" + | } + | }, + | "round_identifier_2": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : (def p = Math.pow(10, 2); Math.round((arg0 * p) / p)))" + | } + | }, + | "sign_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); arg0 != null ? (arg0 > 0 ? 1 : (arg0 < 0 ? -1 : 0)) : null)" + | } + | }, + | "cos_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.cos(arg0))" + | } + | }, + | "acos_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.acos(arg0))" + | } + | }, + | "sin_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.sin(arg0))" + | } + | }, + | "asin_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.asin(arg0))" + | } + | }, + | "tan_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.tan(arg0))" + | } + | }, + | "atan_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.atan(arg0))" + | } + | }, + | "atan2_identifier_3_0": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.atan2(arg0, 3.0))" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defa", "def a") + .replaceAll("defp", "def p") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll(":\\(", " : (") + .replaceAll(":0", " : 0") + .replaceAll("=Math", " = Math") + .replaceAll(",(\\d)", ", $1") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("; if", ";if") + .replaceAll("==", " == ") + .replaceAll("\\+", " + ") + .replaceAll("\\*", " * ") + .replaceAll("/", " / ") + .replaceAll(">", " > ") + .replaceAll("<", " < ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(",LocalDate", ", LocalDate") + .replaceAll("=DateTimeFormatter", " = DateTimeFormatter") + .replaceAll("\\(double\\)(\\d)", "(double) $1") + } + } diff --git a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala index 4037983e..f19f2457 100644 --- a/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala +++ b/sql/bridge/src/test/scala/app/softnetwork/elastic/sql/SQLQuerySpec.scala @@ -1327,7 +1327,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "diff": { | "script": { | "lang": "painless", - | "source": "(def s = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def e = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); s != null && e != null ? ChronoUnit.DAYS.between(s, e) : null)" + | "source": "(def arg0 = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def arg1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); (arg0 == null || arg1 == null) ? null : ChronoUnit.DAYS.between(arg0, arg1))" | } | } | }, @@ -1340,14 +1340,14 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("\\s", "") .replaceAll("defv", "def v") .replaceAll("defe", "def e") - .replaceAll("defs", "def s") + .replaceAll("defa", "def a") .replaceAll("if\\(", "if (") .replaceAll("=\\(", " = (") .replaceAll("\\?", " ? ") .replaceAll(":null", " : null") .replaceAll("null:", "null : ") .replaceAll("return", " return ") - .replaceAll("between\\(s,", "between(s, ") + .replaceAll(",a", ", a") .replaceAll(";", "; ") .replaceAll("==", " == ") .replaceAll("!=", " != ") @@ -1377,7 +1377,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "max": { | "script": { | "lang": "painless", - | "source": "(def s = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def e = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(e0, ZonedDateTime::from) : null); s != null && e != null ? ChronoUnit.DAYS.between(s, e) : null)" + | "source": "(def arg0 = (!doc.containsKey('updatedAt') || doc['updatedAt'].empty ? null : doc['updatedAt'].value); def arg1 = (def e0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); e0 != null ? DateTimeFormatter.ofPattern('yyyy-MM-ddTHH:mm:ssZ').parse(e0, ZonedDateTime::from) : null); (arg0 == null || arg1 == null) ? null : ChronoUnit.DAYS.between(arg0, arg1))" | } | } | } @@ -1388,14 +1388,14 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("\\s", "") .replaceAll("defv", "def v") .replaceAll("defe", "def e") - .replaceAll("defs", "def s") + .replaceAll("defa", "def a") .replaceAll("if\\(", "if (") .replaceAll("=\\(", " = (") .replaceAll("\\?", " ? ") .replaceAll(":null", " : null") .replaceAll("null:", "null : ") .replaceAll("return", " return ") - .replaceAll("between\\(s,", "between(s, ") + .replaceAll(",a", ", a") .replaceAll(";", "; ") .replaceAll("==", " == ") .replaceAll("!=", " != ") @@ -1797,7 +1797,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "c": { | "script": { | "lang": "painless", - | "source": "{ def v0 = ({ def e1=(!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2=DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from).minus(2, ChronoUnit.DAYS); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); }" + | "source": "{ def v0 = ((def arg0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); (arg0 == null) ? null : arg0 == DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from).minus(2, ChronoUnit.DAYS) ? null : arg0));if (v0 != null) return v0; return ZonedDateTime.now(ZoneId.of('Z')).toLocalDate(); }" | } | } | }, @@ -1806,9 +1806,10 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "identifier" | ] | } - |}""".stripMargin.replaceAll("\\s+", "") + |}""".stripMargin + .replaceAll("\\s+", "") .replaceAll("defv", " def v") - .replaceAll("defe", " def e") + .replaceAll("defa", "def a") .replaceAll("if\\(", "if (") .replaceAll("=\\(", " = (") .replaceAll("\\?", " ? ") @@ -1820,6 +1821,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("returnv", " return v") .replaceAll("returne", " return e") .replaceAll(";}", "; }") + .replaceAll(";\\(", "; (") .replaceAll("==", " == ") .replaceAll("!=", " != ") .replaceAll("&&", " && ") @@ -1831,7 +1833,6 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("ZonedDateTime", " ZonedDateTime") } - it should "handle cast function as script field" in { val select: ElasticSearchRequest = SQLQuery(cast) @@ -1846,7 +1847,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { | "c": { | "script": { | "lang": "painless", - | "source": "{ def v0 = ({ def e1 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); def e2 = DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from); return e1 == e2 ? null : e1; });if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')).minus(2, ChronoUnit.HOURS); }.toInstant().toEpochMilli()" + | "source": "{ def v0 = ((def arg0 = (!doc.containsKey('createdAt') || doc['createdAt'].empty ? null : doc['createdAt'].value); (arg0 == null) ? null : arg0 == DateTimeFormatter.ofPattern('yyyy-MM-dd').parse(\"2025-09-11\", LocalDate::from) ? null : arg0));if (v0 != null) return v0; return (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate()).atStartOfDay(ZoneId.of('Z')).minus(2, ChronoUnit.HOURS); }.toInstant().toEpochMilli()" | } | } | }, @@ -1858,7 +1859,7 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { |}""".stripMargin .replaceAll("\\s+", "") .replaceAll("defv", " def v") - .replaceAll("defe", " def e") + .replaceAll("defa", "def a") .replaceAll("if\\(", "if (") .replaceAll("=\\(", " = (") .replaceAll("\\?", " ? ") @@ -2050,4 +2051,261 @@ class SQLQuerySpec extends AnyFlatSpec with Matchers { .replaceAll("if \\(\\s*def", "if (def") } + it should "handle arithmetic function as script field and condition" in { + val select: ElasticSearchRequest = + SQLQuery(arithmetic.replace("as group1", "")) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 * (ZonedDateTime.now(ZoneId.of('Z')).toLocalDate().get(ChronoUnit.YEARS) - 10)) > 10000" + | } + | } + | } + | ] + | } + | }, + | "script_fields": { + | "add": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 + 1)" + | } + | }, + | "sub": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 - 1)" + | } + | }, + | "mul": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 * 2)" + | } + | }, + | "div": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 / 2)" + | } + | }, + | "mod": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); ( lv0 == null ) ? null : (lv0 % 2)" + | } + | }, + | "identifier_mul_identifier2_minus_10": { + | "script": { + | "lang": "painless", + | "source": "def lv0 = ((def lv1 = ((!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value)); def rv1 = ((!doc.containsKey('identifier2') || doc['identifier2'].empty ? null : doc['identifier2'].value)); ( lv1 == null || rv1 == null ) ? null : (lv1 * rv1))); ( lv0 == null ) ? null : (lv0 - 10)" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s", "") + .replaceAll("defv", "def v") + .replaceAll("defe", "def e") + .replaceAll("defl", "def l") + .replaceAll("defr", "def r") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll(";", "; ") + .replaceAll(">", " > ") + .replaceAll("\\*", " * ") + .replaceAll("/", " / ") + .replaceAll("%", " % ") + .replaceAll("\\+", " + ") + .replaceAll("-", " - ") + .replaceAll("==", " == ") + .replaceAll("\\|\\|", " || ") + } + + it should "handle mathematic function as script field and condition" in { + val select: ElasticSearchRequest = + SQLQuery(mathematical) + val query = select.query + println(query) + query shouldBe + """{ + | "query": { + | "bool": { + | "filter": [ + | { + | "script": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.sqrt(arg0)) > 100.0" + | } + | } + | } + | ] + | } + | }, + | "script_fields": { + | "abs_identifier_plus_1_0_mul_2": { + | "script": { + | "lang": "painless", + | "source": "((def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.abs(arg0)) + 1.0) * ((double) 2)" + | } + | }, + | "ceil_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.ceil(arg0))" + | } + | }, + | "floor_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.floor(arg0))" + | } + | }, + | "sqrt_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.sqrt(arg0))" + | } + | }, + | "exp_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.exp(arg0))" + | } + | }, + | "log_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.log(arg0))" + | } + | }, + | "log10_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.log10(arg0))" + | } + | }, + | "pow_identifier_3": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.pow(arg0, 3))" + | } + | }, + | "round_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : (def p = Math.pow(10, 0); Math.round((arg0 * p) / p)))" + | } + | }, + | "round_identifier_2": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : (def p = Math.pow(10, 2); Math.round((arg0 * p) / p)))" + | } + | }, + | "sign_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); arg0 != null ? (arg0 > 0 ? 1 : (arg0 < 0 ? -1 : 0)) : null)" + | } + | }, + | "cos_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.cos(arg0))" + | } + | }, + | "acos_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.acos(arg0))" + | } + | }, + | "sin_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.sin(arg0))" + | } + | }, + | "asin_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.asin(arg0))" + | } + | }, + | "tan_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.tan(arg0))" + | } + | }, + | "atan_identifier": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.atan(arg0))" + | } + | }, + | "atan2_identifier_3_0": { + | "script": { + | "lang": "painless", + | "source": "(def arg0 = (!doc.containsKey('identifier') || doc['identifier'].empty ? null : doc['identifier'].value); (arg0 == null) ? null : Math.atan2(arg0, 3.0))" + | } + | } + | }, + | "_source": { + | "includes": [ + | "identifier" + | ] + | } + |}""".stripMargin + .replaceAll("\\s+", "") + .replaceAll("defv", " def v") + .replaceAll("defa", "def a") + .replaceAll("defp", "def p") + .replaceAll("if\\(", "if (") + .replaceAll("=\\(", " = (") + .replaceAll(":\\(", " : (") + .replaceAll(":0", " : 0") + .replaceAll("=Math", " = Math") + .replaceAll(",(\\d)", ", $1") + .replaceAll("\\?", " ? ") + .replaceAll(":null", " : null") + .replaceAll("null:", "null : ") + .replaceAll("return", " return ") + .replaceAll("between\\(s,", "between(s, ") + .replaceAll(";", "; ") + .replaceAll("; if", ";if") + .replaceAll("==", " == ") + .replaceAll("\\+", " + ") + .replaceAll("\\*", " * ") + .replaceAll("/", " / ") + .replaceAll(">", " > ") + .replaceAll("<", " < ") + .replaceAll("!=", " != ") + .replaceAll("&&", " && ") + .replaceAll("\\|\\|", " || ") + .replaceAll(";\\s\\s", "; ") + .replaceAll("ChronoUnit", " ChronoUnit") + .replaceAll(",LocalDate", ", LocalDate") + .replaceAll("=DateTimeFormatter", " = DateTimeFormatter") + .replaceAll("\\(double\\)(\\d)", "(double) $1") + } + } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala index cbdd1552..8bf6ab49 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLFunction.scala @@ -14,7 +14,7 @@ sealed trait SQLFunction extends SQLRegex { } sealed trait SQLFunctionWithIdentifier extends SQLFunction { - def identifier: SQLIdentifier + def identifier: SQLIdentifier //= SQLIdentifier("", functions = this :: Nil) } trait SQLFunctionWithValue[+T] extends SQLFunction { @@ -97,46 +97,83 @@ trait SQLFunctionChain extends SQLFunction { } } + def arithmetic: Boolean = functions.nonEmpty && functions.forall { + case _: SQLArithmeticExpression => true + case _ => false + } } -sealed trait SQLUnaryFunction[In <: SQLType, Out <: SQLType] - extends SQLFunction - with PainlessScript { +sealed trait SQLFunctionN[In <: SQLType, Out <: SQLType] extends SQLFunction with PainlessScript { + def fun: Option[PainlessScript] = None + + def args: List[PainlessScript] + def argsSeparator: String = ", " + def inputType: In def outputType: Out + override def in: SQLType = inputType override def out: SQLType = outputType + override def applyType(in: SQLType): SQLType = outputType + + override def sql: String = + s"${fun.map(_.sql).getOrElse("")}(${args.map(_.sql).mkString(argsSeparator)})" + + override def toSQL(base: String): String = s"$base$sql" + + override def painless: String = { + val nullCheck = + args.filter(_.nullable).zipWithIndex.map { case (_, i) => s"arg$i == null" }.mkString(" || ") + + val assignments = + args + .filter(_.nullable) + .zipWithIndex + .map { case (a, i) => s"def arg$i = ${a.painless};" } + .mkString(" ") + + val callArgs = args.zipWithIndex + .map { case (a, i) => + if (a.nullable) + s"arg$i" + else + a.painless + } + + if (args.exists(_.nullable)) + s"($assignments ($nullCheck) ? null : ${toPainlessCall(callArgs)})" + else + s"${toPainlessCall(callArgs)}" + } + + def toPainlessCall(callArgs: List[String]): String = + if (callArgs.nonEmpty) + s"${fun.map(_.painless).getOrElse("")}(${callArgs.mkString(argsSeparator)})" + else + fun.map(_.painless).getOrElse("") } sealed trait SQLBinaryFunction[In1 <: SQLType, In2 <: SQLType, Out <: SQLType] - extends SQLUnaryFunction[SQLAny, Out] { self: SQLFunction => - - override def inputType: SQLAny = SQLTypes.Any + extends SQLFunctionN[In2, Out] { self: SQLFunction => def left: PainlessScript def right: PainlessScript + override def args: List[PainlessScript] = List(left, right) + override def nullable: Boolean = left.nullable || right.nullable } -sealed trait SQLTransformFunction[In <: SQLType, Out <: SQLType] extends SQLUnaryFunction[In, Out] { +sealed trait SQLTransformFunction[In <: SQLType, Out <: SQLType] extends SQLFunctionN[In, Out] { def toPainless(base: String, idx: Int): String = { - if (nullable) + if (nullable && base.nonEmpty) s"(def e$idx = $base; e$idx != null ? e$idx$painless : null)" else s"$base$painless" } } -sealed trait SQLArithmeticFunction[In <: SQLType, Out <: SQLType] - extends SQLTransformFunction[In, Out] - with MathScript { - def operator: ArithmeticOperator - override def toSQL(base: String): String = s"$base$operator$sql" - override def applyType(in: SQLType): SQLType = in -} - sealed trait AggregateFunction extends SQLFunction case object Count extends SQLExpr("count") with AggregateFunction case object Min extends SQLExpr("min") with AggregateFunction @@ -150,6 +187,8 @@ sealed trait TimeUnit extends PainlessScript with MathScript { lazy val regex: Regex = s"\\b(?i)$sql(s)?\\b".r override def painless: String = s"ChronoUnit.${sql.toUpperCase()}S" + + override def nullable: Boolean = false } sealed trait CalendarUnit extends TimeUnit @@ -222,6 +261,8 @@ sealed trait TimeInterval extends PainlessScript with MathScript { Left(s"Intervals not supported for type $in") } } + + override def nullable: Boolean = false } import TimeUnit._ @@ -241,10 +282,24 @@ object TimeInterval { } } -sealed trait SQLIntervalFunction[IO <: SQLTemporal] extends SQLArithmeticFunction[IO, IO] { +sealed trait SQLIntervalFunction[IO <: SQLTemporal] + extends SQLTransformFunction[IO, IO] + with MathScript { + def operator: IntervalOperator + + override def fun: Option[IntervalOperator] = Some(operator) + def interval: TimeInterval + + override def args: List[PainlessScript] = List(interval) + + override def argsSeparator: String = " " + override def sql: String = s"$operator${args.map(_.sql).mkString(argsSeparator)}" + override def script: String = s"${operator.script}${interval.script}" + private[this] var _out: SQLType = outputType + override def out: SQLType = _out override def applyType(in: SQLType): SQLType = { @@ -265,25 +320,19 @@ sealed trait SQLIntervalFunction[IO <: SQLTemporal] extends SQLArithmeticFunctio } sealed trait AddInterval[IO <: SQLTemporal] extends SQLIntervalFunction[IO] { - override def operator: ArithmeticOperator = Add - override def painless: String = s".plus(${interval.painless})" + override def operator: IntervalOperator = Plus } sealed trait SubtractInterval[IO <: SQLTemporal] extends SQLIntervalFunction[IO] { - override def operator: ArithmeticOperator = Subtract - override def painless: String = s".minus(${interval.painless})" + override def operator: IntervalOperator = Minus } -case class SQLAddInterval(interval: TimeInterval) - extends SQLExpr(interval.sql) - with AddInterval[SQLTemporal] { +case class SQLAddInterval(interval: TimeInterval) extends AddInterval[SQLTemporal] { override def inputType: SQLTemporal = SQLTypes.Temporal override def outputType: SQLTemporal = SQLTypes.Temporal } -case class SQLSubtractInterval(interval: TimeInterval) - extends SQLExpr(interval.sql) - with SubtractInterval[SQLTemporal] { +case class SQLSubtractInterval(interval: TimeInterval) extends SubtractInterval[SQLTemporal] { override def inputType: SQLTemporal = SQLTypes.Temporal override def outputType: SQLTemporal = SQLTypes.Temporal } @@ -339,27 +388,43 @@ case object Now extends SQLExpr("now") with CurrentDateTimeFunction case object NowWithParens extends SQLExpr("now()") with CurrentDateTimeFunction +case object DateTrunc extends SQLExpr("date_trunc") with SQLRegex with PainlessScript { + override def painless: String = ".truncatedTo" +} + case class DateTrunc(identifier: SQLIdentifier, unit: TimeUnit) - extends SQLExpr("date_trunc") - with DateTimeFunction + extends DateTimeFunction with SQLTransformFunction[SQLTemporal, SQLTemporal] with SQLFunctionWithIdentifier { + override def fun: Option[PainlessScript] = Some(DateTrunc) + + override def args: List[PainlessScript] = List(unit) + override def inputType: SQLTemporal = SQLTypes.Temporal // par défaut override def outputType: SQLTemporal = SQLTypes.Temporal // idem + + override def sql: String = DateTrunc.sql override def toSQL(base: String): String = { s"$sql($base, ${unit.sql})" } - override def painless: String = s".truncatedTo(${unit.painless})" +} + +case object Extract extends SQLExpr("extract") with SQLRegex with PainlessScript { + override def painless: String = ".get" } case class Extract(unit: TimeUnit, override val sql: String = "extract") - extends SQLExpr(sql) - with DateTimeFunction + extends DateTimeFunction with SQLTransformFunction[SQLTemporal, SQLNumeric] { + override def fun: Option[PainlessScript] = Some(Extract) + + override def args: List[PainlessScript] = List(unit) + override def inputType: SQLTemporal = SQLTypes.Temporal override def outputType: SQLNumeric = SQLTypes.Numeric + override def toSQL(base: String): String = s"$sql(${unit.sql} from $base)" - override def painless: String = s".get(${unit.painless})" + } object YEAR extends Extract(Year, Year.sql) { @@ -386,65 +451,80 @@ object SECOND extends Extract(Second, Second.sql) { override def toSQL(base: String): String = s"$sql($base)" } +case object DateDiff extends SQLExpr("date_diff") with SQLRegex with PainlessScript { + override def painless: String = ".between" +} + case class DateDiff(end: PainlessScript, start: PainlessScript, unit: TimeUnit) - extends SQLExpr("date_diff") - with DateTimeFunction + extends DateTimeFunction with SQLBinaryFunction[SQLDateTime, SQLDateTime, SQLNumeric] with PainlessScript { + override def fun: Option[PainlessScript] = Some(DateDiff) + + override def inputType: SQLDateTime = SQLTypes.DateTime override def outputType: SQLNumeric = SQLTypes.Numeric - override def left: PainlessScript = end - override def right: PainlessScript = start - override def toSQL(base: String): String = { - s"$sql(${end.sql}, ${start.sql}, ${unit.sql})" - } - override def painless: String = { - if (start.nullable && end.nullable) - s"(def s = ${start.painless}; def e = ${end.painless}; s != null && e != null ? ${unit.painless}.between(s, e) : null)" - else if (start.nullable) - s"(def s = ${start.painless}; s != null ? ${unit.painless}.between(s, ${end.painless}) : null)" - else if (end.nullable) - s"(def e = ${end.painless}; e != null ? ${unit.painless}.between(${start.painless}, e) : null)" - else - s"${unit.painless}.between(${start.painless}, ${end.painless})" - } + + override def left: PainlessScript = start + override def right: PainlessScript = end + + override def sql: String = DateDiff.sql + + override def toSQL(base: String): String = s"$sql(${end.sql}, ${start.sql}, ${unit.sql})" + + override def toPainlessCall(callArgs: List[String]): String = + s"${unit.painless}${DateDiff.painless}(${callArgs.mkString(", ")})" } +case object DateAdd extends SQLExpr("date_add") with SQLRegex + case class DateAdd(identifier: SQLIdentifier, interval: TimeInterval) - extends SQLExpr("date_add") - with DateFunction + extends DateFunction with AddInterval[SQLDate] with SQLTransformFunction[SQLDate, SQLDate] with SQLFunctionWithIdentifier { override def inputType: SQLDate = SQLTypes.Date override def outputType: SQLDate = SQLTypes.Date + override def sql: String = DateAdd.sql override def toSQL(base: String): String = { s"$sql($base, ${interval.sql})" } } +case object DateSub extends SQLExpr("date_sub") with SQLRegex + case class DateSub(identifier: SQLIdentifier, interval: TimeInterval) - extends SQLExpr("date_sub") - with DateFunction + extends DateFunction with SubtractInterval[SQLDate] with SQLTransformFunction[SQLDate, SQLDate] with SQLFunctionWithIdentifier { override def inputType: SQLDate = SQLTypes.Date override def outputType: SQLDate = SQLTypes.Date + override def sql: String = DateSub.sql override def toSQL(base: String): String = { s"$sql($base, ${interval.sql})" } } +case object ParseDate extends SQLExpr("parse_date") with SQLRegex with PainlessScript { + override def painless: String = ".parse" +} + case class ParseDate(identifier: SQLIdentifier, format: String) - extends SQLExpr("parse_date") - with DateFunction + extends DateFunction with SQLTransformFunction[SQLVarchar, SQLDate] with SQLFunctionWithIdentifier { + override def fun: Option[PainlessScript] = Some(ParseDate) + + override def args: List[PainlessScript] = List.empty + override def inputType: SQLVarchar = SQLTypes.Varchar override def outputType: SQLDate = SQLTypes.Date + + override def sql: String = ParseDate.sql override def toSQL(base: String): String = { s"$sql($base, '$format')" } + override def painless: String = throw new NotImplementedError("Use toPainless instead") override def toPainless(base: String, idx: Int): String = if (nullable) @@ -453,16 +533,26 @@ case class ParseDate(identifier: SQLIdentifier, format: String) s"DateTimeFormatter.ofPattern('$format').parse($base, LocalDate::from)" } +case object FormatDate extends SQLExpr("format_date") with SQLRegex with PainlessScript { + override def painless: String = ".format" +} + case class FormatDate(identifier: SQLIdentifier, format: String) - extends SQLExpr("format_date") - with DateFunction + extends DateFunction with SQLTransformFunction[SQLDate, SQLVarchar] with SQLFunctionWithIdentifier { + override def fun: Option[PainlessScript] = Some(FormatDate) + + override def args: List[PainlessScript] = List.empty + override def inputType: SQLDate = SQLTypes.Date override def outputType: SQLVarchar = SQLTypes.Varchar + + override def sql: String = FormatDate.sql override def toSQL(base: String): String = { s"$sql($base, '$format')" } + override def painless: String = throw new NotImplementedError("Use toPainless instead") override def toPainless(base: String, idx: Int): String = if (nullable) @@ -471,42 +561,56 @@ case class FormatDate(identifier: SQLIdentifier, format: String) s"DateTimeFormatter.ofPattern('$format').format($base)" } +case object DateTimeAdd extends SQLExpr("datetime_add") with SQLRegex + case class DateTimeAdd(identifier: SQLIdentifier, interval: TimeInterval) - extends SQLExpr("datetime_add") - with DateTimeFunction + extends DateTimeFunction with AddInterval[SQLDateTime] with SQLTransformFunction[SQLDateTime, SQLDateTime] with SQLFunctionWithIdentifier { override def inputType: SQLDateTime = SQLTypes.DateTime override def outputType: SQLDateTime = SQLTypes.DateTime + override def sql: String = DateTimeAdd.sql override def toSQL(base: String): String = { s"$sql($base, ${interval.sql})" } } +case object DateTimeSub extends SQLExpr("datetime_sub") with SQLRegex + case class DateTimeSub(identifier: SQLIdentifier, interval: TimeInterval) - extends SQLExpr("datetime_sub") - with DateTimeFunction + extends DateTimeFunction with SubtractInterval[SQLDateTime] with SQLTransformFunction[SQLDateTime, SQLDateTime] with SQLFunctionWithIdentifier { override def inputType: SQLDateTime = SQLTypes.DateTime override def outputType: SQLDateTime = SQLTypes.DateTime + override def sql: String = DateTimeSub.sql override def toSQL(base: String): String = { s"$sql($base, ${interval.sql})" } } +case object ParseDateTime extends SQLExpr("parse_datetime") with SQLRegex with PainlessScript { + override def painless: String = ".parse" +} + case class ParseDateTime(identifier: SQLIdentifier, format: String) - extends SQLExpr("parse_datetime") - with DateTimeFunction + extends DateTimeFunction with SQLTransformFunction[SQLVarchar, SQLDateTime] with SQLFunctionWithIdentifier { + override def fun: Option[PainlessScript] = Some(ParseDateTime) + + override def args: List[PainlessScript] = List.empty + override def inputType: SQLVarchar = SQLTypes.Varchar override def outputType: SQLDateTime = SQLTypes.DateTime + + override def sql: String = ParseDateTime.sql override def toSQL(base: String): String = { s"$sql($base, '$format')" } + override def painless: String = throw new NotImplementedError("Use toPainless instead") override def toPainless(base: String, idx: Int): String = if (nullable) @@ -515,16 +619,26 @@ case class ParseDateTime(identifier: SQLIdentifier, format: String) s"DateTimeFormatter.ofPattern('$format').parse($base, ZonedDateTime::from)" } +case object FormatDateTime extends SQLExpr("format_datetime") with SQLRegex with PainlessScript { + override def painless: String = ".format" +} + case class FormatDateTime(identifier: SQLIdentifier, format: String) - extends SQLExpr("format_datetime") - with DateTimeFunction + extends DateTimeFunction with SQLTransformFunction[SQLDateTime, SQLVarchar] with SQLFunctionWithIdentifier { + override def fun: Option[PainlessScript] = Some(FormatDateTime) + + override def args: List[PainlessScript] = List.empty + override def inputType: SQLDateTime = SQLTypes.DateTime override def outputType: SQLVarchar = SQLTypes.Varchar + + override def sql: String = FormatDateTime.sql override def toSQL(base: String): String = { s"$sql($base, '$format')" } + override def painless: String = throw new NotImplementedError("Use toPainless instead") override def toPainless(base: String, idx: Int): String = if (nullable) @@ -537,15 +651,23 @@ sealed trait SQLConditionalFunction[In <: SQLType] extends SQLTransformFunction[In, SQLBool] with SQLFunctionWithIdentifier { def operator: SQLConditionalOperator + + override def fun: Option[PainlessScript] = Some(operator) + override def outputType: SQLBool = SQLTypes.Boolean + override def toPainless(base: String, idx: Int): String = s"($base$painless)" } -case class SQLIsNullFunction(identifier: SQLIdentifier) - extends SQLExpr("isnull") - with SQLConditionalFunction[SQLAny] { - override def operator: SQLConditionalOperator = IsNull +case class SQLIsNullFunction(identifier: SQLIdentifier) extends SQLConditionalFunction[SQLAny] { + override def operator: SQLConditionalOperator = IsNullFunction + + override def args: List[PainlessScript] = List(identifier) + override def inputType: SQLAny = SQLTypes.Any + + override def toSQL(base: String): String = sql + override def painless: String = s" == null" override def toPainless(base: String, idx: Int): String = { if (nullable) @@ -555,11 +677,15 @@ case class SQLIsNullFunction(identifier: SQLIdentifier) } } -case class SQLIsNotNullFunction(identifier: SQLIdentifier) - extends SQLExpr("isnotnull") - with SQLConditionalFunction[SQLAny] { - override def operator: SQLConditionalOperator = IsNotNull +case class SQLIsNotNullFunction(identifier: SQLIdentifier) extends SQLConditionalFunction[SQLAny] { + override def operator: SQLConditionalOperator = IsNotNullFunction + + override def args: List[PainlessScript] = List(identifier) + override def inputType: SQLAny = SQLTypes.Any + + override def toSQL(base: String): String = sql + override def painless: String = s" != null" override def toPainless(base: String, idx: Int): String = { if (nullable) @@ -569,14 +695,22 @@ case class SQLIsNotNullFunction(identifier: SQLIdentifier) } } -case class SQLCoalesce(values: List[PainlessScript]) extends SQLConditionalFunction[SQLAny] { - override def operator: SQLConditionalOperator = Coalesce +case class SQLCoalesce(values: List[PainlessScript]) + extends SQLTransformFunction[SQLAny, SQLType] + with SQLFunctionWithIdentifier { + def operator: SQLConditionalOperator = Coalesce + + override def fun: Option[SQLConditionalOperator] = Some(operator) + + override def args: List[PainlessScript] = values + + override def outputType: SQLType = SQLTypeUtils.leastCommonSuperType(args.map(_.out)) override def identifier: SQLIdentifier = SQLIdentifier("") override def inputType: SQLAny = SQLTypes.Any - override lazy val sql: String = s"$Coalesce(${values.map(_.sql).mkString(", ")})" + override def sql: String = s"$Coalesce(${values.map(_.sql).mkString(", ")})" // Reprend l’idée de SQLValues mais pour n’importe quel token override def out: SQLType = SQLTypeUtils.leastCommonSuperType(values.map(_.out).distinct) @@ -613,23 +747,21 @@ case class SQLNullIf(expr1: PainlessScript, expr2: PainlessScript) extends SQLConditionalFunction[SQLAny] { override def operator: SQLConditionalOperator = NullIf + override def args: List[PainlessScript] = List(expr1, expr2) + override def identifier: SQLIdentifier = SQLIdentifier("") override def inputType: SQLAny = SQLTypes.Any - override def sql: String = s"$NullIf(${expr1.sql}, ${expr2.sql})" - override def out: SQLType = expr1.out override def applyType(in: SQLType): SQLType = out - override def painless: String = { - val e1 = expr1.painless - val e2 = expr2.painless - s"""{ def e1 = $e1; - |def e2 = $e2; - |return e1 == e2 ? null : e1; - |}""".stripMargin.replaceAll("\n", " ") + override def toPainlessCall(callArgs: List[String]): String = { + callArgs match { + case List(arg0, arg1) => s"${arg0.trim} == ${arg1.trim} ? null : $arg0" + case _ => throw new IllegalArgumentException("NULLIF requires exactly two arguments") + } } } @@ -638,6 +770,8 @@ case class SQLCast(value: PainlessScript, targetType: SQLType, as: Boolean = tru override def inputType: SQLType = value.out override def outputType: SQLType = targetType + override def args: List[PainlessScript] = List.empty + override def sql: String = s"$Cast(${value.sql} ${if (as) s"$Alias " else ""}${targetType.typeId})" @@ -655,6 +789,8 @@ case class SQLCaseWhen( conditions: List[(PainlessScript, PainlessScript)], default: Option[PainlessScript] ) extends SQLTransformFunction[SQLAny, SQLAny] { + override def args: List[PainlessScript] = List.empty + override def inputType: SQLAny = SQLTypes.Any override def outputType: SQLAny = SQLTypes.Any @@ -751,3 +887,141 @@ case class SQLCaseWhen( override def nullable: Boolean = conditions.exists { case (_, res) => res.nullable } || default.forall(_.nullable) } + +case class SQLArithmeticExpression( + left: PainlessScript, + operator: ArithmeticOperator, + right: PainlessScript, + group: Boolean = false +) extends SQLTransformFunction[SQLNumeric, SQLNumeric] + with SQLBinaryFunction[SQLNumeric, SQLNumeric, SQLNumeric] { + + override def fun: Option[ArithmeticOperator] = Some(operator) + + override def inputType: SQLNumeric = SQLTypes.Numeric + override def outputType: SQLNumeric = SQLTypes.Numeric + + override def applyType(in: SQLType): SQLType = in + + override def sql: String = { + val expr = s"${left.sql}$operator${right.sql}" + if (group) + s"($expr)" + else + expr + } + + override def out: SQLType = + SQLTypeUtils.leastCommonSuperType(List(left.out, right.out)) + + override def validate(): Either[String, Unit] = { + for { + _ <- left.validate() + _ <- right.validate() + _ <- SQLValidator.validateTypesMatching(left.out, right.out) + } yield () + } + + override def nullable: Boolean = left.nullable || right.nullable + + override def toPainless(base: String, idx: Int): String = { + if (nullable) { + val l = left match { + case t: SQLTransformFunction[_, _] => + SQLTypeUtils.coerce(t.toPainless("", idx + 1), left.out, out, nullable = false) + case _ => SQLTypeUtils.coerce(left.painless, left.out, out, nullable = false) + } + val r = right match { + case t: SQLTransformFunction[_, _] => + SQLTypeUtils.coerce(t.toPainless("", idx + 1), right.out, out, nullable = false) + case _ => SQLTypeUtils.coerce(right.painless, right.out, out, nullable = false) + } + var expr = "" + if (left.nullable) + expr += s"def lv$idx = ($l); " + if (right.nullable) + expr += s"def rv$idx = ($r); " + if (left.nullable && right.nullable) + expr += s"(lv$idx == null || rv$idx == null) ? null : (lv$idx ${operator.painless} rv$idx)" + else if (left.nullable) + expr += s"(lv$idx == null) ? null : (lv$idx ${operator.painless} $r)" + else + expr += s"(rv$idx == null) ? null : ($l ${operator.painless} rv$idx)" + if (group) + expr = s"($expr)" + return s"$base$expr" + } + s"$base$painless" + } + + override def painless: String = { + val l = SQLTypeUtils.coerce(left, out) + val r = SQLTypeUtils.coerce(right, out) + val expr = s"$l ${operator.painless} $r" + if (group) + s"($expr)" + else + expr + } + +} + +sealed trait MathematicalFunction + extends SQLTransformFunction[SQLNumeric, SQLNumeric] + with SQLFunctionWithIdentifier { + override def inputType: SQLNumeric = SQLTypes.Numeric + + override def outputType: SQLNumeric = SQLTypes.Double + + def operator: UnaryArithmeticOperator + + override def fun: Option[PainlessScript] = Some(operator) + + override def identifier: SQLIdentifier = SQLIdentifier("", functions = this :: Nil) + +} + +case class SQLMathematicalFunction( + operator: UnaryArithmeticOperator, + arg: PainlessScript +) extends MathematicalFunction { + override def args: List[PainlessScript] = List(arg) +} + +case class SQLPow(arg: PainlessScript, exponent: Int) extends MathematicalFunction { + override def operator: UnaryArithmeticOperator = Pow + override def args: List[PainlessScript] = List(arg, SQLIntValue(exponent)) + override def nullable: Boolean = arg.nullable +} + +case class SQLRound(arg: PainlessScript, scale: Option[Int]) extends MathematicalFunction { + override def operator: UnaryArithmeticOperator = Round + + override def args: List[PainlessScript] = + List(arg) ++ scale.map(SQLIntValue(_)).toList + + override def toPainlessCall(callArgs: List[String]): String = + s"(def p = ${SQLPow(SQLIntValue(10), scale.getOrElse(0)).painless}; ${operator.painless}((${callArgs.head} * p) / p))" +} + +case class SQLSign(arg: PainlessScript) extends MathematicalFunction { + override def operator: UnaryArithmeticOperator = Sign + + override def args: List[PainlessScript] = List(arg) + + override def outputType: SQLNumeric = SQLTypes.Int + + override def painless: String = { + val ret = "arg0 > 0 ? 1 : (arg0 < 0 ? -1 : 0)" + if (arg.nullable) + s"(def arg0 = ${arg.painless}; arg0 != null ? ($ret) : null)" + else + s"(def arg0 = ${arg.painless}; $ret)" + } +} + +case class SQLAtan2(y: PainlessScript, x: PainlessScript) extends MathematicalFunction { + override def operator: UnaryArithmeticOperator = Atan2 + override def args: List[PainlessScript] = List(y, x) + override def nullable: Boolean = y.nullable || x.nullable +} diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala index 1e9eb1fb..cb572ef9 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLOperator.scala @@ -9,19 +9,66 @@ trait SQLOperator extends SQLToken with PainlessScript with SQLRegex { case Like | Match => ".matches" case Eq => "==" case Ne => "!=" + case Plus => ".plus" + case Minus => ".minus" + case IsNull => " == null" + case IsNotNull => " != null" case _ => sql } } -sealed trait ArithmeticOperator extends SQLOperator with MathScript { +sealed trait BinaryOperator extends SQLOperator + +sealed trait ArithmeticOperator extends SQLOperator { override def toString: String = s" $sql " +} + +sealed trait BinaryArithmeticOperator extends ArithmeticOperator with BinaryOperator + +sealed trait IntervalOperator extends BinaryArithmeticOperator with MathScript { override def script: String = sql } -case object Add extends SQLExpr("+") with ArithmeticOperator -case object Subtract extends SQLExpr("-") with ArithmeticOperator -case object Multiply extends SQLExpr("*") with ArithmeticOperator -case object Divide extends SQLExpr("/") with ArithmeticOperator -case object Modulo extends SQLExpr("%") with ArithmeticOperator +case object Plus extends SQLExpr("+") with IntervalOperator { + override def painless: String = ".plus" +} +case object Minus extends SQLExpr("-") with IntervalOperator { + override def painless: String = ".minus" +} + +case object Add extends SQLExpr("+") with IntervalOperator +case object Subtract extends SQLExpr("-") with IntervalOperator + +case object Multiply extends SQLExpr("*") with BinaryArithmeticOperator +case object Divide extends SQLExpr("/") with BinaryArithmeticOperator +case object Modulo extends SQLExpr("%") with BinaryArithmeticOperator + +sealed trait UnaryArithmeticOperator extends ArithmeticOperator { + override def painless: String = s"Math.${sql.toLowerCase()}" +} + +case object Abs extends SQLExpr("abs") with UnaryArithmeticOperator +case object Ceil extends SQLExpr("ceil") with UnaryArithmeticOperator +case object Floor extends SQLExpr("floor") with UnaryArithmeticOperator +case object Round extends SQLExpr("round") with UnaryArithmeticOperator +case object Exp extends SQLExpr("exp") with UnaryArithmeticOperator +case object Log extends SQLExpr("log") with UnaryArithmeticOperator +case object Log10 extends SQLExpr("log10") with UnaryArithmeticOperator +case object Pow extends SQLExpr("pow") with UnaryArithmeticOperator +case object Sqrt extends SQLExpr("sqrt") with UnaryArithmeticOperator +case object Sign extends SQLExpr("sign") with UnaryArithmeticOperator +case object Pi extends SQLExpr("pi") with UnaryArithmeticOperator { + override def painless: String = "Math.PI" +} + +sealed trait TrigonometricOperator extends UnaryArithmeticOperator + +case object Sin extends SQLExpr("sin") with TrigonometricOperator +case object Asin extends SQLExpr("asin") with TrigonometricOperator +case object Cos extends SQLExpr("cos") with TrigonometricOperator +case object Acos extends SQLExpr("acos") with TrigonometricOperator +case object Tan extends SQLExpr("tan") with TrigonometricOperator +case object Atan extends SQLExpr("atan") with TrigonometricOperator +case object Atan2 extends SQLExpr("atan2") with TrigonometricOperator sealed trait SQLExpressionOperator extends SQLOperator @@ -46,11 +93,9 @@ case object Lt extends SQLExpr("<") with SQLComparisonOperator case object In extends SQLExpr("in") with SQLComparisonOperator case object Like extends SQLExpr("like") with SQLComparisonOperator case object Between extends SQLExpr("between") with SQLComparisonOperator -case object IsNull extends SQLExpr("is null") with SQLComparisonOperator with SQLConditionalOperator -case object IsNotNull - extends SQLExpr("is not null") - with SQLComparisonOperator - with SQLConditionalOperator +case object IsNull extends SQLExpr("is null") with SQLComparisonOperator +case object IsNotNull extends SQLExpr("is not null") with SQLComparisonOperator + case object Match extends SQLExpr("match") with SQLComparisonOperator case object Against extends SQLExpr("against") with SQLRegex @@ -58,22 +103,26 @@ sealed trait SQLLogicalOperator extends SQLExpressionOperator case object Not extends SQLExpr("not") with SQLLogicalOperator +sealed trait SQLPredicateOperator extends SQLLogicalOperator + +case object And extends SQLExpr("and") with SQLPredicateOperator +case object Or extends SQLExpr("or") with SQLPredicateOperator + sealed trait SQLConditionalOperator extends SQLExpressionOperator case object Coalesce extends SQLExpr("coalesce") with SQLConditionalOperator +case object IsNullFunction extends SQLExpr("isnull") with SQLConditionalOperator +case object IsNotNullFunction extends SQLExpr("isnotnull") with SQLConditionalOperator case object NullIf extends SQLExpr("nullif") with SQLConditionalOperator case object Exists extends SQLExpr("exists") with SQLConditionalOperator + case object Cast extends SQLExpr("cast") with SQLConditionalOperator case object Case extends SQLExpr("case") with SQLConditionalOperator + case object When extends SQLExpr("when") with SQLRegex case object Then extends SQLExpr("then") with SQLRegex case object Else extends SQLExpr("else") with SQLRegex case object End extends SQLExpr("end") with SQLRegex -sealed trait SQLPredicateOperator extends SQLLogicalOperator - -case object And extends SQLExpr("and") with SQLPredicateOperator -case object Or extends SQLExpr("or") with SQLPredicateOperator - case object Union extends SQLExpr("union") with SQLOperator with SQLRegex sealed trait ElasticOperator extends SQLOperator with SQLRegex diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala index ff6db34c..201c18cc 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLParser.scala @@ -4,6 +4,8 @@ import scala.util.parsing.combinator.{PackratParsers, RegexParsers} import scala.util.parsing.input.CharSequenceReader import TimeUnit._ +import scala.language.implicitConversions + /** Created by smanciot on 27/06/2018. * * SQL Parser for ElasticSearch @@ -68,12 +70,16 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => def double: PackratParser[SQLDoubleValue] = """(-)?(\d+\.\d+)""".r ^^ (str => SQLDoubleValue(str.toDouble)) + def pi: PackratParser[SQLValue[Double]] = + Pi.regex ^^ (_ => SQLPiValue) + def boolean: PackratParser[SQLBoolean] = """(true|false)""".r ^^ (bool => SQLBoolean(bool.toBoolean)) - /*def value_identifier: PackratParser[SQLIdentifier] = (literal | long | double | boolean) ^^ { v => - SQLIdentifier("", functions = v :: Nil) - }*/ + def value_identifier: PackratParser[SQLIdentifier] = (literal | long | double | pi | boolean) ^^ { + v => + SQLIdentifier("", functions = v :: Nil) + } def start: PackratParser[SQLDelimiter] = "(" ^^ (_ => StartPredicate) @@ -110,11 +116,6 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => def time_unit: PackratParser[TimeUnit] = year | month | quarter | week | day | hour | minute | second - def interval: PackratParser[TimeInterval] = - Interval.regex ~ long ~ time_unit ^^ { case _ ~ l ~ u => - TimeInterval(l.value.toInt, u) - } - def parens: PackratParser[List[SQLDelimiter]] = start ~ end ^^ { case s ~ e => s :: e :: Nil } @@ -137,13 +138,50 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => if (p.isDefined) NowWithParens else Now } - def add: PackratParser[ArithmeticOperator] = Add.sql ^^ (_ => Add) + def add: PackratParser[IntervalOperator] = Add.sql ^^ (_ => Add) + + def subtract: PackratParser[IntervalOperator] = Subtract.sql ^^ (_ => Subtract) - def subtract: PackratParser[ArithmeticOperator] = Subtract.sql ^^ (_ => Subtract) + def multiply: PackratParser[ArithmeticOperator] = Multiply.sql ^^ (_ => Multiply) - def intervalOperator: PackratParser[ArithmeticOperator] = add | subtract + def divide: PackratParser[ArithmeticOperator] = Divide.sql ^^ (_ => Divide) - def arithmeticOperator: PackratParser[ArithmeticOperator] = intervalOperator + def modulo: PackratParser[ArithmeticOperator] = Modulo.sql ^^ (_ => Modulo) + + def factor: PackratParser[PainlessScript] = + "(" ~> arithmeticExpressionLevel2 <~ ")" ^^ { + case expr: SQLArithmeticExpression => + expr.copy(group = true) + case other => other + } | valueExpr + + def arithmeticExpressionLevel1: Parser[PainlessScript] = + factor ~ rep((multiply | divide | modulo) ~ factor) ^^ { case left ~ list => + list.foldLeft(left) { case (acc, op ~ right) => + SQLArithmeticExpression(acc, op, right) + } + } + + def arithmeticExpressionLevel2: Parser[PainlessScript] = + arithmeticExpressionLevel1 ~ rep((add | subtract) ~ arithmeticExpressionLevel1) ^^ { + case left ~ list => + list.foldLeft(left) { case (acc, op ~ right) => + SQLArithmeticExpression(acc, op, right) + } + } + + def identifierWithArithmeticExpression: Parser[SQLIdentifier] = arithmeticExpressionLevel2 ^^ { + case af: SQLArithmeticExpression => SQLIdentifier("", functions = af :: Nil) + case id: SQLIdentifier => id + case f: SQLFunctionWithIdentifier => f.identifier + case f: SQLFunction => SQLIdentifier("", functions = f :: Nil) + case other => throw new Exception(s"Unexpected expression $other") + } + + def interval: PackratParser[TimeInterval] = + Interval.regex ~ long ~ time_unit ^^ { case _ ~ l ~ u => + TimeInterval(l.value.toInt, u) + } def add_interval: PackratParser[SQLAddInterval] = add ~ interval ^^ { case _ ~ it => @@ -155,9 +193,14 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => SQLSubtractInterval(it) } - def intervalFunction: PackratParser[SQLArithmeticFunction[SQLTemporal, SQLTemporal]] = + def intervalFunction: PackratParser[SQLTransformFunction[SQLTemporal, SQLTemporal]] = add_interval | substract_interval + def identifierWithIntervalFunction: PackratParser[SQLIdentifier] = + (identifierWithFunction | identifier) ~ intervalFunction ^^ { case i ~ f => + i.copy(functions = f +: i.functions) + } + def identifierWithSystemFunction: PackratParser[SQLIdentifier] = (current_date | current_time | current_timestamp | now) ~ intervalFunction.? ^^ { case f1 ~ f2 => @@ -168,52 +211,52 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => } def date_trunc: PackratParser[SQLFunctionWithIdentifier] = - "(?i)date_trunc".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { + "(?i)date_trunc".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithIntervalFunction | identifier) ~ separator ~ time_unit ~ end ^^ { case _ ~ _ ~ i ~ _ ~ u ~ _ => DateTrunc(i, u) } def extract_identifier: PackratParser[SQLIdentifier] = - "(?i)extract".r ~ start ~ time_unit ~ "(?i)from".r ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ end ^^ { + "(?i)extract".r ~ start ~ time_unit ~ "(?i)from".r ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithIntervalFunction | identifier) ~ end ^^ { case _ ~ _ ~ u ~ _ ~ i ~ _ => i.copy(functions = Extract(u) +: i.functions) } - def extract_year: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = + def extract_year: PackratParser[SQLTransformFunction[SQLTemporal, SQLNumeric]] = Year.regex ^^ (_ => YEAR) - def extract_month: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = + def extract_month: PackratParser[SQLTransformFunction[SQLTemporal, SQLNumeric]] = Month.regex ^^ (_ => MONTH) - def extract_day: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = + def extract_day: PackratParser[SQLTransformFunction[SQLTemporal, SQLNumeric]] = Day.regex ^^ (_ => DAY) - def extract_hour: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = + def extract_hour: PackratParser[SQLTransformFunction[SQLTemporal, SQLNumeric]] = Hour.regex ^^ (_ => HOUR) - def extract_minute: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = + def extract_minute: PackratParser[SQLTransformFunction[SQLTemporal, SQLNumeric]] = Minute.regex ^^ (_ => MINUTE) - def extract_second: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = + def extract_second: PackratParser[SQLTransformFunction[SQLTemporal, SQLNumeric]] = Second.regex ^^ (_ => SECOND) - def extractors: PackratParser[SQLUnaryFunction[SQLTemporal, SQLNumeric]] = + def extractors: PackratParser[SQLTransformFunction[SQLTemporal, SQLNumeric]] = extract_year | extract_month | extract_day | extract_hour | extract_minute | extract_second def date_add: PackratParser[DateFunction with SQLFunctionWithIdentifier] = - "(?i)date_add".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + "(?i)date_add".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithIntervalFunction | identifier) ~ separator ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ ~ t ~ _ => DateAdd(i, t) } def date_sub: PackratParser[DateFunction with SQLFunctionWithIdentifier] = - "(?i)date_sub".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + "(?i)date_sub".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithIntervalFunction | identifier) ~ separator ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ ~ t ~ _ => DateSub(i, t) } def parse_date: PackratParser[DateFunction with SQLFunctionWithIdentifier] = - "(?i)parse_date".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | literal | identifier) ~ separator ~ literal ~ end ^^ { + "(?i)parse_date".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithIntervalFunction | literal | identifier) ~ separator ~ literal ~ end ^^ { case _ ~ _ ~ li ~ _ ~ f ~ _ => li match { case l: SQLStringValue => @@ -224,7 +267,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => } def format_date: PackratParser[DateFunction with SQLFunctionWithIdentifier] = - "(?i)format_date".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + "(?i)format_date".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithIntervalFunction | identifier) ~ separator ~ literal ~ end ^^ { case _ ~ _ ~ i ~ _ ~ f ~ _ => FormatDate(i, f.value) } @@ -232,19 +275,19 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => def date_functions: PackratParser[DateFunction] = date_add | date_sub | parse_date | format_date def datetime_add: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = - "(?i)datetime_add".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + "(?i)datetime_add".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithIntervalFunction | identifier) ~ separator ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ ~ t ~ _ => DateTimeAdd(i, t) } def datetime_sub: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = - "(?i)datetime_sub".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ interval ~ end ^^ { + "(?i)datetime_sub".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithIntervalFunction | identifier) ~ separator ~ interval ~ end ^^ { case _ ~ _ ~ i ~ _ ~ t ~ _ => DateTimeSub(i, t) } def parse_datetime: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = - "(?i)parse_datetime".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | literal | identifier) ~ separator ~ literal ~ end ^^ { + "(?i)parse_datetime".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithIntervalFunction | literal | identifier) ~ separator ~ literal ~ end ^^ { case _ ~ _ ~ li ~ _ ~ f ~ _ => li match { case l: SQLLiteral => @@ -255,7 +298,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => } def format_datetime: PackratParser[DateTimeFunction with SQLFunctionWithIdentifier] = - "(?i)format_datetime".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ literal ~ end ^^ { + "(?i)format_datetime".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithIntervalFunction | identifier) ~ separator ~ literal ~ end ^^ { case _ ~ _ ~ i ~ _ ~ f ~ _ => FormatDateTime(i, f.value) } @@ -286,7 +329,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => } def date_diff: PackratParser[SQLBinaryFunction[_, _, _]] = - "(?i)date_diff".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithArithmeticFunction | identifier) ~ separator ~ time_unit ~ end ^^ { + "(?i)date_diff".r ~ start ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithIntervalFunction | identifier) ~ separator ~ (identifierWithTemporalFunction | identifierWithSystemFunction | identifierWithIntervalFunction | identifier) ~ separator ~ time_unit ~ end ^^ { case _ ~ _ ~ d1 ~ _ ~ d2 ~ _ ~ u ~ _ => DateDiff(d1, d2, u) } @@ -294,17 +337,13 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => SQLIdentifier("", functions = dd :: Nil) } - def case_when_identifier: Parser[SQLIdentifier] = case_when ^^ { cw => - SQLIdentifier("", functions = cw :: Nil) - } - def is_null: PackratParser[SQLConditionalFunction[_]] = - "(?i)isnull".r ~ start ~ (identifierWithTransformation | identifierWithArithmeticFunction | identifierWithTemporalFunction | identifier) ~ end ^^ { + "(?i)isnull".r ~ start ~ (identifierWithTransformation | identifierWithIntervalFunction | identifierWithTemporalFunction | identifier) ~ end ^^ { case _ ~ _ ~ i ~ _ => SQLIsNullFunction(i) } def is_notnull: PackratParser[SQLConditionalFunction[_]] = - "(?i)isnotnull".r ~ start ~ (identifierWithTransformation | identifierWithArithmeticFunction | identifierWithTemporalFunction | identifier) ~ end ^^ { + "(?i)isnotnull".r ~ start ~ (identifierWithTransformation | identifierWithIntervalFunction | identifierWithTemporalFunction | identifier) ~ end ^^ { case _ ~ _ ~ i ~ _ => SQLIsNotNullFunction(i) } @@ -314,12 +353,13 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => date_diff_identifier | // date_diff(...) retournant un identifier-like extract_identifier | identifierWithSystemFunction | // CURRENT_DATE, NOW, etc. (+/- interval) - identifierWithArithmeticFunction | // foo - interval ... + identifierWithIntervalFunction | identifierWithTemporalFunction | // chaîne de fonctions appliquées à un identifier identifierWithFunction | // fonctions appliquées à un identifier literal | // 'string' - long | + pi | double | + long | boolean | identifier @@ -365,9 +405,88 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => case _ ~ e ~ c ~ r ~ _ => SQLCaseWhen(e, c, r) } + def case_when_identifier: Parser[SQLIdentifier] = case_when ^^ { cw => + SQLIdentifier("", functions = cw :: Nil) + } + def logical_functions: PackratParser[SQLTransformFunction[_, _]] = is_null | is_notnull | coalesce | nullif | case_when + private[this] def abs: PackratParser[UnaryArithmeticOperator] = Abs.regex ^^ (_ => Abs) + + private[this] def ceil: PackratParser[UnaryArithmeticOperator] = Ceil.regex ^^ (_ => Ceil) + + private[this] def floor: PackratParser[UnaryArithmeticOperator] = Floor.regex ^^ (_ => Floor) + + private[this] def exp: PackratParser[UnaryArithmeticOperator] = Exp.regex ^^ (_ => Exp) + + private[this] def sqrt: PackratParser[UnaryArithmeticOperator] = Sqrt.regex ^^ (_ => Sqrt) + + private[this] def log: PackratParser[UnaryArithmeticOperator] = Log.regex ^^ (_ => Log) + + private[this] def log10: PackratParser[UnaryArithmeticOperator] = Log10.regex ^^ (_ => Log10) + + implicit def functionAsIdentifier(mf: SQLFunction): SQLIdentifier = mf match { + case id: SQLIdentifier => id + case fid: SQLFunctionWithIdentifier => fid.identifier + case _ => SQLIdentifier("", functions = mf :: Nil) + } + + def arithmeticFunction: PackratParser[MathematicalFunction] = + (abs | ceil | exp | floor | log | log10 | sqrt) ~ start ~ valueExpr ~ end ^^ { + case op ~ _ ~ v ~ _ => SQLMathematicalFunction(op, v) + } + + private[this] def sin: PackratParser[TrigonometricOperator] = Sin.regex ^^ (_ => Sin) + + private[this] def asin: PackratParser[TrigonometricOperator] = Asin.regex ^^ (_ => Asin) + + private[this] def cos: PackratParser[TrigonometricOperator] = Cos.regex ^^ (_ => Cos) + + private[this] def acos: PackratParser[TrigonometricOperator] = Acos.regex ^^ (_ => Acos) + + private[this] def tan: PackratParser[TrigonometricOperator] = Tan.regex ^^ (_ => Tan) + + private[this] def atan: PackratParser[TrigonometricOperator] = Atan.regex ^^ (_ => Atan) + + private[this] def atan2: PackratParser[TrigonometricOperator] = Atan2.regex ^^ (_ => Atan2) + + def atan2Function: PackratParser[MathematicalFunction] = + atan2 ~ start ~ (double | valueExpr) ~ separator ~ (double | valueExpr) ~ end ^^ { + case _ ~ _ ~ y ~ _ ~ x ~ _ => SQLAtan2(y, x) + } + + def trigonometricFunction: PackratParser[MathematicalFunction] = + atan2Function | ((sin | asin | cos | acos | tan | atan) ~ start ~ valueExpr ~ end ^^ { + case op ~ _ ~ v ~ _ => SQLMathematicalFunction(op, v) + }) + + private[this] def round: PackratParser[UnaryArithmeticOperator] = Round.regex ^^ (_ => Round) + + def roundFunction: PackratParser[MathematicalFunction] = + round ~ start ~ valueExpr ~ separator.? ~ long.? ~ end ^^ { case _ ~ _ ~ v ~ _ ~ s ~ _ => + SQLRound(v, s.map(_.value.toInt)) + } + + private[this] def pow: PackratParser[UnaryArithmeticOperator] = Pow.regex ^^ (_ => Pow) + + def powFunction: PackratParser[MathematicalFunction] = + pow ~ start ~ valueExpr ~ separator ~ long ~ end ^^ { case _ ~ _ ~ v1 ~ _ ~ e ~ _ => + SQLPow(v1, e.value.toInt) + } + + private[this] def sign: PackratParser[UnaryArithmeticOperator] = Sign.regex ^^ (_ => Sign) + + def signFunction: PackratParser[MathematicalFunction] = + sign ~ start ~ valueExpr ~ end ^^ { case _ ~ _ ~ v ~ _ => SQLSign(v) } + + def mathematicalFunction: PackratParser[MathematicalFunction] = + arithmeticFunction | trigonometricFunction | roundFunction | powFunction | signFunction + + def mathematicalFunctionWithIdentifier: PackratParser[SQLIdentifier] = mathematicalFunction ^^ { + mf => mf.identifier + } + def sql_functions: PackratParser[SQLFunction] = aggregates | distance | date_diff | date_trunc | extractors | date_functions | datetime_functions | logical_functions @@ -426,6 +545,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => "real", "float", "double", + "pi", "boolean", "time", "date", @@ -448,7 +568,34 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => "when", "then", "else", - "end" + "end", + "union", + "all", + "exists", + "true", + "false", +// "nested", +// "parent", +// "child", + "match", + "against", + "abs", + "ceil", + "floor", + "exp", + "log", + "log10", + "sqrt", + "round", + "pow", + "sign", + "sin", + "asin", + "cos", + "acos", + "tan", + "atan", + "atan2" ) private val identifierRegexStr = @@ -504,7 +651,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => char_type | string_type | datetime_type | timestamp_type | date_type | time_type | boolean_type | long_type | double_type | float_type | int_type | short_type | byte_type private[this] def castFunctionWithIdentifier: PackratParser[SQLIdentifier] = - "(?i)cast".r ~ start ~ (identifierWithTransformation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | extract_identifier | identifier) ~ Alias.regex.? ~ sql_type ~ end ~ arithmeticFunction.? ^^ { + "(?i)cast".r ~ start ~ (identifierWithTransformation | identifierWithSystemFunction | identifierWithIntervalFunction | identifierWithFunction | date_diff_identifier | extract_identifier | identifier) ~ Alias.regex.? ~ sql_type ~ end ~ intervalFunction.? ^^ { case _ ~ _ ~ i ~ as ~ t ~ _ ~ a => i.copy(functions = a.toList ++ (SQLCast(i, targetType = t, as = as.isDefined) +: i.functions) @@ -512,7 +659,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => } private[this] def dateFunctionWithIdentifier: PackratParser[SQLIdentifier] = - (parse_date | format_date | date_add | date_sub) ~ arithmeticFunction.? ^^ { case t ~ af => + (parse_date | format_date | date_add | date_sub) ~ intervalFunction.? ^^ { case t ~ af => af match { case Some(f) => t.identifier.copy(functions = f +: t +: t.identifier.functions) case None => t.identifier.copy(functions = t +: t.identifier.functions) @@ -520,7 +667,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => } private[this] def dateTimeFunctionWithIdentifier: PackratParser[SQLIdentifier] = - (date_trunc | parse_datetime | format_datetime | datetime_add | datetime_sub) ~ arithmeticFunction.? ^^ { + (date_trunc | parse_datetime | format_datetime | datetime_add | datetime_sub) ~ intervalFunction.? ^^ { case t ~ af => af match { case Some(f) => t.identifier.copy(functions = f +: t +: t.identifier.functions) @@ -534,17 +681,10 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => } def identifierWithTransformation: PackratParser[SQLIdentifier] = - castFunctionWithIdentifier | conditionalFunctionWithIdentifier | dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier - - def arithmeticFunction: PackratParser[SQLArithmeticFunction[_, _]] = intervalFunction - - def identifierWithArithmeticFunction: PackratParser[SQLIdentifier] = - (identifierWithFunction | identifier) ~ arithmeticFunction ^^ { case i ~ af => - i.copy(functions = af +: i.functions) - } + mathematicalFunctionWithIdentifier | castFunctionWithIdentifier | conditionalFunctionWithIdentifier | dateFunctionWithIdentifier | dateTimeFunctionWithIdentifier def identifierWithAggregation: PackratParser[SQLIdentifier] = - aggregates ~ start ~ (identifierWithFunction | identifierWithArithmeticFunction | identifier) ~ end ^^ { + aggregates ~ start ~ (identifierWithFunction | identifierWithIntervalFunction | identifier) ~ end ^^ { case a ~ _ ~ i ~ _ => i.copy(functions = a +: i.functions) } @@ -553,7 +693,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => rep1sep( sql_functions, start - ) ~ start.? ~ (identifierWithSystemFunction | identifierWithArithmeticFunction | identifier).? ~ rep1( + ) ~ start.? ~ (identifierWithSystemFunction | identifierWithIntervalFunction | identifier).? ~ rep1( end ) ^^ { case f ~ _ ~ i ~ _ => i match { @@ -573,7 +713,7 @@ trait SQLParser extends RegexParsers with PackratParsers { _: SQLWhereParser => def alias: PackratParser[SQLAlias] = Alias.regex.? ~ regexAlias.r ^^ { case _ ~ b => SQLAlias(b) } def field: PackratParser[Field] = - (identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | extract_identifier | case_when_identifier | identifier) ~ alias.? ^^ { + (identifierWithArithmeticExpression | identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithIntervalFunction | identifierWithFunction | date_diff_identifier | extract_identifier | case_when_identifier | identifier) ~ alias.? ^^ { case i ~ a => SQLField(i, a) } @@ -633,10 +773,10 @@ trait SQLWhereParser { private def diff: PackratParser[SQLComparisonOperator] = Diff.sql ^^ (_ => Diff) private def any_identifier: PackratParser[SQLIdentifier] = - identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithArithmeticFunction | identifierWithFunction | date_diff_identifier | extract_identifier | identifier + identifierWithTransformation | identifierWithAggregation | identifierWithSystemFunction | identifierWithIntervalFunction | identifierWithArithmeticExpression | identifierWithFunction | date_diff_identifier | extract_identifier | identifier private def equality: PackratParser[SQLExpression] = - not.? ~ any_identifier ~ (eq | ne | diff) ~ (boolean | literal | double | long | any_identifier) ^^ { + not.? ~ any_identifier ~ (eq | ne | diff) ~ (boolean | literal | double | pi | long | any_identifier) ^^ { case n ~ i ~ o ~ v => SQLExpression(i, o, v, n) } @@ -654,7 +794,7 @@ trait SQLWhereParser { def lt: PackratParser[SQLComparisonOperator] = Lt.sql ^^ (_ => Lt) private def comparison: PackratParser[SQLExpression] = - not.? ~ any_identifier ~ (ge | gt | le | lt) ~ (double | long | literal | any_identifier) ^^ { + not.? ~ any_identifier ~ (ge | gt | le | lt) ~ (double | pi | long | literal | any_identifier) ^^ { case n ~ i ~ o ~ v => SQLExpression(i, o, v, n) } diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala index 5aa56a26..fcbb3021 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLSelect.scala @@ -7,7 +7,7 @@ sealed trait Field extends Updateable with SQLFunctionChain with PainlessScript def fieldAlias: Option[SQLAlias] def isScriptField: Boolean = functions.nonEmpty && !aggregation && identifier.bucket.isEmpty override def sql: String = s"$identifier${asString(fieldAlias)}" - lazy val sourceField: String = + lazy val sourceField: String = { if (identifier.nested) { identifier.tableAlias .orElse(fieldAlias.map(_.alias)) @@ -18,9 +18,14 @@ sealed trait Field extends Updateable with SQLFunctionChain with PainlessScript .split("\\.") .tail .mkString(".") + } else if (identifier.name.nonEmpty) { + identifier.name + .replace("(", "") + .replace(")", "") } else { - identifier.name.replace("(", "").replace(")", "") + AliasUtils.normalize(identifier.identifierName) } + } override def functions: List[SQLFunction] = identifier.functions diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala index 7dc0911b..fcf2ba0a 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLTypeUtils.scala @@ -43,6 +43,7 @@ object SQLTypeUtils { .contains( out.typeId )) || + (out.isInstanceOf[SQLNumeric] && in.isInstanceOf[SQLNumeric]) || (out.typeId == Varchar.typeId && in.typeId == Varchar.typeId) || (out.typeId == Boolean.typeId && in.typeId == Boolean.typeId) || out.typeId == Any.typeId || in.typeId == Any.typeId || diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala index 15994d96..c041ce3d 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLValidator.scala @@ -12,8 +12,8 @@ object SQLValidator { case Some(left) => return left case None => } - val unaryFuncs = functions.collect { case f: SQLUnaryFunction[_, _] => f } - unaryFuncs.sliding(2).foreach { + val funcs = functions.collect { case f: SQLFunctionN[_, _] => f } + funcs.sliding(2).foreach { case Seq(f1, f2) => validateTypesMatching(f2.outputType, f1.inputType) case _ => // ok diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala index 9e244acb..44420425 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/SQLWhere.scala @@ -327,7 +327,7 @@ sealed trait SQLCriteriaWithConditionalFunction[In <: SQLType] extends Expressio override def maybeNot: Option[Not.type] = None override def asFilter(currentQuery: Option[ElasticBoolQuery]): ElasticFilter = this override val functions: List[SQLFunction] = List(conditionalFunction) - override def sql = s"${conditionalFunction.sql}($identifier)" + override def sql: String = conditionalFunction.sql } object SQLConditionalFunctionAsCriteria { diff --git a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala index 48db25a6..522e8358 100644 --- a/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala +++ b/sql/src/main/scala/app/softnetwork/elastic/sql/package.scala @@ -1,5 +1,6 @@ package app.softnetwork.elastic +import java.security.MessageDigest import java.util.regex.Pattern import scala.reflect.runtime.universe._ import scala.util.Try @@ -179,6 +180,18 @@ package object sql { override def out: SQLNumeric = SQLTypes.Double } + case object SQLPiValue extends SQLValue[Double](Math.PI) { + override def sql: String = "pi" + override def painless: String = "Math.PI" + override def out: SQLNumeric = SQLTypes.Double + } + + case object SQLEValue extends SQLValue[Double](Math.E) { + override def sql: String = "e" + override def painless: String = "Math.E" + override def out: SQLNumeric = SQLTypes.Double + } + sealed abstract class SQLFromTo[+T](val from: SQLValue[T], val to: SQLValue[T]) extends SQLToken { override def sql = s"${from.sql} and ${to.sql}" } @@ -313,6 +326,41 @@ package object sql { case class SQLAlias(alias: String) extends SQLExpr(s" ${Alias.sql} $alias") + object AliasUtils { + private val MaxAliasLength = 50 + + private val opMapping = Map( + "+" -> "plus", + "-" -> "minus", + "*" -> "mul", + "/" -> "div", + "%" -> "mod" + ) + + def normalize(expr: String): String = { + // Remplacer les opérateurs SQL par des noms lisibles + val replaced = opMapping.foldLeft(expr) { case (acc, (k, v)) => + acc.replace(k, s"_${v}_") + } + // Nettoyer pour obtenir un identifiant valide + val normalized = replaced + .replaceAll("[^a-zA-Z0-9_]", "_") // caractères invalides -> "_" + .replaceAll("_+", "_") // compacter plusieurs "_" + .stripPrefix("_") + .stripSuffix("_") + .toLowerCase + + // Tronquer si nécessaire + if (normalized.length > MaxAliasLength) { + val digest = MessageDigest.getInstance("MD5").digest(normalized.getBytes("UTF-8")) + val hash = digest.map("%02x".format(_)).mkString.take(8) // suffix court + normalized.take(MaxAliasLength - hash.length - 1) + "_" + hash + } else { + normalized + } + } + } + trait SQLRegex extends SQLToken { lazy val regex: Regex = s"\\b(?i)$sql\\b".r } @@ -330,13 +378,30 @@ package object sql { def nested: Boolean def fieldAlias: Option[String] def bucket: Option[SQLBucket] + override def sql: String = { + var parts: Seq[String] = name.split("\\.").toSeq + tableAlias match { + case Some(a) => parts = a +: (if (nested) parts.tail else parts) + case _ => + } + val sql = { + if (distinct) { + s"$Distinct ${parts.mkString(".")}".trim + } else { + parts.mkString(".").trim + } + } + functions.reverse.foldLeft(sql)((expr, fun) => { + fun.toSQL(expr) + }) + } applyTo(this) lazy val identifierName: String = functions.reverse.foldLeft(name)((expr, fun) => { fun.toSQL(expr) - }) + }) // FIXME use AliasUtils.normalize? lazy val nestedType: Option[String] = if (nested) Some(name.split('.').head) else None @@ -395,24 +460,7 @@ package object sql { functions: List[SQLFunction] = List.empty, fieldAlias: Option[String] = None, bucket: Option[SQLBucket] = None - ) extends SQLExpr({ - var parts: Seq[String] = name.split("\\.").toSeq - tableAlias match { - case Some(a) => parts = a +: (if (nested) parts.tail else parts) - case _ => - } - val sql = { - if (distinct) { - s"$Distinct ${parts.mkString(".")}".trim - } else { - parts.mkString(".").trim - } - } - functions.reverse.foldLeft(sql)((expr, fun) => { - fun.toSQL(expr) - }) - }) - with Identifier { + ) extends Identifier { def update(request: SQLSearchRequest): SQLIdentifier = { val parts: Seq[String] = name.split("\\.").toSeq diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala index dc9c05fc..d7d37b55 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLDateTimeFunctionSuite.scala @@ -39,7 +39,7 @@ class SQLDateTimeFunctionSuite extends AnyFunSuite { (transforms.head.toPainless(base, 0), transforms.head.outputType.asInstanceOf[SQLType]) val (finalExpr, _) = transforms.tail.foldLeft(initial) { - case ((expr, currentType), t: SQLUnaryFunction[_, _]) => + case ((expr, currentType), t: SQLFunctionN[_, _]) => if (!currentType.getClass.isAssignableFrom(t.inputType.getClass)) { throw new IllegalArgumentException( s"Type mismatch: expected ${currentType.getClass.getSimpleName}, got ${t.inputType.getClass.getSimpleName}" diff --git a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala index 972ba528..b09f0f34 100644 --- a/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala +++ b/sql/src/test/scala/app/softnetwork/elastic/sql/SQLParserSpec.scala @@ -152,6 +152,12 @@ object Queries { val extract: String = "select extract(day from createdAt) as day, extract(month from createdAt) as month, extract(year from createdAt) as year, extract(hour from createdAt) as hour, extract(minute from createdAt) as minute, extract(second from createdAt) as second from Table" + + val arithmetic: String = + "select identifier, identifier + 1 as add, identifier - 1 as sub, identifier * 2 as mul, identifier / 2 as div, identifier % 2 as mod, (identifier * identifier2) - 10 as group1 from Table where identifier * (extract(year from current_date) - 10) > 10000" + + val mathematical: String = + "select identifier, (abs(identifier) + 1.0) * 2, ceil(identifier), floor(identifier), sqrt(identifier), exp(identifier), log(identifier), log10(identifier), pow(identifier, 3), round(identifier), round(identifier, 2), sign(identifier), cos(identifier), acos(identifier), sin(identifier), asin(identifier), tan(identifier), atan(identifier), atan2(identifier, 3.0) from Table where sqrt(identifier) > 100.0" } /** Created by smanciot on 15/02/17. @@ -604,4 +610,18 @@ class SQLParserSpec extends AnyFlatSpec with Matchers { ) } + it should "parse arithmetic expressions" in { + val result = SQLParser(arithmetic) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + arithmetic + ) + } + + it should "parse mathematical functions" in { + val result = SQLParser(mathematical) + result.toOption.flatMap(_.left.toOption.map(_.sql)).getOrElse("") should ===( + mathematical + ) + } + }