Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions providers/common/ai/docs/toolsets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,39 @@ Curated toolset wrapping
The ``DbApiHook`` is resolved lazily from ``db_conn_id`` on first tool call
via ``BaseHook.get_connection(conn_id).get_hook()``.

Multi-schema warehouses
^^^^^^^^^^^^^^^^^^^^^^^^^

When an agent's tables live in several schemas of one database -- common on
Snowflake -- list them with schema-qualified ``allowed_tables`` entries:

.. code-block:: python

SQLToolset(
db_conn_id="snowflake_hq",
allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS", "MODEL_CRM.SF_ASTRO_ORGS"],
)

``list_tables`` then introspects each referenced schema and returns the matching
tables fully qualified (e.g. ``MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS``), and
``get_schema`` routes each qualified name to its own schema. Without this, a
single ``schema`` only covers one namespace, and leaving ``schema`` unset made
introspection query a literal ``"None"`` schema and fail. Unqualified entries
fall back to ``schema``, and table-name matching is case-insensitive (databases
reflect identifiers in their own case). For tables in a different *database*, use
a separate toolset whose connection points at that database.

Parameters
^^^^^^^^^^

- ``db_conn_id``: Airflow connection ID for the database.
- ``allowed_tables``: Restrict which tables the agent can discover via
``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables.
``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables in
``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to span
multiple schemas; see above. Matching is case-insensitive.
See :ref:`allowed-tables-limitation` for an important caveat.
- ``schema``: Database schema/namespace for table listing and introspection.
- ``schema``: Default schema/namespace for unqualified table listing and
introspection. Schema-qualified ``allowed_tables`` entries override it per table.
- ``allow_writes``: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.).
Default ``False`` — only SELECT-family statements are permitted.
- ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,13 @@ class SQLToolset(AbstractToolset[Any]):

:param db_conn_id: Airflow connection ID for the database.
:param allowed_tables: Restrict which tables the agent can discover via
``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables.
``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables
in ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to span
multiple schemas in one database -- common on warehouses such as Snowflake.
``list_tables`` then introspects each referenced schema and returns the
matching tables fully qualified, and ``get_schema`` routes to the table's
own schema. Unqualified entries use ``schema``. Matching is
case-insensitive, since databases reflect identifiers in their own case.

.. note::
``allowed_tables`` controls metadata visibility only. It does **not**
Expand All @@ -120,7 +126,10 @@ class SQLToolset(AbstractToolset[Any]):
restrictions, use database-level permissions (e.g. a read-only role
with grants limited to specific tables).

:param schema: Database schema/namespace for table listing and introspection.
:param schema: Default schema/namespace for table listing and introspection,
used for unqualified ``allowed_tables`` entries and unqualified
``get_schema`` calls. Schema-qualified ``allowed_tables`` entries override
it per table.
:param allow_writes: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.).
Default ``False`` — only SELECT-family statements are permitted.
:param max_rows: Maximum number of rows returned from the ``query`` tool.
Expand All @@ -138,11 +147,43 @@ def __init__(
) -> None:
self._db_conn_id = db_conn_id
self._allowed_tables: frozenset[str] | None = frozenset(allowed_tables) if allowed_tables else None
# Case-folded view for membership tests: databases reflect identifiers in
# their own case (Snowflake stores unquoted names uppercase but reflects
# them lowercased), so a byte-exact match against the user's entries would
# silently miss. allowed_tables is a visibility hint, not access control,
# so case-insensitive matching is safe.
self._allowed_tables_ci: frozenset[str] | None = (
frozenset(t.casefold() for t in self._allowed_tables)
if self._allowed_tables is not None
else None
)
self._schema = schema
self._allow_writes = allow_writes
self._max_rows = max_rows
self._hook: DbApiHook | None = None

# Derive which schemas to introspect from schema-qualified allowed_tables.
# Qualified entries ("SCHEMA.TABLE") are listed under their own schema and
# returned fully qualified; unqualified entries (and allow-all) use the
# default ``schema``.
self._qualified_schemas: frozenset[str] = frozenset()
self._include_default_schema: bool = True
if self._allowed_tables is not None:
qualified_schemas: set[str] = set()
include_default = False
for entry in self._allowed_tables:
entry_schema, sep, _ = entry.rpartition(".")
if sep:
qualified_schemas.add(entry_schema)
else:
include_default = True
self._qualified_schemas = frozenset(qualified_schemas)
self._include_default_schema = include_default

def _is_table_allowed(self, name: str) -> bool:
"""Case-insensitive membership test against ``allowed_tables`` (allow-all when unset)."""
return self._allowed_tables_ci is None or name.casefold() in self._allowed_tables_ci

