Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- The Delete shortcut in the data grid now follows a custom binding.
- Find Next (Cmd+G) and Find Previous (Cmd+Shift+G) now work in the editor.
- Pagination buttons no longer fire their page shortcut twice.
- Running a PostgreSQL script with a `DO $$ ... $$` block or a dollar-quoted function body no longer fails with an unterminated dollar-quoted string error. (#1559)
- AWS IAM connections no longer ask for a password on connect or reconnect. IAM supplies the credentials, so the prompt was never needed. The same now holds for any auth mode that replaces the password, such as a Postgres password file.
- Oracle connection failures show the listener's actual reason (such as an unknown service name) instead of a generic "server closed the connection" message. (#483)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ struct ConfirmDestructiveOperationChatTool: ChatTool {
isError: true
)
}
guard !QueryClassifier.isMultiStatement(query) else {
let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId)

guard !QueryClassifier.isMultiStatement(query, databaseType: meta.databaseType) else {
return ChatToolResult(
content: "Multi-statement queries are not supported. Send one statement at a time.",
isError: true
)
}

let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId)
let tier = QueryClassifier.classifyTier(query, databaseType: meta.databaseType)
guard tier == .destructive else {
return ChatToolResult(
Expand Down
7 changes: 4 additions & 3 deletions TablePro/Core/AI/Chat/Tools/ExecuteQueryChatTool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ struct ExecuteQueryChatTool: ChatTool {
guard (query as NSString).length <= 102_400 else {
return ChatToolResult(content: "Query exceeds 100KB limit", isError: true)
}
guard !QueryClassifier.isMultiStatement(query) else {

let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId)

guard !QueryClassifier.isMultiStatement(query, databaseType: meta.databaseType) else {
return ChatToolResult(
content: "Multi-statement queries are not supported. Send one statement at a time.",
isError: true
Expand All @@ -66,8 +69,6 @@ struct ExecuteQueryChatTool: ChatTool {
clamp: 1...300
) ?? mcpSettings.queryTimeoutSeconds

let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId)

let tier = QueryClassifier.classifyTier(query, databaseType: meta.databaseType)
if tier == .destructive {
return ChatToolResult(
Expand Down
2 changes: 1 addition & 1 deletion TablePro/Core/Coordinators/QueryExecutionCoordinator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ final class QueryExecutionCoordinator {
let fullQuery = tab.content.query
guard !fullQuery.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { return }

let statements = SQLStatementScanner.allStatements(in: fullQuery)
let statements = SQLStatementScanner.allStatements(in: fullQuery, dialect: parent.sqlDialect)
guard !statements.isEmpty else { return }

if AppSettingsManager.shared.editor.queryParametersEnabled {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ public struct ConfirmDestructiveOperationTool: MCPToolImplementation {
)
}

guard !QueryClassifier.isMultiStatement(query) else {
let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId)

guard !QueryClassifier.isMultiStatement(query, databaseType: meta.databaseType) else {
throw MCPProtocolError.invalidParams(
detail: "Multi-statement queries are not supported. Send one statement at a time."
)
}

let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId)

let tier = QueryClassifier.classifyTier(query, databaseType: meta.databaseType)
guard tier == .destructive else {
throw MCPProtocolError.invalidParams(
Expand Down
6 changes: 3 additions & 3 deletions TablePro/Core/MCP/Protocol/Tools/ExecuteQueryTool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ public struct ExecuteQueryTool: MCPToolImplementation {
throw MCPProtocolError.invalidParams(detail: "Query exceeds 100KB limit")
}

guard !QueryClassifier.isMultiStatement(query) else {
let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId)

guard !QueryClassifier.isMultiStatement(query, databaseType: meta.databaseType) else {
throw MCPProtocolError.invalidParams(
detail: "Multi-statement queries are not supported. Send one statement at a time."
)
Expand All @@ -86,8 +88,6 @@ public struct ExecuteQueryTool: MCPToolImplementation {
try await throwIfCancelled(context)
await context.progress.emit(progress: 0.0, total: 1.0, message: "Connecting")

let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId)

if let database {
_ = try await services.connectionBridge.switchDatabase(
connectionId: connectionId,
Expand Down
4 changes: 3 additions & 1 deletion TablePro/Core/Services/Execution/DefaultExecutionGate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ internal actor DefaultExecutionGate: ExecutionGate {
let tier = request.sql.map { QueryClassifier.classifyTier($0, databaseType: request.databaseType) }
let isDangerous = request.sql.map { QueryClassifier.isDangerousQuery($0, databaseType: request.databaseType) } ?? false
let isDestructive = request.kind.declaresDestructive || tier == .destructive || isDangerous
let isMultiStatement = request.sql.map { QueryClassifier.isMultiStatement($0) } ?? false
let isMultiStatement = request.sql.map {
QueryClassifier.isMultiStatement($0, databaseType: request.databaseType)
} ?? false
let effectiveWrite = await resolveEffectiveWrite(request, tier: tier)

if let denial = capabilityDenial(
Expand Down
8 changes: 6 additions & 2 deletions TablePro/Core/Utilities/SQL/QueryClassifier.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//

import Foundation
import TableProPluginKit

enum QueryTier {
case safe
Expand Down Expand Up @@ -129,8 +130,11 @@ enum QueryClassifier {
return .safe
}

static func isMultiStatement(_ sql: String) -> Bool {
SQLStatementScanner.allStatements(in: sql).count > 1
static func isMultiStatement(_ sql: String, databaseType: DatabaseType) -> Bool {
SQLStatementScanner.allStatements(
in: sql,
dialect: SqlDialect.from(databaseTypeId: databaseType.rawValue)
).count > 1
Comment on lines +134 to +137
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Classify PostgreSQL DO blocks before unblocking them

For PostgreSQL, this now makes DO $$ BEGIN DROP TABLE t; END $$; count as a single statement, but classifyTier still only checks top-level prefixes and returns .safe for DO. The AI/MCP execute_query paths rely on this multi-statement guard plus the tier check before running with pre-cleared capabilities, so a destructive operation hidden inside a dollar-quoted DO block can bypass the dedicated confirmation tool and safe-mode write/destructive classification. Please either classify DO blocks conservatively for PostgreSQL or keep a separate non-UI guard for these executable blocks.

Useful? React with 👍 / 👎.

}

static func isExplainStatement(_ sql: String) -> Bool {
Expand Down
56 changes: 2 additions & 54 deletions TablePro/Core/Utilities/SQL/SQLFileParser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,6 @@ final class SQLFileParser: Sendable {
private static let kCapitalE: unichar = 0x45
private static let kSmallE: unichar = 0x65

private static func isIdentifierStart(_ ch: unichar) -> Bool {
(ch >= 0x41 && ch <= 0x5A) || (ch >= 0x61 && ch <= 0x7A) || ch == 0x5F
}

private static func isIdentifierPart(_ ch: unichar) -> Bool {
isIdentifierStart(ch) || (ch >= 0x30 && ch <= 0x39)
}

private enum DollarQuoteScan {
case opener(length: Int, tag: String)
case notOpener
case needsMoreData
}

nonisolated private static func needsLookahead(
_ char: unichar,
state: ParserState,
Expand Down Expand Up @@ -163,44 +149,6 @@ final class SQLFileParser: Sendable {
}
}

private static func scanDollarQuoteOpener(
at pos: Int, in buffer: NSString, bufLen: Int
) -> DollarQuoteScan {
var p = pos + 1
while p < bufLen {
let ch = buffer.character(at: p)
if ch == kDollar {
let tagLen = p - pos - 1
if tagLen == 0 {
return .opener(length: 2, tag: "")
}
let firstChar = buffer.character(at: pos + 1)
if !isIdentifierStart(firstChar) {
return .notOpener
}
let tag = buffer.substring(with: NSRange(location: pos + 1, length: tagLen))
return .opener(length: tagLen + 2, tag: tag)
}
if !isIdentifierPart(ch) {
return .notOpener
}
p += 1
}
return .needsMoreData
}

private static func matchesDollarClose(
at pos: Int, tag: String, in buffer: NSString, bufLen: Int
) -> Bool {
let closeLen = (tag as NSString).length + 2
guard pos + closeLen <= bufLen else { return false }
if buffer.character(at: pos) != kDollar { return false }
if buffer.character(at: pos + closeLen - 1) != kDollar { return false }
if tag.isEmpty { return true }
let tagRange = NSRange(location: pos + 1, length: (tag as NSString).length)
return buffer.substring(with: tagRange) == tag
}

private struct StepResult {
var advanced: Bool
var deferred: Bool
Expand Down Expand Up @@ -255,7 +203,7 @@ final class SQLFileParser: Sendable {
}

if ctx.dialect.supportsDollarQuotes && char == kDollar {
switch scanDollarQuoteOpener(at: i, in: nsBuffer, bufLen: bufLen) {
switch SqlDollarQuote.scanOpener(at: i, in: nsBuffer, bufLen: bufLen) {
case .opener(let length, let tag):
(ctx.hasStatementContent, ctx.statementStartLine) = markContent(
ctx.hasStatementContent, ctx.statementStartLine, ctx.currentLine)
Expand Down Expand Up @@ -452,7 +400,7 @@ final class SQLFileParser: Sendable {
i = pos
return StepResult(advanced: true, deferred: true)
}
if matchesDollarClose(at: pos, tag: ctx.dollarTag, in: nsBuffer, bufLen: bufLen) {
if SqlDollarQuote.matchesClose(at: pos, tag: ctx.dollarTag, in: nsBuffer, bufLen: bufLen) {
pos += closeLen
ctx.state = .normal
ctx.dollarTag = ""
Expand Down
42 changes: 34 additions & 8 deletions TablePro/Core/Utilities/SQL/SQLStatementScanner.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
//

import Foundation
import TableProPluginKit

enum SQLStatementScanner {
struct LocatedStatement {
let sql: String
let offset: Int
}

/// Returns statements with trailing semicolons stripped for driver execution.
static func allStatements(in sql: String) -> [String] {
/// Returns statements with trailing semicolons stripped, for driver execution.
static func allStatements(in sql: String, dialect: SqlDialect = .generic) -> [String] {
var results: [String] = []
scan(sql: sql, cursorPosition: nil) { rawSQL, _ in
scan(sql: sql, cursorPosition: nil, dialect: dialect) { rawSQL, _ in
var trimmed = rawSQL.trimmingCharacters(in: .whitespacesAndNewlines)
if trimmed.hasSuffix(";") {
trimmed = String(trimmed.dropLast())
Expand All @@ -28,7 +29,7 @@ enum SQLStatementScanner {
return results
}

/// Returns statements preserving trailing semicolons for display/history/favorites.
/// Returns statements preserving trailing semicolons, for display/history/favorites.
static func allStatementsPreservingSemicolons(in sql: String) -> [String] {
var results: [String] = []
scan(sql: sql, cursorPosition: nil) { rawSQL, _ in
Expand All @@ -44,8 +45,8 @@ enum SQLStatementScanner {
return results
}

static func statementAtCursor(in sql: String, cursorPosition: Int) -> String {
var result = locatedStatementAtCursor(in: sql, cursorPosition: cursorPosition)
static func statementAtCursor(in sql: String, cursorPosition: Int, dialect: SqlDialect = .generic) -> String {
var result = locatedStatementAtCursor(in: sql, cursorPosition: cursorPosition, dialect: dialect)
.sql
.trimmingCharacters(in: .whitespacesAndNewlines)
if result.hasSuffix(";") {
Expand All @@ -55,9 +56,9 @@ enum SQLStatementScanner {
return result
}

static func locatedStatementAtCursor(in sql: String, cursorPosition: Int) -> LocatedStatement {
static func locatedStatementAtCursor(in sql: String, cursorPosition: Int, dialect: SqlDialect = .generic) -> LocatedStatement {
var result = LocatedStatement(sql: "", offset: 0)
scan(sql: sql, cursorPosition: cursorPosition) { rawSQL, offset in
scan(sql: sql, cursorPosition: cursorPosition, dialect: dialect) { rawSQL, offset in
result = LocatedStatement(sql: rawSQL, offset: offset)
return false
}
Expand All @@ -75,10 +76,12 @@ enum SQLStatementScanner {
private static let star = UInt16(UnicodeScalar("*").value)
private static let newline = UInt16(UnicodeScalar("\n").value)
private static let backslash = UInt16(UnicodeScalar("\\").value)
private static let dollar = UInt16(UnicodeScalar("$").value)

private static func scan(
sql: String,
cursorPosition: Int?,
dialect: SqlDialect = .generic,
onStatement: (_ rawSQL: String, _ offset: Int) -> Bool
) {
let nsQuery = sql as NSString
Expand All @@ -97,6 +100,9 @@ enum SQLStatementScanner {
var stringCharVal: UInt16 = 0
var inLineComment = false
var inBlockComment = false
var inDollarQuote = false
var dollarTag = ""
let dollarQuotesEnabled = dialect.supportsDollarQuotes
var i = 0

while i < length {
Expand All @@ -118,6 +124,18 @@ enum SQLStatementScanner {
continue
}

if inDollarQuote {
if ch == dollar,
SqlDollarQuote.matchesClose(at: i, tag: dollarTag, in: nsQuery, bufLen: length) {
inDollarQuote = false
i += (dollarTag as NSString).length + 2
dollarTag = ""
continue
}
i += 1
continue
}

if !inString && ch == dash && i + 1 < length && nsQuery.character(at: i + 1) == dash {
inLineComment = true
i += 2
Expand Down Expand Up @@ -148,6 +166,14 @@ enum SQLStatementScanner {
}
}

if dollarQuotesEnabled, !inString, ch == dollar,
case .opener(let openerLength, let tag) = SqlDollarQuote.scanOpener(at: i, in: nsQuery, bufLen: length) {
inDollarQuote = true
dollarTag = tag
i += openerLength
continue
}

if ch == semicolonChar && !inString {
let stmtEnd = i + 1

Expand Down
74 changes: 74 additions & 0 deletions TablePro/Core/Utilities/SQL/SqlDollarQuote.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//
// SqlDollarQuote.swift
// TablePro
//

import Foundation

enum SqlDollarQuote {
enum Opener {
case opener(length: Int, tag: String)
case notOpener
case needsMoreData
}

static let dollar: unichar = 0x24

static func isIdentifierStart(_ ch: unichar) -> Bool {
(ch >= 0x41 && ch <= 0x5A) || (ch >= 0x61 && ch <= 0x7A) || ch == 0x5F
}

static func isIdentifierPart(_ ch: unichar) -> Bool {
isIdentifierStart(ch) || (ch >= 0x30 && ch <= 0x39)
}

/// Whether a `$` following this character is part of the preceding identifier,
/// per PostgreSQL's rule that a dollar quote must be separated from a
/// preceding identifier by whitespace (so `a$$b` is one identifier, not an
/// opener).
static func isIdentifierContinuation(_ ch: unichar) -> Bool {
isIdentifierPart(ch) || ch == dollar
}

/// Resolves a `$` at `pos` to a dollar-quote opener, a positional parameter
/// like `$1`, or a non-tag dollar. A `$` glued to a preceding identifier is
/// not an opener. Returns `needsMoreData` when the buffer ends mid-tag; a
/// whole-string caller treats that as `notOpener`.
static func scanOpener(at pos: Int, in buffer: NSString, bufLen: Int) -> Opener {
if pos > 0, isIdentifierContinuation(buffer.character(at: pos - 1)) {
return .notOpener
}
var p = pos + 1
while p < bufLen {
let ch = buffer.character(at: p)
if ch == dollar {
let tagLen = p - pos - 1
if tagLen == 0 {
return .opener(length: 2, tag: "")
}
if !isIdentifierStart(buffer.character(at: pos + 1)) {
return .notOpener
}
let tag = buffer.substring(with: NSRange(location: pos + 1, length: tagLen))
return .opener(length: tagLen + 2, tag: tag)
}
if !isIdentifierPart(ch) {
return .notOpener
}
p += 1
}
return .needsMoreData
}

/// Whether the closing delimiter for `tag` starts at `pos`. The tag match is
/// exact and case-sensitive, per PostgreSQL.
static func matchesClose(at pos: Int, tag: String, in buffer: NSString, bufLen: Int) -> Bool {
let closeLen = (tag as NSString).length + 2
guard pos + closeLen <= bufLen else { return false }
if buffer.character(at: pos) != dollar { return false }
if buffer.character(at: pos + closeLen - 1) != dollar { return false }
if tag.isEmpty { return true }
let tagRange = NSRange(location: pos + 1, length: (tag as NSString).length)
return buffer.substring(with: tagRange) == tag
}
}
Loading
Loading