Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 51 additions & 27 deletions providers/common/ai/docs/toolsets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ the validator, so ``SHOW`` is recognized on databases that support it (Snowflake
MySQL, etc.); on databases without ``SHOW`` it stays rejected. Data-modifying
statements remain blocked -- including ones hidden behind ``DESCRIBE``/``EXPLAIN``
(e.g. ``EXPLAIN DELETE ...``, ``DESCRIBE DROP TABLE ...``), which the validator
rejects by scanning the parsed statement for write operations. Like ``SELECT``,
metadata statements are not scoped by ``allowed_tables`` (see
:ref:`allowed-tables-limitation`) -- an agent can ``DESCRIBE`` a table outside the
list, so rely on database permissions to restrict access.
rejects by scanning the parsed statement for write operations. When
``allowed_tables`` is set it scopes these statements too: a ``DESCRIBE`` names a
table, so its target must be on the list, while ``SHOW`` enumerates objects beyond
any single table and is rejected outright (see :ref:`allowed-tables-enforcement`).

Multi-schema warehouses
^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -186,11 +186,13 @@ Parameters
^^^^^^^^^^

- ``db_conn_id``: Airflow connection ID for the database.
- ``allowed_tables``: Restrict which tables the agent can discover via
``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables in
``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to span
multiple schemas; see above. Matching is case-insensitive.
See :ref:`allowed-tables-limitation` for an important caveat.
- ``allowed_tables``: Restrict the agent to a fixed set of tables. ``None``
(default) exposes all tables in ``schema``. Entries may be schema-qualified
(``"SCHEMA.TABLE"``) to span multiple schemas; see above. Matching is
case-insensitive. When set, the list is enforced on ``query`` and
``check_query`` as well as discovery -- every table a query references must be
on it. See :ref:`allowed-tables-enforcement` for what this does and does not
guarantee.
- ``schema``: Default schema/namespace for unqualified table listing and
introspection. Schema-qualified ``allowed_tables`` entries override it per table.
- ``allow_writes``: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.).
Expand Down Expand Up @@ -560,11 +562,14 @@ No single layer is sufficient — they work together.
INTO, and other non-SELECT statements.
- Does not prevent the agent from reading any registered data source.
* - **SQLToolset: allowed_tables**
- Restricts which tables appear in ``list_tables`` and ``get_schema``
responses, limiting the agent's knowledge of the schema.
- Does **not** validate table references in SQL queries. The agent can
still query unlisted tables if it guesses the name. See
:ref:`allowed-tables-limitation` below.
- Restricts the agent to listed tables across ``list_tables``,
``get_schema``, ``query``, and ``check_query``. Queries are parsed and
every referenced table (including via subqueries, CTEs, JOINs, and
``DESCRIBE``) is checked against the list before execution.
- Cannot police data reached through side-effecting scalar functions
(e.g. ``pg_read_file``), and is only as exact as the SQL parser. Pair it
with least-privilege database grants. See
:ref:`allowed-tables-enforcement` below.
* - **SQLToolset: max_rows**
- Truncates query results to ``max_rows`` (default 50), preventing the
agent from pulling entire tables into context.
Expand All @@ -575,21 +580,38 @@ No single layer is sufficient — they work together.
- Requires explicit configuration — the default allows many rounds.


.. _allowed-tables-limitation:
.. _allowed-tables-enforcement:

The ``allowed_tables`` Limitation
"""""""""""""""""""""""""""""""""
How ``allowed_tables`` Is Enforced
""""""""""""""""""""""""""""""""""

``allowed_tables`` is a **metadata filter**, not an access control mechanism.
It hides table names from ``list_tables`` and blocks ``get_schema`` for
unlisted tables, but does not parse SQL queries to validate table references.
When ``allowed_tables`` is set it governs every tool, not just discovery:

