From 8048633ef3bf8e963285fc26a8c5fdf3d5f4efd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Wed, 3 Jun 2026 10:28:24 +0700 Subject: [PATCH 01/12] fix: change-tracking dialect default, JSON import plugin load, and cloudflared cleanup --- Plugins/JSONImportPlugin/Info.plist | 2 +- .../ChangeTracking/DataChangeManager.swift | 12 ++- .../Cloudflare/CloudflareTunnelManager.swift | 8 +- .../StructureChangeManager.swift | 5 +- .../TableStructureView+DataLoading.swift | 3 +- .../AnyChangeManagerTests.swift | 10 +-- .../DataChangeManagerExtendedTests.swift | 4 +- .../DataChangeManagerTests.swift | 77 ++++++++++++++----- .../StructureChangeManagerPKTests.swift | 12 +-- .../StructureChangeManagerUndoTests.swift | 3 +- 10 files changed, 88 insertions(+), 48 deletions(-) diff --git a/Plugins/JSONImportPlugin/Info.plist b/Plugins/JSONImportPlugin/Info.plist index 8ceb6def3..6174f79e9 100644 --- a/Plugins/JSONImportPlugin/Info.plist +++ b/Plugins/JSONImportPlugin/Info.plist @@ -3,7 +3,7 @@ TableProPluginKitVersion - 17 + 18 TableProProvidesImportFormatIds json diff --git a/TablePro/Core/ChangeTracking/DataChangeManager.swift b/TablePro/Core/ChangeTracking/DataChangeManager.swift index 1980e0f75..e7f3c5ea6 100644 --- a/TablePro/Core/ChangeTracking/DataChangeManager.swift +++ b/TablePro/Core/ChangeTracking/DataChangeManager.swift @@ -53,7 +53,7 @@ final class DataChangeManager: ChangeManaging { var primaryKeyColumns: [String] = [] /// First PK column, for contexts that need a single column (paste, filters) var primaryKeyColumn: String? { primaryKeyColumns.first } - var databaseType: DatabaseType = .mysql + var databaseType: DatabaseType? var pluginDriver: (any PluginDatabaseDriver)? var columns: [String] = [] @@ -94,7 +94,7 @@ final class DataChangeManager: ChangeManaging { tableName: String, columns: [String], primaryKeyColumns: [String], - databaseType: DatabaseType = .mysql, + databaseType: DatabaseType, triggerReload: Bool = true ) { self.tableName = tableName @@ -407,9 +407,15 @@ final class DataChangeManager: ChangeManaging { } } + guard let databaseType else { + throw DatabaseError.queryFailed( + "Cannot generate statements: table dialect not configured" + ) + } + if PluginManager.shared.editorLanguage(for: databaseType) != .sql { throw DatabaseError.queryFailed( - "Cannot generate statements for \(databaseType.rawValue) — plugin driver not initialized" + "Cannot generate statements for \(databaseType.rawValue): plugin driver not initialized" ) } diff --git a/TablePro/Core/Cloudflare/CloudflareTunnelManager.swift b/TablePro/Core/Cloudflare/CloudflareTunnelManager.swift index ad6c3872d..58e6361b2 100644 --- a/TablePro/Core/Cloudflare/CloudflareTunnelManager.swift +++ b/TablePro/Core/Cloudflare/CloudflareTunnelManager.swift @@ -310,8 +310,12 @@ actor CloudflareTunnelManager { UserDefaults.standard.removeObject(forKey: Self.stalePidsDefaultsKey) return } - guard let data = try? JSONEncoder().encode(records) else { return } - UserDefaults.standard.set(data, forKey: Self.stalePidsDefaultsKey) + do { + let data = try JSONEncoder().encode(records) + UserDefaults.standard.set(data, forKey: Self.stalePidsDefaultsKey) + } catch { + Self.logger.error("Failed to persist cloudflared PID records, leaked processes may survive to next launch: \(error.localizedDescription, privacy: .public)") + } } private static func isLiveCloudflared(_ record: CloudflaredPidRecord) -> Bool { diff --git a/TablePro/Core/SchemaTracking/StructureChangeManager.swift b/TablePro/Core/SchemaTracking/StructureChangeManager.swift index da9b9976a..8a606edd5 100644 --- a/TablePro/Core/SchemaTracking/StructureChangeManager.swift +++ b/TablePro/Core/SchemaTracking/StructureChangeManager.swift @@ -32,7 +32,6 @@ final class StructureChangeManager: ChangeManaging { var workingPrimaryKey: [String] = [] var tableName: String? - var databaseType: DatabaseType = .mysql // MARK: - Undo/Redo Support @@ -61,11 +60,9 @@ final class StructureChangeManager: ChangeManaging { columns: [ColumnInfo], indexes: [IndexInfo], foreignKeys: [ForeignKeyInfo], - primaryKey: [String], - databaseType: DatabaseType + primaryKey: [String] ) { self.tableName = tableName - self.databaseType = databaseType // Convert to definitions self.currentColumns = columns.map { EditableColumnDefinition.from($0) } diff --git a/TablePro/Views/Structure/TableStructureView+DataLoading.swift b/TablePro/Views/Structure/TableStructureView+DataLoading.swift index 920ac4470..8d5342743 100644 --- a/TablePro/Views/Structure/TableStructureView+DataLoading.swift +++ b/TablePro/Views/Structure/TableStructureView+DataLoading.swift @@ -98,8 +98,7 @@ extension TableStructureView { columns: columns, indexes: indexes, foreignKeys: foreignKeys, - primaryKey: primaryKey, - databaseType: connection.type + primaryKey: primaryKey ) } diff --git a/TableProTests/Core/ChangeTracking/AnyChangeManagerTests.swift b/TableProTests/Core/ChangeTracking/AnyChangeManagerTests.swift index 3d486f880..6047b4467 100644 --- a/TableProTests/Core/ChangeTracking/AnyChangeManagerTests.swift +++ b/TableProTests/Core/ChangeTracking/AnyChangeManagerTests.swift @@ -18,7 +18,7 @@ struct AnyChangeManagerTests { @Test("DataChangeManager wrapper: hasChanges forwards correctly") func dataManagerHasChangesForwards() { let dataManager = DataChangeManager() - dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"]) + dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"], databaseType: .mysql) let wrapper = AnyChangeManager(dataManager) #expect(wrapper.hasChanges == false) @@ -32,7 +32,7 @@ struct AnyChangeManagerTests { @Test("DataChangeManager wrapper: reloadVersion forwards correctly") func dataManagerReloadVersionForwards() { let dataManager = DataChangeManager() - dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"]) + dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"], databaseType: .mysql) let wrapper = AnyChangeManager(dataManager) let initialVersion = wrapper.reloadVersion @@ -44,7 +44,7 @@ struct AnyChangeManagerTests { @Test("isRowDeleted delegates correctly for DataChangeManager") func isRowDeletedDelegatesCorrectly() { let dataManager = DataChangeManager() - dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"]) + dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"], databaseType: .mysql) let wrapper = AnyChangeManager(dataManager) #expect(wrapper.isRowDeleted(0) == false) @@ -57,7 +57,7 @@ struct AnyChangeManagerTests { @Test("recordCellChange forwards to DataChangeManager") func recordCellChangeForwards() { let dataManager = DataChangeManager() - dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"]) + dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"], databaseType: .mysql) let wrapper = AnyChangeManager(dataManager) wrapper.recordCellChange(rowIndex: 0, columnIndex: 1, columnName: "name", oldValue: "Alice", newValue: "Bob", originalRow: ["1", "Alice"]) @@ -69,7 +69,7 @@ struct AnyChangeManagerTests { @Test("No retain cycle — wrapper can be deallocated") func noRetainCycleOnWrapper() { let dataManager = DataChangeManager() - dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"]) + dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"], databaseType: .mysql) weak var weakWrapper: AnyChangeManager? diff --git a/TableProTests/Core/ChangeTracking/DataChangeManagerExtendedTests.swift b/TableProTests/Core/ChangeTracking/DataChangeManagerExtendedTests.swift index a106f57fa..62cd7165d 100644 --- a/TableProTests/Core/ChangeTracking/DataChangeManagerExtendedTests.swift +++ b/TableProTests/Core/ChangeTracking/DataChangeManagerExtendedTests.swift @@ -25,7 +25,8 @@ struct DataChangeManagerExtendedTests { manager.configureForTable( tableName: "test_table", columns: columns, - primaryKeyColumns: [pk].compactMap { $0 } + primaryKeyColumns: [pk].compactMap { $0 }, + databaseType: .mysql ) return manager } @@ -694,6 +695,7 @@ struct DataChangeManagerExtendedTests { tableName: "test", columns: ["a", "b"], primaryKeyColumns: ["a"], + databaseType: .mysql, triggerReload: false ) #expect(manager.reloadVersion == before) diff --git a/TableProTests/Core/ChangeTracking/DataChangeManagerTests.swift b/TableProTests/Core/ChangeTracking/DataChangeManagerTests.swift index 9ef1a5968..9763e24a0 100644 --- a/TableProTests/Core/ChangeTracking/DataChangeManagerTests.swift +++ b/TableProTests/Core/ChangeTracking/DataChangeManagerTests.swift @@ -40,13 +40,31 @@ struct DataChangeManagerTests { #expect(manager.databaseType == .postgresql) } + @Test("generateSQL throws when the table dialect is not configured") + func generateSQLThrowsWhenDialectNotConfigured() { + let manager = DataChangeManager() + manager.recordCellChange( + rowIndex: 0, + columnIndex: 1, + columnName: "name", + oldValue: "Alice", + newValue: "Bob" + ) + + #expect(manager.databaseType == nil) + #expect(throws: (any Error).self) { + _ = try manager.generateSQL() + } + } + @Test("configureForTable clears existing changes") func configureForTableClearsChanges() async { let manager = DataChangeManager() manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -61,7 +79,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "products", columns: ["id", "title"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) #expect(!manager.hasChanges) @@ -86,7 +105,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -106,7 +126,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -132,7 +153,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -153,7 +175,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -184,7 +207,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -214,7 +238,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -246,7 +271,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordRowDeletion(rowIndex: 0, originalRow: ["1", "Alice"]) @@ -260,7 +286,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -286,7 +313,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordRowDeletion(rowIndex: 2, originalRow: ["3", "Charlie"]) @@ -303,7 +331,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) let rows: [(rowIndex: Int, originalRow: [PluginCellValue])] = [ @@ -327,7 +356,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -352,7 +382,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -377,7 +408,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -397,7 +429,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -421,7 +454,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -443,7 +477,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -484,7 +519,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) let initialVersion = manager.reloadVersion @@ -506,7 +542,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( diff --git a/TableProTests/Core/SchemaTracking/StructureChangeManagerPKTests.swift b/TableProTests/Core/SchemaTracking/StructureChangeManagerPKTests.swift index 06fea744e..397888708 100644 --- a/TableProTests/Core/SchemaTracking/StructureChangeManagerPKTests.swift +++ b/TableProTests/Core/SchemaTracking/StructureChangeManagerPKTests.swift @@ -57,8 +57,7 @@ struct StructureChangeManagerPKTests { columns: sampleColumns(), indexes: sampleIndexes(), foreignKeys: [], - primaryKey: ["id"], - databaseType: .mysql + primaryKey: ["id"] ) let idCol = manager.workingColumns.first { $0.name == "id" } @@ -83,8 +82,7 @@ struct StructureChangeManagerPKTests { columns: sampleColumnsNoPK(), indexes: sampleIndexes(), foreignKeys: [], - primaryKey: ["id"], - databaseType: .postgresql + primaryKey: ["id"] ) // The working columns should have isPrimaryKey set based on the primaryKey parameter @@ -115,8 +113,7 @@ struct StructureChangeManagerPKTests { columns: columns, indexes: [], foreignKeys: [], - primaryKey: ["tenant_id", "user_id"], - databaseType: .postgresql + primaryKey: ["tenant_id", "user_id"] ) let tenantCol = manager.workingColumns.first { $0.name == "tenant_id" } @@ -140,8 +137,7 @@ struct StructureChangeManagerPKTests { columns: sampleColumnsNoPK(), indexes: [], foreignKeys: [], - primaryKey: [], - databaseType: .postgresql + primaryKey: [] ) for col in manager.workingColumns { diff --git a/TableProTests/Core/SchemaTracking/StructureChangeManagerUndoTests.swift b/TableProTests/Core/SchemaTracking/StructureChangeManagerUndoTests.swift index 020df04c8..6d864e27f 100644 --- a/TableProTests/Core/SchemaTracking/StructureChangeManagerUndoTests.swift +++ b/TableProTests/Core/SchemaTracking/StructureChangeManagerUndoTests.swift @@ -40,8 +40,7 @@ struct StructureChangeManagerUndoTests { columns: columns, indexes: indexes, foreignKeys: [], - primaryKey: ["id"], - databaseType: .mysql + primaryKey: ["id"] ) } From 078e0f7ab5f0f531610d1111c150b776adb5013d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Wed, 3 Jun 2026 10:28:39 +0700 Subject: [PATCH 02/12] refactor(plugins): extract shared SQL filter builder and numeric helpers, split oversized drivers --- CHANGELOG.md | 5 + .../CassandraConnection.swift | 851 +++++++++++++ .../CassandraPlugin.swift | 861 ------------- .../CassandraPluginError.swift | 25 + .../ClickHousePlugin.swift | 649 +--------- .../ClickHousePluginDriver+Http.swift | 309 +++++ .../ClickHousePluginDriver+Schema.swift | 340 ++++++ .../DuckDBDriverPlugin/DuckDBConnection.swift | 600 +++++++++ Plugins/DuckDBDriverPlugin/DuckDBPlugin.swift | 612 ---------- .../DuckDBPluginError.swift | 25 + .../JSONExportPlugin/JSONExportPlugin.swift | 11 +- Plugins/MSSQLDriverPlugin/MSSQLPlugin.swift | 805 +----------- .../MSSQLPluginDriver+DDL.swift | 184 +++ .../MSSQLPluginDriver+Schema.swift | 533 ++++++++ .../MongoDBConnection+SyncHelpers.swift | 458 +++++++ .../MongoDBConnection.swift | 460 +------ Plugins/OracleDriverPlugin/OraclePlugin.swift | 102 +- .../RedisPluginDriver+Operations.swift | 510 ++++++++ .../RedisPluginDriver+ResultBuilding.swift | 486 ++++++++ .../RedisPluginDriver+Scan.swift | 85 ++ .../RedisDriverPlugin/RedisPluginDriver.swift | 1075 +---------------- Plugins/SQLExportPlugin/SQLExportPlugin.swift | 17 +- .../PluginDatabaseDriver.swift | 170 ++- .../PluginExportUtilities.swift | 9 + docs/development/refactor-audit.mdx | 69 ++ 25 files changed, 4686 insertions(+), 4565 deletions(-) create mode 100644 Plugins/CassandraDriverPlugin/CassandraConnection.swift create mode 100644 Plugins/CassandraDriverPlugin/CassandraPluginError.swift create mode 100644 Plugins/ClickHouseDriverPlugin/ClickHousePluginDriver+Http.swift create mode 100644 Plugins/ClickHouseDriverPlugin/ClickHousePluginDriver+Schema.swift create mode 100644 Plugins/DuckDBDriverPlugin/DuckDBConnection.swift create mode 100644 Plugins/DuckDBDriverPlugin/DuckDBPluginError.swift create mode 100644 Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+DDL.swift create mode 100644 Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+Schema.swift create mode 100644 Plugins/MongoDBDriverPlugin/MongoDBConnection+SyncHelpers.swift create mode 100644 Plugins/RedisDriverPlugin/RedisPluginDriver+Operations.swift create mode 100644 Plugins/RedisDriverPlugin/RedisPluginDriver+ResultBuilding.swift create mode 100644 Plugins/RedisDriverPlugin/RedisPluginDriver+Scan.swift create mode 100644 docs/development/refactor-audit.mdx diff --git a/CHANGELOG.md b/CHANGELOG.md index 13e469c93..5b8fe0c1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed + +- JSON file import works again. It failed to load in 0.48.0. +- SQL export quotes empty or malformed values in numeric columns instead of writing them unquoted, which could produce invalid INSERT statements. + ## [0.48.0] - 2026-06-02 ### Added diff --git a/Plugins/CassandraDriverPlugin/CassandraConnection.swift b/Plugins/CassandraDriverPlugin/CassandraConnection.swift new file mode 100644 index 000000000..b80b97e67 --- /dev/null +++ b/Plugins/CassandraDriverPlugin/CassandraConnection.swift @@ -0,0 +1,851 @@ +// +// CassandraConnection.swift +// CassandraDriverPlugin +// + +#if canImport(CCassandra) +import CCassandra +#endif +import Foundation +import os +import TableProPluginKit + +actor CassandraConnectionActor { + private static let logger = Logger(subsystem: "com.TablePro.CassandraDriver", category: "Connection") + + nonisolated(unsafe) private static let isoFormatter: ISO8601DateFormatter = { + let f = ISO8601DateFormatter() + f.formatOptions = [.withInternetDateTime, .withFractionalSeconds] + return f + }() + + nonisolated(unsafe) private static let dateFormatter: DateFormatter = { + let f = DateFormatter() + f.dateFormat = "yyyy-MM-dd" + f.timeZone = TimeZone(identifier: "UTC") + return f + }() + + private var cluster: OpaquePointer? // CassCluster* + private var session: OpaquePointer? // CassSession* + private var currentKeyspace: String? + + var isConnected: Bool { session != nil } + + var keyspace: String? { currentKeyspace } + + func connect( + host: String, + port: Int, + username: String?, + password: String?, + keyspace: String?, + sslMode: SSLMode, + sslCaCertPath: String?, + sslClientCertPath: String?, + sslClientKeyPath: String?, + sslClientKeyPassphrase: String? + ) throws { + cluster = cass_cluster_new() + guard let cluster else { + throw CassandraPluginError.connectionFailed("Failed to create cluster object") + } + + cass_cluster_set_contact_points(cluster, host) + cass_cluster_set_port(cluster, Int32(port)) + + if let username, !username.isEmpty, let password { + cass_cluster_set_credentials(cluster, username, password) + } + + if sslMode != .disabled { + guard let ssl = cass_ssl_new() else { + cass_cluster_free(cluster) + self.cluster = nil + throw CassandraPluginError.connectionFailed("Failed to create SSL context") + } + + cass_ssl_set_verify_flags(ssl, CassandraSSLMapping.verifyFlags(for: sslMode)) + + if sslMode == .verifyCa || sslMode == .verifyIdentity { + guard let caCertPath = sslCaCertPath, !caCertPath.isEmpty else { + cass_ssl_free(ssl) + cass_cluster_free(cluster) + self.cluster = nil + throw SSLHandshakeError.untrustedCertificate(serverMessage: "Verify CA or Verify Identity requires a CA certificate path") + } + guard let certData = FileManager.default.contents(atPath: caCertPath), + let certString = String(data: certData, encoding: .utf8) else { + cass_ssl_free(ssl) + cass_cluster_free(cluster) + self.cluster = nil + throw SSLHandshakeError.untrustedCertificate(serverMessage: "Could not read CA certificate at \(caCertPath)") + } + let rc = cass_ssl_add_trusted_cert(ssl, certString) + if rc != CASS_OK { + cass_ssl_free(ssl) + cass_cluster_free(cluster) + self.cluster = nil + throw SSLHandshakeError.untrustedCertificate(serverMessage: "CA certificate at \(caCertPath) is not a valid PEM") + } + } + + let trimmedClientCertPath = sslClientCertPath?.trimmingCharacters(in: .whitespaces) ?? "" + let trimmedClientKeyPath = sslClientKeyPath?.trimmingCharacters(in: .whitespaces) ?? "" + if !trimmedClientCertPath.isEmpty || !trimmedClientKeyPath.isEmpty { + try applyClientCertificate( + to: ssl, + certPath: trimmedClientCertPath, + keyPath: trimmedClientKeyPath, + keyPassphrase: sslClientKeyPassphrase + ) { + cass_ssl_free(ssl) + cass_cluster_free(cluster) + self.cluster = nil + } + } + + cass_cluster_set_ssl(cluster, ssl) + cass_ssl_free(ssl) + } + + // Connection timeout (10 seconds) + cass_cluster_set_connect_timeout(cluster, 10_000) + cass_cluster_set_request_timeout(cluster, 30_000) + + let newSession = cass_session_new() + guard let newSession else { + cass_cluster_free(cluster) + self.cluster = nil + throw CassandraPluginError.connectionFailed("Failed to create session") + } + + let connectFuture: OpaquePointer? + if let keyspace, !keyspace.isEmpty { + connectFuture = cass_session_connect_keyspace(newSession, cluster, keyspace) + currentKeyspace = keyspace + } else { + connectFuture = cass_session_connect(newSession, cluster) + currentKeyspace = nil + } + + guard let future = connectFuture else { + cass_session_free(newSession) + cass_cluster_free(cluster) + self.cluster = nil + throw CassandraPluginError.connectionFailed("Failed to initiate connection") + } + + cass_future_wait(future) + let rc = cass_future_error_code(future) + + if rc != CASS_OK { + let errorMessage = extractFutureError(future) + cass_future_free(future) + cass_session_free(newSession) + cass_cluster_free(cluster) + self.cluster = nil + if let sslError = Self.classifySSLError(rc: rc, message: errorMessage) { + throw sslError + } + throw CassandraPluginError.connectionFailed(errorMessage) + } + + cass_future_free(future) + session = newSession + + Self.logger.info("Connected to Cassandra at \(host):\(port)") + } + + private func applyClientCertificate( + to ssl: OpaquePointer, + certPath: String, + keyPath: String, + keyPassphrase: String?, + cleanup: () -> Void + ) throws { + guard !certPath.isEmpty else { + cleanup() + throw SSLHandshakeError.clientCertRequired(serverMessage: "A client certificate is required when a client key is set") + } + guard !keyPath.isEmpty else { + cleanup() + throw SSLHandshakeError.clientCertRequired(serverMessage: "A client key is required when a client certificate is set") + } + + guard let certData = FileManager.default.contents(atPath: certPath), + let certString = String(data: certData, encoding: .utf8) else { + cleanup() + throw SSLHandshakeError.clientCertRequired(serverMessage: "Could not read client certificate at \(certPath)") + } + let certResult = cass_ssl_set_cert(ssl, certString) + if certResult != CASS_OK { + cleanup() + throw SSLHandshakeError.clientCertRequired(serverMessage: "Client certificate at \(certPath) is not a valid PEM") + } + + guard let keyData = FileManager.default.contents(atPath: keyPath), + let keyString = String(data: keyData, encoding: .utf8) else { + cleanup() + throw SSLHandshakeError.clientKeyInvalid(serverMessage: "Could not read client key at \(keyPath)") + } + let passphrase = keyPassphrase?.isEmpty == false ? keyPassphrase : nil + let keyResult = cass_ssl_set_private_key(ssl, keyString, passphrase) + if keyResult != CASS_OK { + cleanup() + throw Self.privateKeyLoadError(keyPEM: keyString, hasPassphrase: passphrase != nil, keyPath: keyPath) + } + } + + static func isEncryptedPrivateKey(_ pem: String) -> Bool { + pem.contains("ENCRYPTED PRIVATE KEY") || (pem.contains("Proc-Type:") && pem.contains("ENCRYPTED")) + } + + static func privateKeyLoadError(keyPEM: String, hasPassphrase: Bool, keyPath: String) -> SSLHandshakeError { + guard isEncryptedPrivateKey(keyPEM) else { + return .clientKeyInvalid(serverMessage: "The client key at \(keyPath) is not a valid private key") + } + if hasPassphrase { + return .clientKeyPassphraseIncorrect(serverMessage: "The passphrase for the client key at \(keyPath) is incorrect") + } + return .clientKeyPassphraseRequired(serverMessage: "The client key at \(keyPath) is encrypted. Enter its passphrase.") + } + + func close() { + if let session { + let closeFuture = cass_session_close(session) + if let closeFuture { + cass_future_wait(closeFuture) + cass_future_free(closeFuture) + } + cass_session_free(session) + self.session = nil + } + + if let cluster { + cass_cluster_free(cluster) + self.cluster = nil + } + + currentKeyspace = nil + Self.logger.info("Disconnected from Cassandra") + } + + func executeQuery(_ cql: String) throws -> CassandraRawResult { + guard let session else { + throw CassandraPluginError.notConnected + } + + let startTime = Date() + let statement = cass_statement_new(cql, 0) + guard let statement else { + throw CassandraPluginError.queryFailed("Failed to create statement") + } + + defer { cass_statement_free(statement) } + + let future = cass_session_execute(session, statement) + guard let future else { + throw CassandraPluginError.queryFailed("Failed to execute query") + } + + defer { cass_future_free(future) } + + cass_future_wait(future) + let rc = cass_future_error_code(future) + + if rc != CASS_OK { + throw CassandraPluginError.queryFailed(extractFutureError(future)) + } + + let result = cass_future_get_result(future) + defer { + if let result { cass_result_free(result) } + } + + guard let result else { + let executionTime = Date().timeIntervalSince(startTime) + return CassandraRawResult( + columns: [], + columnTypeNames: [], + rows: [], + rowsAffected: 0, + executionTime: executionTime + ) + } + + return extractResult(from: result, startTime: startTime) + } + + func executePrepared(_ cql: String, parameters: [PluginCellValue]) throws -> CassandraRawResult { + guard let session else { + throw CassandraPluginError.notConnected + } + + let startTime = Date() + + // Prepare + let prepareFuture = cass_session_prepare(session, cql) + guard let prepareFuture else { + throw CassandraPluginError.queryFailed("Failed to prepare statement") + } + defer { cass_future_free(prepareFuture) } + + cass_future_wait(prepareFuture) + let prepRc = cass_future_error_code(prepareFuture) + if prepRc != CASS_OK { + throw CassandraPluginError.queryFailed(extractFutureError(prepareFuture)) + } + + let prepared = cass_future_get_prepared(prepareFuture) + guard let prepared else { + throw CassandraPluginError.queryFailed("Failed to get prepared statement") + } + defer { cass_prepared_free(prepared) } + + // Bind parameters + let statement = cass_prepared_bind(prepared) + guard let statement else { + throw CassandraPluginError.queryFailed("Failed to bind prepared statement") + } + defer { cass_statement_free(statement) } + + for (index, param) in parameters.enumerated() { + switch param { + case .text(let value): + cass_statement_bind_string(statement, index, value) + case .bytes(let data): + data.withUnsafeBytes { rawBuffer in + if let base = rawBuffer.baseAddress?.assumingMemoryBound(to: UInt8.self) { + cass_statement_bind_bytes(statement, index, base, data.count) + } else { + cass_statement_bind_null(statement, index) + } + } + case .null: + cass_statement_bind_null(statement, index) + } + } + + // Execute + let future = cass_session_execute(session, statement) + guard let future else { + throw CassandraPluginError.queryFailed("Failed to execute prepared statement") + } + defer { cass_future_free(future) } + + cass_future_wait(future) + let rc = cass_future_error_code(future) + + if rc != CASS_OK { + throw CassandraPluginError.queryFailed(extractFutureError(future)) + } + + let result = cass_future_get_result(future) + defer { + if let result { cass_result_free(result) } + } + + guard let result else { + let executionTime = Date().timeIntervalSince(startTime) + return CassandraRawResult( + columns: [], + columnTypeNames: [], + rows: [], + rowsAffected: 0, + executionTime: executionTime + ) + } + + return extractResult(from: result, startTime: startTime) + } + + func switchKeyspace(_ keyspace: String) throws { + _ = try executeQuery("USE \"\(escapeIdentifier(keyspace))\"") + currentKeyspace = keyspace + } + + func serverVersion() throws -> String? { + let result = try executeQuery("SELECT release_version FROM system.local WHERE key = 'local'") + return result.rows.first?.first?.asText + } + + // MARK: - Private Helpers + + private func extractResult( + from result: OpaquePointer, + startTime: Date + ) -> CassandraRawResult { + let colCount = cass_result_column_count(result) + let rowCount = cass_result_row_count(result) + + var columns: [String] = [] + var columnTypeNames: [String] = [] + + for i in 0..? + var nameLength: Int = 0 + cass_result_column_name(result, i, &namePtr, &nameLength) + if let namePtr { + columns.append(String(cString: namePtr)) + } else { + columns.append("column_\(i)") + } + + let colType = cass_result_column_type(result, i) + columnTypeNames.append(Self.cassTypeName(colType)) + } + + var rows: [[PluginCellValue]] = [] + let iterator = cass_iterator_from_result(result) + defer { + if let iterator { cass_iterator_free(iterator) } + } + + guard let iterator else { + let executionTime = Date().timeIntervalSince(startTime) + return CassandraRawResult( + columns: columns, + columnTypeNames: columnTypeNames, + rows: [], + rowsAffected: Int(rowCount), + executionTime: executionTime + ) + } + + let maxRows = min(Int(rowCount), 100_000) + var count = 0 + + while cass_iterator_next(iterator) == cass_true && count < maxRows { + let row = cass_iterator_get_row(iterator) + guard let row else { continue } + + var rowData: [PluginCellValue] = [] + for col in 0.. Data? { + var bytes: UnsafePointer? + var length: Int = 0 + guard cass_value_get_bytes(value, &bytes, &length) == CASS_OK, let bytes else { + return nil + } + return Data(bytes: bytes, count: length) + } + + private static func extractStringValue(_ value: OpaquePointer) -> String? { + let valueType = cass_value_type(value) + + switch valueType { + case CASS_VALUE_TYPE_ASCII, CASS_VALUE_TYPE_TEXT, CASS_VALUE_TYPE_VARCHAR: + var output: UnsafePointer? + var outputLength: Int = 0 + let rc = cass_value_get_string(value, &output, &outputLength) + if rc == CASS_OK, let output { + return String( + bytesNoCopy: UnsafeMutableRawPointer(mutating: output), + length: outputLength, + encoding: .utf8, + freeWhenDone: false + ) + } + return nil + + case CASS_VALUE_TYPE_INT: + var intVal: Int32 = 0 + if cass_value_get_int32(value, &intVal) == CASS_OK { + return String(intVal) + } + return nil + + case CASS_VALUE_TYPE_BIGINT, CASS_VALUE_TYPE_COUNTER: + var bigintVal: Int64 = 0 + if cass_value_get_int64(value, &bigintVal) == CASS_OK { + return String(bigintVal) + } + return nil + + case CASS_VALUE_TYPE_SMALL_INT: + var smallVal: Int16 = 0 + if cass_value_get_int16(value, &smallVal) == CASS_OK { + return String(smallVal) + } + return nil + + case CASS_VALUE_TYPE_TINY_INT: + var tinyVal: Int8 = 0 + if cass_value_get_int8(value, &tinyVal) == CASS_OK { + return String(tinyVal) + } + return nil + + case CASS_VALUE_TYPE_FLOAT: + var floatVal: Float = 0 + if cass_value_get_float(value, &floatVal) == CASS_OK { + return String(floatVal) + } + return nil + + case CASS_VALUE_TYPE_DOUBLE: + var doubleVal: Double = 0 + if cass_value_get_double(value, &doubleVal) == CASS_OK { + return String(doubleVal) + } + return nil + + case CASS_VALUE_TYPE_BOOLEAN: + var boolVal: cass_bool_t = cass_false + if cass_value_get_bool(value, &boolVal) == CASS_OK { + return boolVal == cass_true ? "true" : "false" + } + return nil + + case CASS_VALUE_TYPE_UUID, CASS_VALUE_TYPE_TIMEUUID: + var uuid = CassUuid() + if cass_value_get_uuid(value, &uuid) == CASS_OK { + var buffer = [CChar](repeating: 0, count: Int(CASS_UUID_STRING_LENGTH)) + cass_uuid_string(uuid, &buffer) + return String(cString: buffer) + } + return nil + + case CASS_VALUE_TYPE_TIMESTAMP: + var timestamp: Int64 = 0 + if cass_value_get_int64(value, ×tamp) == CASS_OK { + let date = Date(timeIntervalSince1970: Double(timestamp) / 1000.0) + return isoFormatter.string(from: date) + } + return nil + + case CASS_VALUE_TYPE_BLOB: + if let data = extractBlobValue(value) { + return "0x" + data.map { String(format: "%02x", $0) }.joined() + } + return nil + + case CASS_VALUE_TYPE_INET: + var inet = CassInet() + if cass_value_get_inet(value, &inet) == CASS_OK { + var buffer = [CChar](repeating: 0, count: Int(CASS_INET_STRING_LENGTH)) + cass_inet_string(inet, &buffer) + return String(cString: buffer) + } + return nil + + case CASS_VALUE_TYPE_LIST, CASS_VALUE_TYPE_SET: + return extractCollectionString(value, open: "[", close: "]") + + case CASS_VALUE_TYPE_MAP: + return extractMapString(value) + + case CASS_VALUE_TYPE_TUPLE: + return extractCollectionString(value, open: "(", close: ")") + + case CASS_VALUE_TYPE_DATE: + var dateVal: UInt32 = 0 + if cass_value_get_uint32(value, &dateVal) == CASS_OK { + let daysSinceEpoch = Int64(dateVal) - Int64(1 << 31) + let epochSeconds = daysSinceEpoch * 86400 + let date = Date(timeIntervalSince1970: Double(epochSeconds)) + return dateFormatter.string(from: date) + } + return nil + + case CASS_VALUE_TYPE_TIME: + var timeVal: Int64 = 0 + if cass_value_get_int64(value, &timeVal) == CASS_OK { + // Cassandra time is nanoseconds since midnight + let totalSeconds = timeVal / 1_000_000_000 + let hours = totalSeconds / 3600 + let minutes = (totalSeconds % 3600) / 60 + let seconds = totalSeconds % 60 + let nanos = timeVal % 1_000_000_000 + if nanos > 0 { + let millis = nanos / 1_000_000 + return String(format: "%02lld:%02lld:%02lld.%03lld", hours, minutes, seconds, millis) + } + return String(format: "%02lld:%02lld:%02lld", hours, minutes, seconds) + } + return nil + + case CASS_VALUE_TYPE_DECIMAL, CASS_VALUE_TYPE_VARINT: + // Read as bytes and display as hex since proper numeric decoding + // requires BigInteger support not available in the C driver API + var bytes: UnsafePointer? + var length: Int = 0 + if cass_value_get_bytes(value, &bytes, &length) == CASS_OK, let bytes { + let data = Data(bytes: bytes, count: length) + return "0x" + data.map { String(format: "%02x", $0) }.joined() + } + return nil + + default: + // Fallback: try reading as string + var output: UnsafePointer? + var outputLength: Int = 0 + if cass_value_get_string(value, &output, &outputLength) == CASS_OK, let output { + return String( + bytesNoCopy: UnsafeMutableRawPointer(mutating: output), + length: outputLength, + encoding: .utf8, + freeWhenDone: false + ) + } + return "" + } + } + + private static func extractCollectionString( + _ value: OpaquePointer, + open: String, + close: String + ) -> String { + guard let iterator = cass_iterator_from_collection(value) else { + return "\(open)\(close)" + } + defer { cass_iterator_free(iterator) } + + var elements: [String] = [] + while cass_iterator_next(iterator) == cass_true { + if let elem = cass_iterator_get_value(iterator) { + elements.append(extractStringValue(elem) ?? "null") + } + } + return "\(open)\(elements.joined(separator: ", "))\(close)" + } + + private static func extractMapString(_ value: OpaquePointer) -> String { + guard let iterator = cass_iterator_from_map(value) else { + return "{}" + } + defer { cass_iterator_free(iterator) } + + var pairs: [String] = [] + while cass_iterator_next(iterator) == cass_true { + let key = cass_iterator_get_map_key(iterator) + let val = cass_iterator_get_map_value(iterator) + let keyStr = key.flatMap { extractStringValue($0) } ?? "null" + let valStr = val.flatMap { extractStringValue($0) } ?? "null" + pairs.append("\(keyStr): \(valStr)") + } + return "{\(pairs.joined(separator: ", "))}" + } + + private static func cassTypeName(_ type: CassValueType) -> String { + switch type { + case CASS_VALUE_TYPE_ASCII: return "ascii" + case CASS_VALUE_TYPE_BIGINT: return "bigint" + case CASS_VALUE_TYPE_BLOB: return "blob" + case CASS_VALUE_TYPE_BOOLEAN: return "boolean" + case CASS_VALUE_TYPE_COUNTER: return "counter" + case CASS_VALUE_TYPE_DECIMAL: return "decimal" + case CASS_VALUE_TYPE_DOUBLE: return "double" + case CASS_VALUE_TYPE_FLOAT: return "float" + case CASS_VALUE_TYPE_INT: return "int" + case CASS_VALUE_TYPE_TEXT: return "text" + case CASS_VALUE_TYPE_TIMESTAMP: return "timestamp" + case CASS_VALUE_TYPE_UUID: return "uuid" + case CASS_VALUE_TYPE_VARCHAR: return "varchar" + case CASS_VALUE_TYPE_VARINT: return "varint" + case CASS_VALUE_TYPE_TIMEUUID: return "timeuuid" + case CASS_VALUE_TYPE_INET: return "inet" + case CASS_VALUE_TYPE_DATE: return "date" + case CASS_VALUE_TYPE_TIME: return "time" + case CASS_VALUE_TYPE_SMALL_INT: return "smallint" + case CASS_VALUE_TYPE_TINY_INT: return "tinyint" + case CASS_VALUE_TYPE_LIST: return "list" + case CASS_VALUE_TYPE_MAP: return "map" + case CASS_VALUE_TYPE_SET: return "set" + case CASS_VALUE_TYPE_TUPLE: return "tuple" + case CASS_VALUE_TYPE_UDT: return "udt" + default: return "text" + } + } + + private func extractFutureError(_ future: OpaquePointer) -> String { + var message: UnsafePointer? + var messageLength: Int = 0 + cass_future_error_message(future, &message, &messageLength) + if let message { + return String( + bytesNoCopy: UnsafeMutableRawPointer(mutating: message), + length: messageLength, + encoding: .utf8, + freeWhenDone: false + ) ?? "Unknown error" + } + return "Unknown error" + } + + func streamQuery( + _ cql: String, + continuation: AsyncThrowingStream.Continuation + ) throws { + guard let session else { + throw CassandraPluginError.notConnected + } + + let pageSize: Int32 = 5_000 + let statement = cass_statement_new(cql, 0) + guard let statement else { + throw CassandraPluginError.queryFailed("Failed to create statement") + } + + cass_statement_set_paging_size(statement, pageSize) + + var headerSent = false + + defer { cass_statement_free(statement) } + + while true { + let future = cass_session_execute(session, statement) + guard let future else { + throw CassandraPluginError.queryFailed("Failed to execute query") + } + + cass_future_wait(future) + let rc = cass_future_error_code(future) + + if rc != CASS_OK { + let errorMessage = extractFutureError(future) + cass_future_free(future) + throw CassandraPluginError.queryFailed(errorMessage) + } + + let result = cass_future_get_result(future) + cass_future_free(future) + + guard let result else { break } + + if !headerSent { + let colCount = cass_result_column_count(result) + var columns: [String] = [] + var columnTypeNames: [String] = [] + + for i in 0..? + var nameLength: Int = 0 + cass_result_column_name(result, i, &namePtr, &nameLength) + if let namePtr { + columns.append(String(cString: namePtr)) + } else { + columns.append("column_\(i)") + } + let colType = cass_result_column_type(result, i) + columnTypeNames.append(Self.cassTypeName(colType)) + } + + continuation.yield(.header(PluginStreamHeader( + columns: columns, + columnTypeNames: columnTypeNames, + estimatedRowCount: nil + ))) + headerSent = true + } + + let colCount = cass_result_column_count(result) + let iterator = cass_iterator_from_result(result) + + if let iterator { + while cass_iterator_next(iterator) == cass_true { + let row = cass_iterator_get_row(iterator) + guard let row else { continue } + + var rowData: [PluginCellValue] = [] + for col in 0.. String { + value.replacingOccurrences(of: "\"", with: "\"\"") + } + + static func classifySSLError(rc: CassError, message: String) -> SSLHandshakeError? { + switch rc { + case CASS_ERROR_SSL_NO_PEER_CERT, CASS_ERROR_SSL_INVALID_PEER_CERT: + return .untrustedCertificate(serverMessage: message) + case CASS_ERROR_SSL_IDENTITY_MISMATCH: + return .hostnameMismatch(serverMessage: message) + case CASS_ERROR_SSL_INVALID_PRIVATE_KEY, CASS_ERROR_SSL_INVALID_CERT: + return .clientCertRequired(serverMessage: message) + case CASS_ERROR_SSL_PROTOCOL_ERROR: + return .cipherMismatch(serverMessage: message) + default: + break + } + let lower = message.lowercased() + if lower.contains("ssl handshake") || lower.contains("tls handshake") || lower.contains("ssl_connect") { + return .cipherMismatch(serverMessage: message) + } + return nil + } +} + +// MARK: - Raw Result + +struct CassandraRawResult: Sendable { + let columns: [String] + let columnTypeNames: [String] + let rows: [[PluginCellValue]] + let rowsAffected: Int + let executionTime: TimeInterval +} diff --git a/Plugins/CassandraDriverPlugin/CassandraPlugin.swift b/Plugins/CassandraDriverPlugin/CassandraPlugin.swift index 2d966008b..f72a3a2fe 100644 --- a/Plugins/CassandraDriverPlugin/CassandraPlugin.swift +++ b/Plugins/CassandraDriverPlugin/CassandraPlugin.swift @@ -104,848 +104,6 @@ internal final class CassandraPlugin: NSObject, TableProPlugin, DriverPlugin { } } -// MARK: - Connection Actor - -private actor CassandraConnectionActor { - private static let logger = Logger(subsystem: "com.TablePro.CassandraDriver", category: "Connection") - - nonisolated(unsafe) private static let isoFormatter: ISO8601DateFormatter = { - let f = ISO8601DateFormatter() - f.formatOptions = [.withInternetDateTime, .withFractionalSeconds] - return f - }() - - nonisolated(unsafe) private static let dateFormatter: DateFormatter = { - let f = DateFormatter() - f.dateFormat = "yyyy-MM-dd" - f.timeZone = TimeZone(identifier: "UTC") - return f - }() - - private var cluster: OpaquePointer? // CassCluster* - private var session: OpaquePointer? // CassSession* - private var currentKeyspace: String? - - var isConnected: Bool { session != nil } - - var keyspace: String? { currentKeyspace } - - func connect( - host: String, - port: Int, - username: String?, - password: String?, - keyspace: String?, - sslMode: SSLMode, - sslCaCertPath: String?, - sslClientCertPath: String?, - sslClientKeyPath: String?, - sslClientKeyPassphrase: String? - ) throws { - cluster = cass_cluster_new() - guard let cluster else { - throw CassandraPluginError.connectionFailed("Failed to create cluster object") - } - - cass_cluster_set_contact_points(cluster, host) - cass_cluster_set_port(cluster, Int32(port)) - - if let username, !username.isEmpty, let password { - cass_cluster_set_credentials(cluster, username, password) - } - - if sslMode != .disabled { - guard let ssl = cass_ssl_new() else { - cass_cluster_free(cluster) - self.cluster = nil - throw CassandraPluginError.connectionFailed("Failed to create SSL context") - } - - cass_ssl_set_verify_flags(ssl, CassandraSSLMapping.verifyFlags(for: sslMode)) - - if sslMode == .verifyCa || sslMode == .verifyIdentity { - guard let caCertPath = sslCaCertPath, !caCertPath.isEmpty else { - cass_ssl_free(ssl) - cass_cluster_free(cluster) - self.cluster = nil - throw SSLHandshakeError.untrustedCertificate(serverMessage: "Verify CA or Verify Identity requires a CA certificate path") - } - guard let certData = FileManager.default.contents(atPath: caCertPath), - let certString = String(data: certData, encoding: .utf8) else { - cass_ssl_free(ssl) - cass_cluster_free(cluster) - self.cluster = nil - throw SSLHandshakeError.untrustedCertificate(serverMessage: "Could not read CA certificate at \(caCertPath)") - } - let rc = cass_ssl_add_trusted_cert(ssl, certString) - if rc != CASS_OK { - cass_ssl_free(ssl) - cass_cluster_free(cluster) - self.cluster = nil - throw SSLHandshakeError.untrustedCertificate(serverMessage: "CA certificate at \(caCertPath) is not a valid PEM") - } - } - - let trimmedClientCertPath = sslClientCertPath?.trimmingCharacters(in: .whitespaces) ?? "" - let trimmedClientKeyPath = sslClientKeyPath?.trimmingCharacters(in: .whitespaces) ?? "" - if !trimmedClientCertPath.isEmpty || !trimmedClientKeyPath.isEmpty { - try applyClientCertificate( - to: ssl, - certPath: trimmedClientCertPath, - keyPath: trimmedClientKeyPath, - keyPassphrase: sslClientKeyPassphrase - ) { - cass_ssl_free(ssl) - cass_cluster_free(cluster) - self.cluster = nil - } - } - - cass_cluster_set_ssl(cluster, ssl) - cass_ssl_free(ssl) - } - - // Connection timeout (10 seconds) - cass_cluster_set_connect_timeout(cluster, 10_000) - cass_cluster_set_request_timeout(cluster, 30_000) - - let newSession = cass_session_new() - guard let newSession else { - cass_cluster_free(cluster) - self.cluster = nil - throw CassandraPluginError.connectionFailed("Failed to create session") - } - - let connectFuture: OpaquePointer? - if let keyspace, !keyspace.isEmpty { - connectFuture = cass_session_connect_keyspace(newSession, cluster, keyspace) - currentKeyspace = keyspace - } else { - connectFuture = cass_session_connect(newSession, cluster) - currentKeyspace = nil - } - - guard let future = connectFuture else { - cass_session_free(newSession) - cass_cluster_free(cluster) - self.cluster = nil - throw CassandraPluginError.connectionFailed("Failed to initiate connection") - } - - cass_future_wait(future) - let rc = cass_future_error_code(future) - - if rc != CASS_OK { - let errorMessage = extractFutureError(future) - cass_future_free(future) - cass_session_free(newSession) - cass_cluster_free(cluster) - self.cluster = nil - if let sslError = Self.classifySSLError(rc: rc, message: errorMessage) { - throw sslError - } - throw CassandraPluginError.connectionFailed(errorMessage) - } - - cass_future_free(future) - session = newSession - - Self.logger.info("Connected to Cassandra at \(host):\(port)") - } - - private func applyClientCertificate( - to ssl: OpaquePointer, - certPath: String, - keyPath: String, - keyPassphrase: String?, - cleanup: () -> Void - ) throws { - guard !certPath.isEmpty else { - cleanup() - throw SSLHandshakeError.clientCertRequired(serverMessage: "A client certificate is required when a client key is set") - } - guard !keyPath.isEmpty else { - cleanup() - throw SSLHandshakeError.clientCertRequired(serverMessage: "A client key is required when a client certificate is set") - } - - guard let certData = FileManager.default.contents(atPath: certPath), - let certString = String(data: certData, encoding: .utf8) else { - cleanup() - throw SSLHandshakeError.clientCertRequired(serverMessage: "Could not read client certificate at \(certPath)") - } - let certResult = cass_ssl_set_cert(ssl, certString) - if certResult != CASS_OK { - cleanup() - throw SSLHandshakeError.clientCertRequired(serverMessage: "Client certificate at \(certPath) is not a valid PEM") - } - - guard let keyData = FileManager.default.contents(atPath: keyPath), - let keyString = String(data: keyData, encoding: .utf8) else { - cleanup() - throw SSLHandshakeError.clientKeyInvalid(serverMessage: "Could not read client key at \(keyPath)") - } - let passphrase = keyPassphrase?.isEmpty == false ? keyPassphrase : nil - let keyResult = cass_ssl_set_private_key(ssl, keyString, passphrase) - if keyResult != CASS_OK { - cleanup() - throw Self.privateKeyLoadError(keyPEM: keyString, hasPassphrase: passphrase != nil, keyPath: keyPath) - } - } - - static func isEncryptedPrivateKey(_ pem: String) -> Bool { - pem.contains("ENCRYPTED PRIVATE KEY") || (pem.contains("Proc-Type:") && pem.contains("ENCRYPTED")) - } - - static func privateKeyLoadError(keyPEM: String, hasPassphrase: Bool, keyPath: String) -> SSLHandshakeError { - guard isEncryptedPrivateKey(keyPEM) else { - return .clientKeyInvalid(serverMessage: "The client key at \(keyPath) is not a valid private key") - } - if hasPassphrase { - return .clientKeyPassphraseIncorrect(serverMessage: "The passphrase for the client key at \(keyPath) is incorrect") - } - return .clientKeyPassphraseRequired(serverMessage: "The client key at \(keyPath) is encrypted. Enter its passphrase.") - } - - func close() { - if let session { - let closeFuture = cass_session_close(session) - if let closeFuture { - cass_future_wait(closeFuture) - cass_future_free(closeFuture) - } - cass_session_free(session) - self.session = nil - } - - if let cluster { - cass_cluster_free(cluster) - self.cluster = nil - } - - currentKeyspace = nil - Self.logger.info("Disconnected from Cassandra") - } - - func executeQuery(_ cql: String) throws -> CassandraRawResult { - guard let session else { - throw CassandraPluginError.notConnected - } - - let startTime = Date() - let statement = cass_statement_new(cql, 0) - guard let statement else { - throw CassandraPluginError.queryFailed("Failed to create statement") - } - - defer { cass_statement_free(statement) } - - let future = cass_session_execute(session, statement) - guard let future else { - throw CassandraPluginError.queryFailed("Failed to execute query") - } - - defer { cass_future_free(future) } - - cass_future_wait(future) - let rc = cass_future_error_code(future) - - if rc != CASS_OK { - throw CassandraPluginError.queryFailed(extractFutureError(future)) - } - - let result = cass_future_get_result(future) - defer { - if let result { cass_result_free(result) } - } - - guard let result else { - let executionTime = Date().timeIntervalSince(startTime) - return CassandraRawResult( - columns: [], - columnTypeNames: [], - rows: [], - rowsAffected: 0, - executionTime: executionTime - ) - } - - return extractResult(from: result, startTime: startTime) - } - - func executePrepared(_ cql: String, parameters: [PluginCellValue]) throws -> CassandraRawResult { - guard let session else { - throw CassandraPluginError.notConnected - } - - let startTime = Date() - - // Prepare - let prepareFuture = cass_session_prepare(session, cql) - guard let prepareFuture else { - throw CassandraPluginError.queryFailed("Failed to prepare statement") - } - defer { cass_future_free(prepareFuture) } - - cass_future_wait(prepareFuture) - let prepRc = cass_future_error_code(prepareFuture) - if prepRc != CASS_OK { - throw CassandraPluginError.queryFailed(extractFutureError(prepareFuture)) - } - - let prepared = cass_future_get_prepared(prepareFuture) - guard let prepared else { - throw CassandraPluginError.queryFailed("Failed to get prepared statement") - } - defer { cass_prepared_free(prepared) } - - // Bind parameters - let statement = cass_prepared_bind(prepared) - guard let statement else { - throw CassandraPluginError.queryFailed("Failed to bind prepared statement") - } - defer { cass_statement_free(statement) } - - for (index, param) in parameters.enumerated() { - switch param { - case .text(let value): - cass_statement_bind_string(statement, index, value) - case .bytes(let data): - data.withUnsafeBytes { rawBuffer in - if let base = rawBuffer.baseAddress?.assumingMemoryBound(to: UInt8.self) { - cass_statement_bind_bytes(statement, index, base, data.count) - } else { - cass_statement_bind_null(statement, index) - } - } - case .null: - cass_statement_bind_null(statement, index) - } - } - - // Execute - let future = cass_session_execute(session, statement) - guard let future else { - throw CassandraPluginError.queryFailed("Failed to execute prepared statement") - } - defer { cass_future_free(future) } - - cass_future_wait(future) - let rc = cass_future_error_code(future) - - if rc != CASS_OK { - throw CassandraPluginError.queryFailed(extractFutureError(future)) - } - - let result = cass_future_get_result(future) - defer { - if let result { cass_result_free(result) } - } - - guard let result else { - let executionTime = Date().timeIntervalSince(startTime) - return CassandraRawResult( - columns: [], - columnTypeNames: [], - rows: [], - rowsAffected: 0, - executionTime: executionTime - ) - } - - return extractResult(from: result, startTime: startTime) - } - - func switchKeyspace(_ keyspace: String) throws { - _ = try executeQuery("USE \"\(escapeIdentifier(keyspace))\"") - currentKeyspace = keyspace - } - - func serverVersion() throws -> String? { - let result = try executeQuery("SELECT release_version FROM system.local WHERE key = 'local'") - return result.rows.first?.first?.asText - } - - // MARK: - Private Helpers - - private func extractResult( - from result: OpaquePointer, - startTime: Date - ) -> CassandraRawResult { - let colCount = cass_result_column_count(result) - let rowCount = cass_result_row_count(result) - - var columns: [String] = [] - var columnTypeNames: [String] = [] - - for i in 0..? - var nameLength: Int = 0 - cass_result_column_name(result, i, &namePtr, &nameLength) - if let namePtr { - columns.append(String(cString: namePtr)) - } else { - columns.append("column_\(i)") - } - - let colType = cass_result_column_type(result, i) - columnTypeNames.append(Self.cassTypeName(colType)) - } - - var rows: [[PluginCellValue]] = [] - let iterator = cass_iterator_from_result(result) - defer { - if let iterator { cass_iterator_free(iterator) } - } - - guard let iterator else { - let executionTime = Date().timeIntervalSince(startTime) - return CassandraRawResult( - columns: columns, - columnTypeNames: columnTypeNames, - rows: [], - rowsAffected: Int(rowCount), - executionTime: executionTime - ) - } - - let maxRows = min(Int(rowCount), 100_000) - var count = 0 - - while cass_iterator_next(iterator) == cass_true && count < maxRows { - let row = cass_iterator_get_row(iterator) - guard let row else { continue } - - var rowData: [PluginCellValue] = [] - for col in 0.. Data? { - var bytes: UnsafePointer? - var length: Int = 0 - guard cass_value_get_bytes(value, &bytes, &length) == CASS_OK, let bytes else { - return nil - } - return Data(bytes: bytes, count: length) - } - - private static func extractStringValue(_ value: OpaquePointer) -> String? { - let valueType = cass_value_type(value) - - switch valueType { - case CASS_VALUE_TYPE_ASCII, CASS_VALUE_TYPE_TEXT, CASS_VALUE_TYPE_VARCHAR: - var output: UnsafePointer? - var outputLength: Int = 0 - let rc = cass_value_get_string(value, &output, &outputLength) - if rc == CASS_OK, let output { - return String( - bytesNoCopy: UnsafeMutableRawPointer(mutating: output), - length: outputLength, - encoding: .utf8, - freeWhenDone: false - ) - } - return nil - - case CASS_VALUE_TYPE_INT: - var intVal: Int32 = 0 - if cass_value_get_int32(value, &intVal) == CASS_OK { - return String(intVal) - } - return nil - - case CASS_VALUE_TYPE_BIGINT, CASS_VALUE_TYPE_COUNTER: - var bigintVal: Int64 = 0 - if cass_value_get_int64(value, &bigintVal) == CASS_OK { - return String(bigintVal) - } - return nil - - case CASS_VALUE_TYPE_SMALL_INT: - var smallVal: Int16 = 0 - if cass_value_get_int16(value, &smallVal) == CASS_OK { - return String(smallVal) - } - return nil - - case CASS_VALUE_TYPE_TINY_INT: - var tinyVal: Int8 = 0 - if cass_value_get_int8(value, &tinyVal) == CASS_OK { - return String(tinyVal) - } - return nil - - case CASS_VALUE_TYPE_FLOAT: - var floatVal: Float = 0 - if cass_value_get_float(value, &floatVal) == CASS_OK { - return String(floatVal) - } - return nil - - case CASS_VALUE_TYPE_DOUBLE: - var doubleVal: Double = 0 - if cass_value_get_double(value, &doubleVal) == CASS_OK { - return String(doubleVal) - } - return nil - - case CASS_VALUE_TYPE_BOOLEAN: - var boolVal: cass_bool_t = cass_false - if cass_value_get_bool(value, &boolVal) == CASS_OK { - return boolVal == cass_true ? "true" : "false" - } - return nil - - case CASS_VALUE_TYPE_UUID, CASS_VALUE_TYPE_TIMEUUID: - var uuid = CassUuid() - if cass_value_get_uuid(value, &uuid) == CASS_OK { - var buffer = [CChar](repeating: 0, count: Int(CASS_UUID_STRING_LENGTH)) - cass_uuid_string(uuid, &buffer) - return String(cString: buffer) - } - return nil - - case CASS_VALUE_TYPE_TIMESTAMP: - var timestamp: Int64 = 0 - if cass_value_get_int64(value, ×tamp) == CASS_OK { - let date = Date(timeIntervalSince1970: Double(timestamp) / 1000.0) - return isoFormatter.string(from: date) - } - return nil - - case CASS_VALUE_TYPE_BLOB: - if let data = extractBlobValue(value) { - return "0x" + data.map { String(format: "%02x", $0) }.joined() - } - return nil - - case CASS_VALUE_TYPE_INET: - var inet = CassInet() - if cass_value_get_inet(value, &inet) == CASS_OK { - var buffer = [CChar](repeating: 0, count: Int(CASS_INET_STRING_LENGTH)) - cass_inet_string(inet, &buffer) - return String(cString: buffer) - } - return nil - - case CASS_VALUE_TYPE_LIST, CASS_VALUE_TYPE_SET: - return extractCollectionString(value, open: "[", close: "]") - - case CASS_VALUE_TYPE_MAP: - return extractMapString(value) - - case CASS_VALUE_TYPE_TUPLE: - return extractCollectionString(value, open: "(", close: ")") - - case CASS_VALUE_TYPE_DATE: - var dateVal: UInt32 = 0 - if cass_value_get_uint32(value, &dateVal) == CASS_OK { - let daysSinceEpoch = Int64(dateVal) - Int64(1 << 31) - let epochSeconds = daysSinceEpoch * 86400 - let date = Date(timeIntervalSince1970: Double(epochSeconds)) - return dateFormatter.string(from: date) - } - return nil - - case CASS_VALUE_TYPE_TIME: - var timeVal: Int64 = 0 - if cass_value_get_int64(value, &timeVal) == CASS_OK { - // Cassandra time is nanoseconds since midnight - let totalSeconds = timeVal / 1_000_000_000 - let hours = totalSeconds / 3600 - let minutes = (totalSeconds % 3600) / 60 - let seconds = totalSeconds % 60 - let nanos = timeVal % 1_000_000_000 - if nanos > 0 { - let millis = nanos / 1_000_000 - return String(format: "%02lld:%02lld:%02lld.%03lld", hours, minutes, seconds, millis) - } - return String(format: "%02lld:%02lld:%02lld", hours, minutes, seconds) - } - return nil - - case CASS_VALUE_TYPE_DECIMAL, CASS_VALUE_TYPE_VARINT: - // Read as bytes and display as hex since proper numeric decoding - // requires BigInteger support not available in the C driver API - var bytes: UnsafePointer? - var length: Int = 0 - if cass_value_get_bytes(value, &bytes, &length) == CASS_OK, let bytes { - let data = Data(bytes: bytes, count: length) - return "0x" + data.map { String(format: "%02x", $0) }.joined() - } - return nil - - default: - // Fallback: try reading as string - var output: UnsafePointer? - var outputLength: Int = 0 - if cass_value_get_string(value, &output, &outputLength) == CASS_OK, let output { - return String( - bytesNoCopy: UnsafeMutableRawPointer(mutating: output), - length: outputLength, - encoding: .utf8, - freeWhenDone: false - ) - } - return "" - } - } - - private static func extractCollectionString( - _ value: OpaquePointer, - open: String, - close: String - ) -> String { - guard let iterator = cass_iterator_from_collection(value) else { - return "\(open)\(close)" - } - defer { cass_iterator_free(iterator) } - - var elements: [String] = [] - while cass_iterator_next(iterator) == cass_true { - if let elem = cass_iterator_get_value(iterator) { - elements.append(extractStringValue(elem) ?? "null") - } - } - return "\(open)\(elements.joined(separator: ", "))\(close)" - } - - private static func extractMapString(_ value: OpaquePointer) -> String { - guard let iterator = cass_iterator_from_map(value) else { - return "{}" - } - defer { cass_iterator_free(iterator) } - - var pairs: [String] = [] - while cass_iterator_next(iterator) == cass_true { - let key = cass_iterator_get_map_key(iterator) - let val = cass_iterator_get_map_value(iterator) - let keyStr = key.flatMap { extractStringValue($0) } ?? "null" - let valStr = val.flatMap { extractStringValue($0) } ?? "null" - pairs.append("\(keyStr): \(valStr)") - } - return "{\(pairs.joined(separator: ", "))}" - } - - private static func cassTypeName(_ type: CassValueType) -> String { - switch type { - case CASS_VALUE_TYPE_ASCII: return "ascii" - case CASS_VALUE_TYPE_BIGINT: return "bigint" - case CASS_VALUE_TYPE_BLOB: return "blob" - case CASS_VALUE_TYPE_BOOLEAN: return "boolean" - case CASS_VALUE_TYPE_COUNTER: return "counter" - case CASS_VALUE_TYPE_DECIMAL: return "decimal" - case CASS_VALUE_TYPE_DOUBLE: return "double" - case CASS_VALUE_TYPE_FLOAT: return "float" - case CASS_VALUE_TYPE_INT: return "int" - case CASS_VALUE_TYPE_TEXT: return "text" - case CASS_VALUE_TYPE_TIMESTAMP: return "timestamp" - case CASS_VALUE_TYPE_UUID: return "uuid" - case CASS_VALUE_TYPE_VARCHAR: return "varchar" - case CASS_VALUE_TYPE_VARINT: return "varint" - case CASS_VALUE_TYPE_TIMEUUID: return "timeuuid" - case CASS_VALUE_TYPE_INET: return "inet" - case CASS_VALUE_TYPE_DATE: return "date" - case CASS_VALUE_TYPE_TIME: return "time" - case CASS_VALUE_TYPE_SMALL_INT: return "smallint" - case CASS_VALUE_TYPE_TINY_INT: return "tinyint" - case CASS_VALUE_TYPE_LIST: return "list" - case CASS_VALUE_TYPE_MAP: return "map" - case CASS_VALUE_TYPE_SET: return "set" - case CASS_VALUE_TYPE_TUPLE: return "tuple" - case CASS_VALUE_TYPE_UDT: return "udt" - default: return "text" - } - } - - private func extractFutureError(_ future: OpaquePointer) -> String { - var message: UnsafePointer? - var messageLength: Int = 0 - cass_future_error_message(future, &message, &messageLength) - if let message { - return String( - bytesNoCopy: UnsafeMutableRawPointer(mutating: message), - length: messageLength, - encoding: .utf8, - freeWhenDone: false - ) ?? "Unknown error" - } - return "Unknown error" - } - - func streamQuery( - _ cql: String, - continuation: AsyncThrowingStream.Continuation - ) throws { - guard let session else { - throw CassandraPluginError.notConnected - } - - let pageSize: Int32 = 5_000 - let statement = cass_statement_new(cql, 0) - guard let statement else { - throw CassandraPluginError.queryFailed("Failed to create statement") - } - - cass_statement_set_paging_size(statement, pageSize) - - var headerSent = false - - defer { cass_statement_free(statement) } - - while true { - let future = cass_session_execute(session, statement) - guard let future else { - throw CassandraPluginError.queryFailed("Failed to execute query") - } - - cass_future_wait(future) - let rc = cass_future_error_code(future) - - if rc != CASS_OK { - let errorMessage = extractFutureError(future) - cass_future_free(future) - throw CassandraPluginError.queryFailed(errorMessage) - } - - let result = cass_future_get_result(future) - cass_future_free(future) - - guard let result else { break } - - if !headerSent { - let colCount = cass_result_column_count(result) - var columns: [String] = [] - var columnTypeNames: [String] = [] - - for i in 0..? - var nameLength: Int = 0 - cass_result_column_name(result, i, &namePtr, &nameLength) - if let namePtr { - columns.append(String(cString: namePtr)) - } else { - columns.append("column_\(i)") - } - let colType = cass_result_column_type(result, i) - columnTypeNames.append(Self.cassTypeName(colType)) - } - - continuation.yield(.header(PluginStreamHeader( - columns: columns, - columnTypeNames: columnTypeNames, - estimatedRowCount: nil - ))) - headerSent = true - } - - let colCount = cass_result_column_count(result) - let iterator = cass_iterator_from_result(result) - - if let iterator { - while cass_iterator_next(iterator) == cass_true { - let row = cass_iterator_get_row(iterator) - guard let row else { continue } - - var rowData: [PluginCellValue] = [] - for col in 0.. String { - value.replacingOccurrences(of: "\"", with: "\"\"") - } - - static func classifySSLError(rc: CassError, message: String) -> SSLHandshakeError? { - switch rc { - case CASS_ERROR_SSL_NO_PEER_CERT, CASS_ERROR_SSL_INVALID_PEER_CERT: - return .untrustedCertificate(serverMessage: message) - case CASS_ERROR_SSL_IDENTITY_MISMATCH: - return .hostnameMismatch(serverMessage: message) - case CASS_ERROR_SSL_INVALID_PRIVATE_KEY, CASS_ERROR_SSL_INVALID_CERT: - return .clientCertRequired(serverMessage: message) - case CASS_ERROR_SSL_PROTOCOL_ERROR: - return .cipherMismatch(serverMessage: message) - default: - break - } - let lower = message.lowercased() - if lower.contains("ssl handshake") || lower.contains("tls handshake") || lower.contains("ssl_connect") { - return .cipherMismatch(serverMessage: message) - } - return nil - } -} - -// MARK: - Raw Result - -private struct CassandraRawResult: Sendable { - let columns: [String] - let columnTypeNames: [String] - let rows: [[PluginCellValue]] - let rowsAffected: Int - let executionTime: TimeInterval -} - // MARK: - Plugin Driver internal final class CassandraPluginDriver: PluginDatabaseDriver, @unchecked Sendable { @@ -1428,22 +586,3 @@ internal final class CassandraPluginDriver: PluginDatabaseDriver, @unchecked Sen } } -// MARK: - Errors - -internal enum CassandraPluginError: Error { - case connectionFailed(String) - case notConnected - case queryFailed(String) - case unsupportedOperation -} - -extension CassandraPluginError: PluginDriverError { - var pluginErrorMessage: String { - switch self { - case .connectionFailed(let msg): return msg - case .notConnected: return String(localized: "Not connected to database") - case .queryFailed(let msg): return msg - case .unsupportedOperation: return String(localized: "Operation not supported by Cassandra") - } - } -} diff --git a/Plugins/CassandraDriverPlugin/CassandraPluginError.swift b/Plugins/CassandraDriverPlugin/CassandraPluginError.swift new file mode 100644 index 000000000..e338b852a --- /dev/null +++ b/Plugins/CassandraDriverPlugin/CassandraPluginError.swift @@ -0,0 +1,25 @@ +// +// CassandraPluginError.swift +// CassandraDriverPlugin +// + +import Foundation +import TableProPluginKit + +internal enum CassandraPluginError: Error { + case connectionFailed(String) + case notConnected + case queryFailed(String) + case unsupportedOperation +} + +extension CassandraPluginError: PluginDriverError { + var pluginErrorMessage: String { + switch self { + case .connectionFailed(let msg): return msg + case .notConnected: return String(localized: "Not connected to database") + case .queryFailed(let msg): return msg + case .unsupportedOperation: return String(localized: "Operation not supported by Cassandra") + } + } +} diff --git a/Plugins/ClickHouseDriverPlugin/ClickHousePlugin.swift b/Plugins/ClickHouseDriverPlugin/ClickHousePlugin.swift index 16d96e512..3dac36eaf 100644 --- a/Plugins/ClickHouseDriverPlugin/ClickHousePlugin.swift +++ b/Plugins/ClickHouseDriverPlugin/ClickHousePlugin.swift @@ -115,7 +115,7 @@ final class ClickHousePlugin: NSObject, TableProPlugin, DriverPlugin { // MARK: - Error Types -private struct ClickHouseError: Error, PluginDriverError { +struct ClickHouseError: Error, PluginDriverError { let message: String var pluginErrorMessage: String { message } @@ -126,7 +126,7 @@ private struct ClickHouseError: Error, PluginDriverError { // MARK: - Internal Query Result -private struct CHQueryResult { +struct CHQueryResult { let columns: [String] let columnTypeNames: [String] let rows: [[PluginCellValue]] @@ -137,19 +137,19 @@ private struct CHQueryResult { // MARK: - Plugin Driver final class ClickHousePluginDriver: PluginDatabaseDriver, @unchecked Sendable { - private let config: DriverConnectionConfig + let config: DriverConnectionConfig private var _serverVersion: String? - private let lock = NSLock() - private var session: URLSession? - private var currentTask: URLSessionDataTask? - private var _currentDatabase: String - private var _lastQueryId: String? - private let _queryTimeout = HttpQueryTimeoutBox() + let lock = NSLock() + var session: URLSession? + var currentTask: URLSessionDataTask? + var _currentDatabase: String + var _lastQueryId: String? + let _queryTimeout = HttpQueryTimeoutBox() - private static let logger = Logger(subsystem: "com.TablePro", category: "ClickHousePluginDriver") + static let logger = Logger(subsystem: "com.TablePro", category: "ClickHousePluginDriver") - private static let selectPrefixes: Set = [ + static let selectPrefixes: Set = [ "SELECT", "SHOW", "DESCRIBE", "DESC", "EXISTS", "EXPLAIN", "WITH" ] @@ -282,335 +282,6 @@ final class ClickHousePluginDriver: PluginDatabaseDriver, @unchecked Sendable { ) } - // MARK: - Schema Operations - - func fetchTables(schema: String?) async throws -> [PluginTableInfo] { - let sql = """ - SELECT name, engine FROM system.tables - WHERE database = currentDatabase() AND name NOT LIKE '.%' - ORDER BY name - """ - let result = try await execute(query: sql) - return result.rows.compactMap { row -> PluginTableInfo? in - guard let name = row[safe: 0]?.asText else { return nil } - let engine = row[safe: 1]?.asText - let tableType = (engine?.contains("View") == true) ? "VIEW" : "TABLE" - return PluginTableInfo(name: name, type: tableType) - } - } - - func fetchColumns(table: String, schema: String?) async throws -> [PluginColumnInfo] { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - - let pkSql = """ - SELECT primary_key, sorting_key FROM system.tables - WHERE database = currentDatabase() AND name = '\(escapedTable)' - """ - let pkResult = try await execute(query: pkSql) - let primaryKey = pkResult.rows.first.flatMap { $0[safe: 0]?.asText } ?? "" - let sortingKey = pkResult.rows.first.flatMap { $0[safe: 1]?.asText } ?? "" - let keyString = primaryKey.isEmpty ? sortingKey : primaryKey - let pkColumns = Set(keyString.split(separator: ",").map { $0.trimmingCharacters(in: .whitespaces) }) - - let sql = """ - SELECT name, type, default_kind, default_expression, comment - FROM system.columns - WHERE database = currentDatabase() AND table = '\(escapedTable)' - ORDER BY position - """ - let result = try await execute(query: sql) - return result.rows.compactMap { row -> PluginColumnInfo? in - guard let name = row[safe: 0]?.asText else { return nil } - let dataType = (row[safe: 1]?.asText) ?? "String" - let defaultKind = row[safe: 2]?.asText - let defaultExpr = row[safe: 3]?.asText - let comment = row[safe: 4]?.asText - - let isNullable = dataType.hasPrefix("Nullable(") - - var defaultValue: String? - if let kind = defaultKind, !kind.isEmpty, let expr = defaultExpr, !expr.isEmpty { - defaultValue = expr - } - - var extra: String? - if let kind = defaultKind, !kind.isEmpty, kind != "DEFAULT" { - extra = kind - } - - return PluginColumnInfo( - name: name, - dataType: dataType, - isNullable: isNullable, - isPrimaryKey: pkColumns.contains(name), - defaultValue: defaultValue, - extra: extra, - comment: (comment?.isEmpty == false) ? comment : nil, - allowedValues: EnumValueParser.parseClickHouseEnum(from: ClickHousePluginDriver.unwrapTypeWrappers(dataType)) - ) - } - } - - func fetchAllColumns(schema: String?) async throws -> [String: [PluginColumnInfo]] { - // Pre-fetch PK columns for all tables. Falls back to sorting_key when - // primary_key is empty (MergeTree without explicit PRIMARY KEY clause). - // Note: expression-based keys like toDate(col) won't match bare column names. - let pkSql = """ - SELECT name, primary_key, sorting_key FROM system.tables - WHERE database = currentDatabase() - """ - let pkResult = try await execute(query: pkSql) - var pkLookup: [String: Set] = [:] - for row in pkResult.rows { - guard let tableName = row[safe: 0]?.asText else { continue } - let primaryKey = (row[safe: 1]?.asText) ?? "" - let sortingKey = (row[safe: 2]?.asText) ?? "" - let keyString = primaryKey.isEmpty ? sortingKey : primaryKey - guard !keyString.isEmpty else { continue } - let cols = Set(keyString.split(separator: ",").map { String($0).trimmingCharacters(in: .whitespaces) }) - pkLookup[tableName] = cols - } - - let sql = """ - SELECT table, name, type, default_kind, default_expression, comment - FROM system.columns - WHERE database = currentDatabase() - ORDER BY table, position - """ - let result = try await execute(query: sql) - var columnsByTable: [String: [PluginColumnInfo]] = [:] - for row in result.rows { - guard let tableName = row[safe: 0]?.asText, - let colName = row[safe: 1]?.asText else { continue } - let dataType = (row[safe: 2]?.asText) ?? "String" - let defaultKind = row[safe: 3]?.asText - let defaultExpr = row[safe: 4]?.asText - let comment = row[safe: 5]?.asText - - let isNullable = dataType.hasPrefix("Nullable(") - - var defaultValue: String? - if let kind = defaultKind, !kind.isEmpty, let expr = defaultExpr, !expr.isEmpty { - defaultValue = expr - } - - var extra: String? - if let kind = defaultKind, !kind.isEmpty, kind != "DEFAULT" { - extra = kind - } - - let colInfo = PluginColumnInfo( - name: colName, - dataType: dataType, - isNullable: isNullable, - isPrimaryKey: pkLookup[tableName]?.contains(colName) == true, - defaultValue: defaultValue, - extra: extra, - comment: (comment?.isEmpty == false) ? comment : nil, - allowedValues: EnumValueParser.parseClickHouseEnum(from: ClickHousePluginDriver.unwrapTypeWrappers(dataType)) - ) - columnsByTable[tableName, default: []].append(colInfo) - } - return columnsByTable - } - - static func unwrapTypeWrappers(_ value: String) -> String { - for prefix in ["Nullable(", "LowCardinality("] { - if value.hasPrefix(prefix), value.hasSuffix(")") { - let start = value.index(value.startIndex, offsetBy: prefix.count) - let end = value.index(before: value.endIndex) - return unwrapTypeWrappers(String(value[start.. [PluginIndexInfo] { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - var indexes: [PluginIndexInfo] = [] - - let sortingKeySql = """ - SELECT sorting_key FROM system.tables - WHERE database = currentDatabase() AND name = '\(escapedTable)' - """ - let sortingResult = try await execute(query: sortingKeySql) - if let row = sortingResult.rows.first, - let sortingKey = row[safe: 0]?.asText, !sortingKey.isEmpty { - let columns = sortingKey.components(separatedBy: ",").map { - $0.trimmingCharacters(in: .whitespacesAndNewlines) - } - indexes.append(PluginIndexInfo( - name: "PRIMARY (sorting key)", - columns: columns, - isUnique: false, - isPrimary: true, - type: "SORTING KEY" - )) - } - - let caps = ClickHouseCapabilities.parse(serverVersion) - guard caps.hasDataSkippingIndicesTable else { return indexes } - let skippingSql = """ - SELECT name, expr FROM system.data_skipping_indices - WHERE database = currentDatabase() AND table = '\(escapedTable)' - """ - let skippingResult = try await execute(query: skippingSql) - for row in skippingResult.rows { - guard let idxName = row[safe: 0]?.asText else { continue } - let expr = (row[safe: 1]?.asText) ?? "" - let columns = expr.components(separatedBy: ",").map { - $0.trimmingCharacters(in: .whitespacesAndNewlines) - } - indexes.append(PluginIndexInfo( - name: idxName, - columns: columns, - isUnique: false, - isPrimary: false, - type: "DATA_SKIPPING" - )) - } - - return indexes - } - - func fetchForeignKeys(table: String, schema: String?) async throws -> [PluginForeignKeyInfo] { - [] - } - - func fetchApproximateRowCount(table: String, schema: String?) async throws -> Int? { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - let sql = """ - SELECT sum(rows) FROM system.parts - WHERE database = currentDatabase() AND table = '\(escapedTable)' AND active = 1 - """ - let result = try await execute(query: sql) - if let row = result.rows.first, let cell = row.first, let str = cell.asText { - return Int(str) - } - return nil - } - - func fetchTableDDL(table: String, schema: String?) async throws -> String { - let escapedTable = table.replacingOccurrences(of: "`", with: "``") - let sql = "SHOW CREATE TABLE `\(escapedTable)`" - let result = try await execute(query: sql) - return result.rows.first?.first?.asText ?? "" - } - - func fetchViewDefinition(view: String, schema: String?) async throws -> String { - let escapedView = view.replacingOccurrences(of: "'", with: "''") - let sql = """ - SELECT as_select FROM system.tables - WHERE database = currentDatabase() AND name = '\(escapedView)' - """ - let result = try await execute(query: sql) - return result.rows.first?.first?.asText ?? "" - } - - func fetchTableMetadata(table: String, schema: String?) async throws -> PluginTableMetadata { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - - let engineSql = """ - SELECT engine, comment FROM system.tables - WHERE database = currentDatabase() AND name = '\(escapedTable)' - """ - let engineResult = try await execute(query: engineSql) - let engine = engineResult.rows.first.flatMap { $0[safe: 0]?.asText } - let tableComment = engineResult.rows.first.flatMap { $0[safe: 1]?.asText } - - let partsSql = """ - SELECT sum(rows), sum(bytes_on_disk) - FROM system.parts - WHERE database = currentDatabase() AND table = '\(escapedTable)' AND active = 1 - """ - let partsResult = try await execute(query: partsSql) - if let row = partsResult.rows.first { - let rowCount = (row[safe: 0]?.asText).flatMap { Int64($0) } - let sizeBytes = (row[safe: 1]?.asText).flatMap { Int64($0) } ?? 0 - return PluginTableMetadata( - tableName: table, - dataSize: sizeBytes, - totalSize: sizeBytes, - rowCount: rowCount, - comment: (tableComment?.isEmpty == false) ? tableComment : nil, - engine: engine - ) - } - - return PluginTableMetadata(tableName: table, engine: engine) - } - - func fetchDatabases() async throws -> [String] { - let result = try await execute(query: "SHOW DATABASES") - return result.rows.compactMap { $0.first?.asText } - } - - func fetchDatabaseMetadata(_ database: String) async throws -> PluginDatabaseMetadata { - let escapedDb = database.replacingOccurrences(of: "'", with: "''") - let sql = """ - SELECT count() AS table_count, sum(total_bytes) AS size_bytes - FROM system.tables WHERE database = '\(escapedDb)' - """ - let result = try await execute(query: sql) - if let row = result.rows.first { - let tableCount = (row[safe: 0]?.asText).flatMap { Int($0) } ?? 0 - let sizeBytes = (row[safe: 1]?.asText).flatMap { Int64($0) } - return PluginDatabaseMetadata( - name: database, - tableCount: tableCount, - sizeBytes: sizeBytes - ) - } - return PluginDatabaseMetadata(name: database) - } - - func fetchAllDatabaseMetadata() async throws -> [PluginDatabaseMetadata] { - let sql = """ - SELECT database, count() AS table_count, sum(total_bytes) AS size_bytes - FROM system.tables - GROUP BY database - ORDER BY database - """ - let result = try await execute(query: sql) - return result.rows.compactMap { row -> PluginDatabaseMetadata? in - guard let name = row[safe: 0]?.asText else { return nil } - let tableCount = (row[safe: 1]?.asText).flatMap { Int($0) } ?? 0 - let sizeBytes = (row[safe: 2]?.asText).flatMap { Int64($0) } - return PluginDatabaseMetadata(name: name, tableCount: tableCount, sizeBytes: sizeBytes) - } - } - - func createDatabaseFormSpec() async throws -> PluginCreateDatabaseFormSpec? { - PluginCreateDatabaseFormSpec(fields: [], footnote: nil) - } - - func createDatabase(_ request: PluginCreateDatabaseRequest) async throws { - let escapedName = request.name.replacingOccurrences(of: "`", with: "``") - _ = try await execute(query: "CREATE DATABASE `\(escapedName)`") - } - - func dropDatabase(name: String) async throws { - let escapedName = name.replacingOccurrences(of: "`", with: "``") - _ = try await execute(query: "DROP DATABASE `\(escapedName)`") - } - - // MARK: - All Tables Metadata - - func allTablesMetadataSQL(schema: String?) -> String? { - """ - SELECT - database as `schema`, - name, - engine as kind, - total_rows as estimated_rows, - formatReadableSize(total_bytes) as total_size, - comment - FROM system.tables - WHERE database = currentDatabase() - ORDER BY name - """ - } - // MARK: - DML Statement Generation func generateStatements( @@ -806,304 +477,6 @@ final class ClickHousePluginDriver: PluginDatabaseDriver, @unchecked Sendable { } } - // MARK: - Private HTTP Layer - - private func executeRaw(_ query: String, queryId: String? = nil) async throws -> CHQueryResult { - lock.lock() - guard let session = self.session else { - lock.unlock() - throw ClickHouseError.notConnected - } - let database = _currentDatabase - if let queryId { - _lastQueryId = queryId - } - lock.unlock() - - var request = try buildRequest(query: query, database: database, queryId: queryId) - request.timeoutInterval = _queryTimeout.requestTimeoutInterval - let isSelect = Self.isSelectLikeQuery(query) - - let (data, response) = try await withTaskCancellationHandler { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<(Data, URLResponse), Error>) in - let task = session.dataTask(with: request) { data, response, error in - if let error { - continuation.resume(throwing: error) - return - } - guard let data, let response else { - continuation.resume(throwing: ClickHouseError(message: "Empty response from server")) - return - } - continuation.resume(returning: (data, response)) - } - - self.lock.lock() - self.currentTask = task - self.lock.unlock() - - task.resume() - } - } onCancel: { - self.lock.lock() - self.currentTask?.cancel() - self.currentTask = nil - self.lock.unlock() - } - - lock.lock() - currentTask = nil - lock.unlock() - - if let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode >= 400 { - let body = String(data: data, encoding: .utf8) ?? "Unknown error" - Self.logger.error("ClickHouse HTTP \(httpResponse.statusCode): \(body)") - throw ClickHouseError(message: body.trimmingCharacters(in: .whitespacesAndNewlines)) - } - - if isSelect { - return parseTabSeparatedResponse(data) - } - - return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) - } - - private func executeRawWithParams(_ query: String, params: [String: String?], queryId: String? = nil) async throws -> CHQueryResult { - lock.lock() - guard let session = self.session else { - lock.unlock() - throw ClickHouseError.notConnected - } - let database = _currentDatabase - if let queryId { - _lastQueryId = queryId - } - lock.unlock() - - var request = try buildRequest(query: query, database: database, queryId: queryId, params: params) - request.timeoutInterval = _queryTimeout.requestTimeoutInterval - let isSelect = Self.isSelectLikeQuery(query) - - let (data, response) = try await withTaskCancellationHandler { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<(Data, URLResponse), Error>) in - let task = session.dataTask(with: request) { data, response, error in - if let error { - continuation.resume(throwing: error) - return - } - guard let data, let response else { - continuation.resume(throwing: ClickHouseError(message: "Empty response from server")) - return - } - continuation.resume(returning: (data, response)) - } - self.lock.lock() - self.currentTask = task - self.lock.unlock() - task.resume() - } - } onCancel: { - self.lock.lock() - self.currentTask?.cancel() - self.currentTask = nil - self.lock.unlock() - } - - lock.lock() - currentTask = nil - lock.unlock() - - if let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode >= 400 { - let body = String(data: data, encoding: .utf8) ?? "Unknown error" - Self.logger.error("ClickHouse HTTP \(httpResponse.statusCode): \(body)") - throw ClickHouseError(message: body.trimmingCharacters(in: .whitespacesAndNewlines)) - } - - if isSelect { - return parseTabSeparatedResponse(data) - } - - return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) - } - - private func buildRequest(query: String, database: String, queryId: String? = nil, params: [String: String?]? = nil) throws -> URLRequest { - let useTLS = config.ssl.isEnabled - - var components = URLComponents() - components.scheme = useTLS ? "https" : "http" - components.host = config.host - components.port = config.port - components.path = "/" - - var queryItems = [URLQueryItem]() - if !database.isEmpty { - queryItems.append(URLQueryItem(name: "database", value: database)) - } - if let queryId { - queryItems.append(URLQueryItem(name: "query_id", value: queryId)) - } - queryItems.append(URLQueryItem(name: "send_progress_in_http_headers", value: "1")) - if let params { - for (key, value) in params.sorted(by: { $0.key < $1.key }) { - queryItems.append(URLQueryItem(name: "param_\(key)", value: value)) - } - } - if !queryItems.isEmpty { - components.queryItems = queryItems - } - - guard let url = components.url else { - throw ClickHouseError(message: "Failed to construct request URL") - } - - var request = URLRequest(url: url) - request.httpMethod = "POST" - - let credentials = "\(config.username):\(config.password)" - if let credData = credentials.data(using: .utf8) { - request.setValue("Basic \(credData.base64EncodedString())", forHTTPHeaderField: "Authorization") - } - - let trimmedQuery = query.trimmingCharacters(in: .whitespacesAndNewlines) - .replacingOccurrences(of: ";+$", with: "", options: .regularExpression) - - if Self.isSelectLikeQuery(trimmedQuery) { - request.httpBody = (trimmedQuery + " FORMAT TabSeparatedWithNamesAndTypes").data(using: .utf8) - } else { - request.httpBody = trimmedQuery.data(using: .utf8) - } - - return request - } - - private static func isSelectLikeQuery(_ query: String) -> Bool { - let trimmed = query.trimmingCharacters(in: .whitespacesAndNewlines) - guard let firstWord = trimmed.split(separator: " ", maxSplits: 1).first else { - return false - } - return selectPrefixes.contains(firstWord.uppercased()) - } - - private func parseTabSeparatedResponse(_ data: Data) -> CHQueryResult { - guard let text = String(data: data, encoding: .utf8), !text.isEmpty else { - return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) - } - - let lines = text.components(separatedBy: "\n") - - guard lines.count >= 2 else { - return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) - } - - let columns = lines[0].components(separatedBy: "\t") - let columnTypes = lines[1].components(separatedBy: "\t") - - var rows: [[PluginCellValue]] = [] - var truncated = false - for i in 2..= PluginRowLimits.emergencyMax { - truncated = true - break - } - } - - return CHQueryResult( - columns: columns, - columnTypeNames: columnTypes, - rows: rows, - affectedRows: rows.count, - isTruncated: truncated - ) - } - - private static func unescapeTsvField(_ field: String) -> String { - var result = "" - result.reserveCapacity((field as NSString).length) - var iterator = field.makeIterator() - - while let char = iterator.next() { - if char == "\\" { - if let next = iterator.next() { - switch next { - case "\\": result.append("\\") - case "t": result.append("\t") - case "n": result.append("\n") - default: - result.append("\\") - result.append(next) - } - } else { - result.append("\\") - } - } else { - result.append(char) - } - } - - return result - } - - /// Convert `?` placeholders to `{p1:String}` and build parameter map for ClickHouse HTTP params. - private static func buildClickHouseParams( - query: String, - parameters: [PluginCellValue] - ) -> (String, [String: String?]) { - var converted = "" - var paramIndex = 0 - var inSingleQuote = false - var inDoubleQuote = false - var isEscaped = false - - for char in query { - if isEscaped { - isEscaped = false - converted.append(char) - continue - } - if char == "\\" && (inSingleQuote || inDoubleQuote) { - isEscaped = true - converted.append(char) - continue - } - if char == "'" && !inDoubleQuote { - inSingleQuote.toggle() - } else if char == "\"" && !inSingleQuote { - inDoubleQuote.toggle() - } - if char == "?" && !inSingleQuote && !inDoubleQuote && paramIndex < parameters.count { - paramIndex += 1 - converted.append("{p\(paramIndex):String}") - } else { - converted.append(char) - } - } - - var paramMap: [String: String?] = [:] - for i in 0.. AsyncThrowingStream { diff --git a/Plugins/ClickHouseDriverPlugin/ClickHousePluginDriver+Http.swift b/Plugins/ClickHouseDriverPlugin/ClickHousePluginDriver+Http.swift new file mode 100644 index 000000000..99006be07 --- /dev/null +++ b/Plugins/ClickHouseDriverPlugin/ClickHousePluginDriver+Http.swift @@ -0,0 +1,309 @@ +// +// ClickHousePluginDriver+Http.swift +// ClickHouseDriverPlugin +// + +import Foundation +import os +import TableProPluginKit + +extension ClickHousePluginDriver { + // MARK: - Private HTTP Layer + + func executeRaw(_ query: String, queryId: String? = nil) async throws -> CHQueryResult { + lock.lock() + guard let session = self.session else { + lock.unlock() + throw ClickHouseError.notConnected + } + let database = _currentDatabase + if let queryId { + _lastQueryId = queryId + } + lock.unlock() + + var request = try buildRequest(query: query, database: database, queryId: queryId) + request.timeoutInterval = _queryTimeout.requestTimeoutInterval + let isSelect = Self.isSelectLikeQuery(query) + + let (data, response) = try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<(Data, URLResponse), Error>) in + let task = session.dataTask(with: request) { data, response, error in + if let error { + continuation.resume(throwing: error) + return + } + guard let data, let response else { + continuation.resume(throwing: ClickHouseError(message: "Empty response from server")) + return + } + continuation.resume(returning: (data, response)) + } + + self.lock.lock() + self.currentTask = task + self.lock.unlock() + + task.resume() + } + } onCancel: { + self.lock.lock() + self.currentTask?.cancel() + self.currentTask = nil + self.lock.unlock() + } + + lock.lock() + currentTask = nil + lock.unlock() + + if let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode >= 400 { + let body = String(data: data, encoding: .utf8) ?? "Unknown error" + Self.logger.error("ClickHouse HTTP \(httpResponse.statusCode): \(body)") + throw ClickHouseError(message: body.trimmingCharacters(in: .whitespacesAndNewlines)) + } + + if isSelect { + return parseTabSeparatedResponse(data) + } + + return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) + } + + func executeRawWithParams(_ query: String, params: [String: String?], queryId: String? = nil) async throws -> CHQueryResult { + lock.lock() + guard let session = self.session else { + lock.unlock() + throw ClickHouseError.notConnected + } + let database = _currentDatabase + if let queryId { + _lastQueryId = queryId + } + lock.unlock() + + var request = try buildRequest(query: query, database: database, queryId: queryId, params: params) + request.timeoutInterval = _queryTimeout.requestTimeoutInterval + let isSelect = Self.isSelectLikeQuery(query) + + let (data, response) = try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<(Data, URLResponse), Error>) in + let task = session.dataTask(with: request) { data, response, error in + if let error { + continuation.resume(throwing: error) + return + } + guard let data, let response else { + continuation.resume(throwing: ClickHouseError(message: "Empty response from server")) + return + } + continuation.resume(returning: (data, response)) + } + self.lock.lock() + self.currentTask = task + self.lock.unlock() + task.resume() + } + } onCancel: { + self.lock.lock() + self.currentTask?.cancel() + self.currentTask = nil + self.lock.unlock() + } + + lock.lock() + currentTask = nil + lock.unlock() + + if let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode >= 400 { + let body = String(data: data, encoding: .utf8) ?? "Unknown error" + Self.logger.error("ClickHouse HTTP \(httpResponse.statusCode): \(body)") + throw ClickHouseError(message: body.trimmingCharacters(in: .whitespacesAndNewlines)) + } + + if isSelect { + return parseTabSeparatedResponse(data) + } + + return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) + } + + func buildRequest(query: String, database: String, queryId: String? = nil, params: [String: String?]? = nil) throws -> URLRequest { + let useTLS = config.ssl.isEnabled + + var components = URLComponents() + components.scheme = useTLS ? "https" : "http" + components.host = config.host + components.port = config.port + components.path = "/" + + var queryItems = [URLQueryItem]() + if !database.isEmpty { + queryItems.append(URLQueryItem(name: "database", value: database)) + } + if let queryId { + queryItems.append(URLQueryItem(name: "query_id", value: queryId)) + } + queryItems.append(URLQueryItem(name: "send_progress_in_http_headers", value: "1")) + if let params { + for (key, value) in params.sorted(by: { $0.key < $1.key }) { + queryItems.append(URLQueryItem(name: "param_\(key)", value: value)) + } + } + if !queryItems.isEmpty { + components.queryItems = queryItems + } + + guard let url = components.url else { + throw ClickHouseError(message: "Failed to construct request URL") + } + + var request = URLRequest(url: url) + request.httpMethod = "POST" + + let credentials = "\(config.username):\(config.password)" + if let credData = credentials.data(using: .utf8) { + request.setValue("Basic \(credData.base64EncodedString())", forHTTPHeaderField: "Authorization") + } + + let trimmedQuery = query.trimmingCharacters(in: .whitespacesAndNewlines) + .replacingOccurrences(of: ";+$", with: "", options: .regularExpression) + + if Self.isSelectLikeQuery(trimmedQuery) { + request.httpBody = (trimmedQuery + " FORMAT TabSeparatedWithNamesAndTypes").data(using: .utf8) + } else { + request.httpBody = trimmedQuery.data(using: .utf8) + } + + return request + } + + static func isSelectLikeQuery(_ query: String) -> Bool { + let trimmed = query.trimmingCharacters(in: .whitespacesAndNewlines) + guard let firstWord = trimmed.split(separator: " ", maxSplits: 1).first else { + return false + } + return selectPrefixes.contains(firstWord.uppercased()) + } + + func parseTabSeparatedResponse(_ data: Data) -> CHQueryResult { + guard let text = String(data: data, encoding: .utf8), !text.isEmpty else { + return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) + } + + let lines = text.components(separatedBy: "\n") + + guard lines.count >= 2 else { + return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) + } + + let columns = lines[0].components(separatedBy: "\t") + let columnTypes = lines[1].components(separatedBy: "\t") + + var rows: [[PluginCellValue]] = [] + var truncated = false + for i in 2..= PluginRowLimits.emergencyMax { + truncated = true + break + } + } + + return CHQueryResult( + columns: columns, + columnTypeNames: columnTypes, + rows: rows, + affectedRows: rows.count, + isTruncated: truncated + ) + } + + static func unescapeTsvField(_ field: String) -> String { + var result = "" + result.reserveCapacity((field as NSString).length) + var iterator = field.makeIterator() + + while let char = iterator.next() { + if char == "\\" { + if let next = iterator.next() { + switch next { + case "\\": result.append("\\") + case "t": result.append("\t") + case "n": result.append("\n") + default: + result.append("\\") + result.append(next) + } + } else { + result.append("\\") + } + } else { + result.append(char) + } + } + + return result + } + + /// Convert `?` placeholders to `{p1:String}` and build parameter map for ClickHouse HTTP params. + static func buildClickHouseParams( + query: String, + parameters: [PluginCellValue] + ) -> (String, [String: String?]) { + var converted = "" + var paramIndex = 0 + var inSingleQuote = false + var inDoubleQuote = false + var isEscaped = false + + for char in query { + if isEscaped { + isEscaped = false + converted.append(char) + continue + } + if char == "\\" && (inSingleQuote || inDoubleQuote) { + isEscaped = true + converted.append(char) + continue + } + if char == "'" && !inDoubleQuote { + inSingleQuote.toggle() + } else if char == "\"" && !inSingleQuote { + inDoubleQuote.toggle() + } + if char == "?" && !inSingleQuote && !inDoubleQuote && paramIndex < parameters.count { + paramIndex += 1 + converted.append("{p\(paramIndex):String}") + } else { + converted.append(char) + } + } + + var paramMap: [String: String?] = [:] + for i in 0.. [PluginTableInfo] { + let sql = """ + SELECT name, engine FROM system.tables + WHERE database = currentDatabase() AND name NOT LIKE '.%' + ORDER BY name + """ + let result = try await execute(query: sql) + return result.rows.compactMap { row -> PluginTableInfo? in + guard let name = row[safe: 0]?.asText else { return nil } + let engine = row[safe: 1]?.asText + let tableType = (engine?.contains("View") == true) ? "VIEW" : "TABLE" + return PluginTableInfo(name: name, type: tableType) + } + } + + func fetchColumns(table: String, schema: String?) async throws -> [PluginColumnInfo] { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + + let pkSql = """ + SELECT primary_key, sorting_key FROM system.tables + WHERE database = currentDatabase() AND name = '\(escapedTable)' + """ + let pkResult = try await execute(query: pkSql) + let primaryKey = pkResult.rows.first.flatMap { $0[safe: 0]?.asText } ?? "" + let sortingKey = pkResult.rows.first.flatMap { $0[safe: 1]?.asText } ?? "" + let keyString = primaryKey.isEmpty ? sortingKey : primaryKey + let pkColumns = Set(keyString.split(separator: ",").map { $0.trimmingCharacters(in: .whitespaces) }) + + let sql = """ + SELECT name, type, default_kind, default_expression, comment + FROM system.columns + WHERE database = currentDatabase() AND table = '\(escapedTable)' + ORDER BY position + """ + let result = try await execute(query: sql) + return result.rows.compactMap { row -> PluginColumnInfo? in + guard let name = row[safe: 0]?.asText else { return nil } + let dataType = (row[safe: 1]?.asText) ?? "String" + let defaultKind = row[safe: 2]?.asText + let defaultExpr = row[safe: 3]?.asText + let comment = row[safe: 4]?.asText + + let isNullable = dataType.hasPrefix("Nullable(") + + var defaultValue: String? + if let kind = defaultKind, !kind.isEmpty, let expr = defaultExpr, !expr.isEmpty { + defaultValue = expr + } + + var extra: String? + if let kind = defaultKind, !kind.isEmpty, kind != "DEFAULT" { + extra = kind + } + + return PluginColumnInfo( + name: name, + dataType: dataType, + isNullable: isNullable, + isPrimaryKey: pkColumns.contains(name), + defaultValue: defaultValue, + extra: extra, + comment: (comment?.isEmpty == false) ? comment : nil, + allowedValues: EnumValueParser.parseClickHouseEnum(from: ClickHousePluginDriver.unwrapTypeWrappers(dataType)) + ) + } + } + + func fetchAllColumns(schema: String?) async throws -> [String: [PluginColumnInfo]] { + // Pre-fetch PK columns for all tables. Falls back to sorting_key when + // primary_key is empty (MergeTree without explicit PRIMARY KEY clause). + // Note: expression-based keys like toDate(col) won't match bare column names. + let pkSql = """ + SELECT name, primary_key, sorting_key FROM system.tables + WHERE database = currentDatabase() + """ + let pkResult = try await execute(query: pkSql) + var pkLookup: [String: Set] = [:] + for row in pkResult.rows { + guard let tableName = row[safe: 0]?.asText else { continue } + let primaryKey = (row[safe: 1]?.asText) ?? "" + let sortingKey = (row[safe: 2]?.asText) ?? "" + let keyString = primaryKey.isEmpty ? sortingKey : primaryKey + guard !keyString.isEmpty else { continue } + let cols = Set(keyString.split(separator: ",").map { String($0).trimmingCharacters(in: .whitespaces) }) + pkLookup[tableName] = cols + } + + let sql = """ + SELECT table, name, type, default_kind, default_expression, comment + FROM system.columns + WHERE database = currentDatabase() + ORDER BY table, position + """ + let result = try await execute(query: sql) + var columnsByTable: [String: [PluginColumnInfo]] = [:] + for row in result.rows { + guard let tableName = row[safe: 0]?.asText, + let colName = row[safe: 1]?.asText else { continue } + let dataType = (row[safe: 2]?.asText) ?? "String" + let defaultKind = row[safe: 3]?.asText + let defaultExpr = row[safe: 4]?.asText + let comment = row[safe: 5]?.asText + + let isNullable = dataType.hasPrefix("Nullable(") + + var defaultValue: String? + if let kind = defaultKind, !kind.isEmpty, let expr = defaultExpr, !expr.isEmpty { + defaultValue = expr + } + + var extra: String? + if let kind = defaultKind, !kind.isEmpty, kind != "DEFAULT" { + extra = kind + } + + let colInfo = PluginColumnInfo( + name: colName, + dataType: dataType, + isNullable: isNullable, + isPrimaryKey: pkLookup[tableName]?.contains(colName) == true, + defaultValue: defaultValue, + extra: extra, + comment: (comment?.isEmpty == false) ? comment : nil, + allowedValues: EnumValueParser.parseClickHouseEnum(from: ClickHousePluginDriver.unwrapTypeWrappers(dataType)) + ) + columnsByTable[tableName, default: []].append(colInfo) + } + return columnsByTable + } + + static func unwrapTypeWrappers(_ value: String) -> String { + for prefix in ["Nullable(", "LowCardinality("] { + if value.hasPrefix(prefix), value.hasSuffix(")") { + let start = value.index(value.startIndex, offsetBy: prefix.count) + let end = value.index(before: value.endIndex) + return unwrapTypeWrappers(String(value[start.. [PluginIndexInfo] { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + var indexes: [PluginIndexInfo] = [] + + let sortingKeySql = """ + SELECT sorting_key FROM system.tables + WHERE database = currentDatabase() AND name = '\(escapedTable)' + """ + let sortingResult = try await execute(query: sortingKeySql) + if let row = sortingResult.rows.first, + let sortingKey = row[safe: 0]?.asText, !sortingKey.isEmpty { + let columns = sortingKey.components(separatedBy: ",").map { + $0.trimmingCharacters(in: .whitespacesAndNewlines) + } + indexes.append(PluginIndexInfo( + name: "PRIMARY (sorting key)", + columns: columns, + isUnique: false, + isPrimary: true, + type: "SORTING KEY" + )) + } + + let caps = ClickHouseCapabilities.parse(serverVersion) + guard caps.hasDataSkippingIndicesTable else { return indexes } + let skippingSql = """ + SELECT name, expr FROM system.data_skipping_indices + WHERE database = currentDatabase() AND table = '\(escapedTable)' + """ + let skippingResult = try await execute(query: skippingSql) + for row in skippingResult.rows { + guard let idxName = row[safe: 0]?.asText else { continue } + let expr = (row[safe: 1]?.asText) ?? "" + let columns = expr.components(separatedBy: ",").map { + $0.trimmingCharacters(in: .whitespacesAndNewlines) + } + indexes.append(PluginIndexInfo( + name: idxName, + columns: columns, + isUnique: false, + isPrimary: false, + type: "DATA_SKIPPING" + )) + } + + return indexes + } + + func fetchForeignKeys(table: String, schema: String?) async throws -> [PluginForeignKeyInfo] { + [] + } + + func fetchApproximateRowCount(table: String, schema: String?) async throws -> Int? { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + let sql = """ + SELECT sum(rows) FROM system.parts + WHERE database = currentDatabase() AND table = '\(escapedTable)' AND active = 1 + """ + let result = try await execute(query: sql) + if let row = result.rows.first, let cell = row.first, let str = cell.asText { + return Int(str) + } + return nil + } + + func fetchTableDDL(table: String, schema: String?) async throws -> String { + let escapedTable = table.replacingOccurrences(of: "`", with: "``") + let sql = "SHOW CREATE TABLE `\(escapedTable)`" + let result = try await execute(query: sql) + return result.rows.first?.first?.asText ?? "" + } + + func fetchViewDefinition(view: String, schema: String?) async throws -> String { + let escapedView = view.replacingOccurrences(of: "'", with: "''") + let sql = """ + SELECT as_select FROM system.tables + WHERE database = currentDatabase() AND name = '\(escapedView)' + """ + let result = try await execute(query: sql) + return result.rows.first?.first?.asText ?? "" + } + + func fetchTableMetadata(table: String, schema: String?) async throws -> PluginTableMetadata { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + + let engineSql = """ + SELECT engine, comment FROM system.tables + WHERE database = currentDatabase() AND name = '\(escapedTable)' + """ + let engineResult = try await execute(query: engineSql) + let engine = engineResult.rows.first.flatMap { $0[safe: 0]?.asText } + let tableComment = engineResult.rows.first.flatMap { $0[safe: 1]?.asText } + + let partsSql = """ + SELECT sum(rows), sum(bytes_on_disk) + FROM system.parts + WHERE database = currentDatabase() AND table = '\(escapedTable)' AND active = 1 + """ + let partsResult = try await execute(query: partsSql) + if let row = partsResult.rows.first { + let rowCount = (row[safe: 0]?.asText).flatMap { Int64($0) } + let sizeBytes = (row[safe: 1]?.asText).flatMap { Int64($0) } ?? 0 + return PluginTableMetadata( + tableName: table, + dataSize: sizeBytes, + totalSize: sizeBytes, + rowCount: rowCount, + comment: (tableComment?.isEmpty == false) ? tableComment : nil, + engine: engine + ) + } + + return PluginTableMetadata(tableName: table, engine: engine) + } + + func fetchDatabases() async throws -> [String] { + let result = try await execute(query: "SHOW DATABASES") + return result.rows.compactMap { $0.first?.asText } + } + + func fetchDatabaseMetadata(_ database: String) async throws -> PluginDatabaseMetadata { + let escapedDb = database.replacingOccurrences(of: "'", with: "''") + let sql = """ + SELECT count() AS table_count, sum(total_bytes) AS size_bytes + FROM system.tables WHERE database = '\(escapedDb)' + """ + let result = try await execute(query: sql) + if let row = result.rows.first { + let tableCount = (row[safe: 0]?.asText).flatMap { Int($0) } ?? 0 + let sizeBytes = (row[safe: 1]?.asText).flatMap { Int64($0) } + return PluginDatabaseMetadata( + name: database, + tableCount: tableCount, + sizeBytes: sizeBytes + ) + } + return PluginDatabaseMetadata(name: database) + } + + func fetchAllDatabaseMetadata() async throws -> [PluginDatabaseMetadata] { + let sql = """ + SELECT database, count() AS table_count, sum(total_bytes) AS size_bytes + FROM system.tables + GROUP BY database + ORDER BY database + """ + let result = try await execute(query: sql) + return result.rows.compactMap { row -> PluginDatabaseMetadata? in + guard let name = row[safe: 0]?.asText else { return nil } + let tableCount = (row[safe: 1]?.asText).flatMap { Int($0) } ?? 0 + let sizeBytes = (row[safe: 2]?.asText).flatMap { Int64($0) } + return PluginDatabaseMetadata(name: name, tableCount: tableCount, sizeBytes: sizeBytes) + } + } + + func createDatabaseFormSpec() async throws -> PluginCreateDatabaseFormSpec? { + PluginCreateDatabaseFormSpec(fields: [], footnote: nil) + } + + func createDatabase(_ request: PluginCreateDatabaseRequest) async throws { + let escapedName = request.name.replacingOccurrences(of: "`", with: "``") + _ = try await execute(query: "CREATE DATABASE `\(escapedName)`") + } + + func dropDatabase(name: String) async throws { + let escapedName = name.replacingOccurrences(of: "`", with: "``") + _ = try await execute(query: "DROP DATABASE `\(escapedName)`") + } + + // MARK: - All Tables Metadata + + func allTablesMetadataSQL(schema: String?) -> String? { + """ + SELECT + database as `schema`, + name, + engine as kind, + total_rows as estimated_rows, + formatReadableSize(total_bytes) as total_size, + comment + FROM system.tables + WHERE database = currentDatabase() + ORDER BY name + """ + } + +} diff --git a/Plugins/DuckDBDriverPlugin/DuckDBConnection.swift b/Plugins/DuckDBDriverPlugin/DuckDBConnection.swift new file mode 100644 index 000000000..666add7c4 --- /dev/null +++ b/Plugins/DuckDBDriverPlugin/DuckDBConnection.swift @@ -0,0 +1,600 @@ +// +// DuckDBConnection.swift +// DuckDBDriverPlugin +// + +import CDuckDB +import Foundation +import os +import TableProPluginKit + +actor DuckDBConnectionActor { + private static let logger = Logger(subsystem: "com.TablePro", category: "DuckDBConnectionActor") + + private var database: duckdb_database? + private var connection: duckdb_connection? + + var isConnected: Bool { connection != nil } + + var connectionHandleForInterrupt: duckdb_connection? { connection } + + func open(path: String) throws { + var db: duckdb_database? + var errorPtr: UnsafeMutablePointer? + let state = duckdb_open_ext(path, &db, nil, &errorPtr) + + if state == DuckDBError { + let detail: String + if let errPtr = errorPtr { + detail = String(cString: errPtr) + duckdb_free(errPtr) + } else { + detail = "unknown error" + } + throw DuckDBPluginError.connectionFailed( + "Failed to open DuckDB database at '\(path)': \(detail)" + ) + } + + guard let openedDB = db else { + throw DuckDBPluginError.connectionFailed( + "Failed to open DuckDB database at '\(path)'" + ) + } + + var conn: duckdb_connection? + let connState = duckdb_connect(openedDB, &conn) + + if connState == DuckDBError { + duckdb_close(&db) + throw DuckDBPluginError.connectionFailed("Failed to create DuckDB connection") + } + + database = db + connection = conn + } + + func close() { + if connection != nil { + duckdb_disconnect(&connection) + connection = nil + } + if database != nil { + duckdb_close(&database) + database = nil + } + } + + func executeQuery(_ query: String) throws -> DuckDBRawResult { + guard let conn = connection else { + throw DuckDBPluginError.notConnected + } + + let startTime = Date() + var result = duckdb_result() + + let state = duckdb_query(conn, query, &result) + + if state == DuckDBError { + let errorMsg: String + if let errPtr = duckdb_result_error(&result) { + errorMsg = String(cString: errPtr) + } else { + errorMsg = "Unknown DuckDB error" + } + duckdb_destroy_result(&result) + throw DuckDBPluginError.queryFailed(errorMsg) + } + + defer { + duckdb_destroy_result(&result) + } + + var raw = Self.extractResult(from: &result, startTime: startTime) + Self.patchTzColumns(&raw, query: query, connection: conn) + return raw + } + + func executePrepared(_ query: String, parameters: [PluginCellValue]) throws -> DuckDBRawResult { + guard let conn = connection else { + throw DuckDBPluginError.notConnected + } + + let startTime = Date() + var stmtOpt: duckdb_prepared_statement? + + let prepState = duckdb_prepare(conn, query, &stmtOpt) + if prepState == DuckDBError { + let errorMsg: String + if let s = stmtOpt, let errPtr = duckdb_prepare_error(s) { + errorMsg = String(cString: errPtr) + } else { + errorMsg = "Failed to prepare statement" + } + duckdb_destroy_prepare(&stmtOpt) + throw DuckDBPluginError.queryFailed(errorMsg) + } + + guard let stmt = stmtOpt else { + throw DuckDBPluginError.queryFailed("Failed to prepare statement") + } + + defer { + duckdb_destroy_prepare(&stmtOpt) + } + + for (index, param) in parameters.enumerated() { + let paramIdx = idx_t(index + 1) + let bindState: duckdb_state + switch param { + case .null: + bindState = duckdb_bind_null(stmt, paramIdx) + case .text(let value): + bindState = duckdb_bind_varchar(stmt, paramIdx, value) + case .bytes(let data): + bindState = data.withUnsafeBytes { rawBuffer -> duckdb_state in + guard let baseAddress = rawBuffer.baseAddress else { + return duckdb_bind_null(stmt, paramIdx) + } + return duckdb_bind_blob(stmt, paramIdx, baseAddress, idx_t(data.count)) + } + } + if bindState == DuckDBError { + throw DuckDBPluginError.queryFailed("Failed to bind parameter at index \(index)") + } + } + + var result = duckdb_result() + let execState = duckdb_execute_prepared(stmt, &result) + + if execState == DuckDBError { + let errorMsg: String + if let errPtr = duckdb_result_error(&result) { + errorMsg = String(cString: errPtr) + } else { + errorMsg = "Failed to execute prepared statement" + } + duckdb_destroy_result(&result) + throw DuckDBPluginError.queryFailed(errorMsg) + } + + defer { + duckdb_destroy_result(&result) + } + + var raw = Self.extractResult(from: &result, startTime: startTime) + Self.patchTzColumns(&raw, query: query, connection: conn) + return raw + } + + func streamQuery( + _ query: String, + continuation: AsyncThrowingStream.Continuation + ) throws { + guard let conn = connection else { + throw DuckDBPluginError.notConnected + } + + var result = duckdb_result() + let state = duckdb_query(conn, query, &result) + + if state == DuckDBError { + let errorMsg: String + if let errPtr = duckdb_result_error(&result) { + errorMsg = String(cString: errPtr) + } else { + errorMsg = "Unknown DuckDB error" + } + duckdb_destroy_result(&result) + throw DuckDBPluginError.queryFailed(errorMsg) + } + + let colCount = duckdb_column_count(&result) + var columns: [String] = [] + var columnTypeNames: [String] = [] + var columnTypes: [duckdb_type] = [] + for i in 0...Continuation + ) throws { + let castExprs = columns.enumerated().map { i, name in + castExpression(for: columnTypes[i], column: name) + } + let wrappedQuery = buildWrappedQuery(originalQuery: query, castExprs: castExprs) + + var result = duckdb_result() + let state = duckdb_query(connection, wrappedQuery, &result) + if state == DuckDBError { + let errorMsg: String + if let errPtr = duckdb_result_error(&result) { + errorMsg = String(cString: errPtr) + } else { + errorMsg = "Unknown DuckDB error" + } + duckdb_destroy_result(&result) + throw DuckDBPluginError.queryFailed(errorMsg) + } + defer { duckdb_destroy_result(&result) } + + try Self.streamResultRows( + &result, + columns: columns, + columnTypeNames: columnTypeNames, + continuation: continuation + ) + } + + private static func streamResultRows( + _ result: inout duckdb_result, + columns: [String], + columnTypeNames: [String], + continuation: AsyncThrowingStream.Continuation + ) throws { + let colCount = duckdb_column_count(&result) + let rowCount = duckdb_row_count(&result) + + continuation.yield(.header(PluginStreamHeader( + columns: columns, + columnTypeNames: columnTypeNames, + estimatedRowCount: Int(rowCount) + ))) + + let maxRows = min(rowCount, UInt64(PluginRowLimits.emergencyMax)) + if rowCount > UInt64(PluginRowLimits.emergencyMax) { + Self.logger.warning("streamQuery truncating result from \(rowCount) to \(maxRows) rows") + } + + for row in 0.. DuckDBRawResult { + let colCount = duckdb_column_count(&result) + let rowCount = duckdb_row_count(&result) + let rowsChanged = duckdb_rows_changed(&result) + + var columns: [String] = [] + var columnTypeNames: [String] = [] + var columnTypes: [duckdb_type] = [] + + for i in 0.. UInt64(PluginRowLimits.emergencyMax) { + truncated = true + } + + for row in 0.. String { + switch type { + case DUCKDB_TYPE_BOOLEAN: return "BOOLEAN" + case DUCKDB_TYPE_TINYINT: return "TINYINT" + case DUCKDB_TYPE_SMALLINT: return "SMALLINT" + case DUCKDB_TYPE_INTEGER: return "INTEGER" + case DUCKDB_TYPE_BIGINT: return "BIGINT" + case DUCKDB_TYPE_UTINYINT: return "UTINYINT" + case DUCKDB_TYPE_USMALLINT: return "USMALLINT" + case DUCKDB_TYPE_UINTEGER: return "UINTEGER" + case DUCKDB_TYPE_UBIGINT: return "UBIGINT" + case DUCKDB_TYPE_FLOAT: return "FLOAT" + case DUCKDB_TYPE_DOUBLE: return "DOUBLE" + case DUCKDB_TYPE_TIMESTAMP: return "TIMESTAMP" + case DUCKDB_TYPE_DATE: return "DATE" + case DUCKDB_TYPE_TIME: return "TIME" + case DUCKDB_TYPE_INTERVAL: return "INTERVAL" + case DUCKDB_TYPE_HUGEINT: return "HUGEINT" + case DUCKDB_TYPE_VARCHAR: return "VARCHAR" + case DUCKDB_TYPE_BLOB: return "BLOB" + case DUCKDB_TYPE_DECIMAL: return "DECIMAL" + case DUCKDB_TYPE_TIMESTAMP_S: return "TIMESTAMP_S" + case DUCKDB_TYPE_TIMESTAMP_MS: return "TIMESTAMP_MS" + case DUCKDB_TYPE_TIMESTAMP_NS: return "TIMESTAMP_NS" + case DUCKDB_TYPE_ENUM: return "ENUM" + case DUCKDB_TYPE_LIST: return "LIST" + case DUCKDB_TYPE_STRUCT: return "STRUCT" + case DUCKDB_TYPE_MAP: return "MAP" + case DUCKDB_TYPE_UUID: return "UUID" + case DUCKDB_TYPE_UNION: return "UNION" + case DUCKDB_TYPE_BIT: return "BIT" + case DUCKDB_TYPE_TIMESTAMP_TZ: return "TIMESTAMPTZ" + case DUCKDB_TYPE_TIME_TZ: return "TIMETZ" + case DUCKDB_TYPE_TIME_NS: return "TIME_NS" + case DUCKDB_TYPE_UHUGEINT: return "UHUGEINT" + case DUCKDB_TYPE_ARRAY: return "ARRAY" + case DUCKDB_TYPE_GEOMETRY: return "GEOMETRY" + default: return "VARCHAR" + } + } + + private static func extractFallbackValue( + _ result: inout duckdb_result, col: idx_t, row: idx_t, type: duckdb_type + ) -> String? { + switch type { + case DUCKDB_TYPE_TIMESTAMP, DUCKDB_TYPE_TIMESTAMP_S, DUCKDB_TYPE_TIMESTAMP_MS, DUCKDB_TYPE_TIMESTAMP_NS: + let ts = duckdb_value_timestamp(&result, col, row) + return formatTimestamp(ts) + + case DUCKDB_TYPE_DATE: + let date = duckdb_value_date(&result, col, row) + let d = duckdb_from_date(date) + return String(format: "\(formatYearISO(d.year))-%02d-%02d", d.month, d.day) + + case DUCKDB_TYPE_TIME, DUCKDB_TYPE_TIME_NS: + let time = duckdb_value_time(&result, col, row) + return formatTime(duckdb_from_time(time)) + + case DUCKDB_TYPE_BOOLEAN: + return duckdb_value_boolean(&result, col, row) ? "true" : "false" + + case DUCKDB_TYPE_TINYINT: + return String(duckdb_value_int8(&result, col, row)) + case DUCKDB_TYPE_SMALLINT: + return String(duckdb_value_int16(&result, col, row)) + case DUCKDB_TYPE_INTEGER: + return String(duckdb_value_int32(&result, col, row)) + case DUCKDB_TYPE_BIGINT: + return String(duckdb_value_int64(&result, col, row)) + case DUCKDB_TYPE_UTINYINT: + return String(duckdb_value_uint8(&result, col, row)) + case DUCKDB_TYPE_USMALLINT: + return String(duckdb_value_uint16(&result, col, row)) + case DUCKDB_TYPE_UINTEGER: + return String(duckdb_value_uint32(&result, col, row)) + case DUCKDB_TYPE_UBIGINT: + return String(duckdb_value_uint64(&result, col, row)) + case DUCKDB_TYPE_FLOAT: + return String(duckdb_value_float(&result, col, row)) + case DUCKDB_TYPE_DOUBLE: + return String(duckdb_value_double(&result, col, row)) + + case DUCKDB_TYPE_HUGEINT: + let h = duckdb_value_hugeint(&result, col, row) + return formatHugeInt(upper: h.upper, lower: h.lower) + + case DUCKDB_TYPE_UHUGEINT: + let u = duckdb_value_uhugeint(&result, col, row) + return formatUHugeInt(upper: u.upper, lower: u.lower) + + default: + return nil + } + } + + static func patchTzColumns( + _ raw: inout DuckDBRawResult, query: String, connection: duckdb_connection + ) { + let patchedColIndices = raw.columnTypes.enumerated().compactMap { idx, type in + isUnrenderable(type) ? idx : nil + } + guard !patchedColIndices.isEmpty, !raw.rows.isEmpty else { return } + + let castExprs = raw.columns.enumerated().map { i, name in + castExpression(for: raw.columnTypes[i], column: name) + } + let wrappedQuery = buildWrappedQuery(originalQuery: query, castExprs: castExprs) + + var patchResult = duckdb_result() + guard duckdb_query(connection, wrappedQuery, &patchResult) == DuckDBSuccess else { return } + defer { duckdb_destroy_result(&patchResult) } + + let patchRowCount = min(duckdb_row_count(&patchResult), UInt64(raw.rows.count)) + for row in 0.. Bool { + switch type { + case DUCKDB_TYPE_TIMESTAMP_TZ, DUCKDB_TYPE_TIME_TZ, DUCKDB_TYPE_GEOMETRY: + return true + default: + return false + } + } + + static func castExpression(for type: duckdb_type, column: String) -> String { + let quoted = quoteIdentifier(column) + switch type { + case DUCKDB_TYPE_GEOMETRY: + return "CASE WHEN \(quoted) IS NULL THEN NULL ELSE ST_AsText(\(quoted)) END AS \(quoted)" + case DUCKDB_TYPE_TIMESTAMP_TZ, DUCKDB_TYPE_TIME_TZ: + return "CASE WHEN \(quoted) IS NULL THEN NULL ELSE CAST(\(quoted) AS VARCHAR) END AS \(quoted)" + default: + return quoted + } + } + + static func buildWrappedQuery(originalQuery: String, castExprs: [String]) -> String { + var trimmed = originalQuery.trimmingCharacters(in: .whitespacesAndNewlines) + if trimmed.hasSuffix(";") { + trimmed = String(trimmed.dropLast()) + } + return "SELECT \(castExprs.joined(separator: ", ")) FROM (\(trimmed)) AS _tp_cast" + } + + static func quoteIdentifier(_ ident: String) -> String { + "\"\(ident.replacingOccurrences(of: "\"", with: "\"\""))\"" + } + + static func formatTimestamp(_ ts: duckdb_timestamp) -> String { + let parts = duckdb_from_timestamp(ts) + let d = parts.date + let t = parts.time + let micros = t.micros % 1_000_000 + let yearPart = formatYearISO(d.year) + if micros == 0 { + return String( + format: "\(yearPart)-%02d-%02d %02d:%02d:%02d", + d.month, d.day, t.hour, t.min, t.sec + ) + } + return String( + format: "\(yearPart)-%02d-%02d %02d:%02d:%02d.%06d", + d.month, d.day, t.hour, t.min, t.sec, micros + ) + } + + static func formatYearISO(_ year: Int32) -> String { + if year < 0 { + return String(format: "-%04d", -Int(year)) + } + return String(format: "%04d", year) + } + + private static func formatTime(_ t: duckdb_time_struct) -> String { + let micros = t.micros % 1_000_000 + if micros == 0 { + return String(format: "%02d:%02d:%02d", t.hour, t.min, t.sec) + } + return String(format: "%02d:%02d:%02d.%06d", t.hour, t.min, t.sec, micros) + } + + static func formatHugeInt(upper: Int64, lower: UInt64) -> String { + HugeIntFormatter.format(upper: upper, lower: lower) + } + + static func formatUHugeInt(upper: UInt64, lower: UInt64) -> String { + HugeIntFormatter.formatUnsigned(upper: upper, lower: lower) + } +} + +struct DuckDBRawResult: @unchecked Sendable { + let columns: [String] + let columnTypeNames: [String] + let columnTypes: [duckdb_type] + var rows: [[PluginCellValue]] + let rowsAffected: Int + let executionTime: TimeInterval + let isTruncated: Bool +} diff --git a/Plugins/DuckDBDriverPlugin/DuckDBPlugin.swift b/Plugins/DuckDBDriverPlugin/DuckDBPlugin.swift index a3874a070..d00210de1 100644 --- a/Plugins/DuckDBDriverPlugin/DuckDBPlugin.swift +++ b/Plugins/DuckDBDriverPlugin/DuckDBPlugin.swift @@ -103,599 +103,6 @@ final class DuckDBPlugin: NSObject, TableProPlugin, DriverPlugin { } } -// MARK: - DuckDB Connection Actor - -private actor DuckDBConnectionActor { - private static let logger = Logger(subsystem: "com.TablePro", category: "DuckDBConnectionActor") - - private var database: duckdb_database? - private var connection: duckdb_connection? - - var isConnected: Bool { connection != nil } - - var connectionHandleForInterrupt: duckdb_connection? { connection } - - func open(path: String) throws { - var db: duckdb_database? - var errorPtr: UnsafeMutablePointer? - let state = duckdb_open_ext(path, &db, nil, &errorPtr) - - if state == DuckDBError { - let detail: String - if let errPtr = errorPtr { - detail = String(cString: errPtr) - duckdb_free(errPtr) - } else { - detail = "unknown error" - } - throw DuckDBPluginError.connectionFailed( - "Failed to open DuckDB database at '\(path)': \(detail)" - ) - } - - guard let openedDB = db else { - throw DuckDBPluginError.connectionFailed( - "Failed to open DuckDB database at '\(path)'" - ) - } - - var conn: duckdb_connection? - let connState = duckdb_connect(openedDB, &conn) - - if connState == DuckDBError { - duckdb_close(&db) - throw DuckDBPluginError.connectionFailed("Failed to create DuckDB connection") - } - - database = db - connection = conn - } - - func close() { - if connection != nil { - duckdb_disconnect(&connection) - connection = nil - } - if database != nil { - duckdb_close(&database) - database = nil - } - } - - func executeQuery(_ query: String) throws -> DuckDBRawResult { - guard let conn = connection else { - throw DuckDBPluginError.notConnected - } - - let startTime = Date() - var result = duckdb_result() - - let state = duckdb_query(conn, query, &result) - - if state == DuckDBError { - let errorMsg: String - if let errPtr = duckdb_result_error(&result) { - errorMsg = String(cString: errPtr) - } else { - errorMsg = "Unknown DuckDB error" - } - duckdb_destroy_result(&result) - throw DuckDBPluginError.queryFailed(errorMsg) - } - - defer { - duckdb_destroy_result(&result) - } - - var raw = Self.extractResult(from: &result, startTime: startTime) - Self.patchTzColumns(&raw, query: query, connection: conn) - return raw - } - - func executePrepared(_ query: String, parameters: [PluginCellValue]) throws -> DuckDBRawResult { - guard let conn = connection else { - throw DuckDBPluginError.notConnected - } - - let startTime = Date() - var stmtOpt: duckdb_prepared_statement? - - let prepState = duckdb_prepare(conn, query, &stmtOpt) - if prepState == DuckDBError { - let errorMsg: String - if let s = stmtOpt, let errPtr = duckdb_prepare_error(s) { - errorMsg = String(cString: errPtr) - } else { - errorMsg = "Failed to prepare statement" - } - duckdb_destroy_prepare(&stmtOpt) - throw DuckDBPluginError.queryFailed(errorMsg) - } - - guard let stmt = stmtOpt else { - throw DuckDBPluginError.queryFailed("Failed to prepare statement") - } - - defer { - duckdb_destroy_prepare(&stmtOpt) - } - - for (index, param) in parameters.enumerated() { - let paramIdx = idx_t(index + 1) - let bindState: duckdb_state - switch param { - case .null: - bindState = duckdb_bind_null(stmt, paramIdx) - case .text(let value): - bindState = duckdb_bind_varchar(stmt, paramIdx, value) - case .bytes(let data): - bindState = data.withUnsafeBytes { rawBuffer -> duckdb_state in - guard let baseAddress = rawBuffer.baseAddress else { - return duckdb_bind_null(stmt, paramIdx) - } - return duckdb_bind_blob(stmt, paramIdx, baseAddress, idx_t(data.count)) - } - } - if bindState == DuckDBError { - throw DuckDBPluginError.queryFailed("Failed to bind parameter at index \(index)") - } - } - - var result = duckdb_result() - let execState = duckdb_execute_prepared(stmt, &result) - - if execState == DuckDBError { - let errorMsg: String - if let errPtr = duckdb_result_error(&result) { - errorMsg = String(cString: errPtr) - } else { - errorMsg = "Failed to execute prepared statement" - } - duckdb_destroy_result(&result) - throw DuckDBPluginError.queryFailed(errorMsg) - } - - defer { - duckdb_destroy_result(&result) - } - - var raw = Self.extractResult(from: &result, startTime: startTime) - Self.patchTzColumns(&raw, query: query, connection: conn) - return raw - } - - func streamQuery( - _ query: String, - continuation: AsyncThrowingStream.Continuation - ) throws { - guard let conn = connection else { - throw DuckDBPluginError.notConnected - } - - var result = duckdb_result() - let state = duckdb_query(conn, query, &result) - - if state == DuckDBError { - let errorMsg: String - if let errPtr = duckdb_result_error(&result) { - errorMsg = String(cString: errPtr) - } else { - errorMsg = "Unknown DuckDB error" - } - duckdb_destroy_result(&result) - throw DuckDBPluginError.queryFailed(errorMsg) - } - - let colCount = duckdb_column_count(&result) - var columns: [String] = [] - var columnTypeNames: [String] = [] - var columnTypes: [duckdb_type] = [] - for i in 0...Continuation - ) throws { - let castExprs = columns.enumerated().map { i, name in - castExpression(for: columnTypes[i], column: name) - } - let wrappedQuery = buildWrappedQuery(originalQuery: query, castExprs: castExprs) - - var result = duckdb_result() - let state = duckdb_query(connection, wrappedQuery, &result) - if state == DuckDBError { - let errorMsg: String - if let errPtr = duckdb_result_error(&result) { - errorMsg = String(cString: errPtr) - } else { - errorMsg = "Unknown DuckDB error" - } - duckdb_destroy_result(&result) - throw DuckDBPluginError.queryFailed(errorMsg) - } - defer { duckdb_destroy_result(&result) } - - try Self.streamResultRows( - &result, - columns: columns, - columnTypeNames: columnTypeNames, - continuation: continuation - ) - } - - private static func streamResultRows( - _ result: inout duckdb_result, - columns: [String], - columnTypeNames: [String], - continuation: AsyncThrowingStream.Continuation - ) throws { - let colCount = duckdb_column_count(&result) - let rowCount = duckdb_row_count(&result) - - continuation.yield(.header(PluginStreamHeader( - columns: columns, - columnTypeNames: columnTypeNames, - estimatedRowCount: Int(rowCount) - ))) - - let maxRows = min(rowCount, UInt64(PluginRowLimits.emergencyMax)) - if rowCount > UInt64(PluginRowLimits.emergencyMax) { - Self.logger.warning("streamQuery truncating result from \(rowCount) to \(maxRows) rows") - } - - for row in 0.. DuckDBRawResult { - let colCount = duckdb_column_count(&result) - let rowCount = duckdb_row_count(&result) - let rowsChanged = duckdb_rows_changed(&result) - - var columns: [String] = [] - var columnTypeNames: [String] = [] - var columnTypes: [duckdb_type] = [] - - for i in 0.. UInt64(PluginRowLimits.emergencyMax) { - truncated = true - } - - for row in 0.. String { - switch type { - case DUCKDB_TYPE_BOOLEAN: return "BOOLEAN" - case DUCKDB_TYPE_TINYINT: return "TINYINT" - case DUCKDB_TYPE_SMALLINT: return "SMALLINT" - case DUCKDB_TYPE_INTEGER: return "INTEGER" - case DUCKDB_TYPE_BIGINT: return "BIGINT" - case DUCKDB_TYPE_UTINYINT: return "UTINYINT" - case DUCKDB_TYPE_USMALLINT: return "USMALLINT" - case DUCKDB_TYPE_UINTEGER: return "UINTEGER" - case DUCKDB_TYPE_UBIGINT: return "UBIGINT" - case DUCKDB_TYPE_FLOAT: return "FLOAT" - case DUCKDB_TYPE_DOUBLE: return "DOUBLE" - case DUCKDB_TYPE_TIMESTAMP: return "TIMESTAMP" - case DUCKDB_TYPE_DATE: return "DATE" - case DUCKDB_TYPE_TIME: return "TIME" - case DUCKDB_TYPE_INTERVAL: return "INTERVAL" - case DUCKDB_TYPE_HUGEINT: return "HUGEINT" - case DUCKDB_TYPE_VARCHAR: return "VARCHAR" - case DUCKDB_TYPE_BLOB: return "BLOB" - case DUCKDB_TYPE_DECIMAL: return "DECIMAL" - case DUCKDB_TYPE_TIMESTAMP_S: return "TIMESTAMP_S" - case DUCKDB_TYPE_TIMESTAMP_MS: return "TIMESTAMP_MS" - case DUCKDB_TYPE_TIMESTAMP_NS: return "TIMESTAMP_NS" - case DUCKDB_TYPE_ENUM: return "ENUM" - case DUCKDB_TYPE_LIST: return "LIST" - case DUCKDB_TYPE_STRUCT: return "STRUCT" - case DUCKDB_TYPE_MAP: return "MAP" - case DUCKDB_TYPE_UUID: return "UUID" - case DUCKDB_TYPE_UNION: return "UNION" - case DUCKDB_TYPE_BIT: return "BIT" - case DUCKDB_TYPE_TIMESTAMP_TZ: return "TIMESTAMPTZ" - case DUCKDB_TYPE_TIME_TZ: return "TIMETZ" - case DUCKDB_TYPE_TIME_NS: return "TIME_NS" - case DUCKDB_TYPE_UHUGEINT: return "UHUGEINT" - case DUCKDB_TYPE_ARRAY: return "ARRAY" - case DUCKDB_TYPE_GEOMETRY: return "GEOMETRY" - default: return "VARCHAR" - } - } - - private static func extractFallbackValue( - _ result: inout duckdb_result, col: idx_t, row: idx_t, type: duckdb_type - ) -> String? { - switch type { - case DUCKDB_TYPE_TIMESTAMP, DUCKDB_TYPE_TIMESTAMP_S, DUCKDB_TYPE_TIMESTAMP_MS, DUCKDB_TYPE_TIMESTAMP_NS: - let ts = duckdb_value_timestamp(&result, col, row) - return formatTimestamp(ts) - - case DUCKDB_TYPE_DATE: - let date = duckdb_value_date(&result, col, row) - let d = duckdb_from_date(date) - return String(format: "\(formatYearISO(d.year))-%02d-%02d", d.month, d.day) - - case DUCKDB_TYPE_TIME, DUCKDB_TYPE_TIME_NS: - let time = duckdb_value_time(&result, col, row) - return formatTime(duckdb_from_time(time)) - - case DUCKDB_TYPE_BOOLEAN: - return duckdb_value_boolean(&result, col, row) ? "true" : "false" - - case DUCKDB_TYPE_TINYINT: - return String(duckdb_value_int8(&result, col, row)) - case DUCKDB_TYPE_SMALLINT: - return String(duckdb_value_int16(&result, col, row)) - case DUCKDB_TYPE_INTEGER: - return String(duckdb_value_int32(&result, col, row)) - case DUCKDB_TYPE_BIGINT: - return String(duckdb_value_int64(&result, col, row)) - case DUCKDB_TYPE_UTINYINT: - return String(duckdb_value_uint8(&result, col, row)) - case DUCKDB_TYPE_USMALLINT: - return String(duckdb_value_uint16(&result, col, row)) - case DUCKDB_TYPE_UINTEGER: - return String(duckdb_value_uint32(&result, col, row)) - case DUCKDB_TYPE_UBIGINT: - return String(duckdb_value_uint64(&result, col, row)) - case DUCKDB_TYPE_FLOAT: - return String(duckdb_value_float(&result, col, row)) - case DUCKDB_TYPE_DOUBLE: - return String(duckdb_value_double(&result, col, row)) - - case DUCKDB_TYPE_HUGEINT: - let h = duckdb_value_hugeint(&result, col, row) - return formatHugeInt(upper: h.upper, lower: h.lower) - - case DUCKDB_TYPE_UHUGEINT: - let u = duckdb_value_uhugeint(&result, col, row) - return formatUHugeInt(upper: u.upper, lower: u.lower) - - default: - return nil - } - } - - static func patchTzColumns( - _ raw: inout DuckDBRawResult, query: String, connection: duckdb_connection - ) { - let patchedColIndices = raw.columnTypes.enumerated().compactMap { idx, type in - isUnrenderable(type) ? idx : nil - } - guard !patchedColIndices.isEmpty, !raw.rows.isEmpty else { return } - - let castExprs = raw.columns.enumerated().map { i, name in - castExpression(for: raw.columnTypes[i], column: name) - } - let wrappedQuery = buildWrappedQuery(originalQuery: query, castExprs: castExprs) - - var patchResult = duckdb_result() - guard duckdb_query(connection, wrappedQuery, &patchResult) == DuckDBSuccess else { return } - defer { duckdb_destroy_result(&patchResult) } - - let patchRowCount = min(duckdb_row_count(&patchResult), UInt64(raw.rows.count)) - for row in 0.. Bool { - switch type { - case DUCKDB_TYPE_TIMESTAMP_TZ, DUCKDB_TYPE_TIME_TZ, DUCKDB_TYPE_GEOMETRY: - return true - default: - return false - } - } - - static func castExpression(for type: duckdb_type, column: String) -> String { - let quoted = quoteIdentifier(column) - switch type { - case DUCKDB_TYPE_GEOMETRY: - return "CASE WHEN \(quoted) IS NULL THEN NULL ELSE ST_AsText(\(quoted)) END AS \(quoted)" - case DUCKDB_TYPE_TIMESTAMP_TZ, DUCKDB_TYPE_TIME_TZ: - return "CASE WHEN \(quoted) IS NULL THEN NULL ELSE CAST(\(quoted) AS VARCHAR) END AS \(quoted)" - default: - return quoted - } - } - - static func buildWrappedQuery(originalQuery: String, castExprs: [String]) -> String { - var trimmed = originalQuery.trimmingCharacters(in: .whitespacesAndNewlines) - if trimmed.hasSuffix(";") { - trimmed = String(trimmed.dropLast()) - } - return "SELECT \(castExprs.joined(separator: ", ")) FROM (\(trimmed)) AS _tp_cast" - } - - static func quoteIdentifier(_ ident: String) -> String { - "\"\(ident.replacingOccurrences(of: "\"", with: "\"\""))\"" - } - - static func formatTimestamp(_ ts: duckdb_timestamp) -> String { - let parts = duckdb_from_timestamp(ts) - let d = parts.date - let t = parts.time - let micros = t.micros % 1_000_000 - let yearPart = formatYearISO(d.year) - if micros == 0 { - return String( - format: "\(yearPart)-%02d-%02d %02d:%02d:%02d", - d.month, d.day, t.hour, t.min, t.sec - ) - } - return String( - format: "\(yearPart)-%02d-%02d %02d:%02d:%02d.%06d", - d.month, d.day, t.hour, t.min, t.sec, micros - ) - } - - static func formatYearISO(_ year: Int32) -> String { - if year < 0 { - return String(format: "-%04d", -Int(year)) - } - return String(format: "%04d", year) - } - - private static func formatTime(_ t: duckdb_time_struct) -> String { - let micros = t.micros % 1_000_000 - if micros == 0 { - return String(format: "%02d:%02d:%02d", t.hour, t.min, t.sec) - } - return String(format: "%02d:%02d:%02d.%06d", t.hour, t.min, t.sec, micros) - } - - static func formatHugeInt(upper: Int64, lower: UInt64) -> String { - HugeIntFormatter.format(upper: upper, lower: lower) - } - - static func formatUHugeInt(upper: UInt64, lower: UInt64) -> String { - HugeIntFormatter.formatUnsigned(upper: upper, lower: lower) - } -} - -private struct DuckDBRawResult: @unchecked Sendable { - let columns: [String] - let columnTypeNames: [String] - let columnTypes: [duckdb_type] - var rows: [[PluginCellValue]] - let rowsAffected: Int - let executionTime: TimeInterval - let isTruncated: Bool -} - // MARK: - DuckDB Plugin Driver final class DuckDBPluginDriver: PluginDatabaseDriver, @unchecked Sendable { @@ -1470,22 +877,3 @@ final class DuckDBPluginDriver: PluginDatabaseDriver, @unchecked Sendable { } } -// MARK: - Errors - -enum DuckDBPluginError: Error { - case connectionFailed(String) - case notConnected - case queryFailed(String) - case unsupportedOperation -} - -extension DuckDBPluginError: PluginDriverError { - var pluginErrorMessage: String { - switch self { - case .connectionFailed(let msg): return msg - case .notConnected: return String(localized: "Not connected to database") - case .queryFailed(let msg): return msg - case .unsupportedOperation: return String(localized: "Operation not supported") - } - } -} diff --git a/Plugins/DuckDBDriverPlugin/DuckDBPluginError.swift b/Plugins/DuckDBDriverPlugin/DuckDBPluginError.swift new file mode 100644 index 000000000..28ae42001 --- /dev/null +++ b/Plugins/DuckDBDriverPlugin/DuckDBPluginError.swift @@ -0,0 +1,25 @@ +// +// DuckDBPluginError.swift +// DuckDBDriverPlugin +// + +import Foundation +import TableProPluginKit + +enum DuckDBPluginError: Error { + case connectionFailed(String) + case notConnected + case queryFailed(String) + case unsupportedOperation +} + +extension DuckDBPluginError: PluginDriverError { + var pluginErrorMessage: String { + switch self { + case .connectionFailed(let msg): return msg + case .notConnected: return String(localized: "Not connected to database") + case .queryFailed(let msg): return msg + case .unsupportedOperation: return String(localized: "Operation not supported") + } + } +} diff --git a/Plugins/JSONExportPlugin/JSONExportPlugin.swift b/Plugins/JSONExportPlugin/JSONExportPlugin.swift index 6aa73363e..2868e2890 100644 --- a/Plugins/JSONExportPlugin/JSONExportPlugin.swift +++ b/Plugins/JSONExportPlugin/JSONExportPlugin.swift @@ -157,7 +157,7 @@ final class JSONExportPlugin: ExportFormatPlugin, SettablePlugin { return val.lowercased() } - let isNumericCol = isNumericColumnType(columnTypeName) + let isNumericCol = PluginExportUtilities.isNumericColumnType(columnTypeName) if isNumericCol && isValidIntegerLiteral(val) { if let intVal = Int(val) { @@ -183,15 +183,6 @@ final class JSONExportPlugin: ExportFormatPlugin, SettablePlugin { return "\"\(PluginExportUtilities.escapeJSONString(val))\"" } - private func isNumericColumnType(_ typeName: String) -> Bool { - let numericPrefixes = [ - "int", "bigint", "decimal", "float", "double", "numeric", - "real", "smallint", "tinyint", "mediumint", "integer", "number" - ] - let lower = typeName.lowercased() - return numericPrefixes.contains { lower.hasPrefix($0) } - } - private func isValidIntegerLiteral(_ val: String) -> Bool { guard !val.isEmpty else { return false } let digits = val.hasPrefix("-") || val.hasPrefix("+") ? String(val.dropFirst()) : val diff --git a/Plugins/MSSQLDriverPlugin/MSSQLPlugin.swift b/Plugins/MSSQLDriverPlugin/MSSQLPlugin.swift index 16b33de1b..468243764 100644 --- a/Plugins/MSSQLDriverPlugin/MSSQLPlugin.swift +++ b/Plugins/MSSQLDriverPlugin/MSSQLPlugin.swift @@ -160,16 +160,16 @@ final class MSSQLPlugin: NSObject, TableProPlugin, DriverPlugin { final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { private let config: DriverConnectionConfig - private var freeTDSConn: FreeTDSConnection? - private var _currentSchema: String + var freeTDSConn: FreeTDSConnection? + var _currentSchema: String private var _serverVersion: String? /// IDENTITY columns observed during `fetchColumns`, keyed by table name. /// `generateMssqlInsert` reads this to skip IDENTITY columns: SQL Server /// rejects explicit values for IDENTITY columns unless IDENTITY_INSERT is ON, /// and the value the user typed is server-allocated anyway. - private var identityColumnsByTable: [String: Set] = [:] - private let identityCacheLock = NSLock() + var identityColumnsByTable: [String: Set] = [:] + let identityCacheLock = NSLock() private static let logger = Logger(subsystem: "com.TablePro", category: "MSSQLPluginDriver") @@ -554,527 +554,6 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { return nil } - // MARK: - Schema Operations - - func fetchTables(schema: String?) async throws -> [PluginTableInfo] { - let esc = effectiveSchemaEscaped(schema) - let sql = """ - SELECT t.TABLE_NAME, t.TABLE_TYPE - FROM INFORMATION_SCHEMA.TABLES t - WHERE t.TABLE_SCHEMA = '\(esc)' - AND t.TABLE_TYPE IN ('BASE TABLE', 'VIEW') - ORDER BY t.TABLE_NAME - """ - let result = try await execute(query: sql) - return result.rows.compactMap { row -> PluginTableInfo? in - guard let name = row[safe: 0]?.asText else { return nil } - let rawType = row[safe: 1]?.asText - let tableType = (rawType == "VIEW") ? "VIEW" : "TABLE" - return PluginTableInfo(name: name, type: tableType) - } - } - - func fetchColumns(table: String, schema: String?) async throws -> [PluginColumnInfo] { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - let esc = effectiveSchemaEscaped(schema) - let sql = """ - SELECT - c.COLUMN_NAME, - c.DATA_TYPE, - c.CHARACTER_MAXIMUM_LENGTH, - c.NUMERIC_PRECISION, - c.NUMERIC_SCALE, - c.IS_NULLABLE, - c.COLUMN_DEFAULT, - COLUMNPROPERTY(OBJECT_ID(c.TABLE_SCHEMA + '.' + c.TABLE_NAME), c.COLUMN_NAME, 'IsIdentity') AS IS_IDENTITY, - CASE WHEN pk.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END AS IS_PK - FROM INFORMATION_SCHEMA.COLUMNS c - LEFT JOIN ( - SELECT kcu.COLUMN_NAME - FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc - JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu - ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME - AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA - WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' - AND tc.TABLE_SCHEMA = '\(esc)' - AND tc.TABLE_NAME = '\(escapedTable)' - ) pk ON c.COLUMN_NAME = pk.COLUMN_NAME - WHERE c.TABLE_NAME = '\(escapedTable)' - AND c.TABLE_SCHEMA = '\(esc)' - ORDER BY c.ORDINAL_POSITION - """ - let result = try await execute(query: sql) - var identityColumns: Set = [] - let columns: [PluginColumnInfo] = result.rows.compactMap { row -> PluginColumnInfo? in - guard let name = row[safe: 0]?.asText else { return nil } - let dataType = row[safe: 1]?.asText - let charLen = row[safe: 2]?.asText - let numPrecision = row[safe: 3]?.asText - let numScale = row[safe: 4]?.asText - let isNullable = (row[safe: 5]?.asText) == "YES" - let defaultValue = row[safe: 6]?.asText - let isIdentity = (row[safe: 7]?.asText) == "1" - let isPk = (row[safe: 8]?.asText) == "1" - - if isIdentity { - identityColumns.insert(name) - } - - let baseType = (dataType ?? "nvarchar").lowercased() - let fixedSizeTypes: Set = [ - "int", "bigint", "smallint", "tinyint", "bit", - "money", "smallmoney", "float", "real", - "datetime", "datetime2", "smalldatetime", "date", "time", - "uniqueidentifier", "text", "ntext", "image", "xml", - "timestamp", "rowversion" - ] - var fullType = baseType - if fixedSizeTypes.contains(baseType) { - // No suffix - } else if let charLen, let len = Int(charLen), len > 0 { - fullType += "(\(len))" - } else if charLen == "-1" { - fullType += "(max)" - } else if let prec = numPrecision, let scale = numScale, - let p = Int(prec), let s = Int(scale) { - fullType += "(\(p),\(s))" - } - - return PluginColumnInfo( - name: name, - dataType: fullType, - isNullable: isNullable, - isPrimaryKey: isPk, - defaultValue: defaultValue, - extra: isIdentity ? "IDENTITY" : nil - ) - } - identityCacheLock.lock() - identityColumnsByTable[table] = identityColumns - identityCacheLock.unlock() - return columns - } - - /// Snapshot of IDENTITY columns observed by the most recent `fetchColumns` for the table. - /// Returns an empty set when `fetchColumns` hasn't run for this table yet, so callers - /// fall through to including every typed value (matching pre-cache behavior). - internal func cachedIdentityColumns(for table: String) -> Set { - identityCacheLock.lock() - defer { identityCacheLock.unlock() } - return identityColumnsByTable[table] ?? [] - } - - /// Test seam: pre-populate the cache so generateMssqlInsert can be exercised - /// without going through a live `fetchColumns` round-trip. - internal func setIdentityColumnsForTesting(_ columns: Set, table: String) { - identityCacheLock.lock() - identityColumnsByTable[table] = columns - identityCacheLock.unlock() - } - - func fetchIndexes(table: String, schema: String?) async throws -> [PluginIndexInfo] { - let esc = (schema ?? _currentSchema).replacingOccurrences(of: "]", with: "]]") - let bracketedTable = table.replacingOccurrences(of: "]", with: "]]") - let bracketedFull = "[\(esc)].[\(bracketedTable)]" - let sql = """ - SELECT i.name, i.is_unique, i.is_primary_key, c.name AS column_name - FROM sys.indexes i - JOIN sys.index_columns ic - ON i.object_id = ic.object_id AND i.index_id = ic.index_id - JOIN sys.columns c - ON ic.object_id = c.object_id AND ic.column_id = c.column_id - WHERE i.object_id = OBJECT_ID('\(bracketedFull)') - AND i.name IS NOT NULL - ORDER BY i.index_id, ic.key_ordinal - """ - let result = try await execute(query: sql) - var indexMap: [String: (unique: Bool, primary: Bool, columns: [String])] = [:] - for row in result.rows { - guard let idxName = row[safe: 0]?.asText, - let colName = row[safe: 3]?.asText else { continue } - let isUnique = (row[safe: 1]?.asText) == "1" - let isPrimary = (row[safe: 2]?.asText) == "1" - if indexMap[idxName] == nil { - indexMap[idxName] = (unique: isUnique, primary: isPrimary, columns: []) - } - indexMap[idxName]?.columns.append(colName) - } - return indexMap.map { name, info in - PluginIndexInfo( - name: name, - columns: info.columns, - isUnique: info.unique, - isPrimary: info.primary, - type: "CLUSTERED" - ) - }.sorted { $0.name < $1.name } - } - - func fetchForeignKeys(table: String, schema: String?) async throws -> [PluginForeignKeyInfo] { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - let esc = effectiveSchemaEscaped(schema) - let sql = """ - SELECT - fk.name AS constraint_name, - cp.name AS column_name, - tr.name AS ref_table, - cr.name AS ref_column - FROM sys.foreign_keys fk - JOIN sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id - JOIN sys.tables tp ON fkc.parent_object_id = tp.object_id - JOIN sys.schemas s ON tp.schema_id = s.schema_id - JOIN sys.columns cp - ON fkc.parent_object_id = cp.object_id AND fkc.parent_column_id = cp.column_id - JOIN sys.tables tr ON fkc.referenced_object_id = tr.object_id - JOIN sys.columns cr - ON fkc.referenced_object_id = cr.object_id AND fkc.referenced_column_id = cr.column_id - WHERE tp.name = '\(escapedTable)' AND s.name = '\(esc)' - ORDER BY fk.name - """ - let result = try await execute(query: sql) - return result.rows.compactMap { row -> PluginForeignKeyInfo? in - guard let constraintName = row[safe: 0]?.asText, - let columnName = row[safe: 1]?.asText, - let refTable = row[safe: 2]?.asText, - let refColumn = row[safe: 3]?.asText else { return nil } - return PluginForeignKeyInfo( - name: constraintName, - column: columnName, - referencedTable: refTable, - referencedColumn: refColumn - ) - } - } - - func fetchAllColumns(schema: String?) async throws -> [String: [PluginColumnInfo]] { - let esc = effectiveSchemaEscaped(schema) - let sql = """ - SELECT - c.TABLE_NAME, - c.COLUMN_NAME, - c.DATA_TYPE, - c.CHARACTER_MAXIMUM_LENGTH, - c.NUMERIC_PRECISION, - c.NUMERIC_SCALE, - c.IS_NULLABLE, - c.COLUMN_DEFAULT, - COLUMNPROPERTY(OBJECT_ID(c.TABLE_SCHEMA + '.' + c.TABLE_NAME), c.COLUMN_NAME, 'IsIdentity') AS IS_IDENTITY, - CASE WHEN pk.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END AS IS_PK - FROM INFORMATION_SCHEMA.COLUMNS c - LEFT JOIN ( - SELECT kcu.TABLE_NAME, kcu.COLUMN_NAME - FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc - JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu - ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME - AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA - WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' - AND tc.TABLE_SCHEMA = '\(esc)' - ) pk ON c.TABLE_NAME = pk.TABLE_NAME AND c.COLUMN_NAME = pk.COLUMN_NAME - WHERE c.TABLE_SCHEMA = '\(esc)' - ORDER BY c.TABLE_NAME, c.ORDINAL_POSITION - """ - let result = try await execute(query: sql) - var columnsByTable: [String: [PluginColumnInfo]] = [:] - for row in result.rows { - guard let tableName = row[safe: 0]?.asText, - let name = row[safe: 1]?.asText else { continue } - let dataType = row[safe: 2]?.asText - let charLen = row[safe: 3]?.asText - let numPrecision = row[safe: 4]?.asText - let numScale = row[safe: 5]?.asText - let isNullable = (row[safe: 6]?.asText) == "YES" - let defaultValue = row[safe: 7]?.asText - let isIdentity = (row[safe: 8]?.asText) == "1" - let isPk = (row[safe: 9]?.asText) == "1" - - let baseType = (dataType ?? "nvarchar").lowercased() - let fixedSizeTypes: Set = [ - "int", "bigint", "smallint", "tinyint", "bit", - "money", "smallmoney", "float", "real", - "datetime", "datetime2", "smalldatetime", "date", "time", - "uniqueidentifier", "text", "ntext", "image", "xml", - "timestamp", "rowversion" - ] - var fullType = baseType - if fixedSizeTypes.contains(baseType) { - // No suffix - } else if let charLen, let len = Int(charLen), len > 0 { - fullType += "(\(len))" - } else if charLen == "-1" { - fullType += "(max)" - } else if let prec = numPrecision, let scale = numScale, - let p = Int(prec), let s = Int(scale) { - fullType += "(\(p),\(s))" - } - - let col = PluginColumnInfo( - name: name, - dataType: fullType, - isNullable: isNullable, - isPrimaryKey: isPk, - defaultValue: defaultValue, - extra: isIdentity ? "IDENTITY" : nil - ) - columnsByTable[tableName, default: []].append(col) - } - return columnsByTable - } - - func fetchAllForeignKeys(schema: String?) async throws -> [String: [PluginForeignKeyInfo]] { - let esc = effectiveSchemaEscaped(schema) - let sql = """ - SELECT - tp.name AS table_name, - fk.name AS constraint_name, - cp.name AS column_name, - tr.name AS ref_table, - cr.name AS ref_column - FROM sys.foreign_keys fk - JOIN sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id - JOIN sys.tables tp ON fkc.parent_object_id = tp.object_id - JOIN sys.schemas s ON tp.schema_id = s.schema_id - JOIN sys.columns cp - ON fkc.parent_object_id = cp.object_id AND fkc.parent_column_id = cp.column_id - JOIN sys.tables tr ON fkc.referenced_object_id = tr.object_id - JOIN sys.columns cr - ON fkc.referenced_object_id = cr.object_id AND fkc.referenced_column_id = cr.column_id - WHERE s.name = '\(esc)' - ORDER BY tp.name, fk.name - """ - let result = try await execute(query: sql) - var fksByTable: [String: [PluginForeignKeyInfo]] = [:] - for row in result.rows { - guard let tableName = row[safe: 0]?.asText, - let constraintName = row[safe: 1]?.asText, - let columnName = row[safe: 2]?.asText, - let refTable = row[safe: 3]?.asText, - let refColumn = row[safe: 4]?.asText else { continue } - let fk = PluginForeignKeyInfo( - name: constraintName, - column: columnName, - referencedTable: refTable, - referencedColumn: refColumn - ) - fksByTable[tableName, default: []].append(fk) - } - return fksByTable - } - - func fetchAllDatabaseMetadata() async throws -> [PluginDatabaseMetadata] { - let sql = """ - SELECT d.name, - SUM(mf.size) * 8 * 1024 AS size_bytes - FROM sys.databases d - LEFT JOIN sys.master_files mf ON d.database_id = mf.database_id - GROUP BY d.name - ORDER BY d.name - """ - do { - let result = try await execute(query: sql) - var metadata = result.rows.compactMap { row -> PluginDatabaseMetadata? in - guard let name = row[safe: 0]?.asText else { return nil } - let sizeBytes = (row[safe: 1]?.asText).flatMap { Int64($0) } - return PluginDatabaseMetadata(name: name, sizeBytes: sizeBytes) - } - - for i in metadata.indices { - let dbName = metadata[i].name.replacingOccurrences(of: "]", with: "]]") - do { - let countResult = try await execute( - query: "SELECT COUNT(*) FROM [\(dbName)].sys.tables" - ) - if let countStr = countResult.rows.first?[safe: 0]?.asText, - let count = Int(countStr) { - metadata[i] = PluginDatabaseMetadata( - name: metadata[i].name, - tableCount: count, - sizeBytes: metadata[i].sizeBytes - ) - } - } catch { - // Database offline or permission denied — leave tableCount as nil - } - } - - return metadata - } catch { - // Fall back to N+1 if permission denied on sys.master_files - let dbs = try await fetchDatabases() - var result: [PluginDatabaseMetadata] = [] - for db in dbs { - do { - result.append(try await fetchDatabaseMetadata(db)) - } catch { - result.append(PluginDatabaseMetadata(name: db)) - } - } - return result - } - } - - func fetchTableDDL(table: String, schema: String?) async throws -> String { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - let esc = effectiveSchemaEscaped(schema) - let cols = try await fetchColumns(table: table, schema: schema) - let indexes = try await fetchIndexes(table: table, schema: schema) - let fks = try await fetchForeignKeys(table: table, schema: schema) - - var ddl = "CREATE TABLE [\(esc)].[\(escapedTable)] (\n" - let colDefs = cols.map { col -> String in - var def = " [\(col.name)] \(col.dataType.uppercased())" - if col.extra == "IDENTITY" { def += " IDENTITY(1,1)" } - def += col.isNullable ? " NULL" : " NOT NULL" - if let d = col.defaultValue { def += " DEFAULT \(d)" } - return def - } - - let pkCols = indexes.filter(\.isPrimary).flatMap(\.columns) - var parts = colDefs - if !pkCols.isEmpty { - let pkName = "PK_\(table)" - let pkDef = " CONSTRAINT [\(pkName)] PRIMARY KEY (\(pkCols.map { "[\($0)]" }.joined(separator: ", ")))" - parts.append(pkDef) - } - - for fk in fks { - let fkDef = " CONSTRAINT [\(fk.name)] FOREIGN KEY ([\(fk.column)]) REFERENCES [\(fk.referencedTable)] ([\(fk.referencedColumn)])" - parts.append(fkDef) - } - - ddl += parts.joined(separator: ",\n") - ddl += "\n);" - return ddl - } - - func fetchViewDefinition(view: String, schema: String?) async throws -> String { - let esc = effectiveSchemaEscaped(schema) - let escapedView = "\(esc).\(view.replacingOccurrences(of: "'", with: "''"))" - let sql = "SELECT definition FROM sys.sql_modules WHERE object_id = OBJECT_ID('\(escapedView)')" - let result = try await execute(query: sql) - return result.rows.first?.first?.asText ?? "" - } - - func fetchTableMetadata(table: String, schema: String?) async throws -> PluginTableMetadata { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - let esc = effectiveSchemaEscaped(schema) - let sql = """ - SELECT - SUM(p.rows) AS row_count, - 8 * SUM(a.used_pages) AS size_kb, - ep.value AS comment - FROM sys.tables t - JOIN sys.schemas s ON t.schema_id = s.schema_id - JOIN sys.partitions p - ON t.object_id = p.object_id AND p.index_id IN (0, 1) - JOIN sys.allocation_units a ON p.partition_id = a.container_id - LEFT JOIN sys.extended_properties ep - ON ep.major_id = t.object_id AND ep.minor_id = 0 AND ep.name = 'MS_Description' - WHERE t.name = '\(escapedTable)' AND s.name = '\(esc)' - GROUP BY ep.value - """ - let result = try await execute(query: sql) - if let row = result.rows.first { - let rowCount = (row[safe: 0]?.asText).flatMap { Int64($0) } - let sizeKb = (row[safe: 1]?.asText).flatMap { Int64($0) } ?? 0 - let comment = row[safe: 2]?.asText - return PluginTableMetadata( - tableName: table, - dataSize: sizeKb * 1_024, - totalSize: sizeKb * 1_024, - rowCount: rowCount, - comment: comment - ) - } - return PluginTableMetadata(tableName: table) - } - - func fetchDatabases() async throws -> [String] { - let sql = "SELECT name FROM sys.databases ORDER BY name" - let result = try await execute(query: sql) - return result.rows.compactMap { $0.first?.asText } - } - - func fetchSchemas() async throws -> [String] { - let sql = """ - SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA - WHERE SCHEMA_NAME NOT IN ( - 'information_schema','sys','db_owner','db_accessadmin', - 'db_securityadmin','db_ddladmin','db_backupoperator', - 'db_datareader','db_datawriter','db_denydatareader', - 'db_denydatawriter','guest' - ) - ORDER BY SCHEMA_NAME - """ - let result = try await execute(query: sql) - return result.rows.compactMap { $0.first?.asText } - } - - func switchSchema(to schema: String) async throws { - _currentSchema = schema - } - - func switchDatabase(to database: String) async throws { - guard let conn = freeTDSConn else { - throw MSSQLPluginError.notConnected - } - try await conn.switchDatabase(database) - } - - func fetchDatabaseMetadata(_ database: String) async throws -> PluginDatabaseMetadata { - let sql = """ - SELECT - SUM(size) * 8.0 / 1024 AS size_mb, - (SELECT COUNT(*) FROM sys.tables) AS table_count - FROM sys.database_files - """ - let result = try await execute(query: sql) - if let row = result.rows.first { - let sizeMb = (row[safe: 0]?.asText).flatMap { Double($0) } ?? 0 - let tableCount = (row[safe: 1]?.asText).flatMap { Int($0) } ?? 0 - return PluginDatabaseMetadata( - name: database, - tableCount: tableCount, - sizeBytes: Int64(sizeMb * 1_024 * 1_024) - ) - } - return PluginDatabaseMetadata(name: database) - } - - func createDatabaseFormSpec() async throws -> PluginCreateDatabaseFormSpec? { - PluginCreateDatabaseFormSpec(fields: [], footnote: nil) - } - - func createDatabase(_ request: PluginCreateDatabaseRequest) async throws { - let quotedName = "[\(request.name.replacingOccurrences(of: "]", with: "]]"))]" - _ = try await execute(query: "CREATE DATABASE \(quotedName)") - } - - func dropDatabase(name: String) async throws { - let quotedName = "[\(name.replacingOccurrences(of: "]", with: "]]"))]" - _ = try await execute(query: "DROP DATABASE \(quotedName)") - } - - // MARK: - All Tables Metadata - - func allTablesMetadataSQL(schema: String?) -> String? { - """ - SELECT - s.name as schema_name, - t.name as name, - CASE WHEN v.object_id IS NOT NULL THEN 'VIEW' ELSE 'TABLE' END as kind, - p.rows as estimated_rows, - CAST(ROUND(SUM(a.total_pages) * 8 / 1024.0, 2) AS VARCHAR) + ' MB' as total_size - FROM sys.tables t - INNER JOIN sys.schemas s ON t.schema_id = s.schema_id - INNER JOIN sys.indexes i ON t.object_id = i.object_id AND i.index_id IN (0, 1) - INNER JOIN sys.partitions p ON i.object_id = p.object_id AND i.index_id = p.index_id - INNER JOIN sys.allocation_units a ON p.partition_id = a.container_id - LEFT JOIN sys.views v ON t.object_id = v.object_id - GROUP BY s.name, t.name, p.rows, v.object_id - ORDER BY t.name - """ - } - // MARK: - Query Building func buildBrowseQuery( @@ -1086,8 +565,9 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { ) -> String? { let quotedTable = mssqlQuoteIdentifier(table) var query = "SELECT * FROM \(quotedTable)" - let orderBy = mssqlBuildOrderByClause(sortColumns: sortColumns, columns: columns) - ?? "ORDER BY (SELECT NULL)" + let orderBy = PluginSQLFilter.buildOrderByClause( + sortColumns: sortColumns, columns: columns, quoteIdentifier: mssqlQuoteIdentifier + ) ?? "ORDER BY (SELECT NULL)" query += " \(orderBy) OFFSET \(offset) ROWS FETCH NEXT \(limit) ROWS ONLY" return query } @@ -1103,12 +583,21 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { ) -> String? { let quotedTable = mssqlQuoteIdentifier(table) var query = "SELECT * FROM \(quotedTable)" - let whereClause = mssqlBuildWhereClause(filters: filters, logicMode: logicMode) + let whereClause = PluginSQLFilter.buildWhereClause( + filters: filters, + logicMode: logicMode, + quoteIdentifier: mssqlQuoteIdentifier, + escapeValue: mssqlEscapeValue, + regexCondition: { quoted, value in + "\(quoted) LIKE '%\(value.replacingOccurrences(of: "'", with: "''"))%'" + } + ) if !whereClause.isEmpty { query += " WHERE \(whereClause)" } - let orderBy = mssqlBuildOrderByClause(sortColumns: sortColumns, columns: columns) - ?? "ORDER BY (SELECT NULL)" + let orderBy = PluginSQLFilter.buildOrderByClause( + sortColumns: sortColumns, columns: columns, quoteIdentifier: mssqlQuoteIdentifier + ) ?? "ORDER BY (SELECT NULL)" query += " \(orderBy) OFFSET \(offset) ROWS FETCH NEXT \(limit) ROWS ONLY" return query } @@ -1119,29 +608,6 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { quoteIdentifier(identifier) } - private func mssqlBuildOrderByClause( - sortColumns: [(columnIndex: Int, ascending: Bool)], - columns: [String] - ) -> String? { - let parts = sortColumns.compactMap { sortCol -> String? in - guard sortCol.columnIndex >= 0, sortCol.columnIndex < columns.count else { return nil } - let columnName = columns[sortCol.columnIndex] - let direction = sortCol.ascending ? "ASC" : "DESC" - let quotedColumn = mssqlQuoteIdentifier(columnName) - return "\(quotedColumn) \(direction)" - } - guard !parts.isEmpty else { return nil } - return "ORDER BY " + parts.joined(separator: ", ") - } - - private func mssqlEscapeForLike(_ text: String) -> String { - text - .replacingOccurrences(of: "\\", with: "\\\\") - .replacingOccurrences(of: "%", with: "\\%") - .replacingOccurrences(of: "_", with: "\\_") - .replacingOccurrences(of: "'", with: "''") - } - private func mssqlEscapeValue(_ value: String) -> String { let trimmed = value.trimmingCharacters(in: .whitespaces) if trimmed.caseInsensitiveCompare("NULL") == .orderedSame { return "NULL" } @@ -1151,65 +617,6 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { return "'\(trimmed.replacingOccurrences(of: "'", with: "''"))'" } - private func mssqlBuildWhereClause( - filters: [(column: String, op: String, value: String)], - logicMode: String - ) -> String { - let conditions = filters.compactMap { filter -> String? in - mssqlBuildFilterCondition(column: filter.column, op: filter.op, value: filter.value) - } - guard !conditions.isEmpty else { return "" } - let separator = logicMode == "and" ? " AND " : " OR " - return conditions.joined(separator: separator) - } - - private func mssqlBuildFilterCondition(column: String, op: String, value: String) -> String? { - let quoted = mssqlQuoteIdentifier(column) - switch op { - case "=": return "\(quoted) = \(mssqlEscapeValue(value))" - case "!=": return "\(quoted) != \(mssqlEscapeValue(value))" - case ">": return "\(quoted) > \(mssqlEscapeValue(value))" - case ">=": return "\(quoted) >= \(mssqlEscapeValue(value))" - case "<": return "\(quoted) < \(mssqlEscapeValue(value))" - case "<=": return "\(quoted) <= \(mssqlEscapeValue(value))" - case "IS NULL": return "\(quoted) IS NULL" - case "IS NOT NULL": return "\(quoted) IS NOT NULL" - case "IS EMPTY": return "(\(quoted) IS NULL OR \(quoted) = '')" - case "IS NOT EMPTY": return "(\(quoted) IS NOT NULL AND \(quoted) != '')" - case "CONTAINS": - let escaped = mssqlEscapeForLike(value) - return "\(quoted) LIKE '%\(escaped)%' ESCAPE '\\'" - case "NOT CONTAINS": - let escaped = mssqlEscapeForLike(value) - return "\(quoted) NOT LIKE '%\(escaped)%' ESCAPE '\\'" - case "STARTS WITH": - let escaped = mssqlEscapeForLike(value) - return "\(quoted) LIKE '\(escaped)%' ESCAPE '\\'" - case "ENDS WITH": - let escaped = mssqlEscapeForLike(value) - return "\(quoted) LIKE '%\(escaped)' ESCAPE '\\'" - case "IN": - let values = value.split(separator: ",") - .map { mssqlEscapeValue($0.trimmingCharacters(in: .whitespaces)) } - .joined(separator: ", ") - return values.isEmpty ? nil : "\(quoted) IN (\(values))" - case "NOT IN": - let values = value.split(separator: ",") - .map { mssqlEscapeValue($0.trimmingCharacters(in: .whitespaces)) } - .joined(separator: ", ") - return values.isEmpty ? nil : "\(quoted) NOT IN (\(values))" - case "BETWEEN": - let parts = value.split(separator: ",", maxSplits: 1) - guard parts.count == 2 else { return nil } - let v1 = mssqlEscapeValue(parts[0].trimmingCharacters(in: .whitespaces)) - let v2 = mssqlEscapeValue(parts[1].trimmingCharacters(in: .whitespaces)) - return "\(quoted) BETWEEN \(v1) AND \(v2)" - case "REGEX": - let escaped = value.replacingOccurrences(of: "'", with: "''") - return "\(quoted) LIKE '%\(escaped)%'" - default: return nil - } - } // MARK: - Private Helpers @@ -1277,183 +684,11 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { value.replacingOccurrences(of: "'", with: "''") } - private func effectiveSchemaEscaped(_ schema: String?) -> String { + func effectiveSchemaEscaped(_ schema: String?) -> String { let raw = schema ?? _currentSchema return raw.replacingOccurrences(of: "'", with: "''") } - // MARK: - Create Table DDL - - func generateCreateTableSQL(definition: PluginCreateTableDefinition) -> String? { - guard !definition.columns.isEmpty else { return nil } - - let schema = _currentSchema - let qualifiedTable = "\(quoteIdentifier(schema)).\(quoteIdentifier(definition.tableName))" - let pkColumns = definition.columns.filter { $0.isPrimaryKey } - let inlinePK = pkColumns.count == 1 - var parts: [String] = definition.columns.map { mssqlColumnDefinition($0, inlinePK: inlinePK) } - - if pkColumns.count > 1 { - let pkCols = pkColumns.map { quoteIdentifier($0.name) }.joined(separator: ", ") - parts.append("PRIMARY KEY (\(pkCols))") - } - - for fk in definition.foreignKeys { - parts.append(mssqlForeignKeyDefinition(fk)) - } - - var sql = "CREATE TABLE \(qualifiedTable) (\n " + - parts.joined(separator: ",\n ") + - "\n);" - - var indexStatements: [String] = [] - for index in definition.indexes { - indexStatements.append(mssqlIndexDefinition(index, qualifiedTable: qualifiedTable)) - } - if !indexStatements.isEmpty { - sql += "\n\n" + indexStatements.joined(separator: ";\n") + ";" - } - - return sql - } - - private func mssqlColumnDefinition(_ col: PluginColumnDefinition, inlinePK: Bool) -> String { - var def = "\(quoteIdentifier(col.name)) \(col.dataType)" - if col.autoIncrement { - def += " IDENTITY(1,1)" - } - if col.isNullable { - def += " NULL" - } else { - def += " NOT NULL" - } - if let defaultValue = col.defaultValue { - def += " DEFAULT \(mssqlDefaultValue(defaultValue))" - } - if inlinePK && col.isPrimaryKey { - def += " PRIMARY KEY" - } - return def - } - - private func mssqlDefaultValue(_ value: String) -> String { - let upper = value.uppercased() - if upper == "NULL" || upper == "GETDATE()" || upper == "NEWID()" || upper == "GETUTCDATE()" - || value.hasPrefix("'") || value.hasPrefix("(") || Int64(value) != nil || Double(value) != nil { - return value - } - return "'\(escapeStringLiteral(value))'" - } - - private func mssqlIndexDefinition(_ index: PluginIndexDefinition, qualifiedTable: String) -> String { - let cols = index.columns.map { quoteIdentifier($0) }.joined(separator: ", ") - let unique = index.isUnique ? "UNIQUE " : "" - var def = "CREATE \(unique)INDEX \(quoteIdentifier(index.name)) ON \(qualifiedTable) (\(cols))" - if let type = index.indexType?.uppercased(), type == "CLUSTERED" { - def = "CREATE \(unique)CLUSTERED INDEX \(quoteIdentifier(index.name)) ON \(qualifiedTable) (\(cols))" - } else if let type = index.indexType?.uppercased(), type == "NONCLUSTERED" { - def = "CREATE \(unique)NONCLUSTERED INDEX \(quoteIdentifier(index.name)) ON \(qualifiedTable) (\(cols))" - } - return def - } - - private func mssqlForeignKeyDefinition(_ fk: PluginForeignKeyDefinition) -> String { - let cols = fk.columns.map { quoteIdentifier($0) }.joined(separator: ", ") - let refCols = fk.referencedColumns.map { quoteIdentifier($0) }.joined(separator: ", ") - var def = "CONSTRAINT \(quoteIdentifier(fk.name)) FOREIGN KEY (\(cols)) REFERENCES \(quoteIdentifier(fk.referencedTable)) (\(refCols))" - if fk.onDelete != "NO ACTION" { - def += " ON DELETE \(fk.onDelete)" - } - if fk.onUpdate != "NO ACTION" { - def += " ON UPDATE \(fk.onUpdate)" - } - return def - } - - // MARK: - ALTER TABLE DDL - - private func mssqlQualifiedTable(_ table: String) -> String { - "\(quoteIdentifier(_currentSchema)).\(quoteIdentifier(table))" - } - - func generateAddColumnSQL(table: String, column: PluginColumnDefinition) -> String? { - "ALTER TABLE \(mssqlQualifiedTable(table)) ADD \(mssqlColumnDefinition(column, inlinePK: false))" - } - - func generateModifyColumnSQL(table: String, oldColumn: PluginColumnDefinition, newColumn: PluginColumnDefinition) -> String? { - let qt = mssqlQualifiedTable(table) - var stmts: [String] = [] - let needsTypeChange = oldColumn.dataType != newColumn.dataType || oldColumn.isNullable != newColumn.isNullable - let defaultChanged = oldColumn.defaultValue != newColumn.defaultValue - - // Rename column first so subsequent statements reference the correct name - if oldColumn.name != newColumn.name { - let escapedPath = "\(escapeStringLiteral(_currentSchema)).\(escapeStringLiteral(table)).\(escapeStringLiteral(oldColumn.name))" - stmts.append("EXEC sp_rename '\(escapedPath)', '\(escapeStringLiteral(newColumn.name))', 'COLUMN'") - } - - let colName = quoteIdentifier(newColumn.name) - - // Drop existing default constraint before ALTER COLUMN or default change - if (defaultChanged || needsTypeChange) && oldColumn.defaultValue != nil { - let objectId = escapeStringLiteral("\(_currentSchema).\(table)") - stmts.append(""" - DECLARE @dfName NVARCHAR(256); \ - SELECT @dfName = dc.name FROM sys.default_constraints dc \ - JOIN sys.columns c ON dc.parent_column_id = c.column_id AND dc.parent_object_id = c.object_id \ - WHERE c.name = '\(escapeStringLiteral(newColumn.name))' \ - AND dc.parent_object_id = OBJECT_ID('\(objectId)'); \ - IF @dfName IS NOT NULL EXEC('ALTER TABLE \(qt) DROP CONSTRAINT [' + @dfName + ']') - """) - } - - if needsTypeChange { - let nullable = newColumn.isNullable ? "NULL" : "NOT NULL" - stmts.append("ALTER TABLE \(qt) ALTER COLUMN \(colName) \(newColumn.dataType) \(nullable)") - } - - if defaultChanged, let defaultValue = newColumn.defaultValue { - stmts.append("ALTER TABLE \(qt) ADD DEFAULT \(mssqlDefaultValue(defaultValue)) FOR \(colName)") - } - - return stmts.isEmpty ? nil : stmts.joined(separator: ";\n") - } - - func generateDropColumnSQL(table: String, columnName: String) -> String? { - "ALTER TABLE \(mssqlQualifiedTable(table)) DROP COLUMN \(quoteIdentifier(columnName))" - } - - func generateAddIndexSQL(table: String, index: PluginIndexDefinition) -> String? { - mssqlIndexDefinition(index, qualifiedTable: mssqlQualifiedTable(table)) - } - - func generateDropIndexSQL(table: String, indexName: String) -> String? { - "DROP INDEX \(quoteIdentifier(indexName)) ON \(mssqlQualifiedTable(table))" - } - - func generateAddForeignKeySQL(table: String, fk: PluginForeignKeyDefinition) -> String? { - "ALTER TABLE \(mssqlQualifiedTable(table)) ADD \(mssqlForeignKeyDefinition(fk))" - } - - func generateDropForeignKeySQL(table: String, constraintName: String) -> String? { - "ALTER TABLE \(mssqlQualifiedTable(table)) DROP CONSTRAINT \(quoteIdentifier(constraintName))" - } - - func generateModifyPrimaryKeySQL(table: String, oldColumns: [String], newColumns: [String], constraintName: String?) -> [String]? { - let qt = mssqlQualifiedTable(table) - var stmts: [String] = [] - if !oldColumns.isEmpty { - let name = constraintName.map { quoteIdentifier($0) } ?? "/* unknown constraint */" - stmts.append("ALTER TABLE \(qt) DROP CONSTRAINT \(name)") - } - if !newColumns.isEmpty { - let cols = newColumns.map { quoteIdentifier($0) }.joined(separator: ", ") - let pkName = constraintName.map { quoteIdentifier($0) } ?? quoteIdentifier("PK_\(table)") - stmts.append("ALTER TABLE \(qt) ADD CONSTRAINT \(pkName) PRIMARY KEY (\(cols))") - } - return stmts.isEmpty ? nil : stmts - } - } // MARK: - Errors diff --git a/Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+DDL.swift b/Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+DDL.swift new file mode 100644 index 000000000..fabe9f46a --- /dev/null +++ b/Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+DDL.swift @@ -0,0 +1,184 @@ +// +// MSSQLPluginDriver+DDL.swift +// MSSQLDriverPlugin +// + +import Foundation +import os +import TableProMSSQLCore +import TableProPluginKit + +extension MSSQLPluginDriver { + // MARK: - Create Table DDL + + func generateCreateTableSQL(definition: PluginCreateTableDefinition) -> String? { + guard !definition.columns.isEmpty else { return nil } + + let schema = _currentSchema + let qualifiedTable = "\(quoteIdentifier(schema)).\(quoteIdentifier(definition.tableName))" + let pkColumns = definition.columns.filter { $0.isPrimaryKey } + let inlinePK = pkColumns.count == 1 + var parts: [String] = definition.columns.map { mssqlColumnDefinition($0, inlinePK: inlinePK) } + + if pkColumns.count > 1 { + let pkCols = pkColumns.map { quoteIdentifier($0.name) }.joined(separator: ", ") + parts.append("PRIMARY KEY (\(pkCols))") + } + + for fk in definition.foreignKeys { + parts.append(mssqlForeignKeyDefinition(fk)) + } + + var sql = "CREATE TABLE \(qualifiedTable) (\n " + + parts.joined(separator: ",\n ") + + "\n);" + + var indexStatements: [String] = [] + for index in definition.indexes { + indexStatements.append(mssqlIndexDefinition(index, qualifiedTable: qualifiedTable)) + } + if !indexStatements.isEmpty { + sql += "\n\n" + indexStatements.joined(separator: ";\n") + ";" + } + + return sql + } + + private func mssqlColumnDefinition(_ col: PluginColumnDefinition, inlinePK: Bool) -> String { + var def = "\(quoteIdentifier(col.name)) \(col.dataType)" + if col.autoIncrement { + def += " IDENTITY(1,1)" + } + if col.isNullable { + def += " NULL" + } else { + def += " NOT NULL" + } + if let defaultValue = col.defaultValue { + def += " DEFAULT \(mssqlDefaultValue(defaultValue))" + } + if inlinePK && col.isPrimaryKey { + def += " PRIMARY KEY" + } + return def + } + + private func mssqlDefaultValue(_ value: String) -> String { + let upper = value.uppercased() + if upper == "NULL" || upper == "GETDATE()" || upper == "NEWID()" || upper == "GETUTCDATE()" + || value.hasPrefix("'") || value.hasPrefix("(") || Int64(value) != nil || Double(value) != nil { + return value + } + return "'\(escapeStringLiteral(value))'" + } + + private func mssqlIndexDefinition(_ index: PluginIndexDefinition, qualifiedTable: String) -> String { + let cols = index.columns.map { quoteIdentifier($0) }.joined(separator: ", ") + let unique = index.isUnique ? "UNIQUE " : "" + var def = "CREATE \(unique)INDEX \(quoteIdentifier(index.name)) ON \(qualifiedTable) (\(cols))" + if let type = index.indexType?.uppercased(), type == "CLUSTERED" { + def = "CREATE \(unique)CLUSTERED INDEX \(quoteIdentifier(index.name)) ON \(qualifiedTable) (\(cols))" + } else if let type = index.indexType?.uppercased(), type == "NONCLUSTERED" { + def = "CREATE \(unique)NONCLUSTERED INDEX \(quoteIdentifier(index.name)) ON \(qualifiedTable) (\(cols))" + } + return def + } + + private func mssqlForeignKeyDefinition(_ fk: PluginForeignKeyDefinition) -> String { + let cols = fk.columns.map { quoteIdentifier($0) }.joined(separator: ", ") + let refCols = fk.referencedColumns.map { quoteIdentifier($0) }.joined(separator: ", ") + var def = "CONSTRAINT \(quoteIdentifier(fk.name)) FOREIGN KEY (\(cols)) REFERENCES \(quoteIdentifier(fk.referencedTable)) (\(refCols))" + if fk.onDelete != "NO ACTION" { + def += " ON DELETE \(fk.onDelete)" + } + if fk.onUpdate != "NO ACTION" { + def += " ON UPDATE \(fk.onUpdate)" + } + return def + } + + // MARK: - ALTER TABLE DDL + + private func mssqlQualifiedTable(_ table: String) -> String { + "\(quoteIdentifier(_currentSchema)).\(quoteIdentifier(table))" + } + + func generateAddColumnSQL(table: String, column: PluginColumnDefinition) -> String? { + "ALTER TABLE \(mssqlQualifiedTable(table)) ADD \(mssqlColumnDefinition(column, inlinePK: false))" + } + + func generateModifyColumnSQL(table: String, oldColumn: PluginColumnDefinition, newColumn: PluginColumnDefinition) -> String? { + let qt = mssqlQualifiedTable(table) + var stmts: [String] = [] + let needsTypeChange = oldColumn.dataType != newColumn.dataType || oldColumn.isNullable != newColumn.isNullable + let defaultChanged = oldColumn.defaultValue != newColumn.defaultValue + + // Rename column first so subsequent statements reference the correct name + if oldColumn.name != newColumn.name { + let escapedPath = "\(escapeStringLiteral(_currentSchema)).\(escapeStringLiteral(table)).\(escapeStringLiteral(oldColumn.name))" + stmts.append("EXEC sp_rename '\(escapedPath)', '\(escapeStringLiteral(newColumn.name))', 'COLUMN'") + } + + let colName = quoteIdentifier(newColumn.name) + + // Drop existing default constraint before ALTER COLUMN or default change + if (defaultChanged || needsTypeChange) && oldColumn.defaultValue != nil { + let objectId = escapeStringLiteral("\(_currentSchema).\(table)") + stmts.append(""" + DECLARE @dfName NVARCHAR(256); \ + SELECT @dfName = dc.name FROM sys.default_constraints dc \ + JOIN sys.columns c ON dc.parent_column_id = c.column_id AND dc.parent_object_id = c.object_id \ + WHERE c.name = '\(escapeStringLiteral(newColumn.name))' \ + AND dc.parent_object_id = OBJECT_ID('\(objectId)'); \ + IF @dfName IS NOT NULL EXEC('ALTER TABLE \(qt) DROP CONSTRAINT [' + @dfName + ']') + """) + } + + if needsTypeChange { + let nullable = newColumn.isNullable ? "NULL" : "NOT NULL" + stmts.append("ALTER TABLE \(qt) ALTER COLUMN \(colName) \(newColumn.dataType) \(nullable)") + } + + if defaultChanged, let defaultValue = newColumn.defaultValue { + stmts.append("ALTER TABLE \(qt) ADD DEFAULT \(mssqlDefaultValue(defaultValue)) FOR \(colName)") + } + + return stmts.isEmpty ? nil : stmts.joined(separator: ";\n") + } + + func generateDropColumnSQL(table: String, columnName: String) -> String? { + "ALTER TABLE \(mssqlQualifiedTable(table)) DROP COLUMN \(quoteIdentifier(columnName))" + } + + func generateAddIndexSQL(table: String, index: PluginIndexDefinition) -> String? { + mssqlIndexDefinition(index, qualifiedTable: mssqlQualifiedTable(table)) + } + + func generateDropIndexSQL(table: String, indexName: String) -> String? { + "DROP INDEX \(quoteIdentifier(indexName)) ON \(mssqlQualifiedTable(table))" + } + + func generateAddForeignKeySQL(table: String, fk: PluginForeignKeyDefinition) -> String? { + "ALTER TABLE \(mssqlQualifiedTable(table)) ADD \(mssqlForeignKeyDefinition(fk))" + } + + func generateDropForeignKeySQL(table: String, constraintName: String) -> String? { + "ALTER TABLE \(mssqlQualifiedTable(table)) DROP CONSTRAINT \(quoteIdentifier(constraintName))" + } + + func generateModifyPrimaryKeySQL(table: String, oldColumns: [String], newColumns: [String], constraintName: String?) -> [String]? { + let qt = mssqlQualifiedTable(table) + var stmts: [String] = [] + if !oldColumns.isEmpty { + let name = constraintName.map { quoteIdentifier($0) } ?? "/* unknown constraint */" + stmts.append("ALTER TABLE \(qt) DROP CONSTRAINT \(name)") + } + if !newColumns.isEmpty { + let cols = newColumns.map { quoteIdentifier($0) }.joined(separator: ", ") + let pkName = constraintName.map { quoteIdentifier($0) } ?? quoteIdentifier("PK_\(table)") + stmts.append("ALTER TABLE \(qt) ADD CONSTRAINT \(pkName) PRIMARY KEY (\(cols))") + } + return stmts.isEmpty ? nil : stmts + } + +} diff --git a/Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+Schema.swift b/Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+Schema.swift new file mode 100644 index 000000000..e2b27121d --- /dev/null +++ b/Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+Schema.swift @@ -0,0 +1,533 @@ +// +// MSSQLPluginDriver+Schema.swift +// MSSQLDriverPlugin +// + +import Foundation +import os +import TableProMSSQLCore +import TableProPluginKit + +extension MSSQLPluginDriver { + // MARK: - Schema Operations + + func fetchTables(schema: String?) async throws -> [PluginTableInfo] { + let esc = effectiveSchemaEscaped(schema) + let sql = """ + SELECT t.TABLE_NAME, t.TABLE_TYPE + FROM INFORMATION_SCHEMA.TABLES t + WHERE t.TABLE_SCHEMA = '\(esc)' + AND t.TABLE_TYPE IN ('BASE TABLE', 'VIEW') + ORDER BY t.TABLE_NAME + """ + let result = try await execute(query: sql) + return result.rows.compactMap { row -> PluginTableInfo? in + guard let name = row[safe: 0]?.asText else { return nil } + let rawType = row[safe: 1]?.asText + let tableType = (rawType == "VIEW") ? "VIEW" : "TABLE" + return PluginTableInfo(name: name, type: tableType) + } + } + + func fetchColumns(table: String, schema: String?) async throws -> [PluginColumnInfo] { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + let esc = effectiveSchemaEscaped(schema) + let sql = """ + SELECT + c.COLUMN_NAME, + c.DATA_TYPE, + c.CHARACTER_MAXIMUM_LENGTH, + c.NUMERIC_PRECISION, + c.NUMERIC_SCALE, + c.IS_NULLABLE, + c.COLUMN_DEFAULT, + COLUMNPROPERTY(OBJECT_ID(c.TABLE_SCHEMA + '.' + c.TABLE_NAME), c.COLUMN_NAME, 'IsIdentity') AS IS_IDENTITY, + CASE WHEN pk.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END AS IS_PK + FROM INFORMATION_SCHEMA.COLUMNS c + LEFT JOIN ( + SELECT kcu.COLUMN_NAME + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc + JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu + ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME + AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA + WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' + AND tc.TABLE_SCHEMA = '\(esc)' + AND tc.TABLE_NAME = '\(escapedTable)' + ) pk ON c.COLUMN_NAME = pk.COLUMN_NAME + WHERE c.TABLE_NAME = '\(escapedTable)' + AND c.TABLE_SCHEMA = '\(esc)' + ORDER BY c.ORDINAL_POSITION + """ + let result = try await execute(query: sql) + var identityColumns: Set = [] + let columns: [PluginColumnInfo] = result.rows.compactMap { row -> PluginColumnInfo? in + guard let name = row[safe: 0]?.asText else { return nil } + let dataType = row[safe: 1]?.asText + let charLen = row[safe: 2]?.asText + let numPrecision = row[safe: 3]?.asText + let numScale = row[safe: 4]?.asText + let isNullable = (row[safe: 5]?.asText) == "YES" + let defaultValue = row[safe: 6]?.asText + let isIdentity = (row[safe: 7]?.asText) == "1" + let isPk = (row[safe: 8]?.asText) == "1" + + if isIdentity { + identityColumns.insert(name) + } + + let baseType = (dataType ?? "nvarchar").lowercased() + let fixedSizeTypes: Set = [ + "int", "bigint", "smallint", "tinyint", "bit", + "money", "smallmoney", "float", "real", + "datetime", "datetime2", "smalldatetime", "date", "time", + "uniqueidentifier", "text", "ntext", "image", "xml", + "timestamp", "rowversion" + ] + var fullType = baseType + if fixedSizeTypes.contains(baseType) { + // No suffix + } else if let charLen, let len = Int(charLen), len > 0 { + fullType += "(\(len))" + } else if charLen == "-1" { + fullType += "(max)" + } else if let prec = numPrecision, let scale = numScale, + let p = Int(prec), let s = Int(scale) { + fullType += "(\(p),\(s))" + } + + return PluginColumnInfo( + name: name, + dataType: fullType, + isNullable: isNullable, + isPrimaryKey: isPk, + defaultValue: defaultValue, + extra: isIdentity ? "IDENTITY" : nil + ) + } + identityCacheLock.lock() + identityColumnsByTable[table] = identityColumns + identityCacheLock.unlock() + return columns + } + + /// Snapshot of IDENTITY columns observed by the most recent `fetchColumns` for the table. + /// Returns an empty set when `fetchColumns` hasn't run for this table yet, so callers + /// fall through to including every typed value (matching pre-cache behavior). + internal func cachedIdentityColumns(for table: String) -> Set { + identityCacheLock.lock() + defer { identityCacheLock.unlock() } + return identityColumnsByTable[table] ?? [] + } + + /// Test seam: pre-populate the cache so generateMssqlInsert can be exercised + /// without going through a live `fetchColumns` round-trip. + internal func setIdentityColumnsForTesting(_ columns: Set, table: String) { + identityCacheLock.lock() + identityColumnsByTable[table] = columns + identityCacheLock.unlock() + } + + func fetchIndexes(table: String, schema: String?) async throws -> [PluginIndexInfo] { + let esc = (schema ?? _currentSchema).replacingOccurrences(of: "]", with: "]]") + let bracketedTable = table.replacingOccurrences(of: "]", with: "]]") + let bracketedFull = "[\(esc)].[\(bracketedTable)]" + let sql = """ + SELECT i.name, i.is_unique, i.is_primary_key, c.name AS column_name + FROM sys.indexes i + JOIN sys.index_columns ic + ON i.object_id = ic.object_id AND i.index_id = ic.index_id + JOIN sys.columns c + ON ic.object_id = c.object_id AND ic.column_id = c.column_id + WHERE i.object_id = OBJECT_ID('\(bracketedFull)') + AND i.name IS NOT NULL + ORDER BY i.index_id, ic.key_ordinal + """ + let result = try await execute(query: sql) + var indexMap: [String: (unique: Bool, primary: Bool, columns: [String])] = [:] + for row in result.rows { + guard let idxName = row[safe: 0]?.asText, + let colName = row[safe: 3]?.asText else { continue } + let isUnique = (row[safe: 1]?.asText) == "1" + let isPrimary = (row[safe: 2]?.asText) == "1" + if indexMap[idxName] == nil { + indexMap[idxName] = (unique: isUnique, primary: isPrimary, columns: []) + } + indexMap[idxName]?.columns.append(colName) + } + return indexMap.map { name, info in + PluginIndexInfo( + name: name, + columns: info.columns, + isUnique: info.unique, + isPrimary: info.primary, + type: "CLUSTERED" + ) + }.sorted { $0.name < $1.name } + } + + func fetchForeignKeys(table: String, schema: String?) async throws -> [PluginForeignKeyInfo] { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + let esc = effectiveSchemaEscaped(schema) + let sql = """ + SELECT + fk.name AS constraint_name, + cp.name AS column_name, + tr.name AS ref_table, + cr.name AS ref_column + FROM sys.foreign_keys fk + JOIN sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id + JOIN sys.tables tp ON fkc.parent_object_id = tp.object_id + JOIN sys.schemas s ON tp.schema_id = s.schema_id + JOIN sys.columns cp + ON fkc.parent_object_id = cp.object_id AND fkc.parent_column_id = cp.column_id + JOIN sys.tables tr ON fkc.referenced_object_id = tr.object_id + JOIN sys.columns cr + ON fkc.referenced_object_id = cr.object_id AND fkc.referenced_column_id = cr.column_id + WHERE tp.name = '\(escapedTable)' AND s.name = '\(esc)' + ORDER BY fk.name + """ + let result = try await execute(query: sql) + return result.rows.compactMap { row -> PluginForeignKeyInfo? in + guard let constraintName = row[safe: 0]?.asText, + let columnName = row[safe: 1]?.asText, + let refTable = row[safe: 2]?.asText, + let refColumn = row[safe: 3]?.asText else { return nil } + return PluginForeignKeyInfo( + name: constraintName, + column: columnName, + referencedTable: refTable, + referencedColumn: refColumn + ) + } + } + + func fetchAllColumns(schema: String?) async throws -> [String: [PluginColumnInfo]] { + let esc = effectiveSchemaEscaped(schema) + let sql = """ + SELECT + c.TABLE_NAME, + c.COLUMN_NAME, + c.DATA_TYPE, + c.CHARACTER_MAXIMUM_LENGTH, + c.NUMERIC_PRECISION, + c.NUMERIC_SCALE, + c.IS_NULLABLE, + c.COLUMN_DEFAULT, + COLUMNPROPERTY(OBJECT_ID(c.TABLE_SCHEMA + '.' + c.TABLE_NAME), c.COLUMN_NAME, 'IsIdentity') AS IS_IDENTITY, + CASE WHEN pk.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END AS IS_PK + FROM INFORMATION_SCHEMA.COLUMNS c + LEFT JOIN ( + SELECT kcu.TABLE_NAME, kcu.COLUMN_NAME + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc + JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu + ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME + AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA + WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' + AND tc.TABLE_SCHEMA = '\(esc)' + ) pk ON c.TABLE_NAME = pk.TABLE_NAME AND c.COLUMN_NAME = pk.COLUMN_NAME + WHERE c.TABLE_SCHEMA = '\(esc)' + ORDER BY c.TABLE_NAME, c.ORDINAL_POSITION + """ + let result = try await execute(query: sql) + var columnsByTable: [String: [PluginColumnInfo]] = [:] + for row in result.rows { + guard let tableName = row[safe: 0]?.asText, + let name = row[safe: 1]?.asText else { continue } + let dataType = row[safe: 2]?.asText + let charLen = row[safe: 3]?.asText + let numPrecision = row[safe: 4]?.asText + let numScale = row[safe: 5]?.asText + let isNullable = (row[safe: 6]?.asText) == "YES" + let defaultValue = row[safe: 7]?.asText + let isIdentity = (row[safe: 8]?.asText) == "1" + let isPk = (row[safe: 9]?.asText) == "1" + + let baseType = (dataType ?? "nvarchar").lowercased() + let fixedSizeTypes: Set = [ + "int", "bigint", "smallint", "tinyint", "bit", + "money", "smallmoney", "float", "real", + "datetime", "datetime2", "smalldatetime", "date", "time", + "uniqueidentifier", "text", "ntext", "image", "xml", + "timestamp", "rowversion" + ] + var fullType = baseType + if fixedSizeTypes.contains(baseType) { + // No suffix + } else if let charLen, let len = Int(charLen), len > 0 { + fullType += "(\(len))" + } else if charLen == "-1" { + fullType += "(max)" + } else if let prec = numPrecision, let scale = numScale, + let p = Int(prec), let s = Int(scale) { + fullType += "(\(p),\(s))" + } + + let col = PluginColumnInfo( + name: name, + dataType: fullType, + isNullable: isNullable, + isPrimaryKey: isPk, + defaultValue: defaultValue, + extra: isIdentity ? "IDENTITY" : nil + ) + columnsByTable[tableName, default: []].append(col) + } + return columnsByTable + } + + func fetchAllForeignKeys(schema: String?) async throws -> [String: [PluginForeignKeyInfo]] { + let esc = effectiveSchemaEscaped(schema) + let sql = """ + SELECT + tp.name AS table_name, + fk.name AS constraint_name, + cp.name AS column_name, + tr.name AS ref_table, + cr.name AS ref_column + FROM sys.foreign_keys fk + JOIN sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id + JOIN sys.tables tp ON fkc.parent_object_id = tp.object_id + JOIN sys.schemas s ON tp.schema_id = s.schema_id + JOIN sys.columns cp + ON fkc.parent_object_id = cp.object_id AND fkc.parent_column_id = cp.column_id + JOIN sys.tables tr ON fkc.referenced_object_id = tr.object_id + JOIN sys.columns cr + ON fkc.referenced_object_id = cr.object_id AND fkc.referenced_column_id = cr.column_id + WHERE s.name = '\(esc)' + ORDER BY tp.name, fk.name + """ + let result = try await execute(query: sql) + var fksByTable: [String: [PluginForeignKeyInfo]] = [:] + for row in result.rows { + guard let tableName = row[safe: 0]?.asText, + let constraintName = row[safe: 1]?.asText, + let columnName = row[safe: 2]?.asText, + let refTable = row[safe: 3]?.asText, + let refColumn = row[safe: 4]?.asText else { continue } + let fk = PluginForeignKeyInfo( + name: constraintName, + column: columnName, + referencedTable: refTable, + referencedColumn: refColumn + ) + fksByTable[tableName, default: []].append(fk) + } + return fksByTable + } + + func fetchAllDatabaseMetadata() async throws -> [PluginDatabaseMetadata] { + let sql = """ + SELECT d.name, + SUM(mf.size) * 8 * 1024 AS size_bytes + FROM sys.databases d + LEFT JOIN sys.master_files mf ON d.database_id = mf.database_id + GROUP BY d.name + ORDER BY d.name + """ + do { + let result = try await execute(query: sql) + var metadata = result.rows.compactMap { row -> PluginDatabaseMetadata? in + guard let name = row[safe: 0]?.asText else { return nil } + let sizeBytes = (row[safe: 1]?.asText).flatMap { Int64($0) } + return PluginDatabaseMetadata(name: name, sizeBytes: sizeBytes) + } + + for i in metadata.indices { + let dbName = metadata[i].name.replacingOccurrences(of: "]", with: "]]") + do { + let countResult = try await execute( + query: "SELECT COUNT(*) FROM [\(dbName)].sys.tables" + ) + if let countStr = countResult.rows.first?[safe: 0]?.asText, + let count = Int(countStr) { + metadata[i] = PluginDatabaseMetadata( + name: metadata[i].name, + tableCount: count, + sizeBytes: metadata[i].sizeBytes + ) + } + } catch { + // Database offline or permission denied: leave tableCount as nil + } + } + + return metadata + } catch { + // Fall back to N+1 if permission denied on sys.master_files + let dbs = try await fetchDatabases() + var result: [PluginDatabaseMetadata] = [] + for db in dbs { + do { + result.append(try await fetchDatabaseMetadata(db)) + } catch { + result.append(PluginDatabaseMetadata(name: db)) + } + } + return result + } + } + + func fetchTableDDL(table: String, schema: String?) async throws -> String { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + let esc = effectiveSchemaEscaped(schema) + let cols = try await fetchColumns(table: table, schema: schema) + let indexes = try await fetchIndexes(table: table, schema: schema) + let fks = try await fetchForeignKeys(table: table, schema: schema) + + var ddl = "CREATE TABLE [\(esc)].[\(escapedTable)] (\n" + let colDefs = cols.map { col -> String in + var def = " [\(col.name)] \(col.dataType.uppercased())" + if col.extra == "IDENTITY" { def += " IDENTITY(1,1)" } + def += col.isNullable ? " NULL" : " NOT NULL" + if let d = col.defaultValue { def += " DEFAULT \(d)" } + return def + } + + let pkCols = indexes.filter(\.isPrimary).flatMap(\.columns) + var parts = colDefs + if !pkCols.isEmpty { + let pkName = "PK_\(table)" + let pkDef = " CONSTRAINT [\(pkName)] PRIMARY KEY (\(pkCols.map { "[\($0)]" }.joined(separator: ", ")))" + parts.append(pkDef) + } + + for fk in fks { + let fkDef = " CONSTRAINT [\(fk.name)] FOREIGN KEY ([\(fk.column)]) REFERENCES [\(fk.referencedTable)] ([\(fk.referencedColumn)])" + parts.append(fkDef) + } + + ddl += parts.joined(separator: ",\n") + ddl += "\n);" + return ddl + } + + func fetchViewDefinition(view: String, schema: String?) async throws -> String { + let esc = effectiveSchemaEscaped(schema) + let escapedView = "\(esc).\(view.replacingOccurrences(of: "'", with: "''"))" + let sql = "SELECT definition FROM sys.sql_modules WHERE object_id = OBJECT_ID('\(escapedView)')" + let result = try await execute(query: sql) + return result.rows.first?.first?.asText ?? "" + } + + func fetchTableMetadata(table: String, schema: String?) async throws -> PluginTableMetadata { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + let esc = effectiveSchemaEscaped(schema) + let sql = """ + SELECT + SUM(p.rows) AS row_count, + 8 * SUM(a.used_pages) AS size_kb, + ep.value AS comment + FROM sys.tables t + JOIN sys.schemas s ON t.schema_id = s.schema_id + JOIN sys.partitions p + ON t.object_id = p.object_id AND p.index_id IN (0, 1) + JOIN sys.allocation_units a ON p.partition_id = a.container_id + LEFT JOIN sys.extended_properties ep + ON ep.major_id = t.object_id AND ep.minor_id = 0 AND ep.name = 'MS_Description' + WHERE t.name = '\(escapedTable)' AND s.name = '\(esc)' + GROUP BY ep.value + """ + let result = try await execute(query: sql) + if let row = result.rows.first { + let rowCount = (row[safe: 0]?.asText).flatMap { Int64($0) } + let sizeKb = (row[safe: 1]?.asText).flatMap { Int64($0) } ?? 0 + let comment = row[safe: 2]?.asText + return PluginTableMetadata( + tableName: table, + dataSize: sizeKb * 1_024, + totalSize: sizeKb * 1_024, + rowCount: rowCount, + comment: comment + ) + } + return PluginTableMetadata(tableName: table) + } + + func fetchDatabases() async throws -> [String] { + let sql = "SELECT name FROM sys.databases ORDER BY name" + let result = try await execute(query: sql) + return result.rows.compactMap { $0.first?.asText } + } + + func fetchSchemas() async throws -> [String] { + let sql = """ + SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA + WHERE SCHEMA_NAME NOT IN ( + 'information_schema','sys','db_owner','db_accessadmin', + 'db_securityadmin','db_ddladmin','db_backupoperator', + 'db_datareader','db_datawriter','db_denydatareader', + 'db_denydatawriter','guest' + ) + ORDER BY SCHEMA_NAME + """ + let result = try await execute(query: sql) + return result.rows.compactMap { $0.first?.asText } + } + + func switchSchema(to schema: String) async throws { + _currentSchema = schema + } + + func switchDatabase(to database: String) async throws { + guard let conn = freeTDSConn else { + throw MSSQLPluginError.notConnected + } + try await conn.switchDatabase(database) + } + + func fetchDatabaseMetadata(_ database: String) async throws -> PluginDatabaseMetadata { + let sql = """ + SELECT + SUM(size) * 8.0 / 1024 AS size_mb, + (SELECT COUNT(*) FROM sys.tables) AS table_count + FROM sys.database_files + """ + let result = try await execute(query: sql) + if let row = result.rows.first { + let sizeMb = (row[safe: 0]?.asText).flatMap { Double($0) } ?? 0 + let tableCount = (row[safe: 1]?.asText).flatMap { Int($0) } ?? 0 + return PluginDatabaseMetadata( + name: database, + tableCount: tableCount, + sizeBytes: Int64(sizeMb * 1_024 * 1_024) + ) + } + return PluginDatabaseMetadata(name: database) + } + + func createDatabaseFormSpec() async throws -> PluginCreateDatabaseFormSpec? { + PluginCreateDatabaseFormSpec(fields: [], footnote: nil) + } + + func createDatabase(_ request: PluginCreateDatabaseRequest) async throws { + let quotedName = "[\(request.name.replacingOccurrences(of: "]", with: "]]"))]" + _ = try await execute(query: "CREATE DATABASE \(quotedName)") + } + + func dropDatabase(name: String) async throws { + let quotedName = "[\(name.replacingOccurrences(of: "]", with: "]]"))]" + _ = try await execute(query: "DROP DATABASE \(quotedName)") + } + + // MARK: - All Tables Metadata + + func allTablesMetadataSQL(schema: String?) -> String? { + """ + SELECT + s.name as schema_name, + t.name as name, + CASE WHEN v.object_id IS NOT NULL THEN 'VIEW' ELSE 'TABLE' END as kind, + p.rows as estimated_rows, + CAST(ROUND(SUM(a.total_pages) * 8 / 1024.0, 2) AS VARCHAR) + ' MB' as total_size + FROM sys.tables t + INNER JOIN sys.schemas s ON t.schema_id = s.schema_id + INNER JOIN sys.indexes i ON t.object_id = i.object_id AND i.index_id IN (0, 1) + INNER JOIN sys.partitions p ON i.object_id = p.object_id AND i.index_id = p.index_id + INNER JOIN sys.allocation_units a ON p.partition_id = a.container_id + LEFT JOIN sys.views v ON t.object_id = v.object_id + GROUP BY s.name, t.name, p.rows, v.object_id + ORDER BY t.name + """ + } + +} diff --git a/Plugins/MongoDBDriverPlugin/MongoDBConnection+SyncHelpers.swift b/Plugins/MongoDBDriverPlugin/MongoDBConnection+SyncHelpers.swift new file mode 100644 index 000000000..776ef452e --- /dev/null +++ b/Plugins/MongoDBDriverPlugin/MongoDBConnection+SyncHelpers.swift @@ -0,0 +1,458 @@ +// +// MongoDBConnection+SyncHelpers.swift +// MongoDBDriverPlugin +// + +#if canImport(CLibMongoc) +import CLibMongoc +#endif +import Foundation +import OSLog +import TableProPluginKit + +#if canImport(CLibMongoc) +extension MongoDBConnection { + func bsonErrorMessage(_ error: inout bson_error_t) -> String { + withUnsafePointer(to: &error.message) { ptr in + ptr.withMemoryRebound(to: CChar.self, capacity: 504) { String(cString: $0) } + } + } + + func makeError(_ error: bson_error_t) -> MongoDBError { + var err = error + return MongoDBError(code: err.code, message: bsonErrorMessage(&err)) + } + + func fetchServerVersionSync() -> String? { + guard let client = self.client, + let command = jsonToBson("{\"buildInfo\": 1}") else { return nil } + defer { bson_destroy(command) } + + let reply = bson_new() + defer { bson_destroy(reply) } + var error = bson_error_t() + + let dbName = database.isEmpty ? "admin" : database + let ok = dbName.withCString { mongoc_client_command_simple(client, $0, command, nil, reply, &error) } + guard ok else { return nil } + + return bsonToDict(reply)["version"] as? String + } + + func getCollection( + _ client: OpaquePointer, database: String, collection: String + ) throws -> OpaquePointer { + guard let col = database.withCString({ dbPtr in + collection.withCString { colPtr in mongoc_client_get_collection(client, dbPtr, colPtr) } + }) else { + throw MongoDBError(code: 0, message: "Failed to get collection \(collection)") + } + return col + } + + func runCommandSync( + client: OpaquePointer, command: String, database: String? + ) throws -> [[String: Any]] { + try checkCancelled() + + guard let bsonCmd = jsonToBson(command) else { + throw MongoDBError(code: 0, message: "Invalid JSON command: \(command)") + } + defer { bson_destroy(bsonCmd) } + + let timeoutMS = queryTimeoutMS + if timeoutMS > 0, !bson_has_field(bsonCmd, "maxTimeMS") { + bson_append_int32(bsonCmd, "maxTimeMS", -1, timeoutMS) + } + + let reply = bson_new() + defer { bson_destroy(reply) } + var error = bson_error_t() + + let effectiveDb = (database ?? self.database).isEmpty ? "admin" : (database ?? self.database) + let ok = effectiveDb.withCString { mongoc_client_command_simple(client, $0, bsonCmd, nil, reply, &error) } + + try checkCancelled() + guard ok else { throw makeError(error) } + + return [bsonToDict(reply)] + } + + func findSync( + client: OpaquePointer, database: String, collection: String, + filter: String, sort: String?, projection: String?, skip: Int, limit: Int + ) throws -> (docs: [[String: Any]], isTruncated: Bool) { + try checkCancelled() + + guard let filterBson = jsonToBson(filter) else { + throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") + } + defer { bson_destroy(filterBson) } + + var optsJson: [String: Any] = ["skip": skip, "limit": limit] + if let sort = sort, let data = sort.data(using: .utf8), + let obj = try? JSONSerialization.jsonObject(with: data) { + optsJson["sort"] = obj + } + if let projection = projection, let data = projection.data(using: .utf8), + let obj = try? JSONSerialization.jsonObject(with: data) { + optsJson["projection"] = obj + } + + let timeoutMS = queryTimeoutMS + if timeoutMS > 0 { + optsJson["maxTimeMS"] = timeoutMS + } + + let optsData = try JSONSerialization.data(withJSONObject: optsJson) + guard let optsStr = String(data: optsData, encoding: .utf8), + let optsBson = jsonToBson(optsStr) else { + throw MongoDBError(code: 0, message: "Failed to build query options") + } + defer { bson_destroy(optsBson) } + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + try checkCancelled() + + guard let cursor = mongoc_collection_find_with_opts(col, filterBson, optsBson, nil) else { + throw MongoDBError(code: 0, message: "Failed to create find cursor") + } + defer { mongoc_cursor_destroy(cursor) } + + return try iterateCursor(cursor) + } + + func aggregateSync( + client: OpaquePointer, database: String, collection: String, pipeline: String + ) throws -> (docs: [[String: Any]], isTruncated: Bool) { + try checkCancelled() + + guard let pipelineBson = jsonToBson(pipeline) else { + throw MongoDBError(code: 0, message: "Invalid JSON pipeline: \(pipeline)") + } + defer { bson_destroy(pipelineBson) } + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + let timeoutMS = queryTimeoutMS + var optsBson: OpaquePointer? + if timeoutMS > 0 { + optsBson = jsonToBson("{\"maxTimeMS\": \(timeoutMS)}") + } + defer { if let opts = optsBson { bson_destroy(opts) } } + + try checkCancelled() + + guard let cursor = mongoc_collection_aggregate( + col, MONGOC_QUERY_NONE, pipelineBson, optsBson, nil + ) else { + throw MongoDBError(code: 0, message: "Failed to create aggregation cursor") + } + defer { mongoc_cursor_destroy(cursor) } + + return try iterateCursor(cursor) + } + + func countDocumentsSync( + client: OpaquePointer, database: String, collection: String, filter: String + ) throws -> Int64 { + try checkCancelled() + + guard let filterBson = jsonToBson(filter) else { + throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") + } + defer { bson_destroy(filterBson) } + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + let timeoutMS = queryTimeoutMS + var optsBson: OpaquePointer? + if timeoutMS > 0 { + optsBson = jsonToBson("{\"maxTimeMS\": \(timeoutMS)}") + } + defer { if let opts = optsBson { bson_destroy(opts) } } + + var error = bson_error_t() + let count = mongoc_collection_count_documents(col, filterBson, optsBson, nil, nil, &error) + + try checkCancelled() + guard count >= 0 else { throw makeError(error) } + return count + } + + func insertOneSync( + client: OpaquePointer, database: String, collection: String, document: String + ) throws -> String? { + try checkCancelled() + + guard let docBson = jsonToBson(document) else { + throw MongoDBError(code: 0, message: "Invalid JSON document: \(document)") + } + defer { bson_destroy(docBson) } + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + let reply = bson_new() + defer { bson_destroy(reply) } + var error = bson_error_t() + + guard mongoc_collection_insert_one(col, docBson, nil, reply, &error) else { + throw makeError(error) + } + + if let objectId = bsonToDict(docBson)["_id"] { return "\(objectId)" } + return nil + } + + func updateOneSync( + client: OpaquePointer, database: String, collection: String, filter: String, update: String + ) throws -> Int64 { + try checkCancelled() + + guard let filterBson = jsonToBson(filter) else { + throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") + } + defer { bson_destroy(filterBson) } + + guard let updateBson = jsonToBson(update) else { + throw MongoDBError(code: 0, message: "Invalid JSON update: \(update)") + } + defer { bson_destroy(updateBson) } + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + let reply = bson_new() + defer { bson_destroy(reply) } + var error = bson_error_t() + + guard mongoc_collection_update_one(col, filterBson, updateBson, nil, reply, &error) else { + throw makeError(error) + } + return (bsonToDict(reply)["modifiedCount"] as? Int64) ?? 0 + } + + func deleteOneSync( + client: OpaquePointer, database: String, collection: String, filter: String + ) throws -> Int64 { + try checkCancelled() + + guard let filterBson = jsonToBson(filter) else { + throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") + } + defer { bson_destroy(filterBson) } + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + let reply = bson_new() + defer { bson_destroy(reply) } + var error = bson_error_t() + + guard mongoc_collection_delete_one(col, filterBson, nil, reply, &error) else { + throw makeError(error) + } + return (bsonToDict(reply)["deletedCount"] as? Int64) ?? 0 + } + + func listDatabasesSync(client: OpaquePointer) throws -> [String] { + try checkCancelled() + + let caps = MongoDBCapabilities.parse(serverVersion()) + var fields = ["\"listDatabases\": 1"] + if caps.supportsListDatabasesNameOnly { + fields.append("\"nameOnly\": true") + } + if caps.supportsAuthorizedDatabases { + fields.append("\"authorizedDatabases\": true") + } + let commandJSON = "{\(fields.joined(separator: ", "))}" + guard let command = jsonToBson(commandJSON) else { + throw MongoDBError(code: 0, message: "Failed to create listDatabases command") + } + defer { bson_destroy(command) } + + let reply = bson_new() + defer { bson_destroy(reply) } + var error = bson_error_t() + + let ok = "admin".withCString { mongoc_client_command_simple(client, $0, command, nil, reply, &error) } + + try checkCancelled() + guard ok else { throw makeError(error) } + + guard let databases = bsonToDict(reply)["databases"] as? [[String: Any]] else { return [] } + return databases.compactMap { $0["name"] as? String } + } + + func listCollectionsSync(client: OpaquePointer, database: String) throws -> [String] { + try checkCancelled() + + guard let mongocDb = database.withCString({ mongoc_client_get_database(client, $0) }) else { + throw MongoDBError(code: 0, message: "Failed to get database \(database)") + } + defer { mongoc_database_destroy(mongocDb) } + + var error = bson_error_t() + guard let names = mongoc_database_get_collection_names_with_opts(mongocDb, nil, &error) else { + throw makeError(error) + } + defer { bson_strfreev(names) } + + try checkCancelled() + + var collections: [String] = [] + var index = 0 + while let namePtr = names[index] { + collections.append(String(cString: namePtr)) + index += 1 + } + return collections + } + + func listIndexesSync( + client: OpaquePointer, database: String, collection: String + ) throws -> [[String: Any]] { + try checkCancelled() + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + guard let cursor = mongoc_collection_find_indexes_with_opts(col, nil) else { + throw MongoDBError(code: 0, message: "Failed to list indexes for \(collection)") + } + defer { mongoc_cursor_destroy(cursor) } + + return try iterateCursor(cursor).docs + } + + func iterateCursor(_ cursor: OpaquePointer) throws -> (docs: [[String: Any]], isTruncated: Bool) { + try checkCancelled() + + var results: [[String: Any]] = [] + var docPtr: OpaquePointer? + var truncated = false + + while mongoc_cursor_next(cursor, &docPtr) { + try checkCancelled() + + if let doc = docPtr { + results.append(bsonToDict(doc)) + } + + if results.count >= PluginRowLimits.emergencyMax { + truncated = true + logger.warning("Result set truncated at \(PluginRowLimits.emergencyMax) documents") + break + } + } + + var error = bson_error_t() + if mongoc_cursor_error(cursor, &error) { + throw makeError(error) + } + return (docs: results, isTruncated: truncated) + } + + func iterateCursorStreaming( + cursor: OpaquePointer, + continuation: AsyncThrowingStream.Continuation, + streamState: MongoStreamState + ) { + var docPtr: OpaquePointer? + var headerSent = false + var columns: [String] = [] + var columnTypeNames: [String] = [] + + while mongoc_cursor_next(cursor, &docPtr) { + if Task.isCancelled { + cleanup(streamState) + continuation.finish(throwing: CancellationError()) + return + } + + guard let doc = docPtr else { continue } + let dict = bsonToDict(doc) + + if !headerSent { + columns = BsonDocumentFlattener.unionColumns(from: [dict]) + let bsonTypes = BsonDocumentFlattener.columnTypes(for: columns, documents: [dict]) + columnTypeNames = bsonTypes.map { bsonTypeToStreamString($0) } + continuation.yield(.header(PluginStreamHeader( + columns: columns, + columnTypeNames: columnTypeNames + ))) + headerSent = true + } else { + for key in dict.keys.sorted() where !columns.contains(key) { + columns.append(key) + let type = BsonDocumentFlattener.columnTypes(for: [key], documents: [dict]) + columnTypeNames.append(bsonTypeToStreamString(type.first ?? 2)) + } + } + + let row: [PluginCellValue] = columns.map { column in + guard let value = dict[column] else { return .null } + if let data = value as? Data { + return .bytes(data) + } + return PluginCellValue.fromOptional(BsonDocumentFlattener.stringValue(for: value)) + } + continuation.yield(.rows([row])) + } + + var error = bson_error_t() + if mongoc_cursor_error(cursor, &error) { + cleanup(streamState) + continuation.finish(throwing: makeError(error)) + return + } + + if !headerSent { + continuation.yield(.header(PluginStreamHeader( + columns: ["_id"], + columnTypeNames: ["VARCHAR"] + ))) + } + + cleanup(streamState) + continuation.finish() + } + + private func cleanup(_ state: MongoStreamState) { + state.lock.lock() + let cur = state.cursor + let col = state.collection + let alreadyDrained = state.drained + state.drained = true + state.cursor = nil + state.collection = nil + state.lock.unlock() + guard !alreadyDrained else { return } + if let cur { mongoc_cursor_destroy(cur) } + if let col { mongoc_collection_destroy(col) } + } + + private func bsonTypeToStreamString(_ type: Int32) -> String { + switch type { + case 1: return "FLOAT" + case 2: return "VARCHAR" + case 3: return "JSON" + case 4: return "JSON" + case 5: return "BLOB" + case 7: return "VARCHAR" + case 8: return "BOOLEAN" + case 9: return "TIMESTAMP" + case 10: return "VARCHAR" + case 16: return "INTEGER" + case 18: return "BIGINT" + default: return "VARCHAR" + } + } +} +#endif diff --git a/Plugins/MongoDBDriverPlugin/MongoDBConnection.swift b/Plugins/MongoDBDriverPlugin/MongoDBConnection.swift index f82187aa4..8e05ab707 100644 --- a/Plugins/MongoDBDriverPlugin/MongoDBConnection.swift +++ b/Plugins/MongoDBDriverPlugin/MongoDBConnection.swift @@ -13,7 +13,7 @@ import Foundation import OSLog import TableProPluginKit -private let logger = Logger(subsystem: "com.TablePro", category: "MongoDBConnection") +let logger = Logger(subsystem: "com.TablePro", category: "MongoDBConnection") // MARK: - Error Types @@ -49,7 +49,7 @@ final class MongoDBConnection: @unchecked Sendable { mongoc_init() }() - private var client: OpaquePointer? + var client: OpaquePointer? #endif private static let queueKey = DispatchSpecificKey() @@ -58,7 +58,7 @@ final class MongoDBConnection: @unchecked Sendable { private let port: Int private let user: String private let password: String? - private let database: String + let database: String private let ssl: SSLConfiguration private let authSource: String? private let readPreference: String? @@ -94,7 +94,7 @@ final class MongoDBConnection: @unchecked Sendable { } } - private var queryTimeoutMS: Int32 { + var queryTimeoutMS: Int32 { stateLock.lock() defer { stateLock.unlock() } return _queryTimeoutMS @@ -350,7 +350,7 @@ final class MongoDBConnection: @unchecked Sendable { /// Throws if cancellation was requested, resetting the flag atomically. /// Safe to call from any thread. - private func checkCancelled() throws { + func checkCancelled() throws { stateLock.lock() let cancelled = _isCancelled if cancelled { _isCancelled = false } @@ -777,454 +777,6 @@ final class MongoDBConnection: @unchecked Sendable { } } -// MARK: - Synchronous Helpers (must be called on the serial queue) - -#if canImport(CLibMongoc) -private extension MongoDBConnection { - func bsonErrorMessage(_ error: inout bson_error_t) -> String { - withUnsafePointer(to: &error.message) { ptr in - ptr.withMemoryRebound(to: CChar.self, capacity: 504) { String(cString: $0) } - } - } - - func makeError(_ error: bson_error_t) -> MongoDBError { - var err = error - return MongoDBError(code: err.code, message: bsonErrorMessage(&err)) - } - - func fetchServerVersionSync() -> String? { - guard let client = self.client, - let command = jsonToBson("{\"buildInfo\": 1}") else { return nil } - defer { bson_destroy(command) } - - let reply = bson_new() - defer { bson_destroy(reply) } - var error = bson_error_t() - - let dbName = database.isEmpty ? "admin" : database - let ok = dbName.withCString { mongoc_client_command_simple(client, $0, command, nil, reply, &error) } - guard ok else { return nil } - - return bsonToDict(reply)["version"] as? String - } - - func getCollection( - _ client: OpaquePointer, database: String, collection: String - ) throws -> OpaquePointer { - guard let col = database.withCString({ dbPtr in - collection.withCString { colPtr in mongoc_client_get_collection(client, dbPtr, colPtr) } - }) else { - throw MongoDBError(code: 0, message: "Failed to get collection \(collection)") - } - return col - } - - func runCommandSync( - client: OpaquePointer, command: String, database: String? - ) throws -> [[String: Any]] { - try checkCancelled() - - guard let bsonCmd = jsonToBson(command) else { - throw MongoDBError(code: 0, message: "Invalid JSON command: \(command)") - } - defer { bson_destroy(bsonCmd) } - - let timeoutMS = queryTimeoutMS - if timeoutMS > 0, !bson_has_field(bsonCmd, "maxTimeMS") { - bson_append_int32(bsonCmd, "maxTimeMS", -1, timeoutMS) - } - - let reply = bson_new() - defer { bson_destroy(reply) } - var error = bson_error_t() - - let effectiveDb = (database ?? self.database).isEmpty ? "admin" : (database ?? self.database) - let ok = effectiveDb.withCString { mongoc_client_command_simple(client, $0, bsonCmd, nil, reply, &error) } - - try checkCancelled() - guard ok else { throw makeError(error) } - - return [bsonToDict(reply)] - } - - func findSync( - client: OpaquePointer, database: String, collection: String, - filter: String, sort: String?, projection: String?, skip: Int, limit: Int - ) throws -> (docs: [[String: Any]], isTruncated: Bool) { - try checkCancelled() - - guard let filterBson = jsonToBson(filter) else { - throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") - } - defer { bson_destroy(filterBson) } - - var optsJson: [String: Any] = ["skip": skip, "limit": limit] - if let sort = sort, let data = sort.data(using: .utf8), - let obj = try? JSONSerialization.jsonObject(with: data) { - optsJson["sort"] = obj - } - if let projection = projection, let data = projection.data(using: .utf8), - let obj = try? JSONSerialization.jsonObject(with: data) { - optsJson["projection"] = obj - } - - let timeoutMS = queryTimeoutMS - if timeoutMS > 0 { - optsJson["maxTimeMS"] = timeoutMS - } - - let optsData = try JSONSerialization.data(withJSONObject: optsJson) - guard let optsStr = String(data: optsData, encoding: .utf8), - let optsBson = jsonToBson(optsStr) else { - throw MongoDBError(code: 0, message: "Failed to build query options") - } - defer { bson_destroy(optsBson) } - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - try checkCancelled() - - guard let cursor = mongoc_collection_find_with_opts(col, filterBson, optsBson, nil) else { - throw MongoDBError(code: 0, message: "Failed to create find cursor") - } - defer { mongoc_cursor_destroy(cursor) } - - return try iterateCursor(cursor) - } - - func aggregateSync( - client: OpaquePointer, database: String, collection: String, pipeline: String - ) throws -> (docs: [[String: Any]], isTruncated: Bool) { - try checkCancelled() - - guard let pipelineBson = jsonToBson(pipeline) else { - throw MongoDBError(code: 0, message: "Invalid JSON pipeline: \(pipeline)") - } - defer { bson_destroy(pipelineBson) } - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - let timeoutMS = queryTimeoutMS - var optsBson: OpaquePointer? - if timeoutMS > 0 { - optsBson = jsonToBson("{\"maxTimeMS\": \(timeoutMS)}") - } - defer { if let opts = optsBson { bson_destroy(opts) } } - - try checkCancelled() - - guard let cursor = mongoc_collection_aggregate( - col, MONGOC_QUERY_NONE, pipelineBson, optsBson, nil - ) else { - throw MongoDBError(code: 0, message: "Failed to create aggregation cursor") - } - defer { mongoc_cursor_destroy(cursor) } - - return try iterateCursor(cursor) - } - - func countDocumentsSync( - client: OpaquePointer, database: String, collection: String, filter: String - ) throws -> Int64 { - try checkCancelled() - - guard let filterBson = jsonToBson(filter) else { - throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") - } - defer { bson_destroy(filterBson) } - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - let timeoutMS = queryTimeoutMS - var optsBson: OpaquePointer? - if timeoutMS > 0 { - optsBson = jsonToBson("{\"maxTimeMS\": \(timeoutMS)}") - } - defer { if let opts = optsBson { bson_destroy(opts) } } - - var error = bson_error_t() - let count = mongoc_collection_count_documents(col, filterBson, optsBson, nil, nil, &error) - - try checkCancelled() - guard count >= 0 else { throw makeError(error) } - return count - } - - func insertOneSync( - client: OpaquePointer, database: String, collection: String, document: String - ) throws -> String? { - try checkCancelled() - - guard let docBson = jsonToBson(document) else { - throw MongoDBError(code: 0, message: "Invalid JSON document: \(document)") - } - defer { bson_destroy(docBson) } - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - let reply = bson_new() - defer { bson_destroy(reply) } - var error = bson_error_t() - - guard mongoc_collection_insert_one(col, docBson, nil, reply, &error) else { - throw makeError(error) - } - - if let objectId = bsonToDict(docBson)["_id"] { return "\(objectId)" } - return nil - } - - func updateOneSync( - client: OpaquePointer, database: String, collection: String, filter: String, update: String - ) throws -> Int64 { - try checkCancelled() - - guard let filterBson = jsonToBson(filter) else { - throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") - } - defer { bson_destroy(filterBson) } - - guard let updateBson = jsonToBson(update) else { - throw MongoDBError(code: 0, message: "Invalid JSON update: \(update)") - } - defer { bson_destroy(updateBson) } - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - let reply = bson_new() - defer { bson_destroy(reply) } - var error = bson_error_t() - - guard mongoc_collection_update_one(col, filterBson, updateBson, nil, reply, &error) else { - throw makeError(error) - } - return (bsonToDict(reply)["modifiedCount"] as? Int64) ?? 0 - } - - func deleteOneSync( - client: OpaquePointer, database: String, collection: String, filter: String - ) throws -> Int64 { - try checkCancelled() - - guard let filterBson = jsonToBson(filter) else { - throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") - } - defer { bson_destroy(filterBson) } - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - let reply = bson_new() - defer { bson_destroy(reply) } - var error = bson_error_t() - - guard mongoc_collection_delete_one(col, filterBson, nil, reply, &error) else { - throw makeError(error) - } - return (bsonToDict(reply)["deletedCount"] as? Int64) ?? 0 - } - - func listDatabasesSync(client: OpaquePointer) throws -> [String] { - try checkCancelled() - - let caps = MongoDBCapabilities.parse(serverVersion()) - var fields = ["\"listDatabases\": 1"] - if caps.supportsListDatabasesNameOnly { - fields.append("\"nameOnly\": true") - } - if caps.supportsAuthorizedDatabases { - fields.append("\"authorizedDatabases\": true") - } - let commandJSON = "{\(fields.joined(separator: ", "))}" - guard let command = jsonToBson(commandJSON) else { - throw MongoDBError(code: 0, message: "Failed to create listDatabases command") - } - defer { bson_destroy(command) } - - let reply = bson_new() - defer { bson_destroy(reply) } - var error = bson_error_t() - - let ok = "admin".withCString { mongoc_client_command_simple(client, $0, command, nil, reply, &error) } - - try checkCancelled() - guard ok else { throw makeError(error) } - - guard let databases = bsonToDict(reply)["databases"] as? [[String: Any]] else { return [] } - return databases.compactMap { $0["name"] as? String } - } - - func listCollectionsSync(client: OpaquePointer, database: String) throws -> [String] { - try checkCancelled() - - guard let mongocDb = database.withCString({ mongoc_client_get_database(client, $0) }) else { - throw MongoDBError(code: 0, message: "Failed to get database \(database)") - } - defer { mongoc_database_destroy(mongocDb) } - - var error = bson_error_t() - guard let names = mongoc_database_get_collection_names_with_opts(mongocDb, nil, &error) else { - throw makeError(error) - } - defer { bson_strfreev(names) } - - try checkCancelled() - - var collections: [String] = [] - var index = 0 - while let namePtr = names[index] { - collections.append(String(cString: namePtr)) - index += 1 - } - return collections - } - - func listIndexesSync( - client: OpaquePointer, database: String, collection: String - ) throws -> [[String: Any]] { - try checkCancelled() - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - guard let cursor = mongoc_collection_find_indexes_with_opts(col, nil) else { - throw MongoDBError(code: 0, message: "Failed to list indexes for \(collection)") - } - defer { mongoc_cursor_destroy(cursor) } - - return try iterateCursor(cursor).docs - } - - func iterateCursor(_ cursor: OpaquePointer) throws -> (docs: [[String: Any]], isTruncated: Bool) { - try checkCancelled() - - var results: [[String: Any]] = [] - var docPtr: OpaquePointer? - var truncated = false - - while mongoc_cursor_next(cursor, &docPtr) { - try checkCancelled() - - if let doc = docPtr { - results.append(bsonToDict(doc)) - } - - if results.count >= PluginRowLimits.emergencyMax { - truncated = true - logger.warning("Result set truncated at \(PluginRowLimits.emergencyMax) documents") - break - } - } - - var error = bson_error_t() - if mongoc_cursor_error(cursor, &error) { - throw makeError(error) - } - return (docs: results, isTruncated: truncated) - } - - func iterateCursorStreaming( - cursor: OpaquePointer, - continuation: AsyncThrowingStream.Continuation, - streamState: MongoStreamState - ) { - var docPtr: OpaquePointer? - var headerSent = false - var columns: [String] = [] - var columnTypeNames: [String] = [] - - while mongoc_cursor_next(cursor, &docPtr) { - if Task.isCancelled { - cleanup(streamState) - continuation.finish(throwing: CancellationError()) - return - } - - guard let doc = docPtr else { continue } - let dict = bsonToDict(doc) - - if !headerSent { - columns = BsonDocumentFlattener.unionColumns(from: [dict]) - let bsonTypes = BsonDocumentFlattener.columnTypes(for: columns, documents: [dict]) - columnTypeNames = bsonTypes.map { bsonTypeToStreamString($0) } - continuation.yield(.header(PluginStreamHeader( - columns: columns, - columnTypeNames: columnTypeNames - ))) - headerSent = true - } else { - for key in dict.keys.sorted() where !columns.contains(key) { - columns.append(key) - let type = BsonDocumentFlattener.columnTypes(for: [key], documents: [dict]) - columnTypeNames.append(bsonTypeToStreamString(type.first ?? 2)) - } - } - - let row: [PluginCellValue] = columns.map { column in - guard let value = dict[column] else { return .null } - if let data = value as? Data { - return .bytes(data) - } - return PluginCellValue.fromOptional(BsonDocumentFlattener.stringValue(for: value)) - } - continuation.yield(.rows([row])) - } - - var error = bson_error_t() - if mongoc_cursor_error(cursor, &error) { - cleanup(streamState) - continuation.finish(throwing: makeError(error)) - return - } - - if !headerSent { - continuation.yield(.header(PluginStreamHeader( - columns: ["_id"], - columnTypeNames: ["VARCHAR"] - ))) - } - - cleanup(streamState) - continuation.finish() - } - - private func cleanup(_ state: MongoStreamState) { - state.lock.lock() - let cur = state.cursor - let col = state.collection - let alreadyDrained = state.drained - state.drained = true - state.cursor = nil - state.collection = nil - state.lock.unlock() - guard !alreadyDrained else { return } - if let cur { mongoc_cursor_destroy(cur) } - if let col { mongoc_collection_destroy(col) } - } - - private func bsonTypeToStreamString(_ type: Int32) -> String { - switch type { - case 1: return "FLOAT" - case 2: return "VARCHAR" - case 3: return "JSON" - case 4: return "JSON" - case 5: return "BLOB" - case 7: return "VARCHAR" - case 8: return "BOOLEAN" - case 9: return "TIMESTAMP" - case 10: return "VARCHAR" - case 16: return "INTEGER" - case 18: return "BIGINT" - default: return "VARCHAR" - } - } -} -#endif final class MongoStreamState: @unchecked Sendable { var cursor: OpaquePointer? @@ -1235,7 +787,7 @@ final class MongoStreamState: @unchecked Sendable { // MARK: - BSON Helpers -private extension MongoDBConnection { +extension MongoDBConnection { /// Convert a JSON string to a bson_t pointer. Caller must call bson_destroy on the result. func jsonToBson(_ json: String) -> OpaquePointer? { #if canImport(CLibMongoc) diff --git a/Plugins/OracleDriverPlugin/OraclePlugin.swift b/Plugins/OracleDriverPlugin/OraclePlugin.swift index a1d47de61..0fffc29c8 100644 --- a/Plugins/OracleDriverPlugin/OraclePlugin.swift +++ b/Plugins/OracleDriverPlugin/OraclePlugin.swift @@ -1037,8 +1037,9 @@ final class OraclePluginDriver: PluginDatabaseDriver, @unchecked Sendable { ) -> String? { let quotedTable = oracleQuoteIdentifier(table) var query = "SELECT * FROM \(quotedTable)" - let orderBy = oracleBuildOrderByClause(sortColumns: sortColumns, columns: columns) - ?? "ORDER BY 1" + let orderBy = PluginSQLFilter.buildOrderByClause( + sortColumns: sortColumns, columns: columns, quoteIdentifier: oracleQuoteIdentifier + ) ?? "ORDER BY 1" query += " \(orderBy) OFFSET \(offset) ROWS FETCH NEXT \(limit) ROWS ONLY" return query } @@ -1054,12 +1055,21 @@ final class OraclePluginDriver: PluginDatabaseDriver, @unchecked Sendable { ) -> String? { let quotedTable = oracleQuoteIdentifier(table) var query = "SELECT * FROM \(quotedTable)" - let whereClause = oracleBuildWhereClause(filters: filters, logicMode: logicMode) + let whereClause = PluginSQLFilter.buildWhereClause( + filters: filters, + logicMode: logicMode, + quoteIdentifier: oracleQuoteIdentifier, + escapeValue: oracleEscapeValue, + regexCondition: { quoted, value in + "REGEXP_LIKE(\(quoted), '\(value.replacingOccurrences(of: "'", with: "''"))')" + } + ) if !whereClause.isEmpty { query += " WHERE \(whereClause)" } - let orderBy = oracleBuildOrderByClause(sortColumns: sortColumns, columns: columns) - ?? "ORDER BY 1" + let orderBy = PluginSQLFilter.buildOrderByClause( + sortColumns: sortColumns, columns: columns, quoteIdentifier: oracleQuoteIdentifier + ) ?? "ORDER BY 1" query += " \(orderBy) OFFSET \(offset) ROWS FETCH NEXT \(limit) ROWS ONLY" return query } @@ -1070,29 +1080,6 @@ final class OraclePluginDriver: PluginDatabaseDriver, @unchecked Sendable { "\"\(identifier.replacingOccurrences(of: "\"", with: "\"\""))\"" } - private func oracleBuildOrderByClause( - sortColumns: [(columnIndex: Int, ascending: Bool)], - columns: [String] - ) -> String? { - let parts = sortColumns.compactMap { sortCol -> String? in - guard sortCol.columnIndex >= 0, sortCol.columnIndex < columns.count else { return nil } - let columnName = columns[sortCol.columnIndex] - let direction = sortCol.ascending ? "ASC" : "DESC" - let quotedColumn = oracleQuoteIdentifier(columnName) - return "\(quotedColumn) \(direction)" - } - guard !parts.isEmpty else { return nil } - return "ORDER BY " + parts.joined(separator: ", ") - } - - private func oracleEscapeForLike(_ text: String) -> String { - text - .replacingOccurrences(of: "\\", with: "\\\\") - .replacingOccurrences(of: "%", with: "\\%") - .replacingOccurrences(of: "_", with: "\\_") - .replacingOccurrences(of: "'", with: "''") - } - private func oracleEscapeValue(_ value: String) -> String { let trimmed = value.trimmingCharacters(in: .whitespaces) if trimmed.caseInsensitiveCompare("NULL") == .orderedSame { return "NULL" } @@ -1100,65 +1087,6 @@ final class OraclePluginDriver: PluginDatabaseDriver, @unchecked Sendable { return "'\(trimmed.replacingOccurrences(of: "'", with: "''"))'" } - private func oracleBuildWhereClause( - filters: [(column: String, op: String, value: String)], - logicMode: String - ) -> String { - let conditions = filters.compactMap { filter -> String? in - oracleBuildFilterCondition(column: filter.column, op: filter.op, value: filter.value) - } - guard !conditions.isEmpty else { return "" } - let separator = logicMode == "and" ? " AND " : " OR " - return conditions.joined(separator: separator) - } - - private func oracleBuildFilterCondition(column: String, op: String, value: String) -> String? { - let quoted = oracleQuoteIdentifier(column) - switch op { - case "=": return "\(quoted) = \(oracleEscapeValue(value))" - case "!=": return "\(quoted) != \(oracleEscapeValue(value))" - case ">": return "\(quoted) > \(oracleEscapeValue(value))" - case ">=": return "\(quoted) >= \(oracleEscapeValue(value))" - case "<": return "\(quoted) < \(oracleEscapeValue(value))" - case "<=": return "\(quoted) <= \(oracleEscapeValue(value))" - case "IS NULL": return "\(quoted) IS NULL" - case "IS NOT NULL": return "\(quoted) IS NOT NULL" - case "IS EMPTY": return "(\(quoted) IS NULL OR \(quoted) = '')" - case "IS NOT EMPTY": return "(\(quoted) IS NOT NULL AND \(quoted) != '')" - case "CONTAINS": - let escaped = oracleEscapeForLike(value) - return "\(quoted) LIKE '%\(escaped)%' ESCAPE '\\'" - case "NOT CONTAINS": - let escaped = oracleEscapeForLike(value) - return "\(quoted) NOT LIKE '%\(escaped)%' ESCAPE '\\'" - case "STARTS WITH": - let escaped = oracleEscapeForLike(value) - return "\(quoted) LIKE '\(escaped)%' ESCAPE '\\'" - case "ENDS WITH": - let escaped = oracleEscapeForLike(value) - return "\(quoted) LIKE '%\(escaped)' ESCAPE '\\'" - case "IN": - let values = value.split(separator: ",") - .map { oracleEscapeValue($0.trimmingCharacters(in: .whitespaces)) } - .joined(separator: ", ") - return values.isEmpty ? nil : "\(quoted) IN (\(values))" - case "NOT IN": - let values = value.split(separator: ",") - .map { oracleEscapeValue($0.trimmingCharacters(in: .whitespaces)) } - .joined(separator: ", ") - return values.isEmpty ? nil : "\(quoted) NOT IN (\(values))" - case "BETWEEN": - let parts = value.split(separator: ",", maxSplits: 1) - guard parts.count == 2 else { return nil } - let v1 = oracleEscapeValue(parts[0].trimmingCharacters(in: .whitespaces)) - let v2 = oracleEscapeValue(parts[1].trimmingCharacters(in: .whitespaces)) - return "\(quoted) BETWEEN \(v1) AND \(v2)" - case "REGEX": - let escaped = value.replacingOccurrences(of: "'", with: "''") - return "REGEXP_LIKE(\(quoted), '\(escaped)')" - default: return nil - } - } // MARK: - Private Helpers diff --git a/Plugins/RedisDriverPlugin/RedisPluginDriver+Operations.swift b/Plugins/RedisDriverPlugin/RedisPluginDriver+Operations.swift new file mode 100644 index 000000000..120200c2a --- /dev/null +++ b/Plugins/RedisDriverPlugin/RedisPluginDriver+Operations.swift @@ -0,0 +1,510 @@ +// +// RedisPluginDriver+Operations.swift +// RedisDriverPlugin +// + +import Foundation +import OSLog +import TableProPluginKit + +extension RedisPluginDriver { + func executeOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .get, .set, .del, .keys, .scan, .type, .ttl, .pttl, .expire, .persist, .rename, .exists: + return try await executeKeyOperation(operation, connection: conn, startTime: startTime) + + case .hget, .hset, .hgetall, .hdel: + return try await executeHashOperation(operation, connection: conn, startTime: startTime) + + case .lrange, .lpush, .rpush, .llen: + return try await executeListOperation(operation, connection: conn, startTime: startTime) + + case .smembers, .sadd, .srem, .scard: + return try await executeSetOperation(operation, connection: conn, startTime: startTime) + + case .zrange, .zadd, .zrem, .zcard: + return try await executeSortedSetOperation(operation, connection: conn, startTime: startTime) + + case .xrange, .xlen: + return try await executeStreamOperation(operation, connection: conn, startTime: startTime) + + case .ping, .info, .dbsize, .flushdb, .select, .configGet, .configSet, .command, .multi, .exec, .discard: + return try await executeServerOperation(operation, connection: conn, startTime: startTime) + } + } + + // MARK: - Key Operations + + func executeKeyOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .get(let key): + let result = try await conn.executeCommand(["GET", key]) + let value = result.stringValue + return PluginQueryResult( + columns: ["Key", "Value"], + columnTypeNames: ["String", "String"], + rows: [[key, value].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .set(let key, let value, let options): + var args = ["SET", key, value] + if let opts = options { + if let ex = opts.ex { args += ["EX", String(ex)] } + if let px = opts.px { args += ["PX", String(px)] } + if let exat = opts.exat { args += ["EXAT", String(exat)] } + if let pxat = opts.pxat { args += ["PXAT", String(pxat)] } + if opts.nx { args.append("NX") } + if opts.xx { args.append("XX") } + } + _ = try await conn.executeCommand(args) + return buildStatusResult("OK", startTime: startTime) + + case .del(let keys): + let args = ["DEL"] + keys + let result = try await conn.executeCommand(args) + let deleted = result.intValue ?? 0 + return PluginQueryResult( + columns: ["deleted"], + columnTypeNames: ["Int64"], + rows: [[String(deleted)].asCells], + rowsAffected: deleted, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .keys(let pattern): + let result = try await conn.executeCommand(["KEYS", pattern]) + guard let items = result.arrayValue else { + return buildEmptyKeyResult(startTime: startTime) + } + let keys = items.map { redisReplyToString($0) } + let capped = Array(keys.prefix(PluginRowLimits.emergencyMax)) + let keysTruncated = keys.count > PluginRowLimits.emergencyMax + return try await buildKeyBrowseResult( + keys: capped, connection: conn, startTime: startTime, isTruncated: keysTruncated + ) + + case .scan(let cursor, let pattern, let count): + var args = ["SCAN", String(cursor)] + if let p = pattern { args += ["MATCH", p] } + if let c = count { args += ["COUNT", String(c)] } + let result = try await conn.executeCommand(args) + return try await handleScanResult(result, connection: conn, startTime: startTime) + + case .type(let key): + let result = try await conn.executeCommand(["TYPE", key]) + let typeName = result.stringValue ?? "none" + return PluginQueryResult( + columns: ["Key", "Type"], + columnTypeNames: ["String", "String"], + rows: [[key, typeName].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .ttl(let key): + let result = try await conn.executeCommand(["TTL", key]) + let ttl = result.intValue ?? -1 + return PluginQueryResult( + columns: ["Key", "TTL"], + columnTypeNames: ["String", "Int64"], + rows: [[key, String(ttl)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .pttl(let key): + let result = try await conn.executeCommand(["PTTL", key]) + let pttl = result.intValue ?? -1 + return PluginQueryResult( + columns: ["Key", "PTTL"], + columnTypeNames: ["String", "Int64"], + rows: [[key, String(pttl)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .expire(let key, let seconds): + let result = try await conn.executeCommand(["EXPIRE", key, String(seconds)]) + let success = (result.intValue ?? 0) == 1 + return buildStatusResult(success ? "OK" : "Key not found", startTime: startTime) + + case .persist(let key): + let result = try await conn.executeCommand(["PERSIST", key]) + let success = (result.intValue ?? 0) == 1 + return buildStatusResult(success ? "OK" : "Key not found or no TTL", startTime: startTime) + + case .rename(let key, let newKey): + let reply = try await conn.executeCommand(["RENAME", key, newKey]) + if case .error(let msg) = reply { + throw RedisPluginError(code: 0, message: "RENAME failed: \(msg)") + } + return buildStatusResult("OK", startTime: startTime) + + case .exists(let keys): + let args = ["EXISTS"] + keys + let result = try await conn.executeCommand(args) + let count = result.intValue ?? 0 + return PluginQueryResult( + columns: ["exists"], + columnTypeNames: ["Int64"], + rows: [[String(count)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeKeyOperation") + } + } + + // MARK: - Hash Operations + + func executeHashOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .hget(let key, let field): + let result = try await conn.executeCommand(["HGET", key, field]) + let value = result.stringValue + return PluginQueryResult( + columns: ["Field", "Value"], + columnTypeNames: ["String", "String"], + rows: [[field, value].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .hset(let key, let fieldValues): + var args = ["HSET", key] + for (field, value) in fieldValues { + args += [field, value] + } + let result = try await conn.executeCommand(args) + let added = result.intValue ?? 0 + return PluginQueryResult( + columns: ["added"], + columnTypeNames: ["Int64"], + rows: [[String(added)].asCells], + rowsAffected: added, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .hgetall(let key): + let result = try await conn.executeCommand(["HGETALL", key]) + return buildHashResult(result, startTime: startTime) + + case .hdel(let key, let fields): + let args = ["HDEL", key] + fields + let result = try await conn.executeCommand(args) + let removed = result.intValue ?? 0 + return PluginQueryResult( + columns: ["removed"], + columnTypeNames: ["Int64"], + rows: [[String(removed)].asCells], + rowsAffected: removed, + executionTime: Date().timeIntervalSince(startTime) + ) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeHashOperation") + } + } + + // MARK: - List Operations + + func executeListOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .lrange(let key, let start, let stop): + let result = try await conn.executeCommand(["LRANGE", key, String(start), String(stop)]) + return buildListResult(result, startOffset: start, startTime: startTime) + + case .lpush(let key, let values): + let args = ["LPUSH", key] + values + let result = try await conn.executeCommand(args) + let length = result.intValue ?? 0 + return PluginQueryResult( + columns: ["length"], + columnTypeNames: ["Int64"], + rows: [[String(length)].asCells], + rowsAffected: values.count, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .rpush(let key, let values): + let args = ["RPUSH", key] + values + let result = try await conn.executeCommand(args) + let length = result.intValue ?? 0 + return PluginQueryResult( + columns: ["length"], + columnTypeNames: ["Int64"], + rows: [[String(length)].asCells], + rowsAffected: values.count, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .llen(let key): + let result = try await conn.executeCommand(["LLEN", key]) + let length = result.intValue ?? 0 + return PluginQueryResult( + columns: ["Key", "Length"], + columnTypeNames: ["String", "Int64"], + rows: [[key, String(length)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeListOperation") + } + } + + // MARK: - Set Operations + + func executeSetOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .smembers(let key): + let result = try await conn.executeCommand(["SMEMBERS", key]) + return buildSetResult(result, startTime: startTime) + + case .sadd(let key, let members): + let args = ["SADD", key] + members + let result = try await conn.executeCommand(args) + let added = result.intValue ?? 0 + return PluginQueryResult( + columns: ["added"], + columnTypeNames: ["Int64"], + rows: [[String(added)].asCells], + rowsAffected: added, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .srem(let key, let members): + let args = ["SREM", key] + members + let result = try await conn.executeCommand(args) + let removed = result.intValue ?? 0 + return PluginQueryResult( + columns: ["removed"], + columnTypeNames: ["Int64"], + rows: [[String(removed)].asCells], + rowsAffected: removed, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .scard(let key): + let result = try await conn.executeCommand(["SCARD", key]) + let count = result.intValue ?? 0 + return PluginQueryResult( + columns: ["Key", "Cardinality"], + columnTypeNames: ["String", "Int64"], + rows: [[key, String(count)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeSetOperation") + } + } + + // MARK: - Sorted Set Operations + + func executeSortedSetOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .zrange(let key, let start, let stop, let flags): + var args = ["ZRANGE", key, start, stop] + args += flags + let withScores = flags.contains("WITHSCORES") + let result = try await conn.executeCommand(args) + return buildSortedSetResult(result, withScores: withScores, startTime: startTime) + + case .zadd(let key, let flags, let scoreMembers): + var args = ["ZADD", key] + args += flags + for (score, member) in scoreMembers { + args += [String(score), member] + } + let result = try await conn.executeCommand(args) + if flags.contains("INCR") { + // INCR mode returns the new score (or nil for NX miss) + let scoreStr = result.stringValue ?? "nil" + return PluginQueryResult( + columns: ["score"], + columnTypeNames: ["String"], + rows: [[scoreStr].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + let count = result.intValue ?? 0 + let columnName = flags.contains("CH") ? "changed" : "added" + return PluginQueryResult( + columns: [columnName], + columnTypeNames: ["Int64"], + rows: [[String(count)].asCells], + rowsAffected: count, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .zrem(let key, let members): + let args = ["ZREM", key] + members + let result = try await conn.executeCommand(args) + let removed = result.intValue ?? 0 + return PluginQueryResult( + columns: ["removed"], + columnTypeNames: ["Int64"], + rows: [[String(removed)].asCells], + rowsAffected: removed, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .zcard(let key): + let result = try await conn.executeCommand(["ZCARD", key]) + let count = result.intValue ?? 0 + return PluginQueryResult( + columns: ["Key", "Cardinality"], + columnTypeNames: ["String", "Int64"], + rows: [[key, String(count)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeSortedSetOperation") + } + } + + // MARK: - Stream Operations + + func executeStreamOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .xrange(let key, let start, let end, let count): + var args = ["XRANGE", key, start, end] + if let c = count { args += ["COUNT", String(c)] } + let result = try await conn.executeCommand(args) + return buildStreamResult(result, startTime: startTime) + + case .xlen(let key): + let result = try await conn.executeCommand(["XLEN", key]) + let length = result.intValue ?? 0 + return PluginQueryResult( + columns: ["Key", "Length"], + columnTypeNames: ["String", "Int64"], + rows: [[key, String(length)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeStreamOperation") + } + } + + // MARK: - Server Operations + + func executeServerOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .ping: + _ = try await conn.executeCommand(["PING"]) + return PluginQueryResult( + columns: ["ok"], + columnTypeNames: ["Int32"], + rows: [["1"].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .info(let section): + var args = ["INFO"] + if let s = section { args.append(s) } + let result = try await conn.executeCommand(args) + let infoText = result.stringValue ?? String(describing: result) + return PluginQueryResult( + columns: ["info"], + columnTypeNames: ["String"], + rows: [[infoText].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .dbsize: + let result = try await conn.executeCommand(["DBSIZE"]) + let count = result.intValue ?? 0 + return PluginQueryResult( + columns: ["keys"], + columnTypeNames: ["Int64"], + rows: [[String(count)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .flushdb: + _ = try await conn.executeCommand(["FLUSHDB"]) + return buildStatusResult("OK", startTime: startTime) + + case .select(let database): + try await conn.selectDatabase(database) + cachedScanPattern = nil + cachedScanKeys = nil + return buildStatusResult("OK", startTime: startTime) + + case .configGet(let parameter): + let result = try await conn.executeCommand(["CONFIG", "GET", parameter]) + return buildConfigResult(result, startTime: startTime) + + case .configSet(let parameter, let value): + _ = try await conn.executeCommand(["CONFIG", "SET", parameter, value]) + return buildStatusResult("OK", startTime: startTime) + + case .command(let args): + let result = try await conn.executeCommand(args) + return buildGenericResult(result, startTime: startTime) + + case .multi: + _ = try await conn.executeCommand(["MULTI"]) + return buildStatusResult("OK", startTime: startTime) + + case .exec: + let result = try await conn.executeCommand(["EXEC"]) + return buildGenericResult(result, startTime: startTime) + + case .discard: + _ = try await conn.executeCommand(["DISCARD"]) + return buildStatusResult("OK", startTime: startTime) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeServerOperation") + } + } +} diff --git a/Plugins/RedisDriverPlugin/RedisPluginDriver+ResultBuilding.swift b/Plugins/RedisDriverPlugin/RedisPluginDriver+ResultBuilding.swift new file mode 100644 index 000000000..61d160ccd --- /dev/null +++ b/Plugins/RedisDriverPlugin/RedisPluginDriver+ResultBuilding.swift @@ -0,0 +1,486 @@ +// +// RedisPluginDriver+ResultBuilding.swift +// RedisDriverPlugin +// + +import Foundation +import OSLog +import TableProPluginKit + +extension RedisPluginDriver { + static let previewLimit = 100 + static let previewMaxChars = 1_000 + + func buildKeyBrowseResult( + keys: [String], + connection conn: RedisPluginConnection, + startTime: Date, + isTruncated: Bool = false + ) async throws -> PluginQueryResult { + guard !keys.isEmpty else { + return buildEmptyKeyResult(startTime: startTime) + } + + var typeAndTtlCommands: [[String]] = [] + typeAndTtlCommands.reserveCapacity(keys.count * 2) + for key in keys { + typeAndTtlCommands.append(["TYPE", key]) + typeAndTtlCommands.append(["TTL", key]) + } + let typeAndTtlReplies = try await conn.executePipeline(typeAndTtlCommands) + + var typeNames: [String] = [] + typeNames.reserveCapacity(keys.count) + var ttlValues: [Int] = [] + ttlValues.reserveCapacity(keys.count) + for i in 0 ..< keys.count { + let typeName = (typeAndTtlReplies[i * 2].stringValue ?? "unknown").uppercased() + let ttl = typeAndTtlReplies[i * 2 + 1].intValue ?? -1 + typeNames.append(typeName) + ttlValues.append(ttl) + } + + var previewCommands: [[String]] = [] + previewCommands.reserveCapacity(keys.count) + var previewCommandIndices: [Int] = [] + previewCommandIndices.reserveCapacity(keys.count) + + for (i, key) in keys.enumerated() { + let command: [String]? = previewCommandForType(typeNames[i], key: key) + if let command { + previewCommandIndices.append(previewCommands.count) + previewCommands.append(command) + } else { + previewCommandIndices.append(-1) + } + } + + var previewReplies: [RedisReply] = [] + if !previewCommands.isEmpty { + previewReplies = try await conn.executePipeline(previewCommands) + } + + var rows: [[PluginCellValue]] = [] + rows.reserveCapacity(keys.count) + for (i, key) in keys.enumerated() { + let ttlStr = String(ttlValues[i]) + let pipelineIndex = previewCommandIndices[i] + let preview: String? + if pipelineIndex >= 0, pipelineIndex < previewReplies.count { + preview = formatPreviewReply( + previewReplies[pipelineIndex], type: typeNames[i] + ) + } else { + preview = nil + } + rows.append([key, typeNames[i], ttlStr, preview].asCells) + } + + return PluginQueryResult( + columns: ["Key", "Type", "TTL", "Value"], + columnTypeNames: ["String", "RedisType", "RedisInt", "RedisRaw"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime), + isTruncated: isTruncated + ) + } + + func previewCommandForType(_ type: String, key: String) -> [String]? { + switch type.lowercased() { + case "string": + return ["GET", key] + case "hash": + return ["HSCAN", key, "0", "COUNT", String(Self.previewLimit)] + case "list": + return ["LRANGE", key, "0", String(Self.previewLimit - 1)] + case "set": + return ["SSCAN", key, "0", "COUNT", String(Self.previewLimit)] + case "zset": + return ["ZRANGE", key, "0", String(Self.previewLimit - 1), "WITHSCORES"] + case "stream": + return ["XREVRANGE", key, "+", "-", "COUNT", "5"] + default: + return nil + } + } + + func formatPreviewReply(_ reply: RedisReply, type: String) -> String? { + switch type.lowercased() { + case "string": + return truncatePreview(redisReplyToString(reply)) + + case "hash": + let array: [RedisReply] + if case .array(let scanResult) = reply, + scanResult.count == 2, + let items = scanResult[1].arrayValue { + array = items + } else if let items = reply.arrayValue, !items.isEmpty { + array = items + } else { + return "{}" + } + guard !array.isEmpty else { return "{}" } + var pairs: [String] = [] + var idx = 0 + while idx + 1 < array.count { + let field = redisReplyToString(array[idx]) + let value = redisReplyToString(array[idx + 1]) + pairs.append( + "\"\(escapeJsonString(field))\":\"\(escapeJsonString(value))\"" + ) + idx += 2 + } + return truncatePreview("{\(pairs.joined(separator: ","))}") + + case "list": + guard let items = reply.arrayValue else { return "[]" } + let quoted = items.map { "\"\(escapeJsonString(redisReplyToString($0)))\"" } + return truncatePreview("[\(quoted.joined(separator: ", "))]") + + case "set": + let members: [RedisReply] + if case .array(let scanResult) = reply, + scanResult.count == 2, + let items = scanResult[1].arrayValue { + members = items + } else if let items = reply.arrayValue { + members = items + } else { + return "[]" + } + let quoted = members.map { "\"\(escapeJsonString(redisReplyToString($0)))\"" } + return truncatePreview("[\(quoted.joined(separator: ", "))]") + + case "zset": + // Parse WITHSCORES result: alternating member, score pairs + guard let items = reply.arrayValue, !items.isEmpty else { return "[]" } + var pairs: [String] = [] + var i = 0 + while i + 1 < items.count { + pairs.append("\(redisReplyToString(items[i])):\(redisReplyToString(items[i + 1]))") + i += 2 + } + return truncatePreview(pairs.joined(separator: ", ")) + + case "stream": + // Parse XREVRANGE result: array of [id, [field, value, ...]] entries + guard let entries = reply.arrayValue, !entries.isEmpty else { + return "(0 entries)" + } + var entryStrings: [String] = [] + for entry in entries { + guard let parts = entry.arrayValue, parts.count >= 2, + let fields = parts[1].arrayValue else { + continue + } + let entryId = redisReplyToString(parts[0]) + var fieldPairs: [String] = [] + var j = 0 + while j + 1 < fields.count { + fieldPairs.append("\(redisReplyToString(fields[j]))=\(redisReplyToString(fields[j + 1]))") + j += 2 + } + entryStrings.append("\(entryId): \(fieldPairs.joined(separator: ", "))") + } + return truncatePreview(entryStrings.joined(separator: "; ")) + + default: + return nil + } + } + + func truncatePreview(_ value: String?) -> String? { + guard let value else { return nil } + let nsValue = value as NSString + if nsValue.length > Self.previewMaxChars { + return nsValue.substring(to: Self.previewMaxChars) + "..." + } + return value + } + + func escapeJsonString(_ str: String) -> String { + var result = "" + for scalar in str.unicodeScalars { + switch scalar { + case "\\": result += "\\\\" + case "\"": result += "\\\"" + case "\n": result += "\\n" + case "\r": result += "\\r" + case "\t": result += "\\t" + default: + if scalar.value < 0x20 { + result += String(format: "\\u%04X", scalar.value) + } else { + result += String(scalar) + } + } + } + return result + } + + func buildEmptyKeyResult(startTime: Date) -> PluginQueryResult { + PluginQueryResult( + columns: ["Key", "Type", "TTL", "Value"], + columnTypeNames: ["String", "RedisType", "RedisInt", "RedisRaw"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + func buildStatusResult(_ message: String, startTime: Date) -> PluginQueryResult { + PluginQueryResult( + columns: ["status"], + columnTypeNames: ["String"], + rows: [[message].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + func buildGenericResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { + switch result { + case .string(let s), .status(let s): + return PluginQueryResult( + columns: ["result"], + columnTypeNames: ["String"], + rows: [[s].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .integer(let i): + return PluginQueryResult( + columns: ["result"], + columnTypeNames: ["Int64"], + rows: [[String(i)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .data(let d): + let str = String(data: d, encoding: .utf8) ?? d.base64EncodedString() + return PluginQueryResult( + columns: ["result"], + columnTypeNames: ["String"], + rows: [[str].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .array(let items): + let rows = items.map { ([redisReplyToString($0)] as [String?]).asCells } + return PluginQueryResult( + columns: ["result"], + columnTypeNames: ["String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .error(let e): + return PluginQueryResult( + columns: ["result"], + columnTypeNames: ["String"], + rows: [[e].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .null: + return PluginQueryResult( + columns: ["result"], + columnTypeNames: ["String"], + rows: [["(nil)"].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + } + + func redisReplyToString(_ reply: RedisReply) -> String { + switch reply { + case .string(let s), .status(let s), .error(let s): return s + case .integer(let i): return String(i) + case .data(let d): return String(data: d, encoding: .utf8) ?? d.base64EncodedString() + case .array(let items): return "[\(items.map { redisReplyToString($0) }.joined(separator: ", "))]" + case .null: return "(nil)" + } + } + + func buildHashResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { + guard let items = result.arrayValue, !items.isEmpty else { + return PluginQueryResult( + columns: ["Field", "Value"], + columnTypeNames: ["String", "String"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + var rows: [[PluginCellValue]] = [] + var i = 0 + while i + 1 < items.count { + rows.append([redisReplyToString(items[i]), redisReplyToString(items[i + 1])].asCells) + i += 2 + } + + return PluginQueryResult( + columns: ["Field", "Value"], + columnTypeNames: ["String", "String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + func buildListResult(_ result: RedisReply, startOffset: Int = 0, startTime: Date) -> PluginQueryResult { + guard let items = result.arrayValue else { + return PluginQueryResult( + columns: ["Index", "Value"], + columnTypeNames: ["Int64", "String"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + let rows = items.enumerated().map { index, item -> [PluginCellValue] in + ([String(startOffset + index), redisReplyToString(item)] as [String?]).asCells + } + + return PluginQueryResult( + columns: ["Index", "Value"], + columnTypeNames: ["Int64", "String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + func buildSetResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { + guard let items = result.arrayValue else { + return PluginQueryResult( + columns: ["Member"], + columnTypeNames: ["String"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + let rows = items.map { ([redisReplyToString($0)] as [String?]).asCells } + + return PluginQueryResult( + columns: ["Member"], + columnTypeNames: ["String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + func buildSortedSetResult(_ result: RedisReply, withScores: Bool, startTime: Date) -> PluginQueryResult { + guard let items = result.arrayValue else { + return PluginQueryResult( + columns: withScores ? ["Member", "Score"] : ["Member"], + columnTypeNames: withScores ? ["String", "Double"] : ["String"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + if withScores { + var rows: [[PluginCellValue]] = [] + var i = 0 + while i + 1 < items.count { + rows.append([redisReplyToString(items[i]), redisReplyToString(items[i + 1])].asCells) + i += 2 + } + return PluginQueryResult( + columns: ["Member", "Score"], + columnTypeNames: ["String", "Double"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } else { + let rows = items.map { ([redisReplyToString($0)] as [String?]).asCells } + return PluginQueryResult( + columns: ["Member"], + columnTypeNames: ["String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + } + + func buildStreamResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { + guard let entries = result.arrayValue else { + return PluginQueryResult( + columns: ["ID", "Fields"], + columnTypeNames: ["String", "String"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + var rows: [[PluginCellValue]] = [] + for entry in entries { + guard let entryParts = entry.arrayValue, entryParts.count >= 2, + let fields = entryParts[1].arrayValue else { + continue + } + let entryId = redisReplyToString(entryParts[0]) + + var fieldPairs: [String] = [] + var i = 0 + while i + 1 < fields.count { + fieldPairs.append("\(redisReplyToString(fields[i]))=\(redisReplyToString(fields[i + 1]))") + i += 2 + } + rows.append([entryId, fieldPairs.joined(separator: ", ")].asCells) + } + + return PluginQueryResult( + columns: ["ID", "Fields"], + columnTypeNames: ["String", "String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + func buildConfigResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { + guard let items = result.arrayValue, !items.isEmpty else { + return PluginQueryResult( + columns: ["Parameter", "Value"], + columnTypeNames: ["String", "String"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + var rows: [[PluginCellValue]] = [] + var i = 0 + while i + 1 < items.count { + rows.append([redisReplyToString(items[i]), redisReplyToString(items[i + 1])].asCells) + i += 2 + } + + return PluginQueryResult( + columns: ["Parameter", "Value"], + columnTypeNames: ["String", "String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } +} diff --git a/Plugins/RedisDriverPlugin/RedisPluginDriver+Scan.swift b/Plugins/RedisDriverPlugin/RedisPluginDriver+Scan.swift new file mode 100644 index 000000000..ebd86c01b --- /dev/null +++ b/Plugins/RedisDriverPlugin/RedisPluginDriver+Scan.swift @@ -0,0 +1,85 @@ +// +// RedisPluginDriver+Scan.swift +// RedisDriverPlugin +// + +import Foundation +import OSLog +import TableProPluginKit + +extension RedisPluginDriver { + func scanAllKeys( + connection conn: RedisPluginConnection, + pattern: String?, + maxKeys: Int + ) async throws -> [String] { + var allKeys: [String] = [] + var cursor = "0" + + repeat { + var args = ["SCAN", cursor] + if let p = pattern { + args += ["MATCH", p] + } + args += ["COUNT", "1000"] + + let result = try await conn.executeCommand(args) + + guard case .array(let scanResult) = result, + scanResult.count == 2 else { + break + } + + let nextCursor: String + switch scanResult[0] { + case .string(let s): nextCursor = s + case .status(let s): nextCursor = s + case .data(let d): nextCursor = String(data: d, encoding: .utf8) ?? "0" + default: nextCursor = "0" + } + cursor = nextCursor + + if case .array(let keyReplies) = scanResult[1] { + for reply in keyReplies { + switch reply { + case .string(let k): allKeys.append(k) + case .data(let d): + if let k = String(data: d, encoding: .utf8) { allKeys.append(k) } + default: break + } + } + } + + if allKeys.count >= maxKeys { + allKeys = Array(allKeys.prefix(maxKeys)) + break + } + } while cursor != "0" + + return allKeys.sorted() + } + + func handleScanResult( + _ result: RedisReply, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + guard case .array(let scanResult) = result, + scanResult.count == 2, + case .array(let keyReplies) = scanResult[1] else { + return buildEmptyKeyResult(startTime: startTime) + } + + let keys = keyReplies.compactMap { reply -> String? in + if case .string(let k) = reply { return k } + if case .data(let d) = reply { return String(data: d, encoding: .utf8) } + return nil + } + + let capped = Array(keys.prefix(PluginRowLimits.emergencyMax)) + let keysTruncated = keys.count > PluginRowLimits.emergencyMax + return try await buildKeyBrowseResult( + keys: capped, connection: conn, startTime: startTime, isTruncated: keysTruncated + ) + } +} diff --git a/Plugins/RedisDriverPlugin/RedisPluginDriver.swift b/Plugins/RedisDriverPlugin/RedisPluginDriver.swift index f38c68a57..3e1954d1a 100644 --- a/Plugins/RedisDriverPlugin/RedisPluginDriver.swift +++ b/Plugins/RedisDriverPlugin/RedisPluginDriver.swift @@ -11,19 +11,19 @@ import Foundation import OSLog import TableProPluginKit -private extension Array where Element == String? { +extension Array where Element == String? { var asCells: [PluginCellValue] { map(PluginCellValue.fromOptional) } } -private extension Array where Element == String { +extension Array where Element == String { var asCells: [PluginCellValue] { map(PluginCellValue.text) } } -private extension Array where Element == [String?] { +extension Array where Element == [String?] { var asCellRows: [[PluginCellValue]] { map { $0.map(PluginCellValue.fromOptional) } } } -private extension Array where Element == [String] { +extension Array where Element == [String] { var asCellRows: [[PluginCellValue]] { map { $0.map(PluginCellValue.text) } } } @@ -35,8 +35,8 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { private static let maxScanKeys = PluginRowLimits.emergencyMax - private var cachedScanPattern: String? - private var cachedScanKeys: [String]? + var cachedScanPattern: String? + var cachedScanKeys: [String]? var serverVersion: String? { redisConnection?.serverVersion() @@ -630,1066 +630,3 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { ) } } - -// MARK: - Operation Dispatch - -private extension RedisPluginDriver { - func executeOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .get, .set, .del, .keys, .scan, .type, .ttl, .pttl, .expire, .persist, .rename, .exists: - return try await executeKeyOperation(operation, connection: conn, startTime: startTime) - - case .hget, .hset, .hgetall, .hdel: - return try await executeHashOperation(operation, connection: conn, startTime: startTime) - - case .lrange, .lpush, .rpush, .llen: - return try await executeListOperation(operation, connection: conn, startTime: startTime) - - case .smembers, .sadd, .srem, .scard: - return try await executeSetOperation(operation, connection: conn, startTime: startTime) - - case .zrange, .zadd, .zrem, .zcard: - return try await executeSortedSetOperation(operation, connection: conn, startTime: startTime) - - case .xrange, .xlen: - return try await executeStreamOperation(operation, connection: conn, startTime: startTime) - - case .ping, .info, .dbsize, .flushdb, .select, .configGet, .configSet, .command, .multi, .exec, .discard: - return try await executeServerOperation(operation, connection: conn, startTime: startTime) - } - } - - // MARK: - Key Operations - - func executeKeyOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .get(let key): - let result = try await conn.executeCommand(["GET", key]) - let value = result.stringValue - return PluginQueryResult( - columns: ["Key", "Value"], - columnTypeNames: ["String", "String"], - rows: [[key, value].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .set(let key, let value, let options): - var args = ["SET", key, value] - if let opts = options { - if let ex = opts.ex { args += ["EX", String(ex)] } - if let px = opts.px { args += ["PX", String(px)] } - if let exat = opts.exat { args += ["EXAT", String(exat)] } - if let pxat = opts.pxat { args += ["PXAT", String(pxat)] } - if opts.nx { args.append("NX") } - if opts.xx { args.append("XX") } - } - _ = try await conn.executeCommand(args) - return buildStatusResult("OK", startTime: startTime) - - case .del(let keys): - let args = ["DEL"] + keys - let result = try await conn.executeCommand(args) - let deleted = result.intValue ?? 0 - return PluginQueryResult( - columns: ["deleted"], - columnTypeNames: ["Int64"], - rows: [[String(deleted)].asCells], - rowsAffected: deleted, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .keys(let pattern): - let result = try await conn.executeCommand(["KEYS", pattern]) - guard let items = result.arrayValue else { - return buildEmptyKeyResult(startTime: startTime) - } - let keys = items.map { redisReplyToString($0) } - let capped = Array(keys.prefix(PluginRowLimits.emergencyMax)) - let keysTruncated = keys.count > PluginRowLimits.emergencyMax - return try await buildKeyBrowseResult( - keys: capped, connection: conn, startTime: startTime, isTruncated: keysTruncated - ) - - case .scan(let cursor, let pattern, let count): - var args = ["SCAN", String(cursor)] - if let p = pattern { args += ["MATCH", p] } - if let c = count { args += ["COUNT", String(c)] } - let result = try await conn.executeCommand(args) - return try await handleScanResult(result, connection: conn, startTime: startTime) - - case .type(let key): - let result = try await conn.executeCommand(["TYPE", key]) - let typeName = result.stringValue ?? "none" - return PluginQueryResult( - columns: ["Key", "Type"], - columnTypeNames: ["String", "String"], - rows: [[key, typeName].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .ttl(let key): - let result = try await conn.executeCommand(["TTL", key]) - let ttl = result.intValue ?? -1 - return PluginQueryResult( - columns: ["Key", "TTL"], - columnTypeNames: ["String", "Int64"], - rows: [[key, String(ttl)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .pttl(let key): - let result = try await conn.executeCommand(["PTTL", key]) - let pttl = result.intValue ?? -1 - return PluginQueryResult( - columns: ["Key", "PTTL"], - columnTypeNames: ["String", "Int64"], - rows: [[key, String(pttl)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .expire(let key, let seconds): - let result = try await conn.executeCommand(["EXPIRE", key, String(seconds)]) - let success = (result.intValue ?? 0) == 1 - return buildStatusResult(success ? "OK" : "Key not found", startTime: startTime) - - case .persist(let key): - let result = try await conn.executeCommand(["PERSIST", key]) - let success = (result.intValue ?? 0) == 1 - return buildStatusResult(success ? "OK" : "Key not found or no TTL", startTime: startTime) - - case .rename(let key, let newKey): - let reply = try await conn.executeCommand(["RENAME", key, newKey]) - if case .error(let msg) = reply { - throw RedisPluginError(code: 0, message: "RENAME failed: \(msg)") - } - return buildStatusResult("OK", startTime: startTime) - - case .exists(let keys): - let args = ["EXISTS"] + keys - let result = try await conn.executeCommand(args) - let count = result.intValue ?? 0 - return PluginQueryResult( - columns: ["exists"], - columnTypeNames: ["Int64"], - rows: [[String(count)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeKeyOperation") - } - } - - // MARK: - Hash Operations - - func executeHashOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .hget(let key, let field): - let result = try await conn.executeCommand(["HGET", key, field]) - let value = result.stringValue - return PluginQueryResult( - columns: ["Field", "Value"], - columnTypeNames: ["String", "String"], - rows: [[field, value].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .hset(let key, let fieldValues): - var args = ["HSET", key] - for (field, value) in fieldValues { - args += [field, value] - } - let result = try await conn.executeCommand(args) - let added = result.intValue ?? 0 - return PluginQueryResult( - columns: ["added"], - columnTypeNames: ["Int64"], - rows: [[String(added)].asCells], - rowsAffected: added, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .hgetall(let key): - let result = try await conn.executeCommand(["HGETALL", key]) - return buildHashResult(result, startTime: startTime) - - case .hdel(let key, let fields): - let args = ["HDEL", key] + fields - let result = try await conn.executeCommand(args) - let removed = result.intValue ?? 0 - return PluginQueryResult( - columns: ["removed"], - columnTypeNames: ["Int64"], - rows: [[String(removed)].asCells], - rowsAffected: removed, - executionTime: Date().timeIntervalSince(startTime) - ) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeHashOperation") - } - } - - // MARK: - List Operations - - func executeListOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .lrange(let key, let start, let stop): - let result = try await conn.executeCommand(["LRANGE", key, String(start), String(stop)]) - return buildListResult(result, startOffset: start, startTime: startTime) - - case .lpush(let key, let values): - let args = ["LPUSH", key] + values - let result = try await conn.executeCommand(args) - let length = result.intValue ?? 0 - return PluginQueryResult( - columns: ["length"], - columnTypeNames: ["Int64"], - rows: [[String(length)].asCells], - rowsAffected: values.count, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .rpush(let key, let values): - let args = ["RPUSH", key] + values - let result = try await conn.executeCommand(args) - let length = result.intValue ?? 0 - return PluginQueryResult( - columns: ["length"], - columnTypeNames: ["Int64"], - rows: [[String(length)].asCells], - rowsAffected: values.count, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .llen(let key): - let result = try await conn.executeCommand(["LLEN", key]) - let length = result.intValue ?? 0 - return PluginQueryResult( - columns: ["Key", "Length"], - columnTypeNames: ["String", "Int64"], - rows: [[key, String(length)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeListOperation") - } - } - - // MARK: - Set Operations - - func executeSetOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .smembers(let key): - let result = try await conn.executeCommand(["SMEMBERS", key]) - return buildSetResult(result, startTime: startTime) - - case .sadd(let key, let members): - let args = ["SADD", key] + members - let result = try await conn.executeCommand(args) - let added = result.intValue ?? 0 - return PluginQueryResult( - columns: ["added"], - columnTypeNames: ["Int64"], - rows: [[String(added)].asCells], - rowsAffected: added, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .srem(let key, let members): - let args = ["SREM", key] + members - let result = try await conn.executeCommand(args) - let removed = result.intValue ?? 0 - return PluginQueryResult( - columns: ["removed"], - columnTypeNames: ["Int64"], - rows: [[String(removed)].asCells], - rowsAffected: removed, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .scard(let key): - let result = try await conn.executeCommand(["SCARD", key]) - let count = result.intValue ?? 0 - return PluginQueryResult( - columns: ["Key", "Cardinality"], - columnTypeNames: ["String", "Int64"], - rows: [[key, String(count)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeSetOperation") - } - } - - // MARK: - Sorted Set Operations - - func executeSortedSetOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .zrange(let key, let start, let stop, let flags): - var args = ["ZRANGE", key, start, stop] - args += flags - let withScores = flags.contains("WITHSCORES") - let result = try await conn.executeCommand(args) - return buildSortedSetResult(result, withScores: withScores, startTime: startTime) - - case .zadd(let key, let flags, let scoreMembers): - var args = ["ZADD", key] - args += flags - for (score, member) in scoreMembers { - args += [String(score), member] - } - let result = try await conn.executeCommand(args) - if flags.contains("INCR") { - // INCR mode returns the new score (or nil for NX miss) - let scoreStr = result.stringValue ?? "nil" - return PluginQueryResult( - columns: ["score"], - columnTypeNames: ["String"], - rows: [[scoreStr].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - let count = result.intValue ?? 0 - let columnName = flags.contains("CH") ? "changed" : "added" - return PluginQueryResult( - columns: [columnName], - columnTypeNames: ["Int64"], - rows: [[String(count)].asCells], - rowsAffected: count, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .zrem(let key, let members): - let args = ["ZREM", key] + members - let result = try await conn.executeCommand(args) - let removed = result.intValue ?? 0 - return PluginQueryResult( - columns: ["removed"], - columnTypeNames: ["Int64"], - rows: [[String(removed)].asCells], - rowsAffected: removed, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .zcard(let key): - let result = try await conn.executeCommand(["ZCARD", key]) - let count = result.intValue ?? 0 - return PluginQueryResult( - columns: ["Key", "Cardinality"], - columnTypeNames: ["String", "Int64"], - rows: [[key, String(count)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeSortedSetOperation") - } - } - - // MARK: - Stream Operations - - func executeStreamOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .xrange(let key, let start, let end, let count): - var args = ["XRANGE", key, start, end] - if let c = count { args += ["COUNT", String(c)] } - let result = try await conn.executeCommand(args) - return buildStreamResult(result, startTime: startTime) - - case .xlen(let key): - let result = try await conn.executeCommand(["XLEN", key]) - let length = result.intValue ?? 0 - return PluginQueryResult( - columns: ["Key", "Length"], - columnTypeNames: ["String", "Int64"], - rows: [[key, String(length)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeStreamOperation") - } - } - - // MARK: - Server Operations - - func executeServerOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .ping: - _ = try await conn.executeCommand(["PING"]) - return PluginQueryResult( - columns: ["ok"], - columnTypeNames: ["Int32"], - rows: [["1"].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .info(let section): - var args = ["INFO"] - if let s = section { args.append(s) } - let result = try await conn.executeCommand(args) - let infoText = result.stringValue ?? String(describing: result) - return PluginQueryResult( - columns: ["info"], - columnTypeNames: ["String"], - rows: [[infoText].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .dbsize: - let result = try await conn.executeCommand(["DBSIZE"]) - let count = result.intValue ?? 0 - return PluginQueryResult( - columns: ["keys"], - columnTypeNames: ["Int64"], - rows: [[String(count)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .flushdb: - _ = try await conn.executeCommand(["FLUSHDB"]) - return buildStatusResult("OK", startTime: startTime) - - case .select(let database): - try await conn.selectDatabase(database) - cachedScanPattern = nil - cachedScanKeys = nil - return buildStatusResult("OK", startTime: startTime) - - case .configGet(let parameter): - let result = try await conn.executeCommand(["CONFIG", "GET", parameter]) - return buildConfigResult(result, startTime: startTime) - - case .configSet(let parameter, let value): - _ = try await conn.executeCommand(["CONFIG", "SET", parameter, value]) - return buildStatusResult("OK", startTime: startTime) - - case .command(let args): - let result = try await conn.executeCommand(args) - return buildGenericResult(result, startTime: startTime) - - case .multi: - _ = try await conn.executeCommand(["MULTI"]) - return buildStatusResult("OK", startTime: startTime) - - case .exec: - let result = try await conn.executeCommand(["EXEC"]) - return buildGenericResult(result, startTime: startTime) - - case .discard: - _ = try await conn.executeCommand(["DISCARD"]) - return buildStatusResult("OK", startTime: startTime) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeServerOperation") - } - } -} - -// MARK: - SCAN Helpers - -private extension RedisPluginDriver { - func scanAllKeys( - connection conn: RedisPluginConnection, - pattern: String?, - maxKeys: Int - ) async throws -> [String] { - var allKeys: [String] = [] - var cursor = "0" - - repeat { - var args = ["SCAN", cursor] - if let p = pattern { - args += ["MATCH", p] - } - args += ["COUNT", "1000"] - - let result = try await conn.executeCommand(args) - - guard case .array(let scanResult) = result, - scanResult.count == 2 else { - break - } - - let nextCursor: String - switch scanResult[0] { - case .string(let s): nextCursor = s - case .status(let s): nextCursor = s - case .data(let d): nextCursor = String(data: d, encoding: .utf8) ?? "0" - default: nextCursor = "0" - } - cursor = nextCursor - - if case .array(let keyReplies) = scanResult[1] { - for reply in keyReplies { - switch reply { - case .string(let k): allKeys.append(k) - case .data(let d): - if let k = String(data: d, encoding: .utf8) { allKeys.append(k) } - default: break - } - } - } - - if allKeys.count >= maxKeys { - allKeys = Array(allKeys.prefix(maxKeys)) - break - } - } while cursor != "0" - - return allKeys.sorted() - } - - func handleScanResult( - _ result: RedisReply, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - guard case .array(let scanResult) = result, - scanResult.count == 2, - case .array(let keyReplies) = scanResult[1] else { - return buildEmptyKeyResult(startTime: startTime) - } - - let keys = keyReplies.compactMap { reply -> String? in - if case .string(let k) = reply { return k } - if case .data(let d) = reply { return String(data: d, encoding: .utf8) } - return nil - } - - let capped = Array(keys.prefix(PluginRowLimits.emergencyMax)) - let keysTruncated = keys.count > PluginRowLimits.emergencyMax - return try await buildKeyBrowseResult( - keys: capped, connection: conn, startTime: startTime, isTruncated: keysTruncated - ) - } -} - -// MARK: - Result Building - -private extension RedisPluginDriver { - static let previewLimit = 100 - static let previewMaxChars = 1_000 - - func buildKeyBrowseResult( - keys: [String], - connection conn: RedisPluginConnection, - startTime: Date, - isTruncated: Bool = false - ) async throws -> PluginQueryResult { - guard !keys.isEmpty else { - return buildEmptyKeyResult(startTime: startTime) - } - - var typeAndTtlCommands: [[String]] = [] - typeAndTtlCommands.reserveCapacity(keys.count * 2) - for key in keys { - typeAndTtlCommands.append(["TYPE", key]) - typeAndTtlCommands.append(["TTL", key]) - } - let typeAndTtlReplies = try await conn.executePipeline(typeAndTtlCommands) - - var typeNames: [String] = [] - typeNames.reserveCapacity(keys.count) - var ttlValues: [Int] = [] - ttlValues.reserveCapacity(keys.count) - for i in 0 ..< keys.count { - let typeName = (typeAndTtlReplies[i * 2].stringValue ?? "unknown").uppercased() - let ttl = typeAndTtlReplies[i * 2 + 1].intValue ?? -1 - typeNames.append(typeName) - ttlValues.append(ttl) - } - - var previewCommands: [[String]] = [] - previewCommands.reserveCapacity(keys.count) - var previewCommandIndices: [Int] = [] - previewCommandIndices.reserveCapacity(keys.count) - - for (i, key) in keys.enumerated() { - let command: [String]? = previewCommandForType(typeNames[i], key: key) - if let command { - previewCommandIndices.append(previewCommands.count) - previewCommands.append(command) - } else { - previewCommandIndices.append(-1) - } - } - - var previewReplies: [RedisReply] = [] - if !previewCommands.isEmpty { - previewReplies = try await conn.executePipeline(previewCommands) - } - - var rows: [[PluginCellValue]] = [] - rows.reserveCapacity(keys.count) - for (i, key) in keys.enumerated() { - let ttlStr = String(ttlValues[i]) - let pipelineIndex = previewCommandIndices[i] - let preview: String? - if pipelineIndex >= 0, pipelineIndex < previewReplies.count { - preview = formatPreviewReply( - previewReplies[pipelineIndex], type: typeNames[i] - ) - } else { - preview = nil - } - rows.append([key, typeNames[i], ttlStr, preview].asCells) - } - - return PluginQueryResult( - columns: ["Key", "Type", "TTL", "Value"], - columnTypeNames: ["String", "RedisType", "RedisInt", "RedisRaw"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime), - isTruncated: isTruncated - ) - } - - func previewCommandForType(_ type: String, key: String) -> [String]? { - switch type.lowercased() { - case "string": - return ["GET", key] - case "hash": - return ["HSCAN", key, "0", "COUNT", String(Self.previewLimit)] - case "list": - return ["LRANGE", key, "0", String(Self.previewLimit - 1)] - case "set": - return ["SSCAN", key, "0", "COUNT", String(Self.previewLimit)] - case "zset": - return ["ZRANGE", key, "0", String(Self.previewLimit - 1), "WITHSCORES"] - case "stream": - return ["XREVRANGE", key, "+", "-", "COUNT", "5"] - default: - return nil - } - } - - func formatPreviewReply(_ reply: RedisReply, type: String) -> String? { - switch type.lowercased() { - case "string": - return truncatePreview(redisReplyToString(reply)) - - case "hash": - let array: [RedisReply] - if case .array(let scanResult) = reply, - scanResult.count == 2, - let items = scanResult[1].arrayValue { - array = items - } else if let items = reply.arrayValue, !items.isEmpty { - array = items - } else { - return "{}" - } - guard !array.isEmpty else { return "{}" } - var pairs: [String] = [] - var idx = 0 - while idx + 1 < array.count { - let field = redisReplyToString(array[idx]) - let value = redisReplyToString(array[idx + 1]) - pairs.append( - "\"\(escapeJsonString(field))\":\"\(escapeJsonString(value))\"" - ) - idx += 2 - } - return truncatePreview("{\(pairs.joined(separator: ","))}") - - case "list": - guard let items = reply.arrayValue else { return "[]" } - let quoted = items.map { "\"\(escapeJsonString(redisReplyToString($0)))\"" } - return truncatePreview("[\(quoted.joined(separator: ", "))]") - - case "set": - let members: [RedisReply] - if case .array(let scanResult) = reply, - scanResult.count == 2, - let items = scanResult[1].arrayValue { - members = items - } else if let items = reply.arrayValue { - members = items - } else { - return "[]" - } - let quoted = members.map { "\"\(escapeJsonString(redisReplyToString($0)))\"" } - return truncatePreview("[\(quoted.joined(separator: ", "))]") - - case "zset": - // Parse WITHSCORES result: alternating member, score pairs - guard let items = reply.arrayValue, !items.isEmpty else { return "[]" } - var pairs: [String] = [] - var i = 0 - while i + 1 < items.count { - pairs.append("\(redisReplyToString(items[i])):\(redisReplyToString(items[i + 1]))") - i += 2 - } - return truncatePreview(pairs.joined(separator: ", ")) - - case "stream": - // Parse XREVRANGE result: array of [id, [field, value, ...]] entries - guard let entries = reply.arrayValue, !entries.isEmpty else { - return "(0 entries)" - } - var entryStrings: [String] = [] - for entry in entries { - guard let parts = entry.arrayValue, parts.count >= 2, - let fields = parts[1].arrayValue else { - continue - } - let entryId = redisReplyToString(parts[0]) - var fieldPairs: [String] = [] - var j = 0 - while j + 1 < fields.count { - fieldPairs.append("\(redisReplyToString(fields[j]))=\(redisReplyToString(fields[j + 1]))") - j += 2 - } - entryStrings.append("\(entryId): \(fieldPairs.joined(separator: ", "))") - } - return truncatePreview(entryStrings.joined(separator: "; ")) - - default: - return nil - } - } - - func truncatePreview(_ value: String?) -> String? { - guard let value else { return nil } - let nsValue = value as NSString - if nsValue.length > Self.previewMaxChars { - return nsValue.substring(to: Self.previewMaxChars) + "..." - } - return value - } - - func escapeJsonString(_ str: String) -> String { - var result = "" - for scalar in str.unicodeScalars { - switch scalar { - case "\\": result += "\\\\" - case "\"": result += "\\\"" - case "\n": result += "\\n" - case "\r": result += "\\r" - case "\t": result += "\\t" - default: - if scalar.value < 0x20 { - result += String(format: "\\u%04X", scalar.value) - } else { - result += String(scalar) - } - } - } - return result - } - - func buildEmptyKeyResult(startTime: Date) -> PluginQueryResult { - PluginQueryResult( - columns: ["Key", "Type", "TTL", "Value"], - columnTypeNames: ["String", "RedisType", "RedisInt", "RedisRaw"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - func buildStatusResult(_ message: String, startTime: Date) -> PluginQueryResult { - PluginQueryResult( - columns: ["status"], - columnTypeNames: ["String"], - rows: [[message].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - func buildGenericResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { - switch result { - case .string(let s), .status(let s): - return PluginQueryResult( - columns: ["result"], - columnTypeNames: ["String"], - rows: [[s].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .integer(let i): - return PluginQueryResult( - columns: ["result"], - columnTypeNames: ["Int64"], - rows: [[String(i)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .data(let d): - let str = String(data: d, encoding: .utf8) ?? d.base64EncodedString() - return PluginQueryResult( - columns: ["result"], - columnTypeNames: ["String"], - rows: [[str].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .array(let items): - let rows = items.map { ([redisReplyToString($0)] as [String?]).asCells } - return PluginQueryResult( - columns: ["result"], - columnTypeNames: ["String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .error(let e): - return PluginQueryResult( - columns: ["result"], - columnTypeNames: ["String"], - rows: [[e].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .null: - return PluginQueryResult( - columns: ["result"], - columnTypeNames: ["String"], - rows: [["(nil)"].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - } - - func redisReplyToString(_ reply: RedisReply) -> String { - switch reply { - case .string(let s), .status(let s), .error(let s): return s - case .integer(let i): return String(i) - case .data(let d): return String(data: d, encoding: .utf8) ?? d.base64EncodedString() - case .array(let items): return "[\(items.map { redisReplyToString($0) }.joined(separator: ", "))]" - case .null: return "(nil)" - } - } - - func buildHashResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { - guard let items = result.arrayValue, !items.isEmpty else { - return PluginQueryResult( - columns: ["Field", "Value"], - columnTypeNames: ["String", "String"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - var rows: [[PluginCellValue]] = [] - var i = 0 - while i + 1 < items.count { - rows.append([redisReplyToString(items[i]), redisReplyToString(items[i + 1])].asCells) - i += 2 - } - - return PluginQueryResult( - columns: ["Field", "Value"], - columnTypeNames: ["String", "String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - func buildListResult(_ result: RedisReply, startOffset: Int = 0, startTime: Date) -> PluginQueryResult { - guard let items = result.arrayValue else { - return PluginQueryResult( - columns: ["Index", "Value"], - columnTypeNames: ["Int64", "String"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - let rows = items.enumerated().map { index, item -> [PluginCellValue] in - ([String(startOffset + index), redisReplyToString(item)] as [String?]).asCells - } - - return PluginQueryResult( - columns: ["Index", "Value"], - columnTypeNames: ["Int64", "String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - func buildSetResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { - guard let items = result.arrayValue else { - return PluginQueryResult( - columns: ["Member"], - columnTypeNames: ["String"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - let rows = items.map { ([redisReplyToString($0)] as [String?]).asCells } - - return PluginQueryResult( - columns: ["Member"], - columnTypeNames: ["String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - func buildSortedSetResult(_ result: RedisReply, withScores: Bool, startTime: Date) -> PluginQueryResult { - guard let items = result.arrayValue else { - return PluginQueryResult( - columns: withScores ? ["Member", "Score"] : ["Member"], - columnTypeNames: withScores ? ["String", "Double"] : ["String"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - if withScores { - var rows: [[PluginCellValue]] = [] - var i = 0 - while i + 1 < items.count { - rows.append([redisReplyToString(items[i]), redisReplyToString(items[i + 1])].asCells) - i += 2 - } - return PluginQueryResult( - columns: ["Member", "Score"], - columnTypeNames: ["String", "Double"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } else { - let rows = items.map { ([redisReplyToString($0)] as [String?]).asCells } - return PluginQueryResult( - columns: ["Member"], - columnTypeNames: ["String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - } - - func buildStreamResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { - guard let entries = result.arrayValue else { - return PluginQueryResult( - columns: ["ID", "Fields"], - columnTypeNames: ["String", "String"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - var rows: [[PluginCellValue]] = [] - for entry in entries { - guard let entryParts = entry.arrayValue, entryParts.count >= 2, - let fields = entryParts[1].arrayValue else { - continue - } - let entryId = redisReplyToString(entryParts[0]) - - var fieldPairs: [String] = [] - var i = 0 - while i + 1 < fields.count { - fieldPairs.append("\(redisReplyToString(fields[i]))=\(redisReplyToString(fields[i + 1]))") - i += 2 - } - rows.append([entryId, fieldPairs.joined(separator: ", ")].asCells) - } - - return PluginQueryResult( - columns: ["ID", "Fields"], - columnTypeNames: ["String", "String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - func buildConfigResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { - guard let items = result.arrayValue, !items.isEmpty else { - return PluginQueryResult( - columns: ["Parameter", "Value"], - columnTypeNames: ["String", "String"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - var rows: [[PluginCellValue]] = [] - var i = 0 - while i + 1 < items.count { - rows.append([redisReplyToString(items[i]), redisReplyToString(items[i + 1])].asCells) - i += 2 - } - - return PluginQueryResult( - columns: ["Parameter", "Value"], - columnTypeNames: ["String", "String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } -} diff --git a/Plugins/SQLExportPlugin/SQLExportPlugin.swift b/Plugins/SQLExportPlugin/SQLExportPlugin.swift index 00a1d515a..578c31e6d 100644 --- a/Plugins/SQLExportPlugin/SQLExportPlugin.swift +++ b/Plugins/SQLExportPlugin/SQLExportPlugin.swift @@ -526,7 +526,7 @@ final class SQLExportPlugin: ExportFormatPlugin, SettablePlugin { let insertPrefix = "INSERT INTO \(tableRef) (\(quotedColumns))\(overriding) VALUES\n" let numericIndices: Set = Set(includedColumnIndices.filter { idx in - idx < columnTypeNames.count && isNumericColumnType(columnTypeNames[idx]) + idx < columnTypeNames.count && PluginExportUtilities.isNumericColumnType(columnTypeNames[idx]) }) let effectiveBatchSize = batchSize <= 1 ? 1 : batchSize @@ -546,7 +546,7 @@ final class SQLExportPlugin: ExportFormatPlugin, SettablePlugin { let hex = data.map { String(format: "%02X", $0) }.joined() return "X'\(hex)'" case .text(let val): - if numericIndices.contains(colIndex) && isNumericLiteral(val) { + if numericIndices.contains(colIndex) && PluginNumericLiteral.isValid(val) { return val } let escaped = dataSource.escapeStringLiteral(val) @@ -571,19 +571,6 @@ final class SQLExportPlugin: ExportFormatPlugin, SettablePlugin { } } - private func isNumericColumnType(_ typeName: String) -> Bool { - let numericPrefixes = [ - "int", "bigint", "decimal", "float", "double", "numeric", - "real", "smallint", "tinyint", "mediumint", "integer", "number" - ] - let lower = typeName.lowercased() - return numericPrefixes.contains { lower.hasPrefix($0) } - } - - private func isNumericLiteral(_ val: String) -> Bool { - val.allSatisfy { $0.isNumber || $0 == "." || $0 == "-" || $0 == "+" || $0 == "e" || $0 == "E" } - } - private func compressFile(source: URL, destination: URL) async throws { let gzipPath = "/usr/bin/gzip" guard FileManager.default.isExecutableFile(atPath: gzipPath) else { diff --git a/Plugins/TableProPluginKit/PluginDatabaseDriver.swift b/Plugins/TableProPluginKit/PluginDatabaseDriver.swift index 3f469a184..1f9f1330b 100644 --- a/Plugins/TableProPluginKit/PluginDatabaseDriver.swift +++ b/Plugins/TableProPluginKit/PluginDatabaseDriver.swift @@ -1,5 +1,44 @@ import Foundation +public enum PluginNumericLiteral { + public static func isValid(_ value: String) -> Bool { + guard !value.isEmpty else { return false } + var scanner = value.makeIterator() + var hasDigit = false + var hasDot = false + var hasE = false + + var first = true + while let c = scanner.next() { + if first { + first = false + if c == "-" || c == "+" { continue } + } + if c.isNumber { + hasDigit = true + continue + } + if c == "." && !hasDot && !hasE { + hasDot = true + continue + } + if (c == "e" || c == "E") && hasDigit && !hasE { + hasE = true + hasDigit = false + if let next = scanner.next() { + if next == "+" || next == "-" || next.isNumber { + if next.isNumber { hasDigit = true } + continue + } + } + return false + } + return false + } + return hasDigit + } +} + @frozen public enum ParameterStyle: String, Sendable { case questionMark // ? @@ -535,40 +574,7 @@ public extension PluginDatabaseDriver { } static func isNumericLiteral(_ value: String) -> Bool { - guard !value.isEmpty else { return false } - var scanner = value.makeIterator() - var hasDigit = false - var hasDot = false - var hasE = false - - var first = true - while let c = scanner.next() { - if first { - first = false - if c == "-" || c == "+" { continue } - } - if c.isNumber { - hasDigit = true - continue - } - if c == "." && !hasDot && !hasE { - hasDot = true - continue - } - if (c == "e" || c == "E") && hasDigit && !hasE { - hasE = true - hasDigit = false - if let next = scanner.next() { - if next == "+" || next == "-" || next.isNumber { - if next.isNumber { hasDigit = true } - continue - } - } - return false - } - return false - } - return hasDigit + PluginNumericLiteral.isValid(value) } func executeUserQuery(query: String, rowCap: Int?, parameters: [PluginCellValue]?) async throws -> PluginQueryResult { @@ -592,3 +598,99 @@ public extension PluginDatabaseDriver { ) } } + +public enum PluginSQLFilter { + public static func escapeForLike(_ text: String) -> String { + text + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "%", with: "\\%") + .replacingOccurrences(of: "_", with: "\\_") + .replacingOccurrences(of: "'", with: "''") + } + + public static func buildOrderByClause( + sortColumns: [(columnIndex: Int, ascending: Bool)], + columns: [String], + quoteIdentifier: (String) -> String + ) -> String? { + let parts = sortColumns.compactMap { sortCol -> String? in + guard sortCol.columnIndex >= 0, sortCol.columnIndex < columns.count else { return nil } + let direction = sortCol.ascending ? "ASC" : "DESC" + return "\(quoteIdentifier(columns[sortCol.columnIndex])) \(direction)" + } + guard !parts.isEmpty else { return nil } + return "ORDER BY " + parts.joined(separator: ", ") + } + + public static func buildWhereClause( + filters: [(column: String, op: String, value: String)], + logicMode: String, + quoteIdentifier: (String) -> String, + escapeValue: (String) -> String, + regexCondition: (_ quotedColumn: String, _ value: String) -> String? + ) -> String { + let conditions = filters.compactMap { filter in + buildFilterCondition( + column: filter.column, + op: filter.op, + value: filter.value, + quoteIdentifier: quoteIdentifier, + escapeValue: escapeValue, + regexCondition: regexCondition + ) + } + guard !conditions.isEmpty else { return "" } + let separator = logicMode == "and" ? " AND " : " OR " + return conditions.joined(separator: separator) + } + + public static func buildFilterCondition( + column: String, + op: String, + value: String, + quoteIdentifier: (String) -> String, + escapeValue: (String) -> String, + regexCondition: (_ quotedColumn: String, _ value: String) -> String? + ) -> String? { + let quoted = quoteIdentifier(column) + switch op { + case "=": return "\(quoted) = \(escapeValue(value))" + case "!=": return "\(quoted) != \(escapeValue(value))" + case ">": return "\(quoted) > \(escapeValue(value))" + case ">=": return "\(quoted) >= \(escapeValue(value))" + case "<": return "\(quoted) < \(escapeValue(value))" + case "<=": return "\(quoted) <= \(escapeValue(value))" + case "IS NULL": return "\(quoted) IS NULL" + case "IS NOT NULL": return "\(quoted) IS NOT NULL" + case "IS EMPTY": return "(\(quoted) IS NULL OR \(quoted) = '')" + case "IS NOT EMPTY": return "(\(quoted) IS NOT NULL AND \(quoted) != '')" + case "CONTAINS": + return "\(quoted) LIKE '%\(escapeForLike(value))%' ESCAPE '\\'" + case "NOT CONTAINS": + return "\(quoted) NOT LIKE '%\(escapeForLike(value))%' ESCAPE '\\'" + case "STARTS WITH": + return "\(quoted) LIKE '\(escapeForLike(value))%' ESCAPE '\\'" + case "ENDS WITH": + return "\(quoted) LIKE '%\(escapeForLike(value))' ESCAPE '\\'" + case "IN": + let values = value.split(separator: ",") + .map { escapeValue($0.trimmingCharacters(in: .whitespaces)) } + .joined(separator: ", ") + return values.isEmpty ? nil : "\(quoted) IN (\(values))" + case "NOT IN": + let values = value.split(separator: ",") + .map { escapeValue($0.trimmingCharacters(in: .whitespaces)) } + .joined(separator: ", ") + return values.isEmpty ? nil : "\(quoted) NOT IN (\(values))" + case "BETWEEN": + let parts = value.split(separator: ",", maxSplits: 1) + guard parts.count == 2 else { return nil } + let v1 = escapeValue(parts[0].trimmingCharacters(in: .whitespaces)) + let v2 = escapeValue(parts[1].trimmingCharacters(in: .whitespaces)) + return "\(quoted) BETWEEN \(v1) AND \(v2)" + case "REGEX": + return regexCondition(quoted, value) + default: return nil + } + } +} diff --git a/Plugins/TableProPluginKit/PluginExportUtilities.swift b/Plugins/TableProPluginKit/PluginExportUtilities.swift index 5bdf9f718..222ec7ee6 100644 --- a/Plugins/TableProPluginKit/PluginExportUtilities.swift +++ b/Plugins/TableProPluginKit/PluginExportUtilities.swift @@ -84,6 +84,15 @@ public enum PluginExportUtilities { result = result.replacingOccurrences(of: "--", with: "") return result } + + public static func isNumericColumnType(_ typeName: String) -> Bool { + let numericPrefixes = [ + "int", "bigint", "decimal", "float", "double", "numeric", + "real", "smallint", "tinyint", "mediumint", "integer", "number" + ] + let lower = typeName.lowercased() + return numericPrefixes.contains { lower.hasPrefix($0) } + } } public extension String { diff --git a/docs/development/refactor-audit.mdx b/docs/development/refactor-audit.mdx new file mode 100644 index 000000000..31b350361 --- /dev/null +++ b/docs/development/refactor-audit.mdx @@ -0,0 +1,69 @@ +--- +title: "Refactor audit" +description: "Internal engineering audit and refactor checklist. Not part of the published navigation." +--- + +Internal tracking doc. Not linked in `docs.json` on purpose. Candid file:line findings for an ongoing, tiered refactor. Source: full-codebase audit on 2026-06-03. + +Headline: the codebase is well built and well maintained. Zero real `print()` in production code, near-zero force unwraps, one coherent modern concurrency model (async/await + actors), disciplined error types, a native color/material system. This is not a rewrite. It is a sequence of surgical, root-cause refactors. Work top-down by tier; keep each commit reviewable and atomic. + +## Tier 0: confirmed bugs (done) + +- [x] JSON import plugin failed to load: `JSONImportPlugin/Info.plist` declared `TableProPluginKitVersion` 17 while the app is at 18, so the gate at `PluginManager.swift:471` rejected it. Bumped to 18. +- [x] Change managers defaulted to MySQL dialect. `DataChangeManager.databaseType` is now optional and `generateSQL` throws if unconfigured instead of silently emitting MySQL; `configureForTable` requires an explicit dialect. `StructureChangeManager.databaseType` was dead write-only state (DDL comes from `SchemaStatementGenerator(pluginDriver:)`) and was removed. +- [x] `SQLExportPlugin.isNumericLiteral` accepted garbage (`+-e..`, `1.2.3`, empty) and emitted it unquoted into INSERT. Replaced with the kit's validated check, now exposed as `PluginNumericLiteral.isValid` in TableProPluginKit (additive, ABI-safe). +- [x] `CloudflareTunnelManager.persistPidRecords` swallowed the encode failure with `try?`, so leaked `cloudflared` PIDs were never cleaned up. Now logs via OSLog. + +## Tier 1: high-value architecture + +- [ ] `MainContentCoordinator` god object: 1,392-line core + 34 extension files (~4,278 lines) in one `@Observable` type, ~30 responsibilities. Push state and logic into the existing sub-coordinators (`filterCoordinator`, `paginationCoordinator`, `rowEditingCoordinator`, `queryExecutionCoordinator`) and have the main type delegate. +- [ ] DI is theatre: 75 `static let shared` in `Core/`; `AppServices.live` wires 26 of them. Route through the `AppServices` composition root, drop `.shared` from types that flow through it. +- [ ] Triplicated reconnect logic in `DatabaseManager+Health.swift:111-308` (x2) and `+Sessions.swift:108-155`. Extract one `establishDriver(for:)` (connect + timeout + startup commands + restore schema/db). +- [ ] `ConnectionStorage` god object (988L): CRUD + file persistence + UserDefaults migration + Keychain for 8 secret kinds + the `StoredConnection` model. Split into `ConnectionStore` / `ConnectionSecretsStore` / model file. +- [ ] `Core/Services/Infrastructure/` (38 files) mixes AppKit view controllers (`MainSplitViewController`, `TabWindowController`, toolbars) into Core. Move view controllers to `Views/`; split the rest into Launch / Windowing / Tabs. +Cross-plugin duplication (investigated 2026-06-03; the original audit overstated this, 3 of 5 were false positives): + +- [x] `isNumericColumnType` (SQLExport/JSONExport): byte-identical. Extracted to `PluginExportUtilities.isNumericColumnType`. +- [x] SQL filter/where/order-by builder (Oracle + MSSQL): real. Extracted to `PluginSQLFilter` (parameterized by quote/escape/regex closures), which also makes the injection-defense escaping a single canonical copy. ClickHouse's `buildWhereClause` is unrelated (it is DML, not browse filtering). +- [x] JSON scalar decoder (`D1Value`/`HranaValue`/`BQCellValue`): FALSE POSITIVE. Three different wire formats (bare scalar vs keyed `{type,value,base64}` vs recursive `record`/`array`). A shared type would be a leaky union. Skipped. +- [x] Schemaless flattener (Mongo/DynamoDB): FALSE POSITIVE. Different priority-key source, ordering, value types, return types. Skipped. +- [x] DML INSERT/UPDATE/DELETE generator (Oracle/MSSQL/ClickHouse): FALSE POSITIVE. Differ in quoting, default/identity handling, WHERE strategy, row-limit, statement form (standard vs ClickHouse ALTER TABLE), and dispatch. Correctness-critical, no headless test net. Skipped. + +- [x] Oversized plugin files split (2026-06-03), all build green: + - Redis 1695 -> 632 + `+Operations` (510) / `+Scan` (85) / `+ResultBuilding` (486) + - DuckDB 1491 -> 879 + `DuckDBConnection` (600) / `DuckDBPluginError` (25) + - Cassandra 1449 -> 588 + `CassandraConnection` (851) / `CassandraPluginError` (25) + - ClickHouse 1426 -> 799 + `+Schema` (340) / `+Http` (309) + - MSSQL 1403 -> 710 + `+Schema` (533) / `+DDL` (184) + - Mongo 1351 -> 903 + `+SyncHelpers` (458) + - Oracle (1152) and the rest already under the limit after de-dup. + - The 6 split plugins use Xcode synchronized file-system groups, so new files auto-included with NO pbxproj edits. Splitting across files required broadening the `private` members the moved code touches to `internal` (the established `MainContentCoordinator+*` pattern). +- [ ] DynamoDB (1238, 38 over) left intact. It is the one plugin still on a legacy explicit `PBXGroup` (not a synchronized group), so a new file needs manual pbxproj target-membership wiring, which the build system did not honor (xcodebuild never compiled the added file). Not worth the risk for 38 lines on a limit that is not CI-enforced for plugins. To do it properly: convert `Plugins/DynamoDBDriverPlugin` to a `PBXFileSystemSynchronizedRootGroup` (matching the other plugins), then split like the rest. + +## Tier 2: macOS HIG / native + +- [ ] No design-token layer. 5 competing corner radii (`6`x34, `4`x17, `8`x8), ad-hoc paddings, hardcoded control sizes. Add `Theme/Metrics.swift` (the existing `ThemeLayout.swift` only holds fonts; misnamed). Route literals through it. +- [ ] `runModal()` blocking panels instead of sheets: `IntegrationsActivityLogPane.swift:236,242`, `ERDiagramView.swift:270`. Use `beginSheetModal(for:)`. +- [ ] Thin accessibility: 61/343 view files set accessibility labels. Backfill `.accessibilityLabel` on icon-only controls. +- [ ] Business logic in views: `ExportDialog.swift:674` builds and runs raw `information_schema`/Oracle SQL inside a SwiftUI View. Move to an export service. +- [ ] Raw RGB colors in `ChatComposerView.swift:204`, hardcoded editor font in `HighlightedSQLTextView.swift:14`. Route through the theme. + +## Tier 3: consistency / hygiene + +- [ ] 19 plugin `String(localized: "...\(x)")` dynamic-key bugs (DynamoDB, Etcd, Mongo, CloudflareD1, Redis). Plugins ship no `.xcstrings`, so every plugin `String(localized:)` is a no-op. Pick a policy: add catalogs or drop the calls. +- [ ] Dead `DialectQuoteHelper.swift` (zero refs). Delete; consolidate identifier quoting into PluginKit's `SQLDialectDescriptor`. +- [ ] `Data.hexEncoded` reinvented 11x; `DatabaseType.brandColor` duplicates the registry and silently `.accentColor`s unknown plugins; dead `allCases` compatibility shim; `SidebarViewModel` static registry holds strong refs (leak), use `NSMapTable` weak values like `ConnectionDataCache`. +- [ ] Driver error types (`MariaDBPluginError`, `LibPQPluginError`, `OracleError`, ...) don't conform to `PluginDriverError`, bypassing the kit's error formatting inconsistently. +- [ ] `PluginStreamElement` is `@frozen` but will grow (`.progress`/`.warning`); that makes a future additive case a breaking ABI bump. Drop `@frozen`, add `@unknown default` at the 5 consumers. +- [ ] Suffix vocabulary overload (Manager 22 / Service 15 / Coordinator 12 / Provider 22). 14 top-level free functions (mostly `ConnectionGroupTree.swift`) violate explicit access control. +- [ ] Five `@Observable @MainActor` state classes live under `Models/` (view state). Move to `ViewModels/`. + +## Tier 4: test coverage gaps + +- [ ] Sync has zero tests: `ConflictResolver`, `SyncCoordinator`, `SyncChangeTracker`. Holds last-write-wins/merge logic and the documented "markDeleted after saveConnections" delete-ordering invariant. Highest-value gap. +- [ ] Untested storage: `AppSettingsStorage`, `SSHProfileStorage`, `TabStateStorage` (the 500KB truncation guard). + +## Policy decisions (owner's call) + +- [ ] The "no comments" rule. About 1,700 inline + 2,200 `///` doc comments exist; SwiftLint ignores comments, so the rule is unenforced. Mass deletion is high-churn and low-value. Recommendation: relax CLAUDE.md to "no what-comments; `///` API docs and load-bearing why-notes allowed" rather than stripping thousands of lines. +- [ ] Suffix vocabulary convention (Manager vs Service vs Coordinator). One line in CLAUDE.md before any rename pass. From 85d2d404000a593cedc64daa3f8ec101c7d04f48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Wed, 3 Jun 2026 10:36:38 +0700 Subject: [PATCH 03/12] refactor: dedup hex encoding into a shared extension and drop dead DatabaseType.allCases shim --- TablePro/Core/Database/AWS/AWSSSO.swift | 2 +- TablePro/Core/Database/AWS/AWSSigV4.swift | 4 ++-- TablePro/Core/MCP/Auth/MCPBearerTokenAuthenticator.swift | 2 +- TablePro/Core/MCP/MCPTokenStore.swift | 2 +- TablePro/Core/Plugins/PluginDriverAdapter.swift | 2 +- TablePro/Core/Plugins/PluginInstaller.swift | 2 +- TablePro/Core/SSH/SSHPathUtilities.swift | 2 +- .../Services/Formatting/ValueDisplayFormatService.swift | 2 +- TablePro/Extensions/Sequence+HexEncoded.swift | 7 +++++++ TablePro/Extensions/String+SHA256.swift | 2 +- TablePro/Models/Connection/DatabaseConnection.swift | 3 --- TablePro/Theme/ThemeRegistryInstaller.swift | 2 +- 12 files changed, 18 insertions(+), 14 deletions(-) create mode 100644 TablePro/Extensions/Sequence+HexEncoded.swift diff --git a/TablePro/Core/Database/AWS/AWSSSO.swift b/TablePro/Core/Database/AWS/AWSSSO.swift index d33a86828..695b43cdc 100644 --- a/TablePro/Core/Database/AWS/AWSSSO.swift +++ b/TablePro/Core/Database/AWS/AWSSSO.swift @@ -300,6 +300,6 @@ enum AWSSSO { data.withUnsafeBytes { ptr in _ = CC_SHA1(ptr.baseAddress, CC_LONG(data.count), &hash) } - return hash.map { String(format: "%02x", $0) }.joined() + return hash.hexEncoded } } diff --git a/TablePro/Core/Database/AWS/AWSSigV4.swift b/TablePro/Core/Database/AWS/AWSSigV4.swift index 63d38b6fd..db96d0ff0 100644 --- a/TablePro/Core/Database/AWS/AWSSigV4.swift +++ b/TablePro/Core/Database/AWS/AWSSigV4.swift @@ -27,7 +27,7 @@ enum AWSSigV4 { } static func hmacHex(key: Data, data: Data) -> String { - hmac(key: key, data: data).map { String(format: "%02x", $0) }.joined() + hmac(key: key, data: data).hexEncoded } static func sha256Hex(_ data: Data) -> String { @@ -35,7 +35,7 @@ enum AWSSigV4 { data.withUnsafeBytes { ptr in _ = CC_SHA256(ptr.baseAddress, CC_LONG(data.count), &hash) } - return hash.map { String(format: "%02x", $0) }.joined() + return hash.hexEncoded } static func deriveSigningKey(secretKey: String, dateStamp: String, region: String, service: String) -> Data { diff --git a/TablePro/Core/MCP/Auth/MCPBearerTokenAuthenticator.swift b/TablePro/Core/MCP/Auth/MCPBearerTokenAuthenticator.swift index 5177af407..4204f0cef 100644 --- a/TablePro/Core/MCP/Auth/MCPBearerTokenAuthenticator.swift +++ b/TablePro/Core/MCP/Auth/MCPBearerTokenAuthenticator.swift @@ -207,7 +207,7 @@ public actor MCPBearerTokenAuthenticator: MCPAuthenticator { internal static func fingerprint(of token: String) -> String { guard let data = token.data(using: .utf8) else { return "" } let digest = SHA256.hash(data: data) - let hex = digest.map { String(format: "%02x", $0) }.joined() + let hex = digest.hexEncoded return String(hex.prefix(16)) } } diff --git a/TablePro/Core/MCP/MCPTokenStore.swift b/TablePro/Core/MCP/MCPTokenStore.swift index f1b848848..bc71debf1 100644 --- a/TablePro/Core/MCP/MCPTokenStore.swift +++ b/TablePro/Core/MCP/MCPTokenStore.swift @@ -341,7 +341,7 @@ actor MCPTokenStore { let input = salt + plaintext guard let data = input.data(using: .utf8) else { return "" } let digest = SHA256.hash(data: data) - return digest.map { String(format: "%02x", $0) }.joined() + return digest.hexEncoded } private func base64UrlEncode(_ data: Data) -> String { diff --git a/TablePro/Core/Plugins/PluginDriverAdapter.swift b/TablePro/Core/Plugins/PluginDriverAdapter.swift index d4d9955b1..e95e5275b 100644 --- a/TablePro/Core/Plugins/PluginDriverAdapter.swift +++ b/TablePro/Core/Plugins/PluginDriverAdapter.swift @@ -85,7 +85,7 @@ final class PluginDriverAdapter: DatabaseDriver, SchemaSwitchable { case let d as Date: return Self.iso8601Formatter.string(from: d) case let data as Data: - return data.map { String(format: "%02x", $0) }.joined() + return data.hexEncoded case let uuid as UUID: return uuid.uuidString default: diff --git a/TablePro/Core/Plugins/PluginInstaller.swift b/TablePro/Core/Plugins/PluginInstaller.swift index 981f0054e..9b56f14fd 100644 --- a/TablePro/Core/Plugins/PluginInstaller.swift +++ b/TablePro/Core/Plugins/PluginInstaller.swift @@ -196,7 +196,7 @@ actor PluginInstaller { let payload = try Data(contentsOf: tempDownloadURL) let digest = SHA256.hash(data: payload) - let hex = digest.map { String(format: "%02x", $0) }.joined() + let hex = digest.hexEncoded guard hex == binary.sha256.lowercased() else { throw PluginError.checksumMismatch } diff --git a/TablePro/Core/SSH/SSHPathUtilities.swift b/TablePro/Core/SSH/SSHPathUtilities.swift index ff995fb18..db75e4557 100644 --- a/TablePro/Core/SSH/SSHPathUtilities.swift +++ b/TablePro/Core/SSH/SSHPathUtilities.swift @@ -85,7 +85,7 @@ struct SSHTokenContext: Sendable { if result.contains("%C") { let basis = "\(localHostnameFQDN())\(hostname ?? "")\(port.map(String.init) ?? "")\(remoteUser ?? "")" let digest = Insecure.SHA1.hash(data: Data(basis.utf8)) - let hex = digest.map { String(format: "%02x", $0) }.joined() + let hex = digest.hexEncoded result = result.replacingOccurrences(of: "%C", with: hex) } diff --git a/TablePro/Core/Services/Formatting/ValueDisplayFormatService.swift b/TablePro/Core/Services/Formatting/ValueDisplayFormatService.swift index 1e919dabb..82fef04d9 100644 --- a/TablePro/Core/Services/Formatting/ValueDisplayFormatService.swift +++ b/TablePro/Core/Services/Formatting/ValueDisplayFormatService.swift @@ -116,7 +116,7 @@ final class ValueDisplayFormatService { // Try raw binary bytes (isoLatin1 encoding from MySQL) if let data = rawValue.data(using: .isoLatin1), data.count == 16 { let bytes = [UInt8](data) - let hex = bytes.map { String(format: "%02x", $0) }.joined() + let hex = bytes.hexEncoded return insertUuidHyphens(hex) } diff --git a/TablePro/Extensions/Sequence+HexEncoded.swift b/TablePro/Extensions/Sequence+HexEncoded.swift new file mode 100644 index 000000000..fe5a1c002 --- /dev/null +++ b/TablePro/Extensions/Sequence+HexEncoded.swift @@ -0,0 +1,7 @@ +import Foundation + +extension Sequence where Element == UInt8 { + var hexEncoded: String { + map { String(format: "%02x", $0) }.joined() + } +} diff --git a/TablePro/Extensions/String+SHA256.swift b/TablePro/Extensions/String+SHA256.swift index 001a933e2..09277d317 100644 --- a/TablePro/Extensions/String+SHA256.swift +++ b/TablePro/Extensions/String+SHA256.swift @@ -13,6 +13,6 @@ extension String { var sha256: String { let data = Data(utf8) let digest = SHA256.hash(data: data) - return digest.map { String(format: "%02x", $0) }.joined() + return digest.hexEncoded } } diff --git a/TablePro/Models/Connection/DatabaseConnection.swift b/TablePro/Models/Connection/DatabaseConnection.swift index c420c17bd..8196a64bd 100644 --- a/TablePro/Models/Connection/DatabaseConnection.swift +++ b/TablePro/Models/Connection/DatabaseConnection.swift @@ -64,9 +64,6 @@ extension DatabaseType { static var allKnownTypes: [DatabaseType] { PluginMetadataRegistry.shared.allRegisteredTypeIds().map { DatabaseType(rawValue: $0) } } - - /// Compatibility shim for CaseIterable call sites. - static var allCases: [DatabaseType] { allKnownTypes } } extension DatabaseType { diff --git a/TablePro/Theme/ThemeRegistryInstaller.swift b/TablePro/Theme/ThemeRegistryInstaller.swift index 75aca2d4e..159ba82f4 100644 --- a/TablePro/Theme/ThemeRegistryInstaller.swift +++ b/TablePro/Theme/ThemeRegistryInstaller.swift @@ -213,7 +213,7 @@ internal final class ThemeRegistryInstaller { let downloadedData = try Data(contentsOf: tempDownloadURL) let digest = SHA256.hash(data: downloadedData) - let hexChecksum = digest.map { String(format: "%02x", $0) }.joined() + let hexChecksum = digest.hexEncoded if hexChecksum != resolved.sha256.lowercased() { throw PluginError.checksumMismatch From 1ce28df59433f94ed67f9349e103e835a6f09dac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Wed, 3 Jun 2026 10:37:12 +0700 Subject: [PATCH 04/12] docs: update refactor audit with Tier 3 progress --- docs/development/refactor-audit.mdx | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/development/refactor-audit.mdx b/docs/development/refactor-audit.mdx index 31b350361..1832f5a2d 100644 --- a/docs/development/refactor-audit.mdx +++ b/docs/development/refactor-audit.mdx @@ -50,9 +50,12 @@ Cross-plugin duplication (investigated 2026-06-03; the original audit overstated ## Tier 3: consistency / hygiene -- [ ] 19 plugin `String(localized: "...\(x)")` dynamic-key bugs (DynamoDB, Etcd, Mongo, CloudflareD1, Redis). Plugins ship no `.xcstrings`, so every plugin `String(localized:)` is a no-op. Pick a policy: add catalogs or drop the calls. -- [ ] Dead `DialectQuoteHelper.swift` (zero refs). Delete; consolidate identifier quoting into PluginKit's `SQLDialectDescriptor`. -- [ ] `Data.hexEncoded` reinvented 11x; `DatabaseType.brandColor` duplicates the registry and silently `.accentColor`s unknown plugins; dead `allCases` compatibility shim; `SidebarViewModel` static registry holds strong refs (leak), use `NSMapTable` weak values like `ConnectionDataCache`. +- [x] `Data.hexEncoded` reinvented 11x: extracted `Sequence where Element == UInt8 { hexEncoded }`, routed all 11 call sites. (2026-06-03) +- [x] Dead `DatabaseType.allCases` compatibility shim: removed (verified no callers). (2026-06-03) +- [~] DialectQuoteHelper: VERIFIED FALSE POSITIVE. `quoteIdentifierFromDialect`/`resolveSQLDialect` are used by FilterSQLGenerator, SQLStatementGenerator, ExportDataTool, SQLRowToStatementConverter, QueryTab, MainContentCoordinator. NOT dead. Left as-is. +- [ ] `DatabaseType.brandColor` (hardcoded 19-case switch, one caller) duplicates the registry-driven `PluginManager.brandColor(for:)`. Fix = route through the registry. NEEDS SIGN-OFF: shifts some chooser-sheet colors to match the registry (e.g. ClickHouse FFCC01 -> FFD100), a visible change. +- [ ] `SidebarViewModel` static `[UUID: SidebarViewModel]` registry holds strong refs (leak if `removeConnection` is missed). NOT a mechanical swap: the registry is the strong owner, so converting to weak needs the window/view ownership graph checked first or the cache breaks. +- [ ] 19 plugin `String(localized: "...\(x)")` dynamic-key calls (DynamoDB, Etcd, Mongo, CloudflareD1, Redis). Plugins ship no `.xcstrings`, so these are passthroughs today (not user-visible bugs), but they break if a catalog is ever added. POLICY: add per-plugin catalogs or drop `String(localized:)` in plugin code. - [ ] Driver error types (`MariaDBPluginError`, `LibPQPluginError`, `OracleError`, ...) don't conform to `PluginDriverError`, bypassing the kit's error formatting inconsistently. - [ ] `PluginStreamElement` is `@frozen` but will grow (`.progress`/`.warning`); that makes a future additive case a breaking ABI bump. Drop `@frozen`, add `@unknown default` at the 5 consumers. - [ ] Suffix vocabulary overload (Manager 22 / Service 15 / Coordinator 12 / Provider 22). 14 top-level free functions (mostly `ConnectionGroupTree.swift`) violate explicit access control. From 0af7c5b8334237c1f4f80bd579b4e2020cad65b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Wed, 3 Jun 2026 10:46:33 +0700 Subject: [PATCH 05/12] docs: record Tier 3 plugin-side verification findings --- docs/development/refactor-audit.mdx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/development/refactor-audit.mdx b/docs/development/refactor-audit.mdx index 1832f5a2d..4db5e77c2 100644 --- a/docs/development/refactor-audit.mdx +++ b/docs/development/refactor-audit.mdx @@ -53,11 +53,11 @@ Cross-plugin duplication (investigated 2026-06-03; the original audit overstated - [x] `Data.hexEncoded` reinvented 11x: extracted `Sequence where Element == UInt8 { hexEncoded }`, routed all 11 call sites. (2026-06-03) - [x] Dead `DatabaseType.allCases` compatibility shim: removed (verified no callers). (2026-06-03) - [~] DialectQuoteHelper: VERIFIED FALSE POSITIVE. `quoteIdentifierFromDialect`/`resolveSQLDialect` are used by FilterSQLGenerator, SQLStatementGenerator, ExportDataTool, SQLRowToStatementConverter, QueryTab, MainContentCoordinator. NOT dead. Left as-is. -- [ ] `DatabaseType.brandColor` (hardcoded 19-case switch, one caller) duplicates the registry-driven `PluginManager.brandColor(for:)`. Fix = route through the registry. NEEDS SIGN-OFF: shifts some chooser-sheet colors to match the registry (e.g. ClickHouse FFCC01 -> FFD100), a visible change. -- [ ] `SidebarViewModel` static `[UUID: SidebarViewModel]` registry holds strong refs (leak if `removeConnection` is missed). NOT a mechanical swap: the registry is the strong owner, so converting to weak needs the window/view ownership graph checked first or the cache breaks. +- [ ] `DatabaseType.brandColor` (hardcoded 19-case switch, one caller). Attempted routing through `PluginManager.brandColor(for:)`, then REVERTED: that helper resolves via `pluginTypeId`, which maps variants to their driver (MariaDB->MySQL), and the variant registry snapshots carry no `brandColorHex`, so variants would lose their color (color/icon mismatch). PROPER FIX (deliberate, not hygiene): add `brandColorHex` to the variant registry snapshots, then delegate to `snapshot(forTypeId: rawValue)?.brandColorHex` like `iconName` already does. +- [~] `SidebarViewModel` static registry: VERIFIED the audit's fix is WRONG. `MainContentCoordinator.sidebarViewModel` is `weak` and `SidebarView` only holds a local; the registry is the SOLE strong owner, so `NSMapTable` weak values would deallocate the VM immediately and break the sidebar. Lifecycle is managed by `removeConnection` on session teardown (`DatabaseManager+Sessions.swift:365`). Design is sound; left as-is. - [ ] 19 plugin `String(localized: "...\(x)")` dynamic-key calls (DynamoDB, Etcd, Mongo, CloudflareD1, Redis). Plugins ship no `.xcstrings`, so these are passthroughs today (not user-visible bugs), but they break if a catalog is ever added. POLICY: add per-plugin catalogs or drop `String(localized:)` in plugin code. -- [ ] Driver error types (`MariaDBPluginError`, `LibPQPluginError`, `OracleError`, ...) don't conform to `PluginDriverError`, bypassing the kit's error formatting inconsistently. -- [ ] `PluginStreamElement` is `@frozen` but will grow (`.progress`/`.warning`); that makes a future additive case a breaking ABI bump. Drop `@frozen`, add `@unknown default` at the 5 consumers. +- [x] Driver error types -> `PluginDriverError`: VERIFIED ALREADY DONE. All driver error types (MariaDB, LibPQ, Oracle, SQLite, MongoDB, Redis, etc.) already conform. Audit false positive. +- [ ] `PluginStreamElement` is `@frozen` and the kit uses Library Evolution, so dropping `@frozen` is a BINARY-BREAKING ABI change (many plugins switch over it), needing a coordinated 18->19 bump + `release-all-plugins`. NOT hygiene. Defer to the next deliberate ABI bump; do it then so future additive cases are free. - [ ] Suffix vocabulary overload (Manager 22 / Service 15 / Coordinator 12 / Provider 22). 14 top-level free functions (mostly `ConnectionGroupTree.swift`) violate explicit access control. - [ ] Five `@Observable @MainActor` state classes live under `Models/` (view state). Move to `ViewModels/`. From 32127830299a7d1374f4d60c3b89b57b02dfb910 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Wed, 3 Jun 2026 11:12:25 +0700 Subject: [PATCH 06/12] fix(tests): update allCases and StructureChangeManager.loadSchema callers after API changes --- TableProTests/Models/DatabaseTypeCassandraTests.swift | 4 ++-- TableProTests/Models/DatabaseTypeCockroachDBTests.swift | 2 +- TableProTests/Models/DatabaseTypeMSSQLTests.swift | 2 +- TableProTests/Models/DatabaseTypeRedisTests.swift | 2 +- TableProTests/Models/DatabaseTypeTests.swift | 5 ----- .../Structure/StructureEditingSupportFieldDiffTests.swift | 3 +-- 6 files changed, 6 insertions(+), 12 deletions(-) diff --git a/TableProTests/Models/DatabaseTypeCassandraTests.swift b/TableProTests/Models/DatabaseTypeCassandraTests.swift index 0ad36d743..013961300 100644 --- a/TableProTests/Models/DatabaseTypeCassandraTests.swift +++ b/TableProTests/Models/DatabaseTypeCassandraTests.swift @@ -86,11 +86,11 @@ struct DatabaseTypeCassandraTests { @Test("Cassandra included in allCases") func cassandraIncludedInAllCases() { - #expect(DatabaseType.allCases.contains(.cassandra)) + #expect(DatabaseType.allKnownTypes.contains(.cassandra)) } @Test("ScyllaDB included in allCases") func scylladbIncludedInAllCases() { - #expect(DatabaseType.allCases.contains(.scylladb)) + #expect(DatabaseType.allKnownTypes.contains(.scylladb)) } } diff --git a/TableProTests/Models/DatabaseTypeCockroachDBTests.swift b/TableProTests/Models/DatabaseTypeCockroachDBTests.swift index 5ad8a4f71..80723a128 100644 --- a/TableProTests/Models/DatabaseTypeCockroachDBTests.swift +++ b/TableProTests/Models/DatabaseTypeCockroachDBTests.swift @@ -68,6 +68,6 @@ struct DatabaseTypeCockroachDBTests { @Test("allCases shim contains cockroachdb") func allCasesContainsCockroachDB() { - #expect(DatabaseType.allCases.contains(.cockroachdb)) + #expect(DatabaseType.allKnownTypes.contains(.cockroachdb)) } } diff --git a/TableProTests/Models/DatabaseTypeMSSQLTests.swift b/TableProTests/Models/DatabaseTypeMSSQLTests.swift index 7b2f9d882..2c93361dc 100644 --- a/TableProTests/Models/DatabaseTypeMSSQLTests.swift +++ b/TableProTests/Models/DatabaseTypeMSSQLTests.swift @@ -53,6 +53,6 @@ struct DatabaseTypeMSSQLTests { @Test("allCases shim contains mssql") func allCasesContainsMSSql() { - #expect(DatabaseType.allCases.contains(.mssql)) + #expect(DatabaseType.allKnownTypes.contains(.mssql)) } } diff --git a/TableProTests/Models/DatabaseTypeRedisTests.swift b/TableProTests/Models/DatabaseTypeRedisTests.swift index 0f0cc44c2..8e1b2ee7b 100644 --- a/TableProTests/Models/DatabaseTypeRedisTests.swift +++ b/TableProTests/Models/DatabaseTypeRedisTests.swift @@ -46,6 +46,6 @@ struct DatabaseTypeRedisTests { @Test("Included in allCases shim") func includedInAllCases() { - #expect(DatabaseType.allCases.contains(.redis)) + #expect(DatabaseType.allKnownTypes.contains(.redis)) } } diff --git a/TableProTests/Models/DatabaseTypeTests.swift b/TableProTests/Models/DatabaseTypeTests.swift index ee4d1616b..215b8cde2 100644 --- a/TableProTests/Models/DatabaseTypeTests.swift +++ b/TableProTests/Models/DatabaseTypeTests.swift @@ -47,11 +47,6 @@ struct DatabaseTypeTests { #expect(knownTypes.count >= 5) } - @Test("allCases shim matches allKnownTypes") - func testAllCasesShim() { - #expect(DatabaseType.allCases == DatabaseType.allKnownTypes) - } - @Test("Raw value matches display name", arguments: [ (DatabaseType.mysql, "MySQL"), (DatabaseType.mariadb, "MariaDB"), diff --git a/TableProTests/Views/Structure/StructureEditingSupportFieldDiffTests.swift b/TableProTests/Views/Structure/StructureEditingSupportFieldDiffTests.swift index f8e2d47fd..ff40043e9 100644 --- a/TableProTests/Views/Structure/StructureEditingSupportFieldDiffTests.swift +++ b/TableProTests/Views/Structure/StructureEditingSupportFieldDiffTests.swift @@ -251,8 +251,7 @@ struct StructureChangeManagerUndoDeleteTests { columns: columns, indexes: [], foreignKeys: [], - primaryKey: ["id"], - databaseType: .mysql + primaryKey: ["id"] ) return manager } From ff50f7acea07a6575dca53ac589ff5a7a6549b5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Wed, 3 Jun 2026 11:12:26 +0700 Subject: [PATCH 07/12] test(sync): add SyncChangeTracker dirty/tombstone/suppression unit tests --- .../Core/Sync/SyncChangeTrackerTests.swift | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 TableProTests/Core/Sync/SyncChangeTrackerTests.swift diff --git a/TableProTests/Core/Sync/SyncChangeTrackerTests.swift b/TableProTests/Core/Sync/SyncChangeTrackerTests.swift new file mode 100644 index 000000000..d15237de4 --- /dev/null +++ b/TableProTests/Core/Sync/SyncChangeTrackerTests.swift @@ -0,0 +1,79 @@ +// +// SyncChangeTrackerTests.swift +// TableProTests +// + +import Foundation +import Testing + +@testable import TablePro + +@Suite("SyncChangeTracker") +@MainActor +struct SyncChangeTrackerTests { + private let metadata: SyncMetadataStorage + private let tracker: SyncChangeTracker + + init() { + let unique = UUID().uuidString + let syncDefaults = UserDefaults(suiteName: "com.TablePro.tests.SyncChangeTracker.\(unique)")! + metadata = SyncMetadataStorage(userDefaults: syncDefaults) + tracker = SyncChangeTracker(metadataStorage: metadata) + } + + @Test("markDirty records the id as dirty") + func markDirtyAddsId() { + tracker.markDirty(.connection, id: "conn-1") + #expect(tracker.dirtyRecords(for: .connection) == ["conn-1"]) + } + + @Test("markDirty with multiple ids records all of them") + func markDirtyMultiple() { + tracker.markDirty(.connection, ids: ["a", "b", "c"]) + #expect(tracker.dirtyRecords(for: .connection) == ["a", "b", "c"]) + } + + @Test("markDirty with an empty id list records nothing") + func markDirtyEmptyIsNoop() { + tracker.markDirty(.connection, ids: []) + #expect(tracker.dirtyRecords(for: .connection).isEmpty) + } + + @Test("markDeleted clears the dirty flag and records a tombstone") + func markDeletedClearsDirtyAndTombstones() { + tracker.markDirty(.connection, id: "conn-1") + tracker.markDeleted(.connection, id: "conn-1") + + #expect(!tracker.dirtyRecords(for: .connection).contains("conn-1")) + #expect(metadata.tombstones(for: .connection).contains { $0.id == "conn-1" }) + } + + @Test("Suppression makes markDirty and markDeleted no-ops") + func suppressionDisablesTracking() { + tracker.isSuppressed = true + tracker.markDirty(.connection, id: "conn-1") + tracker.markDeleted(.group, id: "group-1") + + #expect(tracker.dirtyRecords(for: .connection).isEmpty) + #expect(metadata.tombstones(for: .group).isEmpty) + } + + @Test("clearDirty removes one id; clearAllDirty clears the type") + func clearDirtyBehavior() { + tracker.markDirty(.connection, ids: ["a", "b"]) + tracker.clearDirty(.connection, id: "a") + #expect(tracker.dirtyRecords(for: .connection) == ["b"]) + + tracker.clearAllDirty(.connection) + #expect(tracker.dirtyRecords(for: .connection).isEmpty) + } + + @Test("Dirty records are scoped per record type") + func dirtyRecordsScopedByType() { + tracker.markDirty(.connection, id: "x") + tracker.markDirty(.group, id: "y") + + #expect(tracker.dirtyRecords(for: .connection) == ["x"]) + #expect(tracker.dirtyRecords(for: .group) == ["y"]) + } +} From 7284e466bf3a6c092a7b0ad5ddcabc1f53ec2dae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Wed, 3 Jun 2026 11:14:06 +0700 Subject: [PATCH 08/12] docs: record Tier 4 Sync test + the allCases caller-grep lesson --- docs/development/refactor-audit.mdx | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/development/refactor-audit.mdx b/docs/development/refactor-audit.mdx index 4db5e77c2..6bc29aa06 100644 --- a/docs/development/refactor-audit.mdx +++ b/docs/development/refactor-audit.mdx @@ -51,7 +51,7 @@ Cross-plugin duplication (investigated 2026-06-03; the original audit overstated ## Tier 3: consistency / hygiene - [x] `Data.hexEncoded` reinvented 11x: extracted `Sequence where Element == UInt8 { hexEncoded }`, routed all 11 call sites. (2026-06-03) -- [x] Dead `DatabaseType.allCases` compatibility shim: removed (verified no callers). (2026-06-03) +- [x] `DatabaseType.allCases` compatibility shim: removed, callers migrated to `allKnownTypes`. (2026-06-03) NOTE: my first "no callers" grep checked only `TablePro`/`Plugins`, missing 6 usages in `TableProTests` (DatabaseType*Tests). The app-only build passed but the TEST build failed; caught by `build-for-testing` and fixed (callers -> `allKnownTypes`, deleted the shim-equivalence test). Lesson: grep `TableProTests` too for any public-API removal. - [~] DialectQuoteHelper: VERIFIED FALSE POSITIVE. `quoteIdentifierFromDialect`/`resolveSQLDialect` are used by FilterSQLGenerator, SQLStatementGenerator, ExportDataTool, SQLRowToStatementConverter, QueryTab, MainContentCoordinator. NOT dead. Left as-is. - [ ] `DatabaseType.brandColor` (hardcoded 19-case switch, one caller). Attempted routing through `PluginManager.brandColor(for:)`, then REVERTED: that helper resolves via `pluginTypeId`, which maps variants to their driver (MariaDB->MySQL), and the variant registry snapshots carry no `brandColorHex`, so variants would lose their color (color/icon mismatch). PROPER FIX (deliberate, not hygiene): add `brandColorHex` to the variant registry snapshots, then delegate to `snapshot(forTypeId: rawValue)?.brandColorHex` like `iconName` already does. - [~] `SidebarViewModel` static registry: VERIFIED the audit's fix is WRONG. `MainContentCoordinator.sidebarViewModel` is `weak` and `SidebarView` only holds a local; the registry is the SOLE strong owner, so `NSMapTable` weak values would deallocate the VM immediately and break the sidebar. Lifecycle is managed by `removeConnection` on session teardown (`DatabaseManager+Sessions.swift:365`). Design is sound; left as-is. @@ -63,7 +63,10 @@ Cross-plugin duplication (investigated 2026-06-03; the original audit overstated ## Tier 4: test coverage gaps -- [ ] Sync has zero tests: `ConflictResolver`, `SyncCoordinator`, `SyncChangeTracker`. Holds last-write-wins/merge logic and the documented "markDeleted after saveConnections" delete-ordering invariant. Highest-value gap. +- [~] "Sync has zero tests" was FALSE: `CloudKitSyncEngineTests` + `SyncRecordMapperFavoriteTableTests` exist, and `ConnectionStorageSyncDeleteTests` already covers the delete-ordering invariant ("records a tombstone after it is persisted" / "persist before notify"). The real gap was direct unit tests for the dirty/tombstone machinery. +- [x] Added `SyncChangeTrackerTests` (dirty/tombstone/suppression/clear + per-type scoping), isolated via an injected `SyncMetadataStorage(userDefaults:)`. (2026-06-03) Compiles (verified via `build-for-testing`); runs in CI (the test target can't run locally here, SwiftLint SPM plugin). +- [ ] `ConflictResolver`: still untested. Harder, it is a private-init `@MainActor` singleton over `CKRecord`; needs a serialized suite driving `.shared`. +- [ ] `SyncCoordinator`: untested. Large `@MainActor` coordinator with many dependencies; not cleanly unit-testable in isolation. - [ ] Untested storage: `AppSettingsStorage`, `SSHProfileStorage`, `TabStateStorage` (the 500KB truncation guard). ## Policy decisions (owner's call) From e31cc96c9181d088bc98a11750404ec8e0f4af87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Wed, 3 Jun 2026 11:20:59 +0700 Subject: [PATCH 09/12] test(sync): add ConflictResolver queue and keep-local/keep-server resolution tests --- .../Core/Sync/ConflictResolverTests.swift | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 TableProTests/Core/Sync/ConflictResolverTests.swift diff --git a/TableProTests/Core/Sync/ConflictResolverTests.swift b/TableProTests/Core/Sync/ConflictResolverTests.swift new file mode 100644 index 000000000..7a2e93b8d --- /dev/null +++ b/TableProTests/Core/Sync/ConflictResolverTests.swift @@ -0,0 +1,89 @@ +// +// ConflictResolverTests.swift +// TableProTests +// + +import CloudKit +import Foundation +import Testing + +@testable import TablePro + +@Suite("ConflictResolver", .serialized) +@MainActor +struct ConflictResolverTests { + private let resolver = ConflictResolver.shared + + init() { + while resolver.hasConflicts { + _ = resolver.resolveCurrentConflict(keepLocal: false) + } + } + + private func makeConflict(local: [String: String], server: [String: String]) -> SyncConflict { + let localRecord = CKRecord(recordType: "Connection") + for (key, value) in local { + localRecord[key] = value as CKRecordValue + } + let serverRecord = CKRecord(recordType: "Connection") + for (key, value) in server { + serverRecord[key] = value as CKRecordValue + } + return SyncConflict( + recordType: .connection, + entityName: "users", + localRecord: localRecord, + serverRecord: serverRecord, + localModifiedAt: Date(timeIntervalSince1970: 100), + serverModifiedAt: Date(timeIntervalSince1970: 200) + ) + } + + @Test("addConflict queues the conflict as current") + func addConflictQueues() { + #expect(!resolver.hasConflicts) + resolver.addConflict(makeConflict(local: ["name": "L"], server: ["name": "S"])) + + #expect(resolver.hasConflicts) + #expect(resolver.currentConflict?.entityName == "users") + + _ = resolver.resolveCurrentConflict(keepLocal: false) + } + + @Test("Keeping the server version discards the conflict and returns nil") + func keepServerReturnsNil() { + resolver.addConflict(makeConflict(local: ["name": "L"], server: ["name": "S"])) + + let result = resolver.resolveCurrentConflict(keepLocal: false) + + #expect(result == nil) + #expect(!resolver.hasConflicts) + } + + @Test("Keeping local copies local field values onto the server record") + func keepLocalCopiesFieldsOntoServerRecord() { + resolver.addConflict(makeConflict(local: ["name": "Local"], server: ["name": "Server"])) + + let resolved = resolver.resolveCurrentConflict(keepLocal: true) + + #expect(resolved?["name"] as? String == "Local") + #expect(!resolver.hasConflicts) + } + + @Test("Conflicts are resolved in FIFO order") + func conflictsResolveInFifoOrder() { + resolver.addConflict(makeConflict(local: ["name": "first"], server: ["name": "s1"])) + resolver.addConflict(makeConflict(local: ["name": "second"], server: ["name": "s2"])) + + #expect(resolver.currentConflict?.localRecord["name"] as? String == "first") + _ = resolver.resolveCurrentConflict(keepLocal: false) + #expect(resolver.currentConflict?.localRecord["name"] as? String == "second") + _ = resolver.resolveCurrentConflict(keepLocal: false) + #expect(!resolver.hasConflicts) + } + + @Test("Resolving with no pending conflicts returns nil") + func resolveWithNoConflictsReturnsNil() { + #expect(resolver.resolveCurrentConflict(keepLocal: false) == nil) + } +} From b82af570c3e2f2a1cf404e0174101132c74f680c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Wed, 3 Jun 2026 11:21:16 +0700 Subject: [PATCH 10/12] docs: mark ConflictResolver tested, record storage-coverage findings --- docs/development/refactor-audit.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/development/refactor-audit.mdx b/docs/development/refactor-audit.mdx index 6bc29aa06..b4a7666ec 100644 --- a/docs/development/refactor-audit.mdx +++ b/docs/development/refactor-audit.mdx @@ -65,9 +65,9 @@ Cross-plugin duplication (investigated 2026-06-03; the original audit overstated - [~] "Sync has zero tests" was FALSE: `CloudKitSyncEngineTests` + `SyncRecordMapperFavoriteTableTests` exist, and `ConnectionStorageSyncDeleteTests` already covers the delete-ordering invariant ("records a tombstone after it is persisted" / "persist before notify"). The real gap was direct unit tests for the dirty/tombstone machinery. - [x] Added `SyncChangeTrackerTests` (dirty/tombstone/suppression/clear + per-type scoping), isolated via an injected `SyncMetadataStorage(userDefaults:)`. (2026-06-03) Compiles (verified via `build-for-testing`); runs in CI (the test target can't run locally here, SwiftLint SPM plugin). -- [ ] `ConflictResolver`: still untested. Harder, it is a private-init `@MainActor` singleton over `CKRecord`; needs a serialized suite driving `.shared`. +- [x] Added `ConflictResolverTests` (queue/FIFO, keep-server returns nil, keep-local copies local fields onto the server record), via a `.serialized` suite that drains `.shared` in `init`. (2026-06-03) Compiles via `build-for-testing`; runs in CI. - [ ] `SyncCoordinator`: untested. Large `@MainActor` coordinator with many dependencies; not cleanly unit-testable in isolation. -- [ ] Untested storage: `AppSettingsStorage`, `SSHProfileStorage`, `TabStateStorage` (the 500KB truncation guard). +- [~] "Untested storage" was mostly false: the 500KB truncation guard is already covered by `TabPersistenceCoordinatorTests` ("Large query over 500KB is truncated", 600KB query). `SSHProfileStorage` is Keychain-backed (Keychain tests are flaky/quarantined here, skip). `AppSettingsStorage` is a thin Codable-through-`UserDefaults` wrapper (low bug-risk boilerplate); injectable via `init(userDefaults:)` if a round-trip test is wanted. ## Policy decisions (owner's call) From 52adaa1995278fe81b22dbf572a6383f574f9f49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ng=C3=B4=20Qu=E1=BB=91c=20=C4=90=E1=BA=A1t?= Date: Wed, 3 Jun 2026 10:54:48 +0700 Subject: [PATCH 11/12] refactor(hig): show activity-log export errors as sheets, rename ThemeLayout file to ThemeFonts --- .../{ThemeLayout.swift => ThemeFonts.swift} | 2 +- .../IntegrationsActivityLogPane.swift | 26 ++++++++++++------- 2 files changed, 18 insertions(+), 10 deletions(-) rename TablePro/Theme/{ThemeLayout.swift => ThemeFonts.swift} (98%) diff --git a/TablePro/Theme/ThemeLayout.swift b/TablePro/Theme/ThemeFonts.swift similarity index 98% rename from TablePro/Theme/ThemeLayout.swift rename to TablePro/Theme/ThemeFonts.swift index 4f796657c..8fa0901dd 100644 --- a/TablePro/Theme/ThemeLayout.swift +++ b/TablePro/Theme/ThemeFonts.swift @@ -1,5 +1,5 @@ // -// ThemeLayout.swift +// ThemeFonts.swift // TablePro // diff --git a/TablePro/Views/Integrations/IntegrationsActivityLogPane.swift b/TablePro/Views/Integrations/IntegrationsActivityLogPane.swift index 45d9afe1f..6df0fd756 100644 --- a/TablePro/Views/Integrations/IntegrationsActivityLogPane.swift +++ b/TablePro/Views/Integrations/IntegrationsActivityLogPane.swift @@ -233,18 +233,26 @@ struct IntegrationsActivityLogPane: View { panel.canCreateDirectories = true panel.title = String(localized: "Export Activity Log") - guard panel.runModal() == .OK, let url = panel.url else { return } + if let window = AlertHelper.resolveWindow(nil) { + panel.beginSheetModal(for: window) { response in + guard response == .OK, let url = panel.url else { return } + writeActivityLog(to: url) + } + } else { + guard panel.runModal() == .OK, let url = panel.url else { return } + writeActivityLog(to: url) + } + } - let csv = csvString(for: filteredEntries) + private func writeActivityLog(to url: URL) { do { - try csv.write(to: url, atomically: true, encoding: .utf8) + try csvString(for: filteredEntries).write(to: url, atomically: true, encoding: .utf8) } catch { - let alert = NSAlert() - alert.messageText = String(localized: "Could not export activity log") - alert.informativeText = error.localizedDescription - alert.alertStyle = .warning - alert.addButton(withTitle: String(localized: "OK")) - alert.runModal() + AlertHelper.showErrorSheet( + title: String(localized: "Could not export activity log"), + message: error.localizedDescription, + window: nil + ) } } From c05a624ccebabcf323d61076c6b9dfe204d5c897 Mon Sep 17 00:00:00 2001 From: Ngo Quoc Dat Date: Wed, 3 Jun 2026 17:19:08 +0700 Subject: [PATCH 12/12] wip --- docs/development/refactor-audit.mdx | 75 ----------------------------- 1 file changed, 75 deletions(-) delete mode 100644 docs/development/refactor-audit.mdx diff --git a/docs/development/refactor-audit.mdx b/docs/development/refactor-audit.mdx deleted file mode 100644 index b4a7666ec..000000000 --- a/docs/development/refactor-audit.mdx +++ /dev/null @@ -1,75 +0,0 @@ ---- -title: "Refactor audit" -description: "Internal engineering audit and refactor checklist. Not part of the published navigation." ---- - -Internal tracking doc. Not linked in `docs.json` on purpose. Candid file:line findings for an ongoing, tiered refactor. Source: full-codebase audit on 2026-06-03. - -Headline: the codebase is well built and well maintained. Zero real `print()` in production code, near-zero force unwraps, one coherent modern concurrency model (async/await + actors), disciplined error types, a native color/material system. This is not a rewrite. It is a sequence of surgical, root-cause refactors. Work top-down by tier; keep each commit reviewable and atomic. - -## Tier 0: confirmed bugs (done) - -- [x] JSON import plugin failed to load: `JSONImportPlugin/Info.plist` declared `TableProPluginKitVersion` 17 while the app is at 18, so the gate at `PluginManager.swift:471` rejected it. Bumped to 18. -- [x] Change managers defaulted to MySQL dialect. `DataChangeManager.databaseType` is now optional and `generateSQL` throws if unconfigured instead of silently emitting MySQL; `configureForTable` requires an explicit dialect. `StructureChangeManager.databaseType` was dead write-only state (DDL comes from `SchemaStatementGenerator(pluginDriver:)`) and was removed. -- [x] `SQLExportPlugin.isNumericLiteral` accepted garbage (`+-e..`, `1.2.3`, empty) and emitted it unquoted into INSERT. Replaced with the kit's validated check, now exposed as `PluginNumericLiteral.isValid` in TableProPluginKit (additive, ABI-safe). -- [x] `CloudflareTunnelManager.persistPidRecords` swallowed the encode failure with `try?`, so leaked `cloudflared` PIDs were never cleaned up. Now logs via OSLog. - -## Tier 1: high-value architecture - -- [ ] `MainContentCoordinator` god object: 1,392-line core + 34 extension files (~4,278 lines) in one `@Observable` type, ~30 responsibilities. Push state and logic into the existing sub-coordinators (`filterCoordinator`, `paginationCoordinator`, `rowEditingCoordinator`, `queryExecutionCoordinator`) and have the main type delegate. -- [ ] DI is theatre: 75 `static let shared` in `Core/`; `AppServices.live` wires 26 of them. Route through the `AppServices` composition root, drop `.shared` from types that flow through it. -- [ ] Triplicated reconnect logic in `DatabaseManager+Health.swift:111-308` (x2) and `+Sessions.swift:108-155`. Extract one `establishDriver(for:)` (connect + timeout + startup commands + restore schema/db). -- [ ] `ConnectionStorage` god object (988L): CRUD + file persistence + UserDefaults migration + Keychain for 8 secret kinds + the `StoredConnection` model. Split into `ConnectionStore` / `ConnectionSecretsStore` / model file. -- [ ] `Core/Services/Infrastructure/` (38 files) mixes AppKit view controllers (`MainSplitViewController`, `TabWindowController`, toolbars) into Core. Move view controllers to `Views/`; split the rest into Launch / Windowing / Tabs. -Cross-plugin duplication (investigated 2026-06-03; the original audit overstated this, 3 of 5 were false positives): - -- [x] `isNumericColumnType` (SQLExport/JSONExport): byte-identical. Extracted to `PluginExportUtilities.isNumericColumnType`. -- [x] SQL filter/where/order-by builder (Oracle + MSSQL): real. Extracted to `PluginSQLFilter` (parameterized by quote/escape/regex closures), which also makes the injection-defense escaping a single canonical copy. ClickHouse's `buildWhereClause` is unrelated (it is DML, not browse filtering). -- [x] JSON scalar decoder (`D1Value`/`HranaValue`/`BQCellValue`): FALSE POSITIVE. Three different wire formats (bare scalar vs keyed `{type,value,base64}` vs recursive `record`/`array`). A shared type would be a leaky union. Skipped. -- [x] Schemaless flattener (Mongo/DynamoDB): FALSE POSITIVE. Different priority-key source, ordering, value types, return types. Skipped. -- [x] DML INSERT/UPDATE/DELETE generator (Oracle/MSSQL/ClickHouse): FALSE POSITIVE. Differ in quoting, default/identity handling, WHERE strategy, row-limit, statement form (standard vs ClickHouse ALTER TABLE), and dispatch. Correctness-critical, no headless test net. Skipped. - -- [x] Oversized plugin files split (2026-06-03), all build green: - - Redis 1695 -> 632 + `+Operations` (510) / `+Scan` (85) / `+ResultBuilding` (486) - - DuckDB 1491 -> 879 + `DuckDBConnection` (600) / `DuckDBPluginError` (25) - - Cassandra 1449 -> 588 + `CassandraConnection` (851) / `CassandraPluginError` (25) - - ClickHouse 1426 -> 799 + `+Schema` (340) / `+Http` (309) - - MSSQL 1403 -> 710 + `+Schema` (533) / `+DDL` (184) - - Mongo 1351 -> 903 + `+SyncHelpers` (458) - - Oracle (1152) and the rest already under the limit after de-dup. - - The 6 split plugins use Xcode synchronized file-system groups, so new files auto-included with NO pbxproj edits. Splitting across files required broadening the `private` members the moved code touches to `internal` (the established `MainContentCoordinator+*` pattern). -- [ ] DynamoDB (1238, 38 over) left intact. It is the one plugin still on a legacy explicit `PBXGroup` (not a synchronized group), so a new file needs manual pbxproj target-membership wiring, which the build system did not honor (xcodebuild never compiled the added file). Not worth the risk for 38 lines on a limit that is not CI-enforced for plugins. To do it properly: convert `Plugins/DynamoDBDriverPlugin` to a `PBXFileSystemSynchronizedRootGroup` (matching the other plugins), then split like the rest. - -## Tier 2: macOS HIG / native - -- [ ] No design-token layer. 5 competing corner radii (`6`x34, `4`x17, `8`x8), ad-hoc paddings, hardcoded control sizes. Add `Theme/Metrics.swift` (the existing `ThemeLayout.swift` only holds fonts; misnamed). Route literals through it. -- [ ] `runModal()` blocking panels instead of sheets: `IntegrationsActivityLogPane.swift:236,242`, `ERDiagramView.swift:270`. Use `beginSheetModal(for:)`. -- [ ] Thin accessibility: 61/343 view files set accessibility labels. Backfill `.accessibilityLabel` on icon-only controls. -- [ ] Business logic in views: `ExportDialog.swift:674` builds and runs raw `information_schema`/Oracle SQL inside a SwiftUI View. Move to an export service. -- [ ] Raw RGB colors in `ChatComposerView.swift:204`, hardcoded editor font in `HighlightedSQLTextView.swift:14`. Route through the theme. - -## Tier 3: consistency / hygiene - -- [x] `Data.hexEncoded` reinvented 11x: extracted `Sequence where Element == UInt8 { hexEncoded }`, routed all 11 call sites. (2026-06-03) -- [x] `DatabaseType.allCases` compatibility shim: removed, callers migrated to `allKnownTypes`. (2026-06-03) NOTE: my first "no callers" grep checked only `TablePro`/`Plugins`, missing 6 usages in `TableProTests` (DatabaseType*Tests). The app-only build passed but the TEST build failed; caught by `build-for-testing` and fixed (callers -> `allKnownTypes`, deleted the shim-equivalence test). Lesson: grep `TableProTests` too for any public-API removal. -- [~] DialectQuoteHelper: VERIFIED FALSE POSITIVE. `quoteIdentifierFromDialect`/`resolveSQLDialect` are used by FilterSQLGenerator, SQLStatementGenerator, ExportDataTool, SQLRowToStatementConverter, QueryTab, MainContentCoordinator. NOT dead. Left as-is. -- [ ] `DatabaseType.brandColor` (hardcoded 19-case switch, one caller). Attempted routing through `PluginManager.brandColor(for:)`, then REVERTED: that helper resolves via `pluginTypeId`, which maps variants to their driver (MariaDB->MySQL), and the variant registry snapshots carry no `brandColorHex`, so variants would lose their color (color/icon mismatch). PROPER FIX (deliberate, not hygiene): add `brandColorHex` to the variant registry snapshots, then delegate to `snapshot(forTypeId: rawValue)?.brandColorHex` like `iconName` already does. -- [~] `SidebarViewModel` static registry: VERIFIED the audit's fix is WRONG. `MainContentCoordinator.sidebarViewModel` is `weak` and `SidebarView` only holds a local; the registry is the SOLE strong owner, so `NSMapTable` weak values would deallocate the VM immediately and break the sidebar. Lifecycle is managed by `removeConnection` on session teardown (`DatabaseManager+Sessions.swift:365`). Design is sound; left as-is. -- [ ] 19 plugin `String(localized: "...\(x)")` dynamic-key calls (DynamoDB, Etcd, Mongo, CloudflareD1, Redis). Plugins ship no `.xcstrings`, so these are passthroughs today (not user-visible bugs), but they break if a catalog is ever added. POLICY: add per-plugin catalogs or drop `String(localized:)` in plugin code. -- [x] Driver error types -> `PluginDriverError`: VERIFIED ALREADY DONE. All driver error types (MariaDB, LibPQ, Oracle, SQLite, MongoDB, Redis, etc.) already conform. Audit false positive. -- [ ] `PluginStreamElement` is `@frozen` and the kit uses Library Evolution, so dropping `@frozen` is a BINARY-BREAKING ABI change (many plugins switch over it), needing a coordinated 18->19 bump + `release-all-plugins`. NOT hygiene. Defer to the next deliberate ABI bump; do it then so future additive cases are free. -- [ ] Suffix vocabulary overload (Manager 22 / Service 15 / Coordinator 12 / Provider 22). 14 top-level free functions (mostly `ConnectionGroupTree.swift`) violate explicit access control. -- [ ] Five `@Observable @MainActor` state classes live under `Models/` (view state). Move to `ViewModels/`. - -## Tier 4: test coverage gaps - -- [~] "Sync has zero tests" was FALSE: `CloudKitSyncEngineTests` + `SyncRecordMapperFavoriteTableTests` exist, and `ConnectionStorageSyncDeleteTests` already covers the delete-ordering invariant ("records a tombstone after it is persisted" / "persist before notify"). The real gap was direct unit tests for the dirty/tombstone machinery. -- [x] Added `SyncChangeTrackerTests` (dirty/tombstone/suppression/clear + per-type scoping), isolated via an injected `SyncMetadataStorage(userDefaults:)`. (2026-06-03) Compiles (verified via `build-for-testing`); runs in CI (the test target can't run locally here, SwiftLint SPM plugin). -- [x] Added `ConflictResolverTests` (queue/FIFO, keep-server returns nil, keep-local copies local fields onto the server record), via a `.serialized` suite that drains `.shared` in `init`. (2026-06-03) Compiles via `build-for-testing`; runs in CI. -- [ ] `SyncCoordinator`: untested. Large `@MainActor` coordinator with many dependencies; not cleanly unit-testable in isolation. -- [~] "Untested storage" was mostly false: the 500KB truncation guard is already covered by `TabPersistenceCoordinatorTests` ("Large query over 500KB is truncated", 600KB query). `SSHProfileStorage` is Keychain-backed (Keychain tests are flaky/quarantined here, skip). `AppSettingsStorage` is a thin Codable-through-`UserDefaults` wrapper (low bug-risk boilerplate); injectable via `init(userDefaults:)` if a round-trip test is wanted. - -## Policy decisions (owner's call) - -- [ ] The "no comments" rule. About 1,700 inline + 2,200 `///` doc comments exist; SwiftLint ignores comments, so the rule is unenforced. Mass deletion is high-churn and low-value. Recommendation: relax CLAUDE.md to "no what-comments; `///` API docs and load-bearing why-notes allowed" rather than stripping thousands of lines. -- [ ] Suffix vocabulary convention (Manager vs Service vs Coordinator). One line in CLAUDE.md before any rename pass.