diff --git a/server/lib/execute-sql.ts b/server/lib/execute-sql.ts new file mode 100644 index 0000000..c65b5dc --- /dev/null +++ b/server/lib/execute-sql.ts @@ -0,0 +1,51 @@ +// Transport-agnostic helpers shared by every SQL execution path (the ConnectRPC ExecuteSQL +// handler and the MCP execution tools). Identity resolution, permission enforcement, and +// response shaping stay in each transport; what's identical — how a batch is made safe to run +// and how a Postgres error is rendered — lives here so the two paths can't drift. + +// The shape of detectRequiredPermissions() this module needs. +interface StatementAnalysis { + statementCount: number + transactionSafe: boolean +} + +// Build the SQL actually sent to the server. A safe multi-statement batch is wrapped in a +// transaction so a mid-batch failure rolls back; without it, PostgreSQL's Simple Query protocol +// runs each statement in autocommit, leaving 1..N-1 committed when statement N fails. Statements +// that cannot run inside a transaction (CREATE DATABASE, VACUUM, CREATE INDEX CONCURRENTLY) are +// excluded upstream via `transactionSafe`. The `\n;\n` before COMMIT terminates the user's last +// statement even when it lacks a trailing semicolon or ends in a line comment (a bare `;` would +// be swallowed by the comment). +export function buildExecutableSql(rawSql: string, analysis: StatementAnalysis): string { + return analysis.statementCount > 1 && analysis.transactionSafe ? `BEGIN;\n${rawSql}\n;\nCOMMIT;` : rawSql +} + +// Render a thrown postgres.js error into a readable message: the base message, plus line context +// derived from the error position, plus PostgreSQL's DETAIL and HINT when present. `sql` is the +// statement the position refers to. +export function formatExecutionError(err: unknown, sql: string): string { + const baseMessage = err instanceof Error ? err.message : 'Query execution failed' + let fullError = baseMessage + + const pgErr = err as Record + const pos = pgErr?.position + if (typeof pos === 'string' && pos) { + const charPos = parseInt(pos, 10) + if (charPos > 0) { + const before = sql.slice(0, charPos - 1) + const lineNumber = before.split('\n').length + const lines = sql.split('\n') + const offendingLine = lines[lineNumber - 1] + if (offendingLine !== undefined) { + fullError = `ERROR at Line ${lineNumber}: ${baseMessage}\nLINE ${lineNumber}: ${offendingLine}` + } + } + } + if (typeof pgErr?.detail === 'string' && pgErr.detail) { + fullError += `\nDETAIL: ${pgErr.detail}` + } + if (typeof pgErr?.hint === 'string' && pgErr.hint) { + fullError += `\nHINT: ${pgErr.hint}` + } + return fullError +} diff --git a/server/mcp.ts b/server/mcp.ts index da72b0f..a667580 100644 --- a/server/mcp.ts +++ b/server/mcp.ts @@ -15,6 +15,7 @@ import { getConnections, getAgentByToken, type AgentConfig } from './lib/config' import { withConnection, buildConnectionDetails, type ConnectionDetails } from './lib/db' import { getAgentPermissions, type Permission } from './lib/iam' import { detectRequiredPermissions } from './lib/sql-permissions' +import { buildExecutableSql, formatExecutionError } from './lib/execute-sql' import { auditSQL } from './lib/audit' declare const __APP_VERSION__: string @@ -484,11 +485,9 @@ async function execute(principal: Principal, tool: string, expectedPerm: Permiss } requireAll(have, analysis.permissions, tool) - // Wrap safe multi-statement batches in a transaction so a mid-batch failure rolls back. - // The `\n;\n` before COMMIT terminates the user's last statement even when it lacks a - // trailing semicolon or ends in a line comment (a bare `;` would be swallowed by the comment). - const finalSql = - analysis.statementCount > 1 && analysis.transactionSafe ? `BEGIN;\n${rawSql}\n;\nCOMMIT;` : rawSql + // Wrap safe multi-statement batches in a transaction so a mid-batch failure rolls back + // (see buildExecutableSql). + const finalSql = buildExecutableSql(rawSql, analysis) const result = await runAndAudit(principal, tool, connection, details, finalSql, rawSql) const rowCount = result.count ?? result.rows.length @@ -567,7 +566,9 @@ async function runAndAudit( } catch (err) { const message = err instanceof Error ? err.message : 'Query execution failed' auditSQL(actor, connection, details.database, auditSql, false, Date.now() - start, undefined, message, opts) - throw new Error(message) + // Surface the same line/DETAIL/HINT context the UI gets; audit keeps the bare message. + // Format against execSql (what Postgres ran) so the error `position` maps to the right line. + throw new Error(formatExecutionError(err, execSql)) } } diff --git a/server/services/query-service.ts b/server/services/query-service.ts index ffc46de..7be029f 100644 --- a/server/services/query-service.ts +++ b/server/services/query-service.ts @@ -7,6 +7,7 @@ import type postgres from "postgres"; import { getUserFromContext } from "../connect"; import { hasPermission, requirePermission, requirePermissions, requireAnyPermission } from "../lib/iam"; import { detectRequiredPermissions } from "../lib/sql-permissions"; +import { buildExecutableSql, formatExecutionError } from "../lib/execute-sql"; import { auditSQL, auditExport, listAuditEvents } from "../lib/audit"; // Track active queries by queryId -> { pid, connectionDetails, email } @@ -156,6 +157,11 @@ export const queryServiceHandlers: ServiceImpl = { const start = Date.now(); let backendPid = 0; + // The exact text sent to Postgres (multi-statement batches are transaction-wrapped). Built + // before the try so the catch can pass it to formatExecutionError — error `position` offsets + // index into this string, not the raw req.sql. + const executableSql = buildExecutableSql(req.sql, analysis); + try { // Get backend PID for cancellation support and monitoring correlation const pidResult = await client`SELECT pg_backend_pid() as pid`; @@ -185,17 +191,7 @@ export const queryServiceHandlers: ServiceImpl = { backendPid, }; - // Wrap multi-statement SQL in a transaction when safe. Without this, - // PostgreSQL's Simple Query protocol runs each statement in autocommit - // mode, so a failure in statement N leaves 1..N-1 committed. - // Statements like CREATE DATABASE, VACUUM, CREATE INDEX CONCURRENTLY - // cannot run inside a transaction and are excluded. - // The `\n;\n` terminates the user's last statement even when it lacks a - // trailing semicolon or ends in a line comment, so COMMIT isn't merged into it. - const sql = (analysis.statementCount > 1 && analysis.transactionSafe) - ? `BEGIN;\n${req.sql}\n;\nCOMMIT;` - : req.sql; - const result = await client.unsafe(sql); + const result = await client.unsafe(executableSql); const executionTimeMs = Date.now() - start; @@ -259,28 +255,9 @@ export const queryServiceHandlers: ServiceImpl = { const errorMessage = err instanceof Error ? err.message : "Query execution failed"; const executionTimeMs = Date.now() - start; - // Build a richer error with line context, detail, and hint from PostgreSQL - let fullError = errorMessage; - const pgErr = err as Record; - const pos = pgErr?.position; - if (typeof pos === 'string' && pos) { - const charPos = parseInt(pos, 10); - if (charPos > 0) { - const before = req.sql.slice(0, charPos - 1); - const lineNumber = before.split('\n').length; - const lines = req.sql.split('\n'); - const offendingLine = lines[lineNumber - 1]; - if (offendingLine !== undefined) { - fullError = `ERROR at Line ${lineNumber}: ${errorMessage}\nLINE ${lineNumber}: ${offendingLine}`; - } - } - } - if (typeof pgErr?.detail === 'string' && pgErr.detail) { - fullError += `\nDETAIL: ${pgErr.detail}`; - } - if (typeof pgErr?.hint === 'string' && pgErr.hint) { - fullError += `\nHINT: ${pgErr.hint}`; - } + // Build a richer error with line context, detail, and hint from PostgreSQL. Format against + // the executed SQL so the error `position` maps to the right line. + const fullError = formatExecutionError(err, executableSql); auditSQL(user.email, req.connectionId, details.database, req.sql, false, executionTimeMs, undefined, errorMessage) yield { diff --git a/tests/execute-sql.test.ts b/tests/execute-sql.test.ts new file mode 100644 index 0000000..04d964e --- /dev/null +++ b/tests/execute-sql.test.ts @@ -0,0 +1,57 @@ +import { describe, it, expect } from 'vitest' +import { buildExecutableSql, formatExecutionError } from '../server/lib/execute-sql' + +describe('buildExecutableSql', () => { + it('leaves a single statement untouched', () => { + expect(buildExecutableSql('SELECT 1', { statementCount: 1, transactionSafe: true })).toBe('SELECT 1') + }) + + it('wraps a safe multi-statement batch in BEGIN/COMMIT', () => { + const sql = 'INSERT INTO t VALUES (1);\nUPDATE t SET x = 2' + expect(buildExecutableSql(sql, { statementCount: 2, transactionSafe: true })).toBe(`BEGIN;\n${sql}\n;\nCOMMIT;`) + }) + + it('does not wrap a multi-statement batch that is not transaction-safe', () => { + const sql = 'VACUUM;\nSELECT 1' + expect(buildExecutableSql(sql, { statementCount: 2, transactionSafe: false })).toBe(sql) + }) + + it('does not wrap a single transaction-unsafe statement', () => { + expect(buildExecutableSql('VACUUM', { statementCount: 1, transactionSafe: false })).toBe('VACUUM') + }) +}) + +describe('formatExecutionError', () => { + it('returns the bare message when there is no position/detail/hint', () => { + expect(formatExecutionError(new Error('syntax error'), 'SELECT')).toBe('syntax error') + }) + + it('adds line context from the error position', () => { + const sql = 'SELECT 1\nFROM nope\nWHERE x' + // position points into line 2 (1-based char offset) + const err = Object.assign(new Error('relation "nope" does not exist'), { position: '15' }) + const out = formatExecutionError(err, sql) + expect(out).toContain('ERROR at Line 2:') + expect(out).toContain('LINE 2: FROM nope') + }) + + it('appends DETAIL and HINT when present', () => { + const err = Object.assign(new Error('boom'), { detail: 'the detail', hint: 'try this' }) + const out = formatExecutionError(err, 'SELECT 1') + expect(out).toBe('boom\nDETAIL: the detail\nHINT: try this') + }) + + it('falls back for a non-Error throwable', () => { + expect(formatExecutionError('weird', 'SELECT 1')).toBe('Query execution failed') + }) + + // Callers must pass the executed SQL (Postgres `position` indexes into what actually ran). For a + // transaction-wrapped batch, BEGIN; is line 1, so the user's statements shift down by one. + it('maps position onto the executed transaction-wrapped SQL', () => { + const raw = 'INSERT INTO t VALUES (1);\nUPDATE nope SET x = 2' + const executed = buildExecutableSql(raw, { statementCount: 2, transactionSafe: true }) + const pos = executed.indexOf('UPDATE') + 1 // 1-based offset into the executed string + const err = Object.assign(new Error('relation "nope" does not exist'), { position: String(pos) }) + expect(formatExecutionError(err, executed)).toContain('LINE 3: UPDATE nope SET x = 2') + }) +})