An LLM can craft ``SELECT * FROM secrets`` even when
``allowed_tables=["orders"]``. Parsing SQL for table references (including
CTEs, subqueries, aliases, and vendor-specific syntax) is complex and
error-prone; we chose not to provide a false sense of security.
- ``list_tables`` and ``get_schema`` only reveal listed tables.
- ``query`` and ``check_query`` parse the SQL with `sqlglot
<https://github.com/tobymao/sqlglot>`_ and reject it before execution if it
references any table that is not on the list. Tables reached indirectly are
caught too -- through subqueries, CTEs, JOINs, set operations (``UNION`` etc.),
``DESCRIBE``, catalog views such as ``information_schema``, and DML. CTE
references are excluded by lexical scope, so a same-named CTE in another scope
cannot hide a real table, and the database/catalog is part of the match, so a
cross-database reference like ``otherdb.public.orders`` is refused.
- Constructs the list cannot describe are rejected outright while it is active:
table-valued functions (``dblink``), ``TABLE('name')`` row sources, the
``TABLE <name>`` shorthand, ``SHOW``, dynamic SQL (``EXEC``), and **inline
comments** -- the last because parser-vs-engine differences hide in comments
(MySQL executes ``/*! ... */`` while sqlglot and other engines ignore it).

For query-level restrictions, use database permissions:
So ``SELECT * FROM secrets`` with ``allowed_tables=["orders"]`` is refused, and
the rejection is handed back to the agent so it can re-target an allowed table.

This is a strong **application-level guardrail**, but it is not a substitute for
database permissions. It cannot police data reached through a function whose
argument is itself SQL or a path: ``pg_read_file('/etc/passwd')`` reads a file,
and ``query_to_xml('SELECT * FROM other_table', ...)`` or a scalar ``dblink``
reads a table through a string the parser cannot inspect. Any query the engine
parses differently from sqlglot is also a residual gap. For a hard boundary, also
run the connection as a least-privilege role:

.. code-block:: sql

Expand All @@ -598,8 +620,10 @@ For query-level restrictions, use database permissions:
GRANT SELECT ON orders, customers TO airflow_agent_reader;
-- Use this role's credentials in the Airflow connection

The Airflow connection should use a database user with the minimum privileges
required.
Defense in depth: the allow-list contains the agent's *intent* (and gives it a
correctable error), while the database role is the boundary that holds even if
the agent reaches data the parser cannot see. The connection should use a
database user with the minimum privileges required.


HookToolset Guidelines
Expand Down
155 changes: 116 additions & 39 deletions providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

try:
from airflow.providers.common.ai.utils.sql_validation import (
SQLSafetyError,
collect_table_references,
parse_sql as _parse_sql,
resolve_sqlglot_dialect,
validate_sql as _validate_sql,
)
Expand Down Expand Up @@ -94,21 +97,37 @@ class SQLToolset(AbstractToolset[Any]):
toolset does not inspect the error type or message.

:param db_conn_id: Airflow connection ID for the database.
:param allowed_tables: Restrict which tables the agent can discover via
``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables
in ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to span
multiple schemas in one database -- common on warehouses such as Snowflake.
``list_tables`` then introspects each referenced schema and returns the
matching tables fully qualified, and ``get_schema`` routes to the table's
own schema. Unqualified entries use ``schema``. Matching is
case-insensitive, since databases reflect identifiers in their own case.
:param allowed_tables: Restrict the agent to a fixed set of tables. ``None``
(default) exposes every table in ``schema``. Entries may be schema-qualified
(``"SCHEMA.TABLE"``) to span multiple schemas in one database -- common on
warehouses such as Snowflake. ``list_tables`` introspects each referenced
schema and returns the matching tables fully qualified, and ``get_schema``
routes to the table's own schema. Unqualified entries use ``schema``.
Matching is case-insensitive, since databases reflect identifiers in their
own case.

When set, the list is enforced on the ``query`` and ``check_query`` tools as
well as on discovery: every table a query reaches -- through subqueries, CTEs,
JOINs, set operations, ``DESCRIBE``, catalog views such as
``information_schema``, or DML -- must be on the list, resolved with its
database/catalog, or the query is rejected before it runs. CTE references are
excluded by lexical scope (a same-named CTE in another scope never hides a real
table). Constructs the list cannot describe are rejected outright while it is
active: table-valued functions (``dblink``), ``TABLE('name')`` row sources, the
``TABLE <name>`` shorthand, ``SHOW``, dynamic SQL, and **inline comments**
(where parser-vs-engine differences such as MySQL ``/*! ... */`` executable
comments hide).