@property
def id(self) -> str:
return f"sql-{self._db_conn_id}"
Expand Down Expand Up @@ -213,18 +254,47 @@ async def call_tool(
# Tool implementations
# ------------------------------------------------------------------

def _split_table_identifier(self, table_name: str) -> tuple[str | None, str]:
"""Split ``"SCHEMA.TABLE"`` into ``(schema, table)``; unqualified uses the default schema."""
schema, sep, table = table_name.rpartition(".")
if not sep:
return self._schema, table_name
return schema, table

def _list_tables(self) -> str:
hook = self._get_db_hook()
tables: list[str] = hook.inspector.get_table_names(schema=self._schema)
if self._allowed_tables is not None:
tables = [t for t in tables if t in self._allowed_tables]
tables: list[str] = []
# Dedupe by (schema, table) so a table reachable both qualified and via the
# default schema (e.g. "public.users" and "users" with schema="public") is
# listed once. Case-folded because databases reflect identifiers in their case.
seen: set[tuple[str | None, str]] = set()

def add(schema: str | None, name: str, display: str) -> None:
key = (schema.casefold() if schema else None, name.casefold())
if self._is_table_allowed(display) and key not in seen:
seen.add(key)
tables.append(display)

# Schemas referenced by qualified allowed_tables entries: introspect each
# and return matching tables fully qualified so they round-trip to get_schema.
for schema in sorted(self._qualified_schemas):
for name in hook.inspector.get_table_names(schema=schema):
add(schema, name, f"{schema}.{name}")

# Default schema: used for allow-all and unqualified allowed_tables entries.
# Names stay bare to preserve the single-schema behaviour.
if self._include_default_schema:
for name in hook.inspector.get_table_names(schema=self._schema):
add(self._schema, name, name)

return json.dumps(tables)

def _get_schema(self, table_name: str) -> str:
if self._allowed_tables is not None and table_name not in self._allowed_tables:
if not self._is_table_allowed(table_name):
return json.dumps({"error": f"Table {table_name!r} is not in the allowed tables list."})
hook = self._get_db_hook()
columns = hook.get_table_schema(table_name, schema=self._schema)
schema, table = self._split_table_identifier(table_name)
columns = hook.get_table_schema(table, schema=schema)
return json.dumps(columns)

def _query(self, sql: str) -> str:
Expand Down
132 changes: 126 additions & 6 deletions providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from airflow.providers.common.ai.toolsets.sql import SQLToolset
from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
from airflow.providers.common.sql.hooks.sql import DbApiHook


def _make_mock_db_hook(
Expand All @@ -36,8 +37,6 @@ def _make_mock_db_hook(
last_description: list[tuple] | None = None,
):
"""Create a mock DbApiHook with sensible defaults."""
from airflow.providers.common.sql.hooks.sql import DbApiHook

mock = MagicMock(spec=DbApiHook)
mock.inspector = MagicMock()
mock.inspector.get_table_names.return_value = table_names or ["users", "orders"]
Expand Down Expand Up @@ -335,8 +334,6 @@ def test_invalid_sql(self):
class TestSQLToolsetHookResolution:
@patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True)
def test_lazy_resolves_db_hook(self, mock_base_hook):
from airflow.providers.common.sql.hooks.sql import DbApiHook

mock_hook = MagicMock(spec=DbApiHook)
mock_conn = MagicMock(spec=["get_hook"])
mock_conn.get_hook.return_value = mock_hook
Expand All @@ -361,8 +358,6 @@ def test_raises_for_non_dbapi_hook(self, mock_base_hook):

@patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True)
def test_caches_hook_after_first_resolution(self, mock_base_hook):
from airflow.providers.common.sql.hooks.sql import DbApiHook

mock_hook = MagicMock(spec=DbApiHook)
mock_conn = MagicMock(spec=["get_hook"])
mock_conn.get_hook.return_value = mock_hook
Expand All @@ -374,3 +369,128 @@ def test_caches_hook_after_first_resolution(self, mock_base_hook):

# Only called once because result is cached.
mock_base_hook.get_connection.assert_called_once()


class TestSQLToolsetMultiSchema:
"""Schema-qualified allowed_tables span multiple schemas in one database."""

@staticmethod
def _schema_aware_hook(tables_by_schema: dict[str | None, list[str]]):
hook = MagicMock(spec=DbApiHook)
hook.inspector = MagicMock()
hook.inspector.get_table_names.side_effect = lambda schema=None: tables_by_schema.get(schema, [])
hook.get_table_schema.return_value = [{"name": "id", "type": "INTEGER"}]
return hook

