From faa490a63fe1b6b946f82d2282d5629de5c4c968 Mon Sep 17 00:00:00 2001 From: Rishi Pradeep Date: Sun, 8 Jun 2025 01:52:57 +0530 Subject: [PATCH 1/3] Schema discovery implementation --- .gitignore | 3 +- foundation_sql/db.py | 27 +++++++++++++++-- foundation_sql/prompts.md | 28 +++++++++--------- foundation_sql/query.py | 16 ++++++---- tests/common.py | 18 ++++++++---- tests/test_schema_discovery.py | 54 ++++++++++++++++++++++++++++++++++ tests/utils.py | 26 ++++++++++++++++ 7 files changed, 144 insertions(+), 28 deletions(-) create mode 100644 tests/test_schema_discovery.py create mode 100644 tests/utils.py diff --git a/.gitignore b/.gitignore index b9e30d2..92a10cd 100644 --- a/.gitignore +++ b/.gitignore @@ -37,4 +37,5 @@ ENV/ # Misc .DS_Store -__sql__ \ No newline at end of file +__sql__ +fixtures \ No newline at end of file diff --git a/foundation_sql/db.py b/foundation_sql/db.py index c2541fe..ca9a022 100644 --- a/foundation_sql/db.py +++ b/foundation_sql/db.py @@ -10,7 +10,8 @@ from sqlalchemy import create_engine, text from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError - +from sqlalchemy import MetaData +from sqlalchemy.schema import CreateTable from jinja2sql import Jinja2SQL from datetime import datetime @@ -232,7 +233,28 @@ def is_empty(self) -> bool: True if no rows, False otherwise """ return len(self.rows) == 0 + + +# Function to load the schema from the database +def extract_schema_from_db(db_url: str) -> str: + """Extract the schema from the database. + + Args: + db_url: Database URL to use + + Returns: + Schema as a string + """ + engine = create_engine(db_url) + metadata = MetaData() + metadata.reflect(bind=engine) + + schema_lines = [] + for table in metadata.sorted_tables: + ddl = str(CreateTable(table).compile(engine)) + schema_lines.append(ddl + ";") + return "\n\n".join(schema_lines) def get_db(db_url: str) -> Database: @@ -280,7 +302,8 @@ def parse_query_to_pydantic(data: Dict[str, Any], model_class: Type[BaseModel]) # Check the response type and transform accordingly if model_class == int: - return int(unflattened_data["result"]) + # FIX : STILL ONLY GETS FIRST LINE OF RESPONSE + return int(next(iter(unflattened_data.values()))) elif model_class == NoneType: return None diff --git a/foundation_sql/prompts.md b/foundation_sql/prompts.md index 6809d26..5a45a37 100644 --- a/foundation_sql/prompts.md +++ b/foundation_sql/prompts.md @@ -1,12 +1,12 @@ You are an expert SQL developer. Write one or more SQL queries that can perform the actions as explained by the user. Ensure, the SQL query is usable across sqlite and postgresql. The SQL template generated is a jinja2 template - so jinja2 syntax can be used. -1. Start with a -- comment to document the function name, parameters and docstring, explaining what the SQL query does. +1. Start with a comment to document the function name, parameters and docstring, explaining what the SQL query does. Make sure to start comments with `--` (Only 2 dashes, no more , no less) 2. Use jinja2 template to generate SQL 3. When accessing nested fields handle cases if they aren't defined. Use default filter with None value for such cases e.g. -{{user.zip_code|default(None)}} + {{user.zip_code|default(None)}} 4. Ensure response rows can be parsed into Pydantic model. As long as the model fields are named the same as the columns in the SQL query. It also supports nested models by using double underscores to separate nested fields. -5. For complex tasks, more than one queries can be run, separated by ';' -6. Only respond with a single ```sql``` block which contains all queries. +5. For complex tasks, more than one queries can be run, separated by ';', Make sure queries end with ';'. +6. Only respond with a single `sql` block which contains all queries. 7. No other explanation is necessary 8. For insert queries, avoid any RETURNING clause. Let it return the default. 9. We use jinja2 syntax to generate SQL - so parameters don't need to be quoted e.g. use {{user.zip_code|default(None)}} and not '{{user.zip_code|default(None)}}' @@ -19,18 +19,18 @@ You are an expert SQL developer. Write one or more SQL queries that can perform Here is an example def get_task(workspace: schema.Workspace, task_no: int) -> schema.Task: - """ - Creates and returns a Task object, for the provided workspace and task_no - """ - pass - +""" +Creates and returns a Task object, for the provided workspace and task_no +""" +pass The SQL generated would look like the following + ```sql - --- def get_task(workspace: schema.Workspace, task_no: int) -> schema.Task - --- Creates and returns a Task object, for the provided workspace and task_no - --- Expects task_no and workspace.id are defined. If no tasks are found, returns None - SELECT + -- def get_task(workspace: schema.Workspace, task_no: int) -> schema.Task; + -- Creates and returns a Task object, for the provided workspace and task_no; + -- Expects task_no and workspace.id are defined. If no tasks are found, returns None; + SELECT t.id as `id`, t.task_no as `task_no`, t.title as `title`, @@ -55,7 +55,7 @@ The SQL generated would look like the following LEFT JOIN agents a ON t.agent_id = a.id LEFT JOIN models m ON a.model_id = m.id LEFT JOIN workspace_tasks wt ON t.id = wt.task_id - WHERE t.task_no = {{task_no}} AND wt.workspace_id = {{workspace.id}} + WHERE t.task_no = {{task_no}} AND wt.workspace_id = {{workspace.id}}; ``` Below are the real specifications for which query needs to be generated. diff --git a/foundation_sql/query.py b/foundation_sql/query.py index 51a0caf..edabf79 100644 --- a/foundation_sql/query.py +++ b/foundation_sql/query.py @@ -57,15 +57,21 @@ def __init__( self.name = name self.regen = regen self.cache_dir = cache_dir - self.schema = schema or self.load_file(schema_path) + self.db_url = db_url or os.environ.get("DATABASE_URL") + if ( not self.db_url): + raise ValueError(f"Database URL not provided either through constructor or DATABASE_URL environment variable") + + if (not schema and not schema_path): + # Load the schema from the database + self.schema = db.extract_schema_from_db(self.db_url) + + else: + self.schema = schema or self.load_file(schema_path) + if system_prompt or system_prompt_path: self.system_prompt = system_prompt or self.load_file(system_prompt_path) else: self.system_prompt = DEFAULT_SYSTEM_PROMPT - - self.db_url = db_url - if not self.db_url: - raise ValueError(f"Database URL not provided either through constructor or {db_url_env} environment variable") # Initialize cache and SQL generator self.cache = SQLTemplateCache(cache_dir=cache_dir) diff --git a/tests/common.py b/tests/common.py index a4de0bf..fc4f3da 100644 --- a/tests/common.py +++ b/tests/common.py @@ -2,15 +2,21 @@ import os from foundation_sql import db from foundation_sql.query import SQLQueryDecorator +from typing import Optional from dotenv import load_dotenv load_dotenv() -DB_URL = os.environ.get("DATABSE_URL", "sqlite:///:memory:") +# DB_URL = os.environ.get("DATABASE_URL", "sqlite:///:memory:") -def create_query(schema): +def create_query (schema: Optional[str] = None, db_url : Optional[str] = None): + final_db_url = ( + db_url or + os.environ.get("DATABASE_URL") or + "sqlite:///:memory:" + ) return SQLQueryDecorator(schema=schema, - db_url=DB_URL, + db_url=final_db_url, api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_API_BASE_URL"), model=os.getenv("OPENAI_MODEL")) @@ -18,17 +24,17 @@ def create_query(schema): class DatabaseTests(unittest.TestCase): """Base test class for database-driven tests with common setup and helper methods.""" - db_url = DB_URL + db_url = "sqlite:///:memory:" schema_sql = None schema_path = None def setUp(self): """Create a fresh database connection for each test.""" - # Re-initialize the schema for each test to ensure clean state + #Re-initialize the schema for each test to ensure clean state if (self.schema_sql or self.schema_path) and self.db_url: db.get_db(self.db_url).init_schema(schema_sql=self.schema_sql, schema_path=self.schema_path) else: - raise ValueError("At least one of schema_sql, schema_path must be provided along with db_url") + pass def tearDown(self): diff --git a/tests/test_schema_discovery.py b/tests/test_schema_discovery.py new file mode 100644 index 0000000..206332e --- /dev/null +++ b/tests/test_schema_discovery.py @@ -0,0 +1,54 @@ +from typing import List +from tests import common +from pydantic import BaseModel +from tests.utils import BIKES_DB_PATH, create_bike_db + +class Bike(BaseModel): + make: str + model: str + price: int + + +create_bike_db() + +query = common.create_query(db_url=f"sqlite:///{BIKES_DB_PATH}") + +@query +def get_bikes() -> List[Bike]: + """ + Gets all bikes. + """ + pass + +@query +def create_bike(bike: Bike) -> Bike: + """ + Creates a new bike. + """ + pass + +@query +def get_total_price() -> int: + """ + Get the total price of all the bikes + """ + pass + + +class TestSchemaDiscovery(common.DatabaseTests): + db_url = f"sqlite:///{BIKES_DB_PATH}" + schema_sql = None + + def test_schema_discovery(self): + + re_bike = Bike(make="RE", model="Classic", price=600) + create_bike(bike=re_bike) + + harley_bike = Bike(make="Harley", model="A very good one", price=500) + create_bike(bike = harley_bike) + + bikes = get_bikes() + self.assertEqual(len(bikes), 2) + + price = get_total_price() + self.assertEqual(price,1100) \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..07b73b3 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,26 @@ +# utils.py +import os +import sqlite3 + +BIKES_DB_PATH = os.path.abspath( + os.path.join(os.path.dirname(__file__), "fixtures", "bikes.db") +) + +def create_bike_db(): + os.makedirs(os.path.dirname(BIKES_DB_PATH), exist_ok=True) + + + if os.path.exists(BIKES_DB_PATH): + os.remove(BIKES_DB_PATH) + + conn = sqlite3.connect(BIKES_DB_PATH) + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE bikes ( + make TEXT NOT NULL, + model TEXT NOT NULL, + price INTEGER NOT NULL + ); + """) + conn.commit() + conn.close() From ebcdc09ebe1b5e61b39c410977b3364339e1fe5d Mon Sep 17 00:00:00 2001 From: Rishi Pradeep Date: Sat, 4 Oct 2025 04:33:14 +0530 Subject: [PATCH 2/3] Fixed formatting and resolve comments --- .gitignore | 3 +- foundation_sql/cache.py | 39 ++++--- foundation_sql/db.py | 179 ++++++++++++++++------------- foundation_sql/db_drivers.py | 114 ++++++++++-------- foundation_sql/gen.py | 37 +++--- foundation_sql/prompt.py | 60 +++++----- foundation_sql/query.py | 112 +++++++++++------- requirements.txt | 6 +- setup.py | 6 +- tests/common.py | 34 ++++-- tests/test_adapters.py | 3 +- tests/test_db_parser.py | 89 +++++++------- tests/test_generated_id.py | 15 ++- tests/test_many_to_one.py | 24 +++- tests/test_nested_query.py | 57 +++++---- tests/test_query_sync_and_async.py | 29 +++-- tests/test_schema_discovery.py | 49 +++++++- tests/test_simple_query.py | 18 ++- 18 files changed, 526 insertions(+), 348 deletions(-) diff --git a/.gitignore b/.gitignore index 92a10cd..b9e30d2 100644 --- a/.gitignore +++ b/.gitignore @@ -37,5 +37,4 @@ ENV/ # Misc .DS_Store -__sql__ -fixtures \ No newline at end of file +__sql__ \ No newline at end of file diff --git a/foundation_sql/cache.py b/foundation_sql/cache.py index 00f4699..3f21ab2 100644 --- a/foundation_sql/cache.py +++ b/foundation_sql/cache.py @@ -1,83 +1,84 @@ import os from typing import Optional + class SQLTemplateCache: """ Simple file-based cache using file names as keys. - + Attributes: cache_dir (str): Directory to store cached templates """ - + def __init__(self, cache_dir: str): """ Initialize the SQL template cache. - + Args: cache_dir (str): Directory to store cached templates """ self.cache_dir = cache_dir os.makedirs(cache_dir, exist_ok=True) - + def _get_cache_path(self, key: str) -> str: """ Get the full path for a cache file. - + Args: key (str): File name to use as cache key - + Returns: str: Full path to the cache file """ return os.path.join(self.cache_dir, key) - + def set(self, key: str, template: str) -> None: """ Store an SQL template in the cache. - + Args: key (str): File name to use as cache key template (str): SQL template to cache """ cache_file = self._get_cache_path(key) - with open(cache_file, 'w') as f: + with open(cache_file, "w") as f: f.write(template) - + def get(self, key: str) -> Optional[str]: """ Retrieve a cached SQL template. - + Args: key (str): File name to use as cache key - + Returns: Optional[str]: Cached SQL template or None if not found """ cache_file = self._get_cache_path(key) if os.path.exists(cache_file): - with open(cache_file, 'r') as f: + with open(cache_file, "r") as f: return f.read() return None - + def exists(self, key: str) -> bool: """ Check if a cache entry exists. - + Args: key (str): File name to use as cache key - + Returns: bool: True if cache entry exists, False otherwise """ cache_file = self._get_cache_path(key) return os.path.exists(cache_file) - + def clear(self, key: Optional[str] = None) -> None: """ Clear cached templates. - + Args: - key (Optional[str]): Specific key to clear. + key (Optional[str]): Specific key to clear. If None, clears entire cache. """ if key: diff --git a/foundation_sql/db.py b/foundation_sql/db.py index e55828d..61a678b 100644 --- a/foundation_sql/db.py +++ b/foundation_sql/db.py @@ -4,18 +4,18 @@ import logging import os +from datetime import datetime from types import NoneType -from typing import Dict, Any, Optional, Type, Union, List +from typing import Any, Dict, List, Optional, Type, Union + +from jinja2sql import Jinja2SQL from pydantic import BaseModel -from sqlalchemy import text +from sqlalchemy import MetaData, create_engine, text from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy import MetaData from sqlalchemy.schema import CreateTable -from jinja2sql import Jinja2SQL -from datetime import datetime -from foundation_sql.db_drivers import EngineAdapter, SQLAlchemyAdapter -from foundation_sql.db_drivers import AsyncpgAdapter + +from foundation_sql.db_drivers import AsyncpgAdapter, EngineAdapter, SQLAlchemyAdapter NESTED_SPLITTER = "." # Singleton instance @@ -26,22 +26,25 @@ logger = logging.getLogger(__name__) - class Database: """Database operations for Foundation delegated to an EngineAdapter.""" - - def __init__(self, db_url: Optional[str] = None, adapter: Optional[EngineAdapter] = None) -> None: + + def __init__( + self, db_url: Optional[str] = None, adapter: Optional[EngineAdapter] = None + ) -> None: """Initialize the database facade. - + Args: db_url: Database URL (e.g., postgresql://user:pass@host/db, sqlite:///path/to/db) If not provided, will use DATABASE_URL environment variable adapter: Optional explicit adapter. If not provided, a default sync adapter is created. """ - self.db_url = db_url or os.getenv('DATABASE_URL') + self.db_url = db_url or os.getenv("DATABASE_URL") if not self.db_url: - raise ValueError('Database URL must be provided either through constructor or DATABASE_URL environment variable') - + raise ValueError( + "Database URL must be provided either through constructor or DATABASE_URL environment variable" + ) + # Default to sync SQLAlchemy adapter unless explicitly provided self.adapter: EngineAdapter = adapter or SQLAlchemyAdapter(self.db_url) @@ -56,7 +59,9 @@ def get_engine(self) -> Engine: return self.adapter.engine raise RuntimeError("Engine is not available for this adapter") - def init_schema(self, schema_sql:Optional[str]=None, schema_path: Optional[str] = None) -> None: + def init_schema( + self, schema_sql: Optional[str] = None, schema_path: Optional[str] = None + ) -> None: """Initialize the database schema if it doesn't exist. This method runs the schema creation script in an idempotent way. @@ -68,26 +73,26 @@ def init_schema(self, schema_sql:Optional[str]=None, schema_path: Optional[str] If not provided, will use the default schema at data/tables.sql """ if not schema_sql: - with open(schema_path, 'r') as f: + with open(schema_path, "r") as f: schema_sql = f.read() # Delegate to adapter self.adapter.init_schema(schema_sql) def run_sql(self, sql_template: str, **context) -> Any: """Run an SQL template string with jinja2sql for rendering and parameter substitution. - + Args: sql_template: SQL template string with jinja2sql syntax **context: Context variables for template rendering - + Returns: For SELECT queries: A QueryResult object with methods for data access For INSERT/UPDATE/DELETE queries: The number of rows affected """ # Add datetime.now function to context if needed - if 'now' not in context: - context['now'] = datetime.now - + if "now" not in context: + context["now"] = datetime.now + # Delegate to adapter; it returns either list[dict] rows or int rowcount result = self.adapter.run_sql(sql_template, context) if isinstance(result, int): @@ -96,109 +101,116 @@ def run_sql(self, sql_template: str, **context) -> Any: return QueryResult(result) # ---------- Async delegates (Phase 2) ---------- - async def init_schema_async(self, schema_sql: Optional[str] = None, schema_path: Optional[str] = None) -> None: - if not hasattr(self.adapter, 'init_schema_async'): + async def init_schema_async( + self, schema_sql: Optional[str] = None, schema_path: Optional[str] = None + ) -> None: + if not hasattr(self.adapter, "init_schema_async"): raise NotImplementedError("Async schema init not supported by this adapter") if not schema_sql: - with open(schema_path, 'r') as f: + with open(schema_path, "r") as f: schema_sql = f.read() await self.adapter.init_schema_async(schema_sql) # type: ignore[attr-defined] async def run_sql_async(self, sql_template: str, **context) -> Any: - if not hasattr(self.adapter, 'run_sql_async'): + if not hasattr(self.adapter, "run_sql_async"): raise NotImplementedError("Async run_sql not supported by this adapter") - if 'now' not in context: - context['now'] = datetime.now + if "now" not in context: + context["now"] = datetime.now result = await self.adapter.run_sql_async(sql_template, context) # type: ignore[attr-defined] if isinstance(result, int): return result return QueryResult(result) async def close_async(self) -> None: - if hasattr(self.adapter, 'close_async'): + if hasattr(self.adapter, "close_async"): await self.adapter.close_async() # type: ignore[attr-defined] - - def execute(self, sql: str, params: Optional[Union[tuple, dict, List[tuple]]] = None) -> Any: + def execute( + self, sql: str, params: Optional[Union[tuple, dict, List[tuple]]] = None + ) -> Any: """ Execute a raw SQL statement with optional parameters. - + Args: sql (str): SQL statement to execute - params (Optional[Union[tuple, dict, List[tuple]]]): + params (Optional[Union[tuple, dict, List[tuple]]]): Optional parameters for the SQL statement - Single tuple for single parameter set - List of tuples for multiple parameter sets (bulk insert) - Dictionary for named parameters - + Returns: Any: Result of the execution """ # Only supported on SQLAlchemy adapter path for now if not isinstance(self.adapter, SQLAlchemyAdapter): - raise NotImplementedError("execute() is only supported for SQLAlchemy adapter") + raise NotImplementedError( + "execute() is only supported for SQLAlchemy adapter" + ) with self.adapter.engine.connect() as connection: try: # Replace '?' placeholders with SQLAlchemy named parameters - if '?' in sql: + if "?" in sql: # Count the number of placeholders - placeholder_count = sql.count('?') - + placeholder_count = sql.count("?") + # Replace '?' with named parameters - named_sql = sql.replace('?', ':{}'.format) - + named_sql = sql.replace("?", ":{}".format) + # Prepare parameters if params is None: named_params = {} elif isinstance(params, (tuple, list)): # Convert tuple/list to named dictionary - named_params = {f'p{i}': val for i, val in enumerate(params)} - named_sql = named_sql.format(p=lambda i: f':p{i}') + named_params = {f"p{i}": val for i, val in enumerate(params)} + named_sql = named_sql.format(p=lambda i: f":p{i}") elif isinstance(params, dict): named_params = params else: - raise ValueError("Invalid parameter type. Must be tuple, dict, or list of tuples.") - + raise ValueError( + "Invalid parameter type. Must be tuple, dict, or list of tuples." + ) + # Execute with named parameters result = connection.execute(text(named_sql), named_params) else: # If no '?' placeholders, use as-is result = connection.execute(text(sql), params or {}) - + # If it's a SELECT query, return the rows if result.returns_rows: return result.fetchall() - + # For INSERT, UPDATE, DELETE, return the number of rows affected return result.rowcount - + except SQLAlchemyError as e: raise RuntimeError(f"Database execution error: {str(e)}") from e class QueryResult: """A clean abstraction over query results that doesn't leak implementation details.""" - + def __init__(self, rows: List[Dict[str, Any]]): """Initialize with a list of row dictionaries. - + Args: rows: List of dictionaries representing database rows """ self.rows = rows - + def first(self) -> Optional[Dict[str, Any]]: """Get the first row as a dictionary or None if no rows. - + Returns: First row as a dictionary or None """ return self.rows[0] if self.rows else None - + def all(self) -> List[Dict[str, Any]]: """Get all rows as a list of dictionaries. - + Returns: List of dictionaries representing all rows """ @@ -206,28 +218,34 @@ def all(self) -> List[Dict[str, Any]]: def count(self) -> int: """Get the number of rows. - + Returns: Number of rows """ return len(self.rows) - + def is_empty(self) -> bool: """Check if the result contains any rows. - + Returns: True if no rows, False otherwise """ return len(self.rows) == 0 - + + def scalar(self) -> Optional[Any]: + """Get the first value from the first row.""" + if self.rows and self.rows[0]: + return next(iter(self.rows[0].values()), None) + return None + # Function to load the schema from the database def extract_schema_from_db(db_url: str) -> str: """Extract the schema from the database. - + Args: db_url: Database URL to use - + Returns: Schema as a string """ @@ -245,18 +263,19 @@ def extract_schema_from_db(db_url: str) -> str: def get_db(db_url: str) -> Database: """Get the database instance. - + Args: db_url: Database URL to use - + Returns: Database instance """ if db_url not in DATABASES: DATABASES[db_url] = Database(db_url) - + return DATABASES[db_url] + def get_db_with_adapter(db_url: str, mode: str) -> Database: """Internal helper for selecting adapter explicitly. mode: "sync" | "async" (async not implemented yet) @@ -267,13 +286,14 @@ def get_db_with_adapter(db_url: str, mode: str) -> Database: return Database(db_url, adapter=AsyncpgAdapter(db_url)) raise ValueError(f"Unknown adapter mode: {mode}") + def run_sql(db_url: str, sql_template: str, **context) -> Any: """Run an SQL template string with jinja2sql for rendering and parameter substitution. - + Args: sql_template: SQL template string with jinja2sql syntax **context: Context variables for template rendering - + Returns: For SELECT queries: A QueryResult object with methods for data access For INSERT/UPDATE/DELETE queries: The number of rows affected @@ -281,18 +301,20 @@ def run_sql(db_url: str, sql_template: str, **context) -> Any: return get_db(db_url).run_sql(sql_template, **context) -def parse_query_to_pydantic(data: Dict[str, Any], model_class: Type[BaseModel]) -> Optional[BaseModel]: +def parse_query_to_pydantic( + data: Dict[str, Any], model_class: Type[BaseModel] +) -> Optional[BaseModel]: """Parse query result data into a Pydantic model, handling nested models. - + Args: data: Dictionary containing query results with optional nested fields model_class: The Pydantic model class to instantiate - + Returns: Instance of the Pydantic model or None if data is None """ if not data: - return None + return None unflattened_data = unflatten_dict(data) @@ -302,25 +324,24 @@ def parse_query_to_pydantic(data: Dict[str, Any], model_class: Type[BaseModel]) return int(next(iter(unflattened_data.values()))) elif model_class == NoneType: return None - - return model_class(**unflattened_data) + return model_class(**unflattened_data) def unflatten_dict(flat_dict: Dict[str, Any]) -> Dict[str, Any]: """Convert a flattened dictionary with keys like 'parent.child.grandchild' (using NESTED_SPLITTER) into a nested dictionary structure. - + Args: flat_dict: Dictionary with flattened keys using NESTED_SPLITTER for nesting - + Returns: - Nested dictionary structure where nested objects with all None values + Nested dictionary structure where nested objects with all None values are replaced by None at the parent level. """ grouped_keys = {} direct_keys = {} - + # First, categorize the keys for key, value in flat_dict.items(): if NESTED_SPLITTER in key: @@ -330,27 +351,27 @@ def unflatten_dict(flat_dict: Dict[str, Any]) -> Dict[str, Any]: grouped_keys[prefix][rest] = value else: direct_keys[key] = value - + # Process each group and add to result result = dict(direct_keys) # Start with the direct keys - + for prefix, nested_dict in grouped_keys.items(): # Check if this prefix contains nested structures has_nested = any(NESTED_SPLITTER in key for key in nested_dict.keys()) - + if has_nested: # Recursively unflatten the nested structure nested_result = unflatten_dict(nested_dict) - + # Check if all values in the nested result are None after unflattening is_all_none = False if isinstance(nested_result, dict): is_all_none = all(v is None for v in nested_result.values()) - + result[prefix] = None if is_all_none else nested_result else: # Check if all values are None for a flat nested dict is_all_none = all(v is None for v in nested_dict.values()) result[prefix] = None if is_all_none else nested_dict - + return result diff --git a/foundation_sql/db_drivers.py b/foundation_sql/db_drivers.py index 28f64d2..72739b0 100644 --- a/foundation_sql/db_drivers.py +++ b/foundation_sql/db_drivers.py @@ -3,16 +3,15 @@ import logging import re # Add missing import from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, List +from datetime import datetime +from typing import Any, Dict, List, Optional +import asyncpg +from jinja2sql import Jinja2SQL from sqlalchemy import create_engine, text from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError -from jinja2sql import Jinja2SQL -from datetime import datetime -import asyncpg - logger = logging.getLogger(__name__) @@ -21,25 +20,26 @@ class EngineAdapter(ABC): # ---------- Sync API ---------- @abstractmethod - def init_schema(self, schema_sql: str) -> None: - ... + def init_schema(self, schema_sql: str) -> None: ... @abstractmethod - def run_sql(self, template: str, data: Dict[str, Any]) -> Any: - ... + def run_sql(self, template: str, data: Dict[str, Any]) -> Any: ... @abstractmethod - def close(self) -> None: - ... + def close(self) -> None: ... # ---------- Async API (optional) ---------- - async def init_pool_async(self) -> None: # pragma: no cover - to be implemented by async adapters + async def init_pool_async( + self, + ) -> None: # pragma: no cover - to be implemented by async adapters raise NotImplementedError async def init_schema_async(self, schema_sql: str) -> None: # pragma: no cover raise NotImplementedError - async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: # pragma: no cover + async def run_sql_async( + self, template: str, data: Dict[str, Any] + ) -> Any: # pragma: no cover raise NotImplementedError async def close_async(self) -> None: # pragma: no cover @@ -59,20 +59,22 @@ def __init__(self, dsn: str) -> None: def init_schema(self, schema_sql: str) -> None: with self.engine.begin() as conn: try: - for statement in schema_sql.split(';'): + for statement in schema_sql.split(";"): if statement.strip(): conn.execute(text(statement)) except SQLAlchemyError as e: - raise RuntimeError(f'Failed to initialize schema: {str(e)}') from e + raise RuntimeError(f"Failed to initialize schema: {str(e)}") from e def run_sql(self, template: str, data: Dict[str, Any]) -> Any: # ensure now is available - if 'now' not in data: - data['now'] = datetime.now + if "now" not in data: + data["now"] = datetime.now try: # Normalize templates: remove explicit tojson filters to allow binding - normalized = re.sub(r"\{\{\s*([^}]+?)\s*\|\s*tojson\s*\}\}", r"{{ \1 }}", template) + normalized = re.sub( + r"\{\{\s*([^}]+?)\s*\|\s*tojson\s*\}\}", r"{{ \1 }}", template + ) query, params = self.j2sql.from_string(normalized, context=data) except Exception as e: raise ValueError( @@ -82,7 +84,9 @@ def run_sql(self, template: str, data: Dict[str, Any]) -> Any: with self.engine.connect() as conn: with conn.begin(): try: - statements = [stmt.strip() for stmt in query.split(';') if stmt.strip()] + statements = [ + stmt.strip() for stmt in query.split(";") if stmt.strip() + ] total_rows = 0 last_result = None @@ -137,23 +141,25 @@ async def init_schema_async(self, schema_sql: str) -> None: assert self.pool is not None async with self.pool.acquire() as conn: # Execute statements sequentially to ensure order - for statement in [s.strip() for s in schema_sql.split(';') if s.strip()]: + for statement in [s.strip() for s in schema_sql.split(";") if s.strip()]: try: logger.debug(f"Executing schema statement: {statement}") await conn.execute(statement) except Exception as e: logger.error(f"Failed to execute schema statement: {statement}") - raise RuntimeError(f"Failed to execute schema statement: {str(e)}") from e + raise RuntimeError( + f"Failed to execute schema statement: {str(e)}" + ) from e async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: await self.init_pool_async() assert self.pool is not None - if 'now' not in data: - data['now'] = datetime.now + if "now" not in data: + data["now"] = datetime.now # Special handling for templates without parameters - if '{{' not in template: + if "{{" not in template: # No template variables, execute directly async with self.pool.acquire() as conn: is_select = template.strip().lower().startswith("select") @@ -167,15 +173,19 @@ async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: try: # Add the same template normalization as SQLAlchemyAdapter # Normalize templates: remove explicit tojson filters to allow binding - normalized = re.sub(r"\{\{\s*([^}]+?)\s*\|\s*tojson\s*\}\}", r"{{ \1 }}", template) + normalized = re.sub( + r"\{\{\s*([^}]+?)\s*\|\s*tojson\s*\}\}", r"{{ \1 }}", template + ) # jinja2sql generates SQL with $1, $2... placeholders and params in correct order query, params_list = self.j2sql.from_string(normalized, context=data) - + # Debug: log parameter types logger.debug(f"Rendered query: {query}") logger.debug(f"Parameter values: {params_list}") - logger.debug(f"Parameter types: {[(type(p).__name__, p) for p in params_list]}") - + logger.debug( + f"Parameter types: {[(type(p).__name__, p) for p in params_list]}" + ) + except Exception as e: raise ValueError( f"Failed to render SQL. Likely SQL template & Parameter mismatch: {str(e)}" @@ -184,34 +194,40 @@ async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: async with self.pool.acquire() as conn: try: # Handle multiple statements like SQLAlchemyAdapter - statements = [stmt.strip() for stmt in query.split(';') if stmt.strip()] + statements = [stmt.strip() for stmt in query.split(";") if stmt.strip()] total_rows = 0 last_result = None logger.debug(f"Executing {len(statements)} statement(s)") for i, statement in enumerate(statements): logger.debug(f"Statement {i+1}: {statement}") - + is_select = statement.strip().lower().startswith("select") - + if is_select: last_result = await conn.fetch(statement, *params_list) total_rows += len(last_result) logger.debug(f"SELECT returned {len(last_result)} rows") else: # Check if this is a schema operation or data operation - is_schema_op = any(keyword in statement.upper() for keyword in - ['CREATE', 'DROP', 'ALTER', 'TRUNCATE']) - + is_schema_op = any( + keyword in statement.upper() + for keyword in ["CREATE", "DROP", "ALTER", "TRUNCATE"] + ) + if is_schema_op: # Schema operations typically don't use parameters - logger.debug("Executing as schema operation (no parameters)") + logger.debug( + "Executing as schema operation (no parameters)" + ) status = await conn.execute(statement) else: # Data operations use parameters - logger.debug(f"Executing as data operation with parameters: {params_list}") + logger.debug( + f"Executing as data operation with parameters: {params_list}" + ) status = await conn.execute(statement, *params_list) - + stmt_rows = _parse_rowcount(status) total_rows += stmt_rows logger.debug(f"Statement affected {stmt_rows} rows") @@ -221,23 +237,29 @@ async def run_sql_async(self, template: str, data: Dict[str, Any]) -> Any: rows = [dict(r) for r in last_result] logger.debug(f"Returning {len(rows)} rows") return rows - + logger.debug(f"Returning row count: {total_rows}") return total_rows - + except Exception as e: - error_msg = (f"Failed to execute SQL: {str(e)}\n" - f"Rendered SQL: {query}\n" - f"Parameters: {params_list}\n" - f"Parameter types: {[type(p).__name__ for p in params_list]}") + error_msg = ( + f"Failed to execute SQL: {str(e)}\n" + f"Rendered SQL: {query}\n" + f"Parameters: {params_list}\n" + f"Parameter types: {[type(p).__name__ for p in params_list]}" + ) logger.error(error_msg) raise RuntimeError(error_msg) from e # Sync methods are not supported for asyncpg adapter - def init_schema(self, schema_sql: str) -> None: # pragma: no cover - sync not supported + def init_schema( + self, schema_sql: str + ) -> None: # pragma: no cover - sync not supported raise NotImplementedError("Use init_schema_async with AsyncpgAdapter") - def run_sql(self, template: str, data: Dict[str, Any]) -> Any: # pragma: no cover - sync not supported + def run_sql( + self, template: str, data: Dict[str, Any] + ) -> Any: # pragma: no cover - sync not supported raise NotImplementedError("Use run_sql_async with AsyncpgAdapter") def close(self) -> None: # pragma: no cover - sync not supported @@ -251,4 +273,4 @@ def _parse_rowcount(status: str) -> int: # Common patterns: 'INSERT 0 1' (rowcount is last), 'UPDATE 3' (last), 'DELETE 2' return int(parts[-1]) except Exception: - return 0 \ No newline at end of file + return 0 diff --git a/foundation_sql/gen.py b/foundation_sql/gen.py index cad5266..cdd53c7 100644 --- a/foundation_sql/gen.py +++ b/foundation_sql/gen.py @@ -1,27 +1,26 @@ import logging -from openai import OpenAI import re +from openai import OpenAI + + class SQLGenerator: """ Advanced SQL template generator with configurable LLM backend. - + Supports: - Configurable language models - Persistent template caching - Flexible generation parameters - Comprehensive error handling """ - + def __init__( - self, - api_key: str, - base_url: str, - model: str = "llama-3.3-70b-versatile" + self, api_key: str, base_url: str, model: str = "llama-3.3-70b-versatile" ): """ Initialize the SQL generator. - + Args: api_key (str): API key for the LLM service base_url (str): Base URL for the LLM service @@ -36,27 +35,25 @@ def __init__( def client(self): return OpenAI(api_key=self.api_key, base_url=self.base_url) - def generate_sql(self, prompt: str) -> str: """ Generate an SQL template based on the provided prompt. - + Args: prompt (str): Detailed prompt for SQL generation - + Returns: str: Generated SQL template """ response = self.client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": prompt} - ] + model=self.model, messages=[{"role": "system", "content": prompt}] ) - + generated_sql = response.choices[0].message.content.strip() - + # Remove ```sql or ``` fences - sql_template = re.sub(r"^```sql\s*|^```\s*|```$", "", generated_sql, flags=re.MULTILINE).strip() - - return sql_template \ No newline at end of file + sql_template = re.sub( + r"^```sql\s*|^```\s*|```$", "", generated_sql, flags=re.MULTILINE + ).strip() + + return sql_template diff --git a/foundation_sql/prompt.py b/foundation_sql/prompt.py index dd2132d..edd2cd1 100644 --- a/foundation_sql/prompt.py +++ b/foundation_sql/prompt.py @@ -1,8 +1,8 @@ import inspect import json +from datetime import datetime from types import NoneType from typing import Any, Callable, Dict, Optional, Type, Union, get_type_hints -from datetime import datetime from pydantic import BaseModel @@ -27,7 +27,6 @@ def _model_fields(self): return {} return {k: str(v) for k, v in self.return_type.model_fields.items()} - def _extract_kwargs(self, func: Callable) -> Dict[str, Type]: """ Extract named parameters and their types from a function. @@ -40,8 +39,8 @@ def _extract_kwargs(self, func: Callable) -> Dict[str, Type]: """ signature = inspect.signature(func) return { - name: param.annotation - for name, param in signature.parameters.items() + name: param.annotation + for name, param in signature.parameters.items() if param.annotation is not param.empty } @@ -65,57 +64,60 @@ def serialize_value(v): def _extract_return_model(self, func: Callable) -> (Type[BaseModel], Optional[str]): """ Extract the return model type from a function's type annotations. - + Args: func (Callable): Function to analyze - + Returns: Tuple containing: - Pydantic model class - Wrapper type ('list' or None) - + Raises: ValueError: If return type is invalid or not a Pydantic model """ hints = get_type_hints(func) - if 'return' not in hints: - raise ValueError(f'Function {func.__name__} must have a return type annotation') - - return_type = hints['return'] + if "return" not in hints: + raise ValueError( + f"Function {func.__name__} must have a return type annotation" + ) + + return_type = hints["return"] wrapper = None - + # Handle Optional[Model] - if hasattr(return_type, '__origin__') and return_type.__origin__ is Union: + if hasattr(return_type, "__origin__") and return_type.__origin__ is Union: args = return_type.__args__ if len(args) == 2 and args[1] is type(None): return_type = args[0] - + # Handle List[Model] - if hasattr(return_type, '__origin__') and return_type.__origin__ is list: - wrapper = 'list' + if hasattr(return_type, "__origin__") and return_type.__origin__ is list: + wrapper = "list" return_type = return_type.__args__[0] - - return return_type, wrapper + return return_type, wrapper class SQLPromptGenerator: """ Generates prompts for SQL template generation based on function context and predefined schemas. - + Attributes: func (FunctionSpec): Function to generate SQL for template_name (str): Name of the SQL template """ - - def __init__(self, func_spec: FunctionSpec, + + def __init__( + self, + func_spec: FunctionSpec, template_name: str, system_prompt: str, - schema: Optional[str] = None - ): + schema: Optional[str] = None, + ): """ Initialize the SQL prompt generator. - + Args: func (FunctionSpec): Function to generate SQL for template_name (str): Name of the SQL template @@ -128,10 +130,15 @@ def __init__(self, func_spec: FunctionSpec, self.schema = schema self.system_prompt = system_prompt - def generate_prompt(self, kwargs: Dict[str, Any], error: Optional[str] = None, prev_template: Optional[str] = None) -> str: + def generate_prompt( + self, + kwargs: Dict[str, Any], + error: Optional[str] = None, + prev_template: Optional[str] = None, + ) -> str: """ Generate a comprehensive prompt for SQL template generation. - + Returns: str: Detailed prompt with function context and schema """ @@ -148,7 +155,6 @@ def generate_prompt(self, kwargs: Dict[str, Any], error: Optional[str] = None, p Review the error and suggest an improved SQL template that works. """ - return f""" {self.system_prompt} ---------------- diff --git a/foundation_sql/query.py b/foundation_sql/query.py index fb88fe2..dbe3cde 100644 --- a/foundation_sql/query.py +++ b/foundation_sql/query.py @@ -1,45 +1,44 @@ -import os import functools -from typing import Any, Callable, Optional import inspect +import os +from importlib import resources as impresources +from typing import Any, Callable, Dict, Optional -from foundation_sql.prompt import SQLPromptGenerator, FunctionSpec -from foundation_sql.gen import SQLGenerator -from foundation_sql.cache import SQLTemplateCache from foundation_sql import db -from typing import Callable, Dict, Optional - -from importlib import resources as impresources +from foundation_sql.cache import SQLTemplateCache +from foundation_sql.gen import SQLGenerator +from foundation_sql.prompt import FunctionSpec, SQLPromptGenerator -DEFAULT_SYSTEM_PROMPT = impresources.read_text('foundation_sql', 'prompts.md') +DEFAULT_SYSTEM_PROMPT = impresources.read_text("foundation_sql", "prompts.md") class SQLQueryDecorator: """ Advanced decorator for generating and executing SQL queries with comprehensive features. - + Supports: - Dynamic SQL template generation - Configurable LLM backend - Persistent template caching - Robust error handling and regeneration - + Attributes: name (Optional[str]): Custom name for SQL template regen (Optional[bool]): SQL template regeneration strategy config (SQLGeneratorConfig): Configuration for SQL generation """ - + def __init__( - self, - name: Optional[str] = None, + self, + name: Optional[str] = None, regen: Optional[bool] = None, repair: Optional[int] = 0, schema: Optional[str] = None, schema_path: Optional[str] = None, + schema_inspect: bool = False, system_prompt: Optional[str] = None, system_prompt_path: Optional[str] = None, - cache_dir: Optional[str] = '__sql__', + cache_dir: Optional[str] = "__sql__", db_url: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None, @@ -47,28 +46,36 @@ def __init__( ): """ Initialize the SQL query decorator. - + Args: - name (Optional[str]): Custom name for SQL file/folder. + name (Optional[str]): Custom name for SQL file/folder. Defaults to function name. regen (Optional[bool]): SQL template regeneration strategy. - config (Optional[SQLGeneratorConfig]): Custom configuration + config (Optional[SQLGeneratorConfig]): Custom configuration for SQL generation. """ self.name = name self.regen = regen self.cache_dir = cache_dir self.db_url = db_url or os.environ.get("DATABASE_URL") - if ( not self.db_url): - raise ValueError(f"Database URL not provided either through constructor or DATABASE_URL environment variable") - - if (not schema and not schema_path): - # Load the schema from the database - self.schema = db.extract_schema_from_db(self.db_url) + if not self.db_url: + raise ValueError( + "Database URL not provided either through constructor or DATABASE_URL environment variable" + ) + if schema_inspect: + if schema or schema_path: + raise ValueError( + "Cannot provide 'schema' or 'schema_path' when 'schema_inspect' is True." + ) + self.schema = db.extract_schema_from_db(self.db_url) else: + if not schema and not schema_path: + raise ValueError( + "Must provide either 'schema' or 'schema_path' when 'schema_inspect' is False." + ) self.schema = schema or self.load_file(schema_path) - + if system_prompt or system_prompt_path: self.system_prompt = system_prompt or self.load_file(system_prompt_path) else: @@ -76,26 +83,31 @@ def __init__( self.db_url = db_url if not self.db_url: - raise ValueError("Database URL not provided either through constructor or DATABASE_URL environment variable") + raise ValueError( + "Database URL not provided either through constructor or DATABASE_URL environment variable" + ) - # Initialize cache and SQL generator self.cache = SQLTemplateCache(cache_dir=cache_dir) self.sql_generator = SQLGenerator( - api_key=api_key, - base_url=base_url, - model=model + api_key=api_key, base_url=base_url, model=model ) self.repair = repair - + def __call__(self, func: Callable) -> Callable: template_name = self.name or f"{func.__name__}.sql" fn_spec = FunctionSpec(func) - prompt_generator = SQLPromptGenerator(fn_spec, template_name, self.system_prompt, self.schema) + prompt_generator = SQLPromptGenerator( + fn_spec, template_name, self.system_prompt, self.schema + ) - def sql_gen(kwargs: Dict[str, Any], error: Optional[str] = None, prev_template: Optional[str] = None): + def sql_gen( + kwargs: Dict[str, Any], + error: Optional[str] = None, + prev_template: Optional[str] = None, + ): if self.regen or not self.cache.exists(template_name) or error: prompt = prompt_generator.generate_prompt(kwargs, error, prev_template) sql_template = self.sql_generator.generate_sql(prompt) @@ -105,7 +117,7 @@ def sql_gen(kwargs: Dict[str, Any], error: Optional[str] = None, prev_template: return sql_template def _parse_result(result_data: Any): - if fn_spec.wrapper == 'list': + if fn_spec.wrapper == "list": return [ db.parse_query_to_pydantic(row, fn_spec.return_type) for row in result_data.all() @@ -115,10 +127,12 @@ def _parse_result(result_data: Any): elif fn_spec.return_type is int: # Best-effort mapping from non-int results to int # 1) QueryResult-like object - if hasattr(result_data, "count") and callable(getattr(result_data, "count")): + if hasattr(result_data, "scalar") and callable( + getattr(result_data, "scalar") + ): try: - return int(result_data.count()) - except Exception: + return int(result_data.scalar()) + except (ValueError, TypeError): pass # 2) List of rows if isinstance(result_data, list): @@ -139,7 +153,11 @@ def _parse_result(result_data: Any): return 0 else: first_row = result_data.first() - return db.parse_query_to_pydantic(first_row, fn_spec.return_type) if first_row else None + return ( + db.parse_query_to_pydantic(first_row, fn_spec.return_type) + if first_row + else None + ) is_async = inspect.iscoroutinefunction(func) executor = WrapSqlExecution( @@ -151,20 +169,20 @@ def _parse_result(result_data: Any): ) return executor.build_wrapper(is_async) - def load_file(self, path: str) -> str: """ Load predefined table schemas. - + Returns: str: SQL schema definitions """ if not path or not os.path.exists(path): raise FileNotFoundError(f"Schema file not found at {path}") - with open(path, 'r') as f: + with open(path, "r") as f: return f.read() + class WrapSqlExecution: def __init__( @@ -185,7 +203,9 @@ async def _execute_async(self, **kwargs: Any): last_exc = None error = None sql_template = None - attempts = self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + attempts = ( + self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + ) database = db.get_db_with_adapter(self.db_url, "async") for _ in range(attempts): @@ -211,7 +231,9 @@ def _execute_sync(self, **kwargs: Any): last_exc = None error = None sql_template = None - attempts = self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + attempts = ( + self.repair + 1 if isinstance(self.repair, int) and self.repair >= 0 else 1 + ) for _ in range(attempts): sql_template = self.sql_gen(kwargs, error, sql_template) @@ -234,12 +256,16 @@ def _execute_sync(self, **kwargs: Any): def build_wrapper(self, is_async: bool): if is_async: + @functools.wraps(self.func) async def async_wrapper(**kwargs: Any): return await self._execute_async(**kwargs) + return async_wrapper else: + @functools.wraps(self.func) def sync_wrapper(**kwargs: Any): return self._execute_sync(**kwargs) + return sync_wrapper diff --git a/requirements.txt b/requirements.txt index 2c34a84..0a58e02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,8 @@ jinja2sql pydantic openai asyncpg -psycopg2-binary \ No newline at end of file +psycopg2-binary + +# Development dependencies +isort +black diff --git a/setup.py b/setup.py index 56f620f..882fd81 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,12 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() with open("requirements.txt", "r", encoding="utf-8") as fh: - requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")] + requirements = [ + line.strip() for line in fh if line.strip() and not line.startswith("#") + ] setup( name="foundation-sql", diff --git a/tests/common.py b/tests/common.py index 5d7c2fd..ec76cb7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,52 +1,62 @@ -import unittest import os -from foundation_sql import db -from foundation_sql.query import SQLQueryDecorator -from typing import Optional import re +import unittest +from typing import Optional from dotenv import load_dotenv + +from foundation_sql import db +from foundation_sql.query import SQLQueryDecorator + load_dotenv() # Force SQLite in-memory for all tests that use this common module. # Async/Postgres-specific tests manage their own DATABASE_URL and are skipped if absent. DB_URL = "sqlite:///:memory:" -def create_query(schema): + +def create_query(schema=None, schema_inspect=False, db_url=DB_URL): return SQLQueryDecorator( schema=schema, - db_url=DB_URL, + schema_inspect=schema_inspect, + db_url=db_url, api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_API_BASE_URL"), model=os.getenv("OPENAI_MODEL"), repair=2, ) + class DatabaseTests(unittest.TestCase): """Base test class for database-driven tests with common setup and helper methods.""" - db_url = "sqlite:///:memory:" + db_url = DB_URL schema_sql = None schema_path = None def setUp(self): """Create a fresh database connection for each test.""" - #Re-initialize the schema for each test to ensure clean state + # Re-initialize the schema for each test to ensure clean state if (self.schema_sql or self.schema_path) and self.db_url: - db.get_db(self.db_url).init_schema(schema_sql=self.schema_sql, schema_path=self.schema_path) + db.get_db(self.db_url).init_schema( + schema_sql=self.schema_sql, schema_path=self.schema_path + ) # Capture table names from schema_sql for teardown cleanup self._tables_to_drop = [] if self.schema_sql: - for raw in self.schema_sql.split(';'): + for raw in self.schema_sql.split(";"): stmt = raw.strip() if not stmt: continue - m = re.search(r"CREATE\s+TABLE\s+IF\s+NOT\s+EXISTS\s+([a-zA-Z_][a-zA-Z0-9_\.]*)", stmt, re.IGNORECASE) + m = re.search( + r"CREATE\s+TABLE\s+IF\s+NOT\s+EXISTS\s+([a-zA-Z_][a-zA-Z0-9_\.]*)", + stmt, + re.IGNORECASE, + ) if m: self._tables_to_drop.append(m.group(1)) else: pass - def tearDown(self): """Close the database connection after each test.""" diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 1e142bf..fac697a 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -2,9 +2,8 @@ import unittest from typing import Any, Dict, List -from foundation_sql.db_drivers import SQLAlchemyAdapter from foundation_sql import db - +from foundation_sql.db_drivers import SQLAlchemyAdapter SYNC_DB_URL = "sqlite:///:memory:" diff --git a/tests/test_db_parser.py b/tests/test_db_parser.py index 5673a8a..5cc73e1 100644 --- a/tests/test_db_parser.py +++ b/tests/test_db_parser.py @@ -5,12 +5,11 @@ import unittest from datetime import datetime from enum import Enum -from typing import Optional +from typing import List, Optional -from foundation_sql import db from pydantic import BaseModel, Field -from enum import Enum -from typing import Optional, List + +from foundation_sql import db from tests import common # --- Test SQL Schema --- @@ -46,6 +45,7 @@ ); """ + # --- Pydantic Models and Enums (previously from schema) --- class Model(BaseModel): id: str @@ -54,10 +54,12 @@ class Model(BaseModel): max_tokens: Optional[int] = None created_at: Optional[datetime] = None + class AgentType(str, Enum): GENERALIST = "generalist" SPECIALIST = "specialist" + class Agent(BaseModel): id: str name: str @@ -67,11 +69,13 @@ class Agent(BaseModel): model: Optional[Model] = None created_at: Optional[datetime] = None + class TaskStatus(str, Enum): NEW = "new" IN_PROGRESS = "in_progress" ASSIGNED_FOR_WORK = "assigned_for_work" + class Task(BaseModel): id: str task_no: Optional[int] = None @@ -79,9 +83,10 @@ class Task(BaseModel): description: Optional[str] = None status: TaskStatus agent: Optional[Agent] = None - parent_task: Optional['Task'] = None + parent_task: Optional["Task"] = None created_at: Optional[datetime] = None + Task.model_rebuild() @@ -99,19 +104,22 @@ def test_parse_basic_model(self): "name": "Test Model", "context_window": 4096, "max_tokens": 1024, - "created_at": "2025-01-01 12:00:00" + "created_at": "2025-01-01 12:00:00", } - + # Parse the data into a Model result = db.parse_query_to_pydantic(data, Model) - + # Assertions self.assertIsNotNone(result) self.assertEqual(result.id, "00000000-0000-0000-0000-000000000001") self.assertEqual(result.name, "Test Model") self.assertEqual(result.context_window, 4096) self.assertEqual(result.max_tokens, 1024) - self.assertEqual(result.created_at, datetime.strptime("2025-01-01 12:00:00", "%Y-%m-%d %H:%M:%S")) + self.assertEqual( + result.created_at, + datetime.strptime("2025-01-01 12:00:00", "%Y-%m-%d %H:%M:%S"), + ) def test_parse_with_enum(self): """Test parsing a model with enum fields.""" @@ -126,12 +134,12 @@ def test_parse_with_enum(self): "model.name": "Test Model", "model.context_window": 4096, "model.max_tokens": 1024, - "created_at": "2025-01-01 12:00:00" + "created_at": "2025-01-01 12:00:00", } - + # Parse the data into an Agent result = db.parse_query_to_pydantic(data, Agent) - + # Assertions self.assertIsNotNone(result) self.assertEqual(result.id, "00000000-0000-0000-0000-000000000001") @@ -143,32 +151,33 @@ def test_parse_with_enum(self): def test_parse_with_optional_enum(self): """Test parsing a model with optional enum fields.""" + # Create a simple test class with optional enum class TestStatus(str, Enum): ACTIVE = "active" INACTIVE = "inactive" - + class TestModel(BaseModel): name: str status: Optional[TestStatus] = None - + # Test with enum value present data1 = { "id": "00000000-0000-0000-0000-000000000001", "name": "Test Item", "status": "active", - "created_at": "2025-01-01 12:00:00" + "created_at": "2025-01-01 12:00:00", } - + result1 = db.parse_query_to_pydantic(data1, TestModel) self.assertEqual(result1.status, TestStatus.ACTIVE) self.assertIsInstance(result1.status, TestStatus) - + # Test with enum value absent data2 = { "id": "00000000-0000-0000-0000-000000000002", "name": "Test Item 2", - "created_at": "2025-01-01 12:00:00" + "created_at": "2025-01-01 12:00:00", } result2 = db.parse_query_to_pydantic(data2, TestModel) @@ -187,17 +196,17 @@ def test_parse_with_nested_model(self): "model.name": "Test Model", "model.context_window": 4096, "model.max_tokens": 1024, - "created_at": "2025-01-01 12:00:00" + "created_at": "2025-01-01 12:00:00", } - + # Parse the data into an Agent result = db.parse_query_to_pydantic(data, Agent) - + # Assertions self.assertIsNotNone(result) self.assertEqual(result.id, "00000000-0000-0000-0000-000000000001") self.assertEqual(result.name, "Test Agent") - + # Check nested model self.assertIsNotNone(result.model) self.assertEqual(result.model.id, "00000000-0000-0000-0000-000000000002") @@ -207,28 +216,29 @@ def test_parse_with_nested_model(self): def test_parse_with_optional_nested_model(self): """Test parsing a model with an optional nested model that is present.""" + # Create a simple test that doesn't involve complex nested models # First, let's create a simpler test class with an optional nested model class SimpleModel(BaseModel): name: str value: int - + class ContainerModel(BaseModel): id: str title: str nested: Optional[SimpleModel] = None - + # Create test data for a container with a nested model data = { "id": "test-id-123", "title": "Test Container", "nested.name": "Test Nested", - "nested.value": 42 + "nested.value": 42, } - + # Parse the data into a ContainerModel result = db.parse_query_to_pydantic(data, ContainerModel) - + # Assertions self.assertIsNotNone(result) self.assertEqual(result.id, "test-id-123") @@ -246,18 +256,18 @@ def test_parse_without_optional_nested_model(self): "title": "Test Task", "description": "A test task", "status": "new", - "created_at": "2025-01-01 12:00:00" + "created_at": "2025-01-01 12:00:00", } - + # Parse the data into a Task result = db.parse_query_to_pydantic(data, Task) - + # Assertions self.assertIsNotNone(result) self.assertEqual(result.id, "00000000-0000-0000-0000-000000000001") self.assertEqual(result.title, "Test Task") self.assertEqual(result.status, TaskStatus.NEW) - + # Check that agent is None self.assertIsNone(result.agent) @@ -266,7 +276,7 @@ def test_parse_empty_data(self): # Test with None result1 = db.parse_query_to_pydantic(None, Model) self.assertIsNone(result1) - + # Test with empty dict result2 = db.parse_query_to_pydantic({}, Model) self.assertIsNone(result2) @@ -283,7 +293,6 @@ def test_parse_complex_nested_structure(self): "description": "A subtask", "status": "in_progress", "created_at": "2025-01-01 12:00:00", - # Agent fields with double-underscore notation "agent.id": "00000000-0000-0000-0000-000000000002", "agent.name": "Test Agent", @@ -291,40 +300,38 @@ def test_parse_complex_nested_structure(self): "agent.instructions": "Test instructions", "agent.type": "specialist", "agent.created_at": "2025-01-01 11:00:00", - # Agent's model fields with double-underscore notation "agent.model.id": "00000000-0000-0000-0000-000000000004", "agent.model.name": "Agent Model", "agent.model.context_window": 4096, "agent.model.max_tokens": 1024, "agent.model.created_at": "2025-01-01 10:00:00", - # Parent task fields with double-underscore notation "parent_task.id": "00000000-0000-0000-0000-000000000003", "parent_task.task_no": 1, "parent_task.title": "Parent Task", "parent_task.description": "A parent task", "parent_task.status": "assigned_for_work", - "parent_task.created_at": "2025-01-01 09:00:00" + "parent_task.created_at": "2025-01-01 09:00:00", } - + # Parse the data into a Task - this should work by properly handling the nested fields # with double-underscore notation result = db.parse_query_to_pydantic(data, Task) - + # Assertions self.assertIsNotNone(result) self.assertEqual(result.id, "00000000-0000-0000-0000-000000000001") self.assertEqual(result.title, "Subtask") self.assertEqual(result.status, TaskStatus.IN_PROGRESS) - + # Check nested agent self.assertIsNotNone(result.agent) self.assertEqual(result.agent.id, "00000000-0000-0000-0000-000000000002") self.assertEqual(result.agent.name, "Test Agent") self.assertIsNotNone(result.agent.model) self.assertEqual(result.agent.model.name, "Agent Model") - + # Check parent task self.assertIsNotNone(result.parent_task) self.assertEqual(result.parent_task.id, "00000000-0000-0000-0000-000000000003") @@ -332,5 +339,5 @@ def test_parse_complex_nested_structure(self): self.assertEqual(result.parent_task.status, TaskStatus.ASSIGNED_FOR_WORK) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_generated_id.py b/tests/test_generated_id.py index 68cc80f..f2dbc83 100644 --- a/tests/test_generated_id.py +++ b/tests/test_generated_id.py @@ -1,13 +1,17 @@ from typing import List, Optional -from tests import common + from pydantic import BaseModel, Field +from tests import common + + class User(BaseModel): id: Optional[int] = Field(default=None) name: str email: str role: str + TABLES_SCHEMA = """ CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -22,6 +26,7 @@ class User(BaseModel): query = common.create_query(schema=TABLES_SCHEMA) + @query def get_users() -> List[User]: """ @@ -29,6 +34,7 @@ def get_users() -> List[User]: """ pass + @query def create_user(user: User) -> User: """ @@ -36,17 +42,18 @@ def create_user(user: User) -> User: """ pass + class TestQuery(common.DatabaseTests): schema_sql = TABLES_SCHEMA - + def test_users(self): users = get_users() self.assertEqual(len(users), 0) - + user = User(name="John Doe", email="john@example.com", role="user") create_user(user=user) - + users = get_users() self.assertEqual(len(users), 1) self.assertEqual(users[0].id, 1) diff --git a/tests/test_many_to_one.py b/tests/test_many_to_one.py index a638200..cc3bc66 100644 --- a/tests/test_many_to_one.py +++ b/tests/test_many_to_one.py @@ -1,20 +1,24 @@ -from typing import List, Optional import os import shutil -from tests import common +from typing import List, Optional + from pydantic import BaseModel +from tests import common + class Workspace(BaseModel): id: int name: str + class Task(BaseModel): id: int workspace: Workspace title: str description: Optional[str] = None + TABLES_SCHEMA = """ CREATE TABLE IF NOT EXISTS workspaces ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -34,6 +38,7 @@ class Task(BaseModel): CACHE_DIR = "__sql__" + class TestWorkspaceTasks(common.DatabaseTests): schema_sql = TABLES_SCHEMA @@ -90,6 +95,7 @@ def setUpClass(cls) -> None: """.strip() ) + @query def create_workspace(name: str) -> Workspace: """ @@ -97,13 +103,17 @@ def create_workspace(name: str) -> Workspace: """ pass + @query -def add_task_to_workspace(workspace: Workspace, title: str, description: Optional[str] = None) -> Task: +def add_task_to_workspace( + workspace: Workspace, title: str, description: Optional[str] = None +) -> Task: """ Inserts a new task into the workspace and returns the Task object. """ pass + @query def get_tasks_for_workspace(workspace: Workspace) -> List[Task]: """ @@ -117,8 +127,12 @@ def test_workspace_tasks(self): self.assertIsInstance(ws, Workspace) # Add tasks - task1 = add_task_to_workspace(workspace=ws, title="Setup repo", description="Initialize git repository") - task2 = add_task_to_workspace(workspace=ws, title="Write docs", description="Document the setup process") + task1 = add_task_to_workspace( + workspace=ws, title="Setup repo", description="Initialize git repository" + ) + task2 = add_task_to_workspace( + workspace=ws, title="Write docs", description="Document the setup process" + ) self.assertIsInstance(task1, Task) self.assertIsInstance(task2, Task) diff --git a/tests/test_nested_query.py b/tests/test_nested_query.py index f33b539..1b4650d 100644 --- a/tests/test_nested_query.py +++ b/tests/test_nested_query.py @@ -1,18 +1,23 @@ -from typing import List, Optional import os import shutil -from tests import common +from typing import List, Optional + from pydantic import BaseModel +from tests import common + + class Address(BaseModel): street: str city: str zip_code: str + class Profile(BaseModel): bio: Optional[str] = None address: Optional[Address] = None + class UserWithProfile(BaseModel): id: str name: str @@ -20,6 +25,7 @@ class UserWithProfile(BaseModel): role: str profile: Optional[Profile] = None + TABLES_SCHEMA = """ CREATE TABLE IF NOT EXISTS users_with_profile ( id VARCHAR(36) PRIMARY KEY, @@ -38,6 +44,7 @@ class UserWithProfile(BaseModel): CACHE_DIR = "__sql__" + class TestNestedQuery(common.DatabaseTests): schema_sql = TABLES_SCHEMA @@ -86,6 +93,7 @@ def setUpClass(cls) -> None: """.strip() ) + @query def get_users_with_profile() -> List[UserWithProfile]: """ @@ -93,6 +101,7 @@ def get_users_with_profile() -> List[UserWithProfile]: """ pass + @query def create_user_with_profile(user: UserWithProfile) -> int: """ @@ -104,66 +113,66 @@ def test_nested_object_creation_and_retrieval(self): # Initially, no users users = get_users_with_profile() self.assertEqual(len(users), 0) - + # Create a user with a full nested profile address = Address(street="123 Main St", city="Anytown", zip_code="12345") profile = Profile(bio="Software Engineer", address=address) user = UserWithProfile( - id="nested_user_1", - name="Jane Doe", - email="jane@example.com", - role="user", - profile=profile + id="nested_user_1", + name="Jane Doe", + email="jane@example.com", + role="user", + profile=profile, ) - + # Store the user create_user_with_profile(user=user) - + # Retrieve and verify retrieved_users = get_users_with_profile() self.assertEqual(len(retrieved_users), 1) - + retrieved_user = retrieved_users[0] self.assertEqual(retrieved_user.id, "nested_user_1") self.assertEqual(retrieved_user.name, "Jane Doe") self.assertEqual(retrieved_user.email, "jane@example.com") self.assertEqual(retrieved_user.role, "user") - + # Check nested profile self.assertIsNotNone(retrieved_user.profile) self.assertEqual(retrieved_user.profile.bio, "Software Engineer") - + # Check nested address self.assertIsNotNone(retrieved_user.profile.address) self.assertEqual(retrieved_user.profile.address.street, "123 Main St") self.assertEqual(retrieved_user.profile.address.city, "Anytown") self.assertEqual(retrieved_user.profile.address.zip_code, "12345") - + def test_nested_object_with_partial_data(self): # Create a user with a partial profile user = UserWithProfile( - id="nested_user_2", - name="John Smith", - email="john@example.com", - role="guest", - profile=Profile(bio="Data Scientist") + id="nested_user_2", + name="John Smith", + email="john@example.com", + role="guest", + profile=Profile(bio="Data Scientist"), ) - + # Store the user create_user_with_profile(user=user) - + # Retrieve and verify retrieved_users = get_users_with_profile() self.assertEqual(len(retrieved_users), 1) - + # Find the newly added user retrieved_user = next(u for u in retrieved_users if u.id == "nested_user_2") - + self.assertEqual(retrieved_user.id, "nested_user_2") self.assertEqual(retrieved_user.name, "John Smith") self.assertEqual(retrieved_user.email, "john@example.com") self.assertEqual(retrieved_user.role, "guest") - + # Check partial profile self.assertIsNotNone(retrieved_user.profile) self.assertEqual(retrieved_user.profile.bio, "Data Scientist") diff --git a/tests/test_query_sync_and_async.py b/tests/test_query_sync_and_async.py index 3aec2f6..6517ad8 100644 --- a/tests/test_query_sync_and_async.py +++ b/tests/test_query_sync_and_async.py @@ -1,20 +1,21 @@ +import inspect import os import shutil import unittest -import inspect from typing import List from pydantic import BaseModel -from foundation_sql.query import SQLQueryDecorator from foundation_sql import db from foundation_sql.db_drivers import AsyncpgAdapter - +from foundation_sql.query import SQLQueryDecorator SQLITE_DB_URL = "sqlite:///__test_sync.sqlite3" # Attempt to ensure DATABASE_URL is available by reading .env if needed -if not os.environ.get("DATABASE_URL") and os.path.exists(os.path.join(os.path.dirname(__file__), "..", ".env")): +if not os.environ.get("DATABASE_URL") and os.path.exists( + os.path.join(os.path.dirname(__file__), "..", ".env") +): env_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".env")) try: with open(env_path, "r") as f: @@ -30,7 +31,9 @@ except Exception: pass -ASYNC_DB_URL = os.environ.get("DATABASE_URL") # e.g., postgresql://user:pass@localhost:5432/dbname +ASYNC_DB_URL = os.environ.get( + "DATABASE_URL" +) # e.g., postgresql://user:pass@localhost:5432/dbname CACHE_DIR_SYNC = "__sql__/__sql_sync__" CACHE_DIR_ASYNC = "__sql__/__sql_async__" @@ -56,7 +59,9 @@ def setUpClass(cls) -> None: with open(os.path.join(CACHE_DIR_SYNC, "get_users.sql"), "w") as f: f.write("SELECT id, name FROM users ORDER BY id;") with open(os.path.join(CACHE_DIR_SYNC, "create_user.sql"), "w") as f: - f.write("INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});") + f.write( + "INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});" + ) # Prepare SQLite DB file cleanly if os.path.exists("__test_sync.sqlite3"): @@ -78,7 +83,9 @@ def tearDown(self) -> None: db.DATABASES.clear() def test_sync_wrappers_and_execution(self): - query = SQLQueryDecorator(schema=self.TABLES_SCHEMA, db_url=SQLITE_DB_URL, cache_dir=CACHE_DIR_SYNC) + query = SQLQueryDecorator( + schema=self.TABLES_SCHEMA, db_url=SQLITE_DB_URL, cache_dir=CACHE_DIR_SYNC + ) @query def get_users() -> List["TestSQLQueryDecoratorSync.User"]: @@ -124,7 +131,9 @@ def setUpClass(cls) -> None: with open(os.path.join(CACHE_DIR_ASYNC, "get_users.sql"), "w") as f: f.write("SELECT id, name FROM users ORDER BY id;") with open(os.path.join(CACHE_DIR_ASYNC, "create_user.sql"), "w") as f: - f.write("INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});") + f.write( + "INSERT INTO users (id, name) VALUES ({{ user.id }}, {{ user.name | tojson }});" + ) async def asyncSetUp(self): # Initialize schema on Postgres using async adapter each test for isolation @@ -138,7 +147,9 @@ async def asyncTearDown(self): db.DATABASES.clear() async def test_async_wrappers_and_execution(self): - query = SQLQueryDecorator(schema=self.TABLES_SCHEMA, db_url=ASYNC_DB_URL, cache_dir=CACHE_DIR_ASYNC) + query = SQLQueryDecorator( + schema=self.TABLES_SCHEMA, db_url=ASYNC_DB_URL, cache_dir=CACHE_DIR_ASYNC + ) @query async def get_users() -> List["TestSQLQueryDecoratorAsync.User"]: diff --git a/tests/test_schema_discovery.py b/tests/test_schema_discovery.py index 206332e..c2915e1 100644 --- a/tests/test_schema_discovery.py +++ b/tests/test_schema_discovery.py @@ -1,7 +1,41 @@ +import os +import sqlite3 from typing import List -from tests import common + from pydantic import BaseModel -from tests.utils import BIKES_DB_PATH, create_bike_db + +from tests import common + +# --- Start of moved code from tests/utils.py --- + +BIKES_DB_PATH = os.path.abspath( + os.path.join(os.path.dirname(__file__), "fixtures", "bikes.db") +) + + +def create_bike_db(): + os.makedirs(os.path.dirname(BIKES_DB_PATH), exist_ok=True) + + if os.path.exists(BIKES_DB_PATH): + os.remove(BIKES_DB_PATH) + + conn = sqlite3.connect(BIKES_DB_PATH) + cursor = conn.cursor() + cursor.execute( + """ + CREATE TABLE bikes ( + make TEXT NOT NULL, + model TEXT NOT NULL, + price INTEGER NOT NULL + ); + """ + ) + conn.commit() + conn.close() + + +# --- End of moved code --- + class Bike(BaseModel): make: str @@ -11,7 +45,8 @@ class Bike(BaseModel): create_bike_db() -query = common.create_query(db_url=f"sqlite:///{BIKES_DB_PATH}") +query = common.create_query(db_url=f"sqlite:///{BIKES_DB_PATH}", schema_inspect=True) + @query def get_bikes() -> List[Bike]: @@ -20,6 +55,7 @@ def get_bikes() -> List[Bike]: """ pass + @query def create_bike(bike: Bike) -> Bike: """ @@ -27,10 +63,11 @@ def create_bike(bike: Bike) -> Bike: """ pass + @query def get_total_price() -> int: """ - Get the total price of all the bikes + Calculates the total price of all the bikes. """ pass @@ -45,10 +82,10 @@ def test_schema_discovery(self): create_bike(bike=re_bike) harley_bike = Bike(make="Harley", model="A very good one", price=500) - create_bike(bike = harley_bike) + create_bike(bike=harley_bike) bikes = get_bikes() self.assertEqual(len(bikes), 2) price = get_total_price() - self.assertEqual(price,1100) \ No newline at end of file + self.assertEqual(price, 1100) diff --git a/tests/test_simple_query.py b/tests/test_simple_query.py index ef9ee1c..4cc7390 100644 --- a/tests/test_simple_query.py +++ b/tests/test_simple_query.py @@ -1,8 +1,10 @@ -from typing import List -from tests import common -from pydantic import BaseModel import os import shutil +from typing import List + +from pydantic import BaseModel + +from tests import common class User(BaseModel): @@ -11,6 +13,7 @@ class User(BaseModel): email: str role: str + TABLES_SCHEMA = """ CREATE TABLE IF NOT EXISTS users ( id VARCHAR(36) PRIMARY KEY, @@ -24,6 +27,7 @@ class User(BaseModel): query = common.create_query(schema=TABLES_SCHEMA) + @query def get_users() -> List[User]: """ @@ -31,6 +35,7 @@ def get_users() -> List[User]: """ pass + @query def create_user(user: User) -> int: """ @@ -38,6 +43,7 @@ def create_user(user: User) -> int: """ pass + class TestQuery(common.DatabaseTests): schema_sql = TABLES_SCHEMA @@ -81,14 +87,14 @@ def setUpClass(cls) -> None: """ ).strip() ) - + def test_users(self): users = get_users() self.assertEqual(len(users), 0) - + user = User(id="xxx", name="John Doe", email="john@example.com", role="user") create_user(user=user) - + users = get_users() self.assertEqual(len(users), 1) self.assertEqual(users[0], user) From 1d7346cb41e8679320acfd50846e6f3deaa90602 Mon Sep 17 00:00:00 2001 From: Rishi Pradeep Date: Sat, 4 Oct 2025 04:34:09 +0530 Subject: [PATCH 3/3] Delete utils --- tests/utils.py | 26 -------------------------- 1 file changed, 26 deletions(-) delete mode 100644 tests/utils.py diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index 07b73b3..0000000 --- a/tests/utils.py +++ /dev/null @@ -1,26 +0,0 @@ -# utils.py -import os -import sqlite3 - -BIKES_DB_PATH = os.path.abspath( - os.path.join(os.path.dirname(__file__), "fixtures", "bikes.db") -) - -def create_bike_db(): - os.makedirs(os.path.dirname(BIKES_DB_PATH), exist_ok=True) - - - if os.path.exists(BIKES_DB_PATH): - os.remove(BIKES_DB_PATH) - - conn = sqlite3.connect(BIKES_DB_PATH) - cursor = conn.cursor() - cursor.execute(""" - CREATE TABLE bikes ( - make TEXT NOT NULL, - model TEXT NOT NULL, - price INTEGER NOT NULL - ); - """) - conn.commit() - conn.close()