From faa490a63fe1b6b946f82d2282d5629de5c4c968 Mon Sep 17 00:00:00 2001 From: Rishi Pradeep Date: Sun, 8 Jun 2025 01:52:57 +0530 Subject: [PATCH 1/4] Schema discovery implementation --- .gitignore | 3 +- foundation_sql/db.py | 27 +++++++++++++++-- foundation_sql/prompts.md | 28 +++++++++--------- foundation_sql/query.py | 16 ++++++---- tests/common.py | 18 ++++++++---- tests/test_schema_discovery.py | 54 ++++++++++++++++++++++++++++++++++ tests/utils.py | 26 ++++++++++++++++ 7 files changed, 144 insertions(+), 28 deletions(-) create mode 100644 tests/test_schema_discovery.py create mode 100644 tests/utils.py diff --git a/.gitignore b/.gitignore index b9e30d2..92a10cd 100644 --- a/.gitignore +++ b/.gitignore @@ -37,4 +37,5 @@ ENV/ # Misc .DS_Store -__sql__ \ No newline at end of file +__sql__ +fixtures \ No newline at end of file diff --git a/foundation_sql/db.py b/foundation_sql/db.py index c2541fe..ca9a022 100644 --- a/foundation_sql/db.py +++ b/foundation_sql/db.py @@ -10,7 +10,8 @@ from sqlalchemy import create_engine, text from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError - +from sqlalchemy import MetaData +from sqlalchemy.schema import CreateTable from jinja2sql import Jinja2SQL from datetime import datetime @@ -232,7 +233,28 @@ def is_empty(self) -> bool: True if no rows, False otherwise """ return len(self.rows) == 0 + + +# Function to load the schema from the database +def extract_schema_from_db(db_url: str) -> str: + """Extract the schema from the database. + + Args: + db_url: Database URL to use + + Returns: + Schema as a string + """ + engine = create_engine(db_url) + metadata = MetaData() + metadata.reflect(bind=engine) + + schema_lines = [] + for table in metadata.sorted_tables: + ddl = str(CreateTable(table).compile(engine)) + schema_lines.append(ddl + ";") + return "\n\n".join(schema_lines) def get_db(db_url: str) -> Database: @@ -280,7 +302,8 @@ def parse_query_to_pydantic(data: Dict[str, Any], model_class: Type[BaseModel]) # Check the response type and transform accordingly if model_class == int: - return int(unflattened_data["result"]) + # FIX : STILL ONLY GETS FIRST LINE OF RESPONSE + return int(next(iter(unflattened_data.values()))) elif model_class == NoneType: return None diff --git a/foundation_sql/prompts.md b/foundation_sql/prompts.md index 6809d26..5a45a37 100644 --- a/foundation_sql/prompts.md +++ b/foundation_sql/prompts.md @@ -1,12 +1,12 @@ You are an expert SQL developer. Write one or more SQL queries that can perform the actions as explained by the user. Ensure, the SQL query is usable across sqlite and postgresql. The SQL template generated is a jinja2 template - so jinja2 syntax can be used. -1. Start with a -- comment to document the function name, parameters and docstring, explaining what the SQL query does. +1. Start with a comment to document the function name, parameters and docstring, explaining what the SQL query does. Make sure to start comments with `--` (Only 2 dashes, no more , no less) 2. Use jinja2 template to generate SQL 3. When accessing nested fields handle cases if they aren't defined. Use default filter with None value for such cases e.g. -{{user.zip_code|default(None)}} + {{user.zip_code|default(None)}} 4. Ensure response rows can be parsed into Pydantic model. As long as the model fields are named the same as the columns in the SQL query. It also supports nested models by using double underscores to separate nested fields. -5. For complex tasks, more than one queries can be run, separated by ';' -6. Only respond with a single ```sql``` block which contains all queries. +5. For complex tasks, more than one queries can be run, separated by ';', Make sure queries end with ';'. +6. Only respond with a single `sql` block which contains all queries. 7. No other explanation is necessary 8. For insert queries, avoid any RETURNING clause. Let it return the default. 9. We use jinja2 syntax to generate SQL - so parameters don't need to be quoted e.g. use {{user.zip_code|default(None)}} and not '{{user.zip_code|default(None)}}' @@ -19,18 +19,18 @@ You are an expert SQL developer. Write one or more SQL queries that can perform Here is an example def get_task(workspace: schema.Workspace, task_no: int) -> schema.Task: - """ - Creates and returns a Task object, for the provided workspace and task_no - """ - pass - +""" +Creates and returns a Task object, for the provided workspace and task_no +""" +pass The SQL generated would look like the following + ```sql - --- def get_task(workspace: schema.Workspace, task_no: int) -> schema.Task - --- Creates and returns a Task object, for the provided workspace and task_no - --- Expects task_no and workspace.id are defined. If no tasks are found, returns None - SELECT + -- def get_task(workspace: schema.Workspace, task_no: int) -> schema.Task; + -- Creates and returns a Task object, for the provided workspace and task_no; + -- Expects task_no and workspace.id are defined. If no tasks are found, returns None; + SELECT t.id as `id`, t.task_no as `task_no`, t.title as `title`, @@ -55,7 +55,7 @@ The SQL generated would look like the following LEFT JOIN agents a ON t.agent_id = a.id LEFT JOIN models m ON a.model_id = m.id LEFT JOIN workspace_tasks wt ON t.id = wt.task_id - WHERE t.task_no = {{task_no}} AND wt.workspace_id = {{workspace.id}} + WHERE t.task_no = {{task_no}} AND wt.workspace_id = {{workspace.id}}; ``` Below are the real specifications for which query needs to be generated. diff --git a/foundation_sql/query.py b/foundation_sql/query.py index 51a0caf..edabf79 100644 --- a/foundation_sql/query.py +++ b/foundation_sql/query.py @@ -57,15 +57,21 @@ def __init__( self.name = name self.regen = regen self.cache_dir = cache_dir - self.schema = schema or self.load_file(schema_path) + self.db_url = db_url or os.environ.get("DATABASE_URL") + if ( not self.db_url): + raise ValueError(f"Database URL not provided either through constructor or DATABASE_URL environment variable") + + if (not schema and not schema_path): + # Load the schema from the database + self.schema = db.extract_schema_from_db(self.db_url) + + else: + self.schema = schema or self.load_file(schema_path) + if system_prompt or system_prompt_path: self.system_prompt = system_prompt or self.load_file(system_prompt_path) else: self.system_prompt = DEFAULT_SYSTEM_PROMPT - - self.db_url = db_url - if not self.db_url: - raise ValueError(f"Database URL not provided either through constructor or {db_url_env} environment variable") # Initialize cache and SQL generator self.cache = SQLTemplateCache(cache_dir=cache_dir) diff --git a/tests/common.py b/tests/common.py index a4de0bf..fc4f3da 100644 --- a/tests/common.py +++ b/tests/common.py @@ -2,15 +2,21 @@ import os from foundation_sql import db from foundation_sql.query import SQLQueryDecorator +from typing import Optional from dotenv import load_dotenv load_dotenv() -DB_URL = os.environ.get("DATABSE_URL", "sqlite:///:memory:") +# DB_URL = os.environ.get("DATABASE_URL", "sqlite:///:memory:") -def create_query(schema): +def create_query (schema: Optional[str] = None, db_url : Optional[str] = None): + final_db_url = ( + db_url or + os.environ.get("DATABASE_URL") or + "sqlite:///:memory:" + ) return SQLQueryDecorator(schema=schema, - db_url=DB_URL, + db_url=final_db_url, api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_API_BASE_URL"), model=os.getenv("OPENAI_MODEL")) @@ -18,17 +24,17 @@ def create_query(schema): class DatabaseTests(unittest.TestCase): """Base test class for database-driven tests with common setup and helper methods.""" - db_url = DB_URL + db_url = "sqlite:///:memory:" schema_sql = None schema_path = None def setUp(self): """Create a fresh database connection for each test.""" - # Re-initialize the schema for each test to ensure clean state + #Re-initialize the schema for each test to ensure clean state if (self.schema_sql or self.schema_path) and self.db_url: db.get_db(self.db_url).init_schema(schema_sql=self.schema_sql, schema_path=self.schema_path) else: - raise ValueError("At least one of schema_sql, schema_path must be provided along with db_url") + pass def tearDown(self): diff --git a/tests/test_schema_discovery.py b/tests/test_schema_discovery.py new file mode 100644 index 0000000..206332e --- /dev/null +++ b/tests/test_schema_discovery.py @@ -0,0 +1,54 @@ +from typing import List +from tests import common +from pydantic import BaseModel +from tests.utils import BIKES_DB_PATH, create_bike_db + +class Bike(BaseModel): + make: str + model: str + price: int + + +create_bike_db() + +query = common.create_query(db_url=f"sqlite:///{BIKES_DB_PATH}") + +@query +def get_bikes() -> List[Bike]: + """ + Gets all bikes. + """ + pass + +@query +def create_bike(bike: Bike) -> Bike: + """ + Creates a new bike. + """ + pass + +@query +def get_total_price() -> int: + """ + Get the total price of all the bikes + """ + pass + + +class TestSchemaDiscovery(common.DatabaseTests): + db_url = f"sqlite:///{BIKES_DB_PATH}" + schema_sql = None + + def test_schema_discovery(self): + + re_bike = Bike(make="RE", model="Classic", price=600) + create_bike(bike=re_bike) + + harley_bike = Bike(make="Harley", model="A very good one", price=500) + create_bike(bike = harley_bike) + + bikes = get_bikes() + self.assertEqual(len(bikes), 2) + + price = get_total_price() + self.assertEqual(price,1100) \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..07b73b3 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,26 @@ +# utils.py +import os +import sqlite3 + +BIKES_DB_PATH = os.path.abspath( + os.path.join(os.path.dirname(__file__), "fixtures", "bikes.db") +) + +def create_bike_db(): + os.makedirs(os.path.dirname(BIKES_DB_PATH), exist_ok=True) + + + if os.path.exists(BIKES_DB_PATH): + os.remove(BIKES_DB_PATH) + + conn = sqlite3.connect(BIKES_DB_PATH) + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE bikes ( + make TEXT NOT NULL, + model TEXT NOT NULL, + price INTEGER NOT NULL + ); + """) + conn.commit() + conn.close() From 97229fe5509caee96bb8cd1c76b587183a790064 Mon Sep 17 00:00:00 2001 From: Rishi Pradeepkumar Date: Mon, 9 Jun 2025 04:33:11 +0530 Subject: [PATCH 2/4] Implemented support for multiple db's --- docker-compose.yml | 33 +++++++ foundation_sql/db.py | 170 ++++++++++++++++++--------------- foundation_sql/prompt.py | 66 +++++++------ foundation_sql/prompts.md | 8 +- foundation_sql/query.py | 94 ++++++++++-------- requirements.txt | 31 ++++-- tests/test_mysql.py | 76 +++++++++++++++ tests/test_postgres.py | 81 ++++++++++++++++ tests/test_schema_discovery.py | 18 ++-- tests/utils.py | 53 +++++++++- 10 files changed, 464 insertions(+), 166 deletions(-) create mode 100644 docker-compose.yml create mode 100644 tests/test_mysql.py create mode 100644 tests/test_postgres.py diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..d0613f8 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,33 @@ +version: "3.8" + +services: + mysql: + image: mysql:8.0 + container_name: mysql_db + restart: unless-stopped + environment: + MYSQL_ROOT_PASSWORD: rootpassword + MYSQL_DATABASE: foundation + MYSQL_USER: user + MYSQL_PASSWORD: userpassword + ports: + - "3306:3306" + volumes: + - mysql_data:/var/lib/mysql + + postgres: + image: postgres:15 + container_name: postgres_db + restart: unless-stopped + environment: + POSTGRES_USER: user + POSTGRES_PASSWORD: userpassword + POSTGRES_DB: foundation + ports: + - "5432:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + +volumes: + mysql_data: + postgres_data: diff --git a/foundation_sql/db.py b/foundation_sql/db.py index ca9a022..00a7cae 100644 --- a/foundation_sql/db.py +++ b/foundation_sql/db.py @@ -26,21 +26,22 @@ logger = logging.getLogger(__name__) - class Database: """Database operations for Foundation.""" - + def __init__(self, db_url: Optional[str] = None) -> None: """Initialize the database. - + Args: db_url: SQLAlchemy database URL (e.g., postgresql://user:pass@host/db, sqlite:///path/to/db) If not provided, will use DATABASE_URL environment variable """ - self.db_url = db_url or os.getenv('DATABASE_URL') + self.db_url = db_url or os.getenv("DATABASE_URL") if not self.db_url: - raise ValueError('Database URL must be provided either through constructor or DATABASE_URL environment variable') - + raise ValueError( + "Database URL must be provided either through constructor or DATABASE_URL environment variable" + ) + # Create SQLAlchemy engine self.engine = create_engine(self.db_url) @@ -52,7 +53,9 @@ def get_engine(self) -> Engine: """ return self.engine - def init_schema(self, schema_sql:Optional[str]=None, schema_path: Optional[str] = None) -> None: + def init_schema( + self, schema_sql: Optional[str] = None, schema_path: Optional[str] = None + ) -> None: """Initialize the database schema if it doesn't exist. This method runs the schema creation script in an idempotent way. @@ -64,64 +67,68 @@ def init_schema(self, schema_sql:Optional[str]=None, schema_path: Optional[str] If not provided, will use the default schema at data/tables.sql """ if not schema_sql: - with open(schema_path, 'r') as f: + with open(schema_path, "r") as f: schema_sql = f.read() # Connect and execute schema in a transaction with self.engine.begin() as conn: try: - for statement in schema_sql.split(';'): + for statement in schema_sql.split(";"): if statement.strip(): conn.execute(text(statement)) except SQLAlchemyError as e: - raise RuntimeError(f'Failed to initialize schema: {str(e)}') from e + raise RuntimeError(f"Failed to initialize schema: {str(e)}") from e def run_sql(self, sql_template: str, **context) -> Any: """Run an SQL template string with jinja2sql for rendering and parameter substitution. - + Args: sql_template: SQL template string with jinja2sql syntax **context: Context variables for template rendering - + Returns: For SELECT queries: A QueryResult object with methods for data access For INSERT/UPDATE/DELETE queries: The number of rows affected """ # Add datetime.now function to context if needed - if 'now' not in context: - context['now'] = datetime.now - + if "now" not in context: + context["now"] = datetime.now + # Use jinja2sql to render the template and get SQL with parameters try: - query, params = j2sql.from_string( - sql_template, - context=context - ) + query, params = j2sql.from_string(sql_template, context=context) except Exception as e: - raise ValueError(f"Failed to render SQL. Likely SQL template & Parameter mismatch: {str(e)}") from e - + raise ValueError( + f"Failed to render SQL. Likely SQL template & Parameter mismatch: {str(e)}" + ) from e + # Execute the SQL with parameters in a transaction with self.engine.connect() as conn: with conn.begin(): try: # Split the query into individual statements for SQLite - statements = [stmt.strip() for stmt in query.split(';') if stmt.strip()] - + statements = [ + stmt.strip() + for stmt in query.split(";") + # Ignore empty statements and comments (POSTGRESQL DOESNT IGNORE THEM LIKE MYSQL/SQLITE AND THROWS ERROR OTHERWISE) + if stmt.strip() and not stmt.strip().startswith("--") + ] + # Track total rows affected total_rows = 0 last_result = None - + logger.debug(f"Executing statements: {len(statements)}") # Execute each statement for statement in statements: logger.debug(f"Executing statement: {statement} {params}") result = conn.execute(text(statement), params) total_rows += result.rowcount - + # Store the last result for potential return if result.returns_rows: last_result = result - + # If last result is a SELECT, return its rows if last_result and last_result.returns_rows: # Convert SQLAlchemy Row objects to plain dictionaries @@ -129,90 +136,95 @@ def run_sql(self, sql_template: str, **context) -> Any: rows = [dict(row._mapping) for row in last_result] logger.debug(f"Returning rows: {rows}") return QueryResult(rows) - + # If no rows affected, return 0 return total_rows except SQLAlchemyError as e: # The transaction will be rolled back automatically - raise RuntimeError(f"Failed to execute SQL: {str(e)}\nRendered SQL: {query}") from e - + raise RuntimeError( + f"Failed to execute SQL: {str(e)}\nRendered SQL: {query}" + ) from e - def execute(self, sql: str, params: Optional[Union[tuple, dict, List[tuple]]] = None) -> Any: + def execute( + self, sql: str, params: Optional[Union[tuple, dict, List[tuple]]] = None + ) -> Any: """ Execute a raw SQL statement with optional parameters. - + Args: sql (str): SQL statement to execute - params (Optional[Union[tuple, dict, List[tuple]]]): + params (Optional[Union[tuple, dict, List[tuple]]]): Optional parameters for the SQL statement - Single tuple for single parameter set - List of tuples for multiple parameter sets (bulk insert) - Dictionary for named parameters - + Returns: Any: Result of the execution """ with self.engine.connect() as connection: try: # Replace '?' placeholders with SQLAlchemy named parameters - if '?' in sql: + if "?" in sql: # Count the number of placeholders - placeholder_count = sql.count('?') - + placeholder_count = sql.count("?") + # Replace '?' with named parameters - named_sql = sql.replace('?', ':{}'.format) - + named_sql = sql.replace("?", ":{}".format) + # Prepare parameters if params is None: named_params = {} elif isinstance(params, (tuple, list)): # Convert tuple/list to named dictionary - named_params = {f'p{i}': val for i, val in enumerate(params)} - named_sql = named_sql.format(p=lambda i: f':p{i}') + named_params = {f"p{i}": val for i, val in enumerate(params)} + named_sql = named_sql.format(p=lambda i: f":p{i}") elif isinstance(params, dict): named_params = params else: - raise ValueError("Invalid parameter type. Must be tuple, dict, or list of tuples.") - + raise ValueError( + "Invalid parameter type. Must be tuple, dict, or list of tuples." + ) + # Execute with named parameters result = connection.execute(text(named_sql), named_params) else: # If no '?' placeholders, use as-is result = connection.execute(text(sql), params or {}) - + # If it's a SELECT query, return the rows if result.returns_rows: return result.fetchall() - + # For INSERT, UPDATE, DELETE, return the number of rows affected return result.rowcount - + except SQLAlchemyError as e: raise RuntimeError(f"Database execution error: {str(e)}") from e class QueryResult: """A clean abstraction over query results that doesn't leak implementation details.""" - + def __init__(self, rows: List[Dict[str, Any]]): """Initialize with a list of row dictionaries. - + Args: rows: List of dictionaries representing database rows """ self.rows = rows - + def first(self) -> Optional[Dict[str, Any]]: """Get the first row as a dictionary or None if no rows. - + Returns: First row as a dictionary or None """ return self.rows[0] if self.rows else None - + def all(self) -> List[Dict[str, Any]]: """Get all rows as a list of dictionaries. - + Returns: List of dictionaries representing all rows """ @@ -220,28 +232,28 @@ def all(self) -> List[Dict[str, Any]]: def count(self) -> int: """Get the number of rows. - + Returns: Number of rows """ return len(self.rows) - + def is_empty(self) -> bool: """Check if the result contains any rows. - + Returns: True if no rows, False otherwise """ return len(self.rows) == 0 - + # Function to load the schema from the database def extract_schema_from_db(db_url: str) -> str: """Extract the schema from the database. - + Args: db_url: Database URL to use - + Returns: Schema as a string """ @@ -259,25 +271,26 @@ def extract_schema_from_db(db_url: str) -> str: def get_db(db_url: str) -> Database: """Get the database instance. - + Args: db_url: Database URL to use - + Returns: Database instance """ if db_url not in DATABASES: DATABASES[db_url] = Database(db_url) - + return DATABASES[db_url] + def run_sql(db_url: str, sql_template: str, **context) -> Any: """Run an SQL template string with jinja2sql for rendering and parameter substitution. - + Args: sql_template: SQL template string with jinja2sql syntax **context: Context variables for template rendering - + Returns: For SELECT queries: A QueryResult object with methods for data access For INSERT/UPDATE/DELETE queries: The number of rows affected @@ -285,18 +298,20 @@ def run_sql(db_url: str, sql_template: str, **context) -> Any: return get_db(db_url).run_sql(sql_template, **context) -def parse_query_to_pydantic(data: Dict[str, Any], model_class: Type[BaseModel]) -> Optional[BaseModel]: +def parse_query_to_pydantic( + data: Dict[str, Any], model_class: Type[BaseModel] +) -> Optional[BaseModel]: """Parse query result data into a Pydantic model, handling nested models. - + Args: data: Dictionary containing query results with optional nested fields model_class: The Pydantic model class to instantiate - + Returns: Instance of the Pydantic model or None if data is None """ if not data: - return None + return None unflattened_data = unflatten_dict(data) @@ -306,25 +321,24 @@ def parse_query_to_pydantic(data: Dict[str, Any], model_class: Type[BaseModel]) return int(next(iter(unflattened_data.values()))) elif model_class == NoneType: return None - - return model_class(**unflattened_data) + return model_class(**unflattened_data) def unflatten_dict(flat_dict: Dict[str, Any]) -> Dict[str, Any]: """Convert a flattened dictionary with keys like 'parent.child.grandchild' (using NESTED_SPLITTER) into a nested dictionary structure. - + Args: flat_dict: Dictionary with flattened keys using NESTED_SPLITTER for nesting - + Returns: - Nested dictionary structure where nested objects with all None values + Nested dictionary structure where nested objects with all None values are replaced by None at the parent level. """ grouped_keys = {} direct_keys = {} - + # First, categorize the keys for key, value in flat_dict.items(): if NESTED_SPLITTER in key: @@ -334,27 +348,27 @@ def unflatten_dict(flat_dict: Dict[str, Any]) -> Dict[str, Any]: grouped_keys[prefix][rest] = value else: direct_keys[key] = value - + # Process each group and add to result result = dict(direct_keys) # Start with the direct keys - + for prefix, nested_dict in grouped_keys.items(): # Check if this prefix contains nested structures has_nested = any(NESTED_SPLITTER in key for key in nested_dict.keys()) - + if has_nested: # Recursively unflatten the nested structure nested_result = unflatten_dict(nested_dict) - + # Check if all values in the nested result are None after unflattening is_all_none = False if isinstance(nested_result, dict): is_all_none = all(v is None for v in nested_result.values()) - + result[prefix] = None if is_all_none else nested_result else: # Check if all values are None for a flat nested dict is_all_none = all(v is None for v in nested_dict.values()) result[prefix] = None if is_all_none else nested_dict - + return result diff --git a/foundation_sql/prompt.py b/foundation_sql/prompt.py index dd2132d..b4f9a17 100644 --- a/foundation_sql/prompt.py +++ b/foundation_sql/prompt.py @@ -27,7 +27,6 @@ def _model_fields(self): return {} return {k: str(v) for k, v in self.return_type.model_fields.items()} - def _extract_kwargs(self, func: Callable) -> Dict[str, Type]: """ Extract named parameters and their types from a function. @@ -40,8 +39,8 @@ def _extract_kwargs(self, func: Callable) -> Dict[str, Type]: """ signature = inspect.signature(func) return { - name: param.annotation - for name, param in signature.parameters.items() + name: param.annotation + for name, param in signature.parameters.items() if param.annotation is not param.empty } @@ -65,73 +64,84 @@ def serialize_value(v): def _extract_return_model(self, func: Callable) -> (Type[BaseModel], Optional[str]): """ Extract the return model type from a function's type annotations. - + Args: func (Callable): Function to analyze - + Returns: Tuple containing: - Pydantic model class - Wrapper type ('list' or None) - + Raises: ValueError: If return type is invalid or not a Pydantic model """ hints = get_type_hints(func) - if 'return' not in hints: - raise ValueError(f'Function {func.__name__} must have a return type annotation') - - return_type = hints['return'] + if "return" not in hints: + raise ValueError( + f"Function {func.__name__} must have a return type annotation" + ) + + return_type = hints["return"] wrapper = None - + # Handle Optional[Model] - if hasattr(return_type, '__origin__') and return_type.__origin__ is Union: + if hasattr(return_type, "__origin__") and return_type.__origin__ is Union: args = return_type.__args__ if len(args) == 2 and args[1] is type(None): return_type = args[0] - + # Handle List[Model] - if hasattr(return_type, '__origin__') and return_type.__origin__ is list: - wrapper = 'list' + if hasattr(return_type, "__origin__") and return_type.__origin__ is list: + wrapper = "list" return_type = return_type.__args__[0] - - return return_type, wrapper + return return_type, wrapper class SQLPromptGenerator: """ Generates prompts for SQL template generation based on function context and predefined schemas. - + Attributes: func (FunctionSpec): Function to generate SQL for template_name (str): Name of the SQL template """ - - def __init__(self, func_spec: FunctionSpec, + + def __init__( + self, + func_spec: FunctionSpec, template_name: str, system_prompt: str, - schema: Optional[str] = None - ): + schema: Optional[str] = None, + db_type: Optional[str] = None, + ): """ Initialize the SQL prompt generator. - + Args: func (FunctionSpec): Function to generate SQL for template_name (str): Name of the SQL template system_prompt (str): System prompt for SQL generation - schema (Optional[str]): SQL schema definitions + schema (Optional[str]): SQL schema definSQLitions error_prompt (Optional[str]): Error prompt for SQL generation + db_type (Optional[str]): Type of the database (e.g., 'sqlite', 'postgres') """ self.func_spec = func_spec self.template_name = template_name self.schema = schema self.system_prompt = system_prompt - - def generate_prompt(self, kwargs: Dict[str, Any], error: Optional[str] = None, prev_template: Optional[str] = None) -> str: + self.db_type = db_type + + def generate_prompt( + self, + kwargs: Dict[str, Any], + error: Optional[str] = None, + prev_template: Optional[str] = None, + ) -> str: """ Generate a comprehensive prompt for SQL template generation. - + Returns: str: Detailed prompt with function context and schema """ @@ -148,7 +158,6 @@ def generate_prompt(self, kwargs: Dict[str, Any], error: Optional[str] = None, p Review the error and suggest an improved SQL template that works. """ - return f""" {self.system_prompt} ---------------- @@ -159,6 +168,7 @@ def generate_prompt(self, kwargs: Dict[str, Any], error: Optional[str] = None, p Function Signature: {self.func_spec.signature} Function Docstring: {self.func_spec.docstring} Function Arguments: {self.func_spec.kwargs_json(kwargs)} +Database Type: {self.db_type} Return model: {self.func_spec.return_type.__name__} Model fields: {json.dumps({k: str(v) for k, v in self.func_spec.model_fields.items()}, indent=2)} diff --git a/foundation_sql/prompts.md b/foundation_sql/prompts.md index 5a45a37..c3a3000 100644 --- a/foundation_sql/prompts.md +++ b/foundation_sql/prompts.md @@ -1,4 +1,4 @@ -You are an expert SQL developer. Write one or more SQL queries that can perform the actions as explained by the user. Ensure, the SQL query is usable across sqlite and postgresql. The SQL template generated is a jinja2 template - so jinja2 syntax can be used. +You are an expert SQL developer. Write one or more SQL queries that can perform the actions as explained by the user. The SQL template generated is a jinja2 template - so jinja2 syntax can be used. 1. Start with a comment to document the function name, parameters and docstring, explaining what the SQL query does. Make sure to start comments with `--` (Only 2 dashes, no more , no less) 2. Use jinja2 template to generate SQL @@ -15,6 +15,12 @@ You are an expert SQL developer. Write one or more SQL queries that can perform 12. DONOT use json_build_object to build JSON objects for nested fields 13. DONOT use '' to quote jinja variables. The binding would take care of that automatically. 14. Pay special attention to primary key (usually id fields). Sometimes, they are auto-generated in the schema, in which case insert queries should not set them. Otherwise, they must already be set in the model and then inserted into the table as well. +15. Based on the given database type, generate SQL specific to that database. Do NOT attempt to make SQL cross-compatible. Use correct syntax, features, and quoting for that specific database. Examples of the rules for different database include, but are not limited to + +- For SQLite, use `AUTOINCREMENT` +- For PostgreSQL, use `SERIAL` or `GENERATED` and using " for quoting column or table names +- For MySQL, use `AUTO_INCREMENT` +- Avoid features unsupported by the current DB type. Here is an example diff --git a/foundation_sql/query.py b/foundation_sql/query.py index edabf79..f1c4ae6 100644 --- a/foundation_sql/query.py +++ b/foundation_sql/query.py @@ -7,38 +7,39 @@ from foundation_sql.cache import SQLTemplateCache from foundation_sql import db from typing import Callable, Dict, Optional +from sqlalchemy.engine.url import make_url from importlib import resources as impresources -DEFAULT_SYSTEM_PROMPT = impresources.read_text('foundation_sql', 'prompts.md') +DEFAULT_SYSTEM_PROMPT = impresources.read_text("foundation_sql", "prompts.md") class SQLQueryDecorator: """ Advanced decorator for generating and executing SQL queries with comprehensive features. - + Supports: - Dynamic SQL template generation - Configurable LLM backend - Persistent template caching - Robust error handling and regeneration - + Attributes: name (Optional[str]): Custom name for SQL template regen (Optional[bool]): SQL template regeneration strategy config (SQLGeneratorConfig): Configuration for SQL generation """ - + def __init__( - self, - name: Optional[str] = None, + self, + name: Optional[str] = None, regen: Optional[bool] = None, repair: Optional[int] = 0, schema: Optional[str] = None, schema_path: Optional[str] = None, system_prompt: Optional[str] = None, system_prompt_path: Optional[str] = None, - cache_dir: Optional[str] = '__sql__', + cache_dir: Optional[str] = "__sql__", db_url: Optional[str] = None, api_key: Optional[str] = None, base_url: Optional[str] = None, @@ -46,116 +47,125 @@ def __init__( ): """ Initialize the SQL query decorator. - + Args: - name (Optional[str]): Custom name for SQL file/folder. + name (Optional[str]): Custom name for SQL file/folder. Defaults to function name. regen (Optional[bool]): SQL template regeneration strategy. - config (Optional[SQLGeneratorConfig]): Custom configuration + config (Optional[SQLGeneratorConfig]): Custom configuration for SQL generation. """ self.name = name self.regen = regen self.cache_dir = cache_dir self.db_url = db_url or os.environ.get("DATABASE_URL") - if ( not self.db_url): - raise ValueError(f"Database URL not provided either through constructor or DATABASE_URL environment variable") - - if (not schema and not schema_path): + + if not self.db_url: + raise ValueError( + f"Database URL not provided either through constructor or DATABASE_URL environment variable" + ) + + # Get the type of DB + try: + self.db_type = make_url(db_url).get_backend_name() + + except Exception as e: + raise ValueError(f"Error determining database type from DATABASE_URL: {e}") + + if not schema and not schema_path: # Load the schema from the database self.schema = db.extract_schema_from_db(self.db_url) else: self.schema = schema or self.load_file(schema_path) - + if system_prompt or system_prompt_path: self.system_prompt = system_prompt or self.load_file(system_prompt_path) else: self.system_prompt = DEFAULT_SYSTEM_PROMPT - + # Initialize cache and SQL generator self.cache = SQLTemplateCache(cache_dir=cache_dir) self.sql_generator = SQLGenerator( - api_key=api_key, - base_url=base_url, - model=model + api_key=api_key, base_url=base_url, model=model ) self.repair = repair - def __call__(self, func: Callable) -> Callable: """ Decorator implementation for SQL query generation and execution. - + Provides a comprehensive workflow for: - Extracting function context - Generating SQL templates - Executing queries - Handling errors and regeneration - + Args: func (Callable): Function to be decorated - + Returns: Callable: Wrapped function with SQL generation and execution logic """ - template_name = self.name or f"{func.__name__}.sql" + template_name = self.name or f"{func.__name__}_{self.db_type}.sql" fn_spec = FunctionSpec(func) prompt_generator = SQLPromptGenerator( - fn_spec, - template_name, - self.system_prompt, - self.schema) - + fn_spec, template_name, self.system_prompt, self.schema, self.db_type + ) - def sql_gen(kwargs: Dict[str, Any], error: Optional[str]=None, prev_template: Optional[str]=None): + def sql_gen( + kwargs: Dict[str, Any], + error: Optional[str] = None, + prev_template: Optional[str] = None, + ): if self.regen or not self.cache.exists(template_name) or error: - + prompt = prompt_generator.generate_prompt(kwargs, error, prev_template) sql_template = self.sql_generator.generate_sql(prompt) self.cache.set(template_name, sql_template) else: sql_template = self.cache.get(template_name) - + return sql_template @functools.wraps(func) def wrapper(**kwargs: Any) -> Any: error, sql_template = None, None # try: - # Run the SQL Template + # Run the SQL Template sql_template = sql_gen(kwargs, error, sql_template) result_data = db.run_sql(self.db_url, sql_template, **kwargs) - if fn_spec.wrapper == 'list': + if fn_spec.wrapper == "list": parsed_result = [ - db.parse_query_to_pydantic(row, fn_spec.return_type) + db.parse_query_to_pydantic(row, fn_spec.return_type) for row in result_data.all() ] elif isinstance(result_data, int): parsed_result = result_data else: first_row = result_data.first() - parsed_result = db.parse_query_to_pydantic(first_row, fn_spec.return_type) if first_row else None + parsed_result = ( + db.parse_query_to_pydantic(first_row, fn_spec.return_type) + if first_row + else None + ) return parsed_result - - return wrapper + return wrapper - def load_file(self, path: str) -> str: """ Load predefined table schemas. - + Returns: str: SQL schema definitions """ if not path or not os.path.exists(path): raise FileNotFoundError(f"Schema file not found at {path}") - with open(path, 'r') as f: + with open(path, "r") as f: return f.read() - diff --git a/requirements.txt b/requirements.txt index 4c1cdff..89d28e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,25 @@ -# Core dependencies -python-dotenv -sqlalchemy -jinja2sql -pydantic -openai +annotated-types==0.7.0 +anyio==4.9.0 +certifi==2025.4.26 +distro==1.9.0 +exceptiongroup==1.3.0 +greenlet==3.2.3 +h11==0.16.0 +httpcore==1.0.9 +httpx==0.28.1 +idna==3.10 +Jinja2==3.1.6 +jinja2sql==0.8.0 +jiter==0.10.0 +MarkupSafe==3.0.2 +openai==1.84.0 +psycopg2-binary==2.9.10 +pydantic==2.11.5 +pydantic_core==2.33.2 +PyMySQL==1.1.1 +python-dotenv==1.1.0 +sniffio==1.3.1 +SQLAlchemy==2.0.41 +tqdm==4.67.1 +typing-inspection==0.4.1 +typing_extensions==4.14.0 diff --git a/tests/test_mysql.py b/tests/test_mysql.py new file mode 100644 index 0000000..64148e7 --- /dev/null +++ b/tests/test_mysql.py @@ -0,0 +1,76 @@ +from typing import List +from tests import common +from pydantic import BaseModel +from tests.utils import setup_mysql_db + + +class Bike(BaseModel): + id: int = None # Optional for inserts + make: str + model: str + price: int + + +setup_mysql_db() + +query = common.create_query( + db_url=f"mysql+pymysql://user:userpassword@localhost:3306/foundation" +) + + +@query +def get_bikes() -> List[Bike]: + """ + Gets all bikes. + """ + pass + + +@query +def create_bike(bike: Bike) -> Bike: + """ + Creates a new bike. + """ + pass + + +@query +def get_total_price() -> int: + """ + Get the total price of all the bikes + """ + pass + + +@query +def get_last_inserted_id() -> int: + """ + Returns the last auto-incremented ID inserted into the session + """ + pass + + +class TestSchemaDiscovery(common.DatabaseTests): + db_url = f"mysql+pymysql://user:userpassword@localhost:3306/foundation" + schema_sql = None + + re_bike = Bike(make="RE", model="Classic", price=600) + create_bike(bike=re_bike) + + harley_bike = Bike(make="Harley", model="A very good one", price=500) + create_bike(bike=harley_bike) + + def test_get_bikes(self): + bikes = get_bikes() + self.assertEqual(len(bikes), 2) + + def test_last_insert_id(self): + new_bike = Bike(make="Yamaha", model="FZ", price=800) + create_bike(bike=new_bike) + + last_id = get_last_inserted_id() + self.assertTrue(isinstance(last_id, int) and last_id > 0) + + def test_total_price(self): + price = get_total_price() + self.assertEqual(price, 1900) diff --git a/tests/test_postgres.py b/tests/test_postgres.py new file mode 100644 index 0000000..d99ded0 --- /dev/null +++ b/tests/test_postgres.py @@ -0,0 +1,81 @@ +from typing import List +from pydantic import BaseModel +from tests import common +from tests.utils import setup_postgres_db + + +# Define a Bike model +class Bike(BaseModel): + id: int + make: str + model: str + price: int + + +setup_postgres_db() + + +# Point to your local PostgreSQL database +query = common.create_query( + db_url="postgresql+psycopg2://user:userpassword@localhost:5432/foundation" +) + + +# Query to insert a bike (PostgreSQL will likely use RETURNING automatically) +@query +def create_bike(bike: Bike) -> Bike: + """ + Inserts a bike and returns the created bike. + """ + pass + + +# Query to fetch all bikes +@query +def get_bikes() -> List[Bike]: + """ + Gets all bikes. + """ + pass + + +# Query to get total price of all bikes +@query +def get_total_price() -> int: + """ + Returns total price of all bikes. + """ + pass + + +# Query that naturally uses PostgreSQL-only operator +@query +def search_bikes_case_insensitive(make_query: str) -> List[Bike]: + """ + Returns bikes that match the given make (case-insensitive). + """ + pass + + +# The test class +class TestPostgresSchema(common.DatabaseTests): + db_url = "postgresql+psycopg2://user:userpassword@localhost:5432/foundation" + schema_sql = None + + re_bike = Bike(id=1, make="RE", model="Classic", price=600) + harley = Bike(id=2, make="Harley", model="Sportster", price=500) + create_bike(bike=re_bike) + create_bike(bike=harley) + + def test_get_bikes(self): + bikes = get_bikes() + self.assertEqual(len(bikes), 2) + + def test_total_price(self): + total = get_total_price() + self.assertEqual(total, 1100) + + def test_case_insensitive_search(self): + bikes = search_bikes_case_insensitive(make_query="re") + self.assertEqual(len(bikes), 1) + self.assertEqual(bikes[0].make.lower(), "re") diff --git a/tests/test_schema_discovery.py b/tests/test_schema_discovery.py index 206332e..be450c9 100644 --- a/tests/test_schema_discovery.py +++ b/tests/test_schema_discovery.py @@ -3,6 +3,7 @@ from pydantic import BaseModel from tests.utils import BIKES_DB_PATH, create_bike_db + class Bike(BaseModel): make: str model: str @@ -13,6 +14,7 @@ class Bike(BaseModel): query = common.create_query(db_url=f"sqlite:///{BIKES_DB_PATH}") + @query def get_bikes() -> List[Bike]: """ @@ -20,6 +22,7 @@ def get_bikes() -> List[Bike]: """ pass + @query def create_bike(bike: Bike) -> Bike: """ @@ -27,6 +30,7 @@ def create_bike(bike: Bike) -> Bike: """ pass + @query def get_total_price() -> int: """ @@ -39,16 +43,16 @@ class TestSchemaDiscovery(common.DatabaseTests): db_url = f"sqlite:///{BIKES_DB_PATH}" schema_sql = None - def test_schema_discovery(self): - - re_bike = Bike(make="RE", model="Classic", price=600) - create_bike(bike=re_bike) + re_bike = Bike(make="RE", model="Classic", price=600) + create_bike(bike=re_bike) - harley_bike = Bike(make="Harley", model="A very good one", price=500) - create_bike(bike = harley_bike) + harley_bike = Bike(make="Harley", model="A very good one", price=500) + create_bike(bike=harley_bike) + def test_get_bikes(self): bikes = get_bikes() self.assertEqual(len(bikes), 2) + def test_total_price(self): price = get_total_price() - self.assertEqual(price,1100) \ No newline at end of file + self.assertEqual(price, 1100) diff --git a/tests/utils.py b/tests/utils.py index 07b73b3..d90ad40 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,26 +1,71 @@ # utils.py import os import sqlite3 +from sqlalchemy import create_engine, text BIKES_DB_PATH = os.path.abspath( os.path.join(os.path.dirname(__file__), "fixtures", "bikes.db") ) -def create_bike_db(): - os.makedirs(os.path.dirname(BIKES_DB_PATH), exist_ok=True) +def create_bike_db(): + os.makedirs(os.path.dirname(BIKES_DB_PATH), exist_ok=True) if os.path.exists(BIKES_DB_PATH): os.remove(BIKES_DB_PATH) conn = sqlite3.connect(BIKES_DB_PATH) cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ CREATE TABLE bikes ( make TEXT NOT NULL, model TEXT NOT NULL, price INTEGER NOT NULL ); - """) + """ + ) conn.commit() conn.close() + + +def setup_mysql_db(): + engine = create_engine( + "mysql+pymysql://user:userpassword@localhost:3306/foundation" + ) + with engine.connect() as connection: + connection.execute(text("DROP TABLE IF EXISTS bikes;")) + connection.execute( + text( + """ + CREATE TABLE bikes ( + id INT AUTO_INCREMENT PRIMARY KEY, + make VARCHAR(255) NOT NULL, + model VARCHAR(255) NOT NULL, + price INT NOT NULL + ); + """ + ) + ) + connection.commit() + + +def setup_postgres_db(): + engine = create_engine( + "postgresql+psycopg2://user:userpassword@localhost:5432/foundation" + ) + with engine.connect() as connection: + connection.execute(text("DROP TABLE IF EXISTS bikes;")) + connection.execute( + text( + """ + CREATE TABLE bikes ( + id SERIAL PRIMARY KEY, + make VARCHAR(255) NOT NULL, + model VARCHAR(255) NOT NULL, + price INT NOT NULL + ); + """ + ) + ) + connection.commit() From 346c7e325b86e92fa9b196457bccfa376797da5c Mon Sep 17 00:00:00 2001 From: Rishi Pradeepkumar Date: Mon, 9 Jun 2025 04:37:32 +0530 Subject: [PATCH 3/4] readme changes --- README.md | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 8082505..7bebdc2 100644 --- a/README.md +++ b/README.md @@ -77,8 +77,8 @@ CREATE TABLE IF NOT EXISTS users ( ) """ -query = query.SQLQueryDecorator(schema=TABLES_SCHEMA, - db_url=DB_URL, +query = query.SQLQueryDecorator(schema=TABLES_SCHEMA, + db_url=DB_URL, api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_API_BASE_URL"), model=os.getenv("OPENAI_MODEL")) @@ -110,14 +110,14 @@ class TestQuery(unittest.TestCase): db.get_db(self.db_url).init_schema(schema_sql=self.schema_sql, schema_path=self.schema_path) else: raise ValueError("At least one of schema_sql, schema_path must be provided along with db_url") - + def test_users(self): users = get_users() self.assertEqual(len(users), 0) - + user = User(id="xxx", name="John Doe", email="john@example.com", role="user") create_user(user=user) - + users = get_users() self.assertEqual(len(users), 1) self.assertEqual(users[0], user) @@ -126,7 +126,7 @@ class TestQuery(unittest.TestCase): """Close the database connection after each test.""" for _, connection in db.DATABASES.items(): connection.get_engine().dispose() - + db.DATABASES.clear() ``` @@ -138,9 +138,9 @@ Running these tests would generate the following SQL files -- Creates a new user. -- Expects user.name, user.email and user.role to be defined INSERT INTO `users` ( - `id`, - `name`, - `email`, + `id`, + `name`, + `email`, `role` ) VALUES ( @@ -156,12 +156,12 @@ VALUES ( -- def get_users() -> List[tests.test_simple_query.User] -- Gets all users. -SELECT +SELECT `id` as `id`, `name` as `name`, `email` as `email`, `role` as `role` -FROM +FROM `users` ``` @@ -182,7 +182,18 @@ FROM cp .env_template .env ``` -- Run tests: `python -m unittest discover tests` +## Run tests + +```bash +docker compose up --build +python -m unittest discover tests +``` + +Run a particular test file + +```bash +python -m unittest -v tests/{test_file.py} +``` ## Project Structure From 059a0c4f1e036b615aced4306e7bad46ab0e1c4a Mon Sep 17 00:00:00 2001 From: Rishi Pradeepkumar Date: Mon, 9 Jun 2025 04:41:48 +0530 Subject: [PATCH 4/4] Updated comments on test files --- tests/test_postgres.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_postgres.py b/tests/test_postgres.py index d99ded0..391b1f3 100644 --- a/tests/test_postgres.py +++ b/tests/test_postgres.py @@ -15,13 +15,12 @@ class Bike(BaseModel): setup_postgres_db() -# Point to your local PostgreSQL database query = common.create_query( db_url="postgresql+psycopg2://user:userpassword@localhost:5432/foundation" ) -# Query to insert a bike (PostgreSQL will likely use RETURNING automatically) +# Query to insert a bike @query def create_bike(bike: Bike) -> Bike: """ @@ -48,7 +47,7 @@ def get_total_price() -> int: pass -# Query that naturally uses PostgreSQL-only operator +# Query that should naturally generate a PostgreSQL-only operator @query def search_bikes_case_insensitive(make_query: str) -> List[Bike]: """ @@ -57,7 +56,6 @@ def search_bikes_case_insensitive(make_query: str) -> List[Bike]: pass -# The test class class TestPostgresSchema(common.DatabaseTests): db_url = "postgresql+psycopg2://user:userpassword@localhost:5432/foundation" schema_sql = None