def test_list_tables_spans_multiple_schemas(self):
ts = SQLToolset(
"sf",
allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS", "MODEL_CRM.SF_ASTRO_ORGS"],
)
ts._hook = self._schema_aware_hook(
{
"MODEL_ASTRO": ["DEPLOYMENT_IMAGE_DETAILS", "OTHER_TABLE"],
"MODEL_CRM": ["SF_ASTRO_ORGS"],
}
)

result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), tool=MagicMock())))
assert result == ["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS", "MODEL_CRM.SF_ASTRO_ORGS"]

def test_list_tables_never_introspects_none_schema_when_all_qualified(self):
"""Regression for the 'SHOW TABLES IN SCHEMA "DB"."None"' failure."""
ts = SQLToolset("sf", allowed_tables=["MODEL_ASTRO.X", "MODEL_CRM.Y"])
ts._hook = self._schema_aware_hook({"MODEL_ASTRO": ["X"], "MODEL_CRM": ["Y"]})

asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), tool=MagicMock()))

called_schemas = {c.kwargs.get("schema") for c in ts._hook.inspector.get_table_names.call_args_list}
assert called_schemas == {"MODEL_ASTRO", "MODEL_CRM"}
assert None not in called_schemas

def test_list_tables_mixed_qualified_and_default(self):
ts = SQLToolset("pg", allowed_tables=["users", "MODEL_ASTRO.X"], schema="public")
ts._hook = self._schema_aware_hook({"public": ["users", "orders"], "MODEL_ASTRO": ["X", "Z"]})

result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), tool=MagicMock())))
# Qualified schemas listed first (sorted), then the default schema.
assert result == ["MODEL_ASTRO.X", "users"]

def test_get_schema_routes_to_qualified_schema(self):
ts = SQLToolset("sf", allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS"])
ts._hook = self._schema_aware_hook({"MODEL_ASTRO": ["DEPLOYMENT_IMAGE_DETAILS"]})

result = json.loads(
asyncio.run(
ts.call_tool(
"get_schema",
{"table_name": "MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS"},
ctx=MagicMock(),
tool=MagicMock(),
)
)
)
assert result == [{"name": "id", "type": "INTEGER"}]
ts._hook.get_table_schema.assert_called_once_with("DEPLOYMENT_IMAGE_DETAILS", schema="MODEL_ASTRO")

def test_get_schema_blocks_table_outside_allowed_schema(self):
ts = SQLToolset("sf", allowed_tables=["MODEL_ASTRO.X"])
ts._hook = self._schema_aware_hook({"MODEL_ASTRO": ["X"]})

result = json.loads(
asyncio.run(
ts.call_tool(
"get_schema", {"table_name": "SECRETS.PASSWORDS"}, ctx=MagicMock(), tool=MagicMock()
)
)
)
assert "error" in result
ts._hook.get_table_schema.assert_not_called()

def test_get_schema_unqualified_uses_default_schema(self):
ts = SQLToolset("pg", schema="public")
ts._hook = self._schema_aware_hook({"public": ["users"]})

asyncio.run(ts.call_tool("get_schema", {"table_name": "users"}, ctx=MagicMock(), tool=MagicMock()))
ts._hook.get_table_schema.assert_called_once_with("users", schema="public")

def test_list_tables_matches_case_insensitively(self):
"""Snowflake reflects unquoted names lowercased; uppercase allowed_tables still match."""
ts = SQLToolset(
"sf",
allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS", "MODEL_CRM.SF_ASTRO_ORGS"],
)
ts._hook = self._schema_aware_hook(
{
"MODEL_ASTRO": ["deployment_image_details", "other"],
"MODEL_CRM": ["sf_astro_orgs"],
}
)

result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), tool=MagicMock())))
assert result == ["MODEL_ASTRO.deployment_image_details", "MODEL_CRM.sf_astro_orgs"]

def test_get_schema_matches_case_insensitively(self):
ts = SQLToolset("sf", allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS"])
ts._hook = self._schema_aware_hook({"MODEL_ASTRO": ["deployment_image_details"]})

result = json.loads(
asyncio.run(
ts.call_tool(
"get_schema",
{"table_name": "MODEL_ASTRO.deployment_image_details"},
ctx=MagicMock(),
tool=MagicMock(),
)
)
)
assert "error" not in result
ts._hook.get_table_schema.assert_called_once_with("deployment_image_details", schema="MODEL_ASTRO")

def test_list_tables_deduplicates_same_table(self):
"""A table listed both qualified and unqualified appears once."""
ts = SQLToolset("pg", allowed_tables=["public.users", "users"], schema="public")
ts._hook = self._schema_aware_hook({"public": ["users"]})

result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), tool=MagicMock())))
assert result == ["public.users"]
Loading