diff --git a/providers/common/ai/docs/toolsets.rst b/providers/common/ai/docs/toolsets.rst index b5e868abea209..33b4ddd3b402a 100644 --- a/providers/common/ai/docs/toolsets.rst +++ b/providers/common/ai/docs/toolsets.rst @@ -155,10 +155,10 @@ the validator, so ``SHOW`` is recognized on databases that support it (Snowflake MySQL, etc.); on databases without ``SHOW`` it stays rejected. Data-modifying statements remain blocked -- including ones hidden behind ``DESCRIBE``/``EXPLAIN`` (e.g. ``EXPLAIN DELETE ...``, ``DESCRIBE DROP TABLE ...``), which the validator -rejects by scanning the parsed statement for write operations. Like ``SELECT``, -metadata statements are not scoped by ``allowed_tables`` (see -:ref:`allowed-tables-limitation`) -- an agent can ``DESCRIBE`` a table outside the -list, so rely on database permissions to restrict access. +rejects by scanning the parsed statement for write operations. When +``allowed_tables`` is set it scopes these statements too: a ``DESCRIBE`` names a +table, so its target must be on the list, while ``SHOW`` enumerates objects beyond +any single table and is rejected outright (see :ref:`allowed-tables-enforcement`). Multi-schema warehouses ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -186,11 +186,13 @@ Parameters ^^^^^^^^^^ - ``db_conn_id``: Airflow connection ID for the database. -- ``allowed_tables``: Restrict which tables the agent can discover via - ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables in - ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to span - multiple schemas; see above. Matching is case-insensitive. - See :ref:`allowed-tables-limitation` for an important caveat. +- ``allowed_tables``: Restrict the agent to a fixed set of tables. ``None`` + (default) exposes all tables in ``schema``. Entries may be schema-qualified + (``"SCHEMA.TABLE"``) to span multiple schemas; see above. Matching is + case-insensitive. When set, the list is enforced on ``query`` and + ``check_query`` as well as discovery -- every table a query references must be + on it. See :ref:`allowed-tables-enforcement` for what this does and does not + guarantee. - ``schema``: Default schema/namespace for unqualified table listing and introspection. Schema-qualified ``allowed_tables`` entries override it per table. - ``allow_writes``: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.). @@ -560,11 +562,14 @@ No single layer is sufficient — they work together. INTO, and other non-SELECT statements. - Does not prevent the agent from reading any registered data source. * - **SQLToolset: allowed_tables** - - Restricts which tables appear in ``list_tables`` and ``get_schema`` - responses, limiting the agent's knowledge of the schema. - - Does **not** validate table references in SQL queries. The agent can - still query unlisted tables if it guesses the name. See - :ref:`allowed-tables-limitation` below. + - Restricts the agent to listed tables across ``list_tables``, + ``get_schema``, ``query``, and ``check_query``. Queries are parsed and + every referenced table (including via subqueries, CTEs, JOINs, and + ``DESCRIBE``) is checked against the list before execution. + - Cannot police data reached through side-effecting scalar functions + (e.g. ``pg_read_file``), and is only as exact as the SQL parser. Pair it + with least-privilege database grants. See + :ref:`allowed-tables-enforcement` below. * - **SQLToolset: max_rows** - Truncates query results to ``max_rows`` (default 50), preventing the agent from pulling entire tables into context. @@ -575,21 +580,38 @@ No single layer is sufficient — they work together. - Requires explicit configuration — the default allows many rounds. -.. _allowed-tables-limitation: +.. _allowed-tables-enforcement: -The ``allowed_tables`` Limitation -""""""""""""""""""""""""""""""""" +How ``allowed_tables`` Is Enforced +"""""""""""""""""""""""""""""""""" -``allowed_tables`` is a **metadata filter**, not an access control mechanism. -It hides table names from ``list_tables`` and blocks ``get_schema`` for -unlisted tables, but does not parse SQL queries to validate table references. +When ``allowed_tables`` is set it governs every tool, not just discovery: -An LLM can craft ``SELECT * FROM secrets`` even when -``allowed_tables=["orders"]``. Parsing SQL for table references (including -CTEs, subqueries, aliases, and vendor-specific syntax) is complex and -error-prone; we chose not to provide a false sense of security. +- ``list_tables`` and ``get_schema`` only reveal listed tables. +- ``query`` and ``check_query`` parse the SQL with `sqlglot + `_ and reject it before execution if it + references any table that is not on the list. Tables reached indirectly are + caught too -- through subqueries, CTEs, JOINs, set operations (``UNION`` etc.), + ``DESCRIBE``, catalog views such as ``information_schema``, and DML. CTE + references are excluded by lexical scope, so a same-named CTE in another scope + cannot hide a real table, and the database/catalog is part of the match, so a + cross-database reference like ``otherdb.public.orders`` is refused. +- Constructs the list cannot describe are rejected outright while it is active: + table-valued functions (``dblink``), ``TABLE('name')`` row sources, the + ``TABLE `` shorthand, ``SHOW``, dynamic SQL (``EXEC``), and **inline + comments** -- the last because parser-vs-engine differences hide in comments + (MySQL executes ``/*! ... */`` while sqlglot and other engines ignore it). -For query-level restrictions, use database permissions: +So ``SELECT * FROM secrets`` with ``allowed_tables=["orders"]`` is refused, and +the rejection is handed back to the agent so it can re-target an allowed table. + +This is a strong **application-level guardrail**, but it is not a substitute for +database permissions. It cannot police data reached through a function whose +argument is itself SQL or a path: ``pg_read_file('/etc/passwd')`` reads a file, +and ``query_to_xml('SELECT * FROM other_table', ...)`` or a scalar ``dblink`` +reads a table through a string the parser cannot inspect. Any query the engine +parses differently from sqlglot is also a residual gap. For a hard boundary, also +run the connection as a least-privilege role: .. code-block:: sql @@ -598,8 +620,10 @@ For query-level restrictions, use database permissions: GRANT SELECT ON orders, customers TO airflow_agent_reader; -- Use this role's credentials in the Airflow connection -The Airflow connection should use a database user with the minimum privileges -required. +Defense in depth: the allow-list contains the agent's *intent* (and gives it a +correctable error), while the database role is the boundary that holds even if +the agent reaches data the parser cannot see. The connection should use a +database user with the minimum privileges required. HookToolset Guidelines diff --git a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py index 45990901e1598..a44359719698c 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py +++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py @@ -24,6 +24,9 @@ try: from airflow.providers.common.ai.utils.sql_validation import ( + SQLSafetyError, + collect_table_references, + parse_sql as _parse_sql, resolve_sqlglot_dialect, validate_sql as _validate_sql, ) @@ -94,21 +97,37 @@ class SQLToolset(AbstractToolset[Any]): toolset does not inspect the error type or message. :param db_conn_id: Airflow connection ID for the database. - :param allowed_tables: Restrict which tables the agent can discover via - ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables - in ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to span - multiple schemas in one database -- common on warehouses such as Snowflake. - ``list_tables`` then introspects each referenced schema and returns the - matching tables fully qualified, and ``get_schema`` routes to the table's - own schema. Unqualified entries use ``schema``. Matching is - case-insensitive, since databases reflect identifiers in their own case. + :param allowed_tables: Restrict the agent to a fixed set of tables. ``None`` + (default) exposes every table in ``schema``. Entries may be schema-qualified + (``"SCHEMA.TABLE"``) to span multiple schemas in one database -- common on + warehouses such as Snowflake. ``list_tables`` introspects each referenced + schema and returns the matching tables fully qualified, and ``get_schema`` + routes to the table's own schema. Unqualified entries use ``schema``. + Matching is case-insensitive, since databases reflect identifiers in their + own case. + + When set, the list is enforced on the ``query`` and ``check_query`` tools as + well as on discovery: every table a query reaches -- through subqueries, CTEs, + JOINs, set operations, ``DESCRIBE``, catalog views such as + ``information_schema``, or DML -- must be on the list, resolved with its + database/catalog, or the query is rejected before it runs. CTE references are + excluded by lexical scope (a same-named CTE in another scope never hides a real + table). Constructs the list cannot describe are rejected outright while it is + active: table-valued functions (``dblink``), ``TABLE('name')`` row sources, the + ``TABLE `` shorthand, ``SHOW``, dynamic SQL, and **inline comments** + (where parser-vs-engine differences such as MySQL ``/*! ... */`` executable + comments hide). .. note:: - ``allowed_tables`` controls metadata visibility only. It does **not** - parse or validate table references in SQL queries. An LLM can still - query tables outside this list if it guesses the name. For query-level - restrictions, use database-level permissions (e.g. a read-only role - with grants limited to specific tables). + This is an application-level guardrail, enforced by parsing the SQL with + sqlglot. It is strong defense-in-depth but not a substitute for database + permissions: it cannot police data reached through a function whose + argument is itself SQL or a path -- ``pg_read_file('...')`` (a file) or + ``query_to_xml('SELECT ... FROM other_table', ...)`` and ``dblink`` in + scalar position (a table, read through a string the parser cannot inspect) + -- and any query the engine parses differently from sqlglot is a residual + gap. For a hard guarantee, also point ``db_conn_id`` at a least-privilege + role whose ``SELECT`` grants are limited to the same tables. :param schema: Default schema/namespace for table listing and introspection, used for unqualified ``allowed_tables`` entries and unqualified @@ -131,42 +150,63 @@ def __init__( ) -> None: self._db_conn_id = db_conn_id self._allowed_tables: frozenset[str] | None = frozenset(allowed_tables) if allowed_tables else None - # Case-folded view for membership tests: databases reflect identifiers in - # their own case (Snowflake stores unquoted names uppercase but reflects - # them lowercased), so a byte-exact match against the user's entries would - # silently miss. allowed_tables is a visibility hint, not access control, - # so case-insensitive matching is safe. - self._allowed_tables_ci: frozenset[str] | None = ( - frozenset(t.casefold() for t in self._allowed_tables) - if self._allowed_tables is not None - else None - ) self._schema = schema self._allow_writes = allow_writes self._max_rows = max_rows self._hook: DbApiHook | None = None - # Derive which schemas to introspect from schema-qualified allowed_tables. + # Canonical ``(catalog, schema, table)`` view of allowed_tables for membership + # tests, plus the schemas to introspect. Built once: every reference -- a + # discovery hit, a get_schema arg, or a table parsed out of a query -- is + # normalised to the same shape and matched against this set. + # + # Identifiers are case-folded: databases reflect them in their own case + # (Snowflake stores unquoted names uppercase but reflects them lowercased), so + # a byte-exact match against the user's entries would silently miss. Unqualified + # entries resolve to the default ``schema`` (``None`` when unset) so that + # ``"orders"`` and ``".orders"`` denote the same table. Allow-list + # entries carry no catalog, so any catalog-qualified reference + # (``otherdb.public.orders``) has a non-null catalog in its key and cannot match + # -- that closes cross-database access the single-connection allow-list can't + # describe. + self._allowed_canonical: frozenset[tuple[str | None, str | None, str]] | None = None # Qualified entries ("SCHEMA.TABLE") are listed under their own schema and # returned fully qualified; unqualified entries (and allow-all) use the # default ``schema``. self._qualified_schemas: frozenset[str] = frozenset() self._include_default_schema: bool = True if self._allowed_tables is not None: + canonical: set[tuple[str | None, str | None, str]] = set() qualified_schemas: set[str] = set() include_default = False for entry in self._allowed_tables: - entry_schema, sep, _ = entry.rpartition(".") + entry_schema, sep, table = entry.rpartition(".") if sep: qualified_schemas.add(entry_schema) + canonical.add(self._canonical_ref("", entry_schema, table)) else: include_default = True + canonical.add(self._canonical_ref("", self._schema, entry)) + self._allowed_canonical = frozenset(canonical) self._qualified_schemas = frozenset(qualified_schemas) self._include_default_schema = include_default - def _is_table_allowed(self, name: str) -> bool: - """Case-insensitive membership test against ``allowed_tables`` (allow-all when unset).""" - return self._allowed_tables_ci is None or name.casefold() in self._allowed_tables_ci + @staticmethod + def _canonical_ref( + catalog: str | None, schema: str | None, table: str + ) -> tuple[str | None, str | None, str]: + """Normalise a ``(catalog, schema, table)`` reference to its case-folded comparison key.""" + return ( + catalog.casefold() if catalog else None, + schema.casefold() if schema else None, + table.casefold(), + ) + + def _is_ref_allowed(self, catalog: str | None, schema: str | None, table: str) -> bool: + """Membership test for a resolved ``(catalog, schema, table)`` reference (allow-all when unset).""" + if self._allowed_canonical is None: + return True + return self._canonical_ref(catalog, schema, table) in self._allowed_canonical @property def id(self) -> str: @@ -263,11 +303,11 @@ def _list_tables(self) -> str: # Dedupe by (schema, table) so a table reachable both qualified and via the # default schema (e.g. "public.users" and "users" with schema="public") is # listed once. Case-folded because databases reflect identifiers in their case. - seen: set[tuple[str | None, str]] = set() + seen: set[tuple[str | None, str | None, str]] = set() def add(schema: str | None, name: str, display: str) -> None: - key = (schema.casefold() if schema else None, name.casefold()) - if self._is_table_allowed(display) and key not in seen: + key = self._canonical_ref("", schema, name) + if self._is_ref_allowed("", schema, name) and key not in seen: seen.add(key) tables.append(display) @@ -286,10 +326,10 @@ def add(schema: str | None, name: str, display: str) -> None: return json.dumps(tables) def _get_schema(self, table_name: str) -> str: - if not self._is_table_allowed(table_name): + schema, table = self._split_table_identifier(table_name) + if not self._is_ref_allowed("", schema, table): return json.dumps({"error": f"Table {table_name!r} is not in the allowed tables list."}) hook = self._get_db_hook() - schema, table = self._split_table_identifier(table_name) columns = hook.get_table_schema(table, schema=schema) return json.dumps(columns) @@ -300,15 +340,19 @@ def _dialect_for_validation(self) -> str | None: def _query(self, sql: str) -> str: hook = self._get_db_hook() + dialect = self._dialect_for_validation() + statements: list[Any] | None = None if not self._allow_writes: # allow_read_only_metadata lets agents inspect schemas with DESCRIBE/SHOW # (a common first move) instead of hard-failing; the deep scan still # rejects any data-modifying statement, including EXPLAIN . - _validate_sql( - sql, - dialect=self._dialect_for_validation(), - allow_read_only_metadata=True, - ) + statements = _validate_sql(sql, dialect=dialect, allow_read_only_metadata=True) + elif self._allowed_canonical is not None: + # Writes are allowed but tables are restricted: parse anyway so the + # allow-list still governs which tables a write may touch. + statements = _parse_sql(sql, dialect=dialect) + if statements is not None: + self._enforce_allowed_tables(statements) rows = hook.get_records(sql) # Fetch column names from cursor description. @@ -336,7 +380,40 @@ def _check_query(self, sql: str) -> str: with suppress(Exception): dialect = self._dialect_for_validation() try: - _validate_sql(sql, dialect=dialect, allow_read_only_metadata=True) + statements = _validate_sql(sql, dialect=dialect, allow_read_only_metadata=True) + self._enforce_allowed_tables(statements) return json.dumps({"valid": True}) except Exception as e: return json.dumps({"valid": False, "error": str(e)}) + + def _enforce_allowed_tables(self, statements: list[Any]) -> None: + """ + Reject a parsed query that reaches any table outside ``allowed_tables``. + + No-op when ``allowed_tables`` is unset (allow-all). Otherwise every table the + query references (resolved scope-correctly, including catalog) must be on the + list, and any construct the list cannot describe -- a table-valued function, + ``SHOW``, dynamic SQL, an inline comment, or the ``TABLE `` shorthand -- + is refused. Raises :class:`SQLSafetyError` -- ``call_tool`` turns it into a + ``ModelRetry`` so the agent can re-target an allowed table, while + ``check_query`` reports it invalid. + """ + if self._allowed_canonical is None: + return + scan = collect_table_references(statements) + if scan.unverifiable_sources: + raise SQLSafetyError( + f"Query uses a data source that cannot be checked against allowed_tables: " + f"{'; '.join(scan.unverifiable_sources)}. Query the allowed tables directly: " + f"use list_tables to see them." + ) + disallowed = [ + ".".join(part for part in (catalog, schema, table) if part) + for catalog, schema, table in scan.tables + if not self._is_ref_allowed(catalog, schema or self._schema, table) + ] + if disallowed: + raise SQLSafetyError( + f"Query references tables that are not in the allowed tables list: " + f"{', '.join(sorted(set(disallowed)))}. Use list_tables to see the allowed tables." + ) diff --git a/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py index a00b4dc11e62d..3b39210710e52 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py +++ b/providers/common/ai/src/airflow/providers/common/ai/utils/sql_validation.py @@ -24,6 +24,8 @@ from __future__ import annotations +from typing import NamedTuple + import sqlglot from sqlglot import exp from sqlglot.dialects import Dialects @@ -105,6 +107,245 @@ class SQLSafetyError(Exception): """Generated SQL failed safety validation.""" +def parse_sql( + sql: str, + *, + dialect: str | None = None, + allow_multiple_statements: bool = False, +) -> list[exp.Expr]: + """ + Parse SQL into statements, enforcing the empty- and multi-statement guards only. + + Shared by :func:`validate_sql` (which then applies statement-type checks) and by + callers that need the parsed AST for their own analysis -- e.g. table-reference + extraction for ``allowed_tables`` enforcement -- without the read-only allow-list. + + :param sql: SQL string to parse. + :param dialect: SQL dialect for parsing (``postgres``, ``mysql``, etc.). + :param allow_multiple_statements: Whether to allow multiple semicolon-separated + statements. Default ``False`` -- multi-statement input can hide a dangerous + operation after a benign one. + :return: List of parsed sqlglot Expression objects (never empty). + :raises SQLSafetyError: If the SQL is empty, cannot be parsed, or contains multiple + statements when not permitted. + """ + if not sql or not sql.strip(): + raise SQLSafetyError("Empty SQL input.") + + try: + statements = sqlglot.parse(sql, dialect=dialect, error_level=ErrorLevel.RAISE) + except sqlglot.errors.ParseError as e: + raise SQLSafetyError(f"SQL parse error: {e}") from e + + # sqlglot.parse can return [None] for empty input + parsed = [s for s in statements if s is not None] + if not parsed: + raise SQLSafetyError("Empty SQL input.") + + if not allow_multiple_statements and len(parsed) > 1: + raise SQLSafetyError( + f"Multiple statements detected ({len(parsed)}). Only single statements are allowed by default." + ) + return parsed + + +class TableScan(NamedTuple): + """Result of :func:`collect_table_references`.""" + + #: ``(catalog, schema, table)`` for every real base table referenced anywhere in + #: the AST. ``catalog`` and ``schema`` are ``""`` when the reference omits them. + #: In-scope CTE references are excluded. Catalog is reported so the caller can + #: reject cross-database references (``otherdb.public.orders``) that a + #: ``schema.table`` allow-list cannot describe. + tables: list[tuple[str, str, str]] + #: Human-readable descriptions of constructs that cannot be checked against an + #: allow-list and so must be rejected while one is active: table-valued functions + #: (``dblink``), ``TABLE('name')`` row sources, ``SHOW``, dynamic SQL + #: (``EXEC``/``Command``), inline comments (a parser-vs-engine differential), and + #: the ``TABLE `` shorthand. Empty when every construct is verifiable. + unverifiable_sources: list[str] + + +_DML_TYPES: tuple[type[exp.Expr], ...] = (exp.Insert, exp.Update, exp.Delete, exp.Merge) + + +def _same_identifier(a: exp.Identifier, b: exp.Identifier) -> bool: + """ + Compare two identifiers under standard identifier-folding rules. + + Unquoted names fold (case-insensitive); quoted names are case-preserving and + distinct from unquoted ones. Used to decide whether a table reference names a CTE: + being *stricter* here is safe -- a near-miss falls through to the allow-list check. + """ + aq, bq = bool(a.args.get("quoted")), bool(b.args.get("quoted")) + if not aq and not bq: + return str(a.this).casefold() == str(b.this).casefold() + if aq and bq: + return str(a.this) == str(b.this) + return False + + +def _enclosing_cte(table: exp.Expr, with_: exp.With) -> exp.CTE | None: + """Return the CTE of ``with_`` whose *definition* contains ``table`` (else ``None``).""" + node = table.parent + while node is not None and node is not with_.parent: + if isinstance(node, exp.CTE) and node.parent is with_: + return node + node = node.parent + return None + + +def _is_in_scope_cte(table: exp.Table) -> bool: + """ + Report whether ``table`` is a bare reference resolved by a CTE visible at its scope. + + Walks the ancestor chain (lexical scope) collecting CTE names from each enclosing + ``WITH``. A CTE defined in a *sibling* or *inner* subquery is not an ancestor, so a + real top-level table is never excluded by an unrelated same-named CTE + (``SELECT * FROM secret WHERE id IN (WITH secret AS (...) SELECT ...)``). A + non-recursive CTE is not visible inside its own definition, so + ``WITH secret AS (SELECT * FROM secret) ...`` still reports the real ``secret``. + CTE order matters too: inside one CTE's body only *earlier* siblings are in scope + (forward references need ``RECURSIVE``), so ``WITH a AS (SELECT * FROM secret), + secret AS (...) SELECT * FROM a`` still reports the real ``secret`` read by ``a``. + """ + ref = table.this + if not isinstance(ref, exp.Identifier): + return False + node: exp.Expr | None = table.parent + while node is not None: + # A WITH attaches to its owning query (Select/Union/DML) as a sibling of the + # body, so the query -- an ancestor of the table -- holds it. Find it by type + # rather than a fixed arg key (sqlglot has used both ``with`` and ``with_``). + with_ = ( + next((v for v in node.args.values() if isinstance(v, exp.With)), None) + if isinstance(node, exp.Expression) + else None + ) + if isinstance(with_, exp.With): + recursive = bool(with_.args.get("recursive")) + ctes = list(with_.expressions) + enclosing = _enclosing_cte(table, with_) + # If the reference sits inside CTE E's own body, only CTEs defined *before* + # E are visible there (plus E itself when RECURSIVE); a CTE defined after E + # is not yet in scope. In the main query body every CTE is visible. + enclosing_idx = next((i for i, c in enumerate(ctes) if c is enclosing), None) + for idx, cte in enumerate(ctes): + if enclosing_idx is not None: + if idx > enclosing_idx: + continue + if idx == enclosing_idx and not recursive: + continue + alias = cte.args.get("alias") + cte_ident = alias.this if isinstance(alias, exp.TableAlias) else None + if isinstance(cte_ident, exp.Identifier) and _same_identifier(cte_ident, ref): + return True + node = node.parent + return False + + +def collect_table_references(statements: list[exp.Expr]) -> TableScan: + """ + Walk parsed statements and report every real table they reach, scope-correctly. + + This is the AST half of ``allowed_tables`` enforcement: it returns the concrete + base tables a query reaches (including those nested in subqueries, CTEs, JOINs, set + operations, ``DESCRIBE``, and DML) as ``(catalog, schema, table)`` so the caller can + check each against its allow-list, plus a list of constructs that cannot be checked + and must therefore be rejected while an allow-list is active. + + Handled carefully (each was a confirmed bypass before it was closed): + + - **CTE references are excluded by lexical scope, not by name.** A table is treated + as a CTE only when a ``WITH`` *enclosing that reference* defines the name (see + :func:`_is_in_scope_cte`); a same-named CTE in a sibling/inner query no longer + hides a real top-level table. A DML *target* is always a real table (you cannot + write to a CTE, so a same-named CTE does not shadow it), but DML *sources* follow + normal CTE scoping -- a CTE used as an INSERT/UPDATE source is not flagged. + - **Catalog-qualified references are reported with their catalog**, so the caller + rejects ``otherdb.public.orders`` instead of matching it to ``public.orders``. + - **Unverifiable constructs are listed, not silently dropped:** nameless + table-valued functions (``dblink``), ``TABLE('name')`` row sources + (``exp.TableFromRows``), ``SHOW``, dynamic SQL (``EXEC``/``Command``), the + ``TABLE `` shorthand (which sqlglot parses incorrectly, leaking the + ``TABLE`` keyword as a column), a **quoted identifier** (case-sensitive on the engine but + matched case-insensitively here, so ``"Orders"`` could otherwise reach a table + distinct from the allow-listed ``orders``), and **any inline comment** -- + comments are where parser-vs-engine differentials hide (MySQL executable + ``/*! ... */``, ``--`` not followed by whitespace, ``#``). + + :param statements: Parsed sqlglot statements (from :func:`parse_sql`). + :return: A :class:`TableScan` of real table references and unverifiable constructs. + """ + tables: list[tuple[str, str, str]] = [] + unverifiable: list[str] = [] + for stmt in statements: + # SHOW enumerates objects / leaks a table's columns outside any single table. + if isinstance(stmt, exp.Show): + unverifiable.append("a SHOW statement") + continue + # Dynamic SQL and anything sqlglot can only represent as a raw Command reach + # data through text the parser cannot inspect. + if isinstance(stmt, (exp.Command, exp.Execute)): + unverifiable.append(f"a {type(stmt).__name__.lower()} statement") + continue + + # A comment is a parser-vs-engine differential vector: sqlglot drops it, but the + # engine may execute it (MySQL `/*! ... */`) or tokenize it differently (`--` + # without a trailing space, `#`). sqlglot tokenizes string literals correctly, + # so a `--` inside a quoted string is not flagged here. + if any(node.comments for node in stmt.walk()): + unverifiable.append("an inline comment") + continue + + # `TABLE('name')` / `TABLE($$name$$)` name a table through a string the parser + # cannot resolve; sqlglot models them as TableFromRows, not exp.Table. + if any(True for _ in stmt.find_all(exp.TableFromRows)): + unverifiable.append("a TABLE(...) row source") + continue + + # `TABLE ` (Postgres/MySQL shorthand for SELECT * FROM ) is not + # modelled by sqlglot; it parses incorrectly, leaking the reserved word TABLE as an + # unquoted column identifier. No real query has an unquoted column named TABLE. + if any( + isinstance(col.this, exp.Identifier) + and not col.this.args.get("quoted") + and str(col.this.this).upper() == "TABLE" + for col in stmt.find_all(exp.Column) + ): + unverifiable.append("a TABLE shorthand") + continue + + # A DML statement's *target* (the table written to) is always a real table -- + # you cannot INSERT/UPDATE/DELETE/MERGE into a CTE, so even a same-named CTE does + # not shadow it. Its *sources* (the SELECT/USING/subqueries) follow normal CTE + # scoping, so a CTE used as a source is not mistaken for a base table. + target = stmt.args.get("this") if isinstance(stmt, _DML_TYPES) else None + target_ids = {id(t) for t in target.find_all(exp.Table)} if target is not None else set() + for table in stmt.find_all(exp.Table): + name = table.name + if not name: + unverifiable.append(f"table-valued function ({table.sql()})") + continue + # A bare, non-target reference may be a CTE; a qualified one or a DML target + # never is. + if id(table) not in target_ids and not table.db and not table.catalog and _is_in_scope_cte(table): + continue + # A quoted identifier is case-sensitive on the engine, but the allow-list is + # matched case-insensitively (and a plain ``schema.table`` string cannot + # carry quoting), so a quoted reference cannot be matched soundly: on + # Postgres/Snowflake ``"Orders"`` is a *different* table from the allow-listed + # ``orders``. Reject rather than risk reaching a case-distinct table. + if any( + isinstance(part, exp.Identifier) and part.args.get("quoted") + for part in (table.this, table.args.get("db"), table.args.get("catalog")) + ): + unverifiable.append("a quoted identifier") + continue + tables.append((table.catalog, table.db, name)) + return TableScan(tables=tables, unverifiable_sources=unverifiable) + + def validate_sql( sql: str, *, @@ -138,9 +379,6 @@ def validate_sql( :raises SQLSafetyError: If the SQL is empty, contains disallowed statement types, or has multiple statements when not permitted. """ - if not sql or not sql.strip(): - raise SQLSafetyError("Empty SQL input.") - # A caller-supplied ``allowed_types`` is an explicit opt-out of the curated # read-only defaults (and the data-modifying deep scan). Otherwise we use the # read-only defaults, optionally widened with metadata statements, and keep @@ -154,20 +392,7 @@ def validate_sql( types = allowed_types run_data_modifying_scan = types == DEFAULT_ALLOWED_TYPES - try: - statements = sqlglot.parse(sql, dialect=dialect, error_level=ErrorLevel.RAISE) - except sqlglot.errors.ParseError as e: - raise SQLSafetyError(f"SQL parse error: {e}") from e - - # sqlglot.parse can return [None] for empty input - parsed = [s for s in statements if s is not None] - if not parsed: - raise SQLSafetyError("Empty SQL input.") - - if not allow_multiple_statements and len(parsed) > 1: - raise SQLSafetyError( - f"Multiple statements detected ({len(parsed)}). Only single statements are allowed by default." - ) + parsed = parse_sql(sql, dialect=dialect, allow_multiple_statements=allow_multiple_statements) for stmt in parsed: if not isinstance(stmt, types): diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py index 92e033a6b60f6..ed0619a1db0c1 100644 --- a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py +++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py @@ -467,3 +467,285 @@ def test_check_query_handles_unresolvable_connection(self): ts.call_tool("check_query", {"sql": "SELECT 1"}, ctx=MagicMock(), tool=MagicMock()) ) assert json.loads(result)["valid"] is True + + +def _run_query(ts: SQLToolset, sql: str): + return asyncio.run(ts.call_tool("query", {"sql": sql}, ctx=MagicMock(), tool=MagicMock())) + + +def _run_check(ts: SQLToolset, sql: str): + return json.loads( + asyncio.run(ts.call_tool("check_query", {"sql": sql}, ctx=MagicMock(), tool=MagicMock())) + ) + + +class TestSQLToolsetAllowedTablesQueryEnforcement: + """``allowed_tables`` is enforced on the query/check_query tools, not just on discovery.""" + + def test_query_allows_table_on_the_list(self): + ts = SQLToolset("pg_default", allowed_tables=["orders"]) + ts._hook = _make_mock_db_hook(records=[(1,)], last_description=[("id",)]) + + result = _run_query(ts, "SELECT id FROM orders") + + assert "rows" in json.loads(result) + ts._hook.get_records.assert_called_once_with("SELECT id FROM orders") + + def test_query_blocks_table_off_the_list(self): + """The headline escape: querying a table that is not on the allow-list is refused.""" + ts = SQLToolset("pg_default", allowed_tables=["orders"]) + ts._hook = _make_mock_db_hook() + + with pytest.raises(ModelRetry) as exc_info: + _run_query(ts, "SELECT * FROM secret_salaries") + + assert "not in the allowed tables list" in exc_info.value.message + assert "secret_salaries" in exc_info.value.message + ts._hook.get_records.assert_not_called() + + @pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM (SELECT * FROM secret_salaries) x", + "WITH s AS (SELECT * FROM secret_salaries) SELECT * FROM s", + "SELECT * FROM orders JOIN secret_salaries ON orders.id = secret_salaries.id", + "SELECT * FROM orders UNION SELECT * FROM secret_salaries", + "SELECT * FROM secret_salaries WHERE id IN (SELECT id FROM orders)", + ], + ids=["subquery", "cte_body", "join", "union", "where_subquery"], + ) + def test_query_blocks_disallowed_table_reached_indirectly(self, sql): + ts = SQLToolset("pg_default", allowed_tables=["orders"]) + ts._hook = _make_mock_db_hook() + + with pytest.raises(ModelRetry) as exc_info: + _run_query(ts, sql) + + assert "secret_salaries" in exc_info.value.message + ts._hook.get_records.assert_not_called() + + def test_query_blocks_catalog_enumeration(self): + """information_schema/pg_catalog are ordinary tables, so the allow-list blocks them too.""" + ts = SQLToolset("pg_default", allowed_tables=["orders"]) + ts._hook = _make_mock_db_hook() + + with pytest.raises(ModelRetry) as exc_info: + _run_query(ts, "SELECT table_name FROM information_schema.tables") + + assert "information_schema.tables" in exc_info.value.message + ts._hook.get_records.assert_not_called() + + def test_query_allows_cte_reference_not_mistaken_for_table(self): + """A CTE whose name is not on the list is fine as long as its body stays allowed.""" + ts = SQLToolset("pg_default", allowed_tables=["orders"]) + ts._hook = _make_mock_db_hook(records=[(1,)], last_description=[("id",)]) + + result = _run_query(ts, "WITH ranked AS (SELECT * FROM orders) SELECT * FROM ranked") + + assert "rows" in json.loads(result) + + def test_query_blocks_table_valued_function(self): + """dblink reaches data through a path the list can't describe, so it is refused.""" + ts = SQLToolset("pg_default", allowed_tables=["orders"]) + ts._hook = _make_mock_db_hook() + ts._hook.dialect_name = "postgresql" + + with pytest.raises(ModelRetry) as exc_info: + _run_query(ts, "SELECT * FROM dblink('host=evil', 'SELECT 1') AS t(x int)") + + assert "cannot be checked against allowed_tables" in exc_info.value.message + ts._hook.get_records.assert_not_called() + + def test_query_blocks_show_when_allowlist_active(self): + ts = SQLToolset("sf_default", allowed_tables=["orders"]) + ts._hook = _make_mock_db_hook() + ts._hook.dialect_name = "snowflake" + + with pytest.raises(ModelRetry) as exc_info: + _run_query(ts, "SHOW TABLES") + + assert "cannot be checked against allowed_tables" in exc_info.value.message + ts._hook.get_records.assert_not_called() + + def test_query_blocks_describe_of_disallowed_table(self): + ts = SQLToolset("sf_default", allowed_tables=["orders"]) + ts._hook = _make_mock_db_hook() + ts._hook.dialect_name = "snowflake" + + with pytest.raises(ModelRetry) as exc_info: + _run_query(ts, "DESCRIBE TABLE secret_salaries") + + assert "secret_salaries" in exc_info.value.message + ts._hook.get_records.assert_not_called() + + def test_query_allows_describe_of_allowed_table(self): + ts = SQLToolset("sf_default", allowed_tables=["orders"]) + ts._hook = _make_mock_db_hook(records=[("id", "INT")], last_description=[("name",), ("type",)]) + ts._hook.dialect_name = "snowflake" + + result = _run_query(ts, "DESCRIBE TABLE orders") + + assert "rows" in json.loads(result) + ts._hook.get_records.assert_called_once_with("DESCRIBE TABLE orders") + + def test_query_allows_schema_qualified_table_on_list(self): + ts = SQLToolset("sf", allowed_tables=["MODEL_CRM.SF_ASTRO_ORGS"]) + ts._hook = _make_mock_db_hook(records=[(1,)], last_description=[("id",)]) + ts._hook.dialect_name = "snowflake" + + result = _run_query(ts, "SELECT * FROM MODEL_CRM.SF_ASTRO_ORGS") + + assert "rows" in json.loads(result) + + def test_query_unqualified_resolves_to_default_schema(self): + """``public.orders`` and ``orders`` denote the same table when schema='public'.""" + ts = SQLToolset("pg", allowed_tables=["orders"], schema="public") + ts._hook = _make_mock_db_hook(records=[(1,)], last_description=[("id",)]) + + # Qualifying with the default schema must still match the bare allow-list entry. + result = _run_query(ts, "SELECT * FROM public.orders") + assert "rows" in json.loads(result) + + def test_no_allowlist_leaves_queries_unrestricted(self): + """Without allowed_tables the query tool behaves exactly as before (allow-all).""" + ts = SQLToolset("pg_default") + ts._hook = _make_mock_db_hook(records=[(1,)], last_description=[("id",)]) + + result = _run_query(ts, "SELECT * FROM anything_at_all") + + assert "rows" in json.loads(result) + ts._hook.get_records.assert_called_once_with("SELECT * FROM anything_at_all") + + def test_check_query_reports_disallowed_table_as_invalid(self): + ts = SQLToolset("pg_default", allowed_tables=["orders"]) + ts._hook = _make_mock_db_hook() + + data = _run_check(ts, "SELECT * FROM secret_salaries") + + assert data["valid"] is False + assert "secret_salaries" in data["error"] + + def test_check_query_valid_for_allowed_table(self): + ts = SQLToolset("pg_default", allowed_tables=["orders"]) + ts._hook = _make_mock_db_hook() + + assert _run_check(ts, "SELECT * FROM orders")["valid"] is True + + def test_writes_still_bounded_by_allowed_tables(self): + """allow_writes widens the statement types, but the allow-list still scopes the target.""" + ts = SQLToolset("pg_default", allowed_tables=["orders"], allow_writes=True) + ts._hook = _make_mock_db_hook(records=[], last_description=None) + + # An allowed target is written. + _run_query(ts, "INSERT INTO orders (id) VALUES (1)") + ts._hook.get_records.assert_called_once_with("INSERT INTO orders (id) VALUES (1)") + + # A disallowed target is refused before execution. + ts._hook.get_records.reset_mock() + with pytest.raises(ModelRetry) as exc_info: + _run_query(ts, "INSERT INTO secret_salaries (id) VALUES (1)") + assert "secret_salaries" in exc_info.value.message + ts._hook.get_records.assert_not_called() + + def test_writes_reject_dynamic_sql_the_parser_cannot_inspect(self): + """allow_writes skips the read-only validator, so the allow-list must still + refuse dynamic SQL (EXEC/EXECUTE) whose table access is opaque.""" + ts = SQLToolset("mssql_default", allowed_tables=["orders"], allow_writes=True) + ts._hook = _make_mock_db_hook() + ts._hook.dialect_name = "mssql" + + with pytest.raises(ModelRetry) as exc_info: + _run_query(ts, "EXEC sp_who") + + assert "cannot be checked against allowed_tables" in exc_info.value.message + ts._hook.get_records.assert_not_called() + + +class TestSQLToolsetAllowedTablesBypassRegressions: + """Regression tests for bypasses found by adversarial red-teaming of the allow-list.""" + + @pytest.mark.parametrize( + ("sql", "dialect", "allow_writes"), + [ + # CTE scope: a same-named CTE in an inner/sibling scope must not hide the real table. + ( + "SELECT * FROM secret_salaries WHERE id IN " + "(WITH secret_salaries AS (SELECT 1 id) SELECT id FROM secret_salaries)", + "postgresql", + False, + ), + # Non-recursive CTE is not in scope within its own body. + ( + "WITH secret_salaries AS (SELECT * FROM secret_salaries) SELECT * FROM secret_salaries", + "postgresql", + False, + ), + # A CTE may only reference earlier siblings; a later-defined name is the real table. + ( + "WITH a AS (SELECT * FROM secret_salaries), secret_salaries AS (SELECT 1 id) SELECT * FROM a", + "postgresql", + False, + ), + # Cross-database / catalog qualifier the schema.table allow-list cannot describe. + ("SELECT * FROM secretdb.public.orders", "snowflake", False), + ("SELECT * FROM secret_salaries..orders", "mssql", False), + # MySQL executable comments execute on the engine but sqlglot treats them as inert. + ("SELECT * FROM orders/*!UNION SELECT * FROM secret_salaries*/", "mysql", False), + ("SELECT id FROM orders /*!50000 UNION SELECT id FROM secret_salaries */", "mysql", False), + # TABLE shorthand (mis-parsed) and TABLE('name') row source (string-named). + ("TABLE secret_salaries UNION SELECT * FROM orders", "postgresql", False), + ("SELECT * FROM TABLE('secret_salaries')", "snowflake", False), + # Write-mode CTE shadowing the DML target. + ("WITH secret_salaries AS (SELECT 1) DELETE FROM secret_salaries", "postgresql", True), + # Quoted identifier is case-distinct on the engine but case-folds into the list. + ('SELECT * FROM "Orders"', "postgresql", False), + # A DML source CTE whose body reads an off-list table is still caught. + ( + "WITH src AS (SELECT * FROM secret_salaries) INSERT INTO orders SELECT * FROM src", + "postgresql", + True, + ), + ], + ids=[ + "cte_inner_shadow", + "cte_self_body", + "cte_forward_ref", + "catalog_cross_db", + "mssql_empty_middle", + "mysql_exec_comment", + "mysql_versioned_comment", + "table_shorthand", + "table_row_source", + "write_cte_target", + "quoted_case_distinct", + "dml_cte_body_reads_offlist", + ], + ) + def test_known_bypasses_are_rejected(self, sql, dialect, allow_writes): + ts = SQLToolset("c", allowed_tables=["orders"], allow_writes=allow_writes) + ts._hook = _make_mock_db_hook() + ts._hook.dialect_name = dialect + + with pytest.raises(ModelRetry): + _run_query(ts, sql) + ts._hook.get_records.assert_not_called() + + def test_legit_cte_over_allowed_table_still_runs(self): + """The scope-aware fix must not false-reject a genuine CTE over an allowed table.""" + ts = SQLToolset("c", allowed_tables=["orders"]) + ts._hook = _make_mock_db_hook(records=[(1,)], last_description=[("id",)]) + + result = _run_query(ts, "WITH ranked AS (SELECT * FROM orders) SELECT * FROM ranked") + + assert "rows" in json.loads(result) + ts._hook.get_records.assert_called_once() + + def test_dml_with_cte_source_over_allowed_table_runs(self): + """A CTE used as a DML source must not be mistaken for a disallowed base table.""" + ts = SQLToolset("c", allowed_tables=["orders"], allow_writes=True) + ts._hook = _make_mock_db_hook(records=[], last_description=None) + + sql = "WITH src AS (SELECT * FROM orders) INSERT INTO orders SELECT * FROM src" + _run_query(ts, sql) + + ts._hook.get_records.assert_called_once_with(sql) diff --git a/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py b/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py index 9ca6604ba5e61..257761d78e06a 100644 --- a/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py +++ b/providers/common/ai/tests/unit/common/ai/utils/test_sql_validation.py @@ -21,6 +21,8 @@ from airflow.providers.common.ai.utils.sql_validation import ( SQLSafetyError, + collect_table_references, + parse_sql, resolve_sqlglot_dialect, validate_sql, ) @@ -312,3 +314,193 @@ class TestResolveSqlglotDialect: ) def test_resolution(self, dialect_name, expected): assert resolve_sqlglot_dialect(dialect_name) == expected + + +class TestParseSQL: + """``parse_sql`` enforces only the empty- and multi-statement guards.""" + + def test_returns_statements(self): + parsed = parse_sql("SELECT 1") + assert len(parsed) == 1 + assert isinstance(parsed[0], exp.Select) + + def test_does_not_apply_type_checks(self): + """Unlike validate_sql, parse_sql accepts writes -- callers add their own policy.""" + parsed = parse_sql("DELETE FROM users WHERE id = 1") + assert isinstance(parsed[0], exp.Delete) + + @pytest.mark.parametrize("sql", ["", " ", "\n\t"]) + def test_rejects_empty(self, sql): + with pytest.raises(SQLSafetyError, match="Empty SQL"): + parse_sql(sql) + + def test_rejects_multiple_statements_by_default(self): + with pytest.raises(SQLSafetyError, match="Multiple statements"): + parse_sql("SELECT 1; SELECT 2") + + def test_allows_multiple_statements_when_opted_in(self): + assert len(parse_sql("SELECT 1; SELECT 2", allow_multiple_statements=True)) == 2 + + def test_rejects_unparsable(self): + with pytest.raises(SQLSafetyError, match="parse error"): + parse_sql("SELECT FROM WHERE )(") + + +class TestCollectTableReferences: + """``collect_table_references`` reports the real tables a query reaches.""" + + @pytest.mark.parametrize( + ("sql", "dialect", "expected"), + [ + ("SELECT * FROM secret", None, [("", "", "secret")]), + ("SELECT * FROM model_crm.orders", None, [("", "model_crm", "orders")]), + ("SELECT * FROM a JOIN b ON a.id = b.id", None, [("", "", "a"), ("", "", "b")]), + ("SELECT * FROM (SELECT * FROM inner_t) x", None, [("", "", "inner_t")]), + ("SELECT * FROM a UNION SELECT * FROM b", None, [("", "", "a"), ("", "", "b")]), + ( + "SELECT table_name FROM information_schema.tables", + "postgres", + [("", "information_schema", "tables")], + ), + ("DESCRIBE secret", "mysql", [("", "", "secret")]), + # Cross-database reference carries its catalog so the caller can reject it. + ("SELECT * FROM otherdb.public.orders", "snowflake", [("otherdb", "public", "orders")]), + ], + ids=["table", "qualified", "join", "subquery", "union", "catalog", "describe", "cross_db"], + ) + def test_collects_real_tables(self, sql, dialect, expected): + scan = collect_table_references(parse_sql(sql, dialect=dialect)) + assert sorted(scan.tables) == sorted(expected) + assert scan.unverifiable_sources == [] + + def test_excludes_cte_reference_but_keeps_its_body(self): + scan = collect_table_references(parse_sql("WITH s AS (SELECT * FROM base) SELECT * FROM s")) + assert scan.tables == [("", "", "base")] # 's' is the CTE, not a table + + def test_cte_that_shadows_a_table_name_yields_no_table(self): + scan = collect_table_references(parse_sql("WITH secret AS (SELECT 1 AS x) SELECT * FROM secret")) + assert scan.tables == [] + + def test_schema_qualified_name_is_never_treated_as_a_cte(self): + scan = collect_table_references(parse_sql("WITH s AS (SELECT 1 AS x) SELECT * FROM myschema.s")) + assert scan.tables == [("", "myschema", "s")] + + def test_inner_cte_does_not_shadow_outer_real_table(self): + """A same-named CTE in an inner subquery must not hide the real top-level table.""" + sql = "SELECT * FROM secret WHERE id IN (WITH secret AS (SELECT 1 id) SELECT id FROM secret)" + scan = collect_table_references(parse_sql(sql)) + assert ("", "", "secret") in scan.tables # the top-level real table is reported + + def test_cte_self_body_references_real_table(self): + """A non-recursive CTE is not in scope within its own body, so the real table shows.""" + scan = collect_table_references( + parse_sql("WITH secret AS (SELECT * FROM secret) SELECT * FROM secret") + ) + assert ("", "", "secret") in scan.tables + + def test_cte_forward_reference_is_real_table(self): + """A CTE may only reference earlier siblings; a later-defined name is the real table.""" + sql = "WITH a AS (SELECT * FROM secret), secret AS (SELECT 1 id) SELECT * FROM a" + scan = collect_table_references(parse_sql(sql)) + assert ("", "", "secret") in scan.tables + + def test_legit_cte_reference_is_excluded(self): + """A genuine CTE reference is not reported as a base table (no false reject).""" + scan = collect_table_references( + parse_sql("WITH ranked AS (SELECT * FROM orders) SELECT * FROM ranked") + ) + assert scan.tables == [("", "", "orders")] + + @pytest.mark.parametrize( + ("sql", "dialect"), + [ + ("SELECT * FROM orders/*!UNION SELECT * FROM secret*/", "mysql"), + ("SELECT * FROM orders /*!50000 UNION SELECT * FROM secret */", "mysql"), + ("SELECT * FROM orders WHERE 0--+1 OR id IN (SELECT id FROM secret)", "mysql"), + ], + ids=["exec_comment", "versioned_comment", "dashdash"], + ) + def test_flags_inline_comment_as_unverifiable(self, sql, dialect): + """Comments hide parser-vs-engine differentials (MySQL executable comments), so reject them.""" + scan = collect_table_references(parse_sql(sql, dialect=dialect)) + assert scan.unverifiable_sources + + @pytest.mark.parametrize( + "sql", + ["SELECT * FROM TABLE('secret')", "SELECT * FROM TABLE($$secret$$)"], + ids=["string", "dollar"], + ) + def test_flags_table_row_source_as_unverifiable(self, sql): + """Snowflake TABLE('name') names a table through a string the parser can't resolve.""" + scan = collect_table_references(parse_sql(sql, dialect="snowflake")) + assert scan.unverifiable_sources + + def test_flags_table_shorthand_as_unverifiable(self): + """The TABLE shorthand is mis-parsed by sqlglot, so reject it.""" + scan = collect_table_references( + parse_sql("TABLE secret UNION SELECT * FROM orders", dialect="postgres") + ) + assert scan.unverifiable_sources + + @pytest.mark.parametrize( + "sql", + ['SELECT * FROM public."Orders"', 'SELECT * FROM "Orders"', 'SELECT * FROM "PUBLIC".orders'], + ids=["quoted_table", "quoted_bare", "quoted_schema"], + ) + def test_flags_quoted_identifier_as_unverifiable(self, sql): + """A quoted identifier is case-sensitive; case-insensitive matching can't verify it.""" + scan = collect_table_references(parse_sql(sql, dialect="postgres")) + assert scan.unverifiable_sources + + def test_dml_target_is_real_but_cte_source_is_excluded(self): + """A CTE used as a DML source is not a base table; the target and CTE body are.""" + sql = "WITH src AS (SELECT * FROM orders) INSERT INTO orders SELECT * FROM src" + scan = collect_table_references(parse_sql(sql, dialect="postgres")) + # Only the real table `orders` is reported (target + CTE body); `src` is the CTE. + assert {t for _, _, t in scan.tables} == {"orders"} + assert scan.unverifiable_sources == [] + + def test_dml_target_shadowed_by_cte_is_still_reported(self): + """A DML target is a real table even when a same-named CTE exists (can't write a CTE).""" + sql = "WITH secret AS (SELECT 1) DELETE FROM secret" + scan = collect_table_references(parse_sql(sql, dialect="postgres")) + assert ("", "", "secret") in scan.tables + + @pytest.mark.parametrize( + ("sql", "dialect"), + [ + ("SELECT * FROM dblink('h', 'SELECT 1') AS t(x int)", "postgres"), + ("SELECT * FROM generate_series(1, 10)", "postgres"), + ], + ids=["dblink", "generate_series"], + ) + def test_flags_table_valued_functions_as_unverifiable(self, sql, dialect): + scan = collect_table_references(parse_sql(sql, dialect=dialect)) + assert scan.tables == [] + assert scan.unverifiable_sources + + @pytest.mark.parametrize( + ("sql", "dialect"), + [("SHOW TABLES", "snowflake"), ("SHOW COLUMNS FROM secret", "mysql")], + ids=["show_tables", "show_columns"], + ) + def test_flags_show_as_unverifiable(self, sql, dialect): + scan = collect_table_references(parse_sql(sql, dialect=dialect)) + assert scan.unverifiable_sources + + def test_scalar_function_without_table_has_no_references(self): + """A scalar function call references no table -- the allow-list does not cover it.""" + scan = collect_table_references(parse_sql("SELECT pg_read_file('/etc/passwd')", dialect="postgres")) + assert scan.tables == [] + assert scan.unverifiable_sources == [] + + @pytest.mark.parametrize( + ("sql", "dialect"), + [("EXEC sp_who", "tsql"), ("EXECUTE my_proc", "tsql")], + ids=["exec", "execute"], + ) + def test_flags_dynamic_sql_as_unverifiable(self, sql, dialect): + """EXEC/EXECUTE hide their table access in text the parser can't read.""" + scan = collect_table_references(parse_sql(sql, dialect=dialect)) + assert scan.tables == [] + assert scan.unverifiable_sources