Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 62 additions & 67 deletions foundation_sql/db.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,58 @@
"""
Database operations module for Foundation.
Database operations module for Foundation (adapter-based).
"""

import logging
import os
from types import NoneType
from typing import Dict, Any, Optional, Type, Union, List
from pydantic import BaseModel
from sqlalchemy import create_engine, text
from sqlalchemy import text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError

from jinja2sql import Jinja2SQL
from datetime import datetime
from foundation_sql.db_drivers import EngineAdapter, SQLAlchemyAdapter
from foundation_sql.db_drivers import AsyncpgAdapter

NESTED_SPLITTER = "."
# Singleton instance
DATABASES = {}

j2sql = Jinja2SQL()

# logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger(__name__)



class Database:
"""Database operations for Foundation."""
"""Database operations for Foundation delegated to an EngineAdapter."""

def __init__(self, db_url: Optional[str] = None) -> None:
"""Initialize the database.
def __init__(self, db_url: Optional[str] = None, adapter: Optional[EngineAdapter] = None) -> None:
"""Initialize the database facade.

Args:
db_url: SQLAlchemy database URL (e.g., postgresql://user:pass@host/db, sqlite:///path/to/db)
db_url: Database URL (e.g., postgresql://user:pass@host/db, sqlite:///path/to/db)
If not provided, will use DATABASE_URL environment variable
adapter: Optional explicit adapter. If not provided, a default sync adapter is created.
"""
self.db_url = db_url or os.getenv('DATABASE_URL')
if not self.db_url:
raise ValueError('Database URL must be provided either through constructor or DATABASE_URL environment variable')

# Create SQLAlchemy engine
self.engine = create_engine(self.db_url)
# Default to sync SQLAlchemy adapter unless explicitly provided
self.adapter: EngineAdapter = adapter or SQLAlchemyAdapter(self.db_url)

def get_engine(self) -> Engine:
"""Get the SQLAlchemy engine.
"""Get the underlying SQLAlchemy engine if available.

Returns:
SQLAlchemy Engine instance
"""
return self.engine
# Only available for SQLAlchemyAdapter
if isinstance(self.adapter, SQLAlchemyAdapter):
return self.adapter.engine
raise RuntimeError("Engine is not available for this adapter")

def init_schema(self, schema_sql:Optional[str]=None, schema_path: Optional[str] = None) -> None:
"""Initialize the database schema if it doesn't exist.
Expand All @@ -65,15 +68,8 @@ def init_schema(self, schema_sql:Optional[str]=None, schema_path: Optional[str]
if not schema_sql:
with open(schema_path, 'r') as f:
schema_sql = f.read()

# Connect and execute schema in a transaction
with self.engine.begin() as conn:
try:
for statement in schema_sql.split(';'):
if statement.strip():
conn.execute(text(statement))
except SQLAlchemyError as e:
raise RuntimeError(f'Failed to initialize schema: {str(e)}') from e
# Delegate to adapter
self.adapter.init_schema(schema_sql)

def run_sql(self, sql_template: str, **context) -> Any:
"""Run an SQL template string with jinja2sql for rendering and parameter substitution.
Expand All @@ -90,50 +86,35 @@ def run_sql(self, sql_template: str, **context) -> Any:
if 'now' not in context:
context['now'] = datetime.now

# Use jinja2sql to render the template and get SQL with parameters
try:
query, params = j2sql.from_string(
sql_template,
context=context
)
except Exception as e:
raise ValueError(f"Failed to render SQL. Likely SQL template & Parameter mismatch: {str(e)}") from e

# Execute the SQL with parameters in a transaction
with self.engine.connect() as conn:
with conn.begin():
try:
# Split the query into individual statements for SQLite
statements = [stmt.strip() for stmt in query.split(';') if stmt.strip()]

# Track total rows affected
total_rows = 0
last_result = None

logger.debug(f"Executing statements: {len(statements)}")
# Execute each statement
for statement in statements:
logger.debug(f"Executing statement: {statement} {params}")
result = conn.execute(text(statement), params)
total_rows += result.rowcount

# Store the last result for potential return
if result.returns_rows:
last_result = result

# If last result is a SELECT, return its rows
if last_result and last_result.returns_rows:
# Convert SQLAlchemy Row objects to plain dictionaries
# to avoid leaking SQLAlchemy implementation details
rows = [dict(row._mapping) for row in last_result]
logger.debug(f"Returning rows: {rows}")
return QueryResult(rows)

# If no rows affected, return 0
return total_rows
except SQLAlchemyError as e:
# The transaction will be rolled back automatically
raise RuntimeError(f"Failed to execute SQL: {str(e)}\nRendered SQL: {query}") from e
# Delegate to adapter; it returns either list[dict] rows or int rowcount
result = self.adapter.run_sql(sql_template, context)
if isinstance(result, int):
return result
# assume list of dicts
return QueryResult(result)

# ---------- Async delegates (Phase 2) ----------
async def init_schema_async(self, schema_sql: Optional[str] = None, schema_path: Optional[str] = None) -> None:
if not hasattr(self.adapter, 'init_schema_async'):
raise NotImplementedError("Async schema init not supported by this adapter")
if not schema_sql:
with open(schema_path, 'r') as f:
schema_sql = f.read()
await self.adapter.init_schema_async(schema_sql) # type: ignore[attr-defined]

async def run_sql_async(self, sql_template: str, **context) -> Any:
if not hasattr(self.adapter, 'run_sql_async'):
raise NotImplementedError("Async run_sql not supported by this adapter")
if 'now' not in context:
context['now'] = datetime.now
result = await self.adapter.run_sql_async(sql_template, context) # type: ignore[attr-defined]
if isinstance(result, int):
return result
return QueryResult(result)

async def close_async(self) -> None:
if hasattr(self.adapter, 'close_async'):
await self.adapter.close_async() # type: ignore[attr-defined]


def execute(self, sql: str, params: Optional[Union[tuple, dict, List[tuple]]] = None) -> Any:
Expand All @@ -151,7 +132,11 @@ def execute(self, sql: str, params: Optional[Union[tuple, dict, List[tuple]]] =
Returns:
Any: Result of the execution
"""
with self.engine.connect() as connection:
# Only supported on SQLAlchemy adapter path for now
if not isinstance(self.adapter, SQLAlchemyAdapter):
raise NotImplementedError("execute() is only supported for SQLAlchemy adapter")

with self.adapter.engine.connect() as connection:
try:
# Replace '?' placeholders with SQLAlchemy named parameters
if '?' in sql:
Expand Down Expand Up @@ -249,6 +234,16 @@ def get_db(db_url: str) -> Database:

return DATABASES[db_url]

def get_db_with_adapter(db_url: str, mode: str) -> Database:
"""Internal helper for selecting adapter explicitly.
mode: "sync" | "async" (async not implemented yet)
"""
if mode == "sync":
return Database(db_url, adapter=SQLAlchemyAdapter(db_url))
if mode == "async":
return Database(db_url, adapter=AsyncpgAdapter(db_url))
raise ValueError(f"Unknown adapter mode: {mode}")

def run_sql(db_url: str, sql_template: str, **context) -> Any:
"""Run an SQL template string with jinja2sql for rendering and parameter substitution.

Expand Down
Loading