.. note::
``allowed_tables`` controls metadata visibility only. It does **not**
parse or validate table references in SQL queries. An LLM can still
query tables outside this list if it guesses the name. For query-level
restrictions, use database-level permissions (e.g. a read-only role
with grants limited to specific tables).
This is an application-level guardrail, enforced by parsing the SQL with
sqlglot. It is strong defense-in-depth but not a substitute for database
permissions: it cannot police data reached through a function whose
argument is itself SQL or a path -- ``pg_read_file('...')`` (a file) or
``query_to_xml('SELECT ... FROM other_table', ...)`` and ``dblink`` in
scalar position (a table, read through a string the parser cannot inspect)
-- and any query the engine parses differently from sqlglot is a residual
gap. For a hard guarantee, also point ``db_conn_id`` at a least-privilege
role whose ``SELECT`` grants are limited to the same tables.

:param schema: Default schema/namespace for table listing and introspection,
used for unqualified ``allowed_tables`` entries and unqualified
Expand All @@ -131,42 +150,63 @@ def __init__(
) -> None:
self._db_conn_id = db_conn_id
self._allowed_tables: frozenset[str] | None = frozenset(allowed_tables) if allowed_tables else None
# Case-folded view for membership tests: databases reflect identifiers in
# their own case (Snowflake stores unquoted names uppercase but reflects
# them lowercased), so a byte-exact match against the user's entries would
# silently miss. allowed_tables is a visibility hint, not access control,
# so case-insensitive matching is safe.
self._allowed_tables_ci: frozenset[str] | None = (
frozenset(t.casefold() for t in self._allowed_tables)
if self._allowed_tables is not None
else None
)
self._schema = schema
self._allow_writes = allow_writes
self._max_rows = max_rows
self._hook: DbApiHook | None = None

# Derive which schemas to introspect from schema-qualified allowed_tables.
# Canonical ``(catalog, schema, table)`` view of allowed_tables for membership
# tests, plus the schemas to introspect. Built once: every reference -- a
# discovery hit, a get_schema arg, or a table parsed out of a query -- is
# normalised to the same shape and matched against this set.
#
# Identifiers are case-folded: databases reflect them in their own case
# (Snowflake stores unquoted names uppercase but reflects them lowercased), so
# a byte-exact match against the user's entries would silently miss. Unqualified
# entries resolve to the default ``schema`` (``None`` when unset) so that
# ``"orders"`` and ``"<schema>.orders"`` denote the same table. Allow-list
# entries carry no catalog, so any catalog-qualified reference
# (``otherdb.public.orders``) has a non-null catalog in its key and cannot match
# -- that closes cross-database access the single-connection allow-list can't
# describe.
self._allowed_canonical: frozenset[tuple[str | None, str | None, str]] | None = None
# Qualified entries ("SCHEMA.TABLE") are listed under their own schema and
# returned fully qualified; unqualified entries (and allow-all) use the
# default ``schema``.
self._qualified_schemas: frozenset[str] = frozenset()
self._include_default_schema: bool = True
if self._allowed_tables is not None:
canonical: set[tuple[str | None, str | None, str]] = set()
qualified_schemas: set[str] = set()
include_default = False
for entry in self._allowed_tables:
entry_schema, sep, _ = entry.rpartition(".")
entry_schema, sep, table = entry.rpartition(".")
if sep:
qualified_schemas.add(entry_schema)
canonical.add(self._canonical_ref("", entry_schema, table))
else:
include_default = True
canonical.add(self._canonical_ref("", self._schema, entry))
self._allowed_canonical = frozenset(canonical)
self._qualified_schemas = frozenset(qualified_schemas)
self._include_default_schema = include_default

def _is_table_allowed(self, name: str) -> bool:
"""Case-insensitive membership test against ``allowed_tables`` (allow-all when unset)."""
return self._allowed_tables_ci is None or name.casefold() in self._allowed_tables_ci
@staticmethod
def _canonical_ref(
catalog: str | None, schema: str | None, table: str
) -> tuple[str | None, str | None, str]:
"""Normalise a ``(catalog, schema, table)`` reference to its case-folded comparison key."""
return (
catalog.casefold() if catalog else None,
schema.casefold() if schema else None,
table.casefold(),
)

def _is_ref_allowed(self, catalog: str | None, schema: str | None, table: str) -> bool:
"""Membership test for a resolved ``(catalog, schema, table)`` reference (allow-all when unset)."""
if self._allowed_canonical is None:
return True
return self._canonical_ref(catalog, schema, table) in self._allowed_canonical

@property
def id(self) -> str:
Expand Down Expand Up @@ -263,11 +303,11 @@ def _list_tables(self) -> str:
# Dedupe by (schema, table) so a table reachable both qualified and via the
# default schema (e.g. "public.users" and "users" with schema="public") is
# listed once. Case-folded because databases reflect identifiers in their case.
seen: set[tuple[str | None, str]] = set()
seen: set[tuple[str | None, str | None, str]] = set()

def add(schema: str | None, name: str, display: str) -> None:
key = (schema.casefold() if schema else None, name.casefold())
if self._is_table_allowed(display) and key not in seen:
key = self._canonical_ref("", schema, name)
if self._is_ref_allowed("", schema, name) and key not in seen:
seen.add(key)
tables.append(display)

Expand All @@ -286,10 +326,10 @@ def add(schema: str | None, name: str, display: str) -> None:
return json.dumps(tables)

def _get_schema(self, table_name: str) -> str:
if not self._is_table_allowed(table_name):
schema, table = self._split_table_identifier(table_name)
if not self._is_ref_allowed("", schema, table):
return json.dumps({"error": f"Table {table_name!r} is not in the allowed tables list."})
hook = self._get_db_hook()
schema, table = self._split_table_identifier(table_name)
columns = hook.get_table_schema(table, schema=schema)
return json.dumps(columns)

Expand All @@ -300,15 +340,19 @@ def _dialect_for_validation(self) -> str | None:

def _query(self, sql: str) -> str:
hook = self._get_db_hook()
dialect = self._dialect_for_validation()
statements: list[Any] | None = None
if not self._allow_writes:
# allow_read_only_metadata lets agents inspect schemas with DESCRIBE/SHOW
# (a common first move) instead of hard-failing; the deep scan still
# rejects any data-modifying statement, including EXPLAIN <write>.
_validate_sql(
sql,
dialect=self._dialect_for_validation(),
allow_read_only_metadata=True,
)
statements = _validate_sql(sql, dialect=dialect, allow_read_only_metadata=True)
elif self._allowed_canonical is not None:
# Writes are allowed but tables are restricted: parse anyway so the
# allow-list still governs which tables a write may touch.
statements = _parse_sql(sql, dialect=dialect)
if statements is not None:
self._enforce_allowed_tables(statements)

rows = hook.get_records(sql)
# Fetch column names from cursor description.
Expand Down Expand Up @@ -336,7 +380,40 @@ def _check_query(self, sql: str) -> str:
with suppress(Exception):
dialect = self._dialect_for_validation()
try:
_validate_sql(sql, dialect=dialect, allow_read_only_metadata=True)
statements = _validate_sql(sql, dialect=dialect, allow_read_only_metadata=True)
self._enforce_allowed_tables(statements)
return json.dumps({"valid": True})
except Exception as e:
return json.dumps({"valid": False, "error": str(e)})

def _enforce_allowed_tables(self, statements: list[Any]) -> None:
"""
Reject a parsed query that reaches any table outside ``allowed_tables``.

No-op when ``allowed_tables`` is unset (allow-all). Otherwise every table the
query references (resolved scope-correctly, including catalog) must be on the
list, and any construct the list cannot describe -- a table-valued function,
``SHOW``, dynamic SQL, an inline comment, or the ``TABLE <name>`` shorthand --
is refused. Raises :class:`SQLSafetyError` -- ``call_tool`` turns it into a
``ModelRetry`` so the agent can re-target an allowed table, while
``check_query`` reports it invalid.
"""
if self._allowed_canonical is None:
return
scan = collect_table_references(statements)
if scan.unverifiable_sources:
raise SQLSafetyError(
f"Query uses a data source that cannot be checked against allowed_tables: "
f"{'; '.join(scan.unverifiable_sources)}. Query the allowed tables directly: "
f"use list_tables to see them."
)
disallowed = [
".".join(part for part in (catalog, schema, table) if part)
for catalog, schema, table in scan.tables
if not self._is_ref_allowed(catalog, schema or self._schema, table)
]
if disallowed:
raise SQLSafetyError(
f"Query references tables that are not in the allowed tables list: "
f"{', '.join(sorted(set(disallowed)))}. Use list_tables to see the allowed tables."
)
Loading
Loading