From 5a1b2445a1ca8f557d4ba7ce7d8b260996fd1b53 Mon Sep 17 00:00:00 2001 From: Ngo Quoc Dat Date: Tue, 2 Jun 2026 22:43:06 +0700 Subject: [PATCH 1/2] fix(editor): make statement splitting dollar-quote aware for PostgreSQL (#1559) --- CHANGELOG.md | 1 + .../ConfirmDestructiveOperationChatTool.swift | 5 +- .../AI/Chat/Tools/ExecuteQueryChatTool.swift | 7 ++- .../QueryExecutionCoordinator.swift | 2 +- .../ConfirmDestructiveOperationTool.swift | 6 +- .../MCP/Protocol/Tools/ExecuteQueryTool.swift | 6 +- .../Execution/DefaultExecutionGate.swift | 4 +- .../Core/Utilities/SQL/QueryClassifier.swift | 8 ++- .../Core/Utilities/SQL/SQLFileParser.swift | 56 +---------------- .../Utilities/SQL/SQLStatementScanner.swift | 52 ++++++++++++--- .../Core/Utilities/SQL/SqlDollarQuote.swift | 62 ++++++++++++++++++ .../MainContentCoordinator+ClickHouse.swift | 5 +- .../Views/Main/MainContentCoordinator.swift | 13 ++-- .../Utilities/SQLStatementScannerTests.swift | 63 ++++++++++++++++++- 14 files changed, 203 insertions(+), 87 deletions(-) create mode 100644 TablePro/Core/Utilities/SQL/SqlDollarQuote.swift diff --git a/CHANGELOG.md b/CHANGELOG.md index c1d633510..c090e7851 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - The Delete shortcut in the data grid now follows a custom binding. - Find Next (Cmd+G) and Find Previous (Cmd+Shift+G) now work in the editor. - Pagination buttons no longer fire their page shortcut twice. +- Running a PostgreSQL script with a `DO $$ ... $$` block or a dollar-quoted function body no longer fails with an unterminated dollar-quoted string error. (#1559) ## [0.48.0] - 2026-06-02 diff --git a/TablePro/Core/AI/Chat/Tools/ConfirmDestructiveOperationChatTool.swift b/TablePro/Core/AI/Chat/Tools/ConfirmDestructiveOperationChatTool.swift index 5c1f19deb..4822aadde 100644 --- a/TablePro/Core/AI/Chat/Tools/ConfirmDestructiveOperationChatTool.swift +++ b/TablePro/Core/AI/Chat/Tools/ConfirmDestructiveOperationChatTool.swift @@ -36,14 +36,15 @@ struct ConfirmDestructiveOperationChatTool: ChatTool { isError: true ) } - guard !QueryClassifier.isMultiStatement(query) else { + let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId) + + guard !QueryClassifier.isMultiStatement(query, databaseType: meta.databaseType) else { return ChatToolResult( content: "Multi-statement queries are not supported. Send one statement at a time.", isError: true ) } - let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId) let tier = QueryClassifier.classifyTier(query, databaseType: meta.databaseType) guard tier == .destructive else { return ChatToolResult( diff --git a/TablePro/Core/AI/Chat/Tools/ExecuteQueryChatTool.swift b/TablePro/Core/AI/Chat/Tools/ExecuteQueryChatTool.swift index 62c0bc10a..88f2ffb66 100644 --- a/TablePro/Core/AI/Chat/Tools/ExecuteQueryChatTool.swift +++ b/TablePro/Core/AI/Chat/Tools/ExecuteQueryChatTool.swift @@ -45,7 +45,10 @@ struct ExecuteQueryChatTool: ChatTool { guard (query as NSString).length <= 102_400 else { return ChatToolResult(content: "Query exceeds 100KB limit", isError: true) } - guard !QueryClassifier.isMultiStatement(query) else { + + let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId) + + guard !QueryClassifier.isMultiStatement(query, databaseType: meta.databaseType) else { return ChatToolResult( content: "Multi-statement queries are not supported. Send one statement at a time.", isError: true @@ -66,8 +69,6 @@ struct ExecuteQueryChatTool: ChatTool { clamp: 1...300 ) ?? mcpSettings.queryTimeoutSeconds - let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId) - let tier = QueryClassifier.classifyTier(query, databaseType: meta.databaseType) if tier == .destructive { return ChatToolResult( diff --git a/TablePro/Core/Coordinators/QueryExecutionCoordinator.swift b/TablePro/Core/Coordinators/QueryExecutionCoordinator.swift index 8a199cd55..4df3e6855 100644 --- a/TablePro/Core/Coordinators/QueryExecutionCoordinator.swift +++ b/TablePro/Core/Coordinators/QueryExecutionCoordinator.swift @@ -23,7 +23,7 @@ final class QueryExecutionCoordinator { let fullQuery = tab.content.query guard !fullQuery.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { return } - let statements = SQLStatementScanner.allStatements(in: fullQuery) + let statements = SQLStatementScanner.allStatements(in: fullQuery, dialect: parent.sqlDialect) guard !statements.isEmpty else { return } if AppSettingsManager.shared.editor.queryParametersEnabled { diff --git a/TablePro/Core/MCP/Protocol/Tools/ConfirmDestructiveOperationTool.swift b/TablePro/Core/MCP/Protocol/Tools/ConfirmDestructiveOperationTool.swift index 3479212c9..cb09a1ccf 100644 --- a/TablePro/Core/MCP/Protocol/Tools/ConfirmDestructiveOperationTool.swift +++ b/TablePro/Core/MCP/Protocol/Tools/ConfirmDestructiveOperationTool.swift @@ -57,14 +57,14 @@ public struct ConfirmDestructiveOperationTool: MCPToolImplementation { ) } - guard !QueryClassifier.isMultiStatement(query) else { + let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId) + + guard !QueryClassifier.isMultiStatement(query, databaseType: meta.databaseType) else { throw MCPProtocolError.invalidParams( detail: "Multi-statement queries are not supported. Send one statement at a time." ) } - let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId) - let tier = QueryClassifier.classifyTier(query, databaseType: meta.databaseType) guard tier == .destructive else { throw MCPProtocolError.invalidParams( diff --git a/TablePro/Core/MCP/Protocol/Tools/ExecuteQueryTool.swift b/TablePro/Core/MCP/Protocol/Tools/ExecuteQueryTool.swift index d13c657c6..f1c01d9a2 100644 --- a/TablePro/Core/MCP/Protocol/Tools/ExecuteQueryTool.swift +++ b/TablePro/Core/MCP/Protocol/Tools/ExecuteQueryTool.swift @@ -77,7 +77,9 @@ public struct ExecuteQueryTool: MCPToolImplementation { throw MCPProtocolError.invalidParams(detail: "Query exceeds 100KB limit") } - guard !QueryClassifier.isMultiStatement(query) else { + let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId) + + guard !QueryClassifier.isMultiStatement(query, databaseType: meta.databaseType) else { throw MCPProtocolError.invalidParams( detail: "Multi-statement queries are not supported. Send one statement at a time." ) @@ -86,8 +88,6 @@ public struct ExecuteQueryTool: MCPToolImplementation { try await throwIfCancelled(context) await context.progress.emit(progress: 0.0, total: 1.0, message: "Connecting") - let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId) - if let database { _ = try await services.connectionBridge.switchDatabase( connectionId: connectionId, diff --git a/TablePro/Core/Services/Execution/DefaultExecutionGate.swift b/TablePro/Core/Services/Execution/DefaultExecutionGate.swift index cbc88d7a5..a7a7c5c7c 100644 --- a/TablePro/Core/Services/Execution/DefaultExecutionGate.swift +++ b/TablePro/Core/Services/Execution/DefaultExecutionGate.swift @@ -30,7 +30,9 @@ internal actor DefaultExecutionGate: ExecutionGate { let tier = request.sql.map { QueryClassifier.classifyTier($0, databaseType: request.databaseType) } let isDangerous = request.sql.map { QueryClassifier.isDangerousQuery($0, databaseType: request.databaseType) } ?? false let isDestructive = request.kind.declaresDestructive || tier == .destructive || isDangerous - let isMultiStatement = request.sql.map { QueryClassifier.isMultiStatement($0) } ?? false + let isMultiStatement = request.sql.map { + QueryClassifier.isMultiStatement($0, databaseType: request.databaseType) + } ?? false let effectiveWrite = await resolveEffectiveWrite(request, tier: tier) if let denial = capabilityDenial( diff --git a/TablePro/Core/Utilities/SQL/QueryClassifier.swift b/TablePro/Core/Utilities/SQL/QueryClassifier.swift index 4bf166650..be3a12fa6 100644 --- a/TablePro/Core/Utilities/SQL/QueryClassifier.swift +++ b/TablePro/Core/Utilities/SQL/QueryClassifier.swift @@ -4,6 +4,7 @@ // import Foundation +import TableProPluginKit enum QueryTier { case safe @@ -129,8 +130,11 @@ enum QueryClassifier { return .safe } - static func isMultiStatement(_ sql: String) -> Bool { - SQLStatementScanner.allStatements(in: sql).count > 1 + static func isMultiStatement(_ sql: String, databaseType: DatabaseType) -> Bool { + SQLStatementScanner.allStatements( + in: sql, + dialect: SqlDialect.from(databaseTypeId: databaseType.rawValue) + ).count > 1 } static func isExplainStatement(_ sql: String) -> Bool { diff --git a/TablePro/Core/Utilities/SQL/SQLFileParser.swift b/TablePro/Core/Utilities/SQL/SQLFileParser.swift index d4590d6a8..d570eae5c 100644 --- a/TablePro/Core/Utilities/SQL/SQLFileParser.swift +++ b/TablePro/Core/Utilities/SQL/SQLFileParser.swift @@ -38,20 +38,6 @@ final class SQLFileParser: Sendable { private static let kCapitalE: unichar = 0x45 private static let kSmallE: unichar = 0x65 - private static func isIdentifierStart(_ ch: unichar) -> Bool { - (ch >= 0x41 && ch <= 0x5A) || (ch >= 0x61 && ch <= 0x7A) || ch == 0x5F - } - - private static func isIdentifierPart(_ ch: unichar) -> Bool { - isIdentifierStart(ch) || (ch >= 0x30 && ch <= 0x39) - } - - private enum DollarQuoteScan { - case opener(length: Int, tag: String) - case notOpener - case needsMoreData - } - nonisolated private static func needsLookahead( _ char: unichar, state: ParserState, @@ -163,44 +149,6 @@ final class SQLFileParser: Sendable { } } - private static func scanDollarQuoteOpener( - at pos: Int, in buffer: NSString, bufLen: Int - ) -> DollarQuoteScan { - var p = pos + 1 - while p < bufLen { - let ch = buffer.character(at: p) - if ch == kDollar { - let tagLen = p - pos - 1 - if tagLen == 0 { - return .opener(length: 2, tag: "") - } - let firstChar = buffer.character(at: pos + 1) - if !isIdentifierStart(firstChar) { - return .notOpener - } - let tag = buffer.substring(with: NSRange(location: pos + 1, length: tagLen)) - return .opener(length: tagLen + 2, tag: tag) - } - if !isIdentifierPart(ch) { - return .notOpener - } - p += 1 - } - return .needsMoreData - } - - private static func matchesDollarClose( - at pos: Int, tag: String, in buffer: NSString, bufLen: Int - ) -> Bool { - let closeLen = (tag as NSString).length + 2 - guard pos + closeLen <= bufLen else { return false } - if buffer.character(at: pos) != kDollar { return false } - if buffer.character(at: pos + closeLen - 1) != kDollar { return false } - if tag.isEmpty { return true } - let tagRange = NSRange(location: pos + 1, length: (tag as NSString).length) - return buffer.substring(with: tagRange) == tag - } - private struct StepResult { var advanced: Bool var deferred: Bool @@ -255,7 +203,7 @@ final class SQLFileParser: Sendable { } if ctx.dialect.supportsDollarQuotes && char == kDollar { - switch scanDollarQuoteOpener(at: i, in: nsBuffer, bufLen: bufLen) { + switch SqlDollarQuote.scanOpener(at: i, in: nsBuffer, bufLen: bufLen) { case .opener(let length, let tag): (ctx.hasStatementContent, ctx.statementStartLine) = markContent( ctx.hasStatementContent, ctx.statementStartLine, ctx.currentLine) @@ -452,7 +400,7 @@ final class SQLFileParser: Sendable { i = pos return StepResult(advanced: true, deferred: true) } - if matchesDollarClose(at: pos, tag: ctx.dollarTag, in: nsBuffer, bufLen: bufLen) { + if SqlDollarQuote.matchesClose(at: pos, tag: ctx.dollarTag, in: nsBuffer, bufLen: bufLen) { pos += closeLen ctx.state = .normal ctx.dollarTag = "" diff --git a/TablePro/Core/Utilities/SQL/SQLStatementScanner.swift b/TablePro/Core/Utilities/SQL/SQLStatementScanner.swift index 46f7f7efb..71057baab 100644 --- a/TablePro/Core/Utilities/SQL/SQLStatementScanner.swift +++ b/TablePro/Core/Utilities/SQL/SQLStatementScanner.swift @@ -4,6 +4,7 @@ // import Foundation +import TableProPluginKit enum SQLStatementScanner { struct LocatedStatement { @@ -11,10 +12,10 @@ enum SQLStatementScanner { let offset: Int } - /// Returns statements with trailing semicolons stripped — for driver execution. - static func allStatements(in sql: String) -> [String] { + /// Returns statements with trailing semicolons stripped, for driver execution. + static func allStatements(in sql: String, dialect: SqlDialect = .generic) -> [String] { var results: [String] = [] - scan(sql: sql, cursorPosition: nil) { rawSQL, _ in + scan(sql: sql, cursorPosition: nil, dialect: dialect) { rawSQL, _ in var trimmed = rawSQL.trimmingCharacters(in: .whitespacesAndNewlines) if trimmed.hasSuffix(";") { trimmed = String(trimmed.dropLast()) @@ -28,10 +29,10 @@ enum SQLStatementScanner { return results } - /// Returns statements preserving trailing semicolons — for display/history/favorites. - static func allStatementsPreservingSemicolons(in sql: String) -> [String] { + /// Returns statements preserving trailing semicolons, for display/history/favorites. + static func allStatementsPreservingSemicolons(in sql: String, dialect: SqlDialect = .generic) -> [String] { var results: [String] = [] - scan(sql: sql, cursorPosition: nil) { rawSQL, _ in + scan(sql: sql, cursorPosition: nil, dialect: dialect) { rawSQL, _ in let trimmed = rawSQL.trimmingCharacters(in: .whitespacesAndNewlines) let withoutSemicolon = trimmed.hasSuffix(";") ? String(trimmed.dropLast()).trimmingCharacters(in: .whitespacesAndNewlines) @@ -44,8 +45,8 @@ enum SQLStatementScanner { return results } - static func statementAtCursor(in sql: String, cursorPosition: Int) -> String { - var result = locatedStatementAtCursor(in: sql, cursorPosition: cursorPosition) + static func statementAtCursor(in sql: String, cursorPosition: Int, dialect: SqlDialect = .generic) -> String { + var result = locatedStatementAtCursor(in: sql, cursorPosition: cursorPosition, dialect: dialect) .sql .trimmingCharacters(in: .whitespacesAndNewlines) if result.hasSuffix(";") { @@ -55,9 +56,9 @@ enum SQLStatementScanner { return result } - static func locatedStatementAtCursor(in sql: String, cursorPosition: Int) -> LocatedStatement { + static func locatedStatementAtCursor(in sql: String, cursorPosition: Int, dialect: SqlDialect = .generic) -> LocatedStatement { var result = LocatedStatement(sql: "", offset: 0) - scan(sql: sql, cursorPosition: cursorPosition) { rawSQL, offset in + scan(sql: sql, cursorPosition: cursorPosition, dialect: dialect) { rawSQL, offset in result = LocatedStatement(sql: rawSQL, offset: offset) return false } @@ -75,10 +76,18 @@ enum SQLStatementScanner { private static let star = UInt16(UnicodeScalar("*").value) private static let newline = UInt16(UnicodeScalar("\n").value) private static let backslash = UInt16(UnicodeScalar("\\").value) + private static let dollar = UInt16(UnicodeScalar("$").value) + + private static func isIdentifierContinuation(before index: Int, in buffer: NSString) -> Bool { + guard index > 0 else { return false } + let prev = buffer.character(at: index - 1) + return SqlDollarQuote.isIdentifierPart(prev) || prev == dollar + } private static func scan( sql: String, cursorPosition: Int?, + dialect: SqlDialect = .generic, onStatement: (_ rawSQL: String, _ offset: Int) -> Bool ) { let nsQuery = sql as NSString @@ -97,6 +106,9 @@ enum SQLStatementScanner { var stringCharVal: UInt16 = 0 var inLineComment = false var inBlockComment = false + var inDollarQuote = false + var dollarTag = "" + let dollarQuotesEnabled = dialect.supportsDollarQuotes var i = 0 while i < length { @@ -118,6 +130,18 @@ enum SQLStatementScanner { continue } + if inDollarQuote { + if ch == dollar, + SqlDollarQuote.matchesClose(at: i, tag: dollarTag, in: nsQuery, bufLen: length) { + inDollarQuote = false + i += (dollarTag as NSString).length + 2 + dollarTag = "" + continue + } + i += 1 + continue + } + if !inString && ch == dash && i + 1 < length && nsQuery.character(at: i + 1) == dash { inLineComment = true i += 2 @@ -148,6 +172,14 @@ enum SQLStatementScanner { } } + if dollarQuotesEnabled, !inString, ch == dollar, !isIdentifierContinuation(before: i, in: nsQuery), + case .opener(let openerLength, let tag) = SqlDollarQuote.scanOpener(at: i, in: nsQuery, bufLen: length) { + inDollarQuote = true + dollarTag = tag + i += openerLength + continue + } + if ch == semicolonChar && !inString { let stmtEnd = i + 1 diff --git a/TablePro/Core/Utilities/SQL/SqlDollarQuote.swift b/TablePro/Core/Utilities/SQL/SqlDollarQuote.swift new file mode 100644 index 000000000..3bb3fc132 --- /dev/null +++ b/TablePro/Core/Utilities/SQL/SqlDollarQuote.swift @@ -0,0 +1,62 @@ +// +// SqlDollarQuote.swift +// TablePro +// + +import Foundation + +enum SqlDollarQuote { + enum Opener { + case opener(length: Int, tag: String) + case notOpener + case needsMoreData + } + + static let dollar: unichar = 0x24 + + static func isIdentifierStart(_ ch: unichar) -> Bool { + (ch >= 0x41 && ch <= 0x5A) || (ch >= 0x61 && ch <= 0x7A) || ch == 0x5F + } + + static func isIdentifierPart(_ ch: unichar) -> Bool { + isIdentifierStart(ch) || (ch >= 0x30 && ch <= 0x39) + } + + /// Resolves a `$` at `pos` to a dollar-quote opener, a positional parameter + /// like `$1`, or a non-tag dollar. Returns `needsMoreData` when the buffer + /// ends mid-tag; a whole-string caller treats that as `notOpener`. + static func scanOpener(at pos: Int, in buffer: NSString, bufLen: Int) -> Opener { + var p = pos + 1 + while p < bufLen { + let ch = buffer.character(at: p) + if ch == dollar { + let tagLen = p - pos - 1 + if tagLen == 0 { + return .opener(length: 2, tag: "") + } + if !isIdentifierStart(buffer.character(at: pos + 1)) { + return .notOpener + } + let tag = buffer.substring(with: NSRange(location: pos + 1, length: tagLen)) + return .opener(length: tagLen + 2, tag: tag) + } + if !isIdentifierPart(ch) { + return .notOpener + } + p += 1 + } + return .needsMoreData + } + + /// Whether the closing delimiter for `tag` starts at `pos`. The tag match is + /// exact and case-sensitive, per PostgreSQL. + static func matchesClose(at pos: Int, tag: String, in buffer: NSString, bufLen: Int) -> Bool { + let closeLen = (tag as NSString).length + 2 + guard pos + closeLen <= bufLen else { return false } + if buffer.character(at: pos) != dollar { return false } + if buffer.character(at: pos + closeLen - 1) != dollar { return false } + if tag.isEmpty { return true } + let tagRange = NSRange(location: pos + 1, length: (tag as NSString).length) + return buffer.substring(with: tagRange) == tag + } +} diff --git a/TablePro/Views/Main/Extensions/MainContentCoordinator+ClickHouse.swift b/TablePro/Views/Main/Extensions/MainContentCoordinator+ClickHouse.swift index 024d07884..18dd37d71 100644 --- a/TablePro/Views/Main/Extensions/MainContentCoordinator+ClickHouse.swift +++ b/TablePro/Views/Main/Extensions/MainContentCoordinator+ClickHouse.swift @@ -45,14 +45,15 @@ extension MainContentCoordinator { } else { sql = SQLStatementScanner.statementAtCursor( in: fullQuery, - cursorPosition: cursorPositions.first?.range.location ?? 0 + cursorPosition: cursorPositions.first?.range.location ?? 0, + dialect: sqlDialect ) } let trimmed = sql.trimmingCharacters(in: .whitespacesAndNewlines) guard !trimmed.isEmpty else { return } - let statements = SQLStatementScanner.allStatements(in: trimmed) + let statements = SQLStatementScanner.allStatements(in: trimmed, dialect: sqlDialect) guard let stmt = statements.first else { return } let explainSQL = "\(variant.sqlPrefix) \(stmt)" diff --git a/TablePro/Views/Main/MainContentCoordinator.swift b/TablePro/Views/Main/MainContentCoordinator.swift index cabfda36d..867f0ddbd 100644 --- a/TablePro/Views/Main/MainContentCoordinator.swift +++ b/TablePro/Views/Main/MainContentCoordinator.swift @@ -87,6 +87,7 @@ final class MainContentCoordinator { @ObservationIgnored let services: AppServices let connection: DatabaseConnection var connectionId: UUID { connection.id } + var sqlDialect: SqlDialect { SqlDialect.from(databaseTypeId: connection.type.rawValue) } var activeDatabaseName: String { services.databaseManager.activeDatabaseName(for: connection) } @@ -795,7 +796,8 @@ final class MainContentCoordinator { } else { sql = SQLStatementScanner.statementAtCursor( in: fullQuery, - cursorPosition: cursorPositions.first?.range.location ?? 0 + cursorPosition: cursorPositions.first?.range.location ?? 0, + dialect: sqlDialect ) } @@ -804,7 +806,7 @@ final class MainContentCoordinator { } if services.appSettings.editor.queryParametersEnabled { - let paramStatements = SQLStatementScanner.allStatements(in: sql) + let paramStatements = SQLStatementScanner.allStatements(in: sql, dialect: sqlDialect) guard !paramStatements.isEmpty else { return } let combinedSQL = paramStatements.joined(separator: "; ") let detectedNames = SQLParameterExtractor.extractParameters(from: combinedSQL) @@ -831,7 +833,7 @@ final class MainContentCoordinator { } } - let statements = SQLStatementScanner.allStatements(in: sql) + let statements = SQLStatementScanner.allStatements(in: sql, dialect: sqlDialect) guard !statements.isEmpty else { return } tabManager.tabStructureVersion += 1 @@ -943,7 +945,8 @@ final class MainContentCoordinator { } else { sql = SQLStatementScanner.statementAtCursor( in: fullQuery, - cursorPosition: cursorPositions.first?.range.location ?? 0 + cursorPosition: cursorPositions.first?.range.location ?? 0, + dialect: sqlDialect ) } @@ -951,7 +954,7 @@ final class MainContentCoordinator { guard !trimmed.isEmpty else { return } // Use first statement only (EXPLAIN on a single statement) - let statements = SQLStatementScanner.allStatements(in: trimmed) + let statements = SQLStatementScanner.allStatements(in: trimmed, dialect: sqlDialect) guard let stmt = statements.first else { return } let level = safeModeLevel diff --git a/TableProTests/Core/Utilities/SQLStatementScannerTests.swift b/TableProTests/Core/Utilities/SQLStatementScannerTests.swift index c267e0112..f93769b36 100644 --- a/TableProTests/Core/Utilities/SQLStatementScannerTests.swift +++ b/TableProTests/Core/Utilities/SQLStatementScannerTests.swift @@ -3,8 +3,8 @@ // TableProTests // -import TableProPluginKit @testable import TablePro +import TableProPluginKit import XCTest final class SQLStatementScannerTests: XCTestCase { @@ -109,6 +109,67 @@ final class SQLStatementScannerTests: XCTestCase { ) } + // MARK: - Dollar Quoting (PostgreSQL) + + func testDollarQuotedDoBlockKeepsInternalSemicolons() { + let sql = "DO $$ BEGIN PERFORM 1; PERFORM 2; END $$; SELECT 1;" + XCTAssertEqual( + SQLStatementScanner.allStatements(in: sql, dialect: .postgres), + ["DO $$ BEGIN PERFORM 1; PERFORM 2; END $$", "SELECT 1"] + ) + } + + func testTaggedDollarQuoteKeepsInternalSemicolons() { + let sql = "CREATE FUNCTION f() RETURNS int AS $body$ BEGIN RETURN 1; END $body$ LANGUAGE plpgsql; SELECT f();" + XCTAssertEqual( + SQLStatementScanner.allStatements(in: sql, dialect: .postgres), + [ + "CREATE FUNCTION f() RETURNS int AS $body$ BEGIN RETURN 1; END $body$ LANGUAGE plpgsql", + "SELECT f()" + ] + ) + } + + func testNestedDifferentDollarTags() { + let sql = "DO $outer$ SELECT $inner$ a;b $inner$; END $outer$; SELECT 1;" + XCTAssertEqual( + SQLStatementScanner.allStatements(in: sql, dialect: .postgres), + ["DO $outer$ SELECT $inner$ a;b $inner$; END $outer$", "SELECT 1"] + ) + } + + func testPositionalParameterIsNotDollarQuote() { + let sql = "SELECT $1; SELECT $2;" + XCTAssertEqual( + SQLStatementScanner.allStatements(in: sql, dialect: .postgres), + ["SELECT $1", "SELECT $2"] + ) + } + + func testDollarPairInsideIdentifierIsNotOpener() { + let sql = "SELECT 1 AS a$$; SELECT 2;" + XCTAssertEqual( + SQLStatementScanner.allStatements(in: sql, dialect: .postgres), + ["SELECT 1 AS a$$", "SELECT 2"] + ) + } + + func testGenericDialectIgnoresDollarQuotes() { + let sql = "DO $$ SELECT 1; SELECT 2 $$;" + XCTAssertEqual( + SQLStatementScanner.allStatements(in: sql, dialect: .generic), + ["DO $$ SELECT 1", "SELECT 2 $$"] + ) + } + + func testCursorInsideDollarBodyReturnsWholeStatement() { + let sql = "DO $$ BEGIN PERFORM 1; END $$; SELECT 1;" + XCTAssertEqual( + SQLStatementScanner.statementAtCursor(in: sql, cursorPosition: 20, dialect: .postgres), + "DO $$ BEGIN PERFORM 1; END $$" + ) + } + // MARK: - allStatementsPreservingSemicolons func testPreservingSemicolons() { From 7d5818b9087640ac559f2bb1da83052c802564eb Mon Sep 17 00:00:00 2001 From: Ngo Quoc Dat Date: Tue, 2 Jun 2026 22:57:20 +0700 Subject: [PATCH 2/2] refactor(editor): share the dollar-quote token-start guard across both SQL splitters --- .../Core/Utilities/SQL/SQLStatementScanner.swift | 12 +++--------- TablePro/Core/Utilities/SQL/SqlDollarQuote.swift | 16 ++++++++++++++-- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/TablePro/Core/Utilities/SQL/SQLStatementScanner.swift b/TablePro/Core/Utilities/SQL/SQLStatementScanner.swift index 71057baab..95a4aeeee 100644 --- a/TablePro/Core/Utilities/SQL/SQLStatementScanner.swift +++ b/TablePro/Core/Utilities/SQL/SQLStatementScanner.swift @@ -30,9 +30,9 @@ enum SQLStatementScanner { } /// Returns statements preserving trailing semicolons, for display/history/favorites. - static func allStatementsPreservingSemicolons(in sql: String, dialect: SqlDialect = .generic) -> [String] { + static func allStatementsPreservingSemicolons(in sql: String) -> [String] { var results: [String] = [] - scan(sql: sql, cursorPosition: nil, dialect: dialect) { rawSQL, _ in + scan(sql: sql, cursorPosition: nil) { rawSQL, _ in let trimmed = rawSQL.trimmingCharacters(in: .whitespacesAndNewlines) let withoutSemicolon = trimmed.hasSuffix(";") ? String(trimmed.dropLast()).trimmingCharacters(in: .whitespacesAndNewlines) @@ -78,12 +78,6 @@ enum SQLStatementScanner { private static let backslash = UInt16(UnicodeScalar("\\").value) private static let dollar = UInt16(UnicodeScalar("$").value) - private static func isIdentifierContinuation(before index: Int, in buffer: NSString) -> Bool { - guard index > 0 else { return false } - let prev = buffer.character(at: index - 1) - return SqlDollarQuote.isIdentifierPart(prev) || prev == dollar - } - private static func scan( sql: String, cursorPosition: Int?, @@ -172,7 +166,7 @@ enum SQLStatementScanner { } } - if dollarQuotesEnabled, !inString, ch == dollar, !isIdentifierContinuation(before: i, in: nsQuery), + if dollarQuotesEnabled, !inString, ch == dollar, case .opener(let openerLength, let tag) = SqlDollarQuote.scanOpener(at: i, in: nsQuery, bufLen: length) { inDollarQuote = true dollarTag = tag diff --git a/TablePro/Core/Utilities/SQL/SqlDollarQuote.swift b/TablePro/Core/Utilities/SQL/SqlDollarQuote.swift index 3bb3fc132..ccb40c288 100644 --- a/TablePro/Core/Utilities/SQL/SqlDollarQuote.swift +++ b/TablePro/Core/Utilities/SQL/SqlDollarQuote.swift @@ -22,10 +22,22 @@ enum SqlDollarQuote { isIdentifierStart(ch) || (ch >= 0x30 && ch <= 0x39) } + /// Whether a `$` following this character is part of the preceding identifier, + /// per PostgreSQL's rule that a dollar quote must be separated from a + /// preceding identifier by whitespace (so `a$$b` is one identifier, not an + /// opener). + static func isIdentifierContinuation(_ ch: unichar) -> Bool { + isIdentifierPart(ch) || ch == dollar + } + /// Resolves a `$` at `pos` to a dollar-quote opener, a positional parameter - /// like `$1`, or a non-tag dollar. Returns `needsMoreData` when the buffer - /// ends mid-tag; a whole-string caller treats that as `notOpener`. + /// like `$1`, or a non-tag dollar. A `$` glued to a preceding identifier is + /// not an opener. Returns `needsMoreData` when the buffer ends mid-tag; a + /// whole-string caller treats that as `notOpener`. static func scanOpener(at pos: Int, in buffer: NSString, bufLen: Int) -> Opener { + if pos > 0, isIdentifierContinuation(buffer.character(at: pos - 1)) { + return .notOpener + } var p = pos + 1 while p < bufLen { let ch = buffer.character(at: p)