From 077ef05c7a2c6291148fb243a1650ddbcda03d24 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 5 Jun 2026 15:19:01 +0100 Subject: [PATCH] Support multi-schema introspection in common.ai SQLToolset SQLToolset's metadata tools (list_tables, get_schema) operated against a single schema, so an agent over a multi-schema warehouse (common on Snowflake) could not discover tables across schemas. With no schema set and schema-qualified tables, list_tables introspected a literal "None" schema (SHOW TABLES IN SCHEMA "DB"."None") and failed outright. allowed_tables entries may now be schema-qualified ("SCHEMA.TABLE"). list_tables introspects each referenced schema and returns the matching tables fully qualified, and get_schema routes each qualified name to its own schema. Unqualified entries and the allow-all case keep the previous single-schema behaviour using the default schema. Table-name matching is case-insensitive, because databases reflect identifiers in their own case (Snowflake reflects unquoted names lowercased) and a byte-exact match would silently return nothing. Results are de-duplicated by (schema, table) so a table reachable both qualified and via the default schema is listed once. --- providers/common/ai/docs/toolsets.rst | 29 +++- .../providers/common/ai/toolsets/sql.py | 84 ++++++++++- .../tests/unit/common/ai/toolsets/test_sql.py | 132 +++++++++++++++++- 3 files changed, 230 insertions(+), 15 deletions(-) diff --git a/providers/common/ai/docs/toolsets.rst b/providers/common/ai/docs/toolsets.rst index ee6d8a85b638b..65a3ab11b04d3 100644 --- a/providers/common/ai/docs/toolsets.rst +++ b/providers/common/ai/docs/toolsets.rst @@ -146,14 +146,39 @@ Curated toolset wrapping The ``DbApiHook`` is resolved lazily from ``db_conn_id`` on first tool call via ``BaseHook.get_connection(conn_id).get_hook()``. +Multi-schema warehouses +^^^^^^^^^^^^^^^^^^^^^^^^^ + +When an agent's tables live in several schemas of one database -- common on +Snowflake -- list them with schema-qualified ``allowed_tables`` entries: + +.. code-block:: python + + SQLToolset( + db_conn_id="snowflake_hq", + allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS", "MODEL_CRM.SF_ASTRO_ORGS"], + ) + +``list_tables`` then introspects each referenced schema and returns the matching +tables fully qualified (e.g. ``MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS``), and +``get_schema`` routes each qualified name to its own schema. Without this, a +single ``schema`` only covers one namespace, and leaving ``schema`` unset made +introspection query a literal ``"None"`` schema and fail. Unqualified entries +fall back to ``schema``, and table-name matching is case-insensitive (databases +reflect identifiers in their own case). For tables in a different *database*, use +a separate toolset whose connection points at that database. + Parameters ^^^^^^^^^^ - ``db_conn_id``: Airflow connection ID for the database. - ``allowed_tables``: Restrict which tables the agent can discover via - ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables. + ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables in + ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to span + multiple schemas; see above. Matching is case-insensitive. See :ref:`allowed-tables-limitation` for an important caveat. -- ``schema``: Database schema/namespace for table listing and introspection. +- ``schema``: Default schema/namespace for unqualified table listing and + introspection. Schema-qualified ``allowed_tables`` entries override it per table. - ``allow_writes``: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.). Default ``False`` — only SELECT-family statements are permitted. - ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``. diff --git a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py index 0902cff99f25a..fca07177597e9 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py +++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py @@ -111,7 +111,13 @@ class SQLToolset(AbstractToolset[Any]): :param db_conn_id: Airflow connection ID for the database. :param allowed_tables: Restrict which tables the agent can discover via - ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables. + ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables + in ``schema``. Entries may be schema-qualified (``"SCHEMA.TABLE"``) to span + multiple schemas in one database -- common on warehouses such as Snowflake. + ``list_tables`` then introspects each referenced schema and returns the + matching tables fully qualified, and ``get_schema`` routes to the table's + own schema. Unqualified entries use ``schema``. Matching is + case-insensitive, since databases reflect identifiers in their own case. .. note:: ``allowed_tables`` controls metadata visibility only. It does **not** @@ -120,7 +126,10 @@ class SQLToolset(AbstractToolset[Any]): restrictions, use database-level permissions (e.g. a read-only role with grants limited to specific tables). - :param schema: Database schema/namespace for table listing and introspection. + :param schema: Default schema/namespace for table listing and introspection, + used for unqualified ``allowed_tables`` entries and unqualified + ``get_schema`` calls. Schema-qualified ``allowed_tables`` entries override + it per table. :param allow_writes: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.). Default ``False`` — only SELECT-family statements are permitted. :param max_rows: Maximum number of rows returned from the ``query`` tool. @@ -138,11 +147,43 @@ def __init__( ) -> None: self._db_conn_id = db_conn_id self._allowed_tables: frozenset[str] | None = frozenset(allowed_tables) if allowed_tables else None + # Case-folded view for membership tests: databases reflect identifiers in + # their own case (Snowflake stores unquoted names uppercase but reflects + # them lowercased), so a byte-exact match against the user's entries would + # silently miss. allowed_tables is a visibility hint, not access control, + # so case-insensitive matching is safe. + self._allowed_tables_ci: frozenset[str] | None = ( + frozenset(t.casefold() for t in self._allowed_tables) + if self._allowed_tables is not None + else None + ) self._schema = schema self._allow_writes = allow_writes self._max_rows = max_rows self._hook: DbApiHook | None = None + # Derive which schemas to introspect from schema-qualified allowed_tables. + # Qualified entries ("SCHEMA.TABLE") are listed under their own schema and + # returned fully qualified; unqualified entries (and allow-all) use the + # default ``schema``. + self._qualified_schemas: frozenset[str] = frozenset() + self._include_default_schema: bool = True + if self._allowed_tables is not None: + qualified_schemas: set[str] = set() + include_default = False + for entry in self._allowed_tables: + entry_schema, sep, _ = entry.rpartition(".") + if sep: + qualified_schemas.add(entry_schema) + else: + include_default = True + self._qualified_schemas = frozenset(qualified_schemas) + self._include_default_schema = include_default + + def _is_table_allowed(self, name: str) -> bool: + """Case-insensitive membership test against ``allowed_tables`` (allow-all when unset).""" + return self._allowed_tables_ci is None or name.casefold() in self._allowed_tables_ci + @property def id(self) -> str: return f"sql-{self._db_conn_id}" @@ -213,18 +254,47 @@ async def call_tool( # Tool implementations # ------------------------------------------------------------------ + def _split_table_identifier(self, table_name: str) -> tuple[str | None, str]: + """Split ``"SCHEMA.TABLE"`` into ``(schema, table)``; unqualified uses the default schema.""" + schema, sep, table = table_name.rpartition(".") + if not sep: + return self._schema, table_name + return schema, table + def _list_tables(self) -> str: hook = self._get_db_hook() - tables: list[str] = hook.inspector.get_table_names(schema=self._schema) - if self._allowed_tables is not None: - tables = [t for t in tables if t in self._allowed_tables] + tables: list[str] = [] + # Dedupe by (schema, table) so a table reachable both qualified and via the + # default schema (e.g. "public.users" and "users" with schema="public") is + # listed once. Case-folded because databases reflect identifiers in their case. + seen: set[tuple[str | None, str]] = set() + + def add(schema: str | None, name: str, display: str) -> None: + key = (schema.casefold() if schema else None, name.casefold()) + if self._is_table_allowed(display) and key not in seen: + seen.add(key) + tables.append(display) + + # Schemas referenced by qualified allowed_tables entries: introspect each + # and return matching tables fully qualified so they round-trip to get_schema. + for schema in sorted(self._qualified_schemas): + for name in hook.inspector.get_table_names(schema=schema): + add(schema, name, f"{schema}.{name}") + + # Default schema: used for allow-all and unqualified allowed_tables entries. + # Names stay bare to preserve the single-schema behaviour. + if self._include_default_schema: + for name in hook.inspector.get_table_names(schema=self._schema): + add(self._schema, name, name) + return json.dumps(tables) def _get_schema(self, table_name: str) -> str: - if self._allowed_tables is not None and table_name not in self._allowed_tables: + if not self._is_table_allowed(table_name): return json.dumps({"error": f"Table {table_name!r} is not in the allowed tables list."}) hook = self._get_db_hook() - columns = hook.get_table_schema(table_name, schema=self._schema) + schema, table = self._split_table_identifier(table_name) + columns = hook.get_table_schema(table, schema=schema) return json.dumps(columns) def _query(self, sql: str) -> str: diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py index 471b956385d82..c1aae15aad542 100644 --- a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py +++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py @@ -27,6 +27,7 @@ from airflow.providers.common.ai.toolsets.sql import SQLToolset from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError +from airflow.providers.common.sql.hooks.sql import DbApiHook def _make_mock_db_hook( @@ -36,8 +37,6 @@ def _make_mock_db_hook( last_description: list[tuple] | None = None, ): """Create a mock DbApiHook with sensible defaults.""" - from airflow.providers.common.sql.hooks.sql import DbApiHook - mock = MagicMock(spec=DbApiHook) mock.inspector = MagicMock() mock.inspector.get_table_names.return_value = table_names or ["users", "orders"] @@ -335,8 +334,6 @@ def test_invalid_sql(self): class TestSQLToolsetHookResolution: @patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True) def test_lazy_resolves_db_hook(self, mock_base_hook): - from airflow.providers.common.sql.hooks.sql import DbApiHook - mock_hook = MagicMock(spec=DbApiHook) mock_conn = MagicMock(spec=["get_hook"]) mock_conn.get_hook.return_value = mock_hook @@ -361,8 +358,6 @@ def test_raises_for_non_dbapi_hook(self, mock_base_hook): @patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True) def test_caches_hook_after_first_resolution(self, mock_base_hook): - from airflow.providers.common.sql.hooks.sql import DbApiHook - mock_hook = MagicMock(spec=DbApiHook) mock_conn = MagicMock(spec=["get_hook"]) mock_conn.get_hook.return_value = mock_hook @@ -374,3 +369,128 @@ def test_caches_hook_after_first_resolution(self, mock_base_hook): # Only called once because result is cached. mock_base_hook.get_connection.assert_called_once() + + +class TestSQLToolsetMultiSchema: + """Schema-qualified allowed_tables span multiple schemas in one database.""" + + @staticmethod + def _schema_aware_hook(tables_by_schema: dict[str | None, list[str]]): + hook = MagicMock(spec=DbApiHook) + hook.inspector = MagicMock() + hook.inspector.get_table_names.side_effect = lambda schema=None: tables_by_schema.get(schema, []) + hook.get_table_schema.return_value = [{"name": "id", "type": "INTEGER"}] + return hook + + def test_list_tables_spans_multiple_schemas(self): + ts = SQLToolset( + "sf", + allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS", "MODEL_CRM.SF_ASTRO_ORGS"], + ) + ts._hook = self._schema_aware_hook( + { + "MODEL_ASTRO": ["DEPLOYMENT_IMAGE_DETAILS", "OTHER_TABLE"], + "MODEL_CRM": ["SF_ASTRO_ORGS"], + } + ) + + result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), tool=MagicMock()))) + assert result == ["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS", "MODEL_CRM.SF_ASTRO_ORGS"] + + def test_list_tables_never_introspects_none_schema_when_all_qualified(self): + """Regression for the 'SHOW TABLES IN SCHEMA "DB"."None"' failure.""" + ts = SQLToolset("sf", allowed_tables=["MODEL_ASTRO.X", "MODEL_CRM.Y"]) + ts._hook = self._schema_aware_hook({"MODEL_ASTRO": ["X"], "MODEL_CRM": ["Y"]}) + + asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), tool=MagicMock())) + + called_schemas = {c.kwargs.get("schema") for c in ts._hook.inspector.get_table_names.call_args_list} + assert called_schemas == {"MODEL_ASTRO", "MODEL_CRM"} + assert None not in called_schemas + + def test_list_tables_mixed_qualified_and_default(self): + ts = SQLToolset("pg", allowed_tables=["users", "MODEL_ASTRO.X"], schema="public") + ts._hook = self._schema_aware_hook({"public": ["users", "orders"], "MODEL_ASTRO": ["X", "Z"]}) + + result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), tool=MagicMock()))) + # Qualified schemas listed first (sorted), then the default schema. + assert result == ["MODEL_ASTRO.X", "users"] + + def test_get_schema_routes_to_qualified_schema(self): + ts = SQLToolset("sf", allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS"]) + ts._hook = self._schema_aware_hook({"MODEL_ASTRO": ["DEPLOYMENT_IMAGE_DETAILS"]}) + + result = json.loads( + asyncio.run( + ts.call_tool( + "get_schema", + {"table_name": "MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS"}, + ctx=MagicMock(), + tool=MagicMock(), + ) + ) + ) + assert result == [{"name": "id", "type": "INTEGER"}] + ts._hook.get_table_schema.assert_called_once_with("DEPLOYMENT_IMAGE_DETAILS", schema="MODEL_ASTRO") + + def test_get_schema_blocks_table_outside_allowed_schema(self): + ts = SQLToolset("sf", allowed_tables=["MODEL_ASTRO.X"]) + ts._hook = self._schema_aware_hook({"MODEL_ASTRO": ["X"]}) + + result = json.loads( + asyncio.run( + ts.call_tool( + "get_schema", {"table_name": "SECRETS.PASSWORDS"}, ctx=MagicMock(), tool=MagicMock() + ) + ) + ) + assert "error" in result + ts._hook.get_table_schema.assert_not_called() + + def test_get_schema_unqualified_uses_default_schema(self): + ts = SQLToolset("pg", schema="public") + ts._hook = self._schema_aware_hook({"public": ["users"]}) + + asyncio.run(ts.call_tool("get_schema", {"table_name": "users"}, ctx=MagicMock(), tool=MagicMock())) + ts._hook.get_table_schema.assert_called_once_with("users", schema="public") + + def test_list_tables_matches_case_insensitively(self): + """Snowflake reflects unquoted names lowercased; uppercase allowed_tables still match.""" + ts = SQLToolset( + "sf", + allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS", "MODEL_CRM.SF_ASTRO_ORGS"], + ) + ts._hook = self._schema_aware_hook( + { + "MODEL_ASTRO": ["deployment_image_details", "other"], + "MODEL_CRM": ["sf_astro_orgs"], + } + ) + + result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), tool=MagicMock()))) + assert result == ["MODEL_ASTRO.deployment_image_details", "MODEL_CRM.sf_astro_orgs"] + + def test_get_schema_matches_case_insensitively(self): + ts = SQLToolset("sf", allowed_tables=["MODEL_ASTRO.DEPLOYMENT_IMAGE_DETAILS"]) + ts._hook = self._schema_aware_hook({"MODEL_ASTRO": ["deployment_image_details"]}) + + result = json.loads( + asyncio.run( + ts.call_tool( + "get_schema", + {"table_name": "MODEL_ASTRO.deployment_image_details"}, + ctx=MagicMock(), + tool=MagicMock(), + ) + ) + ) + assert "error" not in result + ts._hook.get_table_schema.assert_called_once_with("deployment_image_details", schema="MODEL_ASTRO") + + def test_list_tables_deduplicates_same_table(self): + """A table listed both qualified and unqualified appears once.""" + ts = SQLToolset("pg", allowed_tables=["public.users", "users"], schema="public") + ts._hook = self._schema_aware_hook({"public": ["users"]}) + + result = json.loads(asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(), tool=MagicMock()))) + assert result == ["public.users"]