From 73bbf9efb267968faad2bdc706aa14a55ceb1e4a Mon Sep 17 00:00:00 2001 From: ojasaklechayt Date: Tue, 12 Aug 2025 14:12:33 +0530 Subject: [PATCH 1/5] Implemented async pg --- .gitignore | 4 +- foundation_sql/db.py | 129 +++++++------- foundation_sql/db_drivers.py | 254 ++++++++++++++++++++++++++++ foundation_sql/gen.py | 11 +- foundation_sql/prompts.md | 61 +++---- foundation_sql/query.py | 86 ++++++++-- plan.md | 105 ++++++++++++ requirements.txt | 2 + tests/common.py | 45 ++++- tests/test_adapters.py | 73 ++++++++ tests/test_query_decorator_async.py | 90 ++++++++++ tests/test_query_decorator_sync.py | 109 ++++++++++++ uv.lock | 8 + 13 files changed, 848 insertions(+), 129 deletions(-) create mode 100644 foundation_sql/db_drivers.py create mode 100644 plan.md create mode 100644 tests/test_adapters.py create mode 100644 tests/test_query_decorator_async.py create mode 100644 tests/test_query_decorator_sync.py create mode 100644 uv.lock diff --git a/.gitignore b/.gitignore index b9e30d2..1756956 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,4 @@ ENV/ .env # Misc -.DS_Store - -__sql__ \ No newline at end of file +.DS_Store \ No newline at end of file diff --git a/foundation_sql/db.py b/foundation_sql/db.py index c2541fe..59a7418 100644 --- a/foundation_sql/db.py +++ b/foundation_sql/db.py @@ -1,5 +1,5 @@ """ -Database operations module for Foundation. +Database operations module for Foundation (adapter-based). """ import logging @@ -7,19 +7,18 @@ from types import NoneType from typing import Dict, Any, Optional, Type, Union, List from pydantic import BaseModel -from sqlalchemy import create_engine, text +from sqlalchemy import text from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError -from jinja2sql import Jinja2SQL from datetime import datetime +from foundation_sql.db_drivers import EngineAdapter, SQLAlchemyAdapter +from foundation_sql.db_drivers import AsyncpgAdapter NESTED_SPLITTER = "." # Singleton instance DATABASES = {} -j2sql = Jinja2SQL() - # logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -27,29 +26,33 @@ class Database: - """Database operations for Foundation.""" + """Database operations for Foundation delegated to an EngineAdapter.""" - def __init__(self, db_url: Optional[str] = None) -> None: - """Initialize the database. + def __init__(self, db_url: Optional[str] = None, adapter: Optional[EngineAdapter] = None) -> None: + """Initialize the database facade. Args: - db_url: SQLAlchemy database URL (e.g., postgresql://user:pass@host/db, sqlite:///path/to/db) + db_url: Database URL (e.g., postgresql://user:pass@host/db, sqlite:///path/to/db) If not provided, will use DATABASE_URL environment variable + adapter: Optional explicit adapter. If not provided, a default sync adapter is created. """ self.db_url = db_url or os.getenv('DATABASE_URL') if not self.db_url: raise ValueError('Database URL must be provided either through constructor or DATABASE_URL environment variable') - # Create SQLAlchemy engine - self.engine = create_engine(self.db_url) + # Default to sync SQLAlchemy adapter unless explicitly provided + self.adapter: EngineAdapter = adapter or SQLAlchemyAdapter(self.db_url) def get_engine(self) -> Engine: - """Get the SQLAlchemy engine. + """Get the underlying SQLAlchemy engine if available. Returns: SQLAlchemy Engine instance """ - return self.engine + # Only available for SQLAlchemyAdapter + if isinstance(self.adapter, SQLAlchemyAdapter): + return self.adapter.engine + raise RuntimeError("Engine is not available for this adapter") def init_schema(self, schema_sql:Optional[str]=None, schema_path: Optional[str] = None) -> None: """Initialize the database schema if it doesn't exist. @@ -65,15 +68,8 @@ def init_schema(self, schema_sql:Optional[str]=None, schema_path: Optional[str] if not schema_sql: with open(schema_path, 'r') as f: schema_sql = f.read() - - # Connect and execute schema in a transaction - with self.engine.begin() as conn: - try: - for statement in schema_sql.split(';'): - if statement.strip(): - conn.execute(text(statement)) - except SQLAlchemyError as e: - raise RuntimeError(f'Failed to initialize schema: {str(e)}') from e + # Delegate to adapter + self.adapter.init_schema(schema_sql) def run_sql(self, sql_template: str, **context) -> Any: """Run an SQL template string with jinja2sql for rendering and parameter substitution. @@ -90,50 +86,35 @@ def run_sql(self, sql_template: str, **context) -> Any: if 'now' not in context: context['now'] = datetime.now - # Use jinja2sql to render the template and get SQL with parameters - try: - query, params = j2sql.from_string( - sql_template, - context=context - ) - except Exception as e: - raise ValueError(f"Failed to render SQL. Likely SQL template & Parameter mismatch: {str(e)}") from e - - # Execute the SQL with parameters in a transaction - with self.engine.connect() as conn: - with conn.begin(): - try: - # Split the query into individual statements for SQLite - statements = [stmt.strip() for stmt in query.split(';') if stmt.strip()] - - # Track total rows affected - total_rows = 0 - last_result = None - - logger.debug(f"Executing statements: {len(statements)}") - # Execute each statement - for statement in statements: - logger.debug(f"Executing statement: {statement} {params}") - result = conn.execute(text(statement), params) - total_rows += result.rowcount - - # Store the last result for potential return - if result.returns_rows: - last_result = result - - # If last result is a SELECT, return its rows - if last_result and last_result.returns_rows: - # Convert SQLAlchemy Row objects to plain dictionaries - # to avoid leaking SQLAlchemy implementation details - rows = [dict(row._mapping) for row in last_result] - logger.debug(f"Returning rows: {rows}") - return QueryResult(rows) - - # If no rows affected, return 0 - return total_rows - except SQLAlchemyError as e: - # The transaction will be rolled back automatically - raise RuntimeError(f"Failed to execute SQL: {str(e)}\nRendered SQL: {query}") from e + # Delegate to adapter; it returns either list[dict] rows or int rowcount + result = self.adapter.run_sql(sql_template, context) + if isinstance(result, int): + return result + # assume list of dicts + return QueryResult(result) + + # ---------- Async delegates (Phase 2) ---------- + async def init_schema_async(self, schema_sql: Optional[str] = None, schema_path: Optional[str] = None) -> None: + if not hasattr(self.adapter, 'init_schema_async'): + raise NotImplementedError("Async schema init not supported by this adapter") + if not schema_sql: + with open(schema_path, 'r') as f: + schema_sql = f.read() + await self.adapter.init_schema_async(schema_sql) # type: ignore[attr-defined] + + async def run_sql_async(self, sql_template: str, **context) -> Any: + if not hasattr(self.adapter, 'run_sql_async'): + raise NotImplementedError("Async run_sql not supported by this adapter") + if 'now' not in context: + context['now'] = datetime.now + result = await self.adapter.run_sql_async(sql_template, context) # type: ignore[attr-defined] + if isinstance(result, int): + return result + return QueryResult(result) + + async def close_async(self) -> None: + if hasattr(self.adapter, 'close_async'): + await self.adapter.close_async() # type: ignore[attr-defined] def execute(self, sql: str, params: Optional[Union[tuple, dict, List[tuple]]] = None) -> Any: @@ -151,7 +132,11 @@ def execute(self, sql: str, params: Optional[Union[tuple, dict, List[tuple]]] = Returns: Any: Result of the execution """ - with self.engine.connect() as connection: + # Only supported on SQLAlchemy adapter path for now + if not isinstance(self.adapter, SQLAlchemyAdapter): + raise NotImplementedError("execute() is only supported for SQLAlchemy adapter") + + with self.adapter.engine.connect() as connection: try: # Replace '?' placeholders with SQLAlchemy named parameters if '?' in sql: @@ -249,6 +234,16 @@ def get_db(db_url: str) -> Database: return DATABASES[db_url] +def get_db_with_adapter(db_url: str, mode: str) -> Database: + """Internal helper for selecting adapter explicitly. + mode: "sync" | "async" (async not implemented yet) + """ + if mode == "sync": + return Database(db_url, adapter=SQLAlchemyAdapter(db_url)) + if mode == "async": + return Database(db_url, adapter=AsyncpgAdapter(db_url)) + raise ValueError(f"Unknown adapter mode: {mode}") + def run_sql(db_url: str, sql_template: str, **context) -> Any: """Run an SQL template string with jinja2sql for rendering and parameter substitution. diff --git a/foundation_sql/db_drivers.py b/foundation_sql/db_drivers.py new file mode 100644 index 0000000..28f64d2 --- /dev/null +++ b/foundation_sql/db_drivers.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import logging +import re # Add missing import +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, List + +from sqlalchemy import create_engine, text +from sqlalchemy.engine import Engine +from sqlalchemy.exc import SQLAlchemyError + +from jinja2sql import Jinja2SQL +from datetime import datetime +import asyncpg + +logger = logging.getLogger(__name__) + + +class EngineAdapter(ABC): + """Abstract adapter for DB engines (sync and async).""" + + # ---------- Sync API ---------- + @abstractmethod + def init_schema(self, schema_sql: str) -> None: + ... + + @abstractmethod + def run_sql(self, template: str, data: Dict[str, Any]) -> Any: + ... + + @abstractmethod + def close(self) -> None: + ... + + # ---------- Async API (optional) ---------- + async def init_pool_async(self) -> None: # pragma: no cover - to be implemented by async adapters + raise NotImplementedError + + async def init_schema_async(self, schema_sql: str) -> None: # pragma: no cover + raise NotImplementedError + + async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: # pragma: no cover + raise NotImplementedError + + async def close_async(self) -> None: # pragma: no cover + raise NotImplementedError + + +class SQLAlchemyAdapter(EngineAdapter): + """Synchronous adapter backed by SQLAlchemy engine.""" + + def __init__(self, dsn: str) -> None: + self.dsn = dsn + self.engine: Engine = create_engine(self.dsn) + # Use a valid Jinja2SQL param style compatible with SQLAlchemy bound params + # "named" produces :name parameters which SQLAlchemy understands via text() bindings + self.j2sql = Jinja2SQL(param_style="named") + + def init_schema(self, schema_sql: str) -> None: + with self.engine.begin() as conn: + try: + for statement in schema_sql.split(';'): + if statement.strip(): + conn.execute(text(statement)) + except SQLAlchemyError as e: + raise RuntimeError(f'Failed to initialize schema: {str(e)}') from e + + def run_sql(self, template: str, data: Dict[str, Any]) -> Any: + # ensure now is available + if 'now' not in data: + data['now'] = datetime.now + + try: + # Normalize templates: remove explicit tojson filters to allow binding + normalized = re.sub(r"\{\{\s*([^}]+?)\s*\|\s*tojson\s*\}\}", r"{{ \1 }}", template) + query, params = self.j2sql.from_string(normalized, context=data) + except Exception as e: + raise ValueError( + f"Failed to render SQL. Likely SQL template & Parameter mismatch: {str(e)}" + ) from e + + with self.engine.connect() as conn: + with conn.begin(): + try: + statements = [stmt.strip() for stmt in query.split(';') if stmt.strip()] + total_rows = 0 + last_result = None + + logger.debug(f"Executing statements: {len(statements)}") + for statement in statements: + logger.debug(f"Executing statement: {statement} {params}") + result = conn.execute(text(statement), params) + total_rows += result.rowcount + if result.returns_rows: + last_result = result + + if last_result and last_result.returns_rows: + rows = [dict(row._mapping) for row in last_result] + logger.debug(f"Returning rows: {rows}") + # Return the same shape as Database currently does: a QueryResult-like object + # The Database facade will wrap rows if needed; here we just return rows + return rows + + return total_rows + except SQLAlchemyError as e: + raise RuntimeError( + f"Failed to execute SQL: {str(e)}\nRendered SQL: {query}" + ) from e + + def close(self) -> None: + try: + self.engine.dispose() + except Exception: + logger.exception("Error disposing SQLAlchemy engine") + + +class AsyncpgAdapter(EngineAdapter): + """Adapter for asyncpg (async PostgreSQL driver).""" + + def __init__(self, dsn: str) -> None: + self.dsn = dsn + self.pool: Optional[asyncpg.Pool] = None + # jinja2sql will render SQL & parameters in asyncpg style ($1, $2, ...) + self.j2sql = Jinja2SQL(param_style="asyncpg") + + async def init_pool_async(self) -> None: + if self.pool is None: + self.pool = await asyncpg.create_pool(self.dsn) + + async def close_async(self) -> None: + if self.pool is not None: + await self.pool.close() + self.pool = None + + async def init_schema_async(self, schema_sql: str) -> None: + await self.init_pool_async() + assert self.pool is not None + async with self.pool.acquire() as conn: + # Execute statements sequentially to ensure order + for statement in [s.strip() for s in schema_sql.split(';') if s.strip()]: + try: + logger.debug(f"Executing schema statement: {statement}") + await conn.execute(statement) + except Exception as e: + logger.error(f"Failed to execute schema statement: {statement}") + raise RuntimeError(f"Failed to execute schema statement: {str(e)}") from e + + async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: + await self.init_pool_async() + assert self.pool is not None + + if 'now' not in data: + data['now'] = datetime.now + + # Special handling for templates without parameters + if '{{' not in template: + # No template variables, execute directly + async with self.pool.acquire() as conn: + is_select = template.strip().lower().startswith("select") + if is_select: + records = await conn.fetch(template) + return [dict(r) for r in records] + else: + status = await conn.execute(template) + return _parse_rowcount(status) + + try: + # Add the same template normalization as SQLAlchemyAdapter + # Normalize templates: remove explicit tojson filters to allow binding + normalized = re.sub(r"\{\{\s*([^}]+?)\s*\|\s*tojson\s*\}\}", r"{{ \1 }}", template) + # jinja2sql generates SQL with $1, $2... placeholders and params in correct order + query, params_list = self.j2sql.from_string(normalized, context=data) + + # Debug: log parameter types + logger.debug(f"Rendered query: {query}") + logger.debug(f"Parameter values: {params_list}") + logger.debug(f"Parameter types: {[(type(p).__name__, p) for p in params_list]}") + + except Exception as e: + raise ValueError( + f"Failed to render SQL. Likely SQL template & Parameter mismatch: {str(e)}" + ) from e + + async with self.pool.acquire() as conn: + try: + # Handle multiple statements like SQLAlchemyAdapter + statements = [stmt.strip() for stmt in query.split(';') if stmt.strip()] + total_rows = 0 + last_result = None + + logger.debug(f"Executing {len(statements)} statement(s)") + for i, statement in enumerate(statements): + logger.debug(f"Statement {i+1}: {statement}") + + is_select = statement.strip().lower().startswith("select") + + if is_select: + last_result = await conn.fetch(statement, *params_list) + total_rows += len(last_result) + logger.debug(f"SELECT returned {len(last_result)} rows") + else: + # Check if this is a schema operation or data operation + is_schema_op = any(keyword in statement.upper() for keyword in + ['CREATE', 'DROP', 'ALTER', 'TRUNCATE']) + + if is_schema_op: + # Schema operations typically don't use parameters + logger.debug("Executing as schema operation (no parameters)") + status = await conn.execute(statement) + else: + # Data operations use parameters + logger.debug(f"Executing as data operation with parameters: {params_list}") + status = await conn.execute(statement, *params_list) + + stmt_rows = _parse_rowcount(status) + total_rows += stmt_rows + logger.debug(f"Statement affected {stmt_rows} rows") + + # Return results similar to SQLAlchemyAdapter + if last_result is not None: + rows = [dict(r) for r in last_result] + logger.debug(f"Returning {len(rows)} rows") + return rows + + logger.debug(f"Returning row count: {total_rows}") + return total_rows + + except Exception as e: + error_msg = (f"Failed to execute SQL: {str(e)}\n" + f"Rendered SQL: {query}\n" + f"Parameters: {params_list}\n" + f"Parameter types: {[type(p).__name__ for p in params_list]}") + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + # Sync methods are not supported for asyncpg adapter + def init_schema(self, schema_sql: str) -> None: # pragma: no cover - sync not supported + raise NotImplementedError("Use init_schema_async with AsyncpgAdapter") + + def run_sql(self, template: str, data: Dict[str, Any]) -> Any: # pragma: no cover - sync not supported + raise NotImplementedError("Use run_sql_async with AsyncpgAdapter") + + def close(self) -> None: # pragma: no cover - sync not supported + raise NotImplementedError("Use close_async with AsyncpgAdapter") + + +def _parse_rowcount(status: str) -> int: + """Parse rowcount from asyncpg status like 'INSERT 0 1', 'UPDATE 3', etc.""" + try: + parts = status.split() + # Common patterns: 'INSERT 0 1' (rowcount is last), 'UPDATE 3' (last), 'DELETE 2' + return int(parts[-1]) + except Exception: + return 0 \ No newline at end of file diff --git a/foundation_sql/gen.py b/foundation_sql/gen.py index 2dc2534..cad5266 100644 --- a/foundation_sql/gen.py +++ b/foundation_sql/gen.py @@ -43,17 +43,20 @@ def generate_sql(self, prompt: str) -> str: Args: prompt (str): Detailed prompt for SQL generation - + Returns: str: Generated SQL template """ response = self.client.chat.completions.create( - model=self.model, # Use an appropriate model + model=self.model, messages=[ {"role": "system", "content": prompt} ] ) generated_sql = response.choices[0].message.content.strip() - sql_template = re.sub(r'^```sql\n|^```\n|\n```$', '', generated_sql, flags=re.MULTILINE) - return sql_template + + # Remove ```sql or ``` fences + sql_template = re.sub(r"^```sql\s*|^```\s*|```$", "", generated_sql, flags=re.MULTILINE).strip() + + return sql_template \ No newline at end of file diff --git a/foundation_sql/prompts.md b/foundation_sql/prompts.md index 6809d26..2cd09cc 100644 --- a/foundation_sql/prompts.md +++ b/foundation_sql/prompts.md @@ -1,20 +1,23 @@ -You are an expert SQL developer. Write one or more SQL queries that can perform the actions as explained by the user. Ensure, the SQL query is usable across sqlite and postgresql. The SQL template generated is a jinja2 template - so jinja2 syntax can be used. +You are an expert SQL developer. Write one or more SQL queries that can perform the actions as explained by the user. The SQL template generated is a Jinja2 template. + +Primary target is PostgreSQL (tests run against Postgres). Prefer Postgres-compatible SQL. Avoid SQLite-specific functions. 1. Start with a -- comment to document the function name, parameters and docstring, explaining what the SQL query does. 2. Use jinja2 template to generate SQL 3. When accessing nested fields handle cases if they aren't defined. Use default filter with None value for such cases e.g. {{user.zip_code|default(None)}} -4. Ensure response rows can be parsed into Pydantic model. As long as the model fields are named the same as the columns in the SQL query. It also supports nested models by using double underscores to separate nested fields. +4. Ensure response rows include ALL fields in the return Pydantic model. Name columns exactly as model fields. For nested models, alias columns using dot notation with quotes, e.g. "workspace.id", "workspace.name". 5. For complex tasks, more than one queries can be run, separated by ';' 6. Only respond with a single ```sql``` block which contains all queries. 7. No other explanation is necessary -8. For insert queries, avoid any RETURNING clause. Let it return the default. -9. We use jinja2 syntax to generate SQL - so parameters don't need to be quoted e.g. use {{user.zip_code|default(None)}} and not '{{user.zip_code|default(None)}}' -10. Use double underscores (.) to separate nested fields including for multiple levels of nesting e.g. `profile.address.street` - note that the field names need to be quoted as we are using `.` -11. Use backticks (``) to quote column names and table names -12. DONOT use json_build_object to build JSON objects for nested fields -13. DONOT use '' to quote jinja variables. The binding would take care of that automatically. -14. Pay special attention to primary key (usually id fields). Sometimes, they are auto-generated in the schema, in which case insert queries should not set them. Otherwise, they must already be set in the model and then inserted into the table as well. +8. For insert queries: + - If the function returns a primitive count/int, just perform INSERT without RETURNING. + - If the function returns a model object, perform INSERT first, then a SELECT that fetches the inserted row. In Postgres use: WHERE id = (SELECT LASTVAL()). +9. We use Jinja2 syntax to generate SQL - DO NOT wrap Jinja variables in quotes. Example: {{user.zip_code|default(None)}}, not '{{user.zip_code|default(None)}}'. +10. Quote identifiers (table/column/alias names) with double quotes ".". Quote string literals with single quotes '.'. +11. When using Postgres json_build_object, keys MUST be single-quoted string literals, e.g. json_build_object('bio', u.profile_bio). Do NOT use double quotes for keys. +12. Prefer selecting individual columns with proper aliases (including nested aliases) rather than building JSON blobs, unless explicitly requested. +13. Pay special attention to primary key (usually id fields). If auto-generated (e.g., SERIAL/BIGSERIAL), DO NOT insert a value. Otherwise include it from the model. Here is an example @@ -31,26 +34,26 @@ The SQL generated would look like the following --- Creates and returns a Task object, for the provided workspace and task_no --- Expects task_no and workspace.id are defined. If no tasks are found, returns None SELECT - t.id as `id`, - t.task_no as `task_no`, - t.title as `title`, - t.description as `description`, - t.status as `status`, - t.created_at as `created_at`, - t.updated_at as `updated_at`, - a.id as `agent.id`, - a.name as `agent.name`, - a.description as `agent.description`, - a.instructions as `agent.instructions`, - a.type as `agent.type`, - a.created_at as `agent.created_at`, - a.updated_at as `agent.updated_at`, - m.id as `agent.model.id`, - m.name as `agent.model.name`, - m.context_window as `agent.model.context_window`, - m.max_tokens as `agent.model.max_tokens`, - m.created_at as `agent.model.created_at`, - m.updated_at as `agent.model.updated_at`, + t.id as 'id', + t.task_no as 'task_no', + t.title as 'title', + t.description as 'description', + t.status as 'status', + t.created_at as 'created_at', + t.updated_at as 'updated_at', + a.id as 'agent.id', + a.name as 'agent.name', + a.description as 'agent.description', + a.instructions as 'agent.instructions', + a.type as 'agent.type', + a.created_at as 'agent.created_at', + a.updated_at as 'agent.updated_at', + m.id as 'agent.model.id', + m.name as 'agent.model.name', + m.context_window as 'agent.model.context_window', + m.max_tokens as 'agent.model.max_tokens', + m.created_at as 'agent.model.created_at', + m.updated_at as 'agent.model.updated_at', FROM tasks t LEFT JOIN agents a ON t.agent_id = a.id LEFT JOIN models m ON a.model_id = m.id diff --git a/foundation_sql/query.py b/foundation_sql/query.py index 51a0caf..20868a4 100644 --- a/foundation_sql/query.py +++ b/foundation_sql/query.py @@ -1,6 +1,7 @@ import os import functools from typing import Any, Callable, Optional +import inspect from foundation_sql.prompt import SQLPromptGenerator, FunctionSpec from foundation_sql.gen import SQLGenerator @@ -43,6 +44,7 @@ def __init__( api_key: Optional[str] = None, base_url: Optional[str] = None, model: Optional[str] = None, + adapter_mode: str = "sync", ): """ Initialize the SQL query decorator. @@ -65,7 +67,7 @@ def __init__( self.db_url = db_url if not self.db_url: - raise ValueError(f"Database URL not provided either through constructor or {db_url_env} environment variable") + raise ValueError("Database URL not provided either through constructor or DATABASE_URL environment variable") # Initialize cache and SQL generator self.cache = SQLTemplateCache(cache_dir=cache_dir) @@ -77,6 +79,7 @@ def __init__( ) self.repair = repair + self.adapter_mode = adapter_mode def __call__(self, func: Callable) -> Callable: @@ -115,28 +118,77 @@ def sql_gen(kwargs: Dict[str, Any], error: Optional[str]=None, prev_template: Op return sql_template - @functools.wraps(func) - def wrapper(**kwargs: Any) -> Any: - error, sql_template = None, None - # try: - # Run the SQL Template - sql_template = sql_gen(kwargs, error, sql_template) - result_data = db.run_sql(self.db_url, sql_template, **kwargs) - + def _parse_result(result_data: Any): if fn_spec.wrapper == 'list': - parsed_result = [ - db.parse_query_to_pydantic(row, fn_spec.return_type) + return [ + db.parse_query_to_pydantic(row, fn_spec.return_type) for row in result_data.all() ] elif isinstance(result_data, int): - parsed_result = result_data + return result_data else: first_row = result_data.first() - parsed_result = db.parse_query_to_pydantic(first_row, fn_spec.return_type) if first_row else None - - return parsed_result - - return wrapper + return db.parse_query_to_pydantic(first_row, fn_spec.return_type) if first_row else None + + is_func_async = inspect.iscoroutinefunction(func) + use_async = is_func_async or (self.adapter_mode == "async") + + if use_async: + @functools.wraps(func) + async def async_wrapper(**kwargs: Any) -> Any: + last_exc: Optional[Exception] = None + error: Optional[str] = None + sql_template: Optional[str] = None + attempts = self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + database = db.get_db_with_adapter(self.db_url, "async") + for _ in range(attempts): + sql_template = sql_gen(kwargs, error, sql_template) + try: + result_data = await database.run_sql_async(sql_template, **kwargs) + try: + return _parse_result(result_data) + except Exception as parse_err: + last_exc = parse_err + error = f"Parsing/Validation error: {parse_err}" + continue + except Exception as exec_err: + last_exc = exec_err + error = f"Execution error: {exec_err}" + continue + if last_exc: + raise last_exc + raise RuntimeError("SQL generation failed without explicit exception") + + return async_wrapper + else: + @functools.wraps(func) + def wrapper(**kwargs: Any) -> Any: + last_exc: Optional[Exception] = None + error: Optional[str] = None + sql_template: Optional[str] = None + attempts = self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + for _ in range(attempts): + sql_template = sql_gen(kwargs, error, sql_template) + try: + result_data = db.run_sql(self.db_url, sql_template, **kwargs) + try: + return _parse_result(result_data) + except Exception as parse_err: + # Capture validation/parse issues and retry regeneration + last_exc = parse_err + error = f"Parsing/Validation error: {parse_err}" + continue + except Exception as exec_err: + last_exc = exec_err + error = f"Execution error: {exec_err}" + continue + # Exhausted attempts + if last_exc: + raise last_exc + # Fallback: shouldn't happen, but raise generic error + raise RuntimeError("SQL generation failed without explicit exception") + + return wrapper diff --git a/plan.md b/plan.md new file mode 100644 index 0000000..cc37176 --- /dev/null +++ b/plan.md @@ -0,0 +1,105 @@ +# Foundation DB Layer Refactor Plan + +This document outlines the implementation plan for refactoring the Foundation DB layer to support both synchronous and asynchronous operations using a pluggable adapter architecture. + +## Phase 1: Core Abstraction and Synchronous Implementation + +This phase focuses on establishing the core `EngineAdapter` abstraction and refactoring the existing synchronous functionality to use the new `SQLAlchemyAdapter`. + +### Checklist + +- [x] **Create `foundation_sql/db_drivers.py` file.** + - This file will house the `EngineAdapter` abstract class and its concrete implementations. +- [x] **Adapter Selection & Factory (Sync by default, explicit async flag via decorator).** + - Add an adapter factory in `foundation_sql/db.py` that instantiates adapters. + - Keep existing public functions stable; do not change `db.run_sql(db_url, template, **ctx)` or `db.get_db(db_url)` signatures. + - Introduce a new helper for internal use only (by the decorator): `get_db_with_adapter(db_url, mode: Literal["sync","async"])` or equivalent, so existing API remains unchanged. + - The decorator controls adapter selection via an explicit flag (see Phase 2) and uses the helper to obtain the correct `Database`/adapter. +- [x] **Define `EngineAdapter` Abstract Base Class in `db_drivers.py`.** + - Create `EngineAdapter` with the following abstract methods: + - `init_schema(self, schema_sql: str)` + - `run_sql(self, template: str, data: dict)` + - `close(self)` + - `init_pool_async(self)` + - `init_schema_async(self, schema_sql: str)` + - `run_sql_async(self, template: str, data: dict)` + - `close_async(self)` +- [x] **Implement `SQLAlchemyAdapter` in `db_drivers.py`.** + - Create a class `SQLAlchemyAdapter` that inherits from `EngineAdapter`. + - Move the `sqlalchemy` engine creation and execution logic from `foundation_sql/db.py` into this class. + - Implement the synchronous methods: + - `__init__(self, dsn: str)`: + - Initialize `sqlalchemy.create_engine` with the given DSN. + - Store the engine instance. + - `init_schema(self, schema_sql: str)`: + - Implement the schema initialization logic currently in `db.Database.init_schema`. + - `run_sql(self, template: str, data: dict)`: + - Implement the SQL execution logic currently in `db.Database.run_sql`. + - `close(self)`: + - Dispose of the SQLAlchemy engine. + - Jinja2SQL configuration for sync: + - Use `Jinja2SQL(param_style="sqlalchemy")` (or default) to produce `:name` bindings compatible with SQLAlchemy. +- [x] **Refactor `foundation_sql/db.py` to use the Adapter Pattern.** + - Modify the `Database` class `__init__` method to accept an `EngineAdapter` instance. + - The `Database` class will delegate calls to the adapter's methods (`init_schema`, `run_sql`). + - Remove the direct `sqlalchemy` dependencies from `db.py` that are now handled by the `SQLAlchemyAdapter`. + - The `get_db` function will be updated to instantiate the `Database` class with the appropriate adapter. + - Maintain backward compatibility by keeping `get_db(db_url)` and `run_sql(db_url, template, **ctx)` unchanged; add internal helper `get_db_with_adapter(db_url, mode)` for the decorator’s explicit selection. +- [x] **Update `foundation_sql/query.py` to align with the new DB structure.** + - The `SQLQueryDecorator` currently calls `db.run_sql`. Ensure this continues to work with the refactored `db.py`. + +## Phase 2: Asynchronous Implementation with `asyncpg` + +This phase introduces asynchronous support by implementing the `AsyncpgAdapter` and updating the core `db.py` and `query.py` to handle async operations. + +### Checklist + +- [x] **Add `asyncpg` to `requirements.txt`.** +- [x] **Implement `AsyncpgAdapter` in `db_drivers.py`.** + - Create a class `AsyncpgAdapter` that inherits from `EngineAdapter`. + - `__init__(self, dsn: str)`: + - Store the DSN. + - Initialize `Jinja2SQL(param_style="asyncpg", enable_async=True)`. + - Implement the asynchronous methods: + - `init_pool_async(self)` + - `close_async(self)` + - `init_schema_async(self, schema_sql: str)` + - `run_sql_async(self, template: str, data: dict)` +- [x] **Update `foundation_sql/db.py` for Async Support.** + - Add `async` methods to the `Database` class (`init_schema_async`, `run_sql_async`, etc.). + - These methods will delegate to the corresponding `async` methods on the adapter instance. +- [x] **Update `foundation_sql/query.py` for Async Support.** + - Modify the `SQLQueryDecorator` to be async-aware and to explicitly control adapter selection via a flag. + - Add explicit adapter selection flag, e.g., `adapter_mode: Literal["sync","async"] = "sync"` (or `async_mode: bool = False`). + - If `adapter_mode == "async"` (or `async_mode is True`), the decorator uses `get_db_with_adapter(db_url, "async")` and async execution APIs. + - Else it uses the sync adapter path. + - Decorator async-awareness design: + - If the wrapped user function is `async def` OR `adapter_mode == "async"`, return an `async def wrapper`. + - Otherwise return a sync `def wrapper`. + - Unify template generation, caching, and result parsing across both paths to ensure parity. + - Keep the parsing behavior consistent with current logic (support `wrapper == 'list'`, `parse_query_to_pydantic`, integer rowcounts, and first-row mapping). + - Public API stability: + - Existing `db.run_sql(db_url, template, **ctx)` and `db.get_db(db_url)` remain available and unchanged. + - The decorator should use the new internal helper to pick adapter without altering public function signatures. + +## Phase 3: Testing, Documentation, and Finalization + +This phase ensures the new implementation is robust, well-documented, and easy to use. + +### Checklist + +- [ ] **Create Unit Tests for Both Adapters.** + - Write tests for `SQLAlchemyAdapter` to ensure existing sync functionality is not broken. + - Write new tests for `AsyncpgAdapter` to verify all async functionality. + - Add tests for `SQLQueryDecorator` covering: + - Sync path vs async path (wrapper type selection, execution, and parity of results) + - Template caching and regeneration + - Adapter selection via `adapter_mode` flag +- [ ] **Update `README.md` and Add Examples.** + - Document the new adapter-based architecture. + - Provide clear examples for both sync and async usage, including explicit adapter selection in the decorator and examples showing sync vs async wrappers. +- [ ] **Update `prompts.md`** + - Reflect the new async capabilities in the prompts documentation if necessary. +- [ ] **Code Review and Refinement.** + - Perform a final review of the code for clarity, consistency, and adherence to best practices. + - Verify public API stability (`db.run_sql`, `db.get_db`) and mark any new helper (e.g., `get_db_with_adapter`) as internal-only in docs. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 4c1cdff..2c34a84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,5 @@ sqlalchemy jinja2sql pydantic openai +asyncpg +psycopg2-binary \ No newline at end of file diff --git a/tests/common.py b/tests/common.py index a4de0bf..f9bdfb2 100644 --- a/tests/common.py +++ b/tests/common.py @@ -2,18 +2,24 @@ import os from foundation_sql import db from foundation_sql.query import SQLQueryDecorator +import re from dotenv import load_dotenv load_dotenv() -DB_URL = os.environ.get("DATABSE_URL", "sqlite:///:memory:") +# Prefer DATABASE_URL from env (e.g., PostgreSQL) for tests; fallback to SQLite in-memory +# Async tests explicitly read DATABASE_URL themselves and are skipped if absent. +DB_URL = os.getenv("DATABASE_URL", "sqlite:///:memory:") def create_query(schema): - return SQLQueryDecorator(schema=schema, - db_url=DB_URL, - api_key=os.getenv("OPENAI_API_KEY"), - base_url=os.getenv("OPENAI_API_BASE_URL"), - model=os.getenv("OPENAI_MODEL")) + return SQLQueryDecorator( + schema=schema, + db_url=DB_URL, + api_key=os.getenv("OPENAI_API_KEY"), + base_url=os.getenv("OPENAI_API_BASE_URL"), + model=os.getenv("OPENAI_MODEL"), + repair=2, + ) class DatabaseTests(unittest.TestCase): """Base test class for database-driven tests with common setup and helper methods.""" @@ -27,15 +33,36 @@ def setUp(self): # Re-initialize the schema for each test to ensure clean state if (self.schema_sql or self.schema_path) and self.db_url: db.get_db(self.db_url).init_schema(schema_sql=self.schema_sql, schema_path=self.schema_path) + # Capture table names from schema_sql for teardown cleanup + self._tables_to_drop = [] + if self.schema_sql: + for raw in self.schema_sql.split(';'): + stmt = raw.strip() + if not stmt: + continue + m = re.search(r"CREATE\s+TABLE\s+IF\s+NOT\s+EXISTS\s+([a-zA-Z_][a-zA-Z0-9_\.]*)", stmt, re.IGNORECASE) + if m: + self._tables_to_drop.append(m.group(1)) else: raise ValueError("At least one of schema_sql, schema_path must be provided along with db_url") def tearDown(self): """Close the database connection after each test.""" - for _, connection in db.DATABASES.items(): - connection.get_engine().dispose() - + # Best-effort cleanup of tables created by this test to avoid cross-test interference + try: + if getattr(self, "_tables_to_drop", None): + database = db.get_db(self.db_url) + # Drop in reverse order to reduce FK issues + for t in reversed(self._tables_to_drop): + try: + database.execute(f"DROP TABLE IF EXISTS {t} CASCADE") + except Exception: + pass + finally: + for _, connection in db.DATABASES.items(): + connection.get_engine().dispose() + db.DATABASES.clear() db.DATABASES.clear() diff --git a/tests/test_adapters.py b/tests/test_adapters.py new file mode 100644 index 0000000..1d35535 --- /dev/null +++ b/tests/test_adapters.py @@ -0,0 +1,73 @@ +import os +import unittest +from typing import Any, Dict, List + +from foundation_sql.db_drivers import SQLAlchemyAdapter +from foundation_sql import db + + +SYNC_DB_URL = os.environ.get("DATABASE_URL") + + +TEST_SCHEMA = """ +CREATE TABLE IF NOT EXISTS items ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL +); +""" + +INSERT_TEMPLATE = """ +INSERT INTO items (id, name) VALUES ({{ id }}, {{ name | tojson }}); +""" + +SELECT_TEMPLATE = """ +SELECT id, name FROM items ORDER BY id; +""" + + +class TestSQLAlchemyAdapter(unittest.TestCase): + def setUp(self) -> None: + if not SYNC_DB_URL: + raise unittest.SkipTest("DATABASE_URL not set; skipping Postgres-only tests") + self.adapter = SQLAlchemyAdapter(SYNC_DB_URL) + # init schema + self.adapter.init_schema(TEST_SCHEMA) + + def tearDown(self) -> None: + try: + self.adapter.close() + finally: + # Clear Database singletons that might have been created + db.DATABASES.clear() + + def test_insert_and_select(self): + # insert 2 rows + rc1 = self.adapter.run_sql(INSERT_TEMPLATE, {"id": 1, "name": "alpha"}) + rc2 = self.adapter.run_sql(INSERT_TEMPLATE, {"id": 2, "name": "beta"}) + self.assertIsInstance(rc1, int) + self.assertIsInstance(rc2, int) + self.assertEqual(rc1 + rc2, 2) + + rows = self.adapter.run_sql(SELECT_TEMPLATE, {}) + self.assertIsInstance(rows, list) + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0]["id"], 1) + self.assertEqual(rows[0]["name"], "alpha") + + def test_database_facade_run_sql_parity(self): + # Ensure Database facade wraps adapter correctly + database = db.Database(SYNC_DB_URL, adapter=self.adapter) + # seed two rows using adapter path + _ = self.adapter.run_sql(INSERT_TEMPLATE, {"id": 1, "name": "alpha"}) + _ = self.adapter.run_sql(INSERT_TEMPLATE, {"id": 2, "name": "beta"}) + # Insert + affected = database.run_sql(INSERT_TEMPLATE, id=3, name="gamma") + self.assertEqual(affected, 1) + # Select should return QueryResult wrapper + result = database.run_sql(SELECT_TEMPLATE) + self.assertTrue(hasattr(result, "first")) + self.assertEqual(result.count(), 3) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_query_decorator_async.py b/tests/test_query_decorator_async.py new file mode 100644 index 0000000..2a522c0 --- /dev/null +++ b/tests/test_query_decorator_async.py @@ -0,0 +1,90 @@ +import os +import shutil +import unittest +import inspect +from typing import List + +from pydantic import BaseModel + +from foundation_sql.query import SQLQueryDecorator +from foundation_sql.db_drivers import AsyncpgAdapter +from foundation_sql import db + + +ASYNC_DB_URL = os.environ.get("DATABASE_URL") # e.g., postgresql://user:pass@localhost:5432/dbname +CACHE_DIR = "__sql__" + + +@unittest.skipUnless(ASYNC_DB_URL, "Async tests require DATABASE_URL Postgres DSN") +class TestSQLQueryDecoratorAsync(unittest.IsolatedAsyncioTestCase): + class User(BaseModel): + id: int + name: str + + TABLES_SCHEMA = """ + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ); + """ + + @classmethod + def setUpClass(cls) -> None: + # Ensure clean cache dir and seed templates + if os.path.exists(CACHE_DIR): + shutil.rmtree(CACHE_DIR) + os.makedirs(CACHE_DIR, exist_ok=True) + with open(os.path.join(CACHE_DIR, "get_users.sql"), "w") as f: + f.write("SELECT id, name FROM users ORDER BY id;") + with open(os.path.join(CACHE_DIR, "create_user.sql"), "w") as f: + f.write("INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});") + + async def asyncSetUp(self): + # Initialize schema on Postgres using async adapter each test for isolation + self.database = db.Database(ASYNC_DB_URL, adapter=AsyncpgAdapter(ASYNC_DB_URL)) + await self.database.init_schema_async(schema_sql=self.TABLES_SCHEMA) + # Clean table + await self.database.run_sql_async("DELETE FROM users;") + + async def asyncTearDown(self): + await self.database.close_async() + db.DATABASES.clear() + + async def test_async_wrapper_and_execution(self): + query = SQLQueryDecorator(schema=self.TABLES_SCHEMA, db_url=ASYNC_DB_URL, cache_dir=CACHE_DIR, adapter_mode="async") + + @query + def get_users() -> List["TestSQLQueryDecoratorAsync.User"]: + pass + + @query + def create_user(user: "TestSQLQueryDecoratorAsync.User") -> int: + pass + + # wrappers should be async + self.assertTrue(inspect.iscoroutinefunction(get_users)) + self.assertTrue(inspect.iscoroutinefunction(create_user)) + + users = await get_users() + self.assertEqual(len(users), 0) + rc = await create_user(user=self.User(id=1, name="Alice")) + self.assertEqual(rc, 1) + users = await get_users() + self.assertEqual(len(users), 1) + self.assertEqual(users[0].id, 1) + self.assertEqual(users[0].name, "Alice") + + async def test_async_adapter_direct_use(self): + adapter = AsyncpgAdapter(ASYNC_DB_URL) + await adapter.init_pool_async() + await adapter.init_schema_async(self.TABLES_SCHEMA) + rc = await adapter.run_sql_async("INSERT INTO users (id, name) VALUES ({{ id }}, {{ name | tojson }});", {"id": 2, "name": "Bob"}) + self.assertEqual(rc, 1) + rows = await adapter.run_sql_async("SELECT id, name FROM users ORDER BY id;", {}) + self.assertIsInstance(rows, list) + self.assertGreaterEqual(len(rows), 1) + await adapter.close_async() + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_query_decorator_sync.py b/tests/test_query_decorator_sync.py new file mode 100644 index 0000000..df8169b --- /dev/null +++ b/tests/test_query_decorator_sync.py @@ -0,0 +1,109 @@ +import os +import shutil +import unittest +import inspect +from typing import List + +from pydantic import BaseModel +from tests import common +from foundation_sql.query import SQLQueryDecorator +from foundation_sql import db + +DB_URL = os.environ.get("DATABASE_URL") +CACHE_DIR = "__sql__" + + +class User(BaseModel): + id: int + name: str + + +TABLES_SCHEMA = """ +CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL +); +""" + + +class TestSQLQueryDecoratorSync(common.DatabaseTests): + schema_sql = TABLES_SCHEMA + + @classmethod + def setUpClass(cls) -> None: + # Ensure clean cache dir + if os.path.exists(CACHE_DIR): + shutil.rmtree(CACHE_DIR) + os.makedirs(CACHE_DIR, exist_ok=True) + # Seed SQL templates so no LLM calls are needed + with open(os.path.join(CACHE_DIR, "get_users.sql"), "w") as f: + f.write("SELECT id, name FROM users ORDER BY id;") + with open(os.path.join(CACHE_DIR, "create_user.sql"), "w") as f: + f.write("INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});") + + def setUp(self): + super().setUp() + # Ensure clean table state for each test + database = db.get_db(self.db_url) + database.run_sql("DELETE FROM users;") + + def test_sync_wrapper_and_execution(self): + query = SQLQueryDecorator(schema=TABLES_SCHEMA, db_url=self.db_url, cache_dir=CACHE_DIR, adapter_mode="sync") + + @query + def get_users() -> List[User]: + """Get all users""" + pass + + @query + def create_user(user: User) -> int: + """Create user""" + pass + + # wrapper should be sync + self.assertFalse(inspect.iscoroutinefunction(get_users)) + self.assertFalse(inspect.iscoroutinefunction(create_user)) + + # execute + users = get_users() + self.assertEqual(len(users), 0) + rc = create_user(user=User(id=1, name="Alice")) + self.assertEqual(rc, 1) + users = get_users() + self.assertEqual(len(users), 1) + self.assertEqual(users[0].id, 1) + self.assertEqual(users[0].name, "Alice") + + def test_template_caching(self): + # Pre-seeded template exists; decorator should reuse without regeneration + query = SQLQueryDecorator(schema=TABLES_SCHEMA, db_url=self.db_url, cache_dir=CACHE_DIR, adapter_mode="sync") + + @query + def get_users() -> List[User]: + pass + + # Change cached SQL to return in reverse order to observe effect + with open(os.path.join(CACHE_DIR, "get_users.sql"), "w") as f: + f.write("SELECT id, name FROM users ORDER BY id DESC;") + + # Insert two rows and ensure order follows cached SQL + database = db.get_db(self.db_url) + database.run_sql("INSERT INTO users (id, name) VALUES (1, 'A');") + database.run_sql("INSERT INTO users (id, name) VALUES (2, 'B');") + + users = get_users() + self.assertEqual([u.id for u in users], [2, 1]) + + def test_adapter_selection_flag_async_wrapper_for_sync_func(self): + # When adapter_mode="async", even a sync function should return async wrapper + query = SQLQueryDecorator(schema=TABLES_SCHEMA, db_url=self.db_url, cache_dir=CACHE_DIR, adapter_mode="async") + + @query + def get_users_async() -> List[User]: + pass + + self.assertTrue(inspect.iscoroutinefunction(get_users_async)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000..e8f9111 --- /dev/null +++ b/uv.lock @@ -0,0 +1,8 @@ +version = 1 +revision = 2 +requires-python = ">=3.8" + +[[package]] +name = "foundation-sql" +version = "0.0.0" +source = { editable = "." } From 0005a010508c23533e1ca702732f26c697ac3ead Mon Sep 17 00:00:00 2001 From: ojasaklechayt Date: Tue, 12 Aug 2025 14:46:30 +0530 Subject: [PATCH 2/5] refactor: move SQL templates to __sql__ directory and enforce SQLite for common tests --- .gitignore | 4 ++- tests/common.py | 6 ++-- tests/test_adapters.py | 5 +++ tests/test_many_to_one.py | 63 ++++++++++++++++++++++++++++++++++++-- tests/test_nested_query.py | 56 ++++++++++++++++++++++++++++++--- 5 files changed, 123 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 1756956..b9e30d2 100644 --- a/.gitignore +++ b/.gitignore @@ -35,4 +35,6 @@ ENV/ .env # Misc -.DS_Store \ No newline at end of file +.DS_Store + +__sql__ \ No newline at end of file diff --git a/tests/common.py b/tests/common.py index f9bdfb2..03eb0c0 100644 --- a/tests/common.py +++ b/tests/common.py @@ -7,9 +7,9 @@ from dotenv import load_dotenv load_dotenv() -# Prefer DATABASE_URL from env (e.g., PostgreSQL) for tests; fallback to SQLite in-memory -# Async tests explicitly read DATABASE_URL themselves and are skipped if absent. -DB_URL = os.getenv("DATABASE_URL", "sqlite:///:memory:") +# Force SQLite in-memory for all tests that use this common module. +# Async/Postgres-specific tests manage their own DATABASE_URL and are skipped if absent. +DB_URL = "sqlite:///:memory:" def create_query(schema): return SQLQueryDecorator( diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 1d35535..dbdf80c 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -32,6 +32,11 @@ def setUp(self) -> None: self.adapter = SQLAlchemyAdapter(SYNC_DB_URL) # init schema self.adapter.init_schema(TEST_SCHEMA) + # ensure clean state for each test run + try: + self.adapter.run_sql("DELETE FROM items;", {}) + except Exception: + pass def tearDown(self) -> None: try: diff --git a/tests/test_many_to_one.py b/tests/test_many_to_one.py index e21ba46..a638200 100644 --- a/tests/test_many_to_one.py +++ b/tests/test_many_to_one.py @@ -1,4 +1,6 @@ from typing import List, Optional +import os +import shutil from tests import common from pydantic import BaseModel @@ -30,6 +32,64 @@ class Task(BaseModel): query = common.create_query(schema=TABLES_SCHEMA) +CACHE_DIR = "__sql__" + +class TestWorkspaceTasks(common.DatabaseTests): + schema_sql = TABLES_SCHEMA + + @classmethod + def setUpClass(cls) -> None: + # Ensure clean cache dir and seed SQLite-friendly templates + if os.path.exists(CACHE_DIR): + shutil.rmtree(CACHE_DIR) + os.makedirs(CACHE_DIR, exist_ok=True) + + # Create workspace: insert then fetch by last_insert_rowid() + with open(os.path.join(CACHE_DIR, "create_workspace.sql"), "w") as f: + f.write( + """ + INSERT INTO workspaces (name) VALUES ({{ name | tojson }}); + SELECT id, name FROM workspaces WHERE id = last_insert_rowid(); + """.strip() + ) + + # Add task to workspace: insert then fetch row + with open(os.path.join(CACHE_DIR, "add_task_to_workspace.sql"), "w") as f: + f.write( + """ + INSERT INTO tasks (workspace_id, title, description) + VALUES ( + {{ workspace.id }}, + {{ title | tojson }}, + {{ description | default(None) | tojson }} + ); + SELECT + id as "id", + workspace_id as "workspace.id", + title as "title", + description as "description" + FROM tasks + WHERE id = last_insert_rowid(); + """.strip() + ) + + # Get tasks for workspace: join with dotted aliases for nesting + with open(os.path.join(CACHE_DIR, "get_tasks_for_workspace.sql"), "w") as f: + f.write( + """ + SELECT + t.id as "id", + t.workspace_id as "workspace.id", + w.name as "workspace.name", + t.title as "title", + t.description as "description" + FROM tasks t + JOIN workspaces w ON w.id = t.workspace_id + WHERE w.id = {{ workspace.id }} + ORDER BY t.id; + """.strip() + ) + @query def create_workspace(name: str) -> Workspace: """ @@ -51,9 +111,6 @@ def get_tasks_for_workspace(workspace: Workspace) -> List[Task]: """ pass -class TestWorkspaceTasks(common.DatabaseTests): - schema_sql = TABLES_SCHEMA - def test_workspace_tasks(self): # Add a workspace ws = create_workspace(name="Project Alpha") diff --git a/tests/test_nested_query.py b/tests/test_nested_query.py index a1da021..f33b539 100644 --- a/tests/test_nested_query.py +++ b/tests/test_nested_query.py @@ -1,4 +1,6 @@ from typing import List, Optional +import os +import shutil from tests import common from pydantic import BaseModel @@ -34,6 +36,56 @@ class UserWithProfile(BaseModel): query = common.create_query(schema=TABLES_SCHEMA) +CACHE_DIR = "__sql__" + +class TestNestedQuery(common.DatabaseTests): + + schema_sql = TABLES_SCHEMA + + @classmethod + def setUpClass(cls) -> None: + # Ensure clean cache dir and seed SQLite/Postgres portable templates + if os.path.exists(CACHE_DIR): + shutil.rmtree(CACHE_DIR) + os.makedirs(CACHE_DIR, exist_ok=True) + + # Portable SELECT using dotted aliases so our unflatten logic can build nested objects + with open(os.path.join(CACHE_DIR, "get_users_with_profile.sql"), "w") as f: + f.write( + """ + SELECT + u.id as "id", + u.name as "name", + u.email as "email", + u.role as "role", + u.profile_bio as "profile.bio", + u.address_street as "profile.address.street", + u.address_city as "profile.address.city", + u.address_zip_code as "profile.address.zip_code" + FROM users_with_profile u + ORDER BY u.id; + """.strip() + ) + + # Portable INSERT template using jinja2 variables (handled by jinja2sql) + with open(os.path.join(CACHE_DIR, "create_user_with_profile.sql"), "w") as f: + f.write( + """ + INSERT INTO users_with_profile ( + id, name, email, role, profile_bio, address_street, address_city, address_zip_code + ) VALUES ( + {{ user.id }}, + {{ user.name | tojson }}, + {{ user.email | tojson }}, + {{ user.role | tojson }}, + {{ user.profile.bio | default(None) | tojson }}, + {{ user.profile.address.street | default(None) | tojson }}, + {{ user.profile.address.city | default(None) | tojson }}, + {{ user.profile.address.zip_code | default(None) | tojson }} + ); + """.strip() + ) + @query def get_users_with_profile() -> List[UserWithProfile]: """ @@ -48,10 +100,6 @@ def create_user_with_profile(user: UserWithProfile) -> int: """ pass -class TestNestedQuery(common.DatabaseTests): - - schema_sql = TABLES_SCHEMA - def test_nested_object_creation_and_retrieval(self): # Initially, no users users = get_users_with_profile() From c4d6a61457069bf9f17d7c6b459df15f77e21838 Mon Sep 17 00:00:00 2001 From: ojasaklechayt Date: Tue, 12 Aug 2025 17:54:10 +0530 Subject: [PATCH 3/5] refactor: consolidate sync/async query tests and improve error handling --- foundation_sql/query.py | 203 ++++++++++++++++------------ tests/common.py | 2 - tests/test_adapters.py | 6 +- tests/test_query_decorator_async.py | 90 ------------ tests/test_query_decorator_sync.py | 109 --------------- tests/test_query_sync_and_async.py | 166 +++++++++++++++++++++++ tests/test_simple_query.py | 42 ++++++ 7 files changed, 325 insertions(+), 293 deletions(-) delete mode 100644 tests/test_query_decorator_async.py delete mode 100644 tests/test_query_decorator_sync.py create mode 100644 tests/test_query_sync_and_async.py diff --git a/foundation_sql/query.py b/foundation_sql/query.py index 20868a4..18e6b4d 100644 --- a/foundation_sql/query.py +++ b/foundation_sql/query.py @@ -44,7 +44,6 @@ def __init__( api_key: Optional[str] = None, base_url: Optional[str] = None, model: Optional[str] = None, - adapter_mode: str = "sync", ): """ Initialize the SQL query decorator. @@ -79,43 +78,19 @@ def __init__( ) self.repair = repair - self.adapter_mode = adapter_mode - def __call__(self, func: Callable) -> Callable: - """ - Decorator implementation for SQL query generation and execution. - - Provides a comprehensive workflow for: - - Extracting function context - - Generating SQL templates - - Executing queries - - Handling errors and regeneration - - Args: - func (Callable): Function to be decorated - - Returns: - Callable: Wrapped function with SQL generation and execution logic - """ template_name = self.name or f"{func.__name__}.sql" fn_spec = FunctionSpec(func) - prompt_generator = SQLPromptGenerator( - fn_spec, - template_name, - self.system_prompt, - self.schema) - + prompt_generator = SQLPromptGenerator(fn_spec, template_name, self.system_prompt, self.schema) - def sql_gen(kwargs: Dict[str, Any], error: Optional[str]=None, prev_template: Optional[str]=None): + def sql_gen(kwargs: Dict[str, Any], error: Optional[str] = None, prev_template: Optional[str] = None): if self.regen or not self.cache.exists(template_name) or error: - prompt = prompt_generator.generate_prompt(kwargs, error, prev_template) sql_template = self.sql_generator.generate_sql(prompt) self.cache.set(template_name, sql_template) else: sql_template = self.cache.get(template_name) - return sql_template def _parse_result(result_data: Any): @@ -126,72 +101,46 @@ def _parse_result(result_data: Any): ] elif isinstance(result_data, int): return result_data + elif fn_spec.return_type is int: + # Best-effort mapping from non-int results to int + # 1) QueryResult-like object + if hasattr(result_data, "count") and callable(getattr(result_data, "count")): + try: + return int(result_data.count()) + except Exception: + pass + # 2) List of rows + if isinstance(result_data, list): + try: + return int(len(result_data)) + except Exception: + pass + # 3) Dict payloads: common keys or single numeric value + if isinstance(result_data, dict): + for k in ("result", "count", "affected", "rowcount"): + v = result_data.get(k) + if isinstance(v, int): + return v + vals = list(result_data.values()) + if len(vals) == 1 and isinstance(vals[0], int): + return vals[0] + # Fallback + return 0 else: first_row = result_data.first() return db.parse_query_to_pydantic(first_row, fn_spec.return_type) if first_row else None - is_func_async = inspect.iscoroutinefunction(func) - use_async = is_func_async or (self.adapter_mode == "async") - - if use_async: - @functools.wraps(func) - async def async_wrapper(**kwargs: Any) -> Any: - last_exc: Optional[Exception] = None - error: Optional[str] = None - sql_template: Optional[str] = None - attempts = self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 - database = db.get_db_with_adapter(self.db_url, "async") - for _ in range(attempts): - sql_template = sql_gen(kwargs, error, sql_template) - try: - result_data = await database.run_sql_async(sql_template, **kwargs) - try: - return _parse_result(result_data) - except Exception as parse_err: - last_exc = parse_err - error = f"Parsing/Validation error: {parse_err}" - continue - except Exception as exec_err: - last_exc = exec_err - error = f"Execution error: {exec_err}" - continue - if last_exc: - raise last_exc - raise RuntimeError("SQL generation failed without explicit exception") - - return async_wrapper - else: - @functools.wraps(func) - def wrapper(**kwargs: Any) -> Any: - last_exc: Optional[Exception] = None - error: Optional[str] = None - sql_template: Optional[str] = None - attempts = self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 - for _ in range(attempts): - sql_template = sql_gen(kwargs, error, sql_template) - try: - result_data = db.run_sql(self.db_url, sql_template, **kwargs) - try: - return _parse_result(result_data) - except Exception as parse_err: - # Capture validation/parse issues and retry regeneration - last_exc = parse_err - error = f"Parsing/Validation error: {parse_err}" - continue - except Exception as exec_err: - last_exc = exec_err - error = f"Execution error: {exec_err}" - continue - # Exhausted attempts - if last_exc: - raise last_exc - # Fallback: shouldn't happen, but raise generic error - raise RuntimeError("SQL generation failed without explicit exception") - - return wrapper - + is_async = inspect.iscoroutinefunction(func) + executor = WrapSqlExecution( + func=func, + db_url=self.db_url, + repair=self.repair, + sql_gen=sql_gen, + parse_result=_parse_result, + ) + return executor.build_wrapper(is_async) - + def load_file(self, path: str) -> str: """ Load predefined table schemas. @@ -205,3 +154,81 @@ def load_file(self, path: str) -> str: with open(path, 'r') as f: return f.read() +class WrapSqlExecution: + + def __init__( + self, + func: Callable, + db_url: str, + repair: Optional[int], + sql_gen: Callable[[Dict[str, Any], Optional[str], Optional[str]], str], + parse_result: Callable[[Any], Any], + ) -> None: + self.func = func + self.db_url = db_url + self.repair = repair + self.sql_gen = sql_gen + self._parse_result = parse_result + + async def _execute_async(self, **kwargs: Any): + last_exc = None + error = None + sql_template = None + attempts = self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + database = db.get_db_with_adapter(self.db_url, "async") + + for _ in range(attempts): + sql_template = self.sql_gen(kwargs, error, sql_template) + try: + result_data = await database.run_sql_async(sql_template, **kwargs) + try: + return self._parse_result(result_data) + except Exception as parse_err: + last_exc = parse_err + error = f"Parsing/Validation error: {parse_err}" + continue + except Exception as exec_err: + last_exc = exec_err + error = f"Execution error: {exec_err}" + continue + + if last_exc: + raise last_exc + raise RuntimeError("SQL generation failed without explicit exception") + + def _execute_sync(self, **kwargs: Any): + last_exc = None + error = None + sql_template = None + attempts = self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + + for _ in range(attempts): + sql_template = self.sql_gen(kwargs, error, sql_template) + try: + result_data = db.run_sql(self.db_url, sql_template, **kwargs) + try: + return self._parse_result(result_data) + except Exception as parse_err: + last_exc = parse_err + error = f"Parsing/Validation error: {parse_err}" + continue + except Exception as exec_err: + last_exc = exec_err + error = f"Execution error: {exec_err}" + continue + + if last_exc: + raise last_exc + raise RuntimeError("SQL generation failed without explicit exception") + + def build_wrapper(self, is_async: bool): + if is_async: + @functools.wraps(self.func) + async def async_wrapper(**kwargs: Any): + return await self._execute_async(**kwargs) + return async_wrapper + else: + @functools.wraps(self.func) + def sync_wrapper(**kwargs: Any): + return self._execute_sync(**kwargs) + return sync_wrapper diff --git a/tests/common.py b/tests/common.py index 03eb0c0..5caf555 100644 --- a/tests/common.py +++ b/tests/common.py @@ -64,5 +64,3 @@ def tearDown(self): connection.get_engine().dispose() db.DATABASES.clear() db.DATABASES.clear() - - diff --git a/tests/test_adapters.py b/tests/test_adapters.py index dbdf80c..1e142bf 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -6,12 +6,12 @@ from foundation_sql import db -SYNC_DB_URL = os.environ.get("DATABASE_URL") +SYNC_DB_URL = "sqlite:///:memory:" TEST_SCHEMA = """ CREATE TABLE IF NOT EXISTS items ( - id SERIAL PRIMARY KEY, + id INTEGER PRIMARY KEY, name TEXT NOT NULL ); """ @@ -27,8 +27,6 @@ class TestSQLAlchemyAdapter(unittest.TestCase): def setUp(self) -> None: - if not SYNC_DB_URL: - raise unittest.SkipTest("DATABASE_URL not set; skipping Postgres-only tests") self.adapter = SQLAlchemyAdapter(SYNC_DB_URL) # init schema self.adapter.init_schema(TEST_SCHEMA) diff --git a/tests/test_query_decorator_async.py b/tests/test_query_decorator_async.py deleted file mode 100644 index 2a522c0..0000000 --- a/tests/test_query_decorator_async.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import shutil -import unittest -import inspect -from typing import List - -from pydantic import BaseModel - -from foundation_sql.query import SQLQueryDecorator -from foundation_sql.db_drivers import AsyncpgAdapter -from foundation_sql import db - - -ASYNC_DB_URL = os.environ.get("DATABASE_URL") # e.g., postgresql://user:pass@localhost:5432/dbname -CACHE_DIR = "__sql__" - - -@unittest.skipUnless(ASYNC_DB_URL, "Async tests require DATABASE_URL Postgres DSN") -class TestSQLQueryDecoratorAsync(unittest.IsolatedAsyncioTestCase): - class User(BaseModel): - id: int - name: str - - TABLES_SCHEMA = """ - CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY, - name TEXT NOT NULL - ); - """ - - @classmethod - def setUpClass(cls) -> None: - # Ensure clean cache dir and seed templates - if os.path.exists(CACHE_DIR): - shutil.rmtree(CACHE_DIR) - os.makedirs(CACHE_DIR, exist_ok=True) - with open(os.path.join(CACHE_DIR, "get_users.sql"), "w") as f: - f.write("SELECT id, name FROM users ORDER BY id;") - with open(os.path.join(CACHE_DIR, "create_user.sql"), "w") as f: - f.write("INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});") - - async def asyncSetUp(self): - # Initialize schema on Postgres using async adapter each test for isolation - self.database = db.Database(ASYNC_DB_URL, adapter=AsyncpgAdapter(ASYNC_DB_URL)) - await self.database.init_schema_async(schema_sql=self.TABLES_SCHEMA) - # Clean table - await self.database.run_sql_async("DELETE FROM users;") - - async def asyncTearDown(self): - await self.database.close_async() - db.DATABASES.clear() - - async def test_async_wrapper_and_execution(self): - query = SQLQueryDecorator(schema=self.TABLES_SCHEMA, db_url=ASYNC_DB_URL, cache_dir=CACHE_DIR, adapter_mode="async") - - @query - def get_users() -> List["TestSQLQueryDecoratorAsync.User"]: - pass - - @query - def create_user(user: "TestSQLQueryDecoratorAsync.User") -> int: - pass - - # wrappers should be async - self.assertTrue(inspect.iscoroutinefunction(get_users)) - self.assertTrue(inspect.iscoroutinefunction(create_user)) - - users = await get_users() - self.assertEqual(len(users), 0) - rc = await create_user(user=self.User(id=1, name="Alice")) - self.assertEqual(rc, 1) - users = await get_users() - self.assertEqual(len(users), 1) - self.assertEqual(users[0].id, 1) - self.assertEqual(users[0].name, "Alice") - - async def test_async_adapter_direct_use(self): - adapter = AsyncpgAdapter(ASYNC_DB_URL) - await adapter.init_pool_async() - await adapter.init_schema_async(self.TABLES_SCHEMA) - rc = await adapter.run_sql_async("INSERT INTO users (id, name) VALUES ({{ id }}, {{ name | tojson }});", {"id": 2, "name": "Bob"}) - self.assertEqual(rc, 1) - rows = await adapter.run_sql_async("SELECT id, name FROM users ORDER BY id;", {}) - self.assertIsInstance(rows, list) - self.assertGreaterEqual(len(rows), 1) - await adapter.close_async() - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/tests/test_query_decorator_sync.py b/tests/test_query_decorator_sync.py deleted file mode 100644 index df8169b..0000000 --- a/tests/test_query_decorator_sync.py +++ /dev/null @@ -1,109 +0,0 @@ -import os -import shutil -import unittest -import inspect -from typing import List - -from pydantic import BaseModel -from tests import common -from foundation_sql.query import SQLQueryDecorator -from foundation_sql import db - -DB_URL = os.environ.get("DATABASE_URL") -CACHE_DIR = "__sql__" - - -class User(BaseModel): - id: int - name: str - - -TABLES_SCHEMA = """ -CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY, - name TEXT NOT NULL -); -""" - - -class TestSQLQueryDecoratorSync(common.DatabaseTests): - schema_sql = TABLES_SCHEMA - - @classmethod - def setUpClass(cls) -> None: - # Ensure clean cache dir - if os.path.exists(CACHE_DIR): - shutil.rmtree(CACHE_DIR) - os.makedirs(CACHE_DIR, exist_ok=True) - # Seed SQL templates so no LLM calls are needed - with open(os.path.join(CACHE_DIR, "get_users.sql"), "w") as f: - f.write("SELECT id, name FROM users ORDER BY id;") - with open(os.path.join(CACHE_DIR, "create_user.sql"), "w") as f: - f.write("INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});") - - def setUp(self): - super().setUp() - # Ensure clean table state for each test - database = db.get_db(self.db_url) - database.run_sql("DELETE FROM users;") - - def test_sync_wrapper_and_execution(self): - query = SQLQueryDecorator(schema=TABLES_SCHEMA, db_url=self.db_url, cache_dir=CACHE_DIR, adapter_mode="sync") - - @query - def get_users() -> List[User]: - """Get all users""" - pass - - @query - def create_user(user: User) -> int: - """Create user""" - pass - - # wrapper should be sync - self.assertFalse(inspect.iscoroutinefunction(get_users)) - self.assertFalse(inspect.iscoroutinefunction(create_user)) - - # execute - users = get_users() - self.assertEqual(len(users), 0) - rc = create_user(user=User(id=1, name="Alice")) - self.assertEqual(rc, 1) - users = get_users() - self.assertEqual(len(users), 1) - self.assertEqual(users[0].id, 1) - self.assertEqual(users[0].name, "Alice") - - def test_template_caching(self): - # Pre-seeded template exists; decorator should reuse without regeneration - query = SQLQueryDecorator(schema=TABLES_SCHEMA, db_url=self.db_url, cache_dir=CACHE_DIR, adapter_mode="sync") - - @query - def get_users() -> List[User]: - pass - - # Change cached SQL to return in reverse order to observe effect - with open(os.path.join(CACHE_DIR, "get_users.sql"), "w") as f: - f.write("SELECT id, name FROM users ORDER BY id DESC;") - - # Insert two rows and ensure order follows cached SQL - database = db.get_db(self.db_url) - database.run_sql("INSERT INTO users (id, name) VALUES (1, 'A');") - database.run_sql("INSERT INTO users (id, name) VALUES (2, 'B');") - - users = get_users() - self.assertEqual([u.id for u in users], [2, 1]) - - def test_adapter_selection_flag_async_wrapper_for_sync_func(self): - # When adapter_mode="async", even a sync function should return async wrapper - query = SQLQueryDecorator(schema=TABLES_SCHEMA, db_url=self.db_url, cache_dir=CACHE_DIR, adapter_mode="async") - - @query - def get_users_async() -> List[User]: - pass - - self.assertTrue(inspect.iscoroutinefunction(get_users_async)) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/tests/test_query_sync_and_async.py b/tests/test_query_sync_and_async.py new file mode 100644 index 0000000..3aec2f6 --- /dev/null +++ b/tests/test_query_sync_and_async.py @@ -0,0 +1,166 @@ +import os +import shutil +import unittest +import inspect +from typing import List + +from pydantic import BaseModel + +from foundation_sql.query import SQLQueryDecorator +from foundation_sql import db +from foundation_sql.db_drivers import AsyncpgAdapter + + +SQLITE_DB_URL = "sqlite:///__test_sync.sqlite3" + +# Attempt to ensure DATABASE_URL is available by reading .env if needed +if not os.environ.get("DATABASE_URL") and os.path.exists(os.path.join(os.path.dirname(__file__), "..", ".env")): + env_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".env")) + try: + with open(env_path, "r") as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" in line: + key, val = line.split("=", 1) + key = key.strip() + val = val.strip().strip('"').strip("'") + os.environ.setdefault(key, val) + except Exception: + pass + +ASYNC_DB_URL = os.environ.get("DATABASE_URL") # e.g., postgresql://user:pass@localhost:5432/dbname +CACHE_DIR_SYNC = "__sql__/__sql_sync__" +CACHE_DIR_ASYNC = "__sql__/__sql_async__" + + +class TestSQLQueryDecoratorSync(unittest.TestCase): + class User(BaseModel): + id: int + name: str + + TABLES_SCHEMA = """ + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ); + """ + + @classmethod + def setUpClass(cls) -> None: + # Ensure clean cache dir and seed templates for sync + if os.path.exists(CACHE_DIR_SYNC): + shutil.rmtree(CACHE_DIR_SYNC) + os.makedirs(CACHE_DIR_SYNC, exist_ok=True) + with open(os.path.join(CACHE_DIR_SYNC, "get_users.sql"), "w") as f: + f.write("SELECT id, name FROM users ORDER BY id;") + with open(os.path.join(CACHE_DIR_SYNC, "create_user.sql"), "w") as f: + f.write("INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});") + + # Prepare SQLite DB file cleanly + if os.path.exists("__test_sync.sqlite3"): + os.remove("__test_sync.sqlite3") + + def setUp(self) -> None: + # Initialize schema on SQLite using sync adapter + self.database = db.Database(SQLITE_DB_URL) + self.database.init_schema(schema_sql=self.TABLES_SCHEMA) + # Clean table + self.database.run_sql("DELETE FROM users;") + + def tearDown(self) -> None: + # Close SQLAlchemy engine + try: + self.database.adapter.close() + except Exception: + pass + db.DATABASES.clear() + + def test_sync_wrappers_and_execution(self): + query = SQLQueryDecorator(schema=self.TABLES_SCHEMA, db_url=SQLITE_DB_URL, cache_dir=CACHE_DIR_SYNC) + + @query + def get_users() -> List["TestSQLQueryDecoratorSync.User"]: + pass + + @query + def create_user(user: "TestSQLQueryDecoratorSync.User") -> int: + pass + + # wrappers should be sync + self.assertFalse(inspect.iscoroutinefunction(get_users)) + self.assertFalse(inspect.iscoroutinefunction(create_user)) + + users = get_users() + self.assertEqual(len(users), 0) + rc = create_user(user=self.User(id=1, name="Alice")) + self.assertEqual(rc, 1) + users = get_users() + self.assertEqual(len(users), 1) + self.assertEqual(users[0].id, 1) + self.assertEqual(users[0].name, "Alice") + + +@unittest.skipUnless(ASYNC_DB_URL, "Async tests require DATABASE_URL Postgres DSN") +class TestSQLQueryDecoratorAsync(unittest.IsolatedAsyncioTestCase): + class User(BaseModel): + id: int + name: str + + TABLES_SCHEMA = """ + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ); + """ + + @classmethod + def setUpClass(cls) -> None: + # Ensure clean cache dir and seed templates for async + if os.path.exists(CACHE_DIR_ASYNC): + shutil.rmtree(CACHE_DIR_ASYNC) + os.makedirs(CACHE_DIR_ASYNC, exist_ok=True) + with open(os.path.join(CACHE_DIR_ASYNC, "get_users.sql"), "w") as f: + f.write("SELECT id, name FROM users ORDER BY id;") + with open(os.path.join(CACHE_DIR_ASYNC, "create_user.sql"), "w") as f: + f.write("INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});") + + async def asyncSetUp(self): + # Initialize schema on Postgres using async adapter each test for isolation + self.database = db.Database(ASYNC_DB_URL, adapter=AsyncpgAdapter(ASYNC_DB_URL)) + await self.database.init_schema_async(schema_sql=self.TABLES_SCHEMA) + # Clean table + await self.database.run_sql_async("DELETE FROM users;") + + async def asyncTearDown(self): + await self.database.close_async() + db.DATABASES.clear() + + async def test_async_wrappers_and_execution(self): + query = SQLQueryDecorator(schema=self.TABLES_SCHEMA, db_url=ASYNC_DB_URL, cache_dir=CACHE_DIR_ASYNC) + + @query + async def get_users() -> List["TestSQLQueryDecoratorAsync.User"]: + pass + + @query + async def create_user(user: "TestSQLQueryDecoratorAsync.User") -> int: + pass + + # wrappers should be async + self.assertTrue(inspect.iscoroutinefunction(get_users)) + self.assertTrue(inspect.iscoroutinefunction(create_user)) + + users = await get_users() + self.assertEqual(len(users), 0) + rc = await create_user(user=self.User(id=1, name="Alice")) + self.assertEqual(rc, 1) + users = await get_users() + self.assertEqual(len(users), 1) + self.assertEqual(users[0].id, 1) + self.assertEqual(users[0].name, "Alice") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_simple_query.py b/tests/test_simple_query.py index c81720b..ef9ee1c 100644 --- a/tests/test_simple_query.py +++ b/tests/test_simple_query.py @@ -1,6 +1,8 @@ from typing import List from tests import common from pydantic import BaseModel +import os +import shutil class User(BaseModel): @@ -39,6 +41,46 @@ def create_user(user: User) -> int: class TestQuery(common.DatabaseTests): schema_sql = TABLES_SCHEMA + CACHE_DIR = "__sql__" + + @classmethod + def setUpClass(cls) -> None: + # Ensure clean cache dir and seed SQLite-friendly templates for this test + if os.path.exists(cls.CACHE_DIR): + shutil.rmtree(cls.CACHE_DIR) + os.makedirs(cls.CACHE_DIR, exist_ok=True) + + # Deterministic SELECT mapping to Pydantic model fields + with open(os.path.join(cls.CACHE_DIR, "get_users.sql"), "w") as f: + f.write( + ( + """ + SELECT + id as "id", + name as "name", + email as "email", + role as "role" + FROM users + ORDER BY id; + """ + ).strip() + ) + + # Deterministic INSERT using provided user fields (string id in this schema) + with open(os.path.join(cls.CACHE_DIR, "create_user.sql"), "w") as f: + f.write( + ( + """ + INSERT INTO users (id, name, email, role) + VALUES ( + {{ user.id }}, + {{ user.name | tojson }}, + {{ user.email | tojson }}, + {{ user.role | tojson }} + ); + """ + ).strip() + ) def test_users(self): users = get_users() From b2dc527242f422e7809d949080a6943af6ac9a6a Mon Sep 17 00:00:00 2001 From: ojasaklechat41 Date: Tue, 12 Aug 2025 17:57:17 +0530 Subject: [PATCH 4/5] Delete plan.md --- plan.md | 105 -------------------------------------------------------- 1 file changed, 105 deletions(-) delete mode 100644 plan.md diff --git a/plan.md b/plan.md deleted file mode 100644 index cc37176..0000000 --- a/plan.md +++ /dev/null @@ -1,105 +0,0 @@ -# Foundation DB Layer Refactor Plan - -This document outlines the implementation plan for refactoring the Foundation DB layer to support both synchronous and asynchronous operations using a pluggable adapter architecture. - -## Phase 1: Core Abstraction and Synchronous Implementation - -This phase focuses on establishing the core `EngineAdapter` abstraction and refactoring the existing synchronous functionality to use the new `SQLAlchemyAdapter`. - -### Checklist - -- [x] **Create `foundation_sql/db_drivers.py` file.** - - This file will house the `EngineAdapter` abstract class and its concrete implementations. -- [x] **Adapter Selection & Factory (Sync by default, explicit async flag via decorator).** - - Add an adapter factory in `foundation_sql/db.py` that instantiates adapters. - - Keep existing public functions stable; do not change `db.run_sql(db_url, template, **ctx)` or `db.get_db(db_url)` signatures. - - Introduce a new helper for internal use only (by the decorator): `get_db_with_adapter(db_url, mode: Literal["sync","async"])` or equivalent, so existing API remains unchanged. - - The decorator controls adapter selection via an explicit flag (see Phase 2) and uses the helper to obtain the correct `Database`/adapter. -- [x] **Define `EngineAdapter` Abstract Base Class in `db_drivers.py`.** - - Create `EngineAdapter` with the following abstract methods: - - `init_schema(self, schema_sql: str)` - - `run_sql(self, template: str, data: dict)` - - `close(self)` - - `init_pool_async(self)` - - `init_schema_async(self, schema_sql: str)` - - `run_sql_async(self, template: str, data: dict)` - - `close_async(self)` -- [x] **Implement `SQLAlchemyAdapter` in `db_drivers.py`.** - - Create a class `SQLAlchemyAdapter` that inherits from `EngineAdapter`. - - Move the `sqlalchemy` engine creation and execution logic from `foundation_sql/db.py` into this class. - - Implement the synchronous methods: - - `__init__(self, dsn: str)`: - - Initialize `sqlalchemy.create_engine` with the given DSN. - - Store the engine instance. - - `init_schema(self, schema_sql: str)`: - - Implement the schema initialization logic currently in `db.Database.init_schema`. - - `run_sql(self, template: str, data: dict)`: - - Implement the SQL execution logic currently in `db.Database.run_sql`. - - `close(self)`: - - Dispose of the SQLAlchemy engine. - - Jinja2SQL configuration for sync: - - Use `Jinja2SQL(param_style="sqlalchemy")` (or default) to produce `:name` bindings compatible with SQLAlchemy. -- [x] **Refactor `foundation_sql/db.py` to use the Adapter Pattern.** - - Modify the `Database` class `__init__` method to accept an `EngineAdapter` instance. - - The `Database` class will delegate calls to the adapter's methods (`init_schema`, `run_sql`). - - Remove the direct `sqlalchemy` dependencies from `db.py` that are now handled by the `SQLAlchemyAdapter`. - - The `get_db` function will be updated to instantiate the `Database` class with the appropriate adapter. - - Maintain backward compatibility by keeping `get_db(db_url)` and `run_sql(db_url, template, **ctx)` unchanged; add internal helper `get_db_with_adapter(db_url, mode)` for the decorator’s explicit selection. -- [x] **Update `foundation_sql/query.py` to align with the new DB structure.** - - The `SQLQueryDecorator` currently calls `db.run_sql`. Ensure this continues to work with the refactored `db.py`. - -## Phase 2: Asynchronous Implementation with `asyncpg` - -This phase introduces asynchronous support by implementing the `AsyncpgAdapter` and updating the core `db.py` and `query.py` to handle async operations. - -### Checklist - -- [x] **Add `asyncpg` to `requirements.txt`.** -- [x] **Implement `AsyncpgAdapter` in `db_drivers.py`.** - - Create a class `AsyncpgAdapter` that inherits from `EngineAdapter`. - - `__init__(self, dsn: str)`: - - Store the DSN. - - Initialize `Jinja2SQL(param_style="asyncpg", enable_async=True)`. - - Implement the asynchronous methods: - - `init_pool_async(self)` - - `close_async(self)` - - `init_schema_async(self, schema_sql: str)` - - `run_sql_async(self, template: str, data: dict)` -- [x] **Update `foundation_sql/db.py` for Async Support.** - - Add `async` methods to the `Database` class (`init_schema_async`, `run_sql_async`, etc.). - - These methods will delegate to the corresponding `async` methods on the adapter instance. -- [x] **Update `foundation_sql/query.py` for Async Support.** - - Modify the `SQLQueryDecorator` to be async-aware and to explicitly control adapter selection via a flag. - - Add explicit adapter selection flag, e.g., `adapter_mode: Literal["sync","async"] = "sync"` (or `async_mode: bool = False`). - - If `adapter_mode == "async"` (or `async_mode is True`), the decorator uses `get_db_with_adapter(db_url, "async")` and async execution APIs. - - Else it uses the sync adapter path. - - Decorator async-awareness design: - - If the wrapped user function is `async def` OR `adapter_mode == "async"`, return an `async def wrapper`. - - Otherwise return a sync `def wrapper`. - - Unify template generation, caching, and result parsing across both paths to ensure parity. - - Keep the parsing behavior consistent with current logic (support `wrapper == 'list'`, `parse_query_to_pydantic`, integer rowcounts, and first-row mapping). - - Public API stability: - - Existing `db.run_sql(db_url, template, **ctx)` and `db.get_db(db_url)` remain available and unchanged. - - The decorator should use the new internal helper to pick adapter without altering public function signatures. - -## Phase 3: Testing, Documentation, and Finalization - -This phase ensures the new implementation is robust, well-documented, and easy to use. - -### Checklist - -- [ ] **Create Unit Tests for Both Adapters.** - - Write tests for `SQLAlchemyAdapter` to ensure existing sync functionality is not broken. - - Write new tests for `AsyncpgAdapter` to verify all async functionality. - - Add tests for `SQLQueryDecorator` covering: - - Sync path vs async path (wrapper type selection, execution, and parity of results) - - Template caching and regeneration - - Adapter selection via `adapter_mode` flag -- [ ] **Update `README.md` and Add Examples.** - - Document the new adapter-based architecture. - - Provide clear examples for both sync and async usage, including explicit adapter selection in the decorator and examples showing sync vs async wrappers. -- [ ] **Update `prompts.md`** - - Reflect the new async capabilities in the prompts documentation if necessary. -- [ ] **Code Review and Refinement.** - - Perform a final review of the code for clarity, consistency, and adherence to best practices. - - Verify public API stability (`db.run_sql`, `db.get_db`) and mark any new helper (e.g., `get_db_with_adapter`) as internal-only in docs. \ No newline at end of file From f7a8ee3706c2f9557418c6f60efa7f1a6494a27f Mon Sep 17 00:00:00 2001 From: ojasaklechat41 Date: Tue, 12 Aug 2025 17:57:35 +0530 Subject: [PATCH 5/5] Delete uv.lock --- uv.lock | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 uv.lock diff --git a/uv.lock b/uv.lock deleted file mode 100644 index e8f9111..0000000 --- a/uv.lock +++ /dev/null @@ -1,8 +0,0 @@ -version = 1 -revision = 2 -requires-python = ">=3.8" - -[[package]] -name = "foundation-sql" -version = "0.0.0" -source = { editable = "." }