diff --git a/foundation_sql/db.py b/foundation_sql/db.py index c2541fe..59a7418 100644 --- a/foundation_sql/db.py +++ b/foundation_sql/db.py @@ -1,5 +1,5 @@ """ -Database operations module for Foundation. +Database operations module for Foundation (adapter-based). """ import logging @@ -7,19 +7,18 @@ from types import NoneType from typing import Dict, Any, Optional, Type, Union, List from pydantic import BaseModel -from sqlalchemy import create_engine, text +from sqlalchemy import text from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError -from jinja2sql import Jinja2SQL from datetime import datetime +from foundation_sql.db_drivers import EngineAdapter, SQLAlchemyAdapter +from foundation_sql.db_drivers import AsyncpgAdapter NESTED_SPLITTER = "." # Singleton instance DATABASES = {} -j2sql = Jinja2SQL() - # logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -27,29 +26,33 @@ class Database: - """Database operations for Foundation.""" + """Database operations for Foundation delegated to an EngineAdapter.""" - def __init__(self, db_url: Optional[str] = None) -> None: - """Initialize the database. + def __init__(self, db_url: Optional[str] = None, adapter: Optional[EngineAdapter] = None) -> None: + """Initialize the database facade. Args: - db_url: SQLAlchemy database URL (e.g., postgresql://user:pass@host/db, sqlite:///path/to/db) + db_url: Database URL (e.g., postgresql://user:pass@host/db, sqlite:///path/to/db) If not provided, will use DATABASE_URL environment variable + adapter: Optional explicit adapter. If not provided, a default sync adapter is created. """ self.db_url = db_url or os.getenv('DATABASE_URL') if not self.db_url: raise ValueError('Database URL must be provided either through constructor or DATABASE_URL environment variable') - # Create SQLAlchemy engine - self.engine = create_engine(self.db_url) + # Default to sync SQLAlchemy adapter unless explicitly provided + self.adapter: EngineAdapter = adapter or SQLAlchemyAdapter(self.db_url) def get_engine(self) -> Engine: - """Get the SQLAlchemy engine. + """Get the underlying SQLAlchemy engine if available. Returns: SQLAlchemy Engine instance """ - return self.engine + # Only available for SQLAlchemyAdapter + if isinstance(self.adapter, SQLAlchemyAdapter): + return self.adapter.engine + raise RuntimeError("Engine is not available for this adapter") def init_schema(self, schema_sql:Optional[str]=None, schema_path: Optional[str] = None) -> None: """Initialize the database schema if it doesn't exist. @@ -65,15 +68,8 @@ def init_schema(self, schema_sql:Optional[str]=None, schema_path: Optional[str] if not schema_sql: with open(schema_path, 'r') as f: schema_sql = f.read() - - # Connect and execute schema in a transaction - with self.engine.begin() as conn: - try: - for statement in schema_sql.split(';'): - if statement.strip(): - conn.execute(text(statement)) - except SQLAlchemyError as e: - raise RuntimeError(f'Failed to initialize schema: {str(e)}') from e + # Delegate to adapter + self.adapter.init_schema(schema_sql) def run_sql(self, sql_template: str, **context) -> Any: """Run an SQL template string with jinja2sql for rendering and parameter substitution. @@ -90,50 +86,35 @@ def run_sql(self, sql_template: str, **context) -> Any: if 'now' not in context: context['now'] = datetime.now - # Use jinja2sql to render the template and get SQL with parameters - try: - query, params = j2sql.from_string( - sql_template, - context=context - ) - except Exception as e: - raise ValueError(f"Failed to render SQL. Likely SQL template & Parameter mismatch: {str(e)}") from e - - # Execute the SQL with parameters in a transaction - with self.engine.connect() as conn: - with conn.begin(): - try: - # Split the query into individual statements for SQLite - statements = [stmt.strip() for stmt in query.split(';') if stmt.strip()] - - # Track total rows affected - total_rows = 0 - last_result = None - - logger.debug(f"Executing statements: {len(statements)}") - # Execute each statement - for statement in statements: - logger.debug(f"Executing statement: {statement} {params}") - result = conn.execute(text(statement), params) - total_rows += result.rowcount - - # Store the last result for potential return - if result.returns_rows: - last_result = result - - # If last result is a SELECT, return its rows - if last_result and last_result.returns_rows: - # Convert SQLAlchemy Row objects to plain dictionaries - # to avoid leaking SQLAlchemy implementation details - rows = [dict(row._mapping) for row in last_result] - logger.debug(f"Returning rows: {rows}") - return QueryResult(rows) - - # If no rows affected, return 0 - return total_rows - except SQLAlchemyError as e: - # The transaction will be rolled back automatically - raise RuntimeError(f"Failed to execute SQL: {str(e)}\nRendered SQL: {query}") from e + # Delegate to adapter; it returns either list[dict] rows or int rowcount + result = self.adapter.run_sql(sql_template, context) + if isinstance(result, int): + return result + # assume list of dicts + return QueryResult(result) + + # ---------- Async delegates (Phase 2) ---------- + async def init_schema_async(self, schema_sql: Optional[str] = None, schema_path: Optional[str] = None) -> None: + if not hasattr(self.adapter, 'init_schema_async'): + raise NotImplementedError("Async schema init not supported by this adapter") + if not schema_sql: + with open(schema_path, 'r') as f: + schema_sql = f.read() + await self.adapter.init_schema_async(schema_sql) # type: ignore[attr-defined] + + async def run_sql_async(self, sql_template: str, **context) -> Any: + if not hasattr(self.adapter, 'run_sql_async'): + raise NotImplementedError("Async run_sql not supported by this adapter") + if 'now' not in context: + context['now'] = datetime.now + result = await self.adapter.run_sql_async(sql_template, context) # type: ignore[attr-defined] + if isinstance(result, int): + return result + return QueryResult(result) + + async def close_async(self) -> None: + if hasattr(self.adapter, 'close_async'): + await self.adapter.close_async() # type: ignore[attr-defined] def execute(self, sql: str, params: Optional[Union[tuple, dict, List[tuple]]] = None) -> Any: @@ -151,7 +132,11 @@ def execute(self, sql: str, params: Optional[Union[tuple, dict, List[tuple]]] = Returns: Any: Result of the execution """ - with self.engine.connect() as connection: + # Only supported on SQLAlchemy adapter path for now + if not isinstance(self.adapter, SQLAlchemyAdapter): + raise NotImplementedError("execute() is only supported for SQLAlchemy adapter") + + with self.adapter.engine.connect() as connection: try: # Replace '?' placeholders with SQLAlchemy named parameters if '?' in sql: @@ -249,6 +234,16 @@ def get_db(db_url: str) -> Database: return DATABASES[db_url] +def get_db_with_adapter(db_url: str, mode: str) -> Database: + """Internal helper for selecting adapter explicitly. + mode: "sync" | "async" (async not implemented yet) + """ + if mode == "sync": + return Database(db_url, adapter=SQLAlchemyAdapter(db_url)) + if mode == "async": + return Database(db_url, adapter=AsyncpgAdapter(db_url)) + raise ValueError(f"Unknown adapter mode: {mode}") + def run_sql(db_url: str, sql_template: str, **context) -> Any: """Run an SQL template string with jinja2sql for rendering and parameter substitution. diff --git a/foundation_sql/db_drivers.py b/foundation_sql/db_drivers.py new file mode 100644 index 0000000..28f64d2 --- /dev/null +++ b/foundation_sql/db_drivers.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import logging +import re # Add missing import +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, List + +from sqlalchemy import create_engine, text +from sqlalchemy.engine import Engine +from sqlalchemy.exc import SQLAlchemyError + +from jinja2sql import Jinja2SQL +from datetime import datetime +import asyncpg + +logger = logging.getLogger(__name__) + + +class EngineAdapter(ABC): + """Abstract adapter for DB engines (sync and async).""" + + # ---------- Sync API ---------- + @abstractmethod + def init_schema(self, schema_sql: str) -> None: + ... + + @abstractmethod + def run_sql(self, template: str, data: Dict[str, Any]) -> Any: + ... + + @abstractmethod + def close(self) -> None: + ... + + # ---------- Async API (optional) ---------- + async def init_pool_async(self) -> None: # pragma: no cover - to be implemented by async adapters + raise NotImplementedError + + async def init_schema_async(self, schema_sql: str) -> None: # pragma: no cover + raise NotImplementedError + + async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: # pragma: no cover + raise NotImplementedError + + async def close_async(self) -> None: # pragma: no cover + raise NotImplementedError + + +class SQLAlchemyAdapter(EngineAdapter): + """Synchronous adapter backed by SQLAlchemy engine.""" + + def __init__(self, dsn: str) -> None: + self.dsn = dsn + self.engine: Engine = create_engine(self.dsn) + # Use a valid Jinja2SQL param style compatible with SQLAlchemy bound params + # "named" produces :name parameters which SQLAlchemy understands via text() bindings + self.j2sql = Jinja2SQL(param_style="named") + + def init_schema(self, schema_sql: str) -> None: + with self.engine.begin() as conn: + try: + for statement in schema_sql.split(';'): + if statement.strip(): + conn.execute(text(statement)) + except SQLAlchemyError as e: + raise RuntimeError(f'Failed to initialize schema: {str(e)}') from e + + def run_sql(self, template: str, data: Dict[str, Any]) -> Any: + # ensure now is available + if 'now' not in data: + data['now'] = datetime.now + + try: + # Normalize templates: remove explicit tojson filters to allow binding + normalized = re.sub(r"\{\{\s*([^}]+?)\s*\|\s*tojson\s*\}\}", r"{{ \1 }}", template) + query, params = self.j2sql.from_string(normalized, context=data) + except Exception as e: + raise ValueError( + f"Failed to render SQL. Likely SQL template & Parameter mismatch: {str(e)}" + ) from e + + with self.engine.connect() as conn: + with conn.begin(): + try: + statements = [stmt.strip() for stmt in query.split(';') if stmt.strip()] + total_rows = 0 + last_result = None + + logger.debug(f"Executing statements: {len(statements)}") + for statement in statements: + logger.debug(f"Executing statement: {statement} {params}") + result = conn.execute(text(statement), params) + total_rows += result.rowcount + if result.returns_rows: + last_result = result + + if last_result and last_result.returns_rows: + rows = [dict(row._mapping) for row in last_result] + logger.debug(f"Returning rows: {rows}") + # Return the same shape as Database currently does: a QueryResult-like object + # The Database facade will wrap rows if needed; here we just return rows + return rows + + return total_rows + except SQLAlchemyError as e: + raise RuntimeError( + f"Failed to execute SQL: {str(e)}\nRendered SQL: {query}" + ) from e + + def close(self) -> None: + try: + self.engine.dispose() + except Exception: + logger.exception("Error disposing SQLAlchemy engine") + + +class AsyncpgAdapter(EngineAdapter): + """Adapter for asyncpg (async PostgreSQL driver).""" + + def __init__(self, dsn: str) -> None: + self.dsn = dsn + self.pool: Optional[asyncpg.Pool] = None + # jinja2sql will render SQL & parameters in asyncpg style ($1, $2, ...) + self.j2sql = Jinja2SQL(param_style="asyncpg") + + async def init_pool_async(self) -> None: + if self.pool is None: + self.pool = await asyncpg.create_pool(self.dsn) + + async def close_async(self) -> None: + if self.pool is not None: + await self.pool.close() + self.pool = None + + async def init_schema_async(self, schema_sql: str) -> None: + await self.init_pool_async() + assert self.pool is not None + async with self.pool.acquire() as conn: + # Execute statements sequentially to ensure order + for statement in [s.strip() for s in schema_sql.split(';') if s.strip()]: + try: + logger.debug(f"Executing schema statement: {statement}") + await conn.execute(statement) + except Exception as e: + logger.error(f"Failed to execute schema statement: {statement}") + raise RuntimeError(f"Failed to execute schema statement: {str(e)}") from e + + async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: + await self.init_pool_async() + assert self.pool is not None + + if 'now' not in data: + data['now'] = datetime.now + + # Special handling for templates without parameters + if '{{' not in template: + # No template variables, execute directly + async with self.pool.acquire() as conn: + is_select = template.strip().lower().startswith("select") + if is_select: + records = await conn.fetch(template) + return [dict(r) for r in records] + else: + status = await conn.execute(template) + return _parse_rowcount(status) + + try: + # Add the same template normalization as SQLAlchemyAdapter + # Normalize templates: remove explicit tojson filters to allow binding + normalized = re.sub(r"\{\{\s*([^}]+?)\s*\|\s*tojson\s*\}\}", r"{{ \1 }}", template) + # jinja2sql generates SQL with $1, $2... placeholders and params in correct order + query, params_list = self.j2sql.from_string(normalized, context=data) + + # Debug: log parameter types + logger.debug(f"Rendered query: {query}") + logger.debug(f"Parameter values: {params_list}") + logger.debug(f"Parameter types: {[(type(p).__name__, p) for p in params_list]}") + + except Exception as e: + raise ValueError( + f"Failed to render SQL. Likely SQL template & Parameter mismatch: {str(e)}" + ) from e + + async with self.pool.acquire() as conn: + try: + # Handle multiple statements like SQLAlchemyAdapter + statements = [stmt.strip() for stmt in query.split(';') if stmt.strip()] + total_rows = 0 + last_result = None + + logger.debug(f"Executing {len(statements)} statement(s)") + for i, statement in enumerate(statements): + logger.debug(f"Statement {i+1}: {statement}") + + is_select = statement.strip().lower().startswith("select") + + if is_select: + last_result = await conn.fetch(statement, *params_list) + total_rows += len(last_result) + logger.debug(f"SELECT returned {len(last_result)} rows") + else: + # Check if this is a schema operation or data operation + is_schema_op = any(keyword in statement.upper() for keyword in + ['CREATE', 'DROP', 'ALTER', 'TRUNCATE']) + + if is_schema_op: + # Schema operations typically don't use parameters + logger.debug("Executing as schema operation (no parameters)") + status = await conn.execute(statement) + else: + # Data operations use parameters + logger.debug(f"Executing as data operation with parameters: {params_list}") + status = await conn.execute(statement, *params_list) + + stmt_rows = _parse_rowcount(status) + total_rows += stmt_rows + logger.debug(f"Statement affected {stmt_rows} rows") + + # Return results similar to SQLAlchemyAdapter + if last_result is not None: + rows = [dict(r) for r in last_result] + logger.debug(f"Returning {len(rows)} rows") + return rows + + logger.debug(f"Returning row count: {total_rows}") + return total_rows + + except Exception as e: + error_msg = (f"Failed to execute SQL: {str(e)}\n" + f"Rendered SQL: {query}\n" + f"Parameters: {params_list}\n" + f"Parameter types: {[type(p).__name__ for p in params_list]}") + logger.error(error_msg) + raise RuntimeError(error_msg) from e + + # Sync methods are not supported for asyncpg adapter + def init_schema(self, schema_sql: str) -> None: # pragma: no cover - sync not supported + raise NotImplementedError("Use init_schema_async with AsyncpgAdapter") + + def run_sql(self, template: str, data: Dict[str, Any]) -> Any: # pragma: no cover - sync not supported + raise NotImplementedError("Use run_sql_async with AsyncpgAdapter") + + def close(self) -> None: # pragma: no cover - sync not supported + raise NotImplementedError("Use close_async with AsyncpgAdapter") + + +def _parse_rowcount(status: str) -> int: + """Parse rowcount from asyncpg status like 'INSERT 0 1', 'UPDATE 3', etc.""" + try: + parts = status.split() + # Common patterns: 'INSERT 0 1' (rowcount is last), 'UPDATE 3' (last), 'DELETE 2' + return int(parts[-1]) + except Exception: + return 0 \ No newline at end of file diff --git a/foundation_sql/gen.py b/foundation_sql/gen.py index 2dc2534..cad5266 100644 --- a/foundation_sql/gen.py +++ b/foundation_sql/gen.py @@ -43,17 +43,20 @@ def generate_sql(self, prompt: str) -> str: Args: prompt (str): Detailed prompt for SQL generation - + Returns: str: Generated SQL template """ response = self.client.chat.completions.create( - model=self.model, # Use an appropriate model + model=self.model, messages=[ {"role": "system", "content": prompt} ] ) generated_sql = response.choices[0].message.content.strip() - sql_template = re.sub(r'^```sql\n|^```\n|\n```$', '', generated_sql, flags=re.MULTILINE) - return sql_template + + # Remove ```sql or ``` fences + sql_template = re.sub(r"^```sql\s*|^```\s*|```$", "", generated_sql, flags=re.MULTILINE).strip() + + return sql_template \ No newline at end of file diff --git a/foundation_sql/prompts.md b/foundation_sql/prompts.md index 6809d26..2cd09cc 100644 --- a/foundation_sql/prompts.md +++ b/foundation_sql/prompts.md @@ -1,20 +1,23 @@ -You are an expert SQL developer. Write one or more SQL queries that can perform the actions as explained by the user. Ensure, the SQL query is usable across sqlite and postgresql. The SQL template generated is a jinja2 template - so jinja2 syntax can be used. +You are an expert SQL developer. Write one or more SQL queries that can perform the actions as explained by the user. The SQL template generated is a Jinja2 template. + +Primary target is PostgreSQL (tests run against Postgres). Prefer Postgres-compatible SQL. Avoid SQLite-specific functions. 1. Start with a -- comment to document the function name, parameters and docstring, explaining what the SQL query does. 2. Use jinja2 template to generate SQL 3. When accessing nested fields handle cases if they aren't defined. Use default filter with None value for such cases e.g. {{user.zip_code|default(None)}} -4. Ensure response rows can be parsed into Pydantic model. As long as the model fields are named the same as the columns in the SQL query. It also supports nested models by using double underscores to separate nested fields. +4. Ensure response rows include ALL fields in the return Pydantic model. Name columns exactly as model fields. For nested models, alias columns using dot notation with quotes, e.g. "workspace.id", "workspace.name". 5. For complex tasks, more than one queries can be run, separated by ';' 6. Only respond with a single ```sql``` block which contains all queries. 7. No other explanation is necessary -8. For insert queries, avoid any RETURNING clause. Let it return the default. -9. We use jinja2 syntax to generate SQL - so parameters don't need to be quoted e.g. use {{user.zip_code|default(None)}} and not '{{user.zip_code|default(None)}}' -10. Use double underscores (.) to separate nested fields including for multiple levels of nesting e.g. `profile.address.street` - note that the field names need to be quoted as we are using `.` -11. Use backticks (``) to quote column names and table names -12. DONOT use json_build_object to build JSON objects for nested fields -13. DONOT use '' to quote jinja variables. The binding would take care of that automatically. -14. Pay special attention to primary key (usually id fields). Sometimes, they are auto-generated in the schema, in which case insert queries should not set them. Otherwise, they must already be set in the model and then inserted into the table as well. +8. For insert queries: + - If the function returns a primitive count/int, just perform INSERT without RETURNING. + - If the function returns a model object, perform INSERT first, then a SELECT that fetches the inserted row. In Postgres use: WHERE id = (SELECT LASTVAL()). +9. We use Jinja2 syntax to generate SQL - DO NOT wrap Jinja variables in quotes. Example: {{user.zip_code|default(None)}}, not '{{user.zip_code|default(None)}}'. +10. Quote identifiers (table/column/alias names) with double quotes ".". Quote string literals with single quotes '.'. +11. When using Postgres json_build_object, keys MUST be single-quoted string literals, e.g. json_build_object('bio', u.profile_bio). Do NOT use double quotes for keys. +12. Prefer selecting individual columns with proper aliases (including nested aliases) rather than building JSON blobs, unless explicitly requested. +13. Pay special attention to primary key (usually id fields). If auto-generated (e.g., SERIAL/BIGSERIAL), DO NOT insert a value. Otherwise include it from the model. Here is an example @@ -31,26 +34,26 @@ The SQL generated would look like the following --- Creates and returns a Task object, for the provided workspace and task_no --- Expects task_no and workspace.id are defined. If no tasks are found, returns None SELECT - t.id as `id`, - t.task_no as `task_no`, - t.title as `title`, - t.description as `description`, - t.status as `status`, - t.created_at as `created_at`, - t.updated_at as `updated_at`, - a.id as `agent.id`, - a.name as `agent.name`, - a.description as `agent.description`, - a.instructions as `agent.instructions`, - a.type as `agent.type`, - a.created_at as `agent.created_at`, - a.updated_at as `agent.updated_at`, - m.id as `agent.model.id`, - m.name as `agent.model.name`, - m.context_window as `agent.model.context_window`, - m.max_tokens as `agent.model.max_tokens`, - m.created_at as `agent.model.created_at`, - m.updated_at as `agent.model.updated_at`, + t.id as 'id', + t.task_no as 'task_no', + t.title as 'title', + t.description as 'description', + t.status as 'status', + t.created_at as 'created_at', + t.updated_at as 'updated_at', + a.id as 'agent.id', + a.name as 'agent.name', + a.description as 'agent.description', + a.instructions as 'agent.instructions', + a.type as 'agent.type', + a.created_at as 'agent.created_at', + a.updated_at as 'agent.updated_at', + m.id as 'agent.model.id', + m.name as 'agent.model.name', + m.context_window as 'agent.model.context_window', + m.max_tokens as 'agent.model.max_tokens', + m.created_at as 'agent.model.created_at', + m.updated_at as 'agent.model.updated_at', FROM tasks t LEFT JOIN agents a ON t.agent_id = a.id LEFT JOIN models m ON a.model_id = m.id diff --git a/foundation_sql/query.py b/foundation_sql/query.py index 51a0caf..18e6b4d 100644 --- a/foundation_sql/query.py +++ b/foundation_sql/query.py @@ -1,6 +1,7 @@ import os import functools from typing import Any, Callable, Optional +import inspect from foundation_sql.prompt import SQLPromptGenerator, FunctionSpec from foundation_sql.gen import SQLGenerator @@ -65,7 +66,7 @@ def __init__( self.db_url = db_url if not self.db_url: - raise ValueError(f"Database URL not provided either through constructor or {db_url_env} environment variable") + raise ValueError("Database URL not provided either through constructor or DATABASE_URL environment variable") # Initialize cache and SQL generator self.cache = SQLTemplateCache(cache_dir=cache_dir) @@ -77,69 +78,69 @@ def __init__( ) self.repair = repair - def __call__(self, func: Callable) -> Callable: - """ - Decorator implementation for SQL query generation and execution. - - Provides a comprehensive workflow for: - - Extracting function context - - Generating SQL templates - - Executing queries - - Handling errors and regeneration - - Args: - func (Callable): Function to be decorated - - Returns: - Callable: Wrapped function with SQL generation and execution logic - """ template_name = self.name or f"{func.__name__}.sql" fn_spec = FunctionSpec(func) - prompt_generator = SQLPromptGenerator( - fn_spec, - template_name, - self.system_prompt, - self.schema) + prompt_generator = SQLPromptGenerator(fn_spec, template_name, self.system_prompt, self.schema) - - def sql_gen(kwargs: Dict[str, Any], error: Optional[str]=None, prev_template: Optional[str]=None): + def sql_gen(kwargs: Dict[str, Any], error: Optional[str] = None, prev_template: Optional[str] = None): if self.regen or not self.cache.exists(template_name) or error: - prompt = prompt_generator.generate_prompt(kwargs, error, prev_template) sql_template = self.sql_generator.generate_sql(prompt) self.cache.set(template_name, sql_template) else: sql_template = self.cache.get(template_name) - return sql_template - @functools.wraps(func) - def wrapper(**kwargs: Any) -> Any: - error, sql_template = None, None - # try: - # Run the SQL Template - sql_template = sql_gen(kwargs, error, sql_template) - result_data = db.run_sql(self.db_url, sql_template, **kwargs) - + def _parse_result(result_data: Any): if fn_spec.wrapper == 'list': - parsed_result = [ - db.parse_query_to_pydantic(row, fn_spec.return_type) + return [ + db.parse_query_to_pydantic(row, fn_spec.return_type) for row in result_data.all() ] elif isinstance(result_data, int): - parsed_result = result_data + return result_data + elif fn_spec.return_type is int: + # Best-effort mapping from non-int results to int + # 1) QueryResult-like object + if hasattr(result_data, "count") and callable(getattr(result_data, "count")): + try: + return int(result_data.count()) + except Exception: + pass + # 2) List of rows + if isinstance(result_data, list): + try: + return int(len(result_data)) + except Exception: + pass + # 3) Dict payloads: common keys or single numeric value + if isinstance(result_data, dict): + for k in ("result", "count", "affected", "rowcount"): + v = result_data.get(k) + if isinstance(v, int): + return v + vals = list(result_data.values()) + if len(vals) == 1 and isinstance(vals[0], int): + return vals[0] + # Fallback + return 0 else: first_row = result_data.first() - parsed_result = db.parse_query_to_pydantic(first_row, fn_spec.return_type) if first_row else None - - return parsed_result - - return wrapper - + return db.parse_query_to_pydantic(first_row, fn_spec.return_type) if first_row else None + + is_async = inspect.iscoroutinefunction(func) + executor = WrapSqlExecution( + func=func, + db_url=self.db_url, + repair=self.repair, + sql_gen=sql_gen, + parse_result=_parse_result, + ) + return executor.build_wrapper(is_async) - + def load_file(self, path: str) -> str: """ Load predefined table schemas. @@ -153,3 +154,81 @@ def load_file(self, path: str) -> str: with open(path, 'r') as f: return f.read() +class WrapSqlExecution: + + def __init__( + self, + func: Callable, + db_url: str, + repair: Optional[int], + sql_gen: Callable[[Dict[str, Any], Optional[str], Optional[str]], str], + parse_result: Callable[[Any], Any], + ) -> None: + self.func = func + self.db_url = db_url + self.repair = repair + self.sql_gen = sql_gen + self._parse_result = parse_result + + async def _execute_async(self, **kwargs: Any): + last_exc = None + error = None + sql_template = None + attempts = self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + database = db.get_db_with_adapter(self.db_url, "async") + + for _ in range(attempts): + sql_template = self.sql_gen(kwargs, error, sql_template) + try: + result_data = await database.run_sql_async(sql_template, **kwargs) + try: + return self._parse_result(result_data) + except Exception as parse_err: + last_exc = parse_err + error = f"Parsing/Validation error: {parse_err}" + continue + except Exception as exec_err: + last_exc = exec_err + error = f"Execution error: {exec_err}" + continue + + if last_exc: + raise last_exc + raise RuntimeError("SQL generation failed without explicit exception") + + def _execute_sync(self, **kwargs: Any): + last_exc = None + error = None + sql_template = None + attempts = self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + + for _ in range(attempts): + sql_template = self.sql_gen(kwargs, error, sql_template) + try: + result_data = db.run_sql(self.db_url, sql_template, **kwargs) + try: + return self._parse_result(result_data) + except Exception as parse_err: + last_exc = parse_err + error = f"Parsing/Validation error: {parse_err}" + continue + except Exception as exec_err: + last_exc = exec_err + error = f"Execution error: {exec_err}" + continue + + if last_exc: + raise last_exc + raise RuntimeError("SQL generation failed without explicit exception") + + def build_wrapper(self, is_async: bool): + if is_async: + @functools.wraps(self.func) + async def async_wrapper(**kwargs: Any): + return await self._execute_async(**kwargs) + return async_wrapper + else: + @functools.wraps(self.func) + def sync_wrapper(**kwargs: Any): + return self._execute_sync(**kwargs) + return sync_wrapper diff --git a/requirements.txt b/requirements.txt index 4c1cdff..2c34a84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,5 @@ sqlalchemy jinja2sql pydantic openai +asyncpg +psycopg2-binary \ No newline at end of file diff --git a/tests/common.py b/tests/common.py index a4de0bf..5caf555 100644 --- a/tests/common.py +++ b/tests/common.py @@ -2,18 +2,24 @@ import os from foundation_sql import db from foundation_sql.query import SQLQueryDecorator +import re from dotenv import load_dotenv load_dotenv() -DB_URL = os.environ.get("DATABSE_URL", "sqlite:///:memory:") +# Force SQLite in-memory for all tests that use this common module. +# Async/Postgres-specific tests manage their own DATABASE_URL and are skipped if absent. +DB_URL = "sqlite:///:memory:" def create_query(schema): - return SQLQueryDecorator(schema=schema, - db_url=DB_URL, - api_key=os.getenv("OPENAI_API_KEY"), - base_url=os.getenv("OPENAI_API_BASE_URL"), - model=os.getenv("OPENAI_MODEL")) + return SQLQueryDecorator( + schema=schema, + db_url=DB_URL, + api_key=os.getenv("OPENAI_API_KEY"), + base_url=os.getenv("OPENAI_API_BASE_URL"), + model=os.getenv("OPENAI_MODEL"), + repair=2, + ) class DatabaseTests(unittest.TestCase): """Base test class for database-driven tests with common setup and helper methods.""" @@ -27,15 +33,34 @@ def setUp(self): # Re-initialize the schema for each test to ensure clean state if (self.schema_sql or self.schema_path) and self.db_url: db.get_db(self.db_url).init_schema(schema_sql=self.schema_sql, schema_path=self.schema_path) + # Capture table names from schema_sql for teardown cleanup + self._tables_to_drop = [] + if self.schema_sql: + for raw in self.schema_sql.split(';'): + stmt = raw.strip() + if not stmt: + continue + m = re.search(r"CREATE\s+TABLE\s+IF\s+NOT\s+EXISTS\s+([a-zA-Z_][a-zA-Z0-9_\.]*)", stmt, re.IGNORECASE) + if m: + self._tables_to_drop.append(m.group(1)) else: raise ValueError("At least one of schema_sql, schema_path must be provided along with db_url") def tearDown(self): """Close the database connection after each test.""" - for _, connection in db.DATABASES.items(): - connection.get_engine().dispose() - + # Best-effort cleanup of tables created by this test to avoid cross-test interference + try: + if getattr(self, "_tables_to_drop", None): + database = db.get_db(self.db_url) + # Drop in reverse order to reduce FK issues + for t in reversed(self._tables_to_drop): + try: + database.execute(f"DROP TABLE IF EXISTS {t} CASCADE") + except Exception: + pass + finally: + for _, connection in db.DATABASES.items(): + connection.get_engine().dispose() + db.DATABASES.clear() db.DATABASES.clear() - - diff --git a/tests/test_adapters.py b/tests/test_adapters.py new file mode 100644 index 0000000..1e142bf --- /dev/null +++ b/tests/test_adapters.py @@ -0,0 +1,76 @@ +import os +import unittest +from typing import Any, Dict, List + +from foundation_sql.db_drivers import SQLAlchemyAdapter +from foundation_sql import db + + +SYNC_DB_URL = "sqlite:///:memory:" + + +TEST_SCHEMA = """ +CREATE TABLE IF NOT EXISTS items ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL +); +""" + +INSERT_TEMPLATE = """ +INSERT INTO items (id, name) VALUES ({{ id }}, {{ name | tojson }}); +""" + +SELECT_TEMPLATE = """ +SELECT id, name FROM items ORDER BY id; +""" + + +class TestSQLAlchemyAdapter(unittest.TestCase): + def setUp(self) -> None: + self.adapter = SQLAlchemyAdapter(SYNC_DB_URL) + # init schema + self.adapter.init_schema(TEST_SCHEMA) + # ensure clean state for each test run + try: + self.adapter.run_sql("DELETE FROM items;", {}) + except Exception: + pass + + def tearDown(self) -> None: + try: + self.adapter.close() + finally: + # Clear Database singletons that might have been created + db.DATABASES.clear() + + def test_insert_and_select(self): + # insert 2 rows + rc1 = self.adapter.run_sql(INSERT_TEMPLATE, {"id": 1, "name": "alpha"}) + rc2 = self.adapter.run_sql(INSERT_TEMPLATE, {"id": 2, "name": "beta"}) + self.assertIsInstance(rc1, int) + self.assertIsInstance(rc2, int) + self.assertEqual(rc1 + rc2, 2) + + rows = self.adapter.run_sql(SELECT_TEMPLATE, {}) + self.assertIsInstance(rows, list) + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0]["id"], 1) + self.assertEqual(rows[0]["name"], "alpha") + + def test_database_facade_run_sql_parity(self): + # Ensure Database facade wraps adapter correctly + database = db.Database(SYNC_DB_URL, adapter=self.adapter) + # seed two rows using adapter path + _ = self.adapter.run_sql(INSERT_TEMPLATE, {"id": 1, "name": "alpha"}) + _ = self.adapter.run_sql(INSERT_TEMPLATE, {"id": 2, "name": "beta"}) + # Insert + affected = database.run_sql(INSERT_TEMPLATE, id=3, name="gamma") + self.assertEqual(affected, 1) + # Select should return QueryResult wrapper + result = database.run_sql(SELECT_TEMPLATE) + self.assertTrue(hasattr(result, "first")) + self.assertEqual(result.count(), 3) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_many_to_one.py b/tests/test_many_to_one.py index e21ba46..a638200 100644 --- a/tests/test_many_to_one.py +++ b/tests/test_many_to_one.py @@ -1,4 +1,6 @@ from typing import List, Optional +import os +import shutil from tests import common from pydantic import BaseModel @@ -30,6 +32,64 @@ class Task(BaseModel): query = common.create_query(schema=TABLES_SCHEMA) +CACHE_DIR = "__sql__" + +class TestWorkspaceTasks(common.DatabaseTests): + schema_sql = TABLES_SCHEMA + + @classmethod + def setUpClass(cls) -> None: + # Ensure clean cache dir and seed SQLite-friendly templates + if os.path.exists(CACHE_DIR): + shutil.rmtree(CACHE_DIR) + os.makedirs(CACHE_DIR, exist_ok=True) + + # Create workspace: insert then fetch by last_insert_rowid() + with open(os.path.join(CACHE_DIR, "create_workspace.sql"), "w") as f: + f.write( + """ + INSERT INTO workspaces (name) VALUES ({{ name | tojson }}); + SELECT id, name FROM workspaces WHERE id = last_insert_rowid(); + """.strip() + ) + + # Add task to workspace: insert then fetch row + with open(os.path.join(CACHE_DIR, "add_task_to_workspace.sql"), "w") as f: + f.write( + """ + INSERT INTO tasks (workspace_id, title, description) + VALUES ( + {{ workspace.id }}, + {{ title | tojson }}, + {{ description | default(None) | tojson }} + ); + SELECT + id as "id", + workspace_id as "workspace.id", + title as "title", + description as "description" + FROM tasks + WHERE id = last_insert_rowid(); + """.strip() + ) + + # Get tasks for workspace: join with dotted aliases for nesting + with open(os.path.join(CACHE_DIR, "get_tasks_for_workspace.sql"), "w") as f: + f.write( + """ + SELECT + t.id as "id", + t.workspace_id as "workspace.id", + w.name as "workspace.name", + t.title as "title", + t.description as "description" + FROM tasks t + JOIN workspaces w ON w.id = t.workspace_id + WHERE w.id = {{ workspace.id }} + ORDER BY t.id; + """.strip() + ) + @query def create_workspace(name: str) -> Workspace: """ @@ -51,9 +111,6 @@ def get_tasks_for_workspace(workspace: Workspace) -> List[Task]: """ pass -class TestWorkspaceTasks(common.DatabaseTests): - schema_sql = TABLES_SCHEMA - def test_workspace_tasks(self): # Add a workspace ws = create_workspace(name="Project Alpha") diff --git a/tests/test_nested_query.py b/tests/test_nested_query.py index a1da021..f33b539 100644 --- a/tests/test_nested_query.py +++ b/tests/test_nested_query.py @@ -1,4 +1,6 @@ from typing import List, Optional +import os +import shutil from tests import common from pydantic import BaseModel @@ -34,6 +36,56 @@ class UserWithProfile(BaseModel): query = common.create_query(schema=TABLES_SCHEMA) +CACHE_DIR = "__sql__" + +class TestNestedQuery(common.DatabaseTests): + + schema_sql = TABLES_SCHEMA + + @classmethod + def setUpClass(cls) -> None: + # Ensure clean cache dir and seed SQLite/Postgres portable templates + if os.path.exists(CACHE_DIR): + shutil.rmtree(CACHE_DIR) + os.makedirs(CACHE_DIR, exist_ok=True) + + # Portable SELECT using dotted aliases so our unflatten logic can build nested objects + with open(os.path.join(CACHE_DIR, "get_users_with_profile.sql"), "w") as f: + f.write( + """ + SELECT + u.id as "id", + u.name as "name", + u.email as "email", + u.role as "role", + u.profile_bio as "profile.bio", + u.address_street as "profile.address.street", + u.address_city as "profile.address.city", + u.address_zip_code as "profile.address.zip_code" + FROM users_with_profile u + ORDER BY u.id; + """.strip() + ) + + # Portable INSERT template using jinja2 variables (handled by jinja2sql) + with open(os.path.join(CACHE_DIR, "create_user_with_profile.sql"), "w") as f: + f.write( + """ + INSERT INTO users_with_profile ( + id, name, email, role, profile_bio, address_street, address_city, address_zip_code + ) VALUES ( + {{ user.id }}, + {{ user.name | tojson }}, + {{ user.email | tojson }}, + {{ user.role | tojson }}, + {{ user.profile.bio | default(None) | tojson }}, + {{ user.profile.address.street | default(None) | tojson }}, + {{ user.profile.address.city | default(None) | tojson }}, + {{ user.profile.address.zip_code | default(None) | tojson }} + ); + """.strip() + ) + @query def get_users_with_profile() -> List[UserWithProfile]: """ @@ -48,10 +100,6 @@ def create_user_with_profile(user: UserWithProfile) -> int: """ pass -class TestNestedQuery(common.DatabaseTests): - - schema_sql = TABLES_SCHEMA - def test_nested_object_creation_and_retrieval(self): # Initially, no users users = get_users_with_profile() diff --git a/tests/test_query_sync_and_async.py b/tests/test_query_sync_and_async.py new file mode 100644 index 0000000..3aec2f6 --- /dev/null +++ b/tests/test_query_sync_and_async.py @@ -0,0 +1,166 @@ +import os +import shutil +import unittest +import inspect +from typing import List + +from pydantic import BaseModel + +from foundation_sql.query import SQLQueryDecorator +from foundation_sql import db +from foundation_sql.db_drivers import AsyncpgAdapter + + +SQLITE_DB_URL = "sqlite:///__test_sync.sqlite3" + +# Attempt to ensure DATABASE_URL is available by reading .env if needed +if not os.environ.get("DATABASE_URL") and os.path.exists(os.path.join(os.path.dirname(__file__), "..", ".env")): + env_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".env")) + try: + with open(env_path, "r") as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" in line: + key, val = line.split("=", 1) + key = key.strip() + val = val.strip().strip('"').strip("'") + os.environ.setdefault(key, val) + except Exception: + pass + +ASYNC_DB_URL = os.environ.get("DATABASE_URL") # e.g., postgresql://user:pass@localhost:5432/dbname +CACHE_DIR_SYNC = "__sql__/__sql_sync__" +CACHE_DIR_ASYNC = "__sql__/__sql_async__" + + +class TestSQLQueryDecoratorSync(unittest.TestCase): + class User(BaseModel): + id: int + name: str + + TABLES_SCHEMA = """ + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ); + """ + + @classmethod + def setUpClass(cls) -> None: + # Ensure clean cache dir and seed templates for sync + if os.path.exists(CACHE_DIR_SYNC): + shutil.rmtree(CACHE_DIR_SYNC) + os.makedirs(CACHE_DIR_SYNC, exist_ok=True) + with open(os.path.join(CACHE_DIR_SYNC, "get_users.sql"), "w") as f: + f.write("SELECT id, name FROM users ORDER BY id;") + with open(os.path.join(CACHE_DIR_SYNC, "create_user.sql"), "w") as f: + f.write("INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});") + + # Prepare SQLite DB file cleanly + if os.path.exists("__test_sync.sqlite3"): + os.remove("__test_sync.sqlite3") + + def setUp(self) -> None: + # Initialize schema on SQLite using sync adapter + self.database = db.Database(SQLITE_DB_URL) + self.database.init_schema(schema_sql=self.TABLES_SCHEMA) + # Clean table + self.database.run_sql("DELETE FROM users;") + + def tearDown(self) -> None: + # Close SQLAlchemy engine + try: + self.database.adapter.close() + except Exception: + pass + db.DATABASES.clear() + + def test_sync_wrappers_and_execution(self): + query = SQLQueryDecorator(schema=self.TABLES_SCHEMA, db_url=SQLITE_DB_URL, cache_dir=CACHE_DIR_SYNC) + + @query + def get_users() -> List["TestSQLQueryDecoratorSync.User"]: + pass + + @query + def create_user(user: "TestSQLQueryDecoratorSync.User") -> int: + pass + + # wrappers should be sync + self.assertFalse(inspect.iscoroutinefunction(get_users)) + self.assertFalse(inspect.iscoroutinefunction(create_user)) + + users = get_users() + self.assertEqual(len(users), 0) + rc = create_user(user=self.User(id=1, name="Alice")) + self.assertEqual(rc, 1) + users = get_users() + self.assertEqual(len(users), 1) + self.assertEqual(users[0].id, 1) + self.assertEqual(users[0].name, "Alice") + + +@unittest.skipUnless(ASYNC_DB_URL, "Async tests require DATABASE_URL Postgres DSN") +class TestSQLQueryDecoratorAsync(unittest.IsolatedAsyncioTestCase): + class User(BaseModel): + id: int + name: str + + TABLES_SCHEMA = """ + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ); + """ + + @classmethod + def setUpClass(cls) -> None: + # Ensure clean cache dir and seed templates for async + if os.path.exists(CACHE_DIR_ASYNC): + shutil.rmtree(CACHE_DIR_ASYNC) + os.makedirs(CACHE_DIR_ASYNC, exist_ok=True) + with open(os.path.join(CACHE_DIR_ASYNC, "get_users.sql"), "w") as f: + f.write("SELECT id, name FROM users ORDER BY id;") + with open(os.path.join(CACHE_DIR_ASYNC, "create_user.sql"), "w") as f: + f.write("INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});") + + async def asyncSetUp(self): + # Initialize schema on Postgres using async adapter each test for isolation + self.database = db.Database(ASYNC_DB_URL, adapter=AsyncpgAdapter(ASYNC_DB_URL)) + await self.database.init_schema_async(schema_sql=self.TABLES_SCHEMA) + # Clean table + await self.database.run_sql_async("DELETE FROM users;") + + async def asyncTearDown(self): + await self.database.close_async() + db.DATABASES.clear() + + async def test_async_wrappers_and_execution(self): + query = SQLQueryDecorator(schema=self.TABLES_SCHEMA, db_url=ASYNC_DB_URL, cache_dir=CACHE_DIR_ASYNC) + + @query + async def get_users() -> List["TestSQLQueryDecoratorAsync.User"]: + pass + + @query + async def create_user(user: "TestSQLQueryDecoratorAsync.User") -> int: + pass + + # wrappers should be async + self.assertTrue(inspect.iscoroutinefunction(get_users)) + self.assertTrue(inspect.iscoroutinefunction(create_user)) + + users = await get_users() + self.assertEqual(len(users), 0) + rc = await create_user(user=self.User(id=1, name="Alice")) + self.assertEqual(rc, 1) + users = await get_users() + self.assertEqual(len(users), 1) + self.assertEqual(users[0].id, 1) + self.assertEqual(users[0].name, "Alice") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_simple_query.py b/tests/test_simple_query.py index c81720b..ef9ee1c 100644 --- a/tests/test_simple_query.py +++ b/tests/test_simple_query.py @@ -1,6 +1,8 @@ from typing import List from tests import common from pydantic import BaseModel +import os +import shutil class User(BaseModel): @@ -39,6 +41,46 @@ def create_user(user: User) -> int: class TestQuery(common.DatabaseTests): schema_sql = TABLES_SCHEMA + CACHE_DIR = "__sql__" + + @classmethod + def setUpClass(cls) -> None: + # Ensure clean cache dir and seed SQLite-friendly templates for this test + if os.path.exists(cls.CACHE_DIR): + shutil.rmtree(cls.CACHE_DIR) + os.makedirs(cls.CACHE_DIR, exist_ok=True) + + # Deterministic SELECT mapping to Pydantic model fields + with open(os.path.join(cls.CACHE_DIR, "get_users.sql"), "w") as f: + f.write( + ( + """ + SELECT + id as "id", + name as "name", + email as "email", + role as "role" + FROM users + ORDER BY id; + """ + ).strip() + ) + + # Deterministic INSERT using provided user fields (string id in this schema) + with open(os.path.join(cls.CACHE_DIR, "create_user.sql"), "w") as f: + f.write( + ( + """ + INSERT INTO users (id, name, email, role) + VALUES ( + {{ user.id }}, + {{ user.name | tojson }}, + {{ user.email | tojson }}, + {{ user.role | tojson }} + ); + """ + ).strip() + ) def test_users(self): users = get_users()