diff --git a/foundation_sql/cache.py b/foundation_sql/cache.py index 00f4699..3f21ab2 100644 --- a/foundation_sql/cache.py +++ b/foundation_sql/cache.py @@ -1,83 +1,84 @@ import os from typing import Optional + class SQLTemplateCache: """ Simple file-based cache using file names as keys. - + Attributes: cache_dir (str): Directory to store cached templates """ - + def __init__(self, cache_dir: str): """ Initialize the SQL template cache. - + Args: cache_dir (str): Directory to store cached templates """ self.cache_dir = cache_dir os.makedirs(cache_dir, exist_ok=True) - + def _get_cache_path(self, key: str) -> str: """ Get the full path for a cache file. - + Args: key (str): File name to use as cache key - + Returns: str: Full path to the cache file """ return os.path.join(self.cache_dir, key) - + def set(self, key: str, template: str) -> None: """ Store an SQL template in the cache. - + Args: key (str): File name to use as cache key template (str): SQL template to cache """ cache_file = self._get_cache_path(key) - with open(cache_file, 'w') as f: + with open(cache_file, "w") as f: f.write(template) - + def get(self, key: str) -> Optional[str]: """ Retrieve a cached SQL template. - + Args: key (str): File name to use as cache key - + Returns: Optional[str]: Cached SQL template or None if not found """ cache_file = self._get_cache_path(key) if os.path.exists(cache_file): - with open(cache_file, 'r') as f: + with open(cache_file, "r") as f: return f.read() return None - + def exists(self, key: str) -> bool: """ Check if a cache entry exists. - + Args: key (str): File name to use as cache key - + Returns: bool: True if cache entry exists, False otherwise """ cache_file = self._get_cache_path(key) return os.path.exists(cache_file) - + def clear(self, key: Optional[str] = None) -> None: """ Clear cached templates. - + Args: - key (Optional[str]): Specific key to clear. + key (Optional[str]): Specific key to clear. If None, clears entire cache. """ if key: diff --git a/foundation_sql/db.py b/foundation_sql/db.py index 59a7418..61a678b 100644 --- a/foundation_sql/db.py +++ b/foundation_sql/db.py @@ -4,16 +4,18 @@ import logging import os +from datetime import datetime from types import NoneType -from typing import Dict, Any, Optional, Type, Union, List +from typing import Any, Dict, List, Optional, Type, Union + +from jinja2sql import Jinja2SQL from pydantic import BaseModel -from sqlalchemy import text +from sqlalchemy import MetaData, create_engine, text from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.schema import CreateTable -from datetime import datetime -from foundation_sql.db_drivers import EngineAdapter, SQLAlchemyAdapter -from foundation_sql.db_drivers import AsyncpgAdapter +from foundation_sql.db_drivers import AsyncpgAdapter, EngineAdapter, SQLAlchemyAdapter NESTED_SPLITTER = "." # Singleton instance @@ -24,22 +26,25 @@ logger = logging.getLogger(__name__) - class Database: """Database operations for Foundation delegated to an EngineAdapter.""" - - def __init__(self, db_url: Optional[str] = None, adapter: Optional[EngineAdapter] = None) -> None: + + def __init__( + self, db_url: Optional[str] = None, adapter: Optional[EngineAdapter] = None + ) -> None: """Initialize the database facade. - + Args: db_url: Database URL (e.g., postgresql://user:pass@host/db, sqlite:///path/to/db) If not provided, will use DATABASE_URL environment variable adapter: Optional explicit adapter. If not provided, a default sync adapter is created. """ - self.db_url = db_url or os.getenv('DATABASE_URL') + self.db_url = db_url or os.getenv("DATABASE_URL") if not self.db_url: - raise ValueError('Database URL must be provided either through constructor or DATABASE_URL environment variable') - + raise ValueError( + "Database URL must be provided either through constructor or DATABASE_URL environment variable" + ) + # Default to sync SQLAlchemy adapter unless explicitly provided self.adapter: EngineAdapter = adapter or SQLAlchemyAdapter(self.db_url) @@ -54,7 +59,9 @@ def get_engine(self) -> Engine: return self.adapter.engine raise RuntimeError("Engine is not available for this adapter") - def init_schema(self, schema_sql:Optional[str]=None, schema_path: Optional[str] = None) -> None: + def init_schema( + self, schema_sql: Optional[str] = None, schema_path: Optional[str] = None + ) -> None: """Initialize the database schema if it doesn't exist. This method runs the schema creation script in an idempotent way. @@ -66,26 +73,26 @@ def init_schema(self, schema_sql:Optional[str]=None, schema_path: Optional[str] If not provided, will use the default schema at data/tables.sql """ if not schema_sql: - with open(schema_path, 'r') as f: + with open(schema_path, "r") as f: schema_sql = f.read() # Delegate to adapter self.adapter.init_schema(schema_sql) def run_sql(self, sql_template: str, **context) -> Any: """Run an SQL template string with jinja2sql for rendering and parameter substitution. - + Args: sql_template: SQL template string with jinja2sql syntax **context: Context variables for template rendering - + Returns: For SELECT queries: A QueryResult object with methods for data access For INSERT/UPDATE/DELETE queries: The number of rows affected """ # Add datetime.now function to context if needed - if 'now' not in context: - context['now'] = datetime.now - + if "now" not in context: + context["now"] = datetime.now + # Delegate to adapter; it returns either list[dict] rows or int rowcount result = self.adapter.run_sql(sql_template, context) if isinstance(result, int): @@ -94,109 +101,116 @@ def run_sql(self, sql_template: str, **context) -> Any: return QueryResult(result) # ---------- Async delegates (Phase 2) ---------- - async def init_schema_async(self, schema_sql: Optional[str] = None, schema_path: Optional[str] = None) -> None: - if not hasattr(self.adapter, 'init_schema_async'): + async def init_schema_async( + self, schema_sql: Optional[str] = None, schema_path: Optional[str] = None + ) -> None: + if not hasattr(self.adapter, "init_schema_async"): raise NotImplementedError("Async schema init not supported by this adapter") if not schema_sql: - with open(schema_path, 'r') as f: + with open(schema_path, "r") as f: schema_sql = f.read() await self.adapter.init_schema_async(schema_sql) # type: ignore[attr-defined] async def run_sql_async(self, sql_template: str, **context) -> Any: - if not hasattr(self.adapter, 'run_sql_async'): + if not hasattr(self.adapter, "run_sql_async"): raise NotImplementedError("Async run_sql not supported by this adapter") - if 'now' not in context: - context['now'] = datetime.now + if "now" not in context: + context["now"] = datetime.now result = await self.adapter.run_sql_async(sql_template, context) # type: ignore[attr-defined] if isinstance(result, int): return result return QueryResult(result) async def close_async(self) -> None: - if hasattr(self.adapter, 'close_async'): + if hasattr(self.adapter, "close_async"): await self.adapter.close_async() # type: ignore[attr-defined] - - def execute(self, sql: str, params: Optional[Union[tuple, dict, List[tuple]]] = None) -> Any: + def execute( + self, sql: str, params: Optional[Union[tuple, dict, List[tuple]]] = None + ) -> Any: """ Execute a raw SQL statement with optional parameters. - + Args: sql (str): SQL statement to execute - params (Optional[Union[tuple, dict, List[tuple]]]): + params (Optional[Union[tuple, dict, List[tuple]]]): Optional parameters for the SQL statement - Single tuple for single parameter set - List of tuples for multiple parameter sets (bulk insert) - Dictionary for named parameters - + Returns: Any: Result of the execution """ # Only supported on SQLAlchemy adapter path for now if not isinstance(self.adapter, SQLAlchemyAdapter): - raise NotImplementedError("execute() is only supported for SQLAlchemy adapter") + raise NotImplementedError( + "execute() is only supported for SQLAlchemy adapter" + ) with self.adapter.engine.connect() as connection: try: # Replace '?' placeholders with SQLAlchemy named parameters - if '?' in sql: + if "?" in sql: # Count the number of placeholders - placeholder_count = sql.count('?') - + placeholder_count = sql.count("?") + # Replace '?' with named parameters - named_sql = sql.replace('?', ':{}'.format) - + named_sql = sql.replace("?", ":{}".format) + # Prepare parameters if params is None: named_params = {} elif isinstance(params, (tuple, list)): # Convert tuple/list to named dictionary - named_params = {f'p{i}': val for i, val in enumerate(params)} - named_sql = named_sql.format(p=lambda i: f':p{i}') + named_params = {f"p{i}": val for i, val in enumerate(params)} + named_sql = named_sql.format(p=lambda i: f":p{i}") elif isinstance(params, dict): named_params = params else: - raise ValueError("Invalid parameter type. Must be tuple, dict, or list of tuples.") - + raise ValueError( + "Invalid parameter type. Must be tuple, dict, or list of tuples." + ) + # Execute with named parameters result = connection.execute(text(named_sql), named_params) else: # If no '?' placeholders, use as-is result = connection.execute(text(sql), params or {}) - + # If it's a SELECT query, return the rows if result.returns_rows: return result.fetchall() - + # For INSERT, UPDATE, DELETE, return the number of rows affected return result.rowcount - + except SQLAlchemyError as e: raise RuntimeError(f"Database execution error: {str(e)}") from e class QueryResult: """A clean abstraction over query results that doesn't leak implementation details.""" - + def __init__(self, rows: List[Dict[str, Any]]): """Initialize with a list of row dictionaries. - + Args: rows: List of dictionaries representing database rows """ self.rows = rows - + def first(self) -> Optional[Dict[str, Any]]: """Get the first row as a dictionary or None if no rows. - + Returns: First row as a dictionary or None """ return self.rows[0] if self.rows else None - + def all(self) -> List[Dict[str, Any]]: """Get all rows as a list of dictionaries. - + Returns: List of dictionaries representing all rows """ @@ -204,36 +218,64 @@ def all(self) -> List[Dict[str, Any]]: def count(self) -> int: """Get the number of rows. - + Returns: Number of rows """ return len(self.rows) - + def is_empty(self) -> bool: """Check if the result contains any rows. - + Returns: True if no rows, False otherwise """ return len(self.rows) == 0 + def scalar(self) -> Optional[Any]: + """Get the first value from the first row.""" + if self.rows and self.rows[0]: + return next(iter(self.rows[0].values()), None) + return None + + +# Function to load the schema from the database +def extract_schema_from_db(db_url: str) -> str: + """Extract the schema from the database. + + Args: + db_url: Database URL to use + + Returns: + Schema as a string + """ + engine = create_engine(db_url) + metadata = MetaData() + metadata.reflect(bind=engine) + + schema_lines = [] + for table in metadata.sorted_tables: + ddl = str(CreateTable(table).compile(engine)) + schema_lines.append(ddl + ";") + + return "\n\n".join(schema_lines) def get_db(db_url: str) -> Database: """Get the database instance. - + Args: db_url: Database URL to use - + Returns: Database instance """ if db_url not in DATABASES: DATABASES[db_url] = Database(db_url) - + return DATABASES[db_url] + def get_db_with_adapter(db_url: str, mode: str) -> Database: """Internal helper for selecting adapter explicitly. mode: "sync" | "async" (async not implemented yet) @@ -244,13 +286,14 @@ def get_db_with_adapter(db_url: str, mode: str) -> Database: return Database(db_url, adapter=AsyncpgAdapter(db_url)) raise ValueError(f"Unknown adapter mode: {mode}") + def run_sql(db_url: str, sql_template: str, **context) -> Any: """Run an SQL template string with jinja2sql for rendering and parameter substitution. - + Args: sql_template: SQL template string with jinja2sql syntax **context: Context variables for template rendering - + Returns: For SELECT queries: A QueryResult object with methods for data access For INSERT/UPDATE/DELETE queries: The number of rows affected @@ -258,45 +301,47 @@ def run_sql(db_url: str, sql_template: str, **context) -> Any: return get_db(db_url).run_sql(sql_template, **context) -def parse_query_to_pydantic(data: Dict[str, Any], model_class: Type[BaseModel]) -> Optional[BaseModel]: +def parse_query_to_pydantic( + data: Dict[str, Any], model_class: Type[BaseModel] +) -> Optional[BaseModel]: """Parse query result data into a Pydantic model, handling nested models. - + Args: data: Dictionary containing query results with optional nested fields model_class: The Pydantic model class to instantiate - + Returns: Instance of the Pydantic model or None if data is None """ if not data: - return None + return None unflattened_data = unflatten_dict(data) # Check the response type and transform accordingly if model_class == int: - return int(unflattened_data["result"]) + # FIX : STILL ONLY GETS FIRST LINE OF RESPONSE + return int(next(iter(unflattened_data.values()))) elif model_class == NoneType: return None - - return model_class(**unflattened_data) + return model_class(**unflattened_data) def unflatten_dict(flat_dict: Dict[str, Any]) -> Dict[str, Any]: """Convert a flattened dictionary with keys like 'parent.child.grandchild' (using NESTED_SPLITTER) into a nested dictionary structure. - + Args: flat_dict: Dictionary with flattened keys using NESTED_SPLITTER for nesting - + Returns: - Nested dictionary structure where nested objects with all None values + Nested dictionary structure where nested objects with all None values are replaced by None at the parent level. """ grouped_keys = {} direct_keys = {} - + # First, categorize the keys for key, value in flat_dict.items(): if NESTED_SPLITTER in key: @@ -306,27 +351,27 @@ def unflatten_dict(flat_dict: Dict[str, Any]) -> Dict[str, Any]: grouped_keys[prefix][rest] = value else: direct_keys[key] = value - + # Process each group and add to result result = dict(direct_keys) # Start with the direct keys - + for prefix, nested_dict in grouped_keys.items(): # Check if this prefix contains nested structures has_nested = any(NESTED_SPLITTER in key for key in nested_dict.keys()) - + if has_nested: # Recursively unflatten the nested structure nested_result = unflatten_dict(nested_dict) - + # Check if all values in the nested result are None after unflattening is_all_none = False if isinstance(nested_result, dict): is_all_none = all(v is None for v in nested_result.values()) - + result[prefix] = None if is_all_none else nested_result else: # Check if all values are None for a flat nested dict is_all_none = all(v is None for v in nested_dict.values()) result[prefix] = None if is_all_none else nested_dict - + return result diff --git a/foundation_sql/db_drivers.py b/foundation_sql/db_drivers.py index 28f64d2..72739b0 100644 --- a/foundation_sql/db_drivers.py +++ b/foundation_sql/db_drivers.py @@ -3,16 +3,15 @@ import logging import re # Add missing import from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, List +from datetime import datetime +from typing import Any, Dict, List, Optional +import asyncpg +from jinja2sql import Jinja2SQL from sqlalchemy import create_engine, text from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError -from jinja2sql import Jinja2SQL -from datetime import datetime -import asyncpg - logger = logging.getLogger(__name__) @@ -21,25 +20,26 @@ class EngineAdapter(ABC): # ---------- Sync API ---------- @abstractmethod - def init_schema(self, schema_sql: str) -> None: - ... + def init_schema(self, schema_sql: str) -> None: ... @abstractmethod - def run_sql(self, template: str, data: Dict[str, Any]) -> Any: - ... + def run_sql(self, template: str, data: Dict[str, Any]) -> Any: ... @abstractmethod - def close(self) -> None: - ... + def close(self) -> None: ... # ---------- Async API (optional) ---------- - async def init_pool_async(self) -> None: # pragma: no cover - to be implemented by async adapters + async def init_pool_async( + self, + ) -> None: # pragma: no cover - to be implemented by async adapters raise NotImplementedError async def init_schema_async(self, schema_sql: str) -> None: # pragma: no cover raise NotImplementedError - async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: # pragma: no cover + async def run_sql_async( + self, template: str, data: Dict[str, Any] + ) -> Any: # pragma: no cover raise NotImplementedError async def close_async(self) -> None: # pragma: no cover @@ -59,20 +59,22 @@ def __init__(self, dsn: str) -> None: def init_schema(self, schema_sql: str) -> None: with self.engine.begin() as conn: try: - for statement in schema_sql.split(';'): + for statement in schema_sql.split(";"): if statement.strip(): conn.execute(text(statement)) except SQLAlchemyError as e: - raise RuntimeError(f'Failed to initialize schema: {str(e)}') from e + raise RuntimeError(f"Failed to initialize schema: {str(e)}") from e def run_sql(self, template: str, data: Dict[str, Any]) -> Any: # ensure now is available - if 'now' not in data: - data['now'] = datetime.now + if "now" not in data: + data["now"] = datetime.now try: # Normalize templates: remove explicit tojson filters to allow binding - normalized = re.sub(r"\{\{\s*([^}]+?)\s*\|\s*tojson\s*\}\}", r"{{ \1 }}", template) + normalized = re.sub( + r"\{\{\s*([^}]+?)\s*\|\s*tojson\s*\}\}", r"{{ \1 }}", template + ) query, params = self.j2sql.from_string(normalized, context=data) except Exception as e: raise ValueError( @@ -82,7 +84,9 @@ def run_sql(self, template: str, data: Dict[str, Any]) -> Any: with self.engine.connect() as conn: with conn.begin(): try: - statements = [stmt.strip() for stmt in query.split(';') if stmt.strip()] + statements = [ + stmt.strip() for stmt in query.split(";") if stmt.strip() + ] total_rows = 0 last_result = None @@ -137,23 +141,25 @@ async def init_schema_async(self, schema_sql: str) -> None: assert self.pool is not None async with self.pool.acquire() as conn: # Execute statements sequentially to ensure order - for statement in [s.strip() for s in schema_sql.split(';') if s.strip()]: + for statement in [s.strip() for s in schema_sql.split(";") if s.strip()]: try: logger.debug(f"Executing schema statement: {statement}") await conn.execute(statement) except Exception as e: logger.error(f"Failed to execute schema statement: {statement}") - raise RuntimeError(f"Failed to execute schema statement: {str(e)}") from e + raise RuntimeError( + f"Failed to execute schema statement: {str(e)}" + ) from e async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: await self.init_pool_async() assert self.pool is not None - if 'now' not in data: - data['now'] = datetime.now + if "now" not in data: + data["now"] = datetime.now # Special handling for templates without parameters - if '{{' not in template: + if "{{" not in template: # No template variables, execute directly async with self.pool.acquire() as conn: is_select = template.strip().lower().startswith("select") @@ -167,15 +173,19 @@ async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: try: # Add the same template normalization as SQLAlchemyAdapter # Normalize templates: remove explicit tojson filters to allow binding - normalized = re.sub(r"\{\{\s*([^}]+?)\s*\|\s*tojson\s*\}\}", r"{{ \1 }}", template) + normalized = re.sub( + r"\{\{\s*([^}]+?)\s*\|\s*tojson\s*\}\}", r"{{ \1 }}", template + ) # jinja2sql generates SQL with $1, $2... placeholders and params in correct order query, params_list = self.j2sql.from_string(normalized, context=data) - + # Debug: log parameter types logger.debug(f"Rendered query: {query}") logger.debug(f"Parameter values: {params_list}") - logger.debug(f"Parameter types: {[(type(p).__name__, p) for p in params_list]}") - + logger.debug( + f"Parameter types: {[(type(p).__name__, p) for p in params_list]}" + ) + except Exception as e: raise ValueError( f"Failed to render SQL. Likely SQL template & Parameter mismatch: {str(e)}" @@ -184,34 +194,40 @@ async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: async with self.pool.acquire() as conn: try: # Handle multiple statements like SQLAlchemyAdapter - statements = [stmt.strip() for stmt in query.split(';') if stmt.strip()] + statements = [stmt.strip() for stmt in query.split(";") if stmt.strip()] total_rows = 0 last_result = None logger.debug(f"Executing {len(statements)} statement(s)") for i, statement in enumerate(statements): logger.debug(f"Statement {i+1}: {statement}") - + is_select = statement.strip().lower().startswith("select") - + if is_select: last_result = await conn.fetch(statement, *params_list) total_rows += len(last_result) logger.debug(f"SELECT returned {len(last_result)} rows") else: # Check if this is a schema operation or data operation - is_schema_op = any(keyword in statement.upper() for keyword in - ['CREATE', 'DROP', 'ALTER', 'TRUNCATE']) - + is_schema_op = any( + keyword in statement.upper() + for keyword in ["CREATE", "DROP", "ALTER", "TRUNCATE"] + ) + if is_schema_op: # Schema operations typically don't use parameters - logger.debug("Executing as schema operation (no parameters)") + logger.debug( + "Executing as schema operation (no parameters)" + ) status = await conn.execute(statement) else: # Data operations use parameters - logger.debug(f"Executing as data operation with parameters: {params_list}") + logger.debug( + f"Executing as data operation with parameters: {params_list}" + ) status = await conn.execute(statement, *params_list) - + stmt_rows = _parse_rowcount(status) total_rows += stmt_rows logger.debug(f"Statement affected {stmt_rows} rows") @@ -221,23 +237,29 @@ async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: rows = [dict(r) for r in last_result] logger.debug(f"Returning {len(rows)} rows") return rows - + logger.debug(f"Returning row count: {total_rows}") return total_rows - + except Exception as e: - error_msg = (f"Failed to execute SQL: {str(e)}\n" - f"Rendered SQL: {query}\n" - f"Parameters: {params_list}\n" - f"Parameter types: {[type(p).__name__ for p in params_list]}") + error_msg = ( + f"Failed to execute SQL: {str(e)}\n" + f"Rendered SQL: {query}\n" + f"Parameters: {params_list}\n" + f"Parameter types: {[type(p).__name__ for p in params_list]}" + ) logger.error(error_msg) raise RuntimeError(error_msg) from e # Sync methods are not supported for asyncpg adapter - def init_schema(self, schema_sql: str) -> None: # pragma: no cover - sync not supported + def init_schema( + self, schema_sql: str + ) -> None: # pragma: no cover - sync not supported raise NotImplementedError("Use init_schema_async with AsyncpgAdapter") - def run_sql(self, template: str, data: Dict[str, Any]) -> Any: # pragma: no cover - sync not supported + def run_sql( + self, template: str, data: Dict[str, Any] + ) -> Any: # pragma: no cover - sync not supported raise NotImplementedError("Use run_sql_async with AsyncpgAdapter") def close(self) -> None: # pragma: no cover - sync not supported @@ -251,4 +273,4 @@ def _parse_rowcount(status: str) -> int: # Common patterns: 'INSERT 0 1' (rowcount is last), 'UPDATE 3' (last), 'DELETE 2' return int(parts[-1]) except Exception: - return 0 \ No newline at end of file + return 0 diff --git a/foundation_sql/gen.py b/foundation_sql/gen.py index cad5266..cdd53c7 100644 --- a/foundation_sql/gen.py +++ b/foundation_sql/gen.py @@ -1,27 +1,26 @@ import logging -from openai import OpenAI import re +from openai import OpenAI + + class SQLGenerator: """ Advanced SQL template generator with configurable LLM backend. - + Supports: - Configurable language models - Persistent template caching - Flexible generation parameters - Comprehensive error handling """ - + def __init__( - self, - api_key: str, - base_url: str, - model: str = "llama-3.3-70b-versatile" + self, api_key: str, base_url: str, model: str = "llama-3.3-70b-versatile" ): """ Initialize the SQL generator. - + Args: api_key (str): API key for the LLM service base_url (str): Base URL for the LLM service @@ -36,27 +35,25 @@ def __init__( def client(self): return OpenAI(api_key=self.api_key, base_url=self.base_url) - def generate_sql(self, prompt: str) -> str: """ Generate an SQL template based on the provided prompt. - + Args: prompt (str): Detailed prompt for SQL generation - + Returns: str: Generated SQL template """ response = self.client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": prompt} - ] + model=self.model, messages=[{"role": "system", "content": prompt}] ) - + generated_sql = response.choices[0].message.content.strip() - + # Remove ```sql or ``` fences - sql_template = re.sub(r"^```sql\s*|^```\s*|```$", "", generated_sql, flags=re.MULTILINE).strip() - - return sql_template \ No newline at end of file + sql_template = re.sub( + r"^```sql\s*|^```\s*|```$", "", generated_sql, flags=re.MULTILINE + ).strip() + + return sql_template diff --git a/foundation_sql/prompt.py b/foundation_sql/prompt.py index dd2132d..edd2cd1 100644 --- a/foundation_sql/prompt.py +++ b/foundation_sql/prompt.py @@ -1,8 +1,8 @@ import inspect import json +from datetime import datetime from types import NoneType from typing import Any, Callable, Dict, Optional, Type, Union, get_type_hints -from datetime import datetime from pydantic import BaseModel @@ -27,7 +27,6 @@ def _model_fields(self): return {} return {k: str(v) for k, v in self.return_type.model_fields.items()} - def _extract_kwargs(self, func: Callable) -> Dict[str, Type]: """ Extract named parameters and their types from a function. @@ -40,8 +39,8 @@ def _extract_kwargs(self, func: Callable) -> Dict[str, Type]: """ signature = inspect.signature(func) return { - name: param.annotation - for name, param in signature.parameters.items() + name: param.annotation + for name, param in signature.parameters.items() if param.annotation is not param.empty } @@ -65,57 +64,60 @@ def serialize_value(v): def _extract_return_model(self, func: Callable) -> (Type[BaseModel], Optional[str]): """ Extract the return model type from a function's type annotations. - + Args: func (Callable): Function to analyze - + Returns: Tuple containing: - Pydantic model class - Wrapper type ('list' or None) - + Raises: ValueError: If return type is invalid or not a Pydantic model """ hints = get_type_hints(func) - if 'return' not in hints: - raise ValueError(f'Function {func.__name__} must have a return type annotation') - - return_type = hints['return'] + if "return" not in hints: + raise ValueError( + f"Function {func.__name__} must have a return type annotation" + ) + + return_type = hints["return"] wrapper = None - + # Handle Optional[Model] - if hasattr(return_type, '__origin__') and return_type.__origin__ is Union: + if hasattr(return_type, "__origin__") and return_type.__origin__ is Union: args = return_type.__args__ if len(args) == 2 and args[1] is type(None): return_type = args[0] - + # Handle List[Model] - if hasattr(return_type, '__origin__') and return_type.__origin__ is list: - wrapper = 'list' + if hasattr(return_type, "__origin__") and return_type.__origin__ is list: + wrapper = "list" return_type = return_type.__args__[0] - - return return_type, wrapper + return return_type, wrapper class SQLPromptGenerator: """ Generates prompts for SQL template generation based on function context and predefined schemas. - + Attributes: func (FunctionSpec): Function to generate SQL for template_name (str): Name of the SQL template """ - - def __init__(self, func_spec: FunctionSpec, + + def __init__( + self, + func_spec: FunctionSpec, template_name: str, system_prompt: str, - schema: Optional[str] = None - ): + schema: Optional[str] = None, + ): """ Initialize the SQL prompt generator. - + Args: func (FunctionSpec): Function to generate SQL for template_name (str): Name of the SQL template @@ -128,10 +130,15 @@ def __init__(self, func_spec: FunctionSpec, self.schema = schema self.system_prompt = system_prompt - def generate_prompt(self, kwargs: Dict[str, Any], error: Optional[str] = None, prev_template: Optional[str] = None) -> str: + def generate_prompt( + self, + kwargs: Dict[str, Any], + error: Optional[str] = None, + prev_template: Optional[str] = None, + ) -> str: """ Generate a comprehensive prompt for SQL template generation. - + Returns: str: Detailed prompt with function context and schema """ @@ -148,7 +155,6 @@ def generate_prompt(self, kwargs: Dict[str, Any], error: Optional[str] = None, p Review the error and suggest an improved SQL template that works. """ - return f""" {self.system_prompt} ---------------- diff --git a/foundation_sql/prompts.md b/foundation_sql/prompts.md index 2cd09cc..db9b994 100644 --- a/foundation_sql/prompts.md +++ b/foundation_sql/prompts.md @@ -2,7 +2,7 @@ You are an expert SQL developer. Write one or more SQL queries that can perform Primary target is PostgreSQL (tests run against Postgres). Prefer Postgres-compatible SQL. Avoid SQLite-specific functions. -1. Start with a -- comment to document the function name, parameters and docstring, explaining what the SQL query does. +1. Start with a comment to document the function name, parameters and docstring, explaining what the SQL query does. Make sure to start comments with `--` (Only 2 dashes, no more , no less) 2. Use jinja2 template to generate SQL 3. When accessing nested fields handle cases if they aren't defined. Use default filter with None value for such cases e.g. {{user.zip_code|default(None)}} @@ -22,13 +22,13 @@ Primary target is PostgreSQL (tests run against Postgres). Prefer Postgres-compa Here is an example def get_task(workspace: schema.Workspace, task_no: int) -> schema.Task: - """ - Creates and returns a Task object, for the provided workspace and task_no - """ - pass - +""" +Creates and returns a Task object, for the provided workspace and task_no +""" +pass The SQL generated would look like the following + ```sql --- def get_task(workspace: schema.Workspace, task_no: int) -> schema.Task --- Creates and returns a Task object, for the provided workspace and task_no @@ -58,7 +58,7 @@ The SQL generated would look like the following LEFT JOIN agents a ON t.agent_id = a.id LEFT JOIN models m ON a.model_id = m.id LEFT JOIN workspace_tasks wt ON t.id = wt.task_id - WHERE t.task_no = {{task_no}} AND wt.workspace_id = {{workspace.id}} + WHERE t.task_no = {{task_no}} AND wt.workspace_id = {{workspace.id}}; ``` Below are the real specifications for which query needs to be generated. diff --git a/foundation_sql/query.py b/foundation_sql/query.py index 18e6b4d..dbe3cde 100644 --- a/foundation_sql/query.py +++ b/foundation_sql/query.py @@ -1,45 +1,44 @@ -import os import functools -from typing import Any, Callable, Optional import inspect +import os +from importlib import resources as impresources +from typing import Any, Callable, Dict, Optional -from foundation_sql.prompt import SQLPromptGenerator, FunctionSpec -from foundation_sql.gen import SQLGenerator -from foundation_sql.cache import SQLTemplateCache from foundation_sql import db -from typing import Callable, Dict, Optional - -from importlib import resources as impresources +from foundation_sql.cache import SQLTemplateCache +from foundation_sql.gen import SQLGenerator +from foundation_sql.prompt import FunctionSpec, SQLPromptGenerator -DEFAULT_SYSTEM_PROMPT = impresources.read_text('foundation_sql', 'prompts.md') +DEFAULT_SYSTEM_PROMPT = impresources.read_text("foundation_sql", "prompts.md") class SQLQueryDecorator: """ Advanced decorator for generating and executing SQL queries with comprehensive features. - + Supports: - Dynamic SQL template generation - Configurable LLM backend - Persistent template caching - Robust error handling and regeneration - + Attributes: name (Optional[str]): Custom name for SQL template regen (Optional[bool]): SQL template regeneration strategy config (SQLGeneratorConfig): Configuration for SQL generation """ - + def __init__( - self, - name: Optional[str] = None, + self, + name: Optional[str] = None, regen: Optional[bool] = None, repair: Optional[int] = 0, schema: Optional[str] = None, schema_path: Optional[str] = None, + schema_inspect: bool = False, system_prompt: Optional[str] = None, system_prompt_path: Optional[str] = None, - cache_dir: Optional[str] = '__sql__', + cache_dir: Optional[str] = "__sql__", db_url: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None, @@ -47,18 +46,36 @@ def __init__( ): """ Initialize the SQL query decorator. - + Args: - name (Optional[str]): Custom name for SQL file/folder. + name (Optional[str]): Custom name for SQL file/folder. Defaults to function name. regen (Optional[bool]): SQL template regeneration strategy. - config (Optional[SQLGeneratorConfig]): Custom configuration + config (Optional[SQLGeneratorConfig]): Custom configuration for SQL generation. """ self.name = name self.regen = regen self.cache_dir = cache_dir - self.schema = schema or self.load_file(schema_path) + self.db_url = db_url or os.environ.get("DATABASE_URL") + if not self.db_url: + raise ValueError( + "Database URL not provided either through constructor or DATABASE_URL environment variable" + ) + + if schema_inspect: + if schema or schema_path: + raise ValueError( + "Cannot provide 'schema' or 'schema_path' when 'schema_inspect' is True." + ) + self.schema = db.extract_schema_from_db(self.db_url) + else: + if not schema and not schema_path: + raise ValueError( + "Must provide either 'schema' or 'schema_path' when 'schema_inspect' is False." + ) + self.schema = schema or self.load_file(schema_path) + if system_prompt or system_prompt_path: self.system_prompt = system_prompt or self.load_file(system_prompt_path) else: @@ -66,25 +83,31 @@ def __init__( self.db_url = db_url if not self.db_url: - raise ValueError("Database URL not provided either through constructor or DATABASE_URL environment variable") - + raise ValueError( + "Database URL not provided either through constructor or DATABASE_URL environment variable" + ) + # Initialize cache and SQL generator self.cache = SQLTemplateCache(cache_dir=cache_dir) self.sql_generator = SQLGenerator( - api_key=api_key, - base_url=base_url, - model=model + api_key=api_key, base_url=base_url, model=model ) self.repair = repair - + def __call__(self, func: Callable) -> Callable: template_name = self.name or f"{func.__name__}.sql" fn_spec = FunctionSpec(func) - prompt_generator = SQLPromptGenerator(fn_spec, template_name, self.system_prompt, self.schema) + prompt_generator = SQLPromptGenerator( + fn_spec, template_name, self.system_prompt, self.schema + ) - def sql_gen(kwargs: Dict[str, Any], error: Optional[str] = None, prev_template: Optional[str] = None): + def sql_gen( + kwargs: Dict[str, Any], + error: Optional[str] = None, + prev_template: Optional[str] = None, + ): if self.regen or not self.cache.exists(template_name) or error: prompt = prompt_generator.generate_prompt(kwargs, error, prev_template) sql_template = self.sql_generator.generate_sql(prompt) @@ -94,7 +117,7 @@ def sql_gen(kwargs: Dict[str, Any], error: Optional[str] = None, prev_template: return sql_template def _parse_result(result_data: Any): - if fn_spec.wrapper == 'list': + if fn_spec.wrapper == "list": return [ db.parse_query_to_pydantic(row, fn_spec.return_type) for row in result_data.all() @@ -104,10 +127,12 @@ def _parse_result(result_data: Any): elif fn_spec.return_type is int: # Best-effort mapping from non-int results to int # 1) QueryResult-like object - if hasattr(result_data, "count") and callable(getattr(result_data, "count")): + if hasattr(result_data, "scalar") and callable( + getattr(result_data, "scalar") + ): try: - return int(result_data.count()) - except Exception: + return int(result_data.scalar()) + except (ValueError, TypeError): pass # 2) List of rows if isinstance(result_data, list): @@ -128,7 +153,11 @@ def _parse_result(result_data: Any): return 0 else: first_row = result_data.first() - return db.parse_query_to_pydantic(first_row, fn_spec.return_type) if first_row else None + return ( + db.parse_query_to_pydantic(first_row, fn_spec.return_type) + if first_row + else None + ) is_async = inspect.iscoroutinefunction(func) executor = WrapSqlExecution( @@ -140,20 +169,20 @@ def _parse_result(result_data: Any): ) return executor.build_wrapper(is_async) - def load_file(self, path: str) -> str: """ Load predefined table schemas. - + Returns: str: SQL schema definitions """ if not path or not os.path.exists(path): raise FileNotFoundError(f"Schema file not found at {path}") - with open(path, 'r') as f: + with open(path, "r") as f: return f.read() + class WrapSqlExecution: def __init__( @@ -174,7 +203,9 @@ async def _execute_async(self, **kwargs: Any): last_exc = None error = None sql_template = None - attempts = self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + attempts = ( + self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + ) database = db.get_db_with_adapter(self.db_url, "async") for _ in range(attempts): @@ -200,7 +231,9 @@ def _execute_sync(self, **kwargs: Any): last_exc = None error = None sql_template = None - attempts = self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + attempts = ( + self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + ) for _ in range(attempts): sql_template = self.sql_gen(kwargs, error, sql_template) @@ -223,12 +256,16 @@ def _execute_sync(self, **kwargs: Any): def build_wrapper(self, is_async: bool): if is_async: + @functools.wraps(self.func) async def async_wrapper(**kwargs: Any): return await self._execute_async(**kwargs) + return async_wrapper else: + @functools.wraps(self.func) def sync_wrapper(**kwargs: Any): return self._execute_sync(**kwargs) + return sync_wrapper diff --git a/requirements.txt b/requirements.txt index 2c34a84..0a58e02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,8 @@ jinja2sql pydantic openai asyncpg -psycopg2-binary \ No newline at end of file +psycopg2-binary + +# Development dependencies +isort +black diff --git a/setup.py b/setup.py index 56f620f..882fd81 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,12 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() with open("requirements.txt", "r", encoding="utf-8") as fh: - requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")] + requirements = [ + line.strip() for line in fh if line.strip() and not line.startswith("#") + ] setup( name="foundation-sql", diff --git a/tests/common.py b/tests/common.py index 5caf555..ec76cb7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,26 +1,32 @@ -import unittest import os -from foundation_sql import db -from foundation_sql.query import SQLQueryDecorator import re +import unittest +from typing import Optional from dotenv import load_dotenv + +from foundation_sql import db +from foundation_sql.query import SQLQueryDecorator + load_dotenv() # Force SQLite in-memory for all tests that use this common module. # Async/Postgres-specific tests manage their own DATABASE_URL and are skipped if absent. DB_URL = "sqlite:///:memory:" -def create_query(schema): + +def create_query(schema=None, schema_inspect=False, db_url=DB_URL): return SQLQueryDecorator( schema=schema, - db_url=DB_URL, + schema_inspect=schema_inspect, + db_url=db_url, api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_API_BASE_URL"), model=os.getenv("OPENAI_MODEL"), repair=2, ) + class DatabaseTests(unittest.TestCase): """Base test class for database-driven tests with common setup and helper methods.""" @@ -32,20 +38,25 @@ def setUp(self): """Create a fresh database connection for each test.""" # Re-initialize the schema for each test to ensure clean state if (self.schema_sql or self.schema_path) and self.db_url: - db.get_db(self.db_url).init_schema(schema_sql=self.schema_sql, schema_path=self.schema_path) + db.get_db(self.db_url).init_schema( + schema_sql=self.schema_sql, schema_path=self.schema_path + ) # Capture table names from schema_sql for teardown cleanup self._tables_to_drop = [] if self.schema_sql: - for raw in self.schema_sql.split(';'): + for raw in self.schema_sql.split(";"): stmt = raw.strip() if not stmt: continue - m = re.search(r"CREATE\s+TABLE\s+IF\s+NOT\s+EXISTS\s+([a-zA-Z_][a-zA-Z0-9_\.]*)", stmt, re.IGNORECASE) + m = re.search( + r"CREATE\s+TABLE\s+IF\s+NOT\s+EXISTS\s+([a-zA-Z_][a-zA-Z0-9_\.]*)", + stmt, + re.IGNORECASE, + ) if m: self._tables_to_drop.append(m.group(1)) else: - raise ValueError("At least one of schema_sql, schema_path must be provided along with db_url") - + pass def tearDown(self): """Close the database connection after each test.""" diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 1e142bf..fac697a 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -2,9 +2,8 @@ import unittest from typing import Any, Dict, List -from foundation_sql.db_drivers import SQLAlchemyAdapter from foundation_sql import db - +from foundation_sql.db_drivers import SQLAlchemyAdapter SYNC_DB_URL = "sqlite:///:memory:" diff --git a/tests/test_db_parser.py b/tests/test_db_parser.py index 5673a8a..5cc73e1 100644 --- a/tests/test_db_parser.py +++ b/tests/test_db_parser.py @@ -5,12 +5,11 @@ import unittest from datetime import datetime from enum import Enum -from typing import Optional +from typing import List, Optional -from foundation_sql import db from pydantic import BaseModel, Field -from enum import Enum -from typing import Optional, List + +from foundation_sql import db from tests import common # --- Test SQL Schema --- @@ -46,6 +45,7 @@ ); """ + # --- Pydantic Models and Enums (previously from schema) --- class Model(BaseModel): id: str @@ -54,10 +54,12 @@ class Model(BaseModel): max_tokens: Optional[int] = None created_at: Optional[datetime] = None + class AgentType(str, Enum): GENERALIST = "generalist" SPECIALIST = "specialist" + class Agent(BaseModel): id: str name: str @@ -67,11 +69,13 @@ class Agent(BaseModel): model: Optional[Model] = None created_at: Optional[datetime] = None + class TaskStatus(str, Enum): NEW = "new" IN_PROGRESS = "in_progress" ASSIGNED_FOR_WORK = "assigned_for_work" + class Task(BaseModel): id: str task_no: Optional[int] = None @@ -79,9 +83,10 @@ class Task(BaseModel): description: Optional[str] = None status: TaskStatus agent: Optional[Agent] = None - parent_task: Optional['Task'] = None + parent_task: Optional["Task"] = None created_at: Optional[datetime] = None + Task.model_rebuild() @@ -99,19 +104,22 @@ def test_parse_basic_model(self): "name": "Test Model", "context_window": 4096, "max_tokens": 1024, - "created_at": "2025-01-01 12:00:00" + "created_at": "2025-01-01 12:00:00", } - + # Parse the data into a Model result = db.parse_query_to_pydantic(data, Model) - + # Assertions self.assertIsNotNone(result) self.assertEqual(result.id, "00000000-0000-0000-0000-000000000001") self.assertEqual(result.name, "Test Model") self.assertEqual(result.context_window, 4096) self.assertEqual(result.max_tokens, 1024) - self.assertEqual(result.created_at, datetime.strptime("2025-01-01 12:00:00", "%Y-%m-%d %H:%M:%S")) + self.assertEqual( + result.created_at, + datetime.strptime("2025-01-01 12:00:00", "%Y-%m-%d %H:%M:%S"), + ) def test_parse_with_enum(self): """Test parsing a model with enum fields.""" @@ -126,12 +134,12 @@ def test_parse_with_enum(self): "model.name": "Test Model", "model.context_window": 4096, "model.max_tokens": 1024, - "created_at": "2025-01-01 12:00:00" + "created_at": "2025-01-01 12:00:00", } - + # Parse the data into an Agent result = db.parse_query_to_pydantic(data, Agent) - + # Assertions self.assertIsNotNone(result) self.assertEqual(result.id, "00000000-0000-0000-0000-000000000001") @@ -143,32 +151,33 @@ def test_parse_with_enum(self): def test_parse_with_optional_enum(self): """Test parsing a model with optional enum fields.""" + # Create a simple test class with optional enum class TestStatus(str, Enum): ACTIVE = "active" INACTIVE = "inactive" - + class TestModel(BaseModel): name: str status: Optional[TestStatus] = None - + # Test with enum value present data1 = { "id": "00000000-0000-0000-0000-000000000001", "name": "Test Item", "status": "active", - "created_at": "2025-01-01 12:00:00" + "created_at": "2025-01-01 12:00:00", } - + result1 = db.parse_query_to_pydantic(data1, TestModel) self.assertEqual(result1.status, TestStatus.ACTIVE) self.assertIsInstance(result1.status, TestStatus) - + # Test with enum value absent data2 = { "id": "00000000-0000-0000-0000-000000000002", "name": "Test Item 2", - "created_at": "2025-01-01 12:00:00" + "created_at": "2025-01-01 12:00:00", } result2 = db.parse_query_to_pydantic(data2, TestModel) @@ -187,17 +196,17 @@ def test_parse_with_nested_model(self): "model.name": "Test Model", "model.context_window": 4096, "model.max_tokens": 1024, - "created_at": "2025-01-01 12:00:00" + "created_at": "2025-01-01 12:00:00", } - + # Parse the data into an Agent result = db.parse_query_to_pydantic(data, Agent) - + # Assertions self.assertIsNotNone(result) self.assertEqual(result.id, "00000000-0000-0000-0000-000000000001") self.assertEqual(result.name, "Test Agent") - + # Check nested model self.assertIsNotNone(result.model) self.assertEqual(result.model.id, "00000000-0000-0000-0000-000000000002") @@ -207,28 +216,29 @@ def test_parse_with_nested_model(self): def test_parse_with_optional_nested_model(self): """Test parsing a model with an optional nested model that is present.""" + # Create a simple test that doesn't involve complex nested models # First, let's create a simpler test class with an optional nested model class SimpleModel(BaseModel): name: str value: int - + class ContainerModel(BaseModel): id: str title: str nested: Optional[SimpleModel] = None - + # Create test data for a container with a nested model data = { "id": "test-id-123", "title": "Test Container", "nested.name": "Test Nested", - "nested.value": 42 + "nested.value": 42, } - + # Parse the data into a ContainerModel result = db.parse_query_to_pydantic(data, ContainerModel) - + # Assertions self.assertIsNotNone(result) self.assertEqual(result.id, "test-id-123") @@ -246,18 +256,18 @@ def test_parse_without_optional_nested_model(self): "title": "Test Task", "description": "A test task", "status": "new", - "created_at": "2025-01-01 12:00:00" + "created_at": "2025-01-01 12:00:00", } - + # Parse the data into a Task result = db.parse_query_to_pydantic(data, Task) - + # Assertions self.assertIsNotNone(result) self.assertEqual(result.id, "00000000-0000-0000-0000-000000000001") self.assertEqual(result.title, "Test Task") self.assertEqual(result.status, TaskStatus.NEW) - + # Check that agent is None self.assertIsNone(result.agent) @@ -266,7 +276,7 @@ def test_parse_empty_data(self): # Test with None result1 = db.parse_query_to_pydantic(None, Model) self.assertIsNone(result1) - + # Test with empty dict result2 = db.parse_query_to_pydantic({}, Model) self.assertIsNone(result2) @@ -283,7 +293,6 @@ def test_parse_complex_nested_structure(self): "description": "A subtask", "status": "in_progress", "created_at": "2025-01-01 12:00:00", - # Agent fields with double-underscore notation "agent.id": "00000000-0000-0000-0000-000000000002", "agent.name": "Test Agent", @@ -291,40 +300,38 @@ def test_parse_complex_nested_structure(self): "agent.instructions": "Test instructions", "agent.type": "specialist", "agent.created_at": "2025-01-01 11:00:00", - # Agent's model fields with double-underscore notation "agent.model.id": "00000000-0000-0000-0000-000000000004", "agent.model.name": "Agent Model", "agent.model.context_window": 4096, "agent.model.max_tokens": 1024, "agent.model.created_at": "2025-01-01 10:00:00", - # Parent task fields with double-underscore notation "parent_task.id": "00000000-0000-0000-0000-000000000003", "parent_task.task_no": 1, "parent_task.title": "Parent Task", "parent_task.description": "A parent task", "parent_task.status": "assigned_for_work", - "parent_task.created_at": "2025-01-01 09:00:00" + "parent_task.created_at": "2025-01-01 09:00:00", } - + # Parse the data into a Task - this should work by properly handling the nested fields # with double-underscore notation result = db.parse_query_to_pydantic(data, Task) - + # Assertions self.assertIsNotNone(result) self.assertEqual(result.id, "00000000-0000-0000-0000-000000000001") self.assertEqual(result.title, "Subtask") self.assertEqual(result.status, TaskStatus.IN_PROGRESS) - + # Check nested agent self.assertIsNotNone(result.agent) self.assertEqual(result.agent.id, "00000000-0000-0000-0000-000000000002") self.assertEqual(result.agent.name, "Test Agent") self.assertIsNotNone(result.agent.model) self.assertEqual(result.agent.model.name, "Agent Model") - + # Check parent task self.assertIsNotNone(result.parent_task) self.assertEqual(result.parent_task.id, "00000000-0000-0000-0000-000000000003") @@ -332,5 +339,5 @@ def test_parse_complex_nested_structure(self): self.assertEqual(result.parent_task.status, TaskStatus.ASSIGNED_FOR_WORK) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_generated_id.py b/tests/test_generated_id.py index 68cc80f..f2dbc83 100644 --- a/tests/test_generated_id.py +++ b/tests/test_generated_id.py @@ -1,13 +1,17 @@ from typing import List, Optional -from tests import common + from pydantic import BaseModel, Field +from tests import common + + class User(BaseModel): id: Optional[int] = Field(default=None) name: str email: str role: str + TABLES_SCHEMA = """ CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -22,6 +26,7 @@ class User(BaseModel): query = common.create_query(schema=TABLES_SCHEMA) + @query def get_users() -> List[User]: """ @@ -29,6 +34,7 @@ def get_users() -> List[User]: """ pass + @query def create_user(user: User) -> User: """ @@ -36,17 +42,18 @@ def create_user(user: User) -> User: """ pass + class TestQuery(common.DatabaseTests): schema_sql = TABLES_SCHEMA - + def test_users(self): users = get_users() self.assertEqual(len(users), 0) - + user = User(name="John Doe", email="john@example.com", role="user") create_user(user=user) - + users = get_users() self.assertEqual(len(users), 1) self.assertEqual(users[0].id, 1) diff --git a/tests/test_many_to_one.py b/tests/test_many_to_one.py index a638200..cc3bc66 100644 --- a/tests/test_many_to_one.py +++ b/tests/test_many_to_one.py @@ -1,20 +1,24 @@ -from typing import List, Optional import os import shutil -from tests import common +from typing import List, Optional + from pydantic import BaseModel +from tests import common + class Workspace(BaseModel): id: int name: str + class Task(BaseModel): id: int workspace: Workspace title: str description: Optional[str] = None + TABLES_SCHEMA = """ CREATE TABLE IF NOT EXISTS workspaces ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -34,6 +38,7 @@ class Task(BaseModel): CACHE_DIR = "__sql__" + class TestWorkspaceTasks(common.DatabaseTests): schema_sql = TABLES_SCHEMA @@ -90,6 +95,7 @@ def setUpClass(cls) -> None: """.strip() ) + @query def create_workspace(name: str) -> Workspace: """ @@ -97,13 +103,17 @@ def create_workspace(name: str) -> Workspace: """ pass + @query -def add_task_to_workspace(workspace: Workspace, title: str, description: Optional[str] = None) -> Task: +def add_task_to_workspace( + workspace: Workspace, title: str, description: Optional[str] = None +) -> Task: """ Inserts a new task into the workspace and returns the Task object. """ pass + @query def get_tasks_for_workspace(workspace: Workspace) -> List[Task]: """ @@ -117,8 +127,12 @@ def test_workspace_tasks(self): self.assertIsInstance(ws, Workspace) # Add tasks - task1 = add_task_to_workspace(workspace=ws, title="Setup repo", description="Initialize git repository") - task2 = add_task_to_workspace(workspace=ws, title="Write docs", description="Document the setup process") + task1 = add_task_to_workspace( + workspace=ws, title="Setup repo", description="Initialize git repository" + ) + task2 = add_task_to_workspace( + workspace=ws, title="Write docs", description="Document the setup process" + ) self.assertIsInstance(task1, Task) self.assertIsInstance(task2, Task) diff --git a/tests/test_nested_query.py b/tests/test_nested_query.py index f33b539..1b4650d 100644 --- a/tests/test_nested_query.py +++ b/tests/test_nested_query.py @@ -1,18 +1,23 @@ -from typing import List, Optional import os import shutil -from tests import common +from typing import List, Optional + from pydantic import BaseModel +from tests import common + + class Address(BaseModel): street: str city: str zip_code: str + class Profile(BaseModel): bio: Optional[str] = None address: Optional[Address] = None + class UserWithProfile(BaseModel): id: str name: str @@ -20,6 +25,7 @@ class UserWithProfile(BaseModel): role: str profile: Optional[Profile] = None + TABLES_SCHEMA = """ CREATE TABLE IF NOT EXISTS users_with_profile ( id VARCHAR(36) PRIMARY KEY, @@ -38,6 +44,7 @@ class UserWithProfile(BaseModel): CACHE_DIR = "__sql__" + class TestNestedQuery(common.DatabaseTests): schema_sql = TABLES_SCHEMA @@ -86,6 +93,7 @@ def setUpClass(cls) -> None: """.strip() ) + @query def get_users_with_profile() -> List[UserWithProfile]: """ @@ -93,6 +101,7 @@ def get_users_with_profile() -> List[UserWithProfile]: """ pass + @query def create_user_with_profile(user: UserWithProfile) -> int: """ @@ -104,66 +113,66 @@ def test_nested_object_creation_and_retrieval(self): # Initially, no users users = get_users_with_profile() self.assertEqual(len(users), 0) - + # Create a user with a full nested profile address = Address(street="123 Main St", city="Anytown", zip_code="12345") profile = Profile(bio="Software Engineer", address=address) user = UserWithProfile( - id="nested_user_1", - name="Jane Doe", - email="jane@example.com", - role="user", - profile=profile + id="nested_user_1", + name="Jane Doe", + email="jane@example.com", + role="user", + profile=profile, ) - + # Store the user create_user_with_profile(user=user) - + # Retrieve and verify retrieved_users = get_users_with_profile() self.assertEqual(len(retrieved_users), 1) - + retrieved_user = retrieved_users[0] self.assertEqual(retrieved_user.id, "nested_user_1") self.assertEqual(retrieved_user.name, "Jane Doe") self.assertEqual(retrieved_user.email, "jane@example.com") self.assertEqual(retrieved_user.role, "user") - + # Check nested profile self.assertIsNotNone(retrieved_user.profile) self.assertEqual(retrieved_user.profile.bio, "Software Engineer") - + # Check nested address self.assertIsNotNone(retrieved_user.profile.address) self.assertEqual(retrieved_user.profile.address.street, "123 Main St") self.assertEqual(retrieved_user.profile.address.city, "Anytown") self.assertEqual(retrieved_user.profile.address.zip_code, "12345") - + def test_nested_object_with_partial_data(self): # Create a user with a partial profile user = UserWithProfile( - id="nested_user_2", - name="John Smith", - email="john@example.com", - role="guest", - profile=Profile(bio="Data Scientist") + id="nested_user_2", + name="John Smith", + email="john@example.com", + role="guest", + profile=Profile(bio="Data Scientist"), ) - + # Store the user create_user_with_profile(user=user) - + # Retrieve and verify retrieved_users = get_users_with_profile() self.assertEqual(len(retrieved_users), 1) - + # Find the newly added user retrieved_user = next(u for u in retrieved_users if u.id == "nested_user_2") - + self.assertEqual(retrieved_user.id, "nested_user_2") self.assertEqual(retrieved_user.name, "John Smith") self.assertEqual(retrieved_user.email, "john@example.com") self.assertEqual(retrieved_user.role, "guest") - + # Check partial profile self.assertIsNotNone(retrieved_user.profile) self.assertEqual(retrieved_user.profile.bio, "Data Scientist") diff --git a/tests/test_query_sync_and_async.py b/tests/test_query_sync_and_async.py index 3aec2f6..6517ad8 100644 --- a/tests/test_query_sync_and_async.py +++ b/tests/test_query_sync_and_async.py @@ -1,20 +1,21 @@ +import inspect import os import shutil import unittest -import inspect from typing import List from pydantic import BaseModel -from foundation_sql.query import SQLQueryDecorator from foundation_sql import db from foundation_sql.db_drivers import AsyncpgAdapter - +from foundation_sql.query import SQLQueryDecorator SQLITE_DB_URL = "sqlite:///__test_sync.sqlite3" # Attempt to ensure DATABASE_URL is available by reading .env if needed -if not os.environ.get("DATABASE_URL") and os.path.exists(os.path.join(os.path.dirname(__file__), "..", ".env")): +if not os.environ.get("DATABASE_URL") and os.path.exists( + os.path.join(os.path.dirname(__file__), "..", ".env") +): env_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".env")) try: with open(env_path, "r") as f: @@ -30,7 +31,9 @@ except Exception: pass -ASYNC_DB_URL = os.environ.get("DATABASE_URL") # e.g., postgresql://user:pass@localhost:5432/dbname +ASYNC_DB_URL = os.environ.get( + "DATABASE_URL" +) # e.g., postgresql://user:pass@localhost:5432/dbname CACHE_DIR_SYNC = "__sql__/__sql_sync__" CACHE_DIR_ASYNC = "__sql__/__sql_async__" @@ -56,7 +59,9 @@ def setUpClass(cls) -> None: with open(os.path.join(CACHE_DIR_SYNC, "get_users.sql"), "w") as f: f.write("SELECT id, name FROM users ORDER BY id;") with open(os.path.join(CACHE_DIR_SYNC, "create_user.sql"), "w") as f: - f.write("INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});") + f.write( + "INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});" + ) # Prepare SQLite DB file cleanly if os.path.exists("__test_sync.sqlite3"): @@ -78,7 +83,9 @@ def tearDown(self) -> None: db.DATABASES.clear() def test_sync_wrappers_and_execution(self): - query = SQLQueryDecorator(schema=self.TABLES_SCHEMA, db_url=SQLITE_DB_URL, cache_dir=CACHE_DIR_SYNC) + query = SQLQueryDecorator( + schema=self.TABLES_SCHEMA, db_url=SQLITE_DB_URL, cache_dir=CACHE_DIR_SYNC + ) @query def get_users() -> List["TestSQLQueryDecoratorSync.User"]: @@ -124,7 +131,9 @@ def setUpClass(cls) -> None: with open(os.path.join(CACHE_DIR_ASYNC, "get_users.sql"), "w") as f: f.write("SELECT id, name FROM users ORDER BY id;") with open(os.path.join(CACHE_DIR_ASYNC, "create_user.sql"), "w") as f: - f.write("INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});") + f.write( + "INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});" + ) async def asyncSetUp(self): # Initialize schema on Postgres using async adapter each test for isolation @@ -138,7 +147,9 @@ async def asyncTearDown(self): db.DATABASES.clear() async def test_async_wrappers_and_execution(self): - query = SQLQueryDecorator(schema=self.TABLES_SCHEMA, db_url=ASYNC_DB_URL, cache_dir=CACHE_DIR_ASYNC) + query = SQLQueryDecorator( + schema=self.TABLES_SCHEMA, db_url=ASYNC_DB_URL, cache_dir=CACHE_DIR_ASYNC + ) @query async def get_users() -> List["TestSQLQueryDecoratorAsync.User"]: diff --git a/tests/test_schema_discovery.py b/tests/test_schema_discovery.py new file mode 100644 index 0000000..c2915e1 --- /dev/null +++ b/tests/test_schema_discovery.py @@ -0,0 +1,91 @@ +import os +import sqlite3 +from typing import List + +from pydantic import BaseModel + +from tests import common + +# --- Start of moved code from tests/utils.py --- + +BIKES_DB_PATH = os.path.abspath( + os.path.join(os.path.dirname(__file__), "fixtures", "bikes.db") +) + + +def create_bike_db(): + os.makedirs(os.path.dirname(BIKES_DB_PATH), exist_ok=True) + + if os.path.exists(BIKES_DB_PATH): + os.remove(BIKES_DB_PATH) + + conn = sqlite3.connect(BIKES_DB_PATH) + cursor = conn.cursor() + cursor.execute( + """ + CREATE TABLE bikes ( + make TEXT NOT NULL, + model TEXT NOT NULL, + price INTEGER NOT NULL + ); + """ + ) + conn.commit() + conn.close() + + +# --- End of moved code --- + + +class Bike(BaseModel): + make: str + model: str + price: int + + +create_bike_db() + +query = common.create_query(db_url=f"sqlite:///{BIKES_DB_PATH}", schema_inspect=True) + + +@query +def get_bikes() -> List[Bike]: + """ + Gets all bikes. + """ + pass + + +@query +def create_bike(bike: Bike) -> Bike: + """ + Creates a new bike. + """ + pass + + +@query +def get_total_price() -> int: + """ + Calculates the total price of all the bikes. + """ + pass + + +class TestSchemaDiscovery(common.DatabaseTests): + db_url = f"sqlite:///{BIKES_DB_PATH}" + schema_sql = None + + def test_schema_discovery(self): + + re_bike = Bike(make="RE", model="Classic", price=600) + create_bike(bike=re_bike) + + harley_bike = Bike(make="Harley", model="A very good one", price=500) + create_bike(bike=harley_bike) + + bikes = get_bikes() + self.assertEqual(len(bikes), 2) + + price = get_total_price() + self.assertEqual(price, 1100) diff --git a/tests/test_simple_query.py b/tests/test_simple_query.py index ef9ee1c..4cc7390 100644 --- a/tests/test_simple_query.py +++ b/tests/test_simple_query.py @@ -1,8 +1,10 @@ -from typing import List -from tests import common -from pydantic import BaseModel import os import shutil +from typing import List + +from pydantic import BaseModel + +from tests import common class User(BaseModel): @@ -11,6 +13,7 @@ class User(BaseModel): email: str role: str + TABLES_SCHEMA = """ CREATE TABLE IF NOT EXISTS users ( id VARCHAR(36) PRIMARY KEY, @@ -24,6 +27,7 @@ class User(BaseModel): query = common.create_query(schema=TABLES_SCHEMA) + @query def get_users() -> List[User]: """ @@ -31,6 +35,7 @@ def get_users() -> List[User]: """ pass + @query def create_user(user: User) -> int: """ @@ -38,6 +43,7 @@ def create_user(user: User) -> int: """ pass + class TestQuery(common.DatabaseTests): schema_sql = TABLES_SCHEMA @@ -81,14 +87,14 @@ def setUpClass(cls) -> None: """ ).strip() ) - + def test_users(self): users = get_users() self.assertEqual(len(users), 0) - + user = User(id="xxx", name="John Doe", email="john@example.com", role="user") create_user(user=user) - + users = get_users() self.assertEqual(len(users), 1) self.assertEqual(users[0], user)