diff --git a/CLAUDE.md b/CLAUDE.md index 690660d..aa3fddf 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## What This Project Does -pycel2sql converts CEL (Common Expression Language) expressions into SQL WHERE clauses. It supports five SQL dialects: PostgreSQL, DuckDB, BigQuery, MySQL, and SQLite. +pycel2sql converts CEL (Common Expression Language) expressions into SQL WHERE clauses. It supports six SQL dialects: PostgreSQL, DuckDB, BigQuery, MySQL, SQLite, and Apache Spark. ## Commands @@ -68,7 +68,7 @@ Lark grammar rule names encode operators: `relation_eq`, `addition_add`, `multip | `__init__.py` | Public API: `convert()`, `convert_parameterized()`, `analyze()`, `introspect()` | | `_converter.py` | Core Converter — Lark Interpreter with visitor methods for every grammar rule | | `dialect/_base.py` | `Dialect` ABC (40+ abstract methods), `WriteFunc` type alias, `IndexAdvisor` protocol | -| `dialect/{postgres,duckdb,bigquery,mysql,sqlite}.py` | Concrete dialect implementations | +| `dialect/{postgres,duckdb,bigquery,mysql,sqlite,spark}.py` | Concrete dialect implementations | | `schema.py` | `Schema` / `FieldSchema` for JSON/array field detection | | `_analysis.py` | `IndexAnalyzer` — second-pass tree walker for index recommendations | | `_utils.py` | Validation, escaping, RE2→SQL regex conversion | @@ -78,11 +78,12 @@ Lark grammar rule names encode operators: `relation_eq`, `addition_add`, `multip ### Dialect Differences -- **PostgreSQL**: `$N` params, `ARRAY[...]`, `~ / ~*` regex, `->>/->` JSON, `POSITION()` for contains -- **DuckDB**: `$N` params, `[...]` arrays, RE2 regex, `CONTAINS()`, `STRING_SPLIT()` -- **BigQuery**: `@pN` params, `[...]` arrays, `REGEXP_CONTAINS()`, `JSON_VALUE()`, `TIMESTAMP_ADD/SUB()` -- **MySQL**: `?` params, `JSON_ARRAY()`, `REGEXP`, `JSON_TABLE()` for unnest -- **SQLite**: `?` params, `json_array()`, no regex/split/join, `json_each()` for unnest +- **PostgreSQL**: `$N` params, `ARRAY[...]`, `~ / ~*` regex, `->>/->` JSON, `POSITION()` for contains, `FORMAT()` +- **DuckDB**: `$N` params, `[...]` arrays, RE2 regex, `CONTAINS()`, `STRING_SPLIT()`, `printf()` +- **BigQuery**: `@pN` params, `[...]` arrays, `REGEXP_CONTAINS()`, `JSON_VALUE()`, `TIMESTAMP_ADD/SUB()`, `FORMAT()` +- **MySQL**: `?` params, `JSON_ARRAY()`, `REGEXP`, `JSON_TABLE()` for unnest, `format()` raises `UnsupportedDialectFeatureError` +- **SQLite**: `?` params, `json_array()`, no regex/split/join, `json_each()` for unnest, `printf()` +- **Apache Spark**: `?` positional params, `array(...)`, `RLIKE`, `get_json_object()`, `concat()`, `array_contains(arr, elem)` (arg order swap), `EXPLODE` / `(SELECT collect_list(...))`, `format_string()`, `(dayofweek(t) - 1)` for day-of-week, JSON array membership raises (no boolean predicate available) ### Test Organization @@ -97,4 +98,8 @@ Unit tests (`tests/test_*.py`) cover each feature area per dialect. Integration - Depth tracking: `_visit_child()` increments/decrements `_depth` and checks limits - Error types use dual messaging pattern to prevent information disclosure (CWE-209) - `validate_schema` parameter: opt-in strict validation on `convert()`/`convert_parameterized()`/`analyze()`. Validates `table.field` references exist in schemas; skips comprehension variables, bare identifiers, and nested JSON keys beyond the first field. Raises `InvalidSchemaError` (with dual messaging). Requires schemas to be provided. +- `json_variables` parameter: opt-in declaration that named CEL variables are flat JSONB columns. Field access (dot or bracket) emits dialect-specific JSON extraction. Takes precedence over schema-declared JSON. Comprehension iter vars shadow `json_variables` (collisions are not treated as JSON inside the comprehension body). +- `column_aliases` parameter: maps CEL identifier names to SQL column names. The alias is validated against the dialect's identifier rules; the original CEL name remains the schema key (alias is output-only). +- `param_start_index` parameter (only on `convert_parameterized()`): shifts the placeholder counter so the first parameter is `$N` / `@pN` instead of `$1` / `@p1`. Values < 1 are clamped to 1. Positional-`?` dialects (MySQL, SQLite, Spark) ignore the index in placeholder text but still preserve parameter ordering. +- `format()` is dispatched per-dialect via `Dialect.write_format`: PostgreSQL/BigQuery emit `FORMAT(...)`, SQLite/DuckDB emit `printf(...)`, Apache Spark emits `format_string(...)`, MySQL raises `UnsupportedDialectFeatureError`. - Ruff for linting, mypy strict for type checking, line length 100, target Python 3.12+ diff --git a/README.md b/README.md index a9050f7..0bd9a65 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ [![BigQuery](https://img.shields.io/badge/BigQuery-669DF6?logo=googlebigquery&logoColor=white)](https://cloud.google.com/bigquery) [![MySQL](https://img.shields.io/badge/MySQL-4479A1?logo=mysql&logoColor=white)](https://www.mysql.com/) [![SQLite](https://img.shields.io/badge/SQLite-003B57?logo=sqlite&logoColor=white)](https://www.sqlite.org/) +[![Apache Spark](https://img.shields.io/badge/Apache%20Spark-E25A1C?logo=apachespark&logoColor=white)](https://spark.apache.org/) Convert [CEL (Common Expression Language)](https://cel.dev/) expressions to SQL WHERE clauses. @@ -38,7 +39,7 @@ sql = convert('status == "active" || tags.size() > 0') ## Dialects -Five SQL dialects are supported: +Six SQL dialects are supported: ```python from pycel2sql import convert @@ -50,9 +51,13 @@ sql = convert('name == "alice"', dialect=get_dialect("mysql")) sql = convert('name == "alice"', dialect=get_dialect("sqlite")) sql = convert('name == "alice"', dialect=get_dialect("duckdb")) sql = convert('name == "alice"', dialect=get_dialect("bigquery")) +sql = convert('name == "alice"', dialect=get_dialect("spark")) # Or instantiate directly -from pycel2sql import PostgresDialect, MySQLDialect, SQLiteDialect, DuckDBDialect, BigQueryDialect +from pycel2sql import ( + PostgresDialect, MySQLDialect, SQLiteDialect, DuckDBDialect, + BigQueryDialect, SparkDialect, +) sql = convert('name == "alice"', dialect=MySQLDialect()) ``` @@ -75,13 +80,76 @@ result = convert_parameterized('name == "alice"', dialect=MySQLDialect()) Placeholder styles per dialect: -| Dialect | Placeholder | -|------------|-------------| -| PostgreSQL | `$1`, `$2`, ... | -| DuckDB | `$1`, `$2`, ... | -| BigQuery | `@p1`, `@p2`, ... | -| MySQL | `?` | -| SQLite | `?` | +| Dialect | Placeholder | +|---------------|--------------------| +| PostgreSQL | `$1`, `$2`, ... | +| DuckDB | `$1`, `$2`, ... | +| BigQuery | `@p1`, `@p2`, ... | +| MySQL | `?` (positional) | +| SQLite | `?` (positional) | +| Apache Spark | `?` (positional) | + +## Conversion Options + +### `json_variables` + +Declare CEL variable names that correspond to flat JSONB columns. Field access via dot notation or bracket notation emits dialect-specific JSON extraction: + +```python +from pycel2sql import convert + +# PostgreSQL: dot and bracket notation both produce ->> operators +sql = convert("context.host == 'a'", json_variables={"context"}) +# => context->>'host' = 'a' + +sql = convert('context["host"] == "a"', json_variables={"context"}) +# => context->>'host' = 'a' + +# Nested paths: intermediate keys use ->, final key uses ->> +sql = convert("tags.corpus.section == 'x'", json_variables={"tags"}) +# => tags->'corpus'->>'section' = 'x' +``` + +`json_variables` takes precedence over schema-declared JSON. Comprehension iter vars shadow `json_variables` (collisions are not treated as JSON inside the comprehension body). + +### `column_aliases` + +Map CEL identifier names to SQL column names. Useful when database columns use prefixed names while user-facing CEL expressions use clean names: + +```python +sql = convert("name == 'a'", column_aliases={"name": "usr_name"}) +# => usr_name = 'a' +``` + +The alias is validated against the dialect's identifier rules. The original CEL name remains the schema key — alias is output-only. + +### `param_start_index` + +Shift the placeholder counter for `convert_parameterized()` when embedding the generated fragment into a larger pre-parameterized query: + +```python +result = convert_parameterized( + "name == 'a' && age > 30", + param_start_index=5, +) +# result.sql => 'name = $5 AND age > $6' +# result.parameters => ['a', 30] +``` + +Values less than 1 are clamped to 1. For positional-`?` dialects (MySQL, SQLite, Apache Spark) the placeholder text is unchanged but the parameter ordering is preserved. + +### `format()` per-dialect mapping + +CEL's `string.format(args)` dispatches to dialect-specific SQL: + +| Dialect | Output | +|---------------|-------------------------| +| PostgreSQL | `FORMAT('...', ...)` | +| BigQuery | `FORMAT('...', ...)` | +| SQLite | `printf('...', ...)` | +| DuckDB | `printf('...', ...)` | +| Apache Spark | `format_string('...', ...)` | +| MySQL | raises `UnsupportedDialectFeatureError` | ## JSON Fields @@ -173,7 +241,7 @@ schemas = introspect_sqlite( ) ``` -All five dialects are supported: `introspect_postgres`, `introspect_duckdb`, `introspect_bigquery`, `introspect_mysql`, `introspect_sqlite`. +All five JDBC-style dialects are supported: `introspect_postgres`, `introspect_duckdb`, `introspect_bigquery`, `introspect_mysql`, `introspect_sqlite`. Apache Spark introspection is not provided — construct `Schema` directly. ## Supported CEL Features diff --git a/src/pycel2sql/__init__.py b/src/pycel2sql/__init__.py index 55c4451..2898ec4 100644 --- a/src/pycel2sql/__init__.py +++ b/src/pycel2sql/__init__.py @@ -19,6 +19,7 @@ from pycel2sql.dialect.duckdb import DuckDBDialect from pycel2sql.dialect.mysql import MySQLDialect from pycel2sql.dialect.postgres import PostgresDialect +from pycel2sql.dialect.spark import SparkDialect from pycel2sql.dialect.sqlite import SQLiteDialect from pycel2sql.introspect import introspect from pycel2sql.schema import Schema @@ -38,6 +39,7 @@ "DuckDBDialect", "MySQLDialect", "PostgresDialect", + "SparkDialect", "SQLiteDialect", ] @@ -60,6 +62,8 @@ def convert( max_depth: int | None = None, max_output_length: int | None = None, validate_schema: bool = False, + json_variables: set[str] | frozenset[str] | list[str] | None = None, + column_aliases: dict[str, str] | None = None, ) -> str: """Convert a CEL expression to an inline SQL WHERE clause string. @@ -71,6 +75,13 @@ def convert( max_output_length: Maximum SQL output length. Defaults to 50000. validate_schema: If True, raise InvalidSchemaError for unrecognized table or field references. Requires schemas to be provided. + json_variables: CEL variable names that correspond to flat JSONB + columns. Field access (dot or bracket notation) against these + variables emits dialect-specific JSON extraction instead of + plain dot notation. + column_aliases: Map CEL identifier names to SQL column names. When a + CEL identifier matches a key, the alias is emitted (and validated + against the dialect's identifier rules). Returns: The SQL WHERE clause string. @@ -84,6 +95,30 @@ def convert( tree = _parser.parse(cel_expr) + kwargs: dict[str, Any] = _build_kwargs( + schemas=schemas, + max_depth=max_depth, + max_output_length=max_output_length, + validate_schema=validate_schema, + json_variables=json_variables, + column_aliases=column_aliases, + ) + + converter = Converter(dialect, **kwargs) + converter.visit(tree) + return converter.result + + +def _build_kwargs( + *, + schemas: dict[str, Schema] | None = None, + max_depth: int | None = None, + max_output_length: int | None = None, + validate_schema: bool = False, + json_variables: set[str] | frozenset[str] | list[str] | None = None, + column_aliases: dict[str, str] | None = None, + param_start_index: int | None = None, +) -> dict[str, Any]: kwargs: dict[str, Any] = {} if schemas is not None: kwargs["schemas"] = schemas @@ -93,10 +128,13 @@ def convert( kwargs["max_output_length"] = max_output_length if validate_schema: kwargs["validate_schema"] = validate_schema - - converter = Converter(dialect, **kwargs) - converter.visit(tree) - return converter.result + if json_variables is not None: + kwargs["json_variables"] = frozenset(json_variables) + if column_aliases is not None: + kwargs["column_aliases"] = dict(column_aliases) + if param_start_index is not None: + kwargs["param_start_index"] = max(1, param_start_index) + return kwargs def convert_parameterized( @@ -107,6 +145,9 @@ def convert_parameterized( max_depth: int | None = None, max_output_length: int | None = None, validate_schema: bool = False, + json_variables: set[str] | frozenset[str] | list[str] | None = None, + column_aliases: dict[str, str] | None = None, + param_start_index: int | None = None, ) -> Result: """Convert a CEL expression to a parameterized SQL WHERE clause. @@ -118,6 +159,11 @@ def convert_parameterized( max_output_length: Maximum SQL output length. Defaults to 50000. validate_schema: If True, raise InvalidSchemaError for unrecognized table or field references. Requires schemas to be provided. + json_variables: CEL variable names that correspond to flat JSONB columns. + column_aliases: Map CEL identifier names to SQL column names. + param_start_index: First placeholder index. Defaults to 1. Useful when + embedding the generated fragment in a larger parameterized query. + Values less than 1 are clamped to 1. Returns: Result with SQL containing $1, $2, ... placeholders and parameter list. @@ -131,15 +177,16 @@ def convert_parameterized( tree = _parser.parse(cel_expr) - kwargs: dict[str, Any] = {"parameterize": True} - if schemas is not None: - kwargs["schemas"] = schemas - if max_depth is not None: - kwargs["max_depth"] = max_depth - if max_output_length is not None: - kwargs["max_output_length"] = max_output_length - if validate_schema: - kwargs["validate_schema"] = validate_schema + kwargs: dict[str, Any] = _build_kwargs( + schemas=schemas, + max_depth=max_depth, + max_output_length=max_output_length, + validate_schema=validate_schema, + json_variables=json_variables, + column_aliases=column_aliases, + param_start_index=param_start_index, + ) + kwargs["parameterize"] = True converter = Converter(dialect, **kwargs) converter.visit(tree) @@ -162,6 +209,8 @@ def analyze( max_depth: int | None = None, max_output_length: int | None = None, validate_schema: bool = False, + json_variables: set[str] | frozenset[str] | list[str] | None = None, + column_aliases: dict[str, str] | None = None, ) -> AnalysisResult: """Analyze a CEL expression for SQL conversion and index recommendations. @@ -173,6 +222,8 @@ def analyze( max_output_length: Maximum SQL output length. validate_schema: If True, raise InvalidSchemaError for unrecognized table or field references. Requires schemas to be provided. + json_variables: CEL variable names that correspond to flat JSONB columns. + column_aliases: Map CEL identifier names to SQL column names. Returns: AnalysisResult with SQL and index recommendations. @@ -189,15 +240,14 @@ def analyze( tree = _parser.parse(cel_expr) # Pass 1: Generate SQL - kwargs: dict[str, Any] = {} - if schemas is not None: - kwargs["schemas"] = schemas - if max_depth is not None: - kwargs["max_depth"] = max_depth - if max_output_length is not None: - kwargs["max_output_length"] = max_output_length - if validate_schema: - kwargs["validate_schema"] = validate_schema + kwargs: dict[str, Any] = _build_kwargs( + schemas=schemas, + max_depth=max_depth, + max_output_length=max_output_length, + validate_schema=validate_schema, + json_variables=json_variables, + column_aliases=column_aliases, + ) converter = Converter(dialect, **kwargs) converter.visit(tree) diff --git a/src/pycel2sql/_converter.py b/src/pycel2sql/_converter.py index f160ee1..2d7ca93 100644 --- a/src/pycel2sql/_converter.py +++ b/src/pycel2sql/_converter.py @@ -148,6 +148,9 @@ def __init__( max_output_length: int = DEFAULT_MAX_SQL_OUTPUT_LENGTH, parameterize: bool = False, validate_schema: bool = False, + json_variables: frozenset[str] | None = None, + column_aliases: dict[str, str] | None = None, + param_start_index: int = 1, ) -> None: self._w = StringIO() self._dialect = dialect @@ -158,9 +161,11 @@ def __init__( self._comprehension_depth = 0 self._parameterize = parameterize self._parameters: list[Any] = [] - self._param_count = 0 + self._param_count = max(1, param_start_index) - 1 self._comprehension_vars: set[str] = set() self._validate_schema = validate_schema + self._json_variables: frozenset[str] = json_variables or frozenset() + self._column_aliases: dict[str, str] = column_aliases or {} if self._validate_schema and not self._schemas: raise InvalidSchemaError( ERR_MSG_SCHEMA_VALIDATION_FAILED, @@ -590,6 +595,13 @@ def member_dot(self, tree: Tree) -> None: # Check for JSON path table_name = self._get_root_ident(obj) + # If the root identifier is declared as a flat JSONB variable, route to + # JSON-path emission (skip schema validation — the caller declared it + # JSONB explicitly). + if self._is_json_variable_root(table_name): + self._build_json_path(tree) + return + # Schema validation (before JSON check and SQL writing) if table_name and not self._is_comprehension_var(table_name): first_field = self._get_first_field(obj, field_name) @@ -701,6 +713,12 @@ def member_index(self, tree: Tree) -> None: if index_literal and _is_string_token(index_literal): raw_key = _strip_quotes(str(index_literal)) validate_field_name(raw_key) + # If the indexed object's root is declared as a flat JSONB + # variable, emit JSON-path extraction instead of plain dot access. + root = self._get_root_ident(obj) + if self._is_json_variable_root(root): + self._emit_json_path(obj, [raw_key], root_is_column=True) + return self._visit_child(obj) self._w.write(f".{raw_key}") return @@ -749,10 +767,13 @@ def primary(self, tree: Tree) -> None: def ident(self, tree: Tree) -> None: """Bare identifier.""" name = str(tree.children[0]) - # Don't validate comprehension iteration variables - if not self._is_comprehension_var(name): - validate_field_name(name) - self._w.write(name) + # Don't validate or alias comprehension iteration variables + if self._is_comprehension_var(name): + self._w.write(name) + return + resolved = self._column_aliases.get(name, name) + validate_field_name(resolved) + self._w.write(resolved) def ident_arg(self, tree: Tree) -> None: """Function call: func(args).""" @@ -1302,23 +1323,26 @@ def _visit_format(self, obj: Tree, args: list) -> None: f"format specifier %{spec} cannot be converted to SQL", ) - # Convert %d, %f etc. to %s for SQL FORMAT() + # Normalize supported numeric specifiers (%d, %f, etc.) to %s before + # dispatching to Dialect.write_format(). %s is universally accepted by + # FORMAT()/printf()/format_string(), avoiding type-coercion mismatches + # when CEL arguments don't match the original specifier exactly. sql_fmt = re.sub(r"%([dfoFeEgG])", "%s", raw_fmt) # Get the argument list arg_list = args[0] list_node = _unwrap_to_data(arg_list, "list_lit") - self._w.write("FORMAT(") - self._dialect.write_string_literal(self._w, sql_fmt) - + write_args: list[Any] = [] if list_node is not None and list_node.children: exprlist = list_node.children[0] if isinstance(exprlist, Tree) and exprlist.data == "exprlist": for child in exprlist.children: - self._w.write(", ") - self._visit_child(child) - self._w.write(")") + write_args.append( + (lambda c=child: self._visit_child(c)) # noqa: B023 + ) + + self._dialect.write_format(self._w, sql_fmt, write_args) # ---- has() function ---- @@ -1331,6 +1355,20 @@ def _visit_has(self, args: list) -> None: if member_dot: table_name = self._get_root_ident(member_dot.children[0]) field_name = str(member_dot.children[1]) + # Single-level has() against a flat JSONB variable: + # has(context.host) -> dialect-specific JSON existence check. + if ( + self._is_json_variable_root(table_name) + and _unwrap_to_data(member_dot.children[0], "member_dot") is None + ): + operand = member_dot.children[0] + self._dialect.write_json_existence( + self._w, + True, + field_name, + lambda: self._visit_child(operand), + ) + return # Check if the parent object (not the field itself) is a JSON field # e.g., has(usr.metadata.key) -> usr.metadata ? 'key' # but has(usr.metadata) -> usr.metadata IS NOT NULL @@ -1656,6 +1694,19 @@ def _visit_comp_filter(self, source: Tree, args: list) -> None: # ---- JSON support ---- + def _is_json_variable_root(self, name: str | None) -> bool: + """Return True if name is a CEL variable declared as flat JSONB. + + Comprehension iter vars shadow json_variables — a json_variable that + collides with a comprehension variable is treated as a non-JSON ident + inside the comprehension body. + """ + if not name: + return False + if self._is_comprehension_var(name): + return False + return name in self._json_variables + def _is_field_json(self, table_name: str, field_name: str) -> bool: schema = self._schemas.get(table_name) if not schema: @@ -1695,26 +1746,56 @@ def _build_json_path(self, tree: Tree) -> None: parts.reverse() - # node should now be the root ident - # First part after root is the JSON column, rest are path segments - if len(parts) < 2: - # Simple field access - self._visit_child(node) - self._w.write(f".{parts[0]}") + root_node = node + root_is_json_column = self._is_json_variable_root(self._get_root_ident(root_node)) + self._emit_json_path(root_node, parts, root_is_column=root_is_json_column) + + def _emit_json_path( + self, + root_node: Tree | Token, + parts: list[str], + *, + root_is_column: bool, + ) -> None: + """Emit a JSON path expression for a (root, parts) chain. + + When root_is_column is True, the root ident IS the JSON column and + every entry in parts is a JSON key. Otherwise the first entry of parts + is the JSON column on the table and the remainder are JSON keys. + """ + # Bind the root visitor once; visit() already handles bare tokens. + def visit_root() -> None: + if isinstance(root_node, Tree): + self._visit_child(root_node) + else: + self._w.write(str(root_node)) + + if not parts: + visit_root() return - root_node = node - json_col = parts[0] + if root_is_column: + # Root ident is the JSON column. All parts are JSON keys. + write_base = visit_root + key_parts = parts + else: + # Schema-driven: first part is the JSON column on the table. + json_col = parts[0] + + def write_base() -> None: + visit_root() + self._w.write(f".{json_col}") - # Callback that writes "root.json_column" - def write_base() -> None: - self._visit_child(root_node) - self._w.write(f".{json_col}") + key_parts = parts[1:] + if not key_parts: + # Simple table.json_col field access — emit as plain ident. + write_base() + return - # Chain through path segments with real callbacks + # Chain through key segments with real callbacks current_base = write_base - for i, part in enumerate(parts[1:]): - is_final = i == len(parts) - 2 + for i, part in enumerate(key_parts): + is_final = i == len(key_parts) - 1 if is_final: self._dialect.write_json_field_access( diff --git a/src/pycel2sql/dialect/__init__.py b/src/pycel2sql/dialect/__init__.py index 49a81f2..eb472f5 100644 --- a/src/pycel2sql/dialect/__init__.py +++ b/src/pycel2sql/dialect/__init__.py @@ -5,6 +5,7 @@ from pycel2sql.dialect.duckdb import DuckDBDialect from pycel2sql.dialect.mysql import MySQLDialect from pycel2sql.dialect.postgres import PostgresDialect +from pycel2sql.dialect.spark import SparkDialect from pycel2sql.dialect.sqlite import SQLiteDialect __all__ = [ @@ -14,6 +15,7 @@ "DuckDBDialect", "MySQLDialect", "PostgresDialect", + "SparkDialect", "SQLiteDialect", "get_dialect", ] @@ -23,6 +25,7 @@ DialectName.DUCKDB: DuckDBDialect, DialectName.BIGQUERY: BigQueryDialect, DialectName.MYSQL: MySQLDialect, + DialectName.SPARK: SparkDialect, DialectName.SQLITE: SQLiteDialect, } diff --git a/src/pycel2sql/dialect/_base.py b/src/pycel2sql/dialect/_base.py index b272754..efeff07 100644 --- a/src/pycel2sql/dialect/_base.py +++ b/src/pycel2sql/dialect/_base.py @@ -22,6 +22,7 @@ class DialectName(enum.StrEnum): SQLITE = "sqlite" DUCKDB = "duckdb" BIGQUERY = "bigquery" + SPARK = "spark" WriteFunc = Callable[[], None] @@ -186,6 +187,11 @@ def write_join( self, w: StringIO, write_array: WriteFunc, write_delim: WriteFunc ) -> None: ... + @abstractmethod + def write_format( + self, w: StringIO, fmt_string: str, write_args: list[WriteFunc] + ) -> None: ... + # --- Comprehensions --- @abstractmethod diff --git a/src/pycel2sql/dialect/bigquery.py b/src/pycel2sql/dialect/bigquery.py index 7ad7227..e5efaa0 100644 --- a/src/pycel2sql/dialect/bigquery.py +++ b/src/pycel2sql/dialect/bigquery.py @@ -146,9 +146,9 @@ def write_array_literal_close(self, w: StringIO) -> None: def write_array_length( self, w: StringIO, dimension: int, write_expr: WriteFunc ) -> None: - w.write("ARRAY_LENGTH(") + w.write("COALESCE(ARRAY_LENGTH(") write_expr() - w.write(")") + w.write("), 0)") def write_list_index( self, w: StringIO, write_array: WriteFunc, write_index: WriteFunc @@ -199,9 +199,9 @@ def write_json_array_elements( w.write("))") def write_json_array_length(self, w: StringIO, write_expr: WriteFunc) -> None: - w.write("ARRAY_LENGTH(JSON_QUERY_ARRAY(") + w.write("COALESCE(ARRAY_LENGTH(JSON_QUERY_ARRAY(") write_expr() - w.write("))") + w.write(")), 0)") def write_json_array_membership( self, w: StringIO, json_func: str, write_expr: WriteFunc @@ -304,6 +304,16 @@ def write_join( write_delim() w.write(")") + def write_format( + self, w: StringIO, fmt_string: str, write_args: list[WriteFunc] + ) -> None: + w.write("FORMAT(") + self.write_string_literal(w, fmt_string) + for arg in write_args: + w.write(", ") + arg() + w.write(")") + # --- Comprehensions --- def write_unnest(self, w: StringIO, write_source: WriteFunc) -> None: diff --git a/src/pycel2sql/dialect/duckdb.py b/src/pycel2sql/dialect/duckdb.py index eb586cf..070f744 100644 --- a/src/pycel2sql/dialect/duckdb.py +++ b/src/pycel2sql/dialect/duckdb.py @@ -261,6 +261,16 @@ def write_join( write_delim() w.write(")") + def write_format( + self, w: StringIO, fmt_string: str, write_args: list[WriteFunc] + ) -> None: + w.write("printf(") + self.write_string_literal(w, fmt_string) + for arg in write_args: + w.write(", ") + arg() + w.write(")") + # --- Comprehensions --- def write_unnest(self, w: StringIO, write_source: WriteFunc) -> None: diff --git a/src/pycel2sql/dialect/mysql.py b/src/pycel2sql/dialect/mysql.py index 35370a0..64f80b6 100644 --- a/src/pycel2sql/dialect/mysql.py +++ b/src/pycel2sql/dialect/mysql.py @@ -6,7 +6,7 @@ from io import StringIO -from pycel2sql._errors import InvalidFieldNameError +from pycel2sql._errors import InvalidFieldNameError, UnsupportedDialectFeatureError from pycel2sql._utils import convert_re2_to_mysql from pycel2sql.dialect._base import Dialect, WriteFunc @@ -291,6 +291,14 @@ def write_join( write_array() w.write(")") + def write_format( + self, w: StringIO, fmt_string: str, write_args: list[WriteFunc] + ) -> None: + raise UnsupportedDialectFeatureError( + "format() is not supported in MySQL", + "MySQL has no equivalent to FORMAT/printf for templated strings", + ) + # --- Comprehensions --- def write_unnest(self, w: StringIO, write_source: WriteFunc) -> None: diff --git a/src/pycel2sql/dialect/postgres.py b/src/pycel2sql/dialect/postgres.py index 95eeb61..48d4006 100644 --- a/src/pycel2sql/dialect/postgres.py +++ b/src/pycel2sql/dialect/postgres.py @@ -253,6 +253,16 @@ def write_join( write_delim() w.write(", '')") + def write_format( + self, w: StringIO, fmt_string: str, write_args: list[WriteFunc] + ) -> None: + w.write("FORMAT(") + self.write_string_literal(w, fmt_string) + for arg in write_args: + w.write(", ") + arg() + w.write(")") + # --- Comprehensions --- def write_unnest(self, w: StringIO, write_source: WriteFunc) -> None: diff --git a/src/pycel2sql/dialect/spark.py b/src/pycel2sql/dialect/spark.py new file mode 100644 index 0000000..5c6f009 --- /dev/null +++ b/src/pycel2sql/dialect/spark.py @@ -0,0 +1,575 @@ +"""Apache Spark SQL dialect implementation. + +Ported from cel2sql Go (PR #117) and cel2sql4j (PR #10). Spark runs on the JVM +and uses ``java.util.regex.Pattern``, so the regex translator is mostly a +passthrough. Spark has no separate JSONB type — JSON fields are accessed via +``get_json_object``; arrays use the native ``ARRAY`` type with +``array_contains`` / ``size`` / ``EXPLODE``. + +Spark indexing is storage-layer-specific (Delta Z-order vs Iceberg sort vs +plain Parquet) and not portable as a single set of SQL recommendations. For +this reason ``SparkDialect`` deliberately does NOT implement ``IndexAdvisor``; +``get_index_advisor()`` returns None for Spark, and ``analyze()`` produces an +empty recommendation list. +""" + +from __future__ import annotations + +import re +from io import StringIO + +from pycel2sql._errors import ( + InvalidFieldNameError, + InvalidRegexPatternError, + UnsupportedDialectFeatureError, +) +from pycel2sql.dialect._base import Dialect, WriteFunc + +# Spark / Hive identifier limit. +_MAX_IDENTIFIER_LENGTH = 128 + +# Apache Spark SQL reserved keywords (lowercased). Sourced from the Apache +# Spark docs (sql-ref-ansi-compliance.html#sql-keywords) plus the standard SQL +# set. +_SPARK_RESERVED: set[str] = { + "all", "alter", "and", "anti", "any", "array", "as", "asc", "between", + "both", "by", "case", "cast", "check", "cluster", "collate", "column", + "create", "cross", "cube", "current", "current_date", "current_time", + "current_timestamp", "current_user", "default", "delete", "desc", + "describe", "distinct", "drop", "else", "end", "escape", "except", + "exists", "false", "fetch", "filter", "for", "foreign", "from", "full", + "function", "grant", "group", "grouping", "having", "hour", "in", "inner", + "insert", "intersect", "interval", "into", "is", "join", "lateral", + "leading", "left", "like", "limit", "local", "map", "minute", "month", + "natural", "no", "not", "null", "of", "on", "only", "or", "order", + "outer", "overlaps", "primary", "references", "right", "rollup", "row", + "rows", "second", "select", "semi", "session_user", "set", "some", + "struct", "table", "tablesample", "then", "time", "to", "trailing", + "true", "union", "unique", "unknown", "update", "user", "using", "values", + "when", "where", "window", "with", "year", +} + +_FIELD_NAME_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + +# CEL type name -> Spark SQL type name. +_TYPE_MAP: dict[str, str] = { + "bool": "BOOLEAN", + "bytes": "BINARY", + "double": "DOUBLE", + "int": "BIGINT", + "uint": "BIGINT", + "string": "STRING", +} + +# Lowercased element-type aliases for write_empty_typed_array. +_SPARK_ELEMENT_TYPE_MAP: dict[str, str] = { + "text": "STRING", + "string": "STRING", + "varchar": "STRING", + "char": "STRING", + "int": "BIGINT", + "integer": "BIGINT", + "bigint": "BIGINT", + "int64": "BIGINT", + "long": "BIGINT", + "double": "DOUBLE", + "float": "DOUBLE", + "real": "DOUBLE", + "float64": "DOUBLE", + "boolean": "BOOLEAN", + "bool": "BOOLEAN", + "bytes": "BINARY", + "bytea": "BINARY", + "blob": "BINARY", + "binary": "BINARY", +} + +# Regex limits — same numbers as DuckDB / Postgres regex helpers. +_MAX_PATTERN_LENGTH = 500 +_MAX_GROUPS = 20 +_MAX_NESTING_DEPTH = 10 + +_NESTED_QUANTIFIERS_RE = re.compile(r"[*+][*+]") +_QUANTIFIED_ALTERNATION_RE = re.compile(r"\([^)]*\|[^)]*\)[*+]") +# Captures the flag-letter set inside a `(?)` or `(?:` group. +_INLINE_FLAG_GROUP_RE = re.compile(r"\(\?([a-zA-Z]+)[:)]") + + +def _spark_element_type(type_name: str) -> str: + return _SPARK_ELEMENT_TYPE_MAP.get(type_name.lower(), type_name.upper()) + + +def _validate_spark_field_name(name: str) -> None: + if not name: + raise InvalidFieldNameError( + "field name cannot be empty", + "empty field name provided", + ) + if len(name) > _MAX_IDENTIFIER_LENGTH: + raise InvalidFieldNameError( + "field name exceeds Spark identifier length limit", + f"field name length {len(name)} exceeds Spark limit of " + f"{_MAX_IDENTIFIER_LENGTH}", + ) + if not _FIELD_NAME_RE.match(name): + raise InvalidFieldNameError( + "invalid field name format", + f"field name '{name}' must start with a letter or underscore and " + "contain only alphanumeric characters and underscores", + ) + if name.lower() in _SPARK_RESERVED: + raise InvalidFieldNameError( + "field name is a reserved SQL keyword", + f"field name '{name}' is a reserved Spark SQL keyword", + ) + + +def _count_unescaped_parens(pattern: str) -> int: + count = 0 + for i, ch in enumerate(pattern): + if ch == "(" and (i == 0 or pattern[i - 1] != "\\"): + count += 1 + return count + + +def _max_nesting_depth(pattern: str) -> int: + max_depth = 0 + current = 0 + for i, ch in enumerate(pattern): + if ch == "(" and (i == 0 or pattern[i - 1] != "\\"): + current += 1 + if current > max_depth: + max_depth = current + elif ch == ")" and (i == 0 or pattern[i - 1] != "\\"): + current -= 1 + return max_depth + + +def _validate_no_nested_quantifiers(pattern: str) -> None: + """Raise InvalidRegexPatternError if a quantifier is nested inside another + quantified group (potential ReDoS). + """ + stack: list[bool] = [] # True if the group has seen a quantifier + i = 0 + while i < len(pattern): + ch = pattern[i] + if i > 0 and pattern[i - 1] == "\\": + i += 1 + continue + if ch == "(": + stack.append(False) + elif ch == ")": + if not stack: + i += 1 + continue + inner_quantified = stack.pop() + next_ch = pattern[i + 1] if i + 1 < len(pattern) else "" + if next_ch in ("*", "+", "?", "{") and inner_quantified: + raise InvalidRegexPatternError( + "invalid regex pattern", + "regex contains catastrophic nested quantifiers that " + "could cause ReDoS", + ) + if stack and inner_quantified: + stack[-1] = True + elif ch in ("*", "+", "?", "{"): + if stack: + stack[-1] = True + i += 1 + + +def _convert_re2_to_spark(pattern: str) -> tuple[str, bool]: + """Validate an RE2-style regex pattern and pass it through to Spark. + + Spark uses java.util.regex which handles inline ``(?i)`` natively, so the + returned ``case_insensitive`` flag is always False — the pattern is emitted + verbatim and the engine honours any inline flag. + """ + if len(pattern) > _MAX_PATTERN_LENGTH: + raise InvalidRegexPatternError( + "invalid pattern in expression", + f"pattern length {len(pattern)} exceeds limit of " + f"{_MAX_PATTERN_LENGTH} characters", + ) + try: + re.compile(pattern) + except re.error as e: + raise InvalidRegexPatternError( + "invalid pattern in expression", + f"regex pattern does not compile: {e}", + ) from e + if "(?=" in pattern or "(?!" in pattern: + raise InvalidRegexPatternError( + "invalid pattern in expression", + "lookahead assertions (?=...), (?!...) are not supported in Spark regex", + ) + if "(?<=" in pattern or "(?...) are not supported in Spark regex", + ) + if _NESTED_QUANTIFIERS_RE.search(pattern): + raise InvalidRegexPatternError( + "invalid pattern in expression", + "regex contains catastrophic nested quantifiers that could cause ReDoS", + ) + _validate_no_nested_quantifiers(pattern) + if _count_unescaped_parens(pattern) > _MAX_GROUPS: + raise InvalidRegexPatternError( + "invalid pattern in expression", + f"regex contains more than {_MAX_GROUPS} capture groups", + ) + if _QUANTIFIED_ALTERNATION_RE.search(pattern): + raise InvalidRegexPatternError( + "invalid pattern in expression", + "regex contains quantified alternation that could cause ReDoS", + ) + if _max_nesting_depth(pattern) > _MAX_NESTING_DEPTH: + raise InvalidRegexPatternError( + "invalid pattern in expression", + f"nesting depth exceeds limit of {_MAX_NESTING_DEPTH}", + ) + # Reject any inline flag group whose flag-letter set contains anything + # other than `i`. This catches combined groups like `(?im)` / `(?mi)` / + # `(?ix:...)` that the previous substring-only check missed. The `(?-...)` + # flag-clear form is rejected too because the regex matches its leading + # `(?` followed by letters; a literal `(?-i)` (no letters before `-`) is + # caught by the explicit `(?-` substring check. + for match in _INLINE_FLAG_GROUP_RE.finditer(pattern): + flags = match.group(1) + if any(f != "i" for f in flags): + raise InvalidRegexPatternError( + "invalid pattern in expression", + "inline flags other than (?i) are not supported in Spark regex", + ) + if "(?-" in pattern: + raise InvalidRegexPatternError( + "invalid pattern in expression", + "inline flags other than (?i) are not supported in Spark regex", + ) + return pattern, False + + +class SparkDialect(Dialect): + """Apache Spark SQL dialect for CEL-to-SQL conversion.""" + + # --- Literals --- + + def write_string_literal(self, w: StringIO, value: str) -> None: + escaped = value.replace("'", "''") + w.write(f"'{escaped}'") + + def write_bytes_literal(self, w: StringIO, value: bytes) -> None: + hex_str = value.hex().upper() + w.write(f"X'{hex_str}'") + + def write_param_placeholder(self, w: StringIO, param_index: int) -> None: + # Spark JDBC uses positional ? placeholders. + w.write("?") + + # --- Operators --- + + def write_string_concat( + self, w: StringIO, write_lhs: WriteFunc, write_rhs: WriteFunc + ) -> None: + # concat() works in all Spark versions; the || operator was added in 3.0+. + w.write("concat(") + write_lhs() + w.write(", ") + write_rhs() + w.write(")") + + def write_regex_match( + self, w: StringIO, write_target: WriteFunc, pattern: str, case_insensitive: bool + ) -> None: + # Spark regex uses Java pattern syntax; (?i) inline flag is honoured by + # the engine, so the case_insensitive flag is always False here (folded + # into the pattern by _convert_re2_to_spark). + write_target() + escaped = pattern.replace("'", "''") + w.write(f" RLIKE '{escaped}'") + + def write_like_escape(self, w: StringIO) -> None: + w.write(" ESCAPE '\\\\'") + + def write_array_membership( + self, w: StringIO, write_elem: WriteFunc, write_array: WriteFunc + ) -> None: + w.write("array_contains(") + write_array() + w.write(", ") + write_elem() + w.write(")") + + # --- Type Casting --- + + def write_cast_to_numeric(self, w: StringIO, write_expr: WriteFunc) -> None: + # Spark has no postfix `::TYPE` cast; arithmetic coercion `+ 0` works + # (same trick MySQL/SQLite use), forcing string→number coercion. + write_expr() + w.write(" + 0") + + def write_type_name(self, w: StringIO, cel_type_name: str) -> None: + sql_type = _TYPE_MAP.get(cel_type_name, cel_type_name.upper()) + w.write(sql_type) + + def write_epoch_extract(self, w: StringIO, write_expr: WriteFunc) -> None: + w.write("UNIX_TIMESTAMP(") + write_expr() + w.write(")") + + def write_timestamp_cast(self, w: StringIO, write_expr: WriteFunc) -> None: + w.write("CAST(") + write_expr() + w.write(" AS TIMESTAMP)") + + # --- Arrays --- + + def write_array_literal_open(self, w: StringIO) -> None: + w.write("array(") + + def write_array_literal_close(self, w: StringIO) -> None: + w.write(")") + + def write_array_length( + self, w: StringIO, dimension: int, write_expr: WriteFunc + ) -> None: + if dimension > 1: + raise UnsupportedDialectFeatureError( + "multi-dimensional array length is not supported in Spark", + f"Spark dialect does not support multi-dimensional array " + f"length (dimension={dimension})", + ) + # In Spark SQL, size(NULL) evaluates to NULL; COALESCE converts that to 0. + w.write("COALESCE(size(") + write_expr() + w.write("), 0)") + + def write_list_index( + self, w: StringIO, write_array: WriteFunc, write_index: WriteFunc + ) -> None: + # Spark arrays are 0-indexed (Java/Scala convention). + write_array() + w.write("[") + write_index() + w.write("]") + + def write_list_index_const( + self, w: StringIO, write_array: WriteFunc, index: int + ) -> None: + write_array() + w.write(f"[{index}]") + + def write_empty_typed_array(self, w: StringIO, type_name: str) -> None: + w.write(f"CAST(array() AS ARRAY<{_spark_element_type(type_name)}>)") + + # --- JSON --- + + def write_json_field_access( + self, w: StringIO, write_base: WriteFunc, field_name: str, is_final: bool + ) -> None: + # Spark's get_json_object always returns a string; the same function is + # used for both intermediate and final access (no JSON_QUERY equivalent). + escaped = field_name.replace("'", "''") + w.write("get_json_object(") + write_base() + w.write(f", '$.{escaped}')") + + def write_json_existence( + self, w: StringIO, is_jsonb: bool, field_name: str, write_base: WriteFunc + ) -> None: + escaped = field_name.replace("'", "''") + w.write("get_json_object(") + write_base() + w.write(f", '$.{escaped}') IS NOT NULL") + + def write_json_array_elements( + self, w: StringIO, is_jsonb: bool, as_text: bool, write_expr: WriteFunc + ) -> None: + # Element type is fixed to STRING; numeric comparisons coerce via + # write_cast_to_numeric. + w.write("EXPLODE(from_json(") + write_expr() + w.write(", 'ARRAY'))") + + def write_json_array_length(self, w: StringIO, write_expr: WriteFunc) -> None: + w.write("COALESCE(size(from_json(") + write_expr() + w.write(", 'ARRAY')), 0)") + + def write_json_array_membership( + self, w: StringIO, json_func: str, write_expr: WriteFunc + ) -> None: + # The converter emits `lhs = ` for this construct, and a + # scalar subquery built from EXPLODE(from_json(...)) can return + # multiple rows — Spark rejects that at runtime. The dialect contract + # here does not provide the candidate element, so we cannot rewrite + # to a boolean predicate (e.g. array_contains(from_json(...), elem)). + # Failing fast at conversion time is preferable to emitting SQL that + # fails at execution. + raise UnsupportedDialectFeatureError( + "JSON array membership is not supported in Spark", + "Spark JSON array membership requires a boolean predicate " + "(array_contains/EXISTS); the dialect contract does not provide " + "the candidate element to build one. Use a typed ARRAY column " + "or rewrite the expression in application code.", + ) + + def write_nested_json_array_membership( + self, w: StringIO, write_expr: WriteFunc + ) -> None: + raise UnsupportedDialectFeatureError( + "nested JSON array membership is not supported in Spark", + "Spark nested JSON array membership requires a boolean predicate " + "(array_contains/EXISTS); the dialect contract does not provide " + "the candidate element to build one. Use a typed ARRAY column " + "or rewrite the expression in application code.", + ) + + # --- Timestamps --- + + def write_duration(self, w: StringIO, value: int, unit: str) -> None: + w.write(f"INTERVAL {value} {unit}") + + def write_interval( + self, w: StringIO, write_value: WriteFunc, unit: str + ) -> None: + w.write("INTERVAL ") + write_value() + w.write(f" {unit}") + + def write_extract( + self, + w: StringIO, + part: str, + write_expr: WriteFunc, + write_tz: WriteFunc | None, + ) -> None: + # Spark dayofweek() returns 1=Sunday..7=Saturday; CEL convention is + # 0=Sunday..6=Saturday. Adjust by subtracting 1. + if part == "DOW": + w.write("(dayofweek(") + write_expr() + if write_tz is not None: + w.write(" AT TIME ZONE ") + write_tz() + w.write(") - 1)") + return + w.write(f"EXTRACT({part} FROM ") + write_expr() + if write_tz is not None: + w.write(" AT TIME ZONE ") + write_tz() + w.write(")") + + def write_timestamp_arithmetic( + self, + w: StringIO, + op: str, + write_ts: WriteFunc, + write_dur: WriteFunc, + ) -> None: + write_ts() + w.write(f" {op} ") + write_dur() + + # --- String Functions --- + + def write_contains( + self, w: StringIO, write_haystack: WriteFunc, write_needle: WriteFunc + ) -> None: + # LOCATE(substr, str) returns 1-based position or 0 when not found. + w.write("LOCATE(") + write_needle() + w.write(", ") + write_haystack() + w.write(") > 0") + + def write_split( + self, w: StringIO, write_str: WriteFunc, write_delim: WriteFunc + ) -> None: + w.write("split(") + write_str() + w.write(", ") + write_delim() + w.write(")") + + def write_split_with_limit( + self, w: StringIO, write_str: WriteFunc, write_delim: WriteFunc, limit: int + ) -> None: + # Spark 3.x+ supports the 3-arg split. + w.write("split(") + write_str() + w.write(", ") + write_delim() + w.write(f", {limit})") + + def write_join( + self, w: StringIO, write_array: WriteFunc, write_delim: WriteFunc + ) -> None: + w.write("array_join(") + write_array() + w.write(", ") + write_delim() + w.write(")") + + def write_format( + self, w: StringIO, fmt_string: str, write_args: list[WriteFunc] + ) -> None: + # Spark's format_string() is its printf-equivalent (supports %s/%d/%f + # directly). + w.write("format_string(") + self.write_string_literal(w, fmt_string) + for arg in write_args: + w.write(", ") + arg() + w.write(")") + + # --- Comprehensions --- + + def write_unnest(self, w: StringIO, write_source: WriteFunc) -> None: + w.write("EXPLODE(") + write_source() + w.write(")") + + def write_array_subquery_open(self, w: StringIO) -> None: + # Spark has no ARRAY(SELECT ...) constructor; collect_list() is the + # closest equivalent. + w.write("(SELECT collect_list(") + + def write_array_subquery_expr_close(self, w: StringIO) -> None: + w.write(")") + + # --- Regex --- + + def convert_regex(self, re2_pattern: str) -> tuple[str, bool]: + return _convert_re2_to_spark(re2_pattern) + + # --- Struct --- + + def write_struct_open(self, w: StringIO) -> None: + w.write("struct(") + + def write_struct_close(self, w: StringIO) -> None: + w.write(")") + + # --- Validation --- + + def max_identifier_length(self) -> int: + return _MAX_IDENTIFIER_LENGTH + + def validate_field_name(self, name: str) -> None: + _validate_spark_field_name(name) + + # --- Capabilities --- + + def supports_native_arrays(self) -> bool: + return True + + def supports_jsonb(self) -> bool: + return False diff --git a/src/pycel2sql/dialect/sqlite.py b/src/pycel2sql/dialect/sqlite.py index 21ca0c5..f7a8efe 100644 --- a/src/pycel2sql/dialect/sqlite.py +++ b/src/pycel2sql/dialect/sqlite.py @@ -306,6 +306,16 @@ def write_join( "SQLite does not have a native array join function", ) + def write_format( + self, w: StringIO, fmt_string: str, write_args: list[WriteFunc] + ) -> None: + w.write("printf(") + self.write_string_literal(w, fmt_string) + for arg in write_args: + w.write(", ") + arg() + w.write(")") + # --- Comprehensions --- def write_unnest(self, w: StringIO, write_source: WriteFunc) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index ac1fe8b..48daf85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from pycel2sql.dialect.duckdb import DuckDBDialect from pycel2sql.dialect.mysql import MySQLDialect from pycel2sql.dialect.postgres import PostgresDialect +from pycel2sql.dialect.spark import SparkDialect from pycel2sql.dialect.sqlite import SQLiteDialect @@ -34,10 +35,16 @@ def sqlite_dialect(): return SQLiteDialect() +@pytest.fixture +def spark_dialect(): + return SparkDialect() + + ALL_DIALECTS = [ PostgresDialect(), DuckDBDialect(), BigQueryDialect(), MySQLDialect(), + SparkDialect(), SQLiteDialect(), ] diff --git a/tests/test_bigquery.py b/tests/test_bigquery.py index fcba049..8c831d7 100644 --- a/tests/test_bigquery.py +++ b/tests/test_bigquery.py @@ -56,9 +56,7 @@ def test_array_index_const(self, d): def test_array_length(self, d): schemas = {"t": Schema([FieldSchema("arr", repeated=True)])} result = convert("t.arr.size()", dialect=d, schemas=schemas) - assert "ARRAY_LENGTH(" in result - # No COALESCE wrapper in BigQuery - assert "COALESCE" not in result + assert result == "COALESCE(ARRAY_LENGTH(t.arr), 0)" def test_empty_typed_array(self, d): result = convert('"a,b".split(",", 0)', dialect=d) diff --git a/tests/test_dialect_parametrized.py b/tests/test_dialect_parametrized.py index 3c9885d..fb94da2 100644 --- a/tests/test_dialect_parametrized.py +++ b/tests/test_dialect_parametrized.py @@ -7,6 +7,7 @@ from pycel2sql.dialect.duckdb import DuckDBDialect from pycel2sql.dialect.mysql import MySQLDialect from pycel2sql.dialect.postgres import PostgresDialect +from pycel2sql.dialect.spark import SparkDialect from pycel2sql.dialect.sqlite import SQLiteDialect ALL_DIALECTS = [ @@ -14,6 +15,7 @@ pytest.param(DuckDBDialect(), id="duckdb"), pytest.param(BigQueryDialect(), id="bigquery"), pytest.param(MySQLDialect(), id="mysql"), + pytest.param(SparkDialect(), id="spark"), pytest.param(SQLiteDialect(), id="sqlite"), ] diff --git a/tests/test_options.py b/tests/test_options.py new file mode 100644 index 0000000..8b96525 --- /dev/null +++ b/tests/test_options.py @@ -0,0 +1,182 @@ +"""Tests for the json_variables, column_aliases, and param_start_index options. + +Ported from cel2sql Go (PR #113) and cel2sql4j (PR #9). +""" + +import pytest + +from pycel2sql import convert, convert_parameterized +from pycel2sql._errors import InvalidFieldNameError +from pycel2sql.dialect.bigquery import BigQueryDialect +from pycel2sql.dialect.duckdb import DuckDBDialect +from pycel2sql.dialect.mysql import MySQLDialect +from pycel2sql.dialect.sqlite import SQLiteDialect + + +class TestParamStartIndex: + def test_default_starts_at_1(self): + result = convert_parameterized("name == 'alice'") + assert result.sql == "name = $1" + assert result.parameters == ["alice"] + + def test_postgres_offset_5(self): + result = convert_parameterized("name == 'alice'", param_start_index=5) + assert result.sql == "name = $5" + assert result.parameters == ["alice"] + + def test_bigquery_offset_10_two_params(self): + result = convert_parameterized( + "name == 'alice' && age > 30", + dialect=BigQueryDialect(), + param_start_index=10, + ) + assert "@p10" in result.sql + assert "@p11" in result.sql + assert result.parameters == ["alice", 30] + + def test_postgres_two_params_count_up(self): + result = convert_parameterized( + "name == 'alice' && age > 30", + param_start_index=7, + ) + assert result.sql == "name = $7 AND age > $8" + assert result.parameters == ["alice", 30] + + @pytest.mark.parametrize("dialect_cls", [MySQLDialect, SQLiteDialect]) + def test_positional_dialects_param_text_unchanged(self, dialect_cls): + # MySQL/SQLite use ? placeholders that don't carry an index in the text, + # but the parameters list ordering must still be preserved. + result = convert_parameterized( + "name == 'a' && age > 30", + dialect=dialect_cls(), + param_start_index=5, + ) + assert result.sql.count("?") == 2 + assert result.parameters == ["a", 30] + + @pytest.mark.parametrize("bad", [0, -5, -100]) + def test_clamping_below_one(self, bad): + result = convert_parameterized("name == 'a'", param_start_index=bad) + assert result.sql == "name = $1" + + def test_many_params_count_up_past_9(self): + # 11 params, starting at 5: should yield $5..$15 + expr = " && ".join([f"x{i} == 'v{i}'" for i in range(11)]) + result = convert_parameterized(expr, param_start_index=5) + for i in range(5, 16): + assert f"${i}" in result.sql + assert len(result.parameters) == 11 + + +class TestColumnAliases: + def test_bare_ident(self): + result = convert("name == 'a'", column_aliases={"name": "usr_name"}) + assert result == "usr_name = 'a'" + + def test_ident_inside_member_dot(self): + # name.first → user_name.first (alias applied at the root ident emission) + result = convert( + "name.first == 'a'", column_aliases={"name": "user_name"} + ) + assert result == "user_name.first = 'a'" + + def test_no_alias_passes_through(self): + result = convert("name == 'a'", column_aliases={"other": "x"}) + assert result == "name = 'a'" + + def test_invalid_alias_raises(self): + with pytest.raises(InvalidFieldNameError): + convert("name == 'a'", column_aliases={"name": "bad name; DROP"}) + + def test_alias_in_parameterized_form(self): + result = convert_parameterized( + "name == 'a'", column_aliases={"name": "usr_name"} + ) + assert result.sql == "usr_name = $1" + assert result.parameters == ["a"] + + +class TestJsonVariables: + def test_postgres_single_level_dot(self): + result = convert( + "context.host == 'a'", json_variables={"context"} + ) + assert result == "context->>'host' = 'a'" + + def test_postgres_bracket_notation(self): + result = convert( + 'context["host"] == "a"', json_variables={"context"} + ) + assert result == "context->>'host' = 'a'" + + def test_postgres_nested_three_levels(self): + result = convert( + "tags.corpus.section == 'a'", json_variables={"tags"} + ) + assert result == "tags->'corpus'->>'section' = 'a'" + + def test_duckdb_single_level(self): + result = convert( + "context.host == 'a'", + dialect=DuckDBDialect(), + json_variables={"context"}, + ) + assert "context->>'host'" in result + + def test_bigquery_single_level(self): + result = convert( + "context.host == 'a'", + dialect=BigQueryDialect(), + json_variables={"context"}, + ) + assert "JSON_VALUE(context, '$.host')" in result + + def test_sqlite_single_level(self): + result = convert( + "context.host == 'a'", + dialect=SQLiteDialect(), + json_variables={"context"}, + ) + assert "json_extract(context, '$.host')" in result + + def test_mysql_single_level(self): + result = convert( + "context.host == 'a'", + dialect=MySQLDialect(), + json_variables={"context"}, + ) + # MySQL's JSON ->> operator extracts as text. + assert result == "context->>'$.host' = 'a'" + + def test_no_json_variable_no_change(self): + # Without json_variables, dot notation stays plain. + result = convert("context.host == 'a'") + assert result == "context.host = 'a'" + + def test_json_variable_with_alias(self): + # Aliased name is the column root in the SQL output. + result = convert( + "context.host == 'a'", + json_variables={"context"}, + column_aliases={"context": "ctx_jsonb"}, + ) + assert result == "ctx_jsonb->>'host' = 'a'" + + def test_comprehension_iter_var_not_treated_as_json_variable(self): + # Iter var "x" colliding with a json_variable name should not be + # treated as JSON inside the comprehension body. + result = convert( + "items.exists(x, x == 1)", + json_variables={"x"}, + ) + # x should appear plainly inside the EXISTS body, not as JSON path. + assert "->>" not in result + + def test_has_against_json_variable(self): + # has(context.host) with context a json_variable should emit the + # dialect's JSON-existence operator (Postgres `?` for JSONB). + result = convert( + "has(context.host)", + json_variables={"context"}, + ) + assert result == "context ? 'host'" diff --git a/tests/test_spark.py b/tests/test_spark.py new file mode 100644 index 0000000..5f178bc --- /dev/null +++ b/tests/test_spark.py @@ -0,0 +1,275 @@ +"""Apache Spark dialect-specific tests. + +Mirrors the Java Cel2SqlSparkTest surface from cel2sql4j PR #10. +""" + +import pytest + +from pycel2sql import convert, convert_parameterized +from pycel2sql._errors import ( + InvalidFieldNameError, + InvalidRegexPatternError, + UnsupportedDialectFeatureError, +) +from pycel2sql.dialect.spark import ( + SparkDialect, + _convert_re2_to_spark, + _validate_spark_field_name, +) +from pycel2sql.schema import FieldSchema, Schema + + +@pytest.fixture +def d(): + return SparkDialect() + + +class TestSparkLiterals: + def test_string_literal(self, d): + assert convert('name == "alice"', dialect=d) == "name = 'alice'" + + def test_int_literal(self, d): + assert convert("age == 25", dialect=d) == "age = 25" + + def test_bool_literal(self, d): + assert convert("active == true", dialect=d) == "active IS TRUE" + + def test_bytes_literal_hex(self, d): + result = convert('b"abc" == data', dialect=d) + # Hex form: X'' + assert "X'" in result and "61" in result.upper() + + +class TestSparkParams: + def test_positional_placeholder(self, d): + result = convert_parameterized('name == "alice"', dialect=d) + assert result.sql == "name = ?" + assert result.parameters == ["alice"] + + def test_multiple_positional(self, d): + result = convert_parameterized("age > 10 && age < 30", dialect=d) + assert result.sql == "age > ? AND age < ?" + assert result.parameters == [10, 30] + + def test_param_start_index_is_a_no_op_for_text(self, d): + result = convert_parameterized( + "age > 10 && age < 30", dialect=d, param_start_index=5 + ) + # Spark uses positional ? — placeholder text is identical. + assert result.sql == "age > ? AND age < ?" + assert result.parameters == [10, 30] + + +class TestSparkArrays: + def test_array_literal(self, d): + assert convert("[1, 2, 3]", dialect=d) == "array(1, 2, 3)" + + def test_array_membership_arg_order(self, d): + # Spark: array_contains(arr, elem) — arg order swap. + result = convert("x in [1, 2, 3]", dialect=d) + assert result == "array_contains(array(1, 2, 3), x)" + + def test_array_index_const_zero_based(self, d): + schemas = {"t": Schema([FieldSchema("arr", repeated=True)])} + result = convert("t.arr[0]", dialect=d, schemas=schemas) + assert result == "t.arr[0]" + + def test_array_index_const_nonzero(self, d): + schemas = {"t": Schema([FieldSchema("arr", repeated=True)])} + result = convert("t.arr[2]", dialect=d, schemas=schemas) + assert result == "t.arr[2]" + + def test_array_length(self, d): + schemas = {"t": Schema([FieldSchema("arr", repeated=True)])} + result = convert("t.arr.size()", dialect=d, schemas=schemas) + assert result == "COALESCE(size(t.arr), 0)" + + def test_empty_typed_array(self, d): + result = convert('"a,b".split(",", 0)', dialect=d) + assert "CAST(array() AS ARRAY<" in result + + +class TestSparkStringFunctions: + def test_contains_locate_arg_order(self, d): + # Spark: LOCATE(needle, haystack) > 0 (arg order: needle then haystack). + result = convert('name.contains("test")', dialect=d) + assert result == "LOCATE('test', name) > 0" + + def test_starts_with_escape(self, d): + result = convert('name.startsWith("a")', dialect=d) + assert "LIKE 'a%'" in result + assert "ESCAPE '\\\\'" in result + + def test_concat_for_string_plus(self, d): + result = convert('"a" + "b" == "ab"', dialect=d) + assert "concat('a', 'b')" in result + + def test_split_basic(self, d): + result = convert('"a,b,c".split(",")', dialect=d) + assert "split('a,b,c', ',')" in result + + def test_split_with_limit(self, d): + result = convert('"a,b,c".split(",", 3)', dialect=d) + assert "split('a,b,c', ',', 3)" in result + + def test_join(self, d): + result = convert('["a", "b"].join(",")', dialect=d) + assert "array_join(array('a', 'b'), ',')" in result + + def test_format_uses_format_string(self, d): + result = convert("'%s = %d'.format([name, 10])", dialect=d) + assert result == "format_string('%s = %s', name, 10)" + + +class TestSparkRegex: + def test_basic_match_rlike(self, d): + result = convert('name.matches("^[a-z]+$")', dialect=d) + assert result == "name RLIKE '^[a-z]+$'" + + def test_inline_case_insensitive_passthrough(self, d): + # Spark honours (?i) inline; we pass it through verbatim. + result = convert('name.matches("(?i)test")', dialect=d) + assert result == "name RLIKE '(?i)test'" + + def test_lookahead_rejected(self, d): + with pytest.raises(InvalidRegexPatternError): + convert('name.matches("(?=test)abc")', dialect=d) + + def test_lookbehind_rejected(self, d): + with pytest.raises(InvalidRegexPatternError): + convert('name.matches("(?<=test)abc")', dialect=d) + + def test_named_group_rejected(self, d): + with pytest.raises(InvalidRegexPatternError): + convert('name.matches("(?Pabc)")', dialect=d) + + def test_quantified_alternation_rejected(self, d): + with pytest.raises(InvalidRegexPatternError): + convert('name.matches("(a|b)+")', dialect=d) + + def test_inline_flag_other_than_i_rejected(self, d): + with pytest.raises(InvalidRegexPatternError): + convert('name.matches("(?m)foo")', dialect=d) + + @pytest.mark.parametrize("pat", ["(?im)foo", "(?mi)foo", "(?ix:foo)", "(?-i)foo"]) + def test_combined_flag_groups_rejected(self, d, pat): + # `(?i)` alone is allowed; any other flags (even combined with `i`) + # must be rejected. + with pytest.raises(InvalidRegexPatternError): + convert(f'name.matches("{pat}")', dialect=d) + + def test_lone_inline_i_flag_allowed(self, d): + # Sanity check: a bare `(?i)` group still passes. + result = convert('name.matches("(?i)foo")', dialect=d) + assert "RLIKE '(?i)foo'" in result + + def test_overlong_pattern_rejected_at_validator(self): + big = "a" * 600 + with pytest.raises(InvalidRegexPatternError): + _convert_re2_to_spark(big) + + def test_nested_quantifier_rejected_at_validator(self): + # Direct validator test (the converter's deeper machinery may rewrite + # before reaching here in some forms; the validator is the source of + # truth for ReDoS guarding). + with pytest.raises(InvalidRegexPatternError): + _convert_re2_to_spark("(a+)+") + + +class TestSparkTimestamps: + def test_int_of_timestamp_uses_unix_timestamp(self, d): + # Bare 'ts' identifier hits the timestamp-name heuristic. + result = convert("int(ts) == 100", dialect=d) + assert result == "UNIX_TIMESTAMP(ts) = 100" + + def test_timestamp_cast(self, d): + result = convert('timestamp("2024-01-01T00:00:00Z") == ts', dialect=d) + assert "CAST(" in result and "AS TIMESTAMP" in result + + def test_duration_interval(self, d): + result = convert('ts + duration("24h") == ts2', dialect=d) + assert "INTERVAL" in result + + def test_extract_year(self, d): + result = convert("ts.getFullYear() == 2024", dialect=d) + assert "EXTRACT(YEAR FROM ts)" in result + + def test_dow_special_case(self, d): + result = convert("ts.getDayOfWeek() == 0", dialect=d) + assert "(dayofweek(ts) - 1)" in result + + +class TestSparkJSON: + def test_json_field_access(self, d): + schemas = {"t": Schema([FieldSchema("data", is_json=True)])} + result = convert('t.data.field == "x"', dialect=d, schemas=schemas) + assert "get_json_object(t.data, '$.field')" in result + + def test_json_existence_via_has(self, d): + result = convert("has(context.host)", dialect=d, json_variables={"context"}) + assert result == "get_json_object(context, '$.host') IS NOT NULL" + + def test_json_variable_dot_access(self, d): + result = convert("context.host == 'a'", dialect=d, json_variables={"context"}) + assert result == "get_json_object(context, '$.host') = 'a'" + + def test_json_array_membership_dialect_method_raises(self, d): + # Direct dialect-level call: the converter doesn't currently route + # `in` against a JSON-array field through write_json_array_membership, + # but the Spark dialect must raise if it ever does. + from io import StringIO + + with pytest.raises(UnsupportedDialectFeatureError): + d.write_json_array_membership(StringIO(), "x", lambda: None) + + def test_nested_json_array_membership_dialect_method_raises(self, d): + from io import StringIO + + with pytest.raises(UnsupportedDialectFeatureError): + d.write_nested_json_array_membership(StringIO(), lambda: None) + + +class TestSparkValidation: + def test_reserved_keyword_rejected(self, d): + with pytest.raises(InvalidFieldNameError): + convert("select == 1", dialect=d) + + def test_overlong_field_rejected_at_validator(self): + # Direct dialect-validator test (the generic field-name validator + # caps at 63 chars and would reject 130 first; testing the + # Spark-specific check directly). + with pytest.raises(InvalidFieldNameError): + _validate_spark_field_name("a" * 130) + + def test_empty_field_rejected_at_validator(self): + with pytest.raises(InvalidFieldNameError): + _validate_spark_field_name("") + + +class TestSparkTypeCasting: + def test_cel_int_to_bigint(self, d): + # Plain int(x) on a non-timestamp ident emits CAST(... AS BIGINT). + result = convert("int(x) == 1", dialect=d) + assert result == "CAST(x AS BIGINT) = 1" + + def test_cel_string_cast(self, d): + result = convert('string(x) == "a"', dialect=d) + assert result == "CAST(x AS STRING) = 'a'" + + def test_cel_bool_cast(self, d): + result = convert("bool(x) == true", dialect=d) + assert "CAST(x AS BOOLEAN)" in result + + +class TestSparkComprehensions: + def test_exists_uses_explode_and_collect_list(self, d): + result = convert("[1, 2, 3].exists(x, x > 1)", dialect=d) + # The comprehension scaffolding should reference EXPLODE. + assert "EXPLODE" in result + + +class TestSparkStructs: + def test_struct_literal(self, d): + # CEL struct/map literal: {a: 1, b: 2} + result = convert("{'a': 1, 'b': 2} == x", dialect=d) + assert "struct(" in result.lower() diff --git a/tests/test_string_functions.py b/tests/test_string_functions.py index 0f55a1d..3437b13 100644 --- a/tests/test_string_functions.py +++ b/tests/test_string_functions.py @@ -142,3 +142,34 @@ def test_without_delimiter(self): def test_with_space(self): result = convert("['hello', 'world'].join(' ') == 'hello world'") assert result == "ARRAY_TO_STRING(ARRAY['hello', 'world'], ' ', '') = 'hello world'" + + +class TestFormatPerDialect: + """format() dispatches per dialect: FORMAT for Postgres/BigQuery, + format_string for Spark, printf for SQLite/DuckDB, raises for MySQL.""" + + def test_postgres_emits_FORMAT(self): + from pycel2sql.dialect.postgres import PostgresDialect + result = convert("'%s = %d'.format([name, 10])", dialect=PostgresDialect()) + assert result == "FORMAT('%s = %s', name, 10)" + + def test_bigquery_emits_FORMAT(self): + from pycel2sql.dialect.bigquery import BigQueryDialect + result = convert("'%s = %d'.format([name, 10])", dialect=BigQueryDialect()) + assert result == "FORMAT('%s = %s', name, 10)" + + def test_sqlite_emits_printf(self): + from pycel2sql.dialect.sqlite import SQLiteDialect + result = convert("'%s = %d'.format([name, 10])", dialect=SQLiteDialect()) + assert result == "printf('%s = %s', name, 10)" + + def test_duckdb_emits_printf(self): + from pycel2sql.dialect.duckdb import DuckDBDialect + result = convert("'%s = %d'.format([name, 10])", dialect=DuckDBDialect()) + assert result == "printf('%s = %s', name, 10)" + + def test_mysql_raises(self): + from pycel2sql._errors import UnsupportedDialectFeatureError + from pycel2sql.dialect.mysql import MySQLDialect + with pytest.raises(UnsupportedDialectFeatureError, match="format"): + convert("'%s'.format([name])", dialect=MySQLDialect())