diff --git a/CHANGELOG.md b/CHANGELOG.md index 72c6aa956..1f18f0ffa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed + +- JSON file import works again. It failed to load in 0.48.0. +- SQL export quotes empty or malformed values in numeric columns instead of writing them unquoted, which could produce invalid INSERT statements. ### Added - Each filter row has a checkbox to turn it on or off and an Apply button to filter by just that row. The main Apply runs every active filter, and disabled filters stay in the panel for later. (#1561) diff --git a/Plugins/CassandraDriverPlugin/CassandraConnection.swift b/Plugins/CassandraDriverPlugin/CassandraConnection.swift new file mode 100644 index 000000000..b80b97e67 --- /dev/null +++ b/Plugins/CassandraDriverPlugin/CassandraConnection.swift @@ -0,0 +1,851 @@ +// +// CassandraConnection.swift +// CassandraDriverPlugin +// + +#if canImport(CCassandra) +import CCassandra +#endif +import Foundation +import os +import TableProPluginKit + +actor CassandraConnectionActor { + private static let logger = Logger(subsystem: "com.TablePro.CassandraDriver", category: "Connection") + + nonisolated(unsafe) private static let isoFormatter: ISO8601DateFormatter = { + let f = ISO8601DateFormatter() + f.formatOptions = [.withInternetDateTime, .withFractionalSeconds] + return f + }() + + nonisolated(unsafe) private static let dateFormatter: DateFormatter = { + let f = DateFormatter() + f.dateFormat = "yyyy-MM-dd" + f.timeZone = TimeZone(identifier: "UTC") + return f + }() + + private var cluster: OpaquePointer? // CassCluster* + private var session: OpaquePointer? // CassSession* + private var currentKeyspace: String? + + var isConnected: Bool { session != nil } + + var keyspace: String? { currentKeyspace } + + func connect( + host: String, + port: Int, + username: String?, + password: String?, + keyspace: String?, + sslMode: SSLMode, + sslCaCertPath: String?, + sslClientCertPath: String?, + sslClientKeyPath: String?, + sslClientKeyPassphrase: String? + ) throws { + cluster = cass_cluster_new() + guard let cluster else { + throw CassandraPluginError.connectionFailed("Failed to create cluster object") + } + + cass_cluster_set_contact_points(cluster, host) + cass_cluster_set_port(cluster, Int32(port)) + + if let username, !username.isEmpty, let password { + cass_cluster_set_credentials(cluster, username, password) + } + + if sslMode != .disabled { + guard let ssl = cass_ssl_new() else { + cass_cluster_free(cluster) + self.cluster = nil + throw CassandraPluginError.connectionFailed("Failed to create SSL context") + } + + cass_ssl_set_verify_flags(ssl, CassandraSSLMapping.verifyFlags(for: sslMode)) + + if sslMode == .verifyCa || sslMode == .verifyIdentity { + guard let caCertPath = sslCaCertPath, !caCertPath.isEmpty else { + cass_ssl_free(ssl) + cass_cluster_free(cluster) + self.cluster = nil + throw SSLHandshakeError.untrustedCertificate(serverMessage: "Verify CA or Verify Identity requires a CA certificate path") + } + guard let certData = FileManager.default.contents(atPath: caCertPath), + let certString = String(data: certData, encoding: .utf8) else { + cass_ssl_free(ssl) + cass_cluster_free(cluster) + self.cluster = nil + throw SSLHandshakeError.untrustedCertificate(serverMessage: "Could not read CA certificate at \(caCertPath)") + } + let rc = cass_ssl_add_trusted_cert(ssl, certString) + if rc != CASS_OK { + cass_ssl_free(ssl) + cass_cluster_free(cluster) + self.cluster = nil + throw SSLHandshakeError.untrustedCertificate(serverMessage: "CA certificate at \(caCertPath) is not a valid PEM") + } + } + + let trimmedClientCertPath = sslClientCertPath?.trimmingCharacters(in: .whitespaces) ?? "" + let trimmedClientKeyPath = sslClientKeyPath?.trimmingCharacters(in: .whitespaces) ?? "" + if !trimmedClientCertPath.isEmpty || !trimmedClientKeyPath.isEmpty { + try applyClientCertificate( + to: ssl, + certPath: trimmedClientCertPath, + keyPath: trimmedClientKeyPath, + keyPassphrase: sslClientKeyPassphrase + ) { + cass_ssl_free(ssl) + cass_cluster_free(cluster) + self.cluster = nil + } + } + + cass_cluster_set_ssl(cluster, ssl) + cass_ssl_free(ssl) + } + + // Connection timeout (10 seconds) + cass_cluster_set_connect_timeout(cluster, 10_000) + cass_cluster_set_request_timeout(cluster, 30_000) + + let newSession = cass_session_new() + guard let newSession else { + cass_cluster_free(cluster) + self.cluster = nil + throw CassandraPluginError.connectionFailed("Failed to create session") + } + + let connectFuture: OpaquePointer? + if let keyspace, !keyspace.isEmpty { + connectFuture = cass_session_connect_keyspace(newSession, cluster, keyspace) + currentKeyspace = keyspace + } else { + connectFuture = cass_session_connect(newSession, cluster) + currentKeyspace = nil + } + + guard let future = connectFuture else { + cass_session_free(newSession) + cass_cluster_free(cluster) + self.cluster = nil + throw CassandraPluginError.connectionFailed("Failed to initiate connection") + } + + cass_future_wait(future) + let rc = cass_future_error_code(future) + + if rc != CASS_OK { + let errorMessage = extractFutureError(future) + cass_future_free(future) + cass_session_free(newSession) + cass_cluster_free(cluster) + self.cluster = nil + if let sslError = Self.classifySSLError(rc: rc, message: errorMessage) { + throw sslError + } + throw CassandraPluginError.connectionFailed(errorMessage) + } + + cass_future_free(future) + session = newSession + + Self.logger.info("Connected to Cassandra at \(host):\(port)") + } + + private func applyClientCertificate( + to ssl: OpaquePointer, + certPath: String, + keyPath: String, + keyPassphrase: String?, + cleanup: () -> Void + ) throws { + guard !certPath.isEmpty else { + cleanup() + throw SSLHandshakeError.clientCertRequired(serverMessage: "A client certificate is required when a client key is set") + } + guard !keyPath.isEmpty else { + cleanup() + throw SSLHandshakeError.clientCertRequired(serverMessage: "A client key is required when a client certificate is set") + } + + guard let certData = FileManager.default.contents(atPath: certPath), + let certString = String(data: certData, encoding: .utf8) else { + cleanup() + throw SSLHandshakeError.clientCertRequired(serverMessage: "Could not read client certificate at \(certPath)") + } + let certResult = cass_ssl_set_cert(ssl, certString) + if certResult != CASS_OK { + cleanup() + throw SSLHandshakeError.clientCertRequired(serverMessage: "Client certificate at \(certPath) is not a valid PEM") + } + + guard let keyData = FileManager.default.contents(atPath: keyPath), + let keyString = String(data: keyData, encoding: .utf8) else { + cleanup() + throw SSLHandshakeError.clientKeyInvalid(serverMessage: "Could not read client key at \(keyPath)") + } + let passphrase = keyPassphrase?.isEmpty == false ? keyPassphrase : nil + let keyResult = cass_ssl_set_private_key(ssl, keyString, passphrase) + if keyResult != CASS_OK { + cleanup() + throw Self.privateKeyLoadError(keyPEM: keyString, hasPassphrase: passphrase != nil, keyPath: keyPath) + } + } + + static func isEncryptedPrivateKey(_ pem: String) -> Bool { + pem.contains("ENCRYPTED PRIVATE KEY") || (pem.contains("Proc-Type:") && pem.contains("ENCRYPTED")) + } + + static func privateKeyLoadError(keyPEM: String, hasPassphrase: Bool, keyPath: String) -> SSLHandshakeError { + guard isEncryptedPrivateKey(keyPEM) else { + return .clientKeyInvalid(serverMessage: "The client key at \(keyPath) is not a valid private key") + } + if hasPassphrase { + return .clientKeyPassphraseIncorrect(serverMessage: "The passphrase for the client key at \(keyPath) is incorrect") + } + return .clientKeyPassphraseRequired(serverMessage: "The client key at \(keyPath) is encrypted. Enter its passphrase.") + } + + func close() { + if let session { + let closeFuture = cass_session_close(session) + if let closeFuture { + cass_future_wait(closeFuture) + cass_future_free(closeFuture) + } + cass_session_free(session) + self.session = nil + } + + if let cluster { + cass_cluster_free(cluster) + self.cluster = nil + } + + currentKeyspace = nil + Self.logger.info("Disconnected from Cassandra") + } + + func executeQuery(_ cql: String) throws -> CassandraRawResult { + guard let session else { + throw CassandraPluginError.notConnected + } + + let startTime = Date() + let statement = cass_statement_new(cql, 0) + guard let statement else { + throw CassandraPluginError.queryFailed("Failed to create statement") + } + + defer { cass_statement_free(statement) } + + let future = cass_session_execute(session, statement) + guard let future else { + throw CassandraPluginError.queryFailed("Failed to execute query") + } + + defer { cass_future_free(future) } + + cass_future_wait(future) + let rc = cass_future_error_code(future) + + if rc != CASS_OK { + throw CassandraPluginError.queryFailed(extractFutureError(future)) + } + + let result = cass_future_get_result(future) + defer { + if let result { cass_result_free(result) } + } + + guard let result else { + let executionTime = Date().timeIntervalSince(startTime) + return CassandraRawResult( + columns: [], + columnTypeNames: [], + rows: [], + rowsAffected: 0, + executionTime: executionTime + ) + } + + return extractResult(from: result, startTime: startTime) + } + + func executePrepared(_ cql: String, parameters: [PluginCellValue]) throws -> CassandraRawResult { + guard let session else { + throw CassandraPluginError.notConnected + } + + let startTime = Date() + + // Prepare + let prepareFuture = cass_session_prepare(session, cql) + guard let prepareFuture else { + throw CassandraPluginError.queryFailed("Failed to prepare statement") + } + defer { cass_future_free(prepareFuture) } + + cass_future_wait(prepareFuture) + let prepRc = cass_future_error_code(prepareFuture) + if prepRc != CASS_OK { + throw CassandraPluginError.queryFailed(extractFutureError(prepareFuture)) + } + + let prepared = cass_future_get_prepared(prepareFuture) + guard let prepared else { + throw CassandraPluginError.queryFailed("Failed to get prepared statement") + } + defer { cass_prepared_free(prepared) } + + // Bind parameters + let statement = cass_prepared_bind(prepared) + guard let statement else { + throw CassandraPluginError.queryFailed("Failed to bind prepared statement") + } + defer { cass_statement_free(statement) } + + for (index, param) in parameters.enumerated() { + switch param { + case .text(let value): + cass_statement_bind_string(statement, index, value) + case .bytes(let data): + data.withUnsafeBytes { rawBuffer in + if let base = rawBuffer.baseAddress?.assumingMemoryBound(to: UInt8.self) { + cass_statement_bind_bytes(statement, index, base, data.count) + } else { + cass_statement_bind_null(statement, index) + } + } + case .null: + cass_statement_bind_null(statement, index) + } + } + + // Execute + let future = cass_session_execute(session, statement) + guard let future else { + throw CassandraPluginError.queryFailed("Failed to execute prepared statement") + } + defer { cass_future_free(future) } + + cass_future_wait(future) + let rc = cass_future_error_code(future) + + if rc != CASS_OK { + throw CassandraPluginError.queryFailed(extractFutureError(future)) + } + + let result = cass_future_get_result(future) + defer { + if let result { cass_result_free(result) } + } + + guard let result else { + let executionTime = Date().timeIntervalSince(startTime) + return CassandraRawResult( + columns: [], + columnTypeNames: [], + rows: [], + rowsAffected: 0, + executionTime: executionTime + ) + } + + return extractResult(from: result, startTime: startTime) + } + + func switchKeyspace(_ keyspace: String) throws { + _ = try executeQuery("USE \"\(escapeIdentifier(keyspace))\"") + currentKeyspace = keyspace + } + + func serverVersion() throws -> String? { + let result = try executeQuery("SELECT release_version FROM system.local WHERE key = 'local'") + return result.rows.first?.first?.asText + } + + // MARK: - Private Helpers + + private func extractResult( + from result: OpaquePointer, + startTime: Date + ) -> CassandraRawResult { + let colCount = cass_result_column_count(result) + let rowCount = cass_result_row_count(result) + + var columns: [String] = [] + var columnTypeNames: [String] = [] + + for i in 0..? + var nameLength: Int = 0 + cass_result_column_name(result, i, &namePtr, &nameLength) + if let namePtr { + columns.append(String(cString: namePtr)) + } else { + columns.append("column_\(i)") + } + + let colType = cass_result_column_type(result, i) + columnTypeNames.append(Self.cassTypeName(colType)) + } + + var rows: [[PluginCellValue]] = [] + let iterator = cass_iterator_from_result(result) + defer { + if let iterator { cass_iterator_free(iterator) } + } + + guard let iterator else { + let executionTime = Date().timeIntervalSince(startTime) + return CassandraRawResult( + columns: columns, + columnTypeNames: columnTypeNames, + rows: [], + rowsAffected: Int(rowCount), + executionTime: executionTime + ) + } + + let maxRows = min(Int(rowCount), 100_000) + var count = 0 + + while cass_iterator_next(iterator) == cass_true && count < maxRows { + let row = cass_iterator_get_row(iterator) + guard let row else { continue } + + var rowData: [PluginCellValue] = [] + for col in 0.. Data? { + var bytes: UnsafePointer? + var length: Int = 0 + guard cass_value_get_bytes(value, &bytes, &length) == CASS_OK, let bytes else { + return nil + } + return Data(bytes: bytes, count: length) + } + + private static func extractStringValue(_ value: OpaquePointer) -> String? { + let valueType = cass_value_type(value) + + switch valueType { + case CASS_VALUE_TYPE_ASCII, CASS_VALUE_TYPE_TEXT, CASS_VALUE_TYPE_VARCHAR: + var output: UnsafePointer? + var outputLength: Int = 0 + let rc = cass_value_get_string(value, &output, &outputLength) + if rc == CASS_OK, let output { + return String( + bytesNoCopy: UnsafeMutableRawPointer(mutating: output), + length: outputLength, + encoding: .utf8, + freeWhenDone: false + ) + } + return nil + + case CASS_VALUE_TYPE_INT: + var intVal: Int32 = 0 + if cass_value_get_int32(value, &intVal) == CASS_OK { + return String(intVal) + } + return nil + + case CASS_VALUE_TYPE_BIGINT, CASS_VALUE_TYPE_COUNTER: + var bigintVal: Int64 = 0 + if cass_value_get_int64(value, &bigintVal) == CASS_OK { + return String(bigintVal) + } + return nil + + case CASS_VALUE_TYPE_SMALL_INT: + var smallVal: Int16 = 0 + if cass_value_get_int16(value, &smallVal) == CASS_OK { + return String(smallVal) + } + return nil + + case CASS_VALUE_TYPE_TINY_INT: + var tinyVal: Int8 = 0 + if cass_value_get_int8(value, &tinyVal) == CASS_OK { + return String(tinyVal) + } + return nil + + case CASS_VALUE_TYPE_FLOAT: + var floatVal: Float = 0 + if cass_value_get_float(value, &floatVal) == CASS_OK { + return String(floatVal) + } + return nil + + case CASS_VALUE_TYPE_DOUBLE: + var doubleVal: Double = 0 + if cass_value_get_double(value, &doubleVal) == CASS_OK { + return String(doubleVal) + } + return nil + + case CASS_VALUE_TYPE_BOOLEAN: + var boolVal: cass_bool_t = cass_false + if cass_value_get_bool(value, &boolVal) == CASS_OK { + return boolVal == cass_true ? "true" : "false" + } + return nil + + case CASS_VALUE_TYPE_UUID, CASS_VALUE_TYPE_TIMEUUID: + var uuid = CassUuid() + if cass_value_get_uuid(value, &uuid) == CASS_OK { + var buffer = [CChar](repeating: 0, count: Int(CASS_UUID_STRING_LENGTH)) + cass_uuid_string(uuid, &buffer) + return String(cString: buffer) + } + return nil + + case CASS_VALUE_TYPE_TIMESTAMP: + var timestamp: Int64 = 0 + if cass_value_get_int64(value, ×tamp) == CASS_OK { + let date = Date(timeIntervalSince1970: Double(timestamp) / 1000.0) + return isoFormatter.string(from: date) + } + return nil + + case CASS_VALUE_TYPE_BLOB: + if let data = extractBlobValue(value) { + return "0x" + data.map { String(format: "%02x", $0) }.joined() + } + return nil + + case CASS_VALUE_TYPE_INET: + var inet = CassInet() + if cass_value_get_inet(value, &inet) == CASS_OK { + var buffer = [CChar](repeating: 0, count: Int(CASS_INET_STRING_LENGTH)) + cass_inet_string(inet, &buffer) + return String(cString: buffer) + } + return nil + + case CASS_VALUE_TYPE_LIST, CASS_VALUE_TYPE_SET: + return extractCollectionString(value, open: "[", close: "]") + + case CASS_VALUE_TYPE_MAP: + return extractMapString(value) + + case CASS_VALUE_TYPE_TUPLE: + return extractCollectionString(value, open: "(", close: ")") + + case CASS_VALUE_TYPE_DATE: + var dateVal: UInt32 = 0 + if cass_value_get_uint32(value, &dateVal) == CASS_OK { + let daysSinceEpoch = Int64(dateVal) - Int64(1 << 31) + let epochSeconds = daysSinceEpoch * 86400 + let date = Date(timeIntervalSince1970: Double(epochSeconds)) + return dateFormatter.string(from: date) + } + return nil + + case CASS_VALUE_TYPE_TIME: + var timeVal: Int64 = 0 + if cass_value_get_int64(value, &timeVal) == CASS_OK { + // Cassandra time is nanoseconds since midnight + let totalSeconds = timeVal / 1_000_000_000 + let hours = totalSeconds / 3600 + let minutes = (totalSeconds % 3600) / 60 + let seconds = totalSeconds % 60 + let nanos = timeVal % 1_000_000_000 + if nanos > 0 { + let millis = nanos / 1_000_000 + return String(format: "%02lld:%02lld:%02lld.%03lld", hours, minutes, seconds, millis) + } + return String(format: "%02lld:%02lld:%02lld", hours, minutes, seconds) + } + return nil + + case CASS_VALUE_TYPE_DECIMAL, CASS_VALUE_TYPE_VARINT: + // Read as bytes and display as hex since proper numeric decoding + // requires BigInteger support not available in the C driver API + var bytes: UnsafePointer? + var length: Int = 0 + if cass_value_get_bytes(value, &bytes, &length) == CASS_OK, let bytes { + let data = Data(bytes: bytes, count: length) + return "0x" + data.map { String(format: "%02x", $0) }.joined() + } + return nil + + default: + // Fallback: try reading as string + var output: UnsafePointer? + var outputLength: Int = 0 + if cass_value_get_string(value, &output, &outputLength) == CASS_OK, let output { + return String( + bytesNoCopy: UnsafeMutableRawPointer(mutating: output), + length: outputLength, + encoding: .utf8, + freeWhenDone: false + ) + } + return "" + } + } + + private static func extractCollectionString( + _ value: OpaquePointer, + open: String, + close: String + ) -> String { + guard let iterator = cass_iterator_from_collection(value) else { + return "\(open)\(close)" + } + defer { cass_iterator_free(iterator) } + + var elements: [String] = [] + while cass_iterator_next(iterator) == cass_true { + if let elem = cass_iterator_get_value(iterator) { + elements.append(extractStringValue(elem) ?? "null") + } + } + return "\(open)\(elements.joined(separator: ", "))\(close)" + } + + private static func extractMapString(_ value: OpaquePointer) -> String { + guard let iterator = cass_iterator_from_map(value) else { + return "{}" + } + defer { cass_iterator_free(iterator) } + + var pairs: [String] = [] + while cass_iterator_next(iterator) == cass_true { + let key = cass_iterator_get_map_key(iterator) + let val = cass_iterator_get_map_value(iterator) + let keyStr = key.flatMap { extractStringValue($0) } ?? "null" + let valStr = val.flatMap { extractStringValue($0) } ?? "null" + pairs.append("\(keyStr): \(valStr)") + } + return "{\(pairs.joined(separator: ", "))}" + } + + private static func cassTypeName(_ type: CassValueType) -> String { + switch type { + case CASS_VALUE_TYPE_ASCII: return "ascii" + case CASS_VALUE_TYPE_BIGINT: return "bigint" + case CASS_VALUE_TYPE_BLOB: return "blob" + case CASS_VALUE_TYPE_BOOLEAN: return "boolean" + case CASS_VALUE_TYPE_COUNTER: return "counter" + case CASS_VALUE_TYPE_DECIMAL: return "decimal" + case CASS_VALUE_TYPE_DOUBLE: return "double" + case CASS_VALUE_TYPE_FLOAT: return "float" + case CASS_VALUE_TYPE_INT: return "int" + case CASS_VALUE_TYPE_TEXT: return "text" + case CASS_VALUE_TYPE_TIMESTAMP: return "timestamp" + case CASS_VALUE_TYPE_UUID: return "uuid" + case CASS_VALUE_TYPE_VARCHAR: return "varchar" + case CASS_VALUE_TYPE_VARINT: return "varint" + case CASS_VALUE_TYPE_TIMEUUID: return "timeuuid" + case CASS_VALUE_TYPE_INET: return "inet" + case CASS_VALUE_TYPE_DATE: return "date" + case CASS_VALUE_TYPE_TIME: return "time" + case CASS_VALUE_TYPE_SMALL_INT: return "smallint" + case CASS_VALUE_TYPE_TINY_INT: return "tinyint" + case CASS_VALUE_TYPE_LIST: return "list" + case CASS_VALUE_TYPE_MAP: return "map" + case CASS_VALUE_TYPE_SET: return "set" + case CASS_VALUE_TYPE_TUPLE: return "tuple" + case CASS_VALUE_TYPE_UDT: return "udt" + default: return "text" + } + } + + private func extractFutureError(_ future: OpaquePointer) -> String { + var message: UnsafePointer? + var messageLength: Int = 0 + cass_future_error_message(future, &message, &messageLength) + if let message { + return String( + bytesNoCopy: UnsafeMutableRawPointer(mutating: message), + length: messageLength, + encoding: .utf8, + freeWhenDone: false + ) ?? "Unknown error" + } + return "Unknown error" + } + + func streamQuery( + _ cql: String, + continuation: AsyncThrowingStream.Continuation + ) throws { + guard let session else { + throw CassandraPluginError.notConnected + } + + let pageSize: Int32 = 5_000 + let statement = cass_statement_new(cql, 0) + guard let statement else { + throw CassandraPluginError.queryFailed("Failed to create statement") + } + + cass_statement_set_paging_size(statement, pageSize) + + var headerSent = false + + defer { cass_statement_free(statement) } + + while true { + let future = cass_session_execute(session, statement) + guard let future else { + throw CassandraPluginError.queryFailed("Failed to execute query") + } + + cass_future_wait(future) + let rc = cass_future_error_code(future) + + if rc != CASS_OK { + let errorMessage = extractFutureError(future) + cass_future_free(future) + throw CassandraPluginError.queryFailed(errorMessage) + } + + let result = cass_future_get_result(future) + cass_future_free(future) + + guard let result else { break } + + if !headerSent { + let colCount = cass_result_column_count(result) + var columns: [String] = [] + var columnTypeNames: [String] = [] + + for i in 0..? + var nameLength: Int = 0 + cass_result_column_name(result, i, &namePtr, &nameLength) + if let namePtr { + columns.append(String(cString: namePtr)) + } else { + columns.append("column_\(i)") + } + let colType = cass_result_column_type(result, i) + columnTypeNames.append(Self.cassTypeName(colType)) + } + + continuation.yield(.header(PluginStreamHeader( + columns: columns, + columnTypeNames: columnTypeNames, + estimatedRowCount: nil + ))) + headerSent = true + } + + let colCount = cass_result_column_count(result) + let iterator = cass_iterator_from_result(result) + + if let iterator { + while cass_iterator_next(iterator) == cass_true { + let row = cass_iterator_get_row(iterator) + guard let row else { continue } + + var rowData: [PluginCellValue] = [] + for col in 0.. String { + value.replacingOccurrences(of: "\"", with: "\"\"") + } + + static func classifySSLError(rc: CassError, message: String) -> SSLHandshakeError? { + switch rc { + case CASS_ERROR_SSL_NO_PEER_CERT, CASS_ERROR_SSL_INVALID_PEER_CERT: + return .untrustedCertificate(serverMessage: message) + case CASS_ERROR_SSL_IDENTITY_MISMATCH: + return .hostnameMismatch(serverMessage: message) + case CASS_ERROR_SSL_INVALID_PRIVATE_KEY, CASS_ERROR_SSL_INVALID_CERT: + return .clientCertRequired(serverMessage: message) + case CASS_ERROR_SSL_PROTOCOL_ERROR: + return .cipherMismatch(serverMessage: message) + default: + break + } + let lower = message.lowercased() + if lower.contains("ssl handshake") || lower.contains("tls handshake") || lower.contains("ssl_connect") { + return .cipherMismatch(serverMessage: message) + } + return nil + } +} + +// MARK: - Raw Result + +struct CassandraRawResult: Sendable { + let columns: [String] + let columnTypeNames: [String] + let rows: [[PluginCellValue]] + let rowsAffected: Int + let executionTime: TimeInterval +} diff --git a/Plugins/CassandraDriverPlugin/CassandraPlugin.swift b/Plugins/CassandraDriverPlugin/CassandraPlugin.swift index 2d966008b..f72a3a2fe 100644 --- a/Plugins/CassandraDriverPlugin/CassandraPlugin.swift +++ b/Plugins/CassandraDriverPlugin/CassandraPlugin.swift @@ -104,848 +104,6 @@ internal final class CassandraPlugin: NSObject, TableProPlugin, DriverPlugin { } } -// MARK: - Connection Actor - -private actor CassandraConnectionActor { - private static let logger = Logger(subsystem: "com.TablePro.CassandraDriver", category: "Connection") - - nonisolated(unsafe) private static let isoFormatter: ISO8601DateFormatter = { - let f = ISO8601DateFormatter() - f.formatOptions = [.withInternetDateTime, .withFractionalSeconds] - return f - }() - - nonisolated(unsafe) private static let dateFormatter: DateFormatter = { - let f = DateFormatter() - f.dateFormat = "yyyy-MM-dd" - f.timeZone = TimeZone(identifier: "UTC") - return f - }() - - private var cluster: OpaquePointer? // CassCluster* - private var session: OpaquePointer? // CassSession* - private var currentKeyspace: String? - - var isConnected: Bool { session != nil } - - var keyspace: String? { currentKeyspace } - - func connect( - host: String, - port: Int, - username: String?, - password: String?, - keyspace: String?, - sslMode: SSLMode, - sslCaCertPath: String?, - sslClientCertPath: String?, - sslClientKeyPath: String?, - sslClientKeyPassphrase: String? - ) throws { - cluster = cass_cluster_new() - guard let cluster else { - throw CassandraPluginError.connectionFailed("Failed to create cluster object") - } - - cass_cluster_set_contact_points(cluster, host) - cass_cluster_set_port(cluster, Int32(port)) - - if let username, !username.isEmpty, let password { - cass_cluster_set_credentials(cluster, username, password) - } - - if sslMode != .disabled { - guard let ssl = cass_ssl_new() else { - cass_cluster_free(cluster) - self.cluster = nil - throw CassandraPluginError.connectionFailed("Failed to create SSL context") - } - - cass_ssl_set_verify_flags(ssl, CassandraSSLMapping.verifyFlags(for: sslMode)) - - if sslMode == .verifyCa || sslMode == .verifyIdentity { - guard let caCertPath = sslCaCertPath, !caCertPath.isEmpty else { - cass_ssl_free(ssl) - cass_cluster_free(cluster) - self.cluster = nil - throw SSLHandshakeError.untrustedCertificate(serverMessage: "Verify CA or Verify Identity requires a CA certificate path") - } - guard let certData = FileManager.default.contents(atPath: caCertPath), - let certString = String(data: certData, encoding: .utf8) else { - cass_ssl_free(ssl) - cass_cluster_free(cluster) - self.cluster = nil - throw SSLHandshakeError.untrustedCertificate(serverMessage: "Could not read CA certificate at \(caCertPath)") - } - let rc = cass_ssl_add_trusted_cert(ssl, certString) - if rc != CASS_OK { - cass_ssl_free(ssl) - cass_cluster_free(cluster) - self.cluster = nil - throw SSLHandshakeError.untrustedCertificate(serverMessage: "CA certificate at \(caCertPath) is not a valid PEM") - } - } - - let trimmedClientCertPath = sslClientCertPath?.trimmingCharacters(in: .whitespaces) ?? "" - let trimmedClientKeyPath = sslClientKeyPath?.trimmingCharacters(in: .whitespaces) ?? "" - if !trimmedClientCertPath.isEmpty || !trimmedClientKeyPath.isEmpty { - try applyClientCertificate( - to: ssl, - certPath: trimmedClientCertPath, - keyPath: trimmedClientKeyPath, - keyPassphrase: sslClientKeyPassphrase - ) { - cass_ssl_free(ssl) - cass_cluster_free(cluster) - self.cluster = nil - } - } - - cass_cluster_set_ssl(cluster, ssl) - cass_ssl_free(ssl) - } - - // Connection timeout (10 seconds) - cass_cluster_set_connect_timeout(cluster, 10_000) - cass_cluster_set_request_timeout(cluster, 30_000) - - let newSession = cass_session_new() - guard let newSession else { - cass_cluster_free(cluster) - self.cluster = nil - throw CassandraPluginError.connectionFailed("Failed to create session") - } - - let connectFuture: OpaquePointer? - if let keyspace, !keyspace.isEmpty { - connectFuture = cass_session_connect_keyspace(newSession, cluster, keyspace) - currentKeyspace = keyspace - } else { - connectFuture = cass_session_connect(newSession, cluster) - currentKeyspace = nil - } - - guard let future = connectFuture else { - cass_session_free(newSession) - cass_cluster_free(cluster) - self.cluster = nil - throw CassandraPluginError.connectionFailed("Failed to initiate connection") - } - - cass_future_wait(future) - let rc = cass_future_error_code(future) - - if rc != CASS_OK { - let errorMessage = extractFutureError(future) - cass_future_free(future) - cass_session_free(newSession) - cass_cluster_free(cluster) - self.cluster = nil - if let sslError = Self.classifySSLError(rc: rc, message: errorMessage) { - throw sslError - } - throw CassandraPluginError.connectionFailed(errorMessage) - } - - cass_future_free(future) - session = newSession - - Self.logger.info("Connected to Cassandra at \(host):\(port)") - } - - private func applyClientCertificate( - to ssl: OpaquePointer, - certPath: String, - keyPath: String, - keyPassphrase: String?, - cleanup: () -> Void - ) throws { - guard !certPath.isEmpty else { - cleanup() - throw SSLHandshakeError.clientCertRequired(serverMessage: "A client certificate is required when a client key is set") - } - guard !keyPath.isEmpty else { - cleanup() - throw SSLHandshakeError.clientCertRequired(serverMessage: "A client key is required when a client certificate is set") - } - - guard let certData = FileManager.default.contents(atPath: certPath), - let certString = String(data: certData, encoding: .utf8) else { - cleanup() - throw SSLHandshakeError.clientCertRequired(serverMessage: "Could not read client certificate at \(certPath)") - } - let certResult = cass_ssl_set_cert(ssl, certString) - if certResult != CASS_OK { - cleanup() - throw SSLHandshakeError.clientCertRequired(serverMessage: "Client certificate at \(certPath) is not a valid PEM") - } - - guard let keyData = FileManager.default.contents(atPath: keyPath), - let keyString = String(data: keyData, encoding: .utf8) else { - cleanup() - throw SSLHandshakeError.clientKeyInvalid(serverMessage: "Could not read client key at \(keyPath)") - } - let passphrase = keyPassphrase?.isEmpty == false ? keyPassphrase : nil - let keyResult = cass_ssl_set_private_key(ssl, keyString, passphrase) - if keyResult != CASS_OK { - cleanup() - throw Self.privateKeyLoadError(keyPEM: keyString, hasPassphrase: passphrase != nil, keyPath: keyPath) - } - } - - static func isEncryptedPrivateKey(_ pem: String) -> Bool { - pem.contains("ENCRYPTED PRIVATE KEY") || (pem.contains("Proc-Type:") && pem.contains("ENCRYPTED")) - } - - static func privateKeyLoadError(keyPEM: String, hasPassphrase: Bool, keyPath: String) -> SSLHandshakeError { - guard isEncryptedPrivateKey(keyPEM) else { - return .clientKeyInvalid(serverMessage: "The client key at \(keyPath) is not a valid private key") - } - if hasPassphrase { - return .clientKeyPassphraseIncorrect(serverMessage: "The passphrase for the client key at \(keyPath) is incorrect") - } - return .clientKeyPassphraseRequired(serverMessage: "The client key at \(keyPath) is encrypted. Enter its passphrase.") - } - - func close() { - if let session { - let closeFuture = cass_session_close(session) - if let closeFuture { - cass_future_wait(closeFuture) - cass_future_free(closeFuture) - } - cass_session_free(session) - self.session = nil - } - - if let cluster { - cass_cluster_free(cluster) - self.cluster = nil - } - - currentKeyspace = nil - Self.logger.info("Disconnected from Cassandra") - } - - func executeQuery(_ cql: String) throws -> CassandraRawResult { - guard let session else { - throw CassandraPluginError.notConnected - } - - let startTime = Date() - let statement = cass_statement_new(cql, 0) - guard let statement else { - throw CassandraPluginError.queryFailed("Failed to create statement") - } - - defer { cass_statement_free(statement) } - - let future = cass_session_execute(session, statement) - guard let future else { - throw CassandraPluginError.queryFailed("Failed to execute query") - } - - defer { cass_future_free(future) } - - cass_future_wait(future) - let rc = cass_future_error_code(future) - - if rc != CASS_OK { - throw CassandraPluginError.queryFailed(extractFutureError(future)) - } - - let result = cass_future_get_result(future) - defer { - if let result { cass_result_free(result) } - } - - guard let result else { - let executionTime = Date().timeIntervalSince(startTime) - return CassandraRawResult( - columns: [], - columnTypeNames: [], - rows: [], - rowsAffected: 0, - executionTime: executionTime - ) - } - - return extractResult(from: result, startTime: startTime) - } - - func executePrepared(_ cql: String, parameters: [PluginCellValue]) throws -> CassandraRawResult { - guard let session else { - throw CassandraPluginError.notConnected - } - - let startTime = Date() - - // Prepare - let prepareFuture = cass_session_prepare(session, cql) - guard let prepareFuture else { - throw CassandraPluginError.queryFailed("Failed to prepare statement") - } - defer { cass_future_free(prepareFuture) } - - cass_future_wait(prepareFuture) - let prepRc = cass_future_error_code(prepareFuture) - if prepRc != CASS_OK { - throw CassandraPluginError.queryFailed(extractFutureError(prepareFuture)) - } - - let prepared = cass_future_get_prepared(prepareFuture) - guard let prepared else { - throw CassandraPluginError.queryFailed("Failed to get prepared statement") - } - defer { cass_prepared_free(prepared) } - - // Bind parameters - let statement = cass_prepared_bind(prepared) - guard let statement else { - throw CassandraPluginError.queryFailed("Failed to bind prepared statement") - } - defer { cass_statement_free(statement) } - - for (index, param) in parameters.enumerated() { - switch param { - case .text(let value): - cass_statement_bind_string(statement, index, value) - case .bytes(let data): - data.withUnsafeBytes { rawBuffer in - if let base = rawBuffer.baseAddress?.assumingMemoryBound(to: UInt8.self) { - cass_statement_bind_bytes(statement, index, base, data.count) - } else { - cass_statement_bind_null(statement, index) - } - } - case .null: - cass_statement_bind_null(statement, index) - } - } - - // Execute - let future = cass_session_execute(session, statement) - guard let future else { - throw CassandraPluginError.queryFailed("Failed to execute prepared statement") - } - defer { cass_future_free(future) } - - cass_future_wait(future) - let rc = cass_future_error_code(future) - - if rc != CASS_OK { - throw CassandraPluginError.queryFailed(extractFutureError(future)) - } - - let result = cass_future_get_result(future) - defer { - if let result { cass_result_free(result) } - } - - guard let result else { - let executionTime = Date().timeIntervalSince(startTime) - return CassandraRawResult( - columns: [], - columnTypeNames: [], - rows: [], - rowsAffected: 0, - executionTime: executionTime - ) - } - - return extractResult(from: result, startTime: startTime) - } - - func switchKeyspace(_ keyspace: String) throws { - _ = try executeQuery("USE \"\(escapeIdentifier(keyspace))\"") - currentKeyspace = keyspace - } - - func serverVersion() throws -> String? { - let result = try executeQuery("SELECT release_version FROM system.local WHERE key = 'local'") - return result.rows.first?.first?.asText - } - - // MARK: - Private Helpers - - private func extractResult( - from result: OpaquePointer, - startTime: Date - ) -> CassandraRawResult { - let colCount = cass_result_column_count(result) - let rowCount = cass_result_row_count(result) - - var columns: [String] = [] - var columnTypeNames: [String] = [] - - for i in 0..? - var nameLength: Int = 0 - cass_result_column_name(result, i, &namePtr, &nameLength) - if let namePtr { - columns.append(String(cString: namePtr)) - } else { - columns.append("column_\(i)") - } - - let colType = cass_result_column_type(result, i) - columnTypeNames.append(Self.cassTypeName(colType)) - } - - var rows: [[PluginCellValue]] = [] - let iterator = cass_iterator_from_result(result) - defer { - if let iterator { cass_iterator_free(iterator) } - } - - guard let iterator else { - let executionTime = Date().timeIntervalSince(startTime) - return CassandraRawResult( - columns: columns, - columnTypeNames: columnTypeNames, - rows: [], - rowsAffected: Int(rowCount), - executionTime: executionTime - ) - } - - let maxRows = min(Int(rowCount), 100_000) - var count = 0 - - while cass_iterator_next(iterator) == cass_true && count < maxRows { - let row = cass_iterator_get_row(iterator) - guard let row else { continue } - - var rowData: [PluginCellValue] = [] - for col in 0.. Data? { - var bytes: UnsafePointer? - var length: Int = 0 - guard cass_value_get_bytes(value, &bytes, &length) == CASS_OK, let bytes else { - return nil - } - return Data(bytes: bytes, count: length) - } - - private static func extractStringValue(_ value: OpaquePointer) -> String? { - let valueType = cass_value_type(value) - - switch valueType { - case CASS_VALUE_TYPE_ASCII, CASS_VALUE_TYPE_TEXT, CASS_VALUE_TYPE_VARCHAR: - var output: UnsafePointer? - var outputLength: Int = 0 - let rc = cass_value_get_string(value, &output, &outputLength) - if rc == CASS_OK, let output { - return String( - bytesNoCopy: UnsafeMutableRawPointer(mutating: output), - length: outputLength, - encoding: .utf8, - freeWhenDone: false - ) - } - return nil - - case CASS_VALUE_TYPE_INT: - var intVal: Int32 = 0 - if cass_value_get_int32(value, &intVal) == CASS_OK { - return String(intVal) - } - return nil - - case CASS_VALUE_TYPE_BIGINT, CASS_VALUE_TYPE_COUNTER: - var bigintVal: Int64 = 0 - if cass_value_get_int64(value, &bigintVal) == CASS_OK { - return String(bigintVal) - } - return nil - - case CASS_VALUE_TYPE_SMALL_INT: - var smallVal: Int16 = 0 - if cass_value_get_int16(value, &smallVal) == CASS_OK { - return String(smallVal) - } - return nil - - case CASS_VALUE_TYPE_TINY_INT: - var tinyVal: Int8 = 0 - if cass_value_get_int8(value, &tinyVal) == CASS_OK { - return String(tinyVal) - } - return nil - - case CASS_VALUE_TYPE_FLOAT: - var floatVal: Float = 0 - if cass_value_get_float(value, &floatVal) == CASS_OK { - return String(floatVal) - } - return nil - - case CASS_VALUE_TYPE_DOUBLE: - var doubleVal: Double = 0 - if cass_value_get_double(value, &doubleVal) == CASS_OK { - return String(doubleVal) - } - return nil - - case CASS_VALUE_TYPE_BOOLEAN: - var boolVal: cass_bool_t = cass_false - if cass_value_get_bool(value, &boolVal) == CASS_OK { - return boolVal == cass_true ? "true" : "false" - } - return nil - - case CASS_VALUE_TYPE_UUID, CASS_VALUE_TYPE_TIMEUUID: - var uuid = CassUuid() - if cass_value_get_uuid(value, &uuid) == CASS_OK { - var buffer = [CChar](repeating: 0, count: Int(CASS_UUID_STRING_LENGTH)) - cass_uuid_string(uuid, &buffer) - return String(cString: buffer) - } - return nil - - case CASS_VALUE_TYPE_TIMESTAMP: - var timestamp: Int64 = 0 - if cass_value_get_int64(value, ×tamp) == CASS_OK { - let date = Date(timeIntervalSince1970: Double(timestamp) / 1000.0) - return isoFormatter.string(from: date) - } - return nil - - case CASS_VALUE_TYPE_BLOB: - if let data = extractBlobValue(value) { - return "0x" + data.map { String(format: "%02x", $0) }.joined() - } - return nil - - case CASS_VALUE_TYPE_INET: - var inet = CassInet() - if cass_value_get_inet(value, &inet) == CASS_OK { - var buffer = [CChar](repeating: 0, count: Int(CASS_INET_STRING_LENGTH)) - cass_inet_string(inet, &buffer) - return String(cString: buffer) - } - return nil - - case CASS_VALUE_TYPE_LIST, CASS_VALUE_TYPE_SET: - return extractCollectionString(value, open: "[", close: "]") - - case CASS_VALUE_TYPE_MAP: - return extractMapString(value) - - case CASS_VALUE_TYPE_TUPLE: - return extractCollectionString(value, open: "(", close: ")") - - case CASS_VALUE_TYPE_DATE: - var dateVal: UInt32 = 0 - if cass_value_get_uint32(value, &dateVal) == CASS_OK { - let daysSinceEpoch = Int64(dateVal) - Int64(1 << 31) - let epochSeconds = daysSinceEpoch * 86400 - let date = Date(timeIntervalSince1970: Double(epochSeconds)) - return dateFormatter.string(from: date) - } - return nil - - case CASS_VALUE_TYPE_TIME: - var timeVal: Int64 = 0 - if cass_value_get_int64(value, &timeVal) == CASS_OK { - // Cassandra time is nanoseconds since midnight - let totalSeconds = timeVal / 1_000_000_000 - let hours = totalSeconds / 3600 - let minutes = (totalSeconds % 3600) / 60 - let seconds = totalSeconds % 60 - let nanos = timeVal % 1_000_000_000 - if nanos > 0 { - let millis = nanos / 1_000_000 - return String(format: "%02lld:%02lld:%02lld.%03lld", hours, minutes, seconds, millis) - } - return String(format: "%02lld:%02lld:%02lld", hours, minutes, seconds) - } - return nil - - case CASS_VALUE_TYPE_DECIMAL, CASS_VALUE_TYPE_VARINT: - // Read as bytes and display as hex since proper numeric decoding - // requires BigInteger support not available in the C driver API - var bytes: UnsafePointer? - var length: Int = 0 - if cass_value_get_bytes(value, &bytes, &length) == CASS_OK, let bytes { - let data = Data(bytes: bytes, count: length) - return "0x" + data.map { String(format: "%02x", $0) }.joined() - } - return nil - - default: - // Fallback: try reading as string - var output: UnsafePointer? - var outputLength: Int = 0 - if cass_value_get_string(value, &output, &outputLength) == CASS_OK, let output { - return String( - bytesNoCopy: UnsafeMutableRawPointer(mutating: output), - length: outputLength, - encoding: .utf8, - freeWhenDone: false - ) - } - return "" - } - } - - private static func extractCollectionString( - _ value: OpaquePointer, - open: String, - close: String - ) -> String { - guard let iterator = cass_iterator_from_collection(value) else { - return "\(open)\(close)" - } - defer { cass_iterator_free(iterator) } - - var elements: [String] = [] - while cass_iterator_next(iterator) == cass_true { - if let elem = cass_iterator_get_value(iterator) { - elements.append(extractStringValue(elem) ?? "null") - } - } - return "\(open)\(elements.joined(separator: ", "))\(close)" - } - - private static func extractMapString(_ value: OpaquePointer) -> String { - guard let iterator = cass_iterator_from_map(value) else { - return "{}" - } - defer { cass_iterator_free(iterator) } - - var pairs: [String] = [] - while cass_iterator_next(iterator) == cass_true { - let key = cass_iterator_get_map_key(iterator) - let val = cass_iterator_get_map_value(iterator) - let keyStr = key.flatMap { extractStringValue($0) } ?? "null" - let valStr = val.flatMap { extractStringValue($0) } ?? "null" - pairs.append("\(keyStr): \(valStr)") - } - return "{\(pairs.joined(separator: ", "))}" - } - - private static func cassTypeName(_ type: CassValueType) -> String { - switch type { - case CASS_VALUE_TYPE_ASCII: return "ascii" - case CASS_VALUE_TYPE_BIGINT: return "bigint" - case CASS_VALUE_TYPE_BLOB: return "blob" - case CASS_VALUE_TYPE_BOOLEAN: return "boolean" - case CASS_VALUE_TYPE_COUNTER: return "counter" - case CASS_VALUE_TYPE_DECIMAL: return "decimal" - case CASS_VALUE_TYPE_DOUBLE: return "double" - case CASS_VALUE_TYPE_FLOAT: return "float" - case CASS_VALUE_TYPE_INT: return "int" - case CASS_VALUE_TYPE_TEXT: return "text" - case CASS_VALUE_TYPE_TIMESTAMP: return "timestamp" - case CASS_VALUE_TYPE_UUID: return "uuid" - case CASS_VALUE_TYPE_VARCHAR: return "varchar" - case CASS_VALUE_TYPE_VARINT: return "varint" - case CASS_VALUE_TYPE_TIMEUUID: return "timeuuid" - case CASS_VALUE_TYPE_INET: return "inet" - case CASS_VALUE_TYPE_DATE: return "date" - case CASS_VALUE_TYPE_TIME: return "time" - case CASS_VALUE_TYPE_SMALL_INT: return "smallint" - case CASS_VALUE_TYPE_TINY_INT: return "tinyint" - case CASS_VALUE_TYPE_LIST: return "list" - case CASS_VALUE_TYPE_MAP: return "map" - case CASS_VALUE_TYPE_SET: return "set" - case CASS_VALUE_TYPE_TUPLE: return "tuple" - case CASS_VALUE_TYPE_UDT: return "udt" - default: return "text" - } - } - - private func extractFutureError(_ future: OpaquePointer) -> String { - var message: UnsafePointer? - var messageLength: Int = 0 - cass_future_error_message(future, &message, &messageLength) - if let message { - return String( - bytesNoCopy: UnsafeMutableRawPointer(mutating: message), - length: messageLength, - encoding: .utf8, - freeWhenDone: false - ) ?? "Unknown error" - } - return "Unknown error" - } - - func streamQuery( - _ cql: String, - continuation: AsyncThrowingStream.Continuation - ) throws { - guard let session else { - throw CassandraPluginError.notConnected - } - - let pageSize: Int32 = 5_000 - let statement = cass_statement_new(cql, 0) - guard let statement else { - throw CassandraPluginError.queryFailed("Failed to create statement") - } - - cass_statement_set_paging_size(statement, pageSize) - - var headerSent = false - - defer { cass_statement_free(statement) } - - while true { - let future = cass_session_execute(session, statement) - guard let future else { - throw CassandraPluginError.queryFailed("Failed to execute query") - } - - cass_future_wait(future) - let rc = cass_future_error_code(future) - - if rc != CASS_OK { - let errorMessage = extractFutureError(future) - cass_future_free(future) - throw CassandraPluginError.queryFailed(errorMessage) - } - - let result = cass_future_get_result(future) - cass_future_free(future) - - guard let result else { break } - - if !headerSent { - let colCount = cass_result_column_count(result) - var columns: [String] = [] - var columnTypeNames: [String] = [] - - for i in 0..? - var nameLength: Int = 0 - cass_result_column_name(result, i, &namePtr, &nameLength) - if let namePtr { - columns.append(String(cString: namePtr)) - } else { - columns.append("column_\(i)") - } - let colType = cass_result_column_type(result, i) - columnTypeNames.append(Self.cassTypeName(colType)) - } - - continuation.yield(.header(PluginStreamHeader( - columns: columns, - columnTypeNames: columnTypeNames, - estimatedRowCount: nil - ))) - headerSent = true - } - - let colCount = cass_result_column_count(result) - let iterator = cass_iterator_from_result(result) - - if let iterator { - while cass_iterator_next(iterator) == cass_true { - let row = cass_iterator_get_row(iterator) - guard let row else { continue } - - var rowData: [PluginCellValue] = [] - for col in 0.. String { - value.replacingOccurrences(of: "\"", with: "\"\"") - } - - static func classifySSLError(rc: CassError, message: String) -> SSLHandshakeError? { - switch rc { - case CASS_ERROR_SSL_NO_PEER_CERT, CASS_ERROR_SSL_INVALID_PEER_CERT: - return .untrustedCertificate(serverMessage: message) - case CASS_ERROR_SSL_IDENTITY_MISMATCH: - return .hostnameMismatch(serverMessage: message) - case CASS_ERROR_SSL_INVALID_PRIVATE_KEY, CASS_ERROR_SSL_INVALID_CERT: - return .clientCertRequired(serverMessage: message) - case CASS_ERROR_SSL_PROTOCOL_ERROR: - return .cipherMismatch(serverMessage: message) - default: - break - } - let lower = message.lowercased() - if lower.contains("ssl handshake") || lower.contains("tls handshake") || lower.contains("ssl_connect") { - return .cipherMismatch(serverMessage: message) - } - return nil - } -} - -// MARK: - Raw Result - -private struct CassandraRawResult: Sendable { - let columns: [String] - let columnTypeNames: [String] - let rows: [[PluginCellValue]] - let rowsAffected: Int - let executionTime: TimeInterval -} - // MARK: - Plugin Driver internal final class CassandraPluginDriver: PluginDatabaseDriver, @unchecked Sendable { @@ -1428,22 +586,3 @@ internal final class CassandraPluginDriver: PluginDatabaseDriver, @unchecked Sen } } -// MARK: - Errors - -internal enum CassandraPluginError: Error { - case connectionFailed(String) - case notConnected - case queryFailed(String) - case unsupportedOperation -} - -extension CassandraPluginError: PluginDriverError { - var pluginErrorMessage: String { - switch self { - case .connectionFailed(let msg): return msg - case .notConnected: return String(localized: "Not connected to database") - case .queryFailed(let msg): return msg - case .unsupportedOperation: return String(localized: "Operation not supported by Cassandra") - } - } -} diff --git a/Plugins/CassandraDriverPlugin/CassandraPluginError.swift b/Plugins/CassandraDriverPlugin/CassandraPluginError.swift new file mode 100644 index 000000000..e338b852a --- /dev/null +++ b/Plugins/CassandraDriverPlugin/CassandraPluginError.swift @@ -0,0 +1,25 @@ +// +// CassandraPluginError.swift +// CassandraDriverPlugin +// + +import Foundation +import TableProPluginKit + +internal enum CassandraPluginError: Error { + case connectionFailed(String) + case notConnected + case queryFailed(String) + case unsupportedOperation +} + +extension CassandraPluginError: PluginDriverError { + var pluginErrorMessage: String { + switch self { + case .connectionFailed(let msg): return msg + case .notConnected: return String(localized: "Not connected to database") + case .queryFailed(let msg): return msg + case .unsupportedOperation: return String(localized: "Operation not supported by Cassandra") + } + } +} diff --git a/Plugins/ClickHouseDriverPlugin/ClickHousePlugin.swift b/Plugins/ClickHouseDriverPlugin/ClickHousePlugin.swift index 16d96e512..3dac36eaf 100644 --- a/Plugins/ClickHouseDriverPlugin/ClickHousePlugin.swift +++ b/Plugins/ClickHouseDriverPlugin/ClickHousePlugin.swift @@ -115,7 +115,7 @@ final class ClickHousePlugin: NSObject, TableProPlugin, DriverPlugin { // MARK: - Error Types -private struct ClickHouseError: Error, PluginDriverError { +struct ClickHouseError: Error, PluginDriverError { let message: String var pluginErrorMessage: String { message } @@ -126,7 +126,7 @@ private struct ClickHouseError: Error, PluginDriverError { // MARK: - Internal Query Result -private struct CHQueryResult { +struct CHQueryResult { let columns: [String] let columnTypeNames: [String] let rows: [[PluginCellValue]] @@ -137,19 +137,19 @@ private struct CHQueryResult { // MARK: - Plugin Driver final class ClickHousePluginDriver: PluginDatabaseDriver, @unchecked Sendable { - private let config: DriverConnectionConfig + let config: DriverConnectionConfig private var _serverVersion: String? - private let lock = NSLock() - private var session: URLSession? - private var currentTask: URLSessionDataTask? - private var _currentDatabase: String - private var _lastQueryId: String? - private let _queryTimeout = HttpQueryTimeoutBox() + let lock = NSLock() + var session: URLSession? + var currentTask: URLSessionDataTask? + var _currentDatabase: String + var _lastQueryId: String? + let _queryTimeout = HttpQueryTimeoutBox() - private static let logger = Logger(subsystem: "com.TablePro", category: "ClickHousePluginDriver") + static let logger = Logger(subsystem: "com.TablePro", category: "ClickHousePluginDriver") - private static let selectPrefixes: Set = [ + static let selectPrefixes: Set = [ "SELECT", "SHOW", "DESCRIBE", "DESC", "EXISTS", "EXPLAIN", "WITH" ] @@ -282,335 +282,6 @@ final class ClickHousePluginDriver: PluginDatabaseDriver, @unchecked Sendable { ) } - // MARK: - Schema Operations - - func fetchTables(schema: String?) async throws -> [PluginTableInfo] { - let sql = """ - SELECT name, engine FROM system.tables - WHERE database = currentDatabase() AND name NOT LIKE '.%' - ORDER BY name - """ - let result = try await execute(query: sql) - return result.rows.compactMap { row -> PluginTableInfo? in - guard let name = row[safe: 0]?.asText else { return nil } - let engine = row[safe: 1]?.asText - let tableType = (engine?.contains("View") == true) ? "VIEW" : "TABLE" - return PluginTableInfo(name: name, type: tableType) - } - } - - func fetchColumns(table: String, schema: String?) async throws -> [PluginColumnInfo] { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - - let pkSql = """ - SELECT primary_key, sorting_key FROM system.tables - WHERE database = currentDatabase() AND name = '\(escapedTable)' - """ - let pkResult = try await execute(query: pkSql) - let primaryKey = pkResult.rows.first.flatMap { $0[safe: 0]?.asText } ?? "" - let sortingKey = pkResult.rows.first.flatMap { $0[safe: 1]?.asText } ?? "" - let keyString = primaryKey.isEmpty ? sortingKey : primaryKey - let pkColumns = Set(keyString.split(separator: ",").map { $0.trimmingCharacters(in: .whitespaces) }) - - let sql = """ - SELECT name, type, default_kind, default_expression, comment - FROM system.columns - WHERE database = currentDatabase() AND table = '\(escapedTable)' - ORDER BY position - """ - let result = try await execute(query: sql) - return result.rows.compactMap { row -> PluginColumnInfo? in - guard let name = row[safe: 0]?.asText else { return nil } - let dataType = (row[safe: 1]?.asText) ?? "String" - let defaultKind = row[safe: 2]?.asText - let defaultExpr = row[safe: 3]?.asText - let comment = row[safe: 4]?.asText - - let isNullable = dataType.hasPrefix("Nullable(") - - var defaultValue: String? - if let kind = defaultKind, !kind.isEmpty, let expr = defaultExpr, !expr.isEmpty { - defaultValue = expr - } - - var extra: String? - if let kind = defaultKind, !kind.isEmpty, kind != "DEFAULT" { - extra = kind - } - - return PluginColumnInfo( - name: name, - dataType: dataType, - isNullable: isNullable, - isPrimaryKey: pkColumns.contains(name), - defaultValue: defaultValue, - extra: extra, - comment: (comment?.isEmpty == false) ? comment : nil, - allowedValues: EnumValueParser.parseClickHouseEnum(from: ClickHousePluginDriver.unwrapTypeWrappers(dataType)) - ) - } - } - - func fetchAllColumns(schema: String?) async throws -> [String: [PluginColumnInfo]] { - // Pre-fetch PK columns for all tables. Falls back to sorting_key when - // primary_key is empty (MergeTree without explicit PRIMARY KEY clause). - // Note: expression-based keys like toDate(col) won't match bare column names. - let pkSql = """ - SELECT name, primary_key, sorting_key FROM system.tables - WHERE database = currentDatabase() - """ - let pkResult = try await execute(query: pkSql) - var pkLookup: [String: Set] = [:] - for row in pkResult.rows { - guard let tableName = row[safe: 0]?.asText else { continue } - let primaryKey = (row[safe: 1]?.asText) ?? "" - let sortingKey = (row[safe: 2]?.asText) ?? "" - let keyString = primaryKey.isEmpty ? sortingKey : primaryKey - guard !keyString.isEmpty else { continue } - let cols = Set(keyString.split(separator: ",").map { String($0).trimmingCharacters(in: .whitespaces) }) - pkLookup[tableName] = cols - } - - let sql = """ - SELECT table, name, type, default_kind, default_expression, comment - FROM system.columns - WHERE database = currentDatabase() - ORDER BY table, position - """ - let result = try await execute(query: sql) - var columnsByTable: [String: [PluginColumnInfo]] = [:] - for row in result.rows { - guard let tableName = row[safe: 0]?.asText, - let colName = row[safe: 1]?.asText else { continue } - let dataType = (row[safe: 2]?.asText) ?? "String" - let defaultKind = row[safe: 3]?.asText - let defaultExpr = row[safe: 4]?.asText - let comment = row[safe: 5]?.asText - - let isNullable = dataType.hasPrefix("Nullable(") - - var defaultValue: String? - if let kind = defaultKind, !kind.isEmpty, let expr = defaultExpr, !expr.isEmpty { - defaultValue = expr - } - - var extra: String? - if let kind = defaultKind, !kind.isEmpty, kind != "DEFAULT" { - extra = kind - } - - let colInfo = PluginColumnInfo( - name: colName, - dataType: dataType, - isNullable: isNullable, - isPrimaryKey: pkLookup[tableName]?.contains(colName) == true, - defaultValue: defaultValue, - extra: extra, - comment: (comment?.isEmpty == false) ? comment : nil, - allowedValues: EnumValueParser.parseClickHouseEnum(from: ClickHousePluginDriver.unwrapTypeWrappers(dataType)) - ) - columnsByTable[tableName, default: []].append(colInfo) - } - return columnsByTable - } - - static func unwrapTypeWrappers(_ value: String) -> String { - for prefix in ["Nullable(", "LowCardinality("] { - if value.hasPrefix(prefix), value.hasSuffix(")") { - let start = value.index(value.startIndex, offsetBy: prefix.count) - let end = value.index(before: value.endIndex) - return unwrapTypeWrappers(String(value[start.. [PluginIndexInfo] { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - var indexes: [PluginIndexInfo] = [] - - let sortingKeySql = """ - SELECT sorting_key FROM system.tables - WHERE database = currentDatabase() AND name = '\(escapedTable)' - """ - let sortingResult = try await execute(query: sortingKeySql) - if let row = sortingResult.rows.first, - let sortingKey = row[safe: 0]?.asText, !sortingKey.isEmpty { - let columns = sortingKey.components(separatedBy: ",").map { - $0.trimmingCharacters(in: .whitespacesAndNewlines) - } - indexes.append(PluginIndexInfo( - name: "PRIMARY (sorting key)", - columns: columns, - isUnique: false, - isPrimary: true, - type: "SORTING KEY" - )) - } - - let caps = ClickHouseCapabilities.parse(serverVersion) - guard caps.hasDataSkippingIndicesTable else { return indexes } - let skippingSql = """ - SELECT name, expr FROM system.data_skipping_indices - WHERE database = currentDatabase() AND table = '\(escapedTable)' - """ - let skippingResult = try await execute(query: skippingSql) - for row in skippingResult.rows { - guard let idxName = row[safe: 0]?.asText else { continue } - let expr = (row[safe: 1]?.asText) ?? "" - let columns = expr.components(separatedBy: ",").map { - $0.trimmingCharacters(in: .whitespacesAndNewlines) - } - indexes.append(PluginIndexInfo( - name: idxName, - columns: columns, - isUnique: false, - isPrimary: false, - type: "DATA_SKIPPING" - )) - } - - return indexes - } - - func fetchForeignKeys(table: String, schema: String?) async throws -> [PluginForeignKeyInfo] { - [] - } - - func fetchApproximateRowCount(table: String, schema: String?) async throws -> Int? { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - let sql = """ - SELECT sum(rows) FROM system.parts - WHERE database = currentDatabase() AND table = '\(escapedTable)' AND active = 1 - """ - let result = try await execute(query: sql) - if let row = result.rows.first, let cell = row.first, let str = cell.asText { - return Int(str) - } - return nil - } - - func fetchTableDDL(table: String, schema: String?) async throws -> String { - let escapedTable = table.replacingOccurrences(of: "`", with: "``") - let sql = "SHOW CREATE TABLE `\(escapedTable)`" - let result = try await execute(query: sql) - return result.rows.first?.first?.asText ?? "" - } - - func fetchViewDefinition(view: String, schema: String?) async throws -> String { - let escapedView = view.replacingOccurrences(of: "'", with: "''") - let sql = """ - SELECT as_select FROM system.tables - WHERE database = currentDatabase() AND name = '\(escapedView)' - """ - let result = try await execute(query: sql) - return result.rows.first?.first?.asText ?? "" - } - - func fetchTableMetadata(table: String, schema: String?) async throws -> PluginTableMetadata { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - - let engineSql = """ - SELECT engine, comment FROM system.tables - WHERE database = currentDatabase() AND name = '\(escapedTable)' - """ - let engineResult = try await execute(query: engineSql) - let engine = engineResult.rows.first.flatMap { $0[safe: 0]?.asText } - let tableComment = engineResult.rows.first.flatMap { $0[safe: 1]?.asText } - - let partsSql = """ - SELECT sum(rows), sum(bytes_on_disk) - FROM system.parts - WHERE database = currentDatabase() AND table = '\(escapedTable)' AND active = 1 - """ - let partsResult = try await execute(query: partsSql) - if let row = partsResult.rows.first { - let rowCount = (row[safe: 0]?.asText).flatMap { Int64($0) } - let sizeBytes = (row[safe: 1]?.asText).flatMap { Int64($0) } ?? 0 - return PluginTableMetadata( - tableName: table, - dataSize: sizeBytes, - totalSize: sizeBytes, - rowCount: rowCount, - comment: (tableComment?.isEmpty == false) ? tableComment : nil, - engine: engine - ) - } - - return PluginTableMetadata(tableName: table, engine: engine) - } - - func fetchDatabases() async throws -> [String] { - let result = try await execute(query: "SHOW DATABASES") - return result.rows.compactMap { $0.first?.asText } - } - - func fetchDatabaseMetadata(_ database: String) async throws -> PluginDatabaseMetadata { - let escapedDb = database.replacingOccurrences(of: "'", with: "''") - let sql = """ - SELECT count() AS table_count, sum(total_bytes) AS size_bytes - FROM system.tables WHERE database = '\(escapedDb)' - """ - let result = try await execute(query: sql) - if let row = result.rows.first { - let tableCount = (row[safe: 0]?.asText).flatMap { Int($0) } ?? 0 - let sizeBytes = (row[safe: 1]?.asText).flatMap { Int64($0) } - return PluginDatabaseMetadata( - name: database, - tableCount: tableCount, - sizeBytes: sizeBytes - ) - } - return PluginDatabaseMetadata(name: database) - } - - func fetchAllDatabaseMetadata() async throws -> [PluginDatabaseMetadata] { - let sql = """ - SELECT database, count() AS table_count, sum(total_bytes) AS size_bytes - FROM system.tables - GROUP BY database - ORDER BY database - """ - let result = try await execute(query: sql) - return result.rows.compactMap { row -> PluginDatabaseMetadata? in - guard let name = row[safe: 0]?.asText else { return nil } - let tableCount = (row[safe: 1]?.asText).flatMap { Int($0) } ?? 0 - let sizeBytes = (row[safe: 2]?.asText).flatMap { Int64($0) } - return PluginDatabaseMetadata(name: name, tableCount: tableCount, sizeBytes: sizeBytes) - } - } - - func createDatabaseFormSpec() async throws -> PluginCreateDatabaseFormSpec? { - PluginCreateDatabaseFormSpec(fields: [], footnote: nil) - } - - func createDatabase(_ request: PluginCreateDatabaseRequest) async throws { - let escapedName = request.name.replacingOccurrences(of: "`", with: "``") - _ = try await execute(query: "CREATE DATABASE `\(escapedName)`") - } - - func dropDatabase(name: String) async throws { - let escapedName = name.replacingOccurrences(of: "`", with: "``") - _ = try await execute(query: "DROP DATABASE `\(escapedName)`") - } - - // MARK: - All Tables Metadata - - func allTablesMetadataSQL(schema: String?) -> String? { - """ - SELECT - database as `schema`, - name, - engine as kind, - total_rows as estimated_rows, - formatReadableSize(total_bytes) as total_size, - comment - FROM system.tables - WHERE database = currentDatabase() - ORDER BY name - """ - } - // MARK: - DML Statement Generation func generateStatements( @@ -806,304 +477,6 @@ final class ClickHousePluginDriver: PluginDatabaseDriver, @unchecked Sendable { } } - // MARK: - Private HTTP Layer - - private func executeRaw(_ query: String, queryId: String? = nil) async throws -> CHQueryResult { - lock.lock() - guard let session = self.session else { - lock.unlock() - throw ClickHouseError.notConnected - } - let database = _currentDatabase - if let queryId { - _lastQueryId = queryId - } - lock.unlock() - - var request = try buildRequest(query: query, database: database, queryId: queryId) - request.timeoutInterval = _queryTimeout.requestTimeoutInterval - let isSelect = Self.isSelectLikeQuery(query) - - let (data, response) = try await withTaskCancellationHandler { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<(Data, URLResponse), Error>) in - let task = session.dataTask(with: request) { data, response, error in - if let error { - continuation.resume(throwing: error) - return - } - guard let data, let response else { - continuation.resume(throwing: ClickHouseError(message: "Empty response from server")) - return - } - continuation.resume(returning: (data, response)) - } - - self.lock.lock() - self.currentTask = task - self.lock.unlock() - - task.resume() - } - } onCancel: { - self.lock.lock() - self.currentTask?.cancel() - self.currentTask = nil - self.lock.unlock() - } - - lock.lock() - currentTask = nil - lock.unlock() - - if let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode >= 400 { - let body = String(data: data, encoding: .utf8) ?? "Unknown error" - Self.logger.error("ClickHouse HTTP \(httpResponse.statusCode): \(body)") - throw ClickHouseError(message: body.trimmingCharacters(in: .whitespacesAndNewlines)) - } - - if isSelect { - return parseTabSeparatedResponse(data) - } - - return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) - } - - private func executeRawWithParams(_ query: String, params: [String: String?], queryId: String? = nil) async throws -> CHQueryResult { - lock.lock() - guard let session = self.session else { - lock.unlock() - throw ClickHouseError.notConnected - } - let database = _currentDatabase - if let queryId { - _lastQueryId = queryId - } - lock.unlock() - - var request = try buildRequest(query: query, database: database, queryId: queryId, params: params) - request.timeoutInterval = _queryTimeout.requestTimeoutInterval - let isSelect = Self.isSelectLikeQuery(query) - - let (data, response) = try await withTaskCancellationHandler { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<(Data, URLResponse), Error>) in - let task = session.dataTask(with: request) { data, response, error in - if let error { - continuation.resume(throwing: error) - return - } - guard let data, let response else { - continuation.resume(throwing: ClickHouseError(message: "Empty response from server")) - return - } - continuation.resume(returning: (data, response)) - } - self.lock.lock() - self.currentTask = task - self.lock.unlock() - task.resume() - } - } onCancel: { - self.lock.lock() - self.currentTask?.cancel() - self.currentTask = nil - self.lock.unlock() - } - - lock.lock() - currentTask = nil - lock.unlock() - - if let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode >= 400 { - let body = String(data: data, encoding: .utf8) ?? "Unknown error" - Self.logger.error("ClickHouse HTTP \(httpResponse.statusCode): \(body)") - throw ClickHouseError(message: body.trimmingCharacters(in: .whitespacesAndNewlines)) - } - - if isSelect { - return parseTabSeparatedResponse(data) - } - - return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) - } - - private func buildRequest(query: String, database: String, queryId: String? = nil, params: [String: String?]? = nil) throws -> URLRequest { - let useTLS = config.ssl.isEnabled - - var components = URLComponents() - components.scheme = useTLS ? "https" : "http" - components.host = config.host - components.port = config.port - components.path = "/" - - var queryItems = [URLQueryItem]() - if !database.isEmpty { - queryItems.append(URLQueryItem(name: "database", value: database)) - } - if let queryId { - queryItems.append(URLQueryItem(name: "query_id", value: queryId)) - } - queryItems.append(URLQueryItem(name: "send_progress_in_http_headers", value: "1")) - if let params { - for (key, value) in params.sorted(by: { $0.key < $1.key }) { - queryItems.append(URLQueryItem(name: "param_\(key)", value: value)) - } - } - if !queryItems.isEmpty { - components.queryItems = queryItems - } - - guard let url = components.url else { - throw ClickHouseError(message: "Failed to construct request URL") - } - - var request = URLRequest(url: url) - request.httpMethod = "POST" - - let credentials = "\(config.username):\(config.password)" - if let credData = credentials.data(using: .utf8) { - request.setValue("Basic \(credData.base64EncodedString())", forHTTPHeaderField: "Authorization") - } - - let trimmedQuery = query.trimmingCharacters(in: .whitespacesAndNewlines) - .replacingOccurrences(of: ";+$", with: "", options: .regularExpression) - - if Self.isSelectLikeQuery(trimmedQuery) { - request.httpBody = (trimmedQuery + " FORMAT TabSeparatedWithNamesAndTypes").data(using: .utf8) - } else { - request.httpBody = trimmedQuery.data(using: .utf8) - } - - return request - } - - private static func isSelectLikeQuery(_ query: String) -> Bool { - let trimmed = query.trimmingCharacters(in: .whitespacesAndNewlines) - guard let firstWord = trimmed.split(separator: " ", maxSplits: 1).first else { - return false - } - return selectPrefixes.contains(firstWord.uppercased()) - } - - private func parseTabSeparatedResponse(_ data: Data) -> CHQueryResult { - guard let text = String(data: data, encoding: .utf8), !text.isEmpty else { - return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) - } - - let lines = text.components(separatedBy: "\n") - - guard lines.count >= 2 else { - return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) - } - - let columns = lines[0].components(separatedBy: "\t") - let columnTypes = lines[1].components(separatedBy: "\t") - - var rows: [[PluginCellValue]] = [] - var truncated = false - for i in 2..= PluginRowLimits.emergencyMax { - truncated = true - break - } - } - - return CHQueryResult( - columns: columns, - columnTypeNames: columnTypes, - rows: rows, - affectedRows: rows.count, - isTruncated: truncated - ) - } - - private static func unescapeTsvField(_ field: String) -> String { - var result = "" - result.reserveCapacity((field as NSString).length) - var iterator = field.makeIterator() - - while let char = iterator.next() { - if char == "\\" { - if let next = iterator.next() { - switch next { - case "\\": result.append("\\") - case "t": result.append("\t") - case "n": result.append("\n") - default: - result.append("\\") - result.append(next) - } - } else { - result.append("\\") - } - } else { - result.append(char) - } - } - - return result - } - - /// Convert `?` placeholders to `{p1:String}` and build parameter map for ClickHouse HTTP params. - private static func buildClickHouseParams( - query: String, - parameters: [PluginCellValue] - ) -> (String, [String: String?]) { - var converted = "" - var paramIndex = 0 - var inSingleQuote = false - var inDoubleQuote = false - var isEscaped = false - - for char in query { - if isEscaped { - isEscaped = false - converted.append(char) - continue - } - if char == "\\" && (inSingleQuote || inDoubleQuote) { - isEscaped = true - converted.append(char) - continue - } - if char == "'" && !inDoubleQuote { - inSingleQuote.toggle() - } else if char == "\"" && !inSingleQuote { - inDoubleQuote.toggle() - } - if char == "?" && !inSingleQuote && !inDoubleQuote && paramIndex < parameters.count { - paramIndex += 1 - converted.append("{p\(paramIndex):String}") - } else { - converted.append(char) - } - } - - var paramMap: [String: String?] = [:] - for i in 0.. AsyncThrowingStream { diff --git a/Plugins/ClickHouseDriverPlugin/ClickHousePluginDriver+Http.swift b/Plugins/ClickHouseDriverPlugin/ClickHousePluginDriver+Http.swift new file mode 100644 index 000000000..99006be07 --- /dev/null +++ b/Plugins/ClickHouseDriverPlugin/ClickHousePluginDriver+Http.swift @@ -0,0 +1,309 @@ +// +// ClickHousePluginDriver+Http.swift +// ClickHouseDriverPlugin +// + +import Foundation +import os +import TableProPluginKit + +extension ClickHousePluginDriver { + // MARK: - Private HTTP Layer + + func executeRaw(_ query: String, queryId: String? = nil) async throws -> CHQueryResult { + lock.lock() + guard let session = self.session else { + lock.unlock() + throw ClickHouseError.notConnected + } + let database = _currentDatabase + if let queryId { + _lastQueryId = queryId + } + lock.unlock() + + var request = try buildRequest(query: query, database: database, queryId: queryId) + request.timeoutInterval = _queryTimeout.requestTimeoutInterval + let isSelect = Self.isSelectLikeQuery(query) + + let (data, response) = try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<(Data, URLResponse), Error>) in + let task = session.dataTask(with: request) { data, response, error in + if let error { + continuation.resume(throwing: error) + return + } + guard let data, let response else { + continuation.resume(throwing: ClickHouseError(message: "Empty response from server")) + return + } + continuation.resume(returning: (data, response)) + } + + self.lock.lock() + self.currentTask = task + self.lock.unlock() + + task.resume() + } + } onCancel: { + self.lock.lock() + self.currentTask?.cancel() + self.currentTask = nil + self.lock.unlock() + } + + lock.lock() + currentTask = nil + lock.unlock() + + if let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode >= 400 { + let body = String(data: data, encoding: .utf8) ?? "Unknown error" + Self.logger.error("ClickHouse HTTP \(httpResponse.statusCode): \(body)") + throw ClickHouseError(message: body.trimmingCharacters(in: .whitespacesAndNewlines)) + } + + if isSelect { + return parseTabSeparatedResponse(data) + } + + return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) + } + + func executeRawWithParams(_ query: String, params: [String: String?], queryId: String? = nil) async throws -> CHQueryResult { + lock.lock() + guard let session = self.session else { + lock.unlock() + throw ClickHouseError.notConnected + } + let database = _currentDatabase + if let queryId { + _lastQueryId = queryId + } + lock.unlock() + + var request = try buildRequest(query: query, database: database, queryId: queryId, params: params) + request.timeoutInterval = _queryTimeout.requestTimeoutInterval + let isSelect = Self.isSelectLikeQuery(query) + + let (data, response) = try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<(Data, URLResponse), Error>) in + let task = session.dataTask(with: request) { data, response, error in + if let error { + continuation.resume(throwing: error) + return + } + guard let data, let response else { + continuation.resume(throwing: ClickHouseError(message: "Empty response from server")) + return + } + continuation.resume(returning: (data, response)) + } + self.lock.lock() + self.currentTask = task + self.lock.unlock() + task.resume() + } + } onCancel: { + self.lock.lock() + self.currentTask?.cancel() + self.currentTask = nil + self.lock.unlock() + } + + lock.lock() + currentTask = nil + lock.unlock() + + if let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode >= 400 { + let body = String(data: data, encoding: .utf8) ?? "Unknown error" + Self.logger.error("ClickHouse HTTP \(httpResponse.statusCode): \(body)") + throw ClickHouseError(message: body.trimmingCharacters(in: .whitespacesAndNewlines)) + } + + if isSelect { + return parseTabSeparatedResponse(data) + } + + return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) + } + + func buildRequest(query: String, database: String, queryId: String? = nil, params: [String: String?]? = nil) throws -> URLRequest { + let useTLS = config.ssl.isEnabled + + var components = URLComponents() + components.scheme = useTLS ? "https" : "http" + components.host = config.host + components.port = config.port + components.path = "/" + + var queryItems = [URLQueryItem]() + if !database.isEmpty { + queryItems.append(URLQueryItem(name: "database", value: database)) + } + if let queryId { + queryItems.append(URLQueryItem(name: "query_id", value: queryId)) + } + queryItems.append(URLQueryItem(name: "send_progress_in_http_headers", value: "1")) + if let params { + for (key, value) in params.sorted(by: { $0.key < $1.key }) { + queryItems.append(URLQueryItem(name: "param_\(key)", value: value)) + } + } + if !queryItems.isEmpty { + components.queryItems = queryItems + } + + guard let url = components.url else { + throw ClickHouseError(message: "Failed to construct request URL") + } + + var request = URLRequest(url: url) + request.httpMethod = "POST" + + let credentials = "\(config.username):\(config.password)" + if let credData = credentials.data(using: .utf8) { + request.setValue("Basic \(credData.base64EncodedString())", forHTTPHeaderField: "Authorization") + } + + let trimmedQuery = query.trimmingCharacters(in: .whitespacesAndNewlines) + .replacingOccurrences(of: ";+$", with: "", options: .regularExpression) + + if Self.isSelectLikeQuery(trimmedQuery) { + request.httpBody = (trimmedQuery + " FORMAT TabSeparatedWithNamesAndTypes").data(using: .utf8) + } else { + request.httpBody = trimmedQuery.data(using: .utf8) + } + + return request + } + + static func isSelectLikeQuery(_ query: String) -> Bool { + let trimmed = query.trimmingCharacters(in: .whitespacesAndNewlines) + guard let firstWord = trimmed.split(separator: " ", maxSplits: 1).first else { + return false + } + return selectPrefixes.contains(firstWord.uppercased()) + } + + func parseTabSeparatedResponse(_ data: Data) -> CHQueryResult { + guard let text = String(data: data, encoding: .utf8), !text.isEmpty else { + return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) + } + + let lines = text.components(separatedBy: "\n") + + guard lines.count >= 2 else { + return CHQueryResult(columns: [], columnTypeNames: [], rows: [], affectedRows: 0, isTruncated: false) + } + + let columns = lines[0].components(separatedBy: "\t") + let columnTypes = lines[1].components(separatedBy: "\t") + + var rows: [[PluginCellValue]] = [] + var truncated = false + for i in 2..= PluginRowLimits.emergencyMax { + truncated = true + break + } + } + + return CHQueryResult( + columns: columns, + columnTypeNames: columnTypes, + rows: rows, + affectedRows: rows.count, + isTruncated: truncated + ) + } + + static func unescapeTsvField(_ field: String) -> String { + var result = "" + result.reserveCapacity((field as NSString).length) + var iterator = field.makeIterator() + + while let char = iterator.next() { + if char == "\\" { + if let next = iterator.next() { + switch next { + case "\\": result.append("\\") + case "t": result.append("\t") + case "n": result.append("\n") + default: + result.append("\\") + result.append(next) + } + } else { + result.append("\\") + } + } else { + result.append(char) + } + } + + return result + } + + /// Convert `?` placeholders to `{p1:String}` and build parameter map for ClickHouse HTTP params. + static func buildClickHouseParams( + query: String, + parameters: [PluginCellValue] + ) -> (String, [String: String?]) { + var converted = "" + var paramIndex = 0 + var inSingleQuote = false + var inDoubleQuote = false + var isEscaped = false + + for char in query { + if isEscaped { + isEscaped = false + converted.append(char) + continue + } + if char == "\\" && (inSingleQuote || inDoubleQuote) { + isEscaped = true + converted.append(char) + continue + } + if char == "'" && !inDoubleQuote { + inSingleQuote.toggle() + } else if char == "\"" && !inSingleQuote { + inDoubleQuote.toggle() + } + if char == "?" && !inSingleQuote && !inDoubleQuote && paramIndex < parameters.count { + paramIndex += 1 + converted.append("{p\(paramIndex):String}") + } else { + converted.append(char) + } + } + + var paramMap: [String: String?] = [:] + for i in 0.. [PluginTableInfo] { + let sql = """ + SELECT name, engine FROM system.tables + WHERE database = currentDatabase() AND name NOT LIKE '.%' + ORDER BY name + """ + let result = try await execute(query: sql) + return result.rows.compactMap { row -> PluginTableInfo? in + guard let name = row[safe: 0]?.asText else { return nil } + let engine = row[safe: 1]?.asText + let tableType = (engine?.contains("View") == true) ? "VIEW" : "TABLE" + return PluginTableInfo(name: name, type: tableType) + } + } + + func fetchColumns(table: String, schema: String?) async throws -> [PluginColumnInfo] { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + + let pkSql = """ + SELECT primary_key, sorting_key FROM system.tables + WHERE database = currentDatabase() AND name = '\(escapedTable)' + """ + let pkResult = try await execute(query: pkSql) + let primaryKey = pkResult.rows.first.flatMap { $0[safe: 0]?.asText } ?? "" + let sortingKey = pkResult.rows.first.flatMap { $0[safe: 1]?.asText } ?? "" + let keyString = primaryKey.isEmpty ? sortingKey : primaryKey + let pkColumns = Set(keyString.split(separator: ",").map { $0.trimmingCharacters(in: .whitespaces) }) + + let sql = """ + SELECT name, type, default_kind, default_expression, comment + FROM system.columns + WHERE database = currentDatabase() AND table = '\(escapedTable)' + ORDER BY position + """ + let result = try await execute(query: sql) + return result.rows.compactMap { row -> PluginColumnInfo? in + guard let name = row[safe: 0]?.asText else { return nil } + let dataType = (row[safe: 1]?.asText) ?? "String" + let defaultKind = row[safe: 2]?.asText + let defaultExpr = row[safe: 3]?.asText + let comment = row[safe: 4]?.asText + + let isNullable = dataType.hasPrefix("Nullable(") + + var defaultValue: String? + if let kind = defaultKind, !kind.isEmpty, let expr = defaultExpr, !expr.isEmpty { + defaultValue = expr + } + + var extra: String? + if let kind = defaultKind, !kind.isEmpty, kind != "DEFAULT" { + extra = kind + } + + return PluginColumnInfo( + name: name, + dataType: dataType, + isNullable: isNullable, + isPrimaryKey: pkColumns.contains(name), + defaultValue: defaultValue, + extra: extra, + comment: (comment?.isEmpty == false) ? comment : nil, + allowedValues: EnumValueParser.parseClickHouseEnum(from: ClickHousePluginDriver.unwrapTypeWrappers(dataType)) + ) + } + } + + func fetchAllColumns(schema: String?) async throws -> [String: [PluginColumnInfo]] { + // Pre-fetch PK columns for all tables. Falls back to sorting_key when + // primary_key is empty (MergeTree without explicit PRIMARY KEY clause). + // Note: expression-based keys like toDate(col) won't match bare column names. + let pkSql = """ + SELECT name, primary_key, sorting_key FROM system.tables + WHERE database = currentDatabase() + """ + let pkResult = try await execute(query: pkSql) + var pkLookup: [String: Set] = [:] + for row in pkResult.rows { + guard let tableName = row[safe: 0]?.asText else { continue } + let primaryKey = (row[safe: 1]?.asText) ?? "" + let sortingKey = (row[safe: 2]?.asText) ?? "" + let keyString = primaryKey.isEmpty ? sortingKey : primaryKey + guard !keyString.isEmpty else { continue } + let cols = Set(keyString.split(separator: ",").map { String($0).trimmingCharacters(in: .whitespaces) }) + pkLookup[tableName] = cols + } + + let sql = """ + SELECT table, name, type, default_kind, default_expression, comment + FROM system.columns + WHERE database = currentDatabase() + ORDER BY table, position + """ + let result = try await execute(query: sql) + var columnsByTable: [String: [PluginColumnInfo]] = [:] + for row in result.rows { + guard let tableName = row[safe: 0]?.asText, + let colName = row[safe: 1]?.asText else { continue } + let dataType = (row[safe: 2]?.asText) ?? "String" + let defaultKind = row[safe: 3]?.asText + let defaultExpr = row[safe: 4]?.asText + let comment = row[safe: 5]?.asText + + let isNullable = dataType.hasPrefix("Nullable(") + + var defaultValue: String? + if let kind = defaultKind, !kind.isEmpty, let expr = defaultExpr, !expr.isEmpty { + defaultValue = expr + } + + var extra: String? + if let kind = defaultKind, !kind.isEmpty, kind != "DEFAULT" { + extra = kind + } + + let colInfo = PluginColumnInfo( + name: colName, + dataType: dataType, + isNullable: isNullable, + isPrimaryKey: pkLookup[tableName]?.contains(colName) == true, + defaultValue: defaultValue, + extra: extra, + comment: (comment?.isEmpty == false) ? comment : nil, + allowedValues: EnumValueParser.parseClickHouseEnum(from: ClickHousePluginDriver.unwrapTypeWrappers(dataType)) + ) + columnsByTable[tableName, default: []].append(colInfo) + } + return columnsByTable + } + + static func unwrapTypeWrappers(_ value: String) -> String { + for prefix in ["Nullable(", "LowCardinality("] { + if value.hasPrefix(prefix), value.hasSuffix(")") { + let start = value.index(value.startIndex, offsetBy: prefix.count) + let end = value.index(before: value.endIndex) + return unwrapTypeWrappers(String(value[start.. [PluginIndexInfo] { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + var indexes: [PluginIndexInfo] = [] + + let sortingKeySql = """ + SELECT sorting_key FROM system.tables + WHERE database = currentDatabase() AND name = '\(escapedTable)' + """ + let sortingResult = try await execute(query: sortingKeySql) + if let row = sortingResult.rows.first, + let sortingKey = row[safe: 0]?.asText, !sortingKey.isEmpty { + let columns = sortingKey.components(separatedBy: ",").map { + $0.trimmingCharacters(in: .whitespacesAndNewlines) + } + indexes.append(PluginIndexInfo( + name: "PRIMARY (sorting key)", + columns: columns, + isUnique: false, + isPrimary: true, + type: "SORTING KEY" + )) + } + + let caps = ClickHouseCapabilities.parse(serverVersion) + guard caps.hasDataSkippingIndicesTable else { return indexes } + let skippingSql = """ + SELECT name, expr FROM system.data_skipping_indices + WHERE database = currentDatabase() AND table = '\(escapedTable)' + """ + let skippingResult = try await execute(query: skippingSql) + for row in skippingResult.rows { + guard let idxName = row[safe: 0]?.asText else { continue } + let expr = (row[safe: 1]?.asText) ?? "" + let columns = expr.components(separatedBy: ",").map { + $0.trimmingCharacters(in: .whitespacesAndNewlines) + } + indexes.append(PluginIndexInfo( + name: idxName, + columns: columns, + isUnique: false, + isPrimary: false, + type: "DATA_SKIPPING" + )) + } + + return indexes + } + + func fetchForeignKeys(table: String, schema: String?) async throws -> [PluginForeignKeyInfo] { + [] + } + + func fetchApproximateRowCount(table: String, schema: String?) async throws -> Int? { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + let sql = """ + SELECT sum(rows) FROM system.parts + WHERE database = currentDatabase() AND table = '\(escapedTable)' AND active = 1 + """ + let result = try await execute(query: sql) + if let row = result.rows.first, let cell = row.first, let str = cell.asText { + return Int(str) + } + return nil + } + + func fetchTableDDL(table: String, schema: String?) async throws -> String { + let escapedTable = table.replacingOccurrences(of: "`", with: "``") + let sql = "SHOW CREATE TABLE `\(escapedTable)`" + let result = try await execute(query: sql) + return result.rows.first?.first?.asText ?? "" + } + + func fetchViewDefinition(view: String, schema: String?) async throws -> String { + let escapedView = view.replacingOccurrences(of: "'", with: "''") + let sql = """ + SELECT as_select FROM system.tables + WHERE database = currentDatabase() AND name = '\(escapedView)' + """ + let result = try await execute(query: sql) + return result.rows.first?.first?.asText ?? "" + } + + func fetchTableMetadata(table: String, schema: String?) async throws -> PluginTableMetadata { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + + let engineSql = """ + SELECT engine, comment FROM system.tables + WHERE database = currentDatabase() AND name = '\(escapedTable)' + """ + let engineResult = try await execute(query: engineSql) + let engine = engineResult.rows.first.flatMap { $0[safe: 0]?.asText } + let tableComment = engineResult.rows.first.flatMap { $0[safe: 1]?.asText } + + let partsSql = """ + SELECT sum(rows), sum(bytes_on_disk) + FROM system.parts + WHERE database = currentDatabase() AND table = '\(escapedTable)' AND active = 1 + """ + let partsResult = try await execute(query: partsSql) + if let row = partsResult.rows.first { + let rowCount = (row[safe: 0]?.asText).flatMap { Int64($0) } + let sizeBytes = (row[safe: 1]?.asText).flatMap { Int64($0) } ?? 0 + return PluginTableMetadata( + tableName: table, + dataSize: sizeBytes, + totalSize: sizeBytes, + rowCount: rowCount, + comment: (tableComment?.isEmpty == false) ? tableComment : nil, + engine: engine + ) + } + + return PluginTableMetadata(tableName: table, engine: engine) + } + + func fetchDatabases() async throws -> [String] { + let result = try await execute(query: "SHOW DATABASES") + return result.rows.compactMap { $0.first?.asText } + } + + func fetchDatabaseMetadata(_ database: String) async throws -> PluginDatabaseMetadata { + let escapedDb = database.replacingOccurrences(of: "'", with: "''") + let sql = """ + SELECT count() AS table_count, sum(total_bytes) AS size_bytes + FROM system.tables WHERE database = '\(escapedDb)' + """ + let result = try await execute(query: sql) + if let row = result.rows.first { + let tableCount = (row[safe: 0]?.asText).flatMap { Int($0) } ?? 0 + let sizeBytes = (row[safe: 1]?.asText).flatMap { Int64($0) } + return PluginDatabaseMetadata( + name: database, + tableCount: tableCount, + sizeBytes: sizeBytes + ) + } + return PluginDatabaseMetadata(name: database) + } + + func fetchAllDatabaseMetadata() async throws -> [PluginDatabaseMetadata] { + let sql = """ + SELECT database, count() AS table_count, sum(total_bytes) AS size_bytes + FROM system.tables + GROUP BY database + ORDER BY database + """ + let result = try await execute(query: sql) + return result.rows.compactMap { row -> PluginDatabaseMetadata? in + guard let name = row[safe: 0]?.asText else { return nil } + let tableCount = (row[safe: 1]?.asText).flatMap { Int($0) } ?? 0 + let sizeBytes = (row[safe: 2]?.asText).flatMap { Int64($0) } + return PluginDatabaseMetadata(name: name, tableCount: tableCount, sizeBytes: sizeBytes) + } + } + + func createDatabaseFormSpec() async throws -> PluginCreateDatabaseFormSpec? { + PluginCreateDatabaseFormSpec(fields: [], footnote: nil) + } + + func createDatabase(_ request: PluginCreateDatabaseRequest) async throws { + let escapedName = request.name.replacingOccurrences(of: "`", with: "``") + _ = try await execute(query: "CREATE DATABASE `\(escapedName)`") + } + + func dropDatabase(name: String) async throws { + let escapedName = name.replacingOccurrences(of: "`", with: "``") + _ = try await execute(query: "DROP DATABASE `\(escapedName)`") + } + + // MARK: - All Tables Metadata + + func allTablesMetadataSQL(schema: String?) -> String? { + """ + SELECT + database as `schema`, + name, + engine as kind, + total_rows as estimated_rows, + formatReadableSize(total_bytes) as total_size, + comment + FROM system.tables + WHERE database = currentDatabase() + ORDER BY name + """ + } + +} diff --git a/Plugins/DuckDBDriverPlugin/DuckDBConnection.swift b/Plugins/DuckDBDriverPlugin/DuckDBConnection.swift new file mode 100644 index 000000000..666add7c4 --- /dev/null +++ b/Plugins/DuckDBDriverPlugin/DuckDBConnection.swift @@ -0,0 +1,600 @@ +// +// DuckDBConnection.swift +// DuckDBDriverPlugin +// + +import CDuckDB +import Foundation +import os +import TableProPluginKit + +actor DuckDBConnectionActor { + private static let logger = Logger(subsystem: "com.TablePro", category: "DuckDBConnectionActor") + + private var database: duckdb_database? + private var connection: duckdb_connection? + + var isConnected: Bool { connection != nil } + + var connectionHandleForInterrupt: duckdb_connection? { connection } + + func open(path: String) throws { + var db: duckdb_database? + var errorPtr: UnsafeMutablePointer? + let state = duckdb_open_ext(path, &db, nil, &errorPtr) + + if state == DuckDBError { + let detail: String + if let errPtr = errorPtr { + detail = String(cString: errPtr) + duckdb_free(errPtr) + } else { + detail = "unknown error" + } + throw DuckDBPluginError.connectionFailed( + "Failed to open DuckDB database at '\(path)': \(detail)" + ) + } + + guard let openedDB = db else { + throw DuckDBPluginError.connectionFailed( + "Failed to open DuckDB database at '\(path)'" + ) + } + + var conn: duckdb_connection? + let connState = duckdb_connect(openedDB, &conn) + + if connState == DuckDBError { + duckdb_close(&db) + throw DuckDBPluginError.connectionFailed("Failed to create DuckDB connection") + } + + database = db + connection = conn + } + + func close() { + if connection != nil { + duckdb_disconnect(&connection) + connection = nil + } + if database != nil { + duckdb_close(&database) + database = nil + } + } + + func executeQuery(_ query: String) throws -> DuckDBRawResult { + guard let conn = connection else { + throw DuckDBPluginError.notConnected + } + + let startTime = Date() + var result = duckdb_result() + + let state = duckdb_query(conn, query, &result) + + if state == DuckDBError { + let errorMsg: String + if let errPtr = duckdb_result_error(&result) { + errorMsg = String(cString: errPtr) + } else { + errorMsg = "Unknown DuckDB error" + } + duckdb_destroy_result(&result) + throw DuckDBPluginError.queryFailed(errorMsg) + } + + defer { + duckdb_destroy_result(&result) + } + + var raw = Self.extractResult(from: &result, startTime: startTime) + Self.patchTzColumns(&raw, query: query, connection: conn) + return raw + } + + func executePrepared(_ query: String, parameters: [PluginCellValue]) throws -> DuckDBRawResult { + guard let conn = connection else { + throw DuckDBPluginError.notConnected + } + + let startTime = Date() + var stmtOpt: duckdb_prepared_statement? + + let prepState = duckdb_prepare(conn, query, &stmtOpt) + if prepState == DuckDBError { + let errorMsg: String + if let s = stmtOpt, let errPtr = duckdb_prepare_error(s) { + errorMsg = String(cString: errPtr) + } else { + errorMsg = "Failed to prepare statement" + } + duckdb_destroy_prepare(&stmtOpt) + throw DuckDBPluginError.queryFailed(errorMsg) + } + + guard let stmt = stmtOpt else { + throw DuckDBPluginError.queryFailed("Failed to prepare statement") + } + + defer { + duckdb_destroy_prepare(&stmtOpt) + } + + for (index, param) in parameters.enumerated() { + let paramIdx = idx_t(index + 1) + let bindState: duckdb_state + switch param { + case .null: + bindState = duckdb_bind_null(stmt, paramIdx) + case .text(let value): + bindState = duckdb_bind_varchar(stmt, paramIdx, value) + case .bytes(let data): + bindState = data.withUnsafeBytes { rawBuffer -> duckdb_state in + guard let baseAddress = rawBuffer.baseAddress else { + return duckdb_bind_null(stmt, paramIdx) + } + return duckdb_bind_blob(stmt, paramIdx, baseAddress, idx_t(data.count)) + } + } + if bindState == DuckDBError { + throw DuckDBPluginError.queryFailed("Failed to bind parameter at index \(index)") + } + } + + var result = duckdb_result() + let execState = duckdb_execute_prepared(stmt, &result) + + if execState == DuckDBError { + let errorMsg: String + if let errPtr = duckdb_result_error(&result) { + errorMsg = String(cString: errPtr) + } else { + errorMsg = "Failed to execute prepared statement" + } + duckdb_destroy_result(&result) + throw DuckDBPluginError.queryFailed(errorMsg) + } + + defer { + duckdb_destroy_result(&result) + } + + var raw = Self.extractResult(from: &result, startTime: startTime) + Self.patchTzColumns(&raw, query: query, connection: conn) + return raw + } + + func streamQuery( + _ query: String, + continuation: AsyncThrowingStream.Continuation + ) throws { + guard let conn = connection else { + throw DuckDBPluginError.notConnected + } + + var result = duckdb_result() + let state = duckdb_query(conn, query, &result) + + if state == DuckDBError { + let errorMsg: String + if let errPtr = duckdb_result_error(&result) { + errorMsg = String(cString: errPtr) + } else { + errorMsg = "Unknown DuckDB error" + } + duckdb_destroy_result(&result) + throw DuckDBPluginError.queryFailed(errorMsg) + } + + let colCount = duckdb_column_count(&result) + var columns: [String] = [] + var columnTypeNames: [String] = [] + var columnTypes: [duckdb_type] = [] + for i in 0...Continuation + ) throws { + let castExprs = columns.enumerated().map { i, name in + castExpression(for: columnTypes[i], column: name) + } + let wrappedQuery = buildWrappedQuery(originalQuery: query, castExprs: castExprs) + + var result = duckdb_result() + let state = duckdb_query(connection, wrappedQuery, &result) + if state == DuckDBError { + let errorMsg: String + if let errPtr = duckdb_result_error(&result) { + errorMsg = String(cString: errPtr) + } else { + errorMsg = "Unknown DuckDB error" + } + duckdb_destroy_result(&result) + throw DuckDBPluginError.queryFailed(errorMsg) + } + defer { duckdb_destroy_result(&result) } + + try Self.streamResultRows( + &result, + columns: columns, + columnTypeNames: columnTypeNames, + continuation: continuation + ) + } + + private static func streamResultRows( + _ result: inout duckdb_result, + columns: [String], + columnTypeNames: [String], + continuation: AsyncThrowingStream.Continuation + ) throws { + let colCount = duckdb_column_count(&result) + let rowCount = duckdb_row_count(&result) + + continuation.yield(.header(PluginStreamHeader( + columns: columns, + columnTypeNames: columnTypeNames, + estimatedRowCount: Int(rowCount) + ))) + + let maxRows = min(rowCount, UInt64(PluginRowLimits.emergencyMax)) + if rowCount > UInt64(PluginRowLimits.emergencyMax) { + Self.logger.warning("streamQuery truncating result from \(rowCount) to \(maxRows) rows") + } + + for row in 0.. DuckDBRawResult { + let colCount = duckdb_column_count(&result) + let rowCount = duckdb_row_count(&result) + let rowsChanged = duckdb_rows_changed(&result) + + var columns: [String] = [] + var columnTypeNames: [String] = [] + var columnTypes: [duckdb_type] = [] + + for i in 0.. UInt64(PluginRowLimits.emergencyMax) { + truncated = true + } + + for row in 0.. String { + switch type { + case DUCKDB_TYPE_BOOLEAN: return "BOOLEAN" + case DUCKDB_TYPE_TINYINT: return "TINYINT" + case DUCKDB_TYPE_SMALLINT: return "SMALLINT" + case DUCKDB_TYPE_INTEGER: return "INTEGER" + case DUCKDB_TYPE_BIGINT: return "BIGINT" + case DUCKDB_TYPE_UTINYINT: return "UTINYINT" + case DUCKDB_TYPE_USMALLINT: return "USMALLINT" + case DUCKDB_TYPE_UINTEGER: return "UINTEGER" + case DUCKDB_TYPE_UBIGINT: return "UBIGINT" + case DUCKDB_TYPE_FLOAT: return "FLOAT" + case DUCKDB_TYPE_DOUBLE: return "DOUBLE" + case DUCKDB_TYPE_TIMESTAMP: return "TIMESTAMP" + case DUCKDB_TYPE_DATE: return "DATE" + case DUCKDB_TYPE_TIME: return "TIME" + case DUCKDB_TYPE_INTERVAL: return "INTERVAL" + case DUCKDB_TYPE_HUGEINT: return "HUGEINT" + case DUCKDB_TYPE_VARCHAR: return "VARCHAR" + case DUCKDB_TYPE_BLOB: return "BLOB" + case DUCKDB_TYPE_DECIMAL: return "DECIMAL" + case DUCKDB_TYPE_TIMESTAMP_S: return "TIMESTAMP_S" + case DUCKDB_TYPE_TIMESTAMP_MS: return "TIMESTAMP_MS" + case DUCKDB_TYPE_TIMESTAMP_NS: return "TIMESTAMP_NS" + case DUCKDB_TYPE_ENUM: return "ENUM" + case DUCKDB_TYPE_LIST: return "LIST" + case DUCKDB_TYPE_STRUCT: return "STRUCT" + case DUCKDB_TYPE_MAP: return "MAP" + case DUCKDB_TYPE_UUID: return "UUID" + case DUCKDB_TYPE_UNION: return "UNION" + case DUCKDB_TYPE_BIT: return "BIT" + case DUCKDB_TYPE_TIMESTAMP_TZ: return "TIMESTAMPTZ" + case DUCKDB_TYPE_TIME_TZ: return "TIMETZ" + case DUCKDB_TYPE_TIME_NS: return "TIME_NS" + case DUCKDB_TYPE_UHUGEINT: return "UHUGEINT" + case DUCKDB_TYPE_ARRAY: return "ARRAY" + case DUCKDB_TYPE_GEOMETRY: return "GEOMETRY" + default: return "VARCHAR" + } + } + + private static func extractFallbackValue( + _ result: inout duckdb_result, col: idx_t, row: idx_t, type: duckdb_type + ) -> String? { + switch type { + case DUCKDB_TYPE_TIMESTAMP, DUCKDB_TYPE_TIMESTAMP_S, DUCKDB_TYPE_TIMESTAMP_MS, DUCKDB_TYPE_TIMESTAMP_NS: + let ts = duckdb_value_timestamp(&result, col, row) + return formatTimestamp(ts) + + case DUCKDB_TYPE_DATE: + let date = duckdb_value_date(&result, col, row) + let d = duckdb_from_date(date) + return String(format: "\(formatYearISO(d.year))-%02d-%02d", d.month, d.day) + + case DUCKDB_TYPE_TIME, DUCKDB_TYPE_TIME_NS: + let time = duckdb_value_time(&result, col, row) + return formatTime(duckdb_from_time(time)) + + case DUCKDB_TYPE_BOOLEAN: + return duckdb_value_boolean(&result, col, row) ? "true" : "false" + + case DUCKDB_TYPE_TINYINT: + return String(duckdb_value_int8(&result, col, row)) + case DUCKDB_TYPE_SMALLINT: + return String(duckdb_value_int16(&result, col, row)) + case DUCKDB_TYPE_INTEGER: + return String(duckdb_value_int32(&result, col, row)) + case DUCKDB_TYPE_BIGINT: + return String(duckdb_value_int64(&result, col, row)) + case DUCKDB_TYPE_UTINYINT: + return String(duckdb_value_uint8(&result, col, row)) + case DUCKDB_TYPE_USMALLINT: + return String(duckdb_value_uint16(&result, col, row)) + case DUCKDB_TYPE_UINTEGER: + return String(duckdb_value_uint32(&result, col, row)) + case DUCKDB_TYPE_UBIGINT: + return String(duckdb_value_uint64(&result, col, row)) + case DUCKDB_TYPE_FLOAT: + return String(duckdb_value_float(&result, col, row)) + case DUCKDB_TYPE_DOUBLE: + return String(duckdb_value_double(&result, col, row)) + + case DUCKDB_TYPE_HUGEINT: + let h = duckdb_value_hugeint(&result, col, row) + return formatHugeInt(upper: h.upper, lower: h.lower) + + case DUCKDB_TYPE_UHUGEINT: + let u = duckdb_value_uhugeint(&result, col, row) + return formatUHugeInt(upper: u.upper, lower: u.lower) + + default: + return nil + } + } + + static func patchTzColumns( + _ raw: inout DuckDBRawResult, query: String, connection: duckdb_connection + ) { + let patchedColIndices = raw.columnTypes.enumerated().compactMap { idx, type in + isUnrenderable(type) ? idx : nil + } + guard !patchedColIndices.isEmpty, !raw.rows.isEmpty else { return } + + let castExprs = raw.columns.enumerated().map { i, name in + castExpression(for: raw.columnTypes[i], column: name) + } + let wrappedQuery = buildWrappedQuery(originalQuery: query, castExprs: castExprs) + + var patchResult = duckdb_result() + guard duckdb_query(connection, wrappedQuery, &patchResult) == DuckDBSuccess else { return } + defer { duckdb_destroy_result(&patchResult) } + + let patchRowCount = min(duckdb_row_count(&patchResult), UInt64(raw.rows.count)) + for row in 0.. Bool { + switch type { + case DUCKDB_TYPE_TIMESTAMP_TZ, DUCKDB_TYPE_TIME_TZ, DUCKDB_TYPE_GEOMETRY: + return true + default: + return false + } + } + + static func castExpression(for type: duckdb_type, column: String) -> String { + let quoted = quoteIdentifier(column) + switch type { + case DUCKDB_TYPE_GEOMETRY: + return "CASE WHEN \(quoted) IS NULL THEN NULL ELSE ST_AsText(\(quoted)) END AS \(quoted)" + case DUCKDB_TYPE_TIMESTAMP_TZ, DUCKDB_TYPE_TIME_TZ: + return "CASE WHEN \(quoted) IS NULL THEN NULL ELSE CAST(\(quoted) AS VARCHAR) END AS \(quoted)" + default: + return quoted + } + } + + static func buildWrappedQuery(originalQuery: String, castExprs: [String]) -> String { + var trimmed = originalQuery.trimmingCharacters(in: .whitespacesAndNewlines) + if trimmed.hasSuffix(";") { + trimmed = String(trimmed.dropLast()) + } + return "SELECT \(castExprs.joined(separator: ", ")) FROM (\(trimmed)) AS _tp_cast" + } + + static func quoteIdentifier(_ ident: String) -> String { + "\"\(ident.replacingOccurrences(of: "\"", with: "\"\""))\"" + } + + static func formatTimestamp(_ ts: duckdb_timestamp) -> String { + let parts = duckdb_from_timestamp(ts) + let d = parts.date + let t = parts.time + let micros = t.micros % 1_000_000 + let yearPart = formatYearISO(d.year) + if micros == 0 { + return String( + format: "\(yearPart)-%02d-%02d %02d:%02d:%02d", + d.month, d.day, t.hour, t.min, t.sec + ) + } + return String( + format: "\(yearPart)-%02d-%02d %02d:%02d:%02d.%06d", + d.month, d.day, t.hour, t.min, t.sec, micros + ) + } + + static func formatYearISO(_ year: Int32) -> String { + if year < 0 { + return String(format: "-%04d", -Int(year)) + } + return String(format: "%04d", year) + } + + private static func formatTime(_ t: duckdb_time_struct) -> String { + let micros = t.micros % 1_000_000 + if micros == 0 { + return String(format: "%02d:%02d:%02d", t.hour, t.min, t.sec) + } + return String(format: "%02d:%02d:%02d.%06d", t.hour, t.min, t.sec, micros) + } + + static func formatHugeInt(upper: Int64, lower: UInt64) -> String { + HugeIntFormatter.format(upper: upper, lower: lower) + } + + static func formatUHugeInt(upper: UInt64, lower: UInt64) -> String { + HugeIntFormatter.formatUnsigned(upper: upper, lower: lower) + } +} + +struct DuckDBRawResult: @unchecked Sendable { + let columns: [String] + let columnTypeNames: [String] + let columnTypes: [duckdb_type] + var rows: [[PluginCellValue]] + let rowsAffected: Int + let executionTime: TimeInterval + let isTruncated: Bool +} diff --git a/Plugins/DuckDBDriverPlugin/DuckDBPlugin.swift b/Plugins/DuckDBDriverPlugin/DuckDBPlugin.swift index a3874a070..d00210de1 100644 --- a/Plugins/DuckDBDriverPlugin/DuckDBPlugin.swift +++ b/Plugins/DuckDBDriverPlugin/DuckDBPlugin.swift @@ -103,599 +103,6 @@ final class DuckDBPlugin: NSObject, TableProPlugin, DriverPlugin { } } -// MARK: - DuckDB Connection Actor - -private actor DuckDBConnectionActor { - private static let logger = Logger(subsystem: "com.TablePro", category: "DuckDBConnectionActor") - - private var database: duckdb_database? - private var connection: duckdb_connection? - - var isConnected: Bool { connection != nil } - - var connectionHandleForInterrupt: duckdb_connection? { connection } - - func open(path: String) throws { - var db: duckdb_database? - var errorPtr: UnsafeMutablePointer? - let state = duckdb_open_ext(path, &db, nil, &errorPtr) - - if state == DuckDBError { - let detail: String - if let errPtr = errorPtr { - detail = String(cString: errPtr) - duckdb_free(errPtr) - } else { - detail = "unknown error" - } - throw DuckDBPluginError.connectionFailed( - "Failed to open DuckDB database at '\(path)': \(detail)" - ) - } - - guard let openedDB = db else { - throw DuckDBPluginError.connectionFailed( - "Failed to open DuckDB database at '\(path)'" - ) - } - - var conn: duckdb_connection? - let connState = duckdb_connect(openedDB, &conn) - - if connState == DuckDBError { - duckdb_close(&db) - throw DuckDBPluginError.connectionFailed("Failed to create DuckDB connection") - } - - database = db - connection = conn - } - - func close() { - if connection != nil { - duckdb_disconnect(&connection) - connection = nil - } - if database != nil { - duckdb_close(&database) - database = nil - } - } - - func executeQuery(_ query: String) throws -> DuckDBRawResult { - guard let conn = connection else { - throw DuckDBPluginError.notConnected - } - - let startTime = Date() - var result = duckdb_result() - - let state = duckdb_query(conn, query, &result) - - if state == DuckDBError { - let errorMsg: String - if let errPtr = duckdb_result_error(&result) { - errorMsg = String(cString: errPtr) - } else { - errorMsg = "Unknown DuckDB error" - } - duckdb_destroy_result(&result) - throw DuckDBPluginError.queryFailed(errorMsg) - } - - defer { - duckdb_destroy_result(&result) - } - - var raw = Self.extractResult(from: &result, startTime: startTime) - Self.patchTzColumns(&raw, query: query, connection: conn) - return raw - } - - func executePrepared(_ query: String, parameters: [PluginCellValue]) throws -> DuckDBRawResult { - guard let conn = connection else { - throw DuckDBPluginError.notConnected - } - - let startTime = Date() - var stmtOpt: duckdb_prepared_statement? - - let prepState = duckdb_prepare(conn, query, &stmtOpt) - if prepState == DuckDBError { - let errorMsg: String - if let s = stmtOpt, let errPtr = duckdb_prepare_error(s) { - errorMsg = String(cString: errPtr) - } else { - errorMsg = "Failed to prepare statement" - } - duckdb_destroy_prepare(&stmtOpt) - throw DuckDBPluginError.queryFailed(errorMsg) - } - - guard let stmt = stmtOpt else { - throw DuckDBPluginError.queryFailed("Failed to prepare statement") - } - - defer { - duckdb_destroy_prepare(&stmtOpt) - } - - for (index, param) in parameters.enumerated() { - let paramIdx = idx_t(index + 1) - let bindState: duckdb_state - switch param { - case .null: - bindState = duckdb_bind_null(stmt, paramIdx) - case .text(let value): - bindState = duckdb_bind_varchar(stmt, paramIdx, value) - case .bytes(let data): - bindState = data.withUnsafeBytes { rawBuffer -> duckdb_state in - guard let baseAddress = rawBuffer.baseAddress else { - return duckdb_bind_null(stmt, paramIdx) - } - return duckdb_bind_blob(stmt, paramIdx, baseAddress, idx_t(data.count)) - } - } - if bindState == DuckDBError { - throw DuckDBPluginError.queryFailed("Failed to bind parameter at index \(index)") - } - } - - var result = duckdb_result() - let execState = duckdb_execute_prepared(stmt, &result) - - if execState == DuckDBError { - let errorMsg: String - if let errPtr = duckdb_result_error(&result) { - errorMsg = String(cString: errPtr) - } else { - errorMsg = "Failed to execute prepared statement" - } - duckdb_destroy_result(&result) - throw DuckDBPluginError.queryFailed(errorMsg) - } - - defer { - duckdb_destroy_result(&result) - } - - var raw = Self.extractResult(from: &result, startTime: startTime) - Self.patchTzColumns(&raw, query: query, connection: conn) - return raw - } - - func streamQuery( - _ query: String, - continuation: AsyncThrowingStream.Continuation - ) throws { - guard let conn = connection else { - throw DuckDBPluginError.notConnected - } - - var result = duckdb_result() - let state = duckdb_query(conn, query, &result) - - if state == DuckDBError { - let errorMsg: String - if let errPtr = duckdb_result_error(&result) { - errorMsg = String(cString: errPtr) - } else { - errorMsg = "Unknown DuckDB error" - } - duckdb_destroy_result(&result) - throw DuckDBPluginError.queryFailed(errorMsg) - } - - let colCount = duckdb_column_count(&result) - var columns: [String] = [] - var columnTypeNames: [String] = [] - var columnTypes: [duckdb_type] = [] - for i in 0...Continuation - ) throws { - let castExprs = columns.enumerated().map { i, name in - castExpression(for: columnTypes[i], column: name) - } - let wrappedQuery = buildWrappedQuery(originalQuery: query, castExprs: castExprs) - - var result = duckdb_result() - let state = duckdb_query(connection, wrappedQuery, &result) - if state == DuckDBError { - let errorMsg: String - if let errPtr = duckdb_result_error(&result) { - errorMsg = String(cString: errPtr) - } else { - errorMsg = "Unknown DuckDB error" - } - duckdb_destroy_result(&result) - throw DuckDBPluginError.queryFailed(errorMsg) - } - defer { duckdb_destroy_result(&result) } - - try Self.streamResultRows( - &result, - columns: columns, - columnTypeNames: columnTypeNames, - continuation: continuation - ) - } - - private static func streamResultRows( - _ result: inout duckdb_result, - columns: [String], - columnTypeNames: [String], - continuation: AsyncThrowingStream.Continuation - ) throws { - let colCount = duckdb_column_count(&result) - let rowCount = duckdb_row_count(&result) - - continuation.yield(.header(PluginStreamHeader( - columns: columns, - columnTypeNames: columnTypeNames, - estimatedRowCount: Int(rowCount) - ))) - - let maxRows = min(rowCount, UInt64(PluginRowLimits.emergencyMax)) - if rowCount > UInt64(PluginRowLimits.emergencyMax) { - Self.logger.warning("streamQuery truncating result from \(rowCount) to \(maxRows) rows") - } - - for row in 0.. DuckDBRawResult { - let colCount = duckdb_column_count(&result) - let rowCount = duckdb_row_count(&result) - let rowsChanged = duckdb_rows_changed(&result) - - var columns: [String] = [] - var columnTypeNames: [String] = [] - var columnTypes: [duckdb_type] = [] - - for i in 0.. UInt64(PluginRowLimits.emergencyMax) { - truncated = true - } - - for row in 0.. String { - switch type { - case DUCKDB_TYPE_BOOLEAN: return "BOOLEAN" - case DUCKDB_TYPE_TINYINT: return "TINYINT" - case DUCKDB_TYPE_SMALLINT: return "SMALLINT" - case DUCKDB_TYPE_INTEGER: return "INTEGER" - case DUCKDB_TYPE_BIGINT: return "BIGINT" - case DUCKDB_TYPE_UTINYINT: return "UTINYINT" - case DUCKDB_TYPE_USMALLINT: return "USMALLINT" - case DUCKDB_TYPE_UINTEGER: return "UINTEGER" - case DUCKDB_TYPE_UBIGINT: return "UBIGINT" - case DUCKDB_TYPE_FLOAT: return "FLOAT" - case DUCKDB_TYPE_DOUBLE: return "DOUBLE" - case DUCKDB_TYPE_TIMESTAMP: return "TIMESTAMP" - case DUCKDB_TYPE_DATE: return "DATE" - case DUCKDB_TYPE_TIME: return "TIME" - case DUCKDB_TYPE_INTERVAL: return "INTERVAL" - case DUCKDB_TYPE_HUGEINT: return "HUGEINT" - case DUCKDB_TYPE_VARCHAR: return "VARCHAR" - case DUCKDB_TYPE_BLOB: return "BLOB" - case DUCKDB_TYPE_DECIMAL: return "DECIMAL" - case DUCKDB_TYPE_TIMESTAMP_S: return "TIMESTAMP_S" - case DUCKDB_TYPE_TIMESTAMP_MS: return "TIMESTAMP_MS" - case DUCKDB_TYPE_TIMESTAMP_NS: return "TIMESTAMP_NS" - case DUCKDB_TYPE_ENUM: return "ENUM" - case DUCKDB_TYPE_LIST: return "LIST" - case DUCKDB_TYPE_STRUCT: return "STRUCT" - case DUCKDB_TYPE_MAP: return "MAP" - case DUCKDB_TYPE_UUID: return "UUID" - case DUCKDB_TYPE_UNION: return "UNION" - case DUCKDB_TYPE_BIT: return "BIT" - case DUCKDB_TYPE_TIMESTAMP_TZ: return "TIMESTAMPTZ" - case DUCKDB_TYPE_TIME_TZ: return "TIMETZ" - case DUCKDB_TYPE_TIME_NS: return "TIME_NS" - case DUCKDB_TYPE_UHUGEINT: return "UHUGEINT" - case DUCKDB_TYPE_ARRAY: return "ARRAY" - case DUCKDB_TYPE_GEOMETRY: return "GEOMETRY" - default: return "VARCHAR" - } - } - - private static func extractFallbackValue( - _ result: inout duckdb_result, col: idx_t, row: idx_t, type: duckdb_type - ) -> String? { - switch type { - case DUCKDB_TYPE_TIMESTAMP, DUCKDB_TYPE_TIMESTAMP_S, DUCKDB_TYPE_TIMESTAMP_MS, DUCKDB_TYPE_TIMESTAMP_NS: - let ts = duckdb_value_timestamp(&result, col, row) - return formatTimestamp(ts) - - case DUCKDB_TYPE_DATE: - let date = duckdb_value_date(&result, col, row) - let d = duckdb_from_date(date) - return String(format: "\(formatYearISO(d.year))-%02d-%02d", d.month, d.day) - - case DUCKDB_TYPE_TIME, DUCKDB_TYPE_TIME_NS: - let time = duckdb_value_time(&result, col, row) - return formatTime(duckdb_from_time(time)) - - case DUCKDB_TYPE_BOOLEAN: - return duckdb_value_boolean(&result, col, row) ? "true" : "false" - - case DUCKDB_TYPE_TINYINT: - return String(duckdb_value_int8(&result, col, row)) - case DUCKDB_TYPE_SMALLINT: - return String(duckdb_value_int16(&result, col, row)) - case DUCKDB_TYPE_INTEGER: - return String(duckdb_value_int32(&result, col, row)) - case DUCKDB_TYPE_BIGINT: - return String(duckdb_value_int64(&result, col, row)) - case DUCKDB_TYPE_UTINYINT: - return String(duckdb_value_uint8(&result, col, row)) - case DUCKDB_TYPE_USMALLINT: - return String(duckdb_value_uint16(&result, col, row)) - case DUCKDB_TYPE_UINTEGER: - return String(duckdb_value_uint32(&result, col, row)) - case DUCKDB_TYPE_UBIGINT: - return String(duckdb_value_uint64(&result, col, row)) - case DUCKDB_TYPE_FLOAT: - return String(duckdb_value_float(&result, col, row)) - case DUCKDB_TYPE_DOUBLE: - return String(duckdb_value_double(&result, col, row)) - - case DUCKDB_TYPE_HUGEINT: - let h = duckdb_value_hugeint(&result, col, row) - return formatHugeInt(upper: h.upper, lower: h.lower) - - case DUCKDB_TYPE_UHUGEINT: - let u = duckdb_value_uhugeint(&result, col, row) - return formatUHugeInt(upper: u.upper, lower: u.lower) - - default: - return nil - } - } - - static func patchTzColumns( - _ raw: inout DuckDBRawResult, query: String, connection: duckdb_connection - ) { - let patchedColIndices = raw.columnTypes.enumerated().compactMap { idx, type in - isUnrenderable(type) ? idx : nil - } - guard !patchedColIndices.isEmpty, !raw.rows.isEmpty else { return } - - let castExprs = raw.columns.enumerated().map { i, name in - castExpression(for: raw.columnTypes[i], column: name) - } - let wrappedQuery = buildWrappedQuery(originalQuery: query, castExprs: castExprs) - - var patchResult = duckdb_result() - guard duckdb_query(connection, wrappedQuery, &patchResult) == DuckDBSuccess else { return } - defer { duckdb_destroy_result(&patchResult) } - - let patchRowCount = min(duckdb_row_count(&patchResult), UInt64(raw.rows.count)) - for row in 0.. Bool { - switch type { - case DUCKDB_TYPE_TIMESTAMP_TZ, DUCKDB_TYPE_TIME_TZ, DUCKDB_TYPE_GEOMETRY: - return true - default: - return false - } - } - - static func castExpression(for type: duckdb_type, column: String) -> String { - let quoted = quoteIdentifier(column) - switch type { - case DUCKDB_TYPE_GEOMETRY: - return "CASE WHEN \(quoted) IS NULL THEN NULL ELSE ST_AsText(\(quoted)) END AS \(quoted)" - case DUCKDB_TYPE_TIMESTAMP_TZ, DUCKDB_TYPE_TIME_TZ: - return "CASE WHEN \(quoted) IS NULL THEN NULL ELSE CAST(\(quoted) AS VARCHAR) END AS \(quoted)" - default: - return quoted - } - } - - static func buildWrappedQuery(originalQuery: String, castExprs: [String]) -> String { - var trimmed = originalQuery.trimmingCharacters(in: .whitespacesAndNewlines) - if trimmed.hasSuffix(";") { - trimmed = String(trimmed.dropLast()) - } - return "SELECT \(castExprs.joined(separator: ", ")) FROM (\(trimmed)) AS _tp_cast" - } - - static func quoteIdentifier(_ ident: String) -> String { - "\"\(ident.replacingOccurrences(of: "\"", with: "\"\""))\"" - } - - static func formatTimestamp(_ ts: duckdb_timestamp) -> String { - let parts = duckdb_from_timestamp(ts) - let d = parts.date - let t = parts.time - let micros = t.micros % 1_000_000 - let yearPart = formatYearISO(d.year) - if micros == 0 { - return String( - format: "\(yearPart)-%02d-%02d %02d:%02d:%02d", - d.month, d.day, t.hour, t.min, t.sec - ) - } - return String( - format: "\(yearPart)-%02d-%02d %02d:%02d:%02d.%06d", - d.month, d.day, t.hour, t.min, t.sec, micros - ) - } - - static func formatYearISO(_ year: Int32) -> String { - if year < 0 { - return String(format: "-%04d", -Int(year)) - } - return String(format: "%04d", year) - } - - private static func formatTime(_ t: duckdb_time_struct) -> String { - let micros = t.micros % 1_000_000 - if micros == 0 { - return String(format: "%02d:%02d:%02d", t.hour, t.min, t.sec) - } - return String(format: "%02d:%02d:%02d.%06d", t.hour, t.min, t.sec, micros) - } - - static func formatHugeInt(upper: Int64, lower: UInt64) -> String { - HugeIntFormatter.format(upper: upper, lower: lower) - } - - static func formatUHugeInt(upper: UInt64, lower: UInt64) -> String { - HugeIntFormatter.formatUnsigned(upper: upper, lower: lower) - } -} - -private struct DuckDBRawResult: @unchecked Sendable { - let columns: [String] - let columnTypeNames: [String] - let columnTypes: [duckdb_type] - var rows: [[PluginCellValue]] - let rowsAffected: Int - let executionTime: TimeInterval - let isTruncated: Bool -} - // MARK: - DuckDB Plugin Driver final class DuckDBPluginDriver: PluginDatabaseDriver, @unchecked Sendable { @@ -1470,22 +877,3 @@ final class DuckDBPluginDriver: PluginDatabaseDriver, @unchecked Sendable { } } -// MARK: - Errors - -enum DuckDBPluginError: Error { - case connectionFailed(String) - case notConnected - case queryFailed(String) - case unsupportedOperation -} - -extension DuckDBPluginError: PluginDriverError { - var pluginErrorMessage: String { - switch self { - case .connectionFailed(let msg): return msg - case .notConnected: return String(localized: "Not connected to database") - case .queryFailed(let msg): return msg - case .unsupportedOperation: return String(localized: "Operation not supported") - } - } -} diff --git a/Plugins/DuckDBDriverPlugin/DuckDBPluginError.swift b/Plugins/DuckDBDriverPlugin/DuckDBPluginError.swift new file mode 100644 index 000000000..28ae42001 --- /dev/null +++ b/Plugins/DuckDBDriverPlugin/DuckDBPluginError.swift @@ -0,0 +1,25 @@ +// +// DuckDBPluginError.swift +// DuckDBDriverPlugin +// + +import Foundation +import TableProPluginKit + +enum DuckDBPluginError: Error { + case connectionFailed(String) + case notConnected + case queryFailed(String) + case unsupportedOperation +} + +extension DuckDBPluginError: PluginDriverError { + var pluginErrorMessage: String { + switch self { + case .connectionFailed(let msg): return msg + case .notConnected: return String(localized: "Not connected to database") + case .queryFailed(let msg): return msg + case .unsupportedOperation: return String(localized: "Operation not supported") + } + } +} diff --git a/Plugins/JSONExportPlugin/JSONExportPlugin.swift b/Plugins/JSONExportPlugin/JSONExportPlugin.swift index 6aa73363e..2868e2890 100644 --- a/Plugins/JSONExportPlugin/JSONExportPlugin.swift +++ b/Plugins/JSONExportPlugin/JSONExportPlugin.swift @@ -157,7 +157,7 @@ final class JSONExportPlugin: ExportFormatPlugin, SettablePlugin { return val.lowercased() } - let isNumericCol = isNumericColumnType(columnTypeName) + let isNumericCol = PluginExportUtilities.isNumericColumnType(columnTypeName) if isNumericCol && isValidIntegerLiteral(val) { if let intVal = Int(val) { @@ -183,15 +183,6 @@ final class JSONExportPlugin: ExportFormatPlugin, SettablePlugin { return "\"\(PluginExportUtilities.escapeJSONString(val))\"" } - private func isNumericColumnType(_ typeName: String) -> Bool { - let numericPrefixes = [ - "int", "bigint", "decimal", "float", "double", "numeric", - "real", "smallint", "tinyint", "mediumint", "integer", "number" - ] - let lower = typeName.lowercased() - return numericPrefixes.contains { lower.hasPrefix($0) } - } - private func isValidIntegerLiteral(_ val: String) -> Bool { guard !val.isEmpty else { return false } let digits = val.hasPrefix("-") || val.hasPrefix("+") ? String(val.dropFirst()) : val diff --git a/Plugins/JSONImportPlugin/Info.plist b/Plugins/JSONImportPlugin/Info.plist index 8ceb6def3..6174f79e9 100644 --- a/Plugins/JSONImportPlugin/Info.plist +++ b/Plugins/JSONImportPlugin/Info.plist @@ -3,7 +3,7 @@ TableProPluginKitVersion - 17 + 18 TableProProvidesImportFormatIds json diff --git a/Plugins/MSSQLDriverPlugin/MSSQLPlugin.swift b/Plugins/MSSQLDriverPlugin/MSSQLPlugin.swift index 16b33de1b..468243764 100644 --- a/Plugins/MSSQLDriverPlugin/MSSQLPlugin.swift +++ b/Plugins/MSSQLDriverPlugin/MSSQLPlugin.swift @@ -160,16 +160,16 @@ final class MSSQLPlugin: NSObject, TableProPlugin, DriverPlugin { final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { private let config: DriverConnectionConfig - private var freeTDSConn: FreeTDSConnection? - private var _currentSchema: String + var freeTDSConn: FreeTDSConnection? + var _currentSchema: String private var _serverVersion: String? /// IDENTITY columns observed during `fetchColumns`, keyed by table name. /// `generateMssqlInsert` reads this to skip IDENTITY columns: SQL Server /// rejects explicit values for IDENTITY columns unless IDENTITY_INSERT is ON, /// and the value the user typed is server-allocated anyway. - private var identityColumnsByTable: [String: Set] = [:] - private let identityCacheLock = NSLock() + var identityColumnsByTable: [String: Set] = [:] + let identityCacheLock = NSLock() private static let logger = Logger(subsystem: "com.TablePro", category: "MSSQLPluginDriver") @@ -554,527 +554,6 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { return nil } - // MARK: - Schema Operations - - func fetchTables(schema: String?) async throws -> [PluginTableInfo] { - let esc = effectiveSchemaEscaped(schema) - let sql = """ - SELECT t.TABLE_NAME, t.TABLE_TYPE - FROM INFORMATION_SCHEMA.TABLES t - WHERE t.TABLE_SCHEMA = '\(esc)' - AND t.TABLE_TYPE IN ('BASE TABLE', 'VIEW') - ORDER BY t.TABLE_NAME - """ - let result = try await execute(query: sql) - return result.rows.compactMap { row -> PluginTableInfo? in - guard let name = row[safe: 0]?.asText else { return nil } - let rawType = row[safe: 1]?.asText - let tableType = (rawType == "VIEW") ? "VIEW" : "TABLE" - return PluginTableInfo(name: name, type: tableType) - } - } - - func fetchColumns(table: String, schema: String?) async throws -> [PluginColumnInfo] { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - let esc = effectiveSchemaEscaped(schema) - let sql = """ - SELECT - c.COLUMN_NAME, - c.DATA_TYPE, - c.CHARACTER_MAXIMUM_LENGTH, - c.NUMERIC_PRECISION, - c.NUMERIC_SCALE, - c.IS_NULLABLE, - c.COLUMN_DEFAULT, - COLUMNPROPERTY(OBJECT_ID(c.TABLE_SCHEMA + '.' + c.TABLE_NAME), c.COLUMN_NAME, 'IsIdentity') AS IS_IDENTITY, - CASE WHEN pk.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END AS IS_PK - FROM INFORMATION_SCHEMA.COLUMNS c - LEFT JOIN ( - SELECT kcu.COLUMN_NAME - FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc - JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu - ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME - AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA - WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' - AND tc.TABLE_SCHEMA = '\(esc)' - AND tc.TABLE_NAME = '\(escapedTable)' - ) pk ON c.COLUMN_NAME = pk.COLUMN_NAME - WHERE c.TABLE_NAME = '\(escapedTable)' - AND c.TABLE_SCHEMA = '\(esc)' - ORDER BY c.ORDINAL_POSITION - """ - let result = try await execute(query: sql) - var identityColumns: Set = [] - let columns: [PluginColumnInfo] = result.rows.compactMap { row -> PluginColumnInfo? in - guard let name = row[safe: 0]?.asText else { return nil } - let dataType = row[safe: 1]?.asText - let charLen = row[safe: 2]?.asText - let numPrecision = row[safe: 3]?.asText - let numScale = row[safe: 4]?.asText - let isNullable = (row[safe: 5]?.asText) == "YES" - let defaultValue = row[safe: 6]?.asText - let isIdentity = (row[safe: 7]?.asText) == "1" - let isPk = (row[safe: 8]?.asText) == "1" - - if isIdentity { - identityColumns.insert(name) - } - - let baseType = (dataType ?? "nvarchar").lowercased() - let fixedSizeTypes: Set = [ - "int", "bigint", "smallint", "tinyint", "bit", - "money", "smallmoney", "float", "real", - "datetime", "datetime2", "smalldatetime", "date", "time", - "uniqueidentifier", "text", "ntext", "image", "xml", - "timestamp", "rowversion" - ] - var fullType = baseType - if fixedSizeTypes.contains(baseType) { - // No suffix - } else if let charLen, let len = Int(charLen), len > 0 { - fullType += "(\(len))" - } else if charLen == "-1" { - fullType += "(max)" - } else if let prec = numPrecision, let scale = numScale, - let p = Int(prec), let s = Int(scale) { - fullType += "(\(p),\(s))" - } - - return PluginColumnInfo( - name: name, - dataType: fullType, - isNullable: isNullable, - isPrimaryKey: isPk, - defaultValue: defaultValue, - extra: isIdentity ? "IDENTITY" : nil - ) - } - identityCacheLock.lock() - identityColumnsByTable[table] = identityColumns - identityCacheLock.unlock() - return columns - } - - /// Snapshot of IDENTITY columns observed by the most recent `fetchColumns` for the table. - /// Returns an empty set when `fetchColumns` hasn't run for this table yet, so callers - /// fall through to including every typed value (matching pre-cache behavior). - internal func cachedIdentityColumns(for table: String) -> Set { - identityCacheLock.lock() - defer { identityCacheLock.unlock() } - return identityColumnsByTable[table] ?? [] - } - - /// Test seam: pre-populate the cache so generateMssqlInsert can be exercised - /// without going through a live `fetchColumns` round-trip. - internal func setIdentityColumnsForTesting(_ columns: Set, table: String) { - identityCacheLock.lock() - identityColumnsByTable[table] = columns - identityCacheLock.unlock() - } - - func fetchIndexes(table: String, schema: String?) async throws -> [PluginIndexInfo] { - let esc = (schema ?? _currentSchema).replacingOccurrences(of: "]", with: "]]") - let bracketedTable = table.replacingOccurrences(of: "]", with: "]]") - let bracketedFull = "[\(esc)].[\(bracketedTable)]" - let sql = """ - SELECT i.name, i.is_unique, i.is_primary_key, c.name AS column_name - FROM sys.indexes i - JOIN sys.index_columns ic - ON i.object_id = ic.object_id AND i.index_id = ic.index_id - JOIN sys.columns c - ON ic.object_id = c.object_id AND ic.column_id = c.column_id - WHERE i.object_id = OBJECT_ID('\(bracketedFull)') - AND i.name IS NOT NULL - ORDER BY i.index_id, ic.key_ordinal - """ - let result = try await execute(query: sql) - var indexMap: [String: (unique: Bool, primary: Bool, columns: [String])] = [:] - for row in result.rows { - guard let idxName = row[safe: 0]?.asText, - let colName = row[safe: 3]?.asText else { continue } - let isUnique = (row[safe: 1]?.asText) == "1" - let isPrimary = (row[safe: 2]?.asText) == "1" - if indexMap[idxName] == nil { - indexMap[idxName] = (unique: isUnique, primary: isPrimary, columns: []) - } - indexMap[idxName]?.columns.append(colName) - } - return indexMap.map { name, info in - PluginIndexInfo( - name: name, - columns: info.columns, - isUnique: info.unique, - isPrimary: info.primary, - type: "CLUSTERED" - ) - }.sorted { $0.name < $1.name } - } - - func fetchForeignKeys(table: String, schema: String?) async throws -> [PluginForeignKeyInfo] { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - let esc = effectiveSchemaEscaped(schema) - let sql = """ - SELECT - fk.name AS constraint_name, - cp.name AS column_name, - tr.name AS ref_table, - cr.name AS ref_column - FROM sys.foreign_keys fk - JOIN sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id - JOIN sys.tables tp ON fkc.parent_object_id = tp.object_id - JOIN sys.schemas s ON tp.schema_id = s.schema_id - JOIN sys.columns cp - ON fkc.parent_object_id = cp.object_id AND fkc.parent_column_id = cp.column_id - JOIN sys.tables tr ON fkc.referenced_object_id = tr.object_id - JOIN sys.columns cr - ON fkc.referenced_object_id = cr.object_id AND fkc.referenced_column_id = cr.column_id - WHERE tp.name = '\(escapedTable)' AND s.name = '\(esc)' - ORDER BY fk.name - """ - let result = try await execute(query: sql) - return result.rows.compactMap { row -> PluginForeignKeyInfo? in - guard let constraintName = row[safe: 0]?.asText, - let columnName = row[safe: 1]?.asText, - let refTable = row[safe: 2]?.asText, - let refColumn = row[safe: 3]?.asText else { return nil } - return PluginForeignKeyInfo( - name: constraintName, - column: columnName, - referencedTable: refTable, - referencedColumn: refColumn - ) - } - } - - func fetchAllColumns(schema: String?) async throws -> [String: [PluginColumnInfo]] { - let esc = effectiveSchemaEscaped(schema) - let sql = """ - SELECT - c.TABLE_NAME, - c.COLUMN_NAME, - c.DATA_TYPE, - c.CHARACTER_MAXIMUM_LENGTH, - c.NUMERIC_PRECISION, - c.NUMERIC_SCALE, - c.IS_NULLABLE, - c.COLUMN_DEFAULT, - COLUMNPROPERTY(OBJECT_ID(c.TABLE_SCHEMA + '.' + c.TABLE_NAME), c.COLUMN_NAME, 'IsIdentity') AS IS_IDENTITY, - CASE WHEN pk.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END AS IS_PK - FROM INFORMATION_SCHEMA.COLUMNS c - LEFT JOIN ( - SELECT kcu.TABLE_NAME, kcu.COLUMN_NAME - FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc - JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu - ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME - AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA - WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' - AND tc.TABLE_SCHEMA = '\(esc)' - ) pk ON c.TABLE_NAME = pk.TABLE_NAME AND c.COLUMN_NAME = pk.COLUMN_NAME - WHERE c.TABLE_SCHEMA = '\(esc)' - ORDER BY c.TABLE_NAME, c.ORDINAL_POSITION - """ - let result = try await execute(query: sql) - var columnsByTable: [String: [PluginColumnInfo]] = [:] - for row in result.rows { - guard let tableName = row[safe: 0]?.asText, - let name = row[safe: 1]?.asText else { continue } - let dataType = row[safe: 2]?.asText - let charLen = row[safe: 3]?.asText - let numPrecision = row[safe: 4]?.asText - let numScale = row[safe: 5]?.asText - let isNullable = (row[safe: 6]?.asText) == "YES" - let defaultValue = row[safe: 7]?.asText - let isIdentity = (row[safe: 8]?.asText) == "1" - let isPk = (row[safe: 9]?.asText) == "1" - - let baseType = (dataType ?? "nvarchar").lowercased() - let fixedSizeTypes: Set = [ - "int", "bigint", "smallint", "tinyint", "bit", - "money", "smallmoney", "float", "real", - "datetime", "datetime2", "smalldatetime", "date", "time", - "uniqueidentifier", "text", "ntext", "image", "xml", - "timestamp", "rowversion" - ] - var fullType = baseType - if fixedSizeTypes.contains(baseType) { - // No suffix - } else if let charLen, let len = Int(charLen), len > 0 { - fullType += "(\(len))" - } else if charLen == "-1" { - fullType += "(max)" - } else if let prec = numPrecision, let scale = numScale, - let p = Int(prec), let s = Int(scale) { - fullType += "(\(p),\(s))" - } - - let col = PluginColumnInfo( - name: name, - dataType: fullType, - isNullable: isNullable, - isPrimaryKey: isPk, - defaultValue: defaultValue, - extra: isIdentity ? "IDENTITY" : nil - ) - columnsByTable[tableName, default: []].append(col) - } - return columnsByTable - } - - func fetchAllForeignKeys(schema: String?) async throws -> [String: [PluginForeignKeyInfo]] { - let esc = effectiveSchemaEscaped(schema) - let sql = """ - SELECT - tp.name AS table_name, - fk.name AS constraint_name, - cp.name AS column_name, - tr.name AS ref_table, - cr.name AS ref_column - FROM sys.foreign_keys fk - JOIN sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id - JOIN sys.tables tp ON fkc.parent_object_id = tp.object_id - JOIN sys.schemas s ON tp.schema_id = s.schema_id - JOIN sys.columns cp - ON fkc.parent_object_id = cp.object_id AND fkc.parent_column_id = cp.column_id - JOIN sys.tables tr ON fkc.referenced_object_id = tr.object_id - JOIN sys.columns cr - ON fkc.referenced_object_id = cr.object_id AND fkc.referenced_column_id = cr.column_id - WHERE s.name = '\(esc)' - ORDER BY tp.name, fk.name - """ - let result = try await execute(query: sql) - var fksByTable: [String: [PluginForeignKeyInfo]] = [:] - for row in result.rows { - guard let tableName = row[safe: 0]?.asText, - let constraintName = row[safe: 1]?.asText, - let columnName = row[safe: 2]?.asText, - let refTable = row[safe: 3]?.asText, - let refColumn = row[safe: 4]?.asText else { continue } - let fk = PluginForeignKeyInfo( - name: constraintName, - column: columnName, - referencedTable: refTable, - referencedColumn: refColumn - ) - fksByTable[tableName, default: []].append(fk) - } - return fksByTable - } - - func fetchAllDatabaseMetadata() async throws -> [PluginDatabaseMetadata] { - let sql = """ - SELECT d.name, - SUM(mf.size) * 8 * 1024 AS size_bytes - FROM sys.databases d - LEFT JOIN sys.master_files mf ON d.database_id = mf.database_id - GROUP BY d.name - ORDER BY d.name - """ - do { - let result = try await execute(query: sql) - var metadata = result.rows.compactMap { row -> PluginDatabaseMetadata? in - guard let name = row[safe: 0]?.asText else { return nil } - let sizeBytes = (row[safe: 1]?.asText).flatMap { Int64($0) } - return PluginDatabaseMetadata(name: name, sizeBytes: sizeBytes) - } - - for i in metadata.indices { - let dbName = metadata[i].name.replacingOccurrences(of: "]", with: "]]") - do { - let countResult = try await execute( - query: "SELECT COUNT(*) FROM [\(dbName)].sys.tables" - ) - if let countStr = countResult.rows.first?[safe: 0]?.asText, - let count = Int(countStr) { - metadata[i] = PluginDatabaseMetadata( - name: metadata[i].name, - tableCount: count, - sizeBytes: metadata[i].sizeBytes - ) - } - } catch { - // Database offline or permission denied — leave tableCount as nil - } - } - - return metadata - } catch { - // Fall back to N+1 if permission denied on sys.master_files - let dbs = try await fetchDatabases() - var result: [PluginDatabaseMetadata] = [] - for db in dbs { - do { - result.append(try await fetchDatabaseMetadata(db)) - } catch { - result.append(PluginDatabaseMetadata(name: db)) - } - } - return result - } - } - - func fetchTableDDL(table: String, schema: String?) async throws -> String { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - let esc = effectiveSchemaEscaped(schema) - let cols = try await fetchColumns(table: table, schema: schema) - let indexes = try await fetchIndexes(table: table, schema: schema) - let fks = try await fetchForeignKeys(table: table, schema: schema) - - var ddl = "CREATE TABLE [\(esc)].[\(escapedTable)] (\n" - let colDefs = cols.map { col -> String in - var def = " [\(col.name)] \(col.dataType.uppercased())" - if col.extra == "IDENTITY" { def += " IDENTITY(1,1)" } - def += col.isNullable ? " NULL" : " NOT NULL" - if let d = col.defaultValue { def += " DEFAULT \(d)" } - return def - } - - let pkCols = indexes.filter(\.isPrimary).flatMap(\.columns) - var parts = colDefs - if !pkCols.isEmpty { - let pkName = "PK_\(table)" - let pkDef = " CONSTRAINT [\(pkName)] PRIMARY KEY (\(pkCols.map { "[\($0)]" }.joined(separator: ", ")))" - parts.append(pkDef) - } - - for fk in fks { - let fkDef = " CONSTRAINT [\(fk.name)] FOREIGN KEY ([\(fk.column)]) REFERENCES [\(fk.referencedTable)] ([\(fk.referencedColumn)])" - parts.append(fkDef) - } - - ddl += parts.joined(separator: ",\n") - ddl += "\n);" - return ddl - } - - func fetchViewDefinition(view: String, schema: String?) async throws -> String { - let esc = effectiveSchemaEscaped(schema) - let escapedView = "\(esc).\(view.replacingOccurrences(of: "'", with: "''"))" - let sql = "SELECT definition FROM sys.sql_modules WHERE object_id = OBJECT_ID('\(escapedView)')" - let result = try await execute(query: sql) - return result.rows.first?.first?.asText ?? "" - } - - func fetchTableMetadata(table: String, schema: String?) async throws -> PluginTableMetadata { - let escapedTable = table.replacingOccurrences(of: "'", with: "''") - let esc = effectiveSchemaEscaped(schema) - let sql = """ - SELECT - SUM(p.rows) AS row_count, - 8 * SUM(a.used_pages) AS size_kb, - ep.value AS comment - FROM sys.tables t - JOIN sys.schemas s ON t.schema_id = s.schema_id - JOIN sys.partitions p - ON t.object_id = p.object_id AND p.index_id IN (0, 1) - JOIN sys.allocation_units a ON p.partition_id = a.container_id - LEFT JOIN sys.extended_properties ep - ON ep.major_id = t.object_id AND ep.minor_id = 0 AND ep.name = 'MS_Description' - WHERE t.name = '\(escapedTable)' AND s.name = '\(esc)' - GROUP BY ep.value - """ - let result = try await execute(query: sql) - if let row = result.rows.first { - let rowCount = (row[safe: 0]?.asText).flatMap { Int64($0) } - let sizeKb = (row[safe: 1]?.asText).flatMap { Int64($0) } ?? 0 - let comment = row[safe: 2]?.asText - return PluginTableMetadata( - tableName: table, - dataSize: sizeKb * 1_024, - totalSize: sizeKb * 1_024, - rowCount: rowCount, - comment: comment - ) - } - return PluginTableMetadata(tableName: table) - } - - func fetchDatabases() async throws -> [String] { - let sql = "SELECT name FROM sys.databases ORDER BY name" - let result = try await execute(query: sql) - return result.rows.compactMap { $0.first?.asText } - } - - func fetchSchemas() async throws -> [String] { - let sql = """ - SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA - WHERE SCHEMA_NAME NOT IN ( - 'information_schema','sys','db_owner','db_accessadmin', - 'db_securityadmin','db_ddladmin','db_backupoperator', - 'db_datareader','db_datawriter','db_denydatareader', - 'db_denydatawriter','guest' - ) - ORDER BY SCHEMA_NAME - """ - let result = try await execute(query: sql) - return result.rows.compactMap { $0.first?.asText } - } - - func switchSchema(to schema: String) async throws { - _currentSchema = schema - } - - func switchDatabase(to database: String) async throws { - guard let conn = freeTDSConn else { - throw MSSQLPluginError.notConnected - } - try await conn.switchDatabase(database) - } - - func fetchDatabaseMetadata(_ database: String) async throws -> PluginDatabaseMetadata { - let sql = """ - SELECT - SUM(size) * 8.0 / 1024 AS size_mb, - (SELECT COUNT(*) FROM sys.tables) AS table_count - FROM sys.database_files - """ - let result = try await execute(query: sql) - if let row = result.rows.first { - let sizeMb = (row[safe: 0]?.asText).flatMap { Double($0) } ?? 0 - let tableCount = (row[safe: 1]?.asText).flatMap { Int($0) } ?? 0 - return PluginDatabaseMetadata( - name: database, - tableCount: tableCount, - sizeBytes: Int64(sizeMb * 1_024 * 1_024) - ) - } - return PluginDatabaseMetadata(name: database) - } - - func createDatabaseFormSpec() async throws -> PluginCreateDatabaseFormSpec? { - PluginCreateDatabaseFormSpec(fields: [], footnote: nil) - } - - func createDatabase(_ request: PluginCreateDatabaseRequest) async throws { - let quotedName = "[\(request.name.replacingOccurrences(of: "]", with: "]]"))]" - _ = try await execute(query: "CREATE DATABASE \(quotedName)") - } - - func dropDatabase(name: String) async throws { - let quotedName = "[\(name.replacingOccurrences(of: "]", with: "]]"))]" - _ = try await execute(query: "DROP DATABASE \(quotedName)") - } - - // MARK: - All Tables Metadata - - func allTablesMetadataSQL(schema: String?) -> String? { - """ - SELECT - s.name as schema_name, - t.name as name, - CASE WHEN v.object_id IS NOT NULL THEN 'VIEW' ELSE 'TABLE' END as kind, - p.rows as estimated_rows, - CAST(ROUND(SUM(a.total_pages) * 8 / 1024.0, 2) AS VARCHAR) + ' MB' as total_size - FROM sys.tables t - INNER JOIN sys.schemas s ON t.schema_id = s.schema_id - INNER JOIN sys.indexes i ON t.object_id = i.object_id AND i.index_id IN (0, 1) - INNER JOIN sys.partitions p ON i.object_id = p.object_id AND i.index_id = p.index_id - INNER JOIN sys.allocation_units a ON p.partition_id = a.container_id - LEFT JOIN sys.views v ON t.object_id = v.object_id - GROUP BY s.name, t.name, p.rows, v.object_id - ORDER BY t.name - """ - } - // MARK: - Query Building func buildBrowseQuery( @@ -1086,8 +565,9 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { ) -> String? { let quotedTable = mssqlQuoteIdentifier(table) var query = "SELECT * FROM \(quotedTable)" - let orderBy = mssqlBuildOrderByClause(sortColumns: sortColumns, columns: columns) - ?? "ORDER BY (SELECT NULL)" + let orderBy = PluginSQLFilter.buildOrderByClause( + sortColumns: sortColumns, columns: columns, quoteIdentifier: mssqlQuoteIdentifier + ) ?? "ORDER BY (SELECT NULL)" query += " \(orderBy) OFFSET \(offset) ROWS FETCH NEXT \(limit) ROWS ONLY" return query } @@ -1103,12 +583,21 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { ) -> String? { let quotedTable = mssqlQuoteIdentifier(table) var query = "SELECT * FROM \(quotedTable)" - let whereClause = mssqlBuildWhereClause(filters: filters, logicMode: logicMode) + let whereClause = PluginSQLFilter.buildWhereClause( + filters: filters, + logicMode: logicMode, + quoteIdentifier: mssqlQuoteIdentifier, + escapeValue: mssqlEscapeValue, + regexCondition: { quoted, value in + "\(quoted) LIKE '%\(value.replacingOccurrences(of: "'", with: "''"))%'" + } + ) if !whereClause.isEmpty { query += " WHERE \(whereClause)" } - let orderBy = mssqlBuildOrderByClause(sortColumns: sortColumns, columns: columns) - ?? "ORDER BY (SELECT NULL)" + let orderBy = PluginSQLFilter.buildOrderByClause( + sortColumns: sortColumns, columns: columns, quoteIdentifier: mssqlQuoteIdentifier + ) ?? "ORDER BY (SELECT NULL)" query += " \(orderBy) OFFSET \(offset) ROWS FETCH NEXT \(limit) ROWS ONLY" return query } @@ -1119,29 +608,6 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { quoteIdentifier(identifier) } - private func mssqlBuildOrderByClause( - sortColumns: [(columnIndex: Int, ascending: Bool)], - columns: [String] - ) -> String? { - let parts = sortColumns.compactMap { sortCol -> String? in - guard sortCol.columnIndex >= 0, sortCol.columnIndex < columns.count else { return nil } - let columnName = columns[sortCol.columnIndex] - let direction = sortCol.ascending ? "ASC" : "DESC" - let quotedColumn = mssqlQuoteIdentifier(columnName) - return "\(quotedColumn) \(direction)" - } - guard !parts.isEmpty else { return nil } - return "ORDER BY " + parts.joined(separator: ", ") - } - - private func mssqlEscapeForLike(_ text: String) -> String { - text - .replacingOccurrences(of: "\\", with: "\\\\") - .replacingOccurrences(of: "%", with: "\\%") - .replacingOccurrences(of: "_", with: "\\_") - .replacingOccurrences(of: "'", with: "''") - } - private func mssqlEscapeValue(_ value: String) -> String { let trimmed = value.trimmingCharacters(in: .whitespaces) if trimmed.caseInsensitiveCompare("NULL") == .orderedSame { return "NULL" } @@ -1151,65 +617,6 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { return "'\(trimmed.replacingOccurrences(of: "'", with: "''"))'" } - private func mssqlBuildWhereClause( - filters: [(column: String, op: String, value: String)], - logicMode: String - ) -> String { - let conditions = filters.compactMap { filter -> String? in - mssqlBuildFilterCondition(column: filter.column, op: filter.op, value: filter.value) - } - guard !conditions.isEmpty else { return "" } - let separator = logicMode == "and" ? " AND " : " OR " - return conditions.joined(separator: separator) - } - - private func mssqlBuildFilterCondition(column: String, op: String, value: String) -> String? { - let quoted = mssqlQuoteIdentifier(column) - switch op { - case "=": return "\(quoted) = \(mssqlEscapeValue(value))" - case "!=": return "\(quoted) != \(mssqlEscapeValue(value))" - case ">": return "\(quoted) > \(mssqlEscapeValue(value))" - case ">=": return "\(quoted) >= \(mssqlEscapeValue(value))" - case "<": return "\(quoted) < \(mssqlEscapeValue(value))" - case "<=": return "\(quoted) <= \(mssqlEscapeValue(value))" - case "IS NULL": return "\(quoted) IS NULL" - case "IS NOT NULL": return "\(quoted) IS NOT NULL" - case "IS EMPTY": return "(\(quoted) IS NULL OR \(quoted) = '')" - case "IS NOT EMPTY": return "(\(quoted) IS NOT NULL AND \(quoted) != '')" - case "CONTAINS": - let escaped = mssqlEscapeForLike(value) - return "\(quoted) LIKE '%\(escaped)%' ESCAPE '\\'" - case "NOT CONTAINS": - let escaped = mssqlEscapeForLike(value) - return "\(quoted) NOT LIKE '%\(escaped)%' ESCAPE '\\'" - case "STARTS WITH": - let escaped = mssqlEscapeForLike(value) - return "\(quoted) LIKE '\(escaped)%' ESCAPE '\\'" - case "ENDS WITH": - let escaped = mssqlEscapeForLike(value) - return "\(quoted) LIKE '%\(escaped)' ESCAPE '\\'" - case "IN": - let values = value.split(separator: ",") - .map { mssqlEscapeValue($0.trimmingCharacters(in: .whitespaces)) } - .joined(separator: ", ") - return values.isEmpty ? nil : "\(quoted) IN (\(values))" - case "NOT IN": - let values = value.split(separator: ",") - .map { mssqlEscapeValue($0.trimmingCharacters(in: .whitespaces)) } - .joined(separator: ", ") - return values.isEmpty ? nil : "\(quoted) NOT IN (\(values))" - case "BETWEEN": - let parts = value.split(separator: ",", maxSplits: 1) - guard parts.count == 2 else { return nil } - let v1 = mssqlEscapeValue(parts[0].trimmingCharacters(in: .whitespaces)) - let v2 = mssqlEscapeValue(parts[1].trimmingCharacters(in: .whitespaces)) - return "\(quoted) BETWEEN \(v1) AND \(v2)" - case "REGEX": - let escaped = value.replacingOccurrences(of: "'", with: "''") - return "\(quoted) LIKE '%\(escaped)%'" - default: return nil - } - } // MARK: - Private Helpers @@ -1277,183 +684,11 @@ final class MSSQLPluginDriver: PluginDatabaseDriver, @unchecked Sendable { value.replacingOccurrences(of: "'", with: "''") } - private func effectiveSchemaEscaped(_ schema: String?) -> String { + func effectiveSchemaEscaped(_ schema: String?) -> String { let raw = schema ?? _currentSchema return raw.replacingOccurrences(of: "'", with: "''") } - // MARK: - Create Table DDL - - func generateCreateTableSQL(definition: PluginCreateTableDefinition) -> String? { - guard !definition.columns.isEmpty else { return nil } - - let schema = _currentSchema - let qualifiedTable = "\(quoteIdentifier(schema)).\(quoteIdentifier(definition.tableName))" - let pkColumns = definition.columns.filter { $0.isPrimaryKey } - let inlinePK = pkColumns.count == 1 - var parts: [String] = definition.columns.map { mssqlColumnDefinition($0, inlinePK: inlinePK) } - - if pkColumns.count > 1 { - let pkCols = pkColumns.map { quoteIdentifier($0.name) }.joined(separator: ", ") - parts.append("PRIMARY KEY (\(pkCols))") - } - - for fk in definition.foreignKeys { - parts.append(mssqlForeignKeyDefinition(fk)) - } - - var sql = "CREATE TABLE \(qualifiedTable) (\n " + - parts.joined(separator: ",\n ") + - "\n);" - - var indexStatements: [String] = [] - for index in definition.indexes { - indexStatements.append(mssqlIndexDefinition(index, qualifiedTable: qualifiedTable)) - } - if !indexStatements.isEmpty { - sql += "\n\n" + indexStatements.joined(separator: ";\n") + ";" - } - - return sql - } - - private func mssqlColumnDefinition(_ col: PluginColumnDefinition, inlinePK: Bool) -> String { - var def = "\(quoteIdentifier(col.name)) \(col.dataType)" - if col.autoIncrement { - def += " IDENTITY(1,1)" - } - if col.isNullable { - def += " NULL" - } else { - def += " NOT NULL" - } - if let defaultValue = col.defaultValue { - def += " DEFAULT \(mssqlDefaultValue(defaultValue))" - } - if inlinePK && col.isPrimaryKey { - def += " PRIMARY KEY" - } - return def - } - - private func mssqlDefaultValue(_ value: String) -> String { - let upper = value.uppercased() - if upper == "NULL" || upper == "GETDATE()" || upper == "NEWID()" || upper == "GETUTCDATE()" - || value.hasPrefix("'") || value.hasPrefix("(") || Int64(value) != nil || Double(value) != nil { - return value - } - return "'\(escapeStringLiteral(value))'" - } - - private func mssqlIndexDefinition(_ index: PluginIndexDefinition, qualifiedTable: String) -> String { - let cols = index.columns.map { quoteIdentifier($0) }.joined(separator: ", ") - let unique = index.isUnique ? "UNIQUE " : "" - var def = "CREATE \(unique)INDEX \(quoteIdentifier(index.name)) ON \(qualifiedTable) (\(cols))" - if let type = index.indexType?.uppercased(), type == "CLUSTERED" { - def = "CREATE \(unique)CLUSTERED INDEX \(quoteIdentifier(index.name)) ON \(qualifiedTable) (\(cols))" - } else if let type = index.indexType?.uppercased(), type == "NONCLUSTERED" { - def = "CREATE \(unique)NONCLUSTERED INDEX \(quoteIdentifier(index.name)) ON \(qualifiedTable) (\(cols))" - } - return def - } - - private func mssqlForeignKeyDefinition(_ fk: PluginForeignKeyDefinition) -> String { - let cols = fk.columns.map { quoteIdentifier($0) }.joined(separator: ", ") - let refCols = fk.referencedColumns.map { quoteIdentifier($0) }.joined(separator: ", ") - var def = "CONSTRAINT \(quoteIdentifier(fk.name)) FOREIGN KEY (\(cols)) REFERENCES \(quoteIdentifier(fk.referencedTable)) (\(refCols))" - if fk.onDelete != "NO ACTION" { - def += " ON DELETE \(fk.onDelete)" - } - if fk.onUpdate != "NO ACTION" { - def += " ON UPDATE \(fk.onUpdate)" - } - return def - } - - // MARK: - ALTER TABLE DDL - - private func mssqlQualifiedTable(_ table: String) -> String { - "\(quoteIdentifier(_currentSchema)).\(quoteIdentifier(table))" - } - - func generateAddColumnSQL(table: String, column: PluginColumnDefinition) -> String? { - "ALTER TABLE \(mssqlQualifiedTable(table)) ADD \(mssqlColumnDefinition(column, inlinePK: false))" - } - - func generateModifyColumnSQL(table: String, oldColumn: PluginColumnDefinition, newColumn: PluginColumnDefinition) -> String? { - let qt = mssqlQualifiedTable(table) - var stmts: [String] = [] - let needsTypeChange = oldColumn.dataType != newColumn.dataType || oldColumn.isNullable != newColumn.isNullable - let defaultChanged = oldColumn.defaultValue != newColumn.defaultValue - - // Rename column first so subsequent statements reference the correct name - if oldColumn.name != newColumn.name { - let escapedPath = "\(escapeStringLiteral(_currentSchema)).\(escapeStringLiteral(table)).\(escapeStringLiteral(oldColumn.name))" - stmts.append("EXEC sp_rename '\(escapedPath)', '\(escapeStringLiteral(newColumn.name))', 'COLUMN'") - } - - let colName = quoteIdentifier(newColumn.name) - - // Drop existing default constraint before ALTER COLUMN or default change - if (defaultChanged || needsTypeChange) && oldColumn.defaultValue != nil { - let objectId = escapeStringLiteral("\(_currentSchema).\(table)") - stmts.append(""" - DECLARE @dfName NVARCHAR(256); \ - SELECT @dfName = dc.name FROM sys.default_constraints dc \ - JOIN sys.columns c ON dc.parent_column_id = c.column_id AND dc.parent_object_id = c.object_id \ - WHERE c.name = '\(escapeStringLiteral(newColumn.name))' \ - AND dc.parent_object_id = OBJECT_ID('\(objectId)'); \ - IF @dfName IS NOT NULL EXEC('ALTER TABLE \(qt) DROP CONSTRAINT [' + @dfName + ']') - """) - } - - if needsTypeChange { - let nullable = newColumn.isNullable ? "NULL" : "NOT NULL" - stmts.append("ALTER TABLE \(qt) ALTER COLUMN \(colName) \(newColumn.dataType) \(nullable)") - } - - if defaultChanged, let defaultValue = newColumn.defaultValue { - stmts.append("ALTER TABLE \(qt) ADD DEFAULT \(mssqlDefaultValue(defaultValue)) FOR \(colName)") - } - - return stmts.isEmpty ? nil : stmts.joined(separator: ";\n") - } - - func generateDropColumnSQL(table: String, columnName: String) -> String? { - "ALTER TABLE \(mssqlQualifiedTable(table)) DROP COLUMN \(quoteIdentifier(columnName))" - } - - func generateAddIndexSQL(table: String, index: PluginIndexDefinition) -> String? { - mssqlIndexDefinition(index, qualifiedTable: mssqlQualifiedTable(table)) - } - - func generateDropIndexSQL(table: String, indexName: String) -> String? { - "DROP INDEX \(quoteIdentifier(indexName)) ON \(mssqlQualifiedTable(table))" - } - - func generateAddForeignKeySQL(table: String, fk: PluginForeignKeyDefinition) -> String? { - "ALTER TABLE \(mssqlQualifiedTable(table)) ADD \(mssqlForeignKeyDefinition(fk))" - } - - func generateDropForeignKeySQL(table: String, constraintName: String) -> String? { - "ALTER TABLE \(mssqlQualifiedTable(table)) DROP CONSTRAINT \(quoteIdentifier(constraintName))" - } - - func generateModifyPrimaryKeySQL(table: String, oldColumns: [String], newColumns: [String], constraintName: String?) -> [String]? { - let qt = mssqlQualifiedTable(table) - var stmts: [String] = [] - if !oldColumns.isEmpty { - let name = constraintName.map { quoteIdentifier($0) } ?? "/* unknown constraint */" - stmts.append("ALTER TABLE \(qt) DROP CONSTRAINT \(name)") - } - if !newColumns.isEmpty { - let cols = newColumns.map { quoteIdentifier($0) }.joined(separator: ", ") - let pkName = constraintName.map { quoteIdentifier($0) } ?? quoteIdentifier("PK_\(table)") - stmts.append("ALTER TABLE \(qt) ADD CONSTRAINT \(pkName) PRIMARY KEY (\(cols))") - } - return stmts.isEmpty ? nil : stmts - } - } // MARK: - Errors diff --git a/Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+DDL.swift b/Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+DDL.swift new file mode 100644 index 000000000..fabe9f46a --- /dev/null +++ b/Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+DDL.swift @@ -0,0 +1,184 @@ +// +// MSSQLPluginDriver+DDL.swift +// MSSQLDriverPlugin +// + +import Foundation +import os +import TableProMSSQLCore +import TableProPluginKit + +extension MSSQLPluginDriver { + // MARK: - Create Table DDL + + func generateCreateTableSQL(definition: PluginCreateTableDefinition) -> String? { + guard !definition.columns.isEmpty else { return nil } + + let schema = _currentSchema + let qualifiedTable = "\(quoteIdentifier(schema)).\(quoteIdentifier(definition.tableName))" + let pkColumns = definition.columns.filter { $0.isPrimaryKey } + let inlinePK = pkColumns.count == 1 + var parts: [String] = definition.columns.map { mssqlColumnDefinition($0, inlinePK: inlinePK) } + + if pkColumns.count > 1 { + let pkCols = pkColumns.map { quoteIdentifier($0.name) }.joined(separator: ", ") + parts.append("PRIMARY KEY (\(pkCols))") + } + + for fk in definition.foreignKeys { + parts.append(mssqlForeignKeyDefinition(fk)) + } + + var sql = "CREATE TABLE \(qualifiedTable) (\n " + + parts.joined(separator: ",\n ") + + "\n);" + + var indexStatements: [String] = [] + for index in definition.indexes { + indexStatements.append(mssqlIndexDefinition(index, qualifiedTable: qualifiedTable)) + } + if !indexStatements.isEmpty { + sql += "\n\n" + indexStatements.joined(separator: ";\n") + ";" + } + + return sql + } + + private func mssqlColumnDefinition(_ col: PluginColumnDefinition, inlinePK: Bool) -> String { + var def = "\(quoteIdentifier(col.name)) \(col.dataType)" + if col.autoIncrement { + def += " IDENTITY(1,1)" + } + if col.isNullable { + def += " NULL" + } else { + def += " NOT NULL" + } + if let defaultValue = col.defaultValue { + def += " DEFAULT \(mssqlDefaultValue(defaultValue))" + } + if inlinePK && col.isPrimaryKey { + def += " PRIMARY KEY" + } + return def + } + + private func mssqlDefaultValue(_ value: String) -> String { + let upper = value.uppercased() + if upper == "NULL" || upper == "GETDATE()" || upper == "NEWID()" || upper == "GETUTCDATE()" + || value.hasPrefix("'") || value.hasPrefix("(") || Int64(value) != nil || Double(value) != nil { + return value + } + return "'\(escapeStringLiteral(value))'" + } + + private func mssqlIndexDefinition(_ index: PluginIndexDefinition, qualifiedTable: String) -> String { + let cols = index.columns.map { quoteIdentifier($0) }.joined(separator: ", ") + let unique = index.isUnique ? "UNIQUE " : "" + var def = "CREATE \(unique)INDEX \(quoteIdentifier(index.name)) ON \(qualifiedTable) (\(cols))" + if let type = index.indexType?.uppercased(), type == "CLUSTERED" { + def = "CREATE \(unique)CLUSTERED INDEX \(quoteIdentifier(index.name)) ON \(qualifiedTable) (\(cols))" + } else if let type = index.indexType?.uppercased(), type == "NONCLUSTERED" { + def = "CREATE \(unique)NONCLUSTERED INDEX \(quoteIdentifier(index.name)) ON \(qualifiedTable) (\(cols))" + } + return def + } + + private func mssqlForeignKeyDefinition(_ fk: PluginForeignKeyDefinition) -> String { + let cols = fk.columns.map { quoteIdentifier($0) }.joined(separator: ", ") + let refCols = fk.referencedColumns.map { quoteIdentifier($0) }.joined(separator: ", ") + var def = "CONSTRAINT \(quoteIdentifier(fk.name)) FOREIGN KEY (\(cols)) REFERENCES \(quoteIdentifier(fk.referencedTable)) (\(refCols))" + if fk.onDelete != "NO ACTION" { + def += " ON DELETE \(fk.onDelete)" + } + if fk.onUpdate != "NO ACTION" { + def += " ON UPDATE \(fk.onUpdate)" + } + return def + } + + // MARK: - ALTER TABLE DDL + + private func mssqlQualifiedTable(_ table: String) -> String { + "\(quoteIdentifier(_currentSchema)).\(quoteIdentifier(table))" + } + + func generateAddColumnSQL(table: String, column: PluginColumnDefinition) -> String? { + "ALTER TABLE \(mssqlQualifiedTable(table)) ADD \(mssqlColumnDefinition(column, inlinePK: false))" + } + + func generateModifyColumnSQL(table: String, oldColumn: PluginColumnDefinition, newColumn: PluginColumnDefinition) -> String? { + let qt = mssqlQualifiedTable(table) + var stmts: [String] = [] + let needsTypeChange = oldColumn.dataType != newColumn.dataType || oldColumn.isNullable != newColumn.isNullable + let defaultChanged = oldColumn.defaultValue != newColumn.defaultValue + + // Rename column first so subsequent statements reference the correct name + if oldColumn.name != newColumn.name { + let escapedPath = "\(escapeStringLiteral(_currentSchema)).\(escapeStringLiteral(table)).\(escapeStringLiteral(oldColumn.name))" + stmts.append("EXEC sp_rename '\(escapedPath)', '\(escapeStringLiteral(newColumn.name))', 'COLUMN'") + } + + let colName = quoteIdentifier(newColumn.name) + + // Drop existing default constraint before ALTER COLUMN or default change + if (defaultChanged || needsTypeChange) && oldColumn.defaultValue != nil { + let objectId = escapeStringLiteral("\(_currentSchema).\(table)") + stmts.append(""" + DECLARE @dfName NVARCHAR(256); \ + SELECT @dfName = dc.name FROM sys.default_constraints dc \ + JOIN sys.columns c ON dc.parent_column_id = c.column_id AND dc.parent_object_id = c.object_id \ + WHERE c.name = '\(escapeStringLiteral(newColumn.name))' \ + AND dc.parent_object_id = OBJECT_ID('\(objectId)'); \ + IF @dfName IS NOT NULL EXEC('ALTER TABLE \(qt) DROP CONSTRAINT [' + @dfName + ']') + """) + } + + if needsTypeChange { + let nullable = newColumn.isNullable ? "NULL" : "NOT NULL" + stmts.append("ALTER TABLE \(qt) ALTER COLUMN \(colName) \(newColumn.dataType) \(nullable)") + } + + if defaultChanged, let defaultValue = newColumn.defaultValue { + stmts.append("ALTER TABLE \(qt) ADD DEFAULT \(mssqlDefaultValue(defaultValue)) FOR \(colName)") + } + + return stmts.isEmpty ? nil : stmts.joined(separator: ";\n") + } + + func generateDropColumnSQL(table: String, columnName: String) -> String? { + "ALTER TABLE \(mssqlQualifiedTable(table)) DROP COLUMN \(quoteIdentifier(columnName))" + } + + func generateAddIndexSQL(table: String, index: PluginIndexDefinition) -> String? { + mssqlIndexDefinition(index, qualifiedTable: mssqlQualifiedTable(table)) + } + + func generateDropIndexSQL(table: String, indexName: String) -> String? { + "DROP INDEX \(quoteIdentifier(indexName)) ON \(mssqlQualifiedTable(table))" + } + + func generateAddForeignKeySQL(table: String, fk: PluginForeignKeyDefinition) -> String? { + "ALTER TABLE \(mssqlQualifiedTable(table)) ADD \(mssqlForeignKeyDefinition(fk))" + } + + func generateDropForeignKeySQL(table: String, constraintName: String) -> String? { + "ALTER TABLE \(mssqlQualifiedTable(table)) DROP CONSTRAINT \(quoteIdentifier(constraintName))" + } + + func generateModifyPrimaryKeySQL(table: String, oldColumns: [String], newColumns: [String], constraintName: String?) -> [String]? { + let qt = mssqlQualifiedTable(table) + var stmts: [String] = [] + if !oldColumns.isEmpty { + let name = constraintName.map { quoteIdentifier($0) } ?? "/* unknown constraint */" + stmts.append("ALTER TABLE \(qt) DROP CONSTRAINT \(name)") + } + if !newColumns.isEmpty { + let cols = newColumns.map { quoteIdentifier($0) }.joined(separator: ", ") + let pkName = constraintName.map { quoteIdentifier($0) } ?? quoteIdentifier("PK_\(table)") + stmts.append("ALTER TABLE \(qt) ADD CONSTRAINT \(pkName) PRIMARY KEY (\(cols))") + } + return stmts.isEmpty ? nil : stmts + } + +} diff --git a/Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+Schema.swift b/Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+Schema.swift new file mode 100644 index 000000000..e2b27121d --- /dev/null +++ b/Plugins/MSSQLDriverPlugin/MSSQLPluginDriver+Schema.swift @@ -0,0 +1,533 @@ +// +// MSSQLPluginDriver+Schema.swift +// MSSQLDriverPlugin +// + +import Foundation +import os +import TableProMSSQLCore +import TableProPluginKit + +extension MSSQLPluginDriver { + // MARK: - Schema Operations + + func fetchTables(schema: String?) async throws -> [PluginTableInfo] { + let esc = effectiveSchemaEscaped(schema) + let sql = """ + SELECT t.TABLE_NAME, t.TABLE_TYPE + FROM INFORMATION_SCHEMA.TABLES t + WHERE t.TABLE_SCHEMA = '\(esc)' + AND t.TABLE_TYPE IN ('BASE TABLE', 'VIEW') + ORDER BY t.TABLE_NAME + """ + let result = try await execute(query: sql) + return result.rows.compactMap { row -> PluginTableInfo? in + guard let name = row[safe: 0]?.asText else { return nil } + let rawType = row[safe: 1]?.asText + let tableType = (rawType == "VIEW") ? "VIEW" : "TABLE" + return PluginTableInfo(name: name, type: tableType) + } + } + + func fetchColumns(table: String, schema: String?) async throws -> [PluginColumnInfo] { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + let esc = effectiveSchemaEscaped(schema) + let sql = """ + SELECT + c.COLUMN_NAME, + c.DATA_TYPE, + c.CHARACTER_MAXIMUM_LENGTH, + c.NUMERIC_PRECISION, + c.NUMERIC_SCALE, + c.IS_NULLABLE, + c.COLUMN_DEFAULT, + COLUMNPROPERTY(OBJECT_ID(c.TABLE_SCHEMA + '.' + c.TABLE_NAME), c.COLUMN_NAME, 'IsIdentity') AS IS_IDENTITY, + CASE WHEN pk.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END AS IS_PK + FROM INFORMATION_SCHEMA.COLUMNS c + LEFT JOIN ( + SELECT kcu.COLUMN_NAME + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc + JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu + ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME + AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA + WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' + AND tc.TABLE_SCHEMA = '\(esc)' + AND tc.TABLE_NAME = '\(escapedTable)' + ) pk ON c.COLUMN_NAME = pk.COLUMN_NAME + WHERE c.TABLE_NAME = '\(escapedTable)' + AND c.TABLE_SCHEMA = '\(esc)' + ORDER BY c.ORDINAL_POSITION + """ + let result = try await execute(query: sql) + var identityColumns: Set = [] + let columns: [PluginColumnInfo] = result.rows.compactMap { row -> PluginColumnInfo? in + guard let name = row[safe: 0]?.asText else { return nil } + let dataType = row[safe: 1]?.asText + let charLen = row[safe: 2]?.asText + let numPrecision = row[safe: 3]?.asText + let numScale = row[safe: 4]?.asText + let isNullable = (row[safe: 5]?.asText) == "YES" + let defaultValue = row[safe: 6]?.asText + let isIdentity = (row[safe: 7]?.asText) == "1" + let isPk = (row[safe: 8]?.asText) == "1" + + if isIdentity { + identityColumns.insert(name) + } + + let baseType = (dataType ?? "nvarchar").lowercased() + let fixedSizeTypes: Set = [ + "int", "bigint", "smallint", "tinyint", "bit", + "money", "smallmoney", "float", "real", + "datetime", "datetime2", "smalldatetime", "date", "time", + "uniqueidentifier", "text", "ntext", "image", "xml", + "timestamp", "rowversion" + ] + var fullType = baseType + if fixedSizeTypes.contains(baseType) { + // No suffix + } else if let charLen, let len = Int(charLen), len > 0 { + fullType += "(\(len))" + } else if charLen == "-1" { + fullType += "(max)" + } else if let prec = numPrecision, let scale = numScale, + let p = Int(prec), let s = Int(scale) { + fullType += "(\(p),\(s))" + } + + return PluginColumnInfo( + name: name, + dataType: fullType, + isNullable: isNullable, + isPrimaryKey: isPk, + defaultValue: defaultValue, + extra: isIdentity ? "IDENTITY" : nil + ) + } + identityCacheLock.lock() + identityColumnsByTable[table] = identityColumns + identityCacheLock.unlock() + return columns + } + + /// Snapshot of IDENTITY columns observed by the most recent `fetchColumns` for the table. + /// Returns an empty set when `fetchColumns` hasn't run for this table yet, so callers + /// fall through to including every typed value (matching pre-cache behavior). + internal func cachedIdentityColumns(for table: String) -> Set { + identityCacheLock.lock() + defer { identityCacheLock.unlock() } + return identityColumnsByTable[table] ?? [] + } + + /// Test seam: pre-populate the cache so generateMssqlInsert can be exercised + /// without going through a live `fetchColumns` round-trip. + internal func setIdentityColumnsForTesting(_ columns: Set, table: String) { + identityCacheLock.lock() + identityColumnsByTable[table] = columns + identityCacheLock.unlock() + } + + func fetchIndexes(table: String, schema: String?) async throws -> [PluginIndexInfo] { + let esc = (schema ?? _currentSchema).replacingOccurrences(of: "]", with: "]]") + let bracketedTable = table.replacingOccurrences(of: "]", with: "]]") + let bracketedFull = "[\(esc)].[\(bracketedTable)]" + let sql = """ + SELECT i.name, i.is_unique, i.is_primary_key, c.name AS column_name + FROM sys.indexes i + JOIN sys.index_columns ic + ON i.object_id = ic.object_id AND i.index_id = ic.index_id + JOIN sys.columns c + ON ic.object_id = c.object_id AND ic.column_id = c.column_id + WHERE i.object_id = OBJECT_ID('\(bracketedFull)') + AND i.name IS NOT NULL + ORDER BY i.index_id, ic.key_ordinal + """ + let result = try await execute(query: sql) + var indexMap: [String: (unique: Bool, primary: Bool, columns: [String])] = [:] + for row in result.rows { + guard let idxName = row[safe: 0]?.asText, + let colName = row[safe: 3]?.asText else { continue } + let isUnique = (row[safe: 1]?.asText) == "1" + let isPrimary = (row[safe: 2]?.asText) == "1" + if indexMap[idxName] == nil { + indexMap[idxName] = (unique: isUnique, primary: isPrimary, columns: []) + } + indexMap[idxName]?.columns.append(colName) + } + return indexMap.map { name, info in + PluginIndexInfo( + name: name, + columns: info.columns, + isUnique: info.unique, + isPrimary: info.primary, + type: "CLUSTERED" + ) + }.sorted { $0.name < $1.name } + } + + func fetchForeignKeys(table: String, schema: String?) async throws -> [PluginForeignKeyInfo] { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + let esc = effectiveSchemaEscaped(schema) + let sql = """ + SELECT + fk.name AS constraint_name, + cp.name AS column_name, + tr.name AS ref_table, + cr.name AS ref_column + FROM sys.foreign_keys fk + JOIN sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id + JOIN sys.tables tp ON fkc.parent_object_id = tp.object_id + JOIN sys.schemas s ON tp.schema_id = s.schema_id + JOIN sys.columns cp + ON fkc.parent_object_id = cp.object_id AND fkc.parent_column_id = cp.column_id + JOIN sys.tables tr ON fkc.referenced_object_id = tr.object_id + JOIN sys.columns cr + ON fkc.referenced_object_id = cr.object_id AND fkc.referenced_column_id = cr.column_id + WHERE tp.name = '\(escapedTable)' AND s.name = '\(esc)' + ORDER BY fk.name + """ + let result = try await execute(query: sql) + return result.rows.compactMap { row -> PluginForeignKeyInfo? in + guard let constraintName = row[safe: 0]?.asText, + let columnName = row[safe: 1]?.asText, + let refTable = row[safe: 2]?.asText, + let refColumn = row[safe: 3]?.asText else { return nil } + return PluginForeignKeyInfo( + name: constraintName, + column: columnName, + referencedTable: refTable, + referencedColumn: refColumn + ) + } + } + + func fetchAllColumns(schema: String?) async throws -> [String: [PluginColumnInfo]] { + let esc = effectiveSchemaEscaped(schema) + let sql = """ + SELECT + c.TABLE_NAME, + c.COLUMN_NAME, + c.DATA_TYPE, + c.CHARACTER_MAXIMUM_LENGTH, + c.NUMERIC_PRECISION, + c.NUMERIC_SCALE, + c.IS_NULLABLE, + c.COLUMN_DEFAULT, + COLUMNPROPERTY(OBJECT_ID(c.TABLE_SCHEMA + '.' + c.TABLE_NAME), c.COLUMN_NAME, 'IsIdentity') AS IS_IDENTITY, + CASE WHEN pk.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END AS IS_PK + FROM INFORMATION_SCHEMA.COLUMNS c + LEFT JOIN ( + SELECT kcu.TABLE_NAME, kcu.COLUMN_NAME + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc + JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu + ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME + AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA + WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' + AND tc.TABLE_SCHEMA = '\(esc)' + ) pk ON c.TABLE_NAME = pk.TABLE_NAME AND c.COLUMN_NAME = pk.COLUMN_NAME + WHERE c.TABLE_SCHEMA = '\(esc)' + ORDER BY c.TABLE_NAME, c.ORDINAL_POSITION + """ + let result = try await execute(query: sql) + var columnsByTable: [String: [PluginColumnInfo]] = [:] + for row in result.rows { + guard let tableName = row[safe: 0]?.asText, + let name = row[safe: 1]?.asText else { continue } + let dataType = row[safe: 2]?.asText + let charLen = row[safe: 3]?.asText + let numPrecision = row[safe: 4]?.asText + let numScale = row[safe: 5]?.asText + let isNullable = (row[safe: 6]?.asText) == "YES" + let defaultValue = row[safe: 7]?.asText + let isIdentity = (row[safe: 8]?.asText) == "1" + let isPk = (row[safe: 9]?.asText) == "1" + + let baseType = (dataType ?? "nvarchar").lowercased() + let fixedSizeTypes: Set = [ + "int", "bigint", "smallint", "tinyint", "bit", + "money", "smallmoney", "float", "real", + "datetime", "datetime2", "smalldatetime", "date", "time", + "uniqueidentifier", "text", "ntext", "image", "xml", + "timestamp", "rowversion" + ] + var fullType = baseType + if fixedSizeTypes.contains(baseType) { + // No suffix + } else if let charLen, let len = Int(charLen), len > 0 { + fullType += "(\(len))" + } else if charLen == "-1" { + fullType += "(max)" + } else if let prec = numPrecision, let scale = numScale, + let p = Int(prec), let s = Int(scale) { + fullType += "(\(p),\(s))" + } + + let col = PluginColumnInfo( + name: name, + dataType: fullType, + isNullable: isNullable, + isPrimaryKey: isPk, + defaultValue: defaultValue, + extra: isIdentity ? "IDENTITY" : nil + ) + columnsByTable[tableName, default: []].append(col) + } + return columnsByTable + } + + func fetchAllForeignKeys(schema: String?) async throws -> [String: [PluginForeignKeyInfo]] { + let esc = effectiveSchemaEscaped(schema) + let sql = """ + SELECT + tp.name AS table_name, + fk.name AS constraint_name, + cp.name AS column_name, + tr.name AS ref_table, + cr.name AS ref_column + FROM sys.foreign_keys fk + JOIN sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id + JOIN sys.tables tp ON fkc.parent_object_id = tp.object_id + JOIN sys.schemas s ON tp.schema_id = s.schema_id + JOIN sys.columns cp + ON fkc.parent_object_id = cp.object_id AND fkc.parent_column_id = cp.column_id + JOIN sys.tables tr ON fkc.referenced_object_id = tr.object_id + JOIN sys.columns cr + ON fkc.referenced_object_id = cr.object_id AND fkc.referenced_column_id = cr.column_id + WHERE s.name = '\(esc)' + ORDER BY tp.name, fk.name + """ + let result = try await execute(query: sql) + var fksByTable: [String: [PluginForeignKeyInfo]] = [:] + for row in result.rows { + guard let tableName = row[safe: 0]?.asText, + let constraintName = row[safe: 1]?.asText, + let columnName = row[safe: 2]?.asText, + let refTable = row[safe: 3]?.asText, + let refColumn = row[safe: 4]?.asText else { continue } + let fk = PluginForeignKeyInfo( + name: constraintName, + column: columnName, + referencedTable: refTable, + referencedColumn: refColumn + ) + fksByTable[tableName, default: []].append(fk) + } + return fksByTable + } + + func fetchAllDatabaseMetadata() async throws -> [PluginDatabaseMetadata] { + let sql = """ + SELECT d.name, + SUM(mf.size) * 8 * 1024 AS size_bytes + FROM sys.databases d + LEFT JOIN sys.master_files mf ON d.database_id = mf.database_id + GROUP BY d.name + ORDER BY d.name + """ + do { + let result = try await execute(query: sql) + var metadata = result.rows.compactMap { row -> PluginDatabaseMetadata? in + guard let name = row[safe: 0]?.asText else { return nil } + let sizeBytes = (row[safe: 1]?.asText).flatMap { Int64($0) } + return PluginDatabaseMetadata(name: name, sizeBytes: sizeBytes) + } + + for i in metadata.indices { + let dbName = metadata[i].name.replacingOccurrences(of: "]", with: "]]") + do { + let countResult = try await execute( + query: "SELECT COUNT(*) FROM [\(dbName)].sys.tables" + ) + if let countStr = countResult.rows.first?[safe: 0]?.asText, + let count = Int(countStr) { + metadata[i] = PluginDatabaseMetadata( + name: metadata[i].name, + tableCount: count, + sizeBytes: metadata[i].sizeBytes + ) + } + } catch { + // Database offline or permission denied: leave tableCount as nil + } + } + + return metadata + } catch { + // Fall back to N+1 if permission denied on sys.master_files + let dbs = try await fetchDatabases() + var result: [PluginDatabaseMetadata] = [] + for db in dbs { + do { + result.append(try await fetchDatabaseMetadata(db)) + } catch { + result.append(PluginDatabaseMetadata(name: db)) + } + } + return result + } + } + + func fetchTableDDL(table: String, schema: String?) async throws -> String { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + let esc = effectiveSchemaEscaped(schema) + let cols = try await fetchColumns(table: table, schema: schema) + let indexes = try await fetchIndexes(table: table, schema: schema) + let fks = try await fetchForeignKeys(table: table, schema: schema) + + var ddl = "CREATE TABLE [\(esc)].[\(escapedTable)] (\n" + let colDefs = cols.map { col -> String in + var def = " [\(col.name)] \(col.dataType.uppercased())" + if col.extra == "IDENTITY" { def += " IDENTITY(1,1)" } + def += col.isNullable ? " NULL" : " NOT NULL" + if let d = col.defaultValue { def += " DEFAULT \(d)" } + return def + } + + let pkCols = indexes.filter(\.isPrimary).flatMap(\.columns) + var parts = colDefs + if !pkCols.isEmpty { + let pkName = "PK_\(table)" + let pkDef = " CONSTRAINT [\(pkName)] PRIMARY KEY (\(pkCols.map { "[\($0)]" }.joined(separator: ", ")))" + parts.append(pkDef) + } + + for fk in fks { + let fkDef = " CONSTRAINT [\(fk.name)] FOREIGN KEY ([\(fk.column)]) REFERENCES [\(fk.referencedTable)] ([\(fk.referencedColumn)])" + parts.append(fkDef) + } + + ddl += parts.joined(separator: ",\n") + ddl += "\n);" + return ddl + } + + func fetchViewDefinition(view: String, schema: String?) async throws -> String { + let esc = effectiveSchemaEscaped(schema) + let escapedView = "\(esc).\(view.replacingOccurrences(of: "'", with: "''"))" + let sql = "SELECT definition FROM sys.sql_modules WHERE object_id = OBJECT_ID('\(escapedView)')" + let result = try await execute(query: sql) + return result.rows.first?.first?.asText ?? "" + } + + func fetchTableMetadata(table: String, schema: String?) async throws -> PluginTableMetadata { + let escapedTable = table.replacingOccurrences(of: "'", with: "''") + let esc = effectiveSchemaEscaped(schema) + let sql = """ + SELECT + SUM(p.rows) AS row_count, + 8 * SUM(a.used_pages) AS size_kb, + ep.value AS comment + FROM sys.tables t + JOIN sys.schemas s ON t.schema_id = s.schema_id + JOIN sys.partitions p + ON t.object_id = p.object_id AND p.index_id IN (0, 1) + JOIN sys.allocation_units a ON p.partition_id = a.container_id + LEFT JOIN sys.extended_properties ep + ON ep.major_id = t.object_id AND ep.minor_id = 0 AND ep.name = 'MS_Description' + WHERE t.name = '\(escapedTable)' AND s.name = '\(esc)' + GROUP BY ep.value + """ + let result = try await execute(query: sql) + if let row = result.rows.first { + let rowCount = (row[safe: 0]?.asText).flatMap { Int64($0) } + let sizeKb = (row[safe: 1]?.asText).flatMap { Int64($0) } ?? 0 + let comment = row[safe: 2]?.asText + return PluginTableMetadata( + tableName: table, + dataSize: sizeKb * 1_024, + totalSize: sizeKb * 1_024, + rowCount: rowCount, + comment: comment + ) + } + return PluginTableMetadata(tableName: table) + } + + func fetchDatabases() async throws -> [String] { + let sql = "SELECT name FROM sys.databases ORDER BY name" + let result = try await execute(query: sql) + return result.rows.compactMap { $0.first?.asText } + } + + func fetchSchemas() async throws -> [String] { + let sql = """ + SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA + WHERE SCHEMA_NAME NOT IN ( + 'information_schema','sys','db_owner','db_accessadmin', + 'db_securityadmin','db_ddladmin','db_backupoperator', + 'db_datareader','db_datawriter','db_denydatareader', + 'db_denydatawriter','guest' + ) + ORDER BY SCHEMA_NAME + """ + let result = try await execute(query: sql) + return result.rows.compactMap { $0.first?.asText } + } + + func switchSchema(to schema: String) async throws { + _currentSchema = schema + } + + func switchDatabase(to database: String) async throws { + guard let conn = freeTDSConn else { + throw MSSQLPluginError.notConnected + } + try await conn.switchDatabase(database) + } + + func fetchDatabaseMetadata(_ database: String) async throws -> PluginDatabaseMetadata { + let sql = """ + SELECT + SUM(size) * 8.0 / 1024 AS size_mb, + (SELECT COUNT(*) FROM sys.tables) AS table_count + FROM sys.database_files + """ + let result = try await execute(query: sql) + if let row = result.rows.first { + let sizeMb = (row[safe: 0]?.asText).flatMap { Double($0) } ?? 0 + let tableCount = (row[safe: 1]?.asText).flatMap { Int($0) } ?? 0 + return PluginDatabaseMetadata( + name: database, + tableCount: tableCount, + sizeBytes: Int64(sizeMb * 1_024 * 1_024) + ) + } + return PluginDatabaseMetadata(name: database) + } + + func createDatabaseFormSpec() async throws -> PluginCreateDatabaseFormSpec? { + PluginCreateDatabaseFormSpec(fields: [], footnote: nil) + } + + func createDatabase(_ request: PluginCreateDatabaseRequest) async throws { + let quotedName = "[\(request.name.replacingOccurrences(of: "]", with: "]]"))]" + _ = try await execute(query: "CREATE DATABASE \(quotedName)") + } + + func dropDatabase(name: String) async throws { + let quotedName = "[\(name.replacingOccurrences(of: "]", with: "]]"))]" + _ = try await execute(query: "DROP DATABASE \(quotedName)") + } + + // MARK: - All Tables Metadata + + func allTablesMetadataSQL(schema: String?) -> String? { + """ + SELECT + s.name as schema_name, + t.name as name, + CASE WHEN v.object_id IS NOT NULL THEN 'VIEW' ELSE 'TABLE' END as kind, + p.rows as estimated_rows, + CAST(ROUND(SUM(a.total_pages) * 8 / 1024.0, 2) AS VARCHAR) + ' MB' as total_size + FROM sys.tables t + INNER JOIN sys.schemas s ON t.schema_id = s.schema_id + INNER JOIN sys.indexes i ON t.object_id = i.object_id AND i.index_id IN (0, 1) + INNER JOIN sys.partitions p ON i.object_id = p.object_id AND i.index_id = p.index_id + INNER JOIN sys.allocation_units a ON p.partition_id = a.container_id + LEFT JOIN sys.views v ON t.object_id = v.object_id + GROUP BY s.name, t.name, p.rows, v.object_id + ORDER BY t.name + """ + } + +} diff --git a/Plugins/MongoDBDriverPlugin/MongoDBConnection+SyncHelpers.swift b/Plugins/MongoDBDriverPlugin/MongoDBConnection+SyncHelpers.swift new file mode 100644 index 000000000..776ef452e --- /dev/null +++ b/Plugins/MongoDBDriverPlugin/MongoDBConnection+SyncHelpers.swift @@ -0,0 +1,458 @@ +// +// MongoDBConnection+SyncHelpers.swift +// MongoDBDriverPlugin +// + +#if canImport(CLibMongoc) +import CLibMongoc +#endif +import Foundation +import OSLog +import TableProPluginKit + +#if canImport(CLibMongoc) +extension MongoDBConnection { + func bsonErrorMessage(_ error: inout bson_error_t) -> String { + withUnsafePointer(to: &error.message) { ptr in + ptr.withMemoryRebound(to: CChar.self, capacity: 504) { String(cString: $0) } + } + } + + func makeError(_ error: bson_error_t) -> MongoDBError { + var err = error + return MongoDBError(code: err.code, message: bsonErrorMessage(&err)) + } + + func fetchServerVersionSync() -> String? { + guard let client = self.client, + let command = jsonToBson("{\"buildInfo\": 1}") else { return nil } + defer { bson_destroy(command) } + + let reply = bson_new() + defer { bson_destroy(reply) } + var error = bson_error_t() + + let dbName = database.isEmpty ? "admin" : database + let ok = dbName.withCString { mongoc_client_command_simple(client, $0, command, nil, reply, &error) } + guard ok else { return nil } + + return bsonToDict(reply)["version"] as? String + } + + func getCollection( + _ client: OpaquePointer, database: String, collection: String + ) throws -> OpaquePointer { + guard let col = database.withCString({ dbPtr in + collection.withCString { colPtr in mongoc_client_get_collection(client, dbPtr, colPtr) } + }) else { + throw MongoDBError(code: 0, message: "Failed to get collection \(collection)") + } + return col + } + + func runCommandSync( + client: OpaquePointer, command: String, database: String? + ) throws -> [[String: Any]] { + try checkCancelled() + + guard let bsonCmd = jsonToBson(command) else { + throw MongoDBError(code: 0, message: "Invalid JSON command: \(command)") + } + defer { bson_destroy(bsonCmd) } + + let timeoutMS = queryTimeoutMS + if timeoutMS > 0, !bson_has_field(bsonCmd, "maxTimeMS") { + bson_append_int32(bsonCmd, "maxTimeMS", -1, timeoutMS) + } + + let reply = bson_new() + defer { bson_destroy(reply) } + var error = bson_error_t() + + let effectiveDb = (database ?? self.database).isEmpty ? "admin" : (database ?? self.database) + let ok = effectiveDb.withCString { mongoc_client_command_simple(client, $0, bsonCmd, nil, reply, &error) } + + try checkCancelled() + guard ok else { throw makeError(error) } + + return [bsonToDict(reply)] + } + + func findSync( + client: OpaquePointer, database: String, collection: String, + filter: String, sort: String?, projection: String?, skip: Int, limit: Int + ) throws -> (docs: [[String: Any]], isTruncated: Bool) { + try checkCancelled() + + guard let filterBson = jsonToBson(filter) else { + throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") + } + defer { bson_destroy(filterBson) } + + var optsJson: [String: Any] = ["skip": skip, "limit": limit] + if let sort = sort, let data = sort.data(using: .utf8), + let obj = try? JSONSerialization.jsonObject(with: data) { + optsJson["sort"] = obj + } + if let projection = projection, let data = projection.data(using: .utf8), + let obj = try? JSONSerialization.jsonObject(with: data) { + optsJson["projection"] = obj + } + + let timeoutMS = queryTimeoutMS + if timeoutMS > 0 { + optsJson["maxTimeMS"] = timeoutMS + } + + let optsData = try JSONSerialization.data(withJSONObject: optsJson) + guard let optsStr = String(data: optsData, encoding: .utf8), + let optsBson = jsonToBson(optsStr) else { + throw MongoDBError(code: 0, message: "Failed to build query options") + } + defer { bson_destroy(optsBson) } + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + try checkCancelled() + + guard let cursor = mongoc_collection_find_with_opts(col, filterBson, optsBson, nil) else { + throw MongoDBError(code: 0, message: "Failed to create find cursor") + } + defer { mongoc_cursor_destroy(cursor) } + + return try iterateCursor(cursor) + } + + func aggregateSync( + client: OpaquePointer, database: String, collection: String, pipeline: String + ) throws -> (docs: [[String: Any]], isTruncated: Bool) { + try checkCancelled() + + guard let pipelineBson = jsonToBson(pipeline) else { + throw MongoDBError(code: 0, message: "Invalid JSON pipeline: \(pipeline)") + } + defer { bson_destroy(pipelineBson) } + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + let timeoutMS = queryTimeoutMS + var optsBson: OpaquePointer? + if timeoutMS > 0 { + optsBson = jsonToBson("{\"maxTimeMS\": \(timeoutMS)}") + } + defer { if let opts = optsBson { bson_destroy(opts) } } + + try checkCancelled() + + guard let cursor = mongoc_collection_aggregate( + col, MONGOC_QUERY_NONE, pipelineBson, optsBson, nil + ) else { + throw MongoDBError(code: 0, message: "Failed to create aggregation cursor") + } + defer { mongoc_cursor_destroy(cursor) } + + return try iterateCursor(cursor) + } + + func countDocumentsSync( + client: OpaquePointer, database: String, collection: String, filter: String + ) throws -> Int64 { + try checkCancelled() + + guard let filterBson = jsonToBson(filter) else { + throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") + } + defer { bson_destroy(filterBson) } + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + let timeoutMS = queryTimeoutMS + var optsBson: OpaquePointer? + if timeoutMS > 0 { + optsBson = jsonToBson("{\"maxTimeMS\": \(timeoutMS)}") + } + defer { if let opts = optsBson { bson_destroy(opts) } } + + var error = bson_error_t() + let count = mongoc_collection_count_documents(col, filterBson, optsBson, nil, nil, &error) + + try checkCancelled() + guard count >= 0 else { throw makeError(error) } + return count + } + + func insertOneSync( + client: OpaquePointer, database: String, collection: String, document: String + ) throws -> String? { + try checkCancelled() + + guard let docBson = jsonToBson(document) else { + throw MongoDBError(code: 0, message: "Invalid JSON document: \(document)") + } + defer { bson_destroy(docBson) } + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + let reply = bson_new() + defer { bson_destroy(reply) } + var error = bson_error_t() + + guard mongoc_collection_insert_one(col, docBson, nil, reply, &error) else { + throw makeError(error) + } + + if let objectId = bsonToDict(docBson)["_id"] { return "\(objectId)" } + return nil + } + + func updateOneSync( + client: OpaquePointer, database: String, collection: String, filter: String, update: String + ) throws -> Int64 { + try checkCancelled() + + guard let filterBson = jsonToBson(filter) else { + throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") + } + defer { bson_destroy(filterBson) } + + guard let updateBson = jsonToBson(update) else { + throw MongoDBError(code: 0, message: "Invalid JSON update: \(update)") + } + defer { bson_destroy(updateBson) } + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + let reply = bson_new() + defer { bson_destroy(reply) } + var error = bson_error_t() + + guard mongoc_collection_update_one(col, filterBson, updateBson, nil, reply, &error) else { + throw makeError(error) + } + return (bsonToDict(reply)["modifiedCount"] as? Int64) ?? 0 + } + + func deleteOneSync( + client: OpaquePointer, database: String, collection: String, filter: String + ) throws -> Int64 { + try checkCancelled() + + guard let filterBson = jsonToBson(filter) else { + throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") + } + defer { bson_destroy(filterBson) } + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + let reply = bson_new() + defer { bson_destroy(reply) } + var error = bson_error_t() + + guard mongoc_collection_delete_one(col, filterBson, nil, reply, &error) else { + throw makeError(error) + } + return (bsonToDict(reply)["deletedCount"] as? Int64) ?? 0 + } + + func listDatabasesSync(client: OpaquePointer) throws -> [String] { + try checkCancelled() + + let caps = MongoDBCapabilities.parse(serverVersion()) + var fields = ["\"listDatabases\": 1"] + if caps.supportsListDatabasesNameOnly { + fields.append("\"nameOnly\": true") + } + if caps.supportsAuthorizedDatabases { + fields.append("\"authorizedDatabases\": true") + } + let commandJSON = "{\(fields.joined(separator: ", "))}" + guard let command = jsonToBson(commandJSON) else { + throw MongoDBError(code: 0, message: "Failed to create listDatabases command") + } + defer { bson_destroy(command) } + + let reply = bson_new() + defer { bson_destroy(reply) } + var error = bson_error_t() + + let ok = "admin".withCString { mongoc_client_command_simple(client, $0, command, nil, reply, &error) } + + try checkCancelled() + guard ok else { throw makeError(error) } + + guard let databases = bsonToDict(reply)["databases"] as? [[String: Any]] else { return [] } + return databases.compactMap { $0["name"] as? String } + } + + func listCollectionsSync(client: OpaquePointer, database: String) throws -> [String] { + try checkCancelled() + + guard let mongocDb = database.withCString({ mongoc_client_get_database(client, $0) }) else { + throw MongoDBError(code: 0, message: "Failed to get database \(database)") + } + defer { mongoc_database_destroy(mongocDb) } + + var error = bson_error_t() + guard let names = mongoc_database_get_collection_names_with_opts(mongocDb, nil, &error) else { + throw makeError(error) + } + defer { bson_strfreev(names) } + + try checkCancelled() + + var collections: [String] = [] + var index = 0 + while let namePtr = names[index] { + collections.append(String(cString: namePtr)) + index += 1 + } + return collections + } + + func listIndexesSync( + client: OpaquePointer, database: String, collection: String + ) throws -> [[String: Any]] { + try checkCancelled() + + let col = try getCollection(client, database: database, collection: collection) + defer { mongoc_collection_destroy(col) } + + guard let cursor = mongoc_collection_find_indexes_with_opts(col, nil) else { + throw MongoDBError(code: 0, message: "Failed to list indexes for \(collection)") + } + defer { mongoc_cursor_destroy(cursor) } + + return try iterateCursor(cursor).docs + } + + func iterateCursor(_ cursor: OpaquePointer) throws -> (docs: [[String: Any]], isTruncated: Bool) { + try checkCancelled() + + var results: [[String: Any]] = [] + var docPtr: OpaquePointer? + var truncated = false + + while mongoc_cursor_next(cursor, &docPtr) { + try checkCancelled() + + if let doc = docPtr { + results.append(bsonToDict(doc)) + } + + if results.count >= PluginRowLimits.emergencyMax { + truncated = true + logger.warning("Result set truncated at \(PluginRowLimits.emergencyMax) documents") + break + } + } + + var error = bson_error_t() + if mongoc_cursor_error(cursor, &error) { + throw makeError(error) + } + return (docs: results, isTruncated: truncated) + } + + func iterateCursorStreaming( + cursor: OpaquePointer, + continuation: AsyncThrowingStream.Continuation, + streamState: MongoStreamState + ) { + var docPtr: OpaquePointer? + var headerSent = false + var columns: [String] = [] + var columnTypeNames: [String] = [] + + while mongoc_cursor_next(cursor, &docPtr) { + if Task.isCancelled { + cleanup(streamState) + continuation.finish(throwing: CancellationError()) + return + } + + guard let doc = docPtr else { continue } + let dict = bsonToDict(doc) + + if !headerSent { + columns = BsonDocumentFlattener.unionColumns(from: [dict]) + let bsonTypes = BsonDocumentFlattener.columnTypes(for: columns, documents: [dict]) + columnTypeNames = bsonTypes.map { bsonTypeToStreamString($0) } + continuation.yield(.header(PluginStreamHeader( + columns: columns, + columnTypeNames: columnTypeNames + ))) + headerSent = true + } else { + for key in dict.keys.sorted() where !columns.contains(key) { + columns.append(key) + let type = BsonDocumentFlattener.columnTypes(for: [key], documents: [dict]) + columnTypeNames.append(bsonTypeToStreamString(type.first ?? 2)) + } + } + + let row: [PluginCellValue] = columns.map { column in + guard let value = dict[column] else { return .null } + if let data = value as? Data { + return .bytes(data) + } + return PluginCellValue.fromOptional(BsonDocumentFlattener.stringValue(for: value)) + } + continuation.yield(.rows([row])) + } + + var error = bson_error_t() + if mongoc_cursor_error(cursor, &error) { + cleanup(streamState) + continuation.finish(throwing: makeError(error)) + return + } + + if !headerSent { + continuation.yield(.header(PluginStreamHeader( + columns: ["_id"], + columnTypeNames: ["VARCHAR"] + ))) + } + + cleanup(streamState) + continuation.finish() + } + + private func cleanup(_ state: MongoStreamState) { + state.lock.lock() + let cur = state.cursor + let col = state.collection + let alreadyDrained = state.drained + state.drained = true + state.cursor = nil + state.collection = nil + state.lock.unlock() + guard !alreadyDrained else { return } + if let cur { mongoc_cursor_destroy(cur) } + if let col { mongoc_collection_destroy(col) } + } + + private func bsonTypeToStreamString(_ type: Int32) -> String { + switch type { + case 1: return "FLOAT" + case 2: return "VARCHAR" + case 3: return "JSON" + case 4: return "JSON" + case 5: return "BLOB" + case 7: return "VARCHAR" + case 8: return "BOOLEAN" + case 9: return "TIMESTAMP" + case 10: return "VARCHAR" + case 16: return "INTEGER" + case 18: return "BIGINT" + default: return "VARCHAR" + } + } +} +#endif diff --git a/Plugins/MongoDBDriverPlugin/MongoDBConnection.swift b/Plugins/MongoDBDriverPlugin/MongoDBConnection.swift index f82187aa4..8e05ab707 100644 --- a/Plugins/MongoDBDriverPlugin/MongoDBConnection.swift +++ b/Plugins/MongoDBDriverPlugin/MongoDBConnection.swift @@ -13,7 +13,7 @@ import Foundation import OSLog import TableProPluginKit -private let logger = Logger(subsystem: "com.TablePro", category: "MongoDBConnection") +let logger = Logger(subsystem: "com.TablePro", category: "MongoDBConnection") // MARK: - Error Types @@ -49,7 +49,7 @@ final class MongoDBConnection: @unchecked Sendable { mongoc_init() }() - private var client: OpaquePointer? + var client: OpaquePointer? #endif private static let queueKey = DispatchSpecificKey() @@ -58,7 +58,7 @@ final class MongoDBConnection: @unchecked Sendable { private let port: Int private let user: String private let password: String? - private let database: String + let database: String private let ssl: SSLConfiguration private let authSource: String? private let readPreference: String? @@ -94,7 +94,7 @@ final class MongoDBConnection: @unchecked Sendable { } } - private var queryTimeoutMS: Int32 { + var queryTimeoutMS: Int32 { stateLock.lock() defer { stateLock.unlock() } return _queryTimeoutMS @@ -350,7 +350,7 @@ final class MongoDBConnection: @unchecked Sendable { /// Throws if cancellation was requested, resetting the flag atomically. /// Safe to call from any thread. - private func checkCancelled() throws { + func checkCancelled() throws { stateLock.lock() let cancelled = _isCancelled if cancelled { _isCancelled = false } @@ -777,454 +777,6 @@ final class MongoDBConnection: @unchecked Sendable { } } -// MARK: - Synchronous Helpers (must be called on the serial queue) - -#if canImport(CLibMongoc) -private extension MongoDBConnection { - func bsonErrorMessage(_ error: inout bson_error_t) -> String { - withUnsafePointer(to: &error.message) { ptr in - ptr.withMemoryRebound(to: CChar.self, capacity: 504) { String(cString: $0) } - } - } - - func makeError(_ error: bson_error_t) -> MongoDBError { - var err = error - return MongoDBError(code: err.code, message: bsonErrorMessage(&err)) - } - - func fetchServerVersionSync() -> String? { - guard let client = self.client, - let command = jsonToBson("{\"buildInfo\": 1}") else { return nil } - defer { bson_destroy(command) } - - let reply = bson_new() - defer { bson_destroy(reply) } - var error = bson_error_t() - - let dbName = database.isEmpty ? "admin" : database - let ok = dbName.withCString { mongoc_client_command_simple(client, $0, command, nil, reply, &error) } - guard ok else { return nil } - - return bsonToDict(reply)["version"] as? String - } - - func getCollection( - _ client: OpaquePointer, database: String, collection: String - ) throws -> OpaquePointer { - guard let col = database.withCString({ dbPtr in - collection.withCString { colPtr in mongoc_client_get_collection(client, dbPtr, colPtr) } - }) else { - throw MongoDBError(code: 0, message: "Failed to get collection \(collection)") - } - return col - } - - func runCommandSync( - client: OpaquePointer, command: String, database: String? - ) throws -> [[String: Any]] { - try checkCancelled() - - guard let bsonCmd = jsonToBson(command) else { - throw MongoDBError(code: 0, message: "Invalid JSON command: \(command)") - } - defer { bson_destroy(bsonCmd) } - - let timeoutMS = queryTimeoutMS - if timeoutMS > 0, !bson_has_field(bsonCmd, "maxTimeMS") { - bson_append_int32(bsonCmd, "maxTimeMS", -1, timeoutMS) - } - - let reply = bson_new() - defer { bson_destroy(reply) } - var error = bson_error_t() - - let effectiveDb = (database ?? self.database).isEmpty ? "admin" : (database ?? self.database) - let ok = effectiveDb.withCString { mongoc_client_command_simple(client, $0, bsonCmd, nil, reply, &error) } - - try checkCancelled() - guard ok else { throw makeError(error) } - - return [bsonToDict(reply)] - } - - func findSync( - client: OpaquePointer, database: String, collection: String, - filter: String, sort: String?, projection: String?, skip: Int, limit: Int - ) throws -> (docs: [[String: Any]], isTruncated: Bool) { - try checkCancelled() - - guard let filterBson = jsonToBson(filter) else { - throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") - } - defer { bson_destroy(filterBson) } - - var optsJson: [String: Any] = ["skip": skip, "limit": limit] - if let sort = sort, let data = sort.data(using: .utf8), - let obj = try? JSONSerialization.jsonObject(with: data) { - optsJson["sort"] = obj - } - if let projection = projection, let data = projection.data(using: .utf8), - let obj = try? JSONSerialization.jsonObject(with: data) { - optsJson["projection"] = obj - } - - let timeoutMS = queryTimeoutMS - if timeoutMS > 0 { - optsJson["maxTimeMS"] = timeoutMS - } - - let optsData = try JSONSerialization.data(withJSONObject: optsJson) - guard let optsStr = String(data: optsData, encoding: .utf8), - let optsBson = jsonToBson(optsStr) else { - throw MongoDBError(code: 0, message: "Failed to build query options") - } - defer { bson_destroy(optsBson) } - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - try checkCancelled() - - guard let cursor = mongoc_collection_find_with_opts(col, filterBson, optsBson, nil) else { - throw MongoDBError(code: 0, message: "Failed to create find cursor") - } - defer { mongoc_cursor_destroy(cursor) } - - return try iterateCursor(cursor) - } - - func aggregateSync( - client: OpaquePointer, database: String, collection: String, pipeline: String - ) throws -> (docs: [[String: Any]], isTruncated: Bool) { - try checkCancelled() - - guard let pipelineBson = jsonToBson(pipeline) else { - throw MongoDBError(code: 0, message: "Invalid JSON pipeline: \(pipeline)") - } - defer { bson_destroy(pipelineBson) } - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - let timeoutMS = queryTimeoutMS - var optsBson: OpaquePointer? - if timeoutMS > 0 { - optsBson = jsonToBson("{\"maxTimeMS\": \(timeoutMS)}") - } - defer { if let opts = optsBson { bson_destroy(opts) } } - - try checkCancelled() - - guard let cursor = mongoc_collection_aggregate( - col, MONGOC_QUERY_NONE, pipelineBson, optsBson, nil - ) else { - throw MongoDBError(code: 0, message: "Failed to create aggregation cursor") - } - defer { mongoc_cursor_destroy(cursor) } - - return try iterateCursor(cursor) - } - - func countDocumentsSync( - client: OpaquePointer, database: String, collection: String, filter: String - ) throws -> Int64 { - try checkCancelled() - - guard let filterBson = jsonToBson(filter) else { - throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") - } - defer { bson_destroy(filterBson) } - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - let timeoutMS = queryTimeoutMS - var optsBson: OpaquePointer? - if timeoutMS > 0 { - optsBson = jsonToBson("{\"maxTimeMS\": \(timeoutMS)}") - } - defer { if let opts = optsBson { bson_destroy(opts) } } - - var error = bson_error_t() - let count = mongoc_collection_count_documents(col, filterBson, optsBson, nil, nil, &error) - - try checkCancelled() - guard count >= 0 else { throw makeError(error) } - return count - } - - func insertOneSync( - client: OpaquePointer, database: String, collection: String, document: String - ) throws -> String? { - try checkCancelled() - - guard let docBson = jsonToBson(document) else { - throw MongoDBError(code: 0, message: "Invalid JSON document: \(document)") - } - defer { bson_destroy(docBson) } - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - let reply = bson_new() - defer { bson_destroy(reply) } - var error = bson_error_t() - - guard mongoc_collection_insert_one(col, docBson, nil, reply, &error) else { - throw makeError(error) - } - - if let objectId = bsonToDict(docBson)["_id"] { return "\(objectId)" } - return nil - } - - func updateOneSync( - client: OpaquePointer, database: String, collection: String, filter: String, update: String - ) throws -> Int64 { - try checkCancelled() - - guard let filterBson = jsonToBson(filter) else { - throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") - } - defer { bson_destroy(filterBson) } - - guard let updateBson = jsonToBson(update) else { - throw MongoDBError(code: 0, message: "Invalid JSON update: \(update)") - } - defer { bson_destroy(updateBson) } - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - let reply = bson_new() - defer { bson_destroy(reply) } - var error = bson_error_t() - - guard mongoc_collection_update_one(col, filterBson, updateBson, nil, reply, &error) else { - throw makeError(error) - } - return (bsonToDict(reply)["modifiedCount"] as? Int64) ?? 0 - } - - func deleteOneSync( - client: OpaquePointer, database: String, collection: String, filter: String - ) throws -> Int64 { - try checkCancelled() - - guard let filterBson = jsonToBson(filter) else { - throw MongoDBError(code: 0, message: "Invalid JSON filter: \(filter)") - } - defer { bson_destroy(filterBson) } - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - let reply = bson_new() - defer { bson_destroy(reply) } - var error = bson_error_t() - - guard mongoc_collection_delete_one(col, filterBson, nil, reply, &error) else { - throw makeError(error) - } - return (bsonToDict(reply)["deletedCount"] as? Int64) ?? 0 - } - - func listDatabasesSync(client: OpaquePointer) throws -> [String] { - try checkCancelled() - - let caps = MongoDBCapabilities.parse(serverVersion()) - var fields = ["\"listDatabases\": 1"] - if caps.supportsListDatabasesNameOnly { - fields.append("\"nameOnly\": true") - } - if caps.supportsAuthorizedDatabases { - fields.append("\"authorizedDatabases\": true") - } - let commandJSON = "{\(fields.joined(separator: ", "))}" - guard let command = jsonToBson(commandJSON) else { - throw MongoDBError(code: 0, message: "Failed to create listDatabases command") - } - defer { bson_destroy(command) } - - let reply = bson_new() - defer { bson_destroy(reply) } - var error = bson_error_t() - - let ok = "admin".withCString { mongoc_client_command_simple(client, $0, command, nil, reply, &error) } - - try checkCancelled() - guard ok else { throw makeError(error) } - - guard let databases = bsonToDict(reply)["databases"] as? [[String: Any]] else { return [] } - return databases.compactMap { $0["name"] as? String } - } - - func listCollectionsSync(client: OpaquePointer, database: String) throws -> [String] { - try checkCancelled() - - guard let mongocDb = database.withCString({ mongoc_client_get_database(client, $0) }) else { - throw MongoDBError(code: 0, message: "Failed to get database \(database)") - } - defer { mongoc_database_destroy(mongocDb) } - - var error = bson_error_t() - guard let names = mongoc_database_get_collection_names_with_opts(mongocDb, nil, &error) else { - throw makeError(error) - } - defer { bson_strfreev(names) } - - try checkCancelled() - - var collections: [String] = [] - var index = 0 - while let namePtr = names[index] { - collections.append(String(cString: namePtr)) - index += 1 - } - return collections - } - - func listIndexesSync( - client: OpaquePointer, database: String, collection: String - ) throws -> [[String: Any]] { - try checkCancelled() - - let col = try getCollection(client, database: database, collection: collection) - defer { mongoc_collection_destroy(col) } - - guard let cursor = mongoc_collection_find_indexes_with_opts(col, nil) else { - throw MongoDBError(code: 0, message: "Failed to list indexes for \(collection)") - } - defer { mongoc_cursor_destroy(cursor) } - - return try iterateCursor(cursor).docs - } - - func iterateCursor(_ cursor: OpaquePointer) throws -> (docs: [[String: Any]], isTruncated: Bool) { - try checkCancelled() - - var results: [[String: Any]] = [] - var docPtr: OpaquePointer? - var truncated = false - - while mongoc_cursor_next(cursor, &docPtr) { - try checkCancelled() - - if let doc = docPtr { - results.append(bsonToDict(doc)) - } - - if results.count >= PluginRowLimits.emergencyMax { - truncated = true - logger.warning("Result set truncated at \(PluginRowLimits.emergencyMax) documents") - break - } - } - - var error = bson_error_t() - if mongoc_cursor_error(cursor, &error) { - throw makeError(error) - } - return (docs: results, isTruncated: truncated) - } - - func iterateCursorStreaming( - cursor: OpaquePointer, - continuation: AsyncThrowingStream.Continuation, - streamState: MongoStreamState - ) { - var docPtr: OpaquePointer? - var headerSent = false - var columns: [String] = [] - var columnTypeNames: [String] = [] - - while mongoc_cursor_next(cursor, &docPtr) { - if Task.isCancelled { - cleanup(streamState) - continuation.finish(throwing: CancellationError()) - return - } - - guard let doc = docPtr else { continue } - let dict = bsonToDict(doc) - - if !headerSent { - columns = BsonDocumentFlattener.unionColumns(from: [dict]) - let bsonTypes = BsonDocumentFlattener.columnTypes(for: columns, documents: [dict]) - columnTypeNames = bsonTypes.map { bsonTypeToStreamString($0) } - continuation.yield(.header(PluginStreamHeader( - columns: columns, - columnTypeNames: columnTypeNames - ))) - headerSent = true - } else { - for key in dict.keys.sorted() where !columns.contains(key) { - columns.append(key) - let type = BsonDocumentFlattener.columnTypes(for: [key], documents: [dict]) - columnTypeNames.append(bsonTypeToStreamString(type.first ?? 2)) - } - } - - let row: [PluginCellValue] = columns.map { column in - guard let value = dict[column] else { return .null } - if let data = value as? Data { - return .bytes(data) - } - return PluginCellValue.fromOptional(BsonDocumentFlattener.stringValue(for: value)) - } - continuation.yield(.rows([row])) - } - - var error = bson_error_t() - if mongoc_cursor_error(cursor, &error) { - cleanup(streamState) - continuation.finish(throwing: makeError(error)) - return - } - - if !headerSent { - continuation.yield(.header(PluginStreamHeader( - columns: ["_id"], - columnTypeNames: ["VARCHAR"] - ))) - } - - cleanup(streamState) - continuation.finish() - } - - private func cleanup(_ state: MongoStreamState) { - state.lock.lock() - let cur = state.cursor - let col = state.collection - let alreadyDrained = state.drained - state.drained = true - state.cursor = nil - state.collection = nil - state.lock.unlock() - guard !alreadyDrained else { return } - if let cur { mongoc_cursor_destroy(cur) } - if let col { mongoc_collection_destroy(col) } - } - - private func bsonTypeToStreamString(_ type: Int32) -> String { - switch type { - case 1: return "FLOAT" - case 2: return "VARCHAR" - case 3: return "JSON" - case 4: return "JSON" - case 5: return "BLOB" - case 7: return "VARCHAR" - case 8: return "BOOLEAN" - case 9: return "TIMESTAMP" - case 10: return "VARCHAR" - case 16: return "INTEGER" - case 18: return "BIGINT" - default: return "VARCHAR" - } - } -} -#endif final class MongoStreamState: @unchecked Sendable { var cursor: OpaquePointer? @@ -1235,7 +787,7 @@ final class MongoStreamState: @unchecked Sendable { // MARK: - BSON Helpers -private extension MongoDBConnection { +extension MongoDBConnection { /// Convert a JSON string to a bson_t pointer. Caller must call bson_destroy on the result. func jsonToBson(_ json: String) -> OpaquePointer? { #if canImport(CLibMongoc) diff --git a/Plugins/OracleDriverPlugin/OraclePlugin.swift b/Plugins/OracleDriverPlugin/OraclePlugin.swift index a1d47de61..0fffc29c8 100644 --- a/Plugins/OracleDriverPlugin/OraclePlugin.swift +++ b/Plugins/OracleDriverPlugin/OraclePlugin.swift @@ -1037,8 +1037,9 @@ final class OraclePluginDriver: PluginDatabaseDriver, @unchecked Sendable { ) -> String? { let quotedTable = oracleQuoteIdentifier(table) var query = "SELECT * FROM \(quotedTable)" - let orderBy = oracleBuildOrderByClause(sortColumns: sortColumns, columns: columns) - ?? "ORDER BY 1" + let orderBy = PluginSQLFilter.buildOrderByClause( + sortColumns: sortColumns, columns: columns, quoteIdentifier: oracleQuoteIdentifier + ) ?? "ORDER BY 1" query += " \(orderBy) OFFSET \(offset) ROWS FETCH NEXT \(limit) ROWS ONLY" return query } @@ -1054,12 +1055,21 @@ final class OraclePluginDriver: PluginDatabaseDriver, @unchecked Sendable { ) -> String? { let quotedTable = oracleQuoteIdentifier(table) var query = "SELECT * FROM \(quotedTable)" - let whereClause = oracleBuildWhereClause(filters: filters, logicMode: logicMode) + let whereClause = PluginSQLFilter.buildWhereClause( + filters: filters, + logicMode: logicMode, + quoteIdentifier: oracleQuoteIdentifier, + escapeValue: oracleEscapeValue, + regexCondition: { quoted, value in + "REGEXP_LIKE(\(quoted), '\(value.replacingOccurrences(of: "'", with: "''"))')" + } + ) if !whereClause.isEmpty { query += " WHERE \(whereClause)" } - let orderBy = oracleBuildOrderByClause(sortColumns: sortColumns, columns: columns) - ?? "ORDER BY 1" + let orderBy = PluginSQLFilter.buildOrderByClause( + sortColumns: sortColumns, columns: columns, quoteIdentifier: oracleQuoteIdentifier + ) ?? "ORDER BY 1" query += " \(orderBy) OFFSET \(offset) ROWS FETCH NEXT \(limit) ROWS ONLY" return query } @@ -1070,29 +1080,6 @@ final class OraclePluginDriver: PluginDatabaseDriver, @unchecked Sendable { "\"\(identifier.replacingOccurrences(of: "\"", with: "\"\""))\"" } - private func oracleBuildOrderByClause( - sortColumns: [(columnIndex: Int, ascending: Bool)], - columns: [String] - ) -> String? { - let parts = sortColumns.compactMap { sortCol -> String? in - guard sortCol.columnIndex >= 0, sortCol.columnIndex < columns.count else { return nil } - let columnName = columns[sortCol.columnIndex] - let direction = sortCol.ascending ? "ASC" : "DESC" - let quotedColumn = oracleQuoteIdentifier(columnName) - return "\(quotedColumn) \(direction)" - } - guard !parts.isEmpty else { return nil } - return "ORDER BY " + parts.joined(separator: ", ") - } - - private func oracleEscapeForLike(_ text: String) -> String { - text - .replacingOccurrences(of: "\\", with: "\\\\") - .replacingOccurrences(of: "%", with: "\\%") - .replacingOccurrences(of: "_", with: "\\_") - .replacingOccurrences(of: "'", with: "''") - } - private func oracleEscapeValue(_ value: String) -> String { let trimmed = value.trimmingCharacters(in: .whitespaces) if trimmed.caseInsensitiveCompare("NULL") == .orderedSame { return "NULL" } @@ -1100,65 +1087,6 @@ final class OraclePluginDriver: PluginDatabaseDriver, @unchecked Sendable { return "'\(trimmed.replacingOccurrences(of: "'", with: "''"))'" } - private func oracleBuildWhereClause( - filters: [(column: String, op: String, value: String)], - logicMode: String - ) -> String { - let conditions = filters.compactMap { filter -> String? in - oracleBuildFilterCondition(column: filter.column, op: filter.op, value: filter.value) - } - guard !conditions.isEmpty else { return "" } - let separator = logicMode == "and" ? " AND " : " OR " - return conditions.joined(separator: separator) - } - - private func oracleBuildFilterCondition(column: String, op: String, value: String) -> String? { - let quoted = oracleQuoteIdentifier(column) - switch op { - case "=": return "\(quoted) = \(oracleEscapeValue(value))" - case "!=": return "\(quoted) != \(oracleEscapeValue(value))" - case ">": return "\(quoted) > \(oracleEscapeValue(value))" - case ">=": return "\(quoted) >= \(oracleEscapeValue(value))" - case "<": return "\(quoted) < \(oracleEscapeValue(value))" - case "<=": return "\(quoted) <= \(oracleEscapeValue(value))" - case "IS NULL": return "\(quoted) IS NULL" - case "IS NOT NULL": return "\(quoted) IS NOT NULL" - case "IS EMPTY": return "(\(quoted) IS NULL OR \(quoted) = '')" - case "IS NOT EMPTY": return "(\(quoted) IS NOT NULL AND \(quoted) != '')" - case "CONTAINS": - let escaped = oracleEscapeForLike(value) - return "\(quoted) LIKE '%\(escaped)%' ESCAPE '\\'" - case "NOT CONTAINS": - let escaped = oracleEscapeForLike(value) - return "\(quoted) NOT LIKE '%\(escaped)%' ESCAPE '\\'" - case "STARTS WITH": - let escaped = oracleEscapeForLike(value) - return "\(quoted) LIKE '\(escaped)%' ESCAPE '\\'" - case "ENDS WITH": - let escaped = oracleEscapeForLike(value) - return "\(quoted) LIKE '%\(escaped)' ESCAPE '\\'" - case "IN": - let values = value.split(separator: ",") - .map { oracleEscapeValue($0.trimmingCharacters(in: .whitespaces)) } - .joined(separator: ", ") - return values.isEmpty ? nil : "\(quoted) IN (\(values))" - case "NOT IN": - let values = value.split(separator: ",") - .map { oracleEscapeValue($0.trimmingCharacters(in: .whitespaces)) } - .joined(separator: ", ") - return values.isEmpty ? nil : "\(quoted) NOT IN (\(values))" - case "BETWEEN": - let parts = value.split(separator: ",", maxSplits: 1) - guard parts.count == 2 else { return nil } - let v1 = oracleEscapeValue(parts[0].trimmingCharacters(in: .whitespaces)) - let v2 = oracleEscapeValue(parts[1].trimmingCharacters(in: .whitespaces)) - return "\(quoted) BETWEEN \(v1) AND \(v2)" - case "REGEX": - let escaped = value.replacingOccurrences(of: "'", with: "''") - return "REGEXP_LIKE(\(quoted), '\(escaped)')" - default: return nil - } - } // MARK: - Private Helpers diff --git a/Plugins/RedisDriverPlugin/RedisPluginDriver+Operations.swift b/Plugins/RedisDriverPlugin/RedisPluginDriver+Operations.swift new file mode 100644 index 000000000..120200c2a --- /dev/null +++ b/Plugins/RedisDriverPlugin/RedisPluginDriver+Operations.swift @@ -0,0 +1,510 @@ +// +// RedisPluginDriver+Operations.swift +// RedisDriverPlugin +// + +import Foundation +import OSLog +import TableProPluginKit + +extension RedisPluginDriver { + func executeOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .get, .set, .del, .keys, .scan, .type, .ttl, .pttl, .expire, .persist, .rename, .exists: + return try await executeKeyOperation(operation, connection: conn, startTime: startTime) + + case .hget, .hset, .hgetall, .hdel: + return try await executeHashOperation(operation, connection: conn, startTime: startTime) + + case .lrange, .lpush, .rpush, .llen: + return try await executeListOperation(operation, connection: conn, startTime: startTime) + + case .smembers, .sadd, .srem, .scard: + return try await executeSetOperation(operation, connection: conn, startTime: startTime) + + case .zrange, .zadd, .zrem, .zcard: + return try await executeSortedSetOperation(operation, connection: conn, startTime: startTime) + + case .xrange, .xlen: + return try await executeStreamOperation(operation, connection: conn, startTime: startTime) + + case .ping, .info, .dbsize, .flushdb, .select, .configGet, .configSet, .command, .multi, .exec, .discard: + return try await executeServerOperation(operation, connection: conn, startTime: startTime) + } + } + + // MARK: - Key Operations + + func executeKeyOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .get(let key): + let result = try await conn.executeCommand(["GET", key]) + let value = result.stringValue + return PluginQueryResult( + columns: ["Key", "Value"], + columnTypeNames: ["String", "String"], + rows: [[key, value].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .set(let key, let value, let options): + var args = ["SET", key, value] + if let opts = options { + if let ex = opts.ex { args += ["EX", String(ex)] } + if let px = opts.px { args += ["PX", String(px)] } + if let exat = opts.exat { args += ["EXAT", String(exat)] } + if let pxat = opts.pxat { args += ["PXAT", String(pxat)] } + if opts.nx { args.append("NX") } + if opts.xx { args.append("XX") } + } + _ = try await conn.executeCommand(args) + return buildStatusResult("OK", startTime: startTime) + + case .del(let keys): + let args = ["DEL"] + keys + let result = try await conn.executeCommand(args) + let deleted = result.intValue ?? 0 + return PluginQueryResult( + columns: ["deleted"], + columnTypeNames: ["Int64"], + rows: [[String(deleted)].asCells], + rowsAffected: deleted, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .keys(let pattern): + let result = try await conn.executeCommand(["KEYS", pattern]) + guard let items = result.arrayValue else { + return buildEmptyKeyResult(startTime: startTime) + } + let keys = items.map { redisReplyToString($0) } + let capped = Array(keys.prefix(PluginRowLimits.emergencyMax)) + let keysTruncated = keys.count > PluginRowLimits.emergencyMax + return try await buildKeyBrowseResult( + keys: capped, connection: conn, startTime: startTime, isTruncated: keysTruncated + ) + + case .scan(let cursor, let pattern, let count): + var args = ["SCAN", String(cursor)] + if let p = pattern { args += ["MATCH", p] } + if let c = count { args += ["COUNT", String(c)] } + let result = try await conn.executeCommand(args) + return try await handleScanResult(result, connection: conn, startTime: startTime) + + case .type(let key): + let result = try await conn.executeCommand(["TYPE", key]) + let typeName = result.stringValue ?? "none" + return PluginQueryResult( + columns: ["Key", "Type"], + columnTypeNames: ["String", "String"], + rows: [[key, typeName].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .ttl(let key): + let result = try await conn.executeCommand(["TTL", key]) + let ttl = result.intValue ?? -1 + return PluginQueryResult( + columns: ["Key", "TTL"], + columnTypeNames: ["String", "Int64"], + rows: [[key, String(ttl)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .pttl(let key): + let result = try await conn.executeCommand(["PTTL", key]) + let pttl = result.intValue ?? -1 + return PluginQueryResult( + columns: ["Key", "PTTL"], + columnTypeNames: ["String", "Int64"], + rows: [[key, String(pttl)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .expire(let key, let seconds): + let result = try await conn.executeCommand(["EXPIRE", key, String(seconds)]) + let success = (result.intValue ?? 0) == 1 + return buildStatusResult(success ? "OK" : "Key not found", startTime: startTime) + + case .persist(let key): + let result = try await conn.executeCommand(["PERSIST", key]) + let success = (result.intValue ?? 0) == 1 + return buildStatusResult(success ? "OK" : "Key not found or no TTL", startTime: startTime) + + case .rename(let key, let newKey): + let reply = try await conn.executeCommand(["RENAME", key, newKey]) + if case .error(let msg) = reply { + throw RedisPluginError(code: 0, message: "RENAME failed: \(msg)") + } + return buildStatusResult("OK", startTime: startTime) + + case .exists(let keys): + let args = ["EXISTS"] + keys + let result = try await conn.executeCommand(args) + let count = result.intValue ?? 0 + return PluginQueryResult( + columns: ["exists"], + columnTypeNames: ["Int64"], + rows: [[String(count)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeKeyOperation") + } + } + + // MARK: - Hash Operations + + func executeHashOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .hget(let key, let field): + let result = try await conn.executeCommand(["HGET", key, field]) + let value = result.stringValue + return PluginQueryResult( + columns: ["Field", "Value"], + columnTypeNames: ["String", "String"], + rows: [[field, value].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .hset(let key, let fieldValues): + var args = ["HSET", key] + for (field, value) in fieldValues { + args += [field, value] + } + let result = try await conn.executeCommand(args) + let added = result.intValue ?? 0 + return PluginQueryResult( + columns: ["added"], + columnTypeNames: ["Int64"], + rows: [[String(added)].asCells], + rowsAffected: added, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .hgetall(let key): + let result = try await conn.executeCommand(["HGETALL", key]) + return buildHashResult(result, startTime: startTime) + + case .hdel(let key, let fields): + let args = ["HDEL", key] + fields + let result = try await conn.executeCommand(args) + let removed = result.intValue ?? 0 + return PluginQueryResult( + columns: ["removed"], + columnTypeNames: ["Int64"], + rows: [[String(removed)].asCells], + rowsAffected: removed, + executionTime: Date().timeIntervalSince(startTime) + ) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeHashOperation") + } + } + + // MARK: - List Operations + + func executeListOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .lrange(let key, let start, let stop): + let result = try await conn.executeCommand(["LRANGE", key, String(start), String(stop)]) + return buildListResult(result, startOffset: start, startTime: startTime) + + case .lpush(let key, let values): + let args = ["LPUSH", key] + values + let result = try await conn.executeCommand(args) + let length = result.intValue ?? 0 + return PluginQueryResult( + columns: ["length"], + columnTypeNames: ["Int64"], + rows: [[String(length)].asCells], + rowsAffected: values.count, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .rpush(let key, let values): + let args = ["RPUSH", key] + values + let result = try await conn.executeCommand(args) + let length = result.intValue ?? 0 + return PluginQueryResult( + columns: ["length"], + columnTypeNames: ["Int64"], + rows: [[String(length)].asCells], + rowsAffected: values.count, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .llen(let key): + let result = try await conn.executeCommand(["LLEN", key]) + let length = result.intValue ?? 0 + return PluginQueryResult( + columns: ["Key", "Length"], + columnTypeNames: ["String", "Int64"], + rows: [[key, String(length)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeListOperation") + } + } + + // MARK: - Set Operations + + func executeSetOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .smembers(let key): + let result = try await conn.executeCommand(["SMEMBERS", key]) + return buildSetResult(result, startTime: startTime) + + case .sadd(let key, let members): + let args = ["SADD", key] + members + let result = try await conn.executeCommand(args) + let added = result.intValue ?? 0 + return PluginQueryResult( + columns: ["added"], + columnTypeNames: ["Int64"], + rows: [[String(added)].asCells], + rowsAffected: added, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .srem(let key, let members): + let args = ["SREM", key] + members + let result = try await conn.executeCommand(args) + let removed = result.intValue ?? 0 + return PluginQueryResult( + columns: ["removed"], + columnTypeNames: ["Int64"], + rows: [[String(removed)].asCells], + rowsAffected: removed, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .scard(let key): + let result = try await conn.executeCommand(["SCARD", key]) + let count = result.intValue ?? 0 + return PluginQueryResult( + columns: ["Key", "Cardinality"], + columnTypeNames: ["String", "Int64"], + rows: [[key, String(count)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeSetOperation") + } + } + + // MARK: - Sorted Set Operations + + func executeSortedSetOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .zrange(let key, let start, let stop, let flags): + var args = ["ZRANGE", key, start, stop] + args += flags + let withScores = flags.contains("WITHSCORES") + let result = try await conn.executeCommand(args) + return buildSortedSetResult(result, withScores: withScores, startTime: startTime) + + case .zadd(let key, let flags, let scoreMembers): + var args = ["ZADD", key] + args += flags + for (score, member) in scoreMembers { + args += [String(score), member] + } + let result = try await conn.executeCommand(args) + if flags.contains("INCR") { + // INCR mode returns the new score (or nil for NX miss) + let scoreStr = result.stringValue ?? "nil" + return PluginQueryResult( + columns: ["score"], + columnTypeNames: ["String"], + rows: [[scoreStr].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + let count = result.intValue ?? 0 + let columnName = flags.contains("CH") ? "changed" : "added" + return PluginQueryResult( + columns: [columnName], + columnTypeNames: ["Int64"], + rows: [[String(count)].asCells], + rowsAffected: count, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .zrem(let key, let members): + let args = ["ZREM", key] + members + let result = try await conn.executeCommand(args) + let removed = result.intValue ?? 0 + return PluginQueryResult( + columns: ["removed"], + columnTypeNames: ["Int64"], + rows: [[String(removed)].asCells], + rowsAffected: removed, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .zcard(let key): + let result = try await conn.executeCommand(["ZCARD", key]) + let count = result.intValue ?? 0 + return PluginQueryResult( + columns: ["Key", "Cardinality"], + columnTypeNames: ["String", "Int64"], + rows: [[key, String(count)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeSortedSetOperation") + } + } + + // MARK: - Stream Operations + + func executeStreamOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .xrange(let key, let start, let end, let count): + var args = ["XRANGE", key, start, end] + if let c = count { args += ["COUNT", String(c)] } + let result = try await conn.executeCommand(args) + return buildStreamResult(result, startTime: startTime) + + case .xlen(let key): + let result = try await conn.executeCommand(["XLEN", key]) + let length = result.intValue ?? 0 + return PluginQueryResult( + columns: ["Key", "Length"], + columnTypeNames: ["String", "Int64"], + rows: [[key, String(length)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeStreamOperation") + } + } + + // MARK: - Server Operations + + func executeServerOperation( + _ operation: RedisOperation, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + switch operation { + case .ping: + _ = try await conn.executeCommand(["PING"]) + return PluginQueryResult( + columns: ["ok"], + columnTypeNames: ["Int32"], + rows: [["1"].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .info(let section): + var args = ["INFO"] + if let s = section { args.append(s) } + let result = try await conn.executeCommand(args) + let infoText = result.stringValue ?? String(describing: result) + return PluginQueryResult( + columns: ["info"], + columnTypeNames: ["String"], + rows: [[infoText].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .dbsize: + let result = try await conn.executeCommand(["DBSIZE"]) + let count = result.intValue ?? 0 + return PluginQueryResult( + columns: ["keys"], + columnTypeNames: ["Int64"], + rows: [[String(count)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .flushdb: + _ = try await conn.executeCommand(["FLUSHDB"]) + return buildStatusResult("OK", startTime: startTime) + + case .select(let database): + try await conn.selectDatabase(database) + cachedScanPattern = nil + cachedScanKeys = nil + return buildStatusResult("OK", startTime: startTime) + + case .configGet(let parameter): + let result = try await conn.executeCommand(["CONFIG", "GET", parameter]) + return buildConfigResult(result, startTime: startTime) + + case .configSet(let parameter, let value): + _ = try await conn.executeCommand(["CONFIG", "SET", parameter, value]) + return buildStatusResult("OK", startTime: startTime) + + case .command(let args): + let result = try await conn.executeCommand(args) + return buildGenericResult(result, startTime: startTime) + + case .multi: + _ = try await conn.executeCommand(["MULTI"]) + return buildStatusResult("OK", startTime: startTime) + + case .exec: + let result = try await conn.executeCommand(["EXEC"]) + return buildGenericResult(result, startTime: startTime) + + case .discard: + _ = try await conn.executeCommand(["DISCARD"]) + return buildStatusResult("OK", startTime: startTime) + + default: + throw RedisPluginError(code: 0, message: "Unexpected operation in executeServerOperation") + } + } +} diff --git a/Plugins/RedisDriverPlugin/RedisPluginDriver+ResultBuilding.swift b/Plugins/RedisDriverPlugin/RedisPluginDriver+ResultBuilding.swift new file mode 100644 index 000000000..61d160ccd --- /dev/null +++ b/Plugins/RedisDriverPlugin/RedisPluginDriver+ResultBuilding.swift @@ -0,0 +1,486 @@ +// +// RedisPluginDriver+ResultBuilding.swift +// RedisDriverPlugin +// + +import Foundation +import OSLog +import TableProPluginKit + +extension RedisPluginDriver { + static let previewLimit = 100 + static let previewMaxChars = 1_000 + + func buildKeyBrowseResult( + keys: [String], + connection conn: RedisPluginConnection, + startTime: Date, + isTruncated: Bool = false + ) async throws -> PluginQueryResult { + guard !keys.isEmpty else { + return buildEmptyKeyResult(startTime: startTime) + } + + var typeAndTtlCommands: [[String]] = [] + typeAndTtlCommands.reserveCapacity(keys.count * 2) + for key in keys { + typeAndTtlCommands.append(["TYPE", key]) + typeAndTtlCommands.append(["TTL", key]) + } + let typeAndTtlReplies = try await conn.executePipeline(typeAndTtlCommands) + + var typeNames: [String] = [] + typeNames.reserveCapacity(keys.count) + var ttlValues: [Int] = [] + ttlValues.reserveCapacity(keys.count) + for i in 0 ..< keys.count { + let typeName = (typeAndTtlReplies[i * 2].stringValue ?? "unknown").uppercased() + let ttl = typeAndTtlReplies[i * 2 + 1].intValue ?? -1 + typeNames.append(typeName) + ttlValues.append(ttl) + } + + var previewCommands: [[String]] = [] + previewCommands.reserveCapacity(keys.count) + var previewCommandIndices: [Int] = [] + previewCommandIndices.reserveCapacity(keys.count) + + for (i, key) in keys.enumerated() { + let command: [String]? = previewCommandForType(typeNames[i], key: key) + if let command { + previewCommandIndices.append(previewCommands.count) + previewCommands.append(command) + } else { + previewCommandIndices.append(-1) + } + } + + var previewReplies: [RedisReply] = [] + if !previewCommands.isEmpty { + previewReplies = try await conn.executePipeline(previewCommands) + } + + var rows: [[PluginCellValue]] = [] + rows.reserveCapacity(keys.count) + for (i, key) in keys.enumerated() { + let ttlStr = String(ttlValues[i]) + let pipelineIndex = previewCommandIndices[i] + let preview: String? + if pipelineIndex >= 0, pipelineIndex < previewReplies.count { + preview = formatPreviewReply( + previewReplies[pipelineIndex], type: typeNames[i] + ) + } else { + preview = nil + } + rows.append([key, typeNames[i], ttlStr, preview].asCells) + } + + return PluginQueryResult( + columns: ["Key", "Type", "TTL", "Value"], + columnTypeNames: ["String", "RedisType", "RedisInt", "RedisRaw"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime), + isTruncated: isTruncated + ) + } + + func previewCommandForType(_ type: String, key: String) -> [String]? { + switch type.lowercased() { + case "string": + return ["GET", key] + case "hash": + return ["HSCAN", key, "0", "COUNT", String(Self.previewLimit)] + case "list": + return ["LRANGE", key, "0", String(Self.previewLimit - 1)] + case "set": + return ["SSCAN", key, "0", "COUNT", String(Self.previewLimit)] + case "zset": + return ["ZRANGE", key, "0", String(Self.previewLimit - 1), "WITHSCORES"] + case "stream": + return ["XREVRANGE", key, "+", "-", "COUNT", "5"] + default: + return nil + } + } + + func formatPreviewReply(_ reply: RedisReply, type: String) -> String? { + switch type.lowercased() { + case "string": + return truncatePreview(redisReplyToString(reply)) + + case "hash": + let array: [RedisReply] + if case .array(let scanResult) = reply, + scanResult.count == 2, + let items = scanResult[1].arrayValue { + array = items + } else if let items = reply.arrayValue, !items.isEmpty { + array = items + } else { + return "{}" + } + guard !array.isEmpty else { return "{}" } + var pairs: [String] = [] + var idx = 0 + while idx + 1 < array.count { + let field = redisReplyToString(array[idx]) + let value = redisReplyToString(array[idx + 1]) + pairs.append( + "\"\(escapeJsonString(field))\":\"\(escapeJsonString(value))\"" + ) + idx += 2 + } + return truncatePreview("{\(pairs.joined(separator: ","))}") + + case "list": + guard let items = reply.arrayValue else { return "[]" } + let quoted = items.map { "\"\(escapeJsonString(redisReplyToString($0)))\"" } + return truncatePreview("[\(quoted.joined(separator: ", "))]") + + case "set": + let members: [RedisReply] + if case .array(let scanResult) = reply, + scanResult.count == 2, + let items = scanResult[1].arrayValue { + members = items + } else if let items = reply.arrayValue { + members = items + } else { + return "[]" + } + let quoted = members.map { "\"\(escapeJsonString(redisReplyToString($0)))\"" } + return truncatePreview("[\(quoted.joined(separator: ", "))]") + + case "zset": + // Parse WITHSCORES result: alternating member, score pairs + guard let items = reply.arrayValue, !items.isEmpty else { return "[]" } + var pairs: [String] = [] + var i = 0 + while i + 1 < items.count { + pairs.append("\(redisReplyToString(items[i])):\(redisReplyToString(items[i + 1]))") + i += 2 + } + return truncatePreview(pairs.joined(separator: ", ")) + + case "stream": + // Parse XREVRANGE result: array of [id, [field, value, ...]] entries + guard let entries = reply.arrayValue, !entries.isEmpty else { + return "(0 entries)" + } + var entryStrings: [String] = [] + for entry in entries { + guard let parts = entry.arrayValue, parts.count >= 2, + let fields = parts[1].arrayValue else { + continue + } + let entryId = redisReplyToString(parts[0]) + var fieldPairs: [String] = [] + var j = 0 + while j + 1 < fields.count { + fieldPairs.append("\(redisReplyToString(fields[j]))=\(redisReplyToString(fields[j + 1]))") + j += 2 + } + entryStrings.append("\(entryId): \(fieldPairs.joined(separator: ", "))") + } + return truncatePreview(entryStrings.joined(separator: "; ")) + + default: + return nil + } + } + + func truncatePreview(_ value: String?) -> String? { + guard let value else { return nil } + let nsValue = value as NSString + if nsValue.length > Self.previewMaxChars { + return nsValue.substring(to: Self.previewMaxChars) + "..." + } + return value + } + + func escapeJsonString(_ str: String) -> String { + var result = "" + for scalar in str.unicodeScalars { + switch scalar { + case "\\": result += "\\\\" + case "\"": result += "\\\"" + case "\n": result += "\\n" + case "\r": result += "\\r" + case "\t": result += "\\t" + default: + if scalar.value < 0x20 { + result += String(format: "\\u%04X", scalar.value) + } else { + result += String(scalar) + } + } + } + return result + } + + func buildEmptyKeyResult(startTime: Date) -> PluginQueryResult { + PluginQueryResult( + columns: ["Key", "Type", "TTL", "Value"], + columnTypeNames: ["String", "RedisType", "RedisInt", "RedisRaw"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + func buildStatusResult(_ message: String, startTime: Date) -> PluginQueryResult { + PluginQueryResult( + columns: ["status"], + columnTypeNames: ["String"], + rows: [[message].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + func buildGenericResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { + switch result { + case .string(let s), .status(let s): + return PluginQueryResult( + columns: ["result"], + columnTypeNames: ["String"], + rows: [[s].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .integer(let i): + return PluginQueryResult( + columns: ["result"], + columnTypeNames: ["Int64"], + rows: [[String(i)].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .data(let d): + let str = String(data: d, encoding: .utf8) ?? d.base64EncodedString() + return PluginQueryResult( + columns: ["result"], + columnTypeNames: ["String"], + rows: [[str].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .array(let items): + let rows = items.map { ([redisReplyToString($0)] as [String?]).asCells } + return PluginQueryResult( + columns: ["result"], + columnTypeNames: ["String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .error(let e): + return PluginQueryResult( + columns: ["result"], + columnTypeNames: ["String"], + rows: [[e].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + + case .null: + return PluginQueryResult( + columns: ["result"], + columnTypeNames: ["String"], + rows: [["(nil)"].asCells], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + } + + func redisReplyToString(_ reply: RedisReply) -> String { + switch reply { + case .string(let s), .status(let s), .error(let s): return s + case .integer(let i): return String(i) + case .data(let d): return String(data: d, encoding: .utf8) ?? d.base64EncodedString() + case .array(let items): return "[\(items.map { redisReplyToString($0) }.joined(separator: ", "))]" + case .null: return "(nil)" + } + } + + func buildHashResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { + guard let items = result.arrayValue, !items.isEmpty else { + return PluginQueryResult( + columns: ["Field", "Value"], + columnTypeNames: ["String", "String"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + var rows: [[PluginCellValue]] = [] + var i = 0 + while i + 1 < items.count { + rows.append([redisReplyToString(items[i]), redisReplyToString(items[i + 1])].asCells) + i += 2 + } + + return PluginQueryResult( + columns: ["Field", "Value"], + columnTypeNames: ["String", "String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + func buildListResult(_ result: RedisReply, startOffset: Int = 0, startTime: Date) -> PluginQueryResult { + guard let items = result.arrayValue else { + return PluginQueryResult( + columns: ["Index", "Value"], + columnTypeNames: ["Int64", "String"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + let rows = items.enumerated().map { index, item -> [PluginCellValue] in + ([String(startOffset + index), redisReplyToString(item)] as [String?]).asCells + } + + return PluginQueryResult( + columns: ["Index", "Value"], + columnTypeNames: ["Int64", "String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + func buildSetResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { + guard let items = result.arrayValue else { + return PluginQueryResult( + columns: ["Member"], + columnTypeNames: ["String"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + let rows = items.map { ([redisReplyToString($0)] as [String?]).asCells } + + return PluginQueryResult( + columns: ["Member"], + columnTypeNames: ["String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + func buildSortedSetResult(_ result: RedisReply, withScores: Bool, startTime: Date) -> PluginQueryResult { + guard let items = result.arrayValue else { + return PluginQueryResult( + columns: withScores ? ["Member", "Score"] : ["Member"], + columnTypeNames: withScores ? ["String", "Double"] : ["String"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + if withScores { + var rows: [[PluginCellValue]] = [] + var i = 0 + while i + 1 < items.count { + rows.append([redisReplyToString(items[i]), redisReplyToString(items[i + 1])].asCells) + i += 2 + } + return PluginQueryResult( + columns: ["Member", "Score"], + columnTypeNames: ["String", "Double"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } else { + let rows = items.map { ([redisReplyToString($0)] as [String?]).asCells } + return PluginQueryResult( + columns: ["Member"], + columnTypeNames: ["String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + } + + func buildStreamResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { + guard let entries = result.arrayValue else { + return PluginQueryResult( + columns: ["ID", "Fields"], + columnTypeNames: ["String", "String"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + var rows: [[PluginCellValue]] = [] + for entry in entries { + guard let entryParts = entry.arrayValue, entryParts.count >= 2, + let fields = entryParts[1].arrayValue else { + continue + } + let entryId = redisReplyToString(entryParts[0]) + + var fieldPairs: [String] = [] + var i = 0 + while i + 1 < fields.count { + fieldPairs.append("\(redisReplyToString(fields[i]))=\(redisReplyToString(fields[i + 1]))") + i += 2 + } + rows.append([entryId, fieldPairs.joined(separator: ", ")].asCells) + } + + return PluginQueryResult( + columns: ["ID", "Fields"], + columnTypeNames: ["String", "String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + func buildConfigResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { + guard let items = result.arrayValue, !items.isEmpty else { + return PluginQueryResult( + columns: ["Parameter", "Value"], + columnTypeNames: ["String", "String"], + rows: [], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + + var rows: [[PluginCellValue]] = [] + var i = 0 + while i + 1 < items.count { + rows.append([redisReplyToString(items[i]), redisReplyToString(items[i + 1])].asCells) + i += 2 + } + + return PluginQueryResult( + columns: ["Parameter", "Value"], + columnTypeNames: ["String", "String"], + rows: rows, + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } +} diff --git a/Plugins/RedisDriverPlugin/RedisPluginDriver+Scan.swift b/Plugins/RedisDriverPlugin/RedisPluginDriver+Scan.swift new file mode 100644 index 000000000..ebd86c01b --- /dev/null +++ b/Plugins/RedisDriverPlugin/RedisPluginDriver+Scan.swift @@ -0,0 +1,85 @@ +// +// RedisPluginDriver+Scan.swift +// RedisDriverPlugin +// + +import Foundation +import OSLog +import TableProPluginKit + +extension RedisPluginDriver { + func scanAllKeys( + connection conn: RedisPluginConnection, + pattern: String?, + maxKeys: Int + ) async throws -> [String] { + var allKeys: [String] = [] + var cursor = "0" + + repeat { + var args = ["SCAN", cursor] + if let p = pattern { + args += ["MATCH", p] + } + args += ["COUNT", "1000"] + + let result = try await conn.executeCommand(args) + + guard case .array(let scanResult) = result, + scanResult.count == 2 else { + break + } + + let nextCursor: String + switch scanResult[0] { + case .string(let s): nextCursor = s + case .status(let s): nextCursor = s + case .data(let d): nextCursor = String(data: d, encoding: .utf8) ?? "0" + default: nextCursor = "0" + } + cursor = nextCursor + + if case .array(let keyReplies) = scanResult[1] { + for reply in keyReplies { + switch reply { + case .string(let k): allKeys.append(k) + case .data(let d): + if let k = String(data: d, encoding: .utf8) { allKeys.append(k) } + default: break + } + } + } + + if allKeys.count >= maxKeys { + allKeys = Array(allKeys.prefix(maxKeys)) + break + } + } while cursor != "0" + + return allKeys.sorted() + } + + func handleScanResult( + _ result: RedisReply, + connection conn: RedisPluginConnection, + startTime: Date + ) async throws -> PluginQueryResult { + guard case .array(let scanResult) = result, + scanResult.count == 2, + case .array(let keyReplies) = scanResult[1] else { + return buildEmptyKeyResult(startTime: startTime) + } + + let keys = keyReplies.compactMap { reply -> String? in + if case .string(let k) = reply { return k } + if case .data(let d) = reply { return String(data: d, encoding: .utf8) } + return nil + } + + let capped = Array(keys.prefix(PluginRowLimits.emergencyMax)) + let keysTruncated = keys.count > PluginRowLimits.emergencyMax + return try await buildKeyBrowseResult( + keys: capped, connection: conn, startTime: startTime, isTruncated: keysTruncated + ) + } +} diff --git a/Plugins/RedisDriverPlugin/RedisPluginDriver.swift b/Plugins/RedisDriverPlugin/RedisPluginDriver.swift index f38c68a57..3e1954d1a 100644 --- a/Plugins/RedisDriverPlugin/RedisPluginDriver.swift +++ b/Plugins/RedisDriverPlugin/RedisPluginDriver.swift @@ -11,19 +11,19 @@ import Foundation import OSLog import TableProPluginKit -private extension Array where Element == String? { +extension Array where Element == String? { var asCells: [PluginCellValue] { map(PluginCellValue.fromOptional) } } -private extension Array where Element == String { +extension Array where Element == String { var asCells: [PluginCellValue] { map(PluginCellValue.text) } } -private extension Array where Element == [String?] { +extension Array where Element == [String?] { var asCellRows: [[PluginCellValue]] { map { $0.map(PluginCellValue.fromOptional) } } } -private extension Array where Element == [String] { +extension Array where Element == [String] { var asCellRows: [[PluginCellValue]] { map { $0.map(PluginCellValue.text) } } } @@ -35,8 +35,8 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { private static let maxScanKeys = PluginRowLimits.emergencyMax - private var cachedScanPattern: String? - private var cachedScanKeys: [String]? + var cachedScanPattern: String? + var cachedScanKeys: [String]? var serverVersion: String? { redisConnection?.serverVersion() @@ -630,1066 +630,3 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { ) } } - -// MARK: - Operation Dispatch - -private extension RedisPluginDriver { - func executeOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .get, .set, .del, .keys, .scan, .type, .ttl, .pttl, .expire, .persist, .rename, .exists: - return try await executeKeyOperation(operation, connection: conn, startTime: startTime) - - case .hget, .hset, .hgetall, .hdel: - return try await executeHashOperation(operation, connection: conn, startTime: startTime) - - case .lrange, .lpush, .rpush, .llen: - return try await executeListOperation(operation, connection: conn, startTime: startTime) - - case .smembers, .sadd, .srem, .scard: - return try await executeSetOperation(operation, connection: conn, startTime: startTime) - - case .zrange, .zadd, .zrem, .zcard: - return try await executeSortedSetOperation(operation, connection: conn, startTime: startTime) - - case .xrange, .xlen: - return try await executeStreamOperation(operation, connection: conn, startTime: startTime) - - case .ping, .info, .dbsize, .flushdb, .select, .configGet, .configSet, .command, .multi, .exec, .discard: - return try await executeServerOperation(operation, connection: conn, startTime: startTime) - } - } - - // MARK: - Key Operations - - func executeKeyOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .get(let key): - let result = try await conn.executeCommand(["GET", key]) - let value = result.stringValue - return PluginQueryResult( - columns: ["Key", "Value"], - columnTypeNames: ["String", "String"], - rows: [[key, value].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .set(let key, let value, let options): - var args = ["SET", key, value] - if let opts = options { - if let ex = opts.ex { args += ["EX", String(ex)] } - if let px = opts.px { args += ["PX", String(px)] } - if let exat = opts.exat { args += ["EXAT", String(exat)] } - if let pxat = opts.pxat { args += ["PXAT", String(pxat)] } - if opts.nx { args.append("NX") } - if opts.xx { args.append("XX") } - } - _ = try await conn.executeCommand(args) - return buildStatusResult("OK", startTime: startTime) - - case .del(let keys): - let args = ["DEL"] + keys - let result = try await conn.executeCommand(args) - let deleted = result.intValue ?? 0 - return PluginQueryResult( - columns: ["deleted"], - columnTypeNames: ["Int64"], - rows: [[String(deleted)].asCells], - rowsAffected: deleted, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .keys(let pattern): - let result = try await conn.executeCommand(["KEYS", pattern]) - guard let items = result.arrayValue else { - return buildEmptyKeyResult(startTime: startTime) - } - let keys = items.map { redisReplyToString($0) } - let capped = Array(keys.prefix(PluginRowLimits.emergencyMax)) - let keysTruncated = keys.count > PluginRowLimits.emergencyMax - return try await buildKeyBrowseResult( - keys: capped, connection: conn, startTime: startTime, isTruncated: keysTruncated - ) - - case .scan(let cursor, let pattern, let count): - var args = ["SCAN", String(cursor)] - if let p = pattern { args += ["MATCH", p] } - if let c = count { args += ["COUNT", String(c)] } - let result = try await conn.executeCommand(args) - return try await handleScanResult(result, connection: conn, startTime: startTime) - - case .type(let key): - let result = try await conn.executeCommand(["TYPE", key]) - let typeName = result.stringValue ?? "none" - return PluginQueryResult( - columns: ["Key", "Type"], - columnTypeNames: ["String", "String"], - rows: [[key, typeName].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .ttl(let key): - let result = try await conn.executeCommand(["TTL", key]) - let ttl = result.intValue ?? -1 - return PluginQueryResult( - columns: ["Key", "TTL"], - columnTypeNames: ["String", "Int64"], - rows: [[key, String(ttl)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .pttl(let key): - let result = try await conn.executeCommand(["PTTL", key]) - let pttl = result.intValue ?? -1 - return PluginQueryResult( - columns: ["Key", "PTTL"], - columnTypeNames: ["String", "Int64"], - rows: [[key, String(pttl)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .expire(let key, let seconds): - let result = try await conn.executeCommand(["EXPIRE", key, String(seconds)]) - let success = (result.intValue ?? 0) == 1 - return buildStatusResult(success ? "OK" : "Key not found", startTime: startTime) - - case .persist(let key): - let result = try await conn.executeCommand(["PERSIST", key]) - let success = (result.intValue ?? 0) == 1 - return buildStatusResult(success ? "OK" : "Key not found or no TTL", startTime: startTime) - - case .rename(let key, let newKey): - let reply = try await conn.executeCommand(["RENAME", key, newKey]) - if case .error(let msg) = reply { - throw RedisPluginError(code: 0, message: "RENAME failed: \(msg)") - } - return buildStatusResult("OK", startTime: startTime) - - case .exists(let keys): - let args = ["EXISTS"] + keys - let result = try await conn.executeCommand(args) - let count = result.intValue ?? 0 - return PluginQueryResult( - columns: ["exists"], - columnTypeNames: ["Int64"], - rows: [[String(count)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeKeyOperation") - } - } - - // MARK: - Hash Operations - - func executeHashOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .hget(let key, let field): - let result = try await conn.executeCommand(["HGET", key, field]) - let value = result.stringValue - return PluginQueryResult( - columns: ["Field", "Value"], - columnTypeNames: ["String", "String"], - rows: [[field, value].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .hset(let key, let fieldValues): - var args = ["HSET", key] - for (field, value) in fieldValues { - args += [field, value] - } - let result = try await conn.executeCommand(args) - let added = result.intValue ?? 0 - return PluginQueryResult( - columns: ["added"], - columnTypeNames: ["Int64"], - rows: [[String(added)].asCells], - rowsAffected: added, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .hgetall(let key): - let result = try await conn.executeCommand(["HGETALL", key]) - return buildHashResult(result, startTime: startTime) - - case .hdel(let key, let fields): - let args = ["HDEL", key] + fields - let result = try await conn.executeCommand(args) - let removed = result.intValue ?? 0 - return PluginQueryResult( - columns: ["removed"], - columnTypeNames: ["Int64"], - rows: [[String(removed)].asCells], - rowsAffected: removed, - executionTime: Date().timeIntervalSince(startTime) - ) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeHashOperation") - } - } - - // MARK: - List Operations - - func executeListOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .lrange(let key, let start, let stop): - let result = try await conn.executeCommand(["LRANGE", key, String(start), String(stop)]) - return buildListResult(result, startOffset: start, startTime: startTime) - - case .lpush(let key, let values): - let args = ["LPUSH", key] + values - let result = try await conn.executeCommand(args) - let length = result.intValue ?? 0 - return PluginQueryResult( - columns: ["length"], - columnTypeNames: ["Int64"], - rows: [[String(length)].asCells], - rowsAffected: values.count, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .rpush(let key, let values): - let args = ["RPUSH", key] + values - let result = try await conn.executeCommand(args) - let length = result.intValue ?? 0 - return PluginQueryResult( - columns: ["length"], - columnTypeNames: ["Int64"], - rows: [[String(length)].asCells], - rowsAffected: values.count, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .llen(let key): - let result = try await conn.executeCommand(["LLEN", key]) - let length = result.intValue ?? 0 - return PluginQueryResult( - columns: ["Key", "Length"], - columnTypeNames: ["String", "Int64"], - rows: [[key, String(length)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeListOperation") - } - } - - // MARK: - Set Operations - - func executeSetOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .smembers(let key): - let result = try await conn.executeCommand(["SMEMBERS", key]) - return buildSetResult(result, startTime: startTime) - - case .sadd(let key, let members): - let args = ["SADD", key] + members - let result = try await conn.executeCommand(args) - let added = result.intValue ?? 0 - return PluginQueryResult( - columns: ["added"], - columnTypeNames: ["Int64"], - rows: [[String(added)].asCells], - rowsAffected: added, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .srem(let key, let members): - let args = ["SREM", key] + members - let result = try await conn.executeCommand(args) - let removed = result.intValue ?? 0 - return PluginQueryResult( - columns: ["removed"], - columnTypeNames: ["Int64"], - rows: [[String(removed)].asCells], - rowsAffected: removed, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .scard(let key): - let result = try await conn.executeCommand(["SCARD", key]) - let count = result.intValue ?? 0 - return PluginQueryResult( - columns: ["Key", "Cardinality"], - columnTypeNames: ["String", "Int64"], - rows: [[key, String(count)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeSetOperation") - } - } - - // MARK: - Sorted Set Operations - - func executeSortedSetOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .zrange(let key, let start, let stop, let flags): - var args = ["ZRANGE", key, start, stop] - args += flags - let withScores = flags.contains("WITHSCORES") - let result = try await conn.executeCommand(args) - return buildSortedSetResult(result, withScores: withScores, startTime: startTime) - - case .zadd(let key, let flags, let scoreMembers): - var args = ["ZADD", key] - args += flags - for (score, member) in scoreMembers { - args += [String(score), member] - } - let result = try await conn.executeCommand(args) - if flags.contains("INCR") { - // INCR mode returns the new score (or nil for NX miss) - let scoreStr = result.stringValue ?? "nil" - return PluginQueryResult( - columns: ["score"], - columnTypeNames: ["String"], - rows: [[scoreStr].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - let count = result.intValue ?? 0 - let columnName = flags.contains("CH") ? "changed" : "added" - return PluginQueryResult( - columns: [columnName], - columnTypeNames: ["Int64"], - rows: [[String(count)].asCells], - rowsAffected: count, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .zrem(let key, let members): - let args = ["ZREM", key] + members - let result = try await conn.executeCommand(args) - let removed = result.intValue ?? 0 - return PluginQueryResult( - columns: ["removed"], - columnTypeNames: ["Int64"], - rows: [[String(removed)].asCells], - rowsAffected: removed, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .zcard(let key): - let result = try await conn.executeCommand(["ZCARD", key]) - let count = result.intValue ?? 0 - return PluginQueryResult( - columns: ["Key", "Cardinality"], - columnTypeNames: ["String", "Int64"], - rows: [[key, String(count)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeSortedSetOperation") - } - } - - // MARK: - Stream Operations - - func executeStreamOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .xrange(let key, let start, let end, let count): - var args = ["XRANGE", key, start, end] - if let c = count { args += ["COUNT", String(c)] } - let result = try await conn.executeCommand(args) - return buildStreamResult(result, startTime: startTime) - - case .xlen(let key): - let result = try await conn.executeCommand(["XLEN", key]) - let length = result.intValue ?? 0 - return PluginQueryResult( - columns: ["Key", "Length"], - columnTypeNames: ["String", "Int64"], - rows: [[key, String(length)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeStreamOperation") - } - } - - // MARK: - Server Operations - - func executeServerOperation( - _ operation: RedisOperation, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - switch operation { - case .ping: - _ = try await conn.executeCommand(["PING"]) - return PluginQueryResult( - columns: ["ok"], - columnTypeNames: ["Int32"], - rows: [["1"].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .info(let section): - var args = ["INFO"] - if let s = section { args.append(s) } - let result = try await conn.executeCommand(args) - let infoText = result.stringValue ?? String(describing: result) - return PluginQueryResult( - columns: ["info"], - columnTypeNames: ["String"], - rows: [[infoText].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .dbsize: - let result = try await conn.executeCommand(["DBSIZE"]) - let count = result.intValue ?? 0 - return PluginQueryResult( - columns: ["keys"], - columnTypeNames: ["Int64"], - rows: [[String(count)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .flushdb: - _ = try await conn.executeCommand(["FLUSHDB"]) - return buildStatusResult("OK", startTime: startTime) - - case .select(let database): - try await conn.selectDatabase(database) - cachedScanPattern = nil - cachedScanKeys = nil - return buildStatusResult("OK", startTime: startTime) - - case .configGet(let parameter): - let result = try await conn.executeCommand(["CONFIG", "GET", parameter]) - return buildConfigResult(result, startTime: startTime) - - case .configSet(let parameter, let value): - _ = try await conn.executeCommand(["CONFIG", "SET", parameter, value]) - return buildStatusResult("OK", startTime: startTime) - - case .command(let args): - let result = try await conn.executeCommand(args) - return buildGenericResult(result, startTime: startTime) - - case .multi: - _ = try await conn.executeCommand(["MULTI"]) - return buildStatusResult("OK", startTime: startTime) - - case .exec: - let result = try await conn.executeCommand(["EXEC"]) - return buildGenericResult(result, startTime: startTime) - - case .discard: - _ = try await conn.executeCommand(["DISCARD"]) - return buildStatusResult("OK", startTime: startTime) - - default: - throw RedisPluginError(code: 0, message: "Unexpected operation in executeServerOperation") - } - } -} - -// MARK: - SCAN Helpers - -private extension RedisPluginDriver { - func scanAllKeys( - connection conn: RedisPluginConnection, - pattern: String?, - maxKeys: Int - ) async throws -> [String] { - var allKeys: [String] = [] - var cursor = "0" - - repeat { - var args = ["SCAN", cursor] - if let p = pattern { - args += ["MATCH", p] - } - args += ["COUNT", "1000"] - - let result = try await conn.executeCommand(args) - - guard case .array(let scanResult) = result, - scanResult.count == 2 else { - break - } - - let nextCursor: String - switch scanResult[0] { - case .string(let s): nextCursor = s - case .status(let s): nextCursor = s - case .data(let d): nextCursor = String(data: d, encoding: .utf8) ?? "0" - default: nextCursor = "0" - } - cursor = nextCursor - - if case .array(let keyReplies) = scanResult[1] { - for reply in keyReplies { - switch reply { - case .string(let k): allKeys.append(k) - case .data(let d): - if let k = String(data: d, encoding: .utf8) { allKeys.append(k) } - default: break - } - } - } - - if allKeys.count >= maxKeys { - allKeys = Array(allKeys.prefix(maxKeys)) - break - } - } while cursor != "0" - - return allKeys.sorted() - } - - func handleScanResult( - _ result: RedisReply, - connection conn: RedisPluginConnection, - startTime: Date - ) async throws -> PluginQueryResult { - guard case .array(let scanResult) = result, - scanResult.count == 2, - case .array(let keyReplies) = scanResult[1] else { - return buildEmptyKeyResult(startTime: startTime) - } - - let keys = keyReplies.compactMap { reply -> String? in - if case .string(let k) = reply { return k } - if case .data(let d) = reply { return String(data: d, encoding: .utf8) } - return nil - } - - let capped = Array(keys.prefix(PluginRowLimits.emergencyMax)) - let keysTruncated = keys.count > PluginRowLimits.emergencyMax - return try await buildKeyBrowseResult( - keys: capped, connection: conn, startTime: startTime, isTruncated: keysTruncated - ) - } -} - -// MARK: - Result Building - -private extension RedisPluginDriver { - static let previewLimit = 100 - static let previewMaxChars = 1_000 - - func buildKeyBrowseResult( - keys: [String], - connection conn: RedisPluginConnection, - startTime: Date, - isTruncated: Bool = false - ) async throws -> PluginQueryResult { - guard !keys.isEmpty else { - return buildEmptyKeyResult(startTime: startTime) - } - - var typeAndTtlCommands: [[String]] = [] - typeAndTtlCommands.reserveCapacity(keys.count * 2) - for key in keys { - typeAndTtlCommands.append(["TYPE", key]) - typeAndTtlCommands.append(["TTL", key]) - } - let typeAndTtlReplies = try await conn.executePipeline(typeAndTtlCommands) - - var typeNames: [String] = [] - typeNames.reserveCapacity(keys.count) - var ttlValues: [Int] = [] - ttlValues.reserveCapacity(keys.count) - for i in 0 ..< keys.count { - let typeName = (typeAndTtlReplies[i * 2].stringValue ?? "unknown").uppercased() - let ttl = typeAndTtlReplies[i * 2 + 1].intValue ?? -1 - typeNames.append(typeName) - ttlValues.append(ttl) - } - - var previewCommands: [[String]] = [] - previewCommands.reserveCapacity(keys.count) - var previewCommandIndices: [Int] = [] - previewCommandIndices.reserveCapacity(keys.count) - - for (i, key) in keys.enumerated() { - let command: [String]? = previewCommandForType(typeNames[i], key: key) - if let command { - previewCommandIndices.append(previewCommands.count) - previewCommands.append(command) - } else { - previewCommandIndices.append(-1) - } - } - - var previewReplies: [RedisReply] = [] - if !previewCommands.isEmpty { - previewReplies = try await conn.executePipeline(previewCommands) - } - - var rows: [[PluginCellValue]] = [] - rows.reserveCapacity(keys.count) - for (i, key) in keys.enumerated() { - let ttlStr = String(ttlValues[i]) - let pipelineIndex = previewCommandIndices[i] - let preview: String? - if pipelineIndex >= 0, pipelineIndex < previewReplies.count { - preview = formatPreviewReply( - previewReplies[pipelineIndex], type: typeNames[i] - ) - } else { - preview = nil - } - rows.append([key, typeNames[i], ttlStr, preview].asCells) - } - - return PluginQueryResult( - columns: ["Key", "Type", "TTL", "Value"], - columnTypeNames: ["String", "RedisType", "RedisInt", "RedisRaw"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime), - isTruncated: isTruncated - ) - } - - func previewCommandForType(_ type: String, key: String) -> [String]? { - switch type.lowercased() { - case "string": - return ["GET", key] - case "hash": - return ["HSCAN", key, "0", "COUNT", String(Self.previewLimit)] - case "list": - return ["LRANGE", key, "0", String(Self.previewLimit - 1)] - case "set": - return ["SSCAN", key, "0", "COUNT", String(Self.previewLimit)] - case "zset": - return ["ZRANGE", key, "0", String(Self.previewLimit - 1), "WITHSCORES"] - case "stream": - return ["XREVRANGE", key, "+", "-", "COUNT", "5"] - default: - return nil - } - } - - func formatPreviewReply(_ reply: RedisReply, type: String) -> String? { - switch type.lowercased() { - case "string": - return truncatePreview(redisReplyToString(reply)) - - case "hash": - let array: [RedisReply] - if case .array(let scanResult) = reply, - scanResult.count == 2, - let items = scanResult[1].arrayValue { - array = items - } else if let items = reply.arrayValue, !items.isEmpty { - array = items - } else { - return "{}" - } - guard !array.isEmpty else { return "{}" } - var pairs: [String] = [] - var idx = 0 - while idx + 1 < array.count { - let field = redisReplyToString(array[idx]) - let value = redisReplyToString(array[idx + 1]) - pairs.append( - "\"\(escapeJsonString(field))\":\"\(escapeJsonString(value))\"" - ) - idx += 2 - } - return truncatePreview("{\(pairs.joined(separator: ","))}") - - case "list": - guard let items = reply.arrayValue else { return "[]" } - let quoted = items.map { "\"\(escapeJsonString(redisReplyToString($0)))\"" } - return truncatePreview("[\(quoted.joined(separator: ", "))]") - - case "set": - let members: [RedisReply] - if case .array(let scanResult) = reply, - scanResult.count == 2, - let items = scanResult[1].arrayValue { - members = items - } else if let items = reply.arrayValue { - members = items - } else { - return "[]" - } - let quoted = members.map { "\"\(escapeJsonString(redisReplyToString($0)))\"" } - return truncatePreview("[\(quoted.joined(separator: ", "))]") - - case "zset": - // Parse WITHSCORES result: alternating member, score pairs - guard let items = reply.arrayValue, !items.isEmpty else { return "[]" } - var pairs: [String] = [] - var i = 0 - while i + 1 < items.count { - pairs.append("\(redisReplyToString(items[i])):\(redisReplyToString(items[i + 1]))") - i += 2 - } - return truncatePreview(pairs.joined(separator: ", ")) - - case "stream": - // Parse XREVRANGE result: array of [id, [field, value, ...]] entries - guard let entries = reply.arrayValue, !entries.isEmpty else { - return "(0 entries)" - } - var entryStrings: [String] = [] - for entry in entries { - guard let parts = entry.arrayValue, parts.count >= 2, - let fields = parts[1].arrayValue else { - continue - } - let entryId = redisReplyToString(parts[0]) - var fieldPairs: [String] = [] - var j = 0 - while j + 1 < fields.count { - fieldPairs.append("\(redisReplyToString(fields[j]))=\(redisReplyToString(fields[j + 1]))") - j += 2 - } - entryStrings.append("\(entryId): \(fieldPairs.joined(separator: ", "))") - } - return truncatePreview(entryStrings.joined(separator: "; ")) - - default: - return nil - } - } - - func truncatePreview(_ value: String?) -> String? { - guard let value else { return nil } - let nsValue = value as NSString - if nsValue.length > Self.previewMaxChars { - return nsValue.substring(to: Self.previewMaxChars) + "..." - } - return value - } - - func escapeJsonString(_ str: String) -> String { - var result = "" - for scalar in str.unicodeScalars { - switch scalar { - case "\\": result += "\\\\" - case "\"": result += "\\\"" - case "\n": result += "\\n" - case "\r": result += "\\r" - case "\t": result += "\\t" - default: - if scalar.value < 0x20 { - result += String(format: "\\u%04X", scalar.value) - } else { - result += String(scalar) - } - } - } - return result - } - - func buildEmptyKeyResult(startTime: Date) -> PluginQueryResult { - PluginQueryResult( - columns: ["Key", "Type", "TTL", "Value"], - columnTypeNames: ["String", "RedisType", "RedisInt", "RedisRaw"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - func buildStatusResult(_ message: String, startTime: Date) -> PluginQueryResult { - PluginQueryResult( - columns: ["status"], - columnTypeNames: ["String"], - rows: [[message].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - func buildGenericResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { - switch result { - case .string(let s), .status(let s): - return PluginQueryResult( - columns: ["result"], - columnTypeNames: ["String"], - rows: [[s].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .integer(let i): - return PluginQueryResult( - columns: ["result"], - columnTypeNames: ["Int64"], - rows: [[String(i)].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .data(let d): - let str = String(data: d, encoding: .utf8) ?? d.base64EncodedString() - return PluginQueryResult( - columns: ["result"], - columnTypeNames: ["String"], - rows: [[str].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .array(let items): - let rows = items.map { ([redisReplyToString($0)] as [String?]).asCells } - return PluginQueryResult( - columns: ["result"], - columnTypeNames: ["String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .error(let e): - return PluginQueryResult( - columns: ["result"], - columnTypeNames: ["String"], - rows: [[e].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - - case .null: - return PluginQueryResult( - columns: ["result"], - columnTypeNames: ["String"], - rows: [["(nil)"].asCells], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - } - - func redisReplyToString(_ reply: RedisReply) -> String { - switch reply { - case .string(let s), .status(let s), .error(let s): return s - case .integer(let i): return String(i) - case .data(let d): return String(data: d, encoding: .utf8) ?? d.base64EncodedString() - case .array(let items): return "[\(items.map { redisReplyToString($0) }.joined(separator: ", "))]" - case .null: return "(nil)" - } - } - - func buildHashResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { - guard let items = result.arrayValue, !items.isEmpty else { - return PluginQueryResult( - columns: ["Field", "Value"], - columnTypeNames: ["String", "String"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - var rows: [[PluginCellValue]] = [] - var i = 0 - while i + 1 < items.count { - rows.append([redisReplyToString(items[i]), redisReplyToString(items[i + 1])].asCells) - i += 2 - } - - return PluginQueryResult( - columns: ["Field", "Value"], - columnTypeNames: ["String", "String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - func buildListResult(_ result: RedisReply, startOffset: Int = 0, startTime: Date) -> PluginQueryResult { - guard let items = result.arrayValue else { - return PluginQueryResult( - columns: ["Index", "Value"], - columnTypeNames: ["Int64", "String"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - let rows = items.enumerated().map { index, item -> [PluginCellValue] in - ([String(startOffset + index), redisReplyToString(item)] as [String?]).asCells - } - - return PluginQueryResult( - columns: ["Index", "Value"], - columnTypeNames: ["Int64", "String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - func buildSetResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { - guard let items = result.arrayValue else { - return PluginQueryResult( - columns: ["Member"], - columnTypeNames: ["String"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - let rows = items.map { ([redisReplyToString($0)] as [String?]).asCells } - - return PluginQueryResult( - columns: ["Member"], - columnTypeNames: ["String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - func buildSortedSetResult(_ result: RedisReply, withScores: Bool, startTime: Date) -> PluginQueryResult { - guard let items = result.arrayValue else { - return PluginQueryResult( - columns: withScores ? ["Member", "Score"] : ["Member"], - columnTypeNames: withScores ? ["String", "Double"] : ["String"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - if withScores { - var rows: [[PluginCellValue]] = [] - var i = 0 - while i + 1 < items.count { - rows.append([redisReplyToString(items[i]), redisReplyToString(items[i + 1])].asCells) - i += 2 - } - return PluginQueryResult( - columns: ["Member", "Score"], - columnTypeNames: ["String", "Double"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } else { - let rows = items.map { ([redisReplyToString($0)] as [String?]).asCells } - return PluginQueryResult( - columns: ["Member"], - columnTypeNames: ["String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - } - - func buildStreamResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { - guard let entries = result.arrayValue else { - return PluginQueryResult( - columns: ["ID", "Fields"], - columnTypeNames: ["String", "String"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - var rows: [[PluginCellValue]] = [] - for entry in entries { - guard let entryParts = entry.arrayValue, entryParts.count >= 2, - let fields = entryParts[1].arrayValue else { - continue - } - let entryId = redisReplyToString(entryParts[0]) - - var fieldPairs: [String] = [] - var i = 0 - while i + 1 < fields.count { - fieldPairs.append("\(redisReplyToString(fields[i]))=\(redisReplyToString(fields[i + 1]))") - i += 2 - } - rows.append([entryId, fieldPairs.joined(separator: ", ")].asCells) - } - - return PluginQueryResult( - columns: ["ID", "Fields"], - columnTypeNames: ["String", "String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - func buildConfigResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { - guard let items = result.arrayValue, !items.isEmpty else { - return PluginQueryResult( - columns: ["Parameter", "Value"], - columnTypeNames: ["String", "String"], - rows: [], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } - - var rows: [[PluginCellValue]] = [] - var i = 0 - while i + 1 < items.count { - rows.append([redisReplyToString(items[i]), redisReplyToString(items[i + 1])].asCells) - i += 2 - } - - return PluginQueryResult( - columns: ["Parameter", "Value"], - columnTypeNames: ["String", "String"], - rows: rows, - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } -} diff --git a/Plugins/SQLExportPlugin/SQLExportPlugin.swift b/Plugins/SQLExportPlugin/SQLExportPlugin.swift index 00a1d515a..578c31e6d 100644 --- a/Plugins/SQLExportPlugin/SQLExportPlugin.swift +++ b/Plugins/SQLExportPlugin/SQLExportPlugin.swift @@ -526,7 +526,7 @@ final class SQLExportPlugin: ExportFormatPlugin, SettablePlugin { let insertPrefix = "INSERT INTO \(tableRef) (\(quotedColumns))\(overriding) VALUES\n" let numericIndices: Set = Set(includedColumnIndices.filter { idx in - idx < columnTypeNames.count && isNumericColumnType(columnTypeNames[idx]) + idx < columnTypeNames.count && PluginExportUtilities.isNumericColumnType(columnTypeNames[idx]) }) let effectiveBatchSize = batchSize <= 1 ? 1 : batchSize @@ -546,7 +546,7 @@ final class SQLExportPlugin: ExportFormatPlugin, SettablePlugin { let hex = data.map { String(format: "%02X", $0) }.joined() return "X'\(hex)'" case .text(let val): - if numericIndices.contains(colIndex) && isNumericLiteral(val) { + if numericIndices.contains(colIndex) && PluginNumericLiteral.isValid(val) { return val } let escaped = dataSource.escapeStringLiteral(val) @@ -571,19 +571,6 @@ final class SQLExportPlugin: ExportFormatPlugin, SettablePlugin { } } - private func isNumericColumnType(_ typeName: String) -> Bool { - let numericPrefixes = [ - "int", "bigint", "decimal", "float", "double", "numeric", - "real", "smallint", "tinyint", "mediumint", "integer", "number" - ] - let lower = typeName.lowercased() - return numericPrefixes.contains { lower.hasPrefix($0) } - } - - private func isNumericLiteral(_ val: String) -> Bool { - val.allSatisfy { $0.isNumber || $0 == "." || $0 == "-" || $0 == "+" || $0 == "e" || $0 == "E" } - } - private func compressFile(source: URL, destination: URL) async throws { let gzipPath = "/usr/bin/gzip" guard FileManager.default.isExecutableFile(atPath: gzipPath) else { diff --git a/Plugins/TableProPluginKit/PluginDatabaseDriver.swift b/Plugins/TableProPluginKit/PluginDatabaseDriver.swift index 3f469a184..1f9f1330b 100644 --- a/Plugins/TableProPluginKit/PluginDatabaseDriver.swift +++ b/Plugins/TableProPluginKit/PluginDatabaseDriver.swift @@ -1,5 +1,44 @@ import Foundation +public enum PluginNumericLiteral { + public static func isValid(_ value: String) -> Bool { + guard !value.isEmpty else { return false } + var scanner = value.makeIterator() + var hasDigit = false + var hasDot = false + var hasE = false + + var first = true + while let c = scanner.next() { + if first { + first = false + if c == "-" || c == "+" { continue } + } + if c.isNumber { + hasDigit = true + continue + } + if c == "." && !hasDot && !hasE { + hasDot = true + continue + } + if (c == "e" || c == "E") && hasDigit && !hasE { + hasE = true + hasDigit = false + if let next = scanner.next() { + if next == "+" || next == "-" || next.isNumber { + if next.isNumber { hasDigit = true } + continue + } + } + return false + } + return false + } + return hasDigit + } +} + @frozen public enum ParameterStyle: String, Sendable { case questionMark // ? @@ -535,40 +574,7 @@ public extension PluginDatabaseDriver { } static func isNumericLiteral(_ value: String) -> Bool { - guard !value.isEmpty else { return false } - var scanner = value.makeIterator() - var hasDigit = false - var hasDot = false - var hasE = false - - var first = true - while let c = scanner.next() { - if first { - first = false - if c == "-" || c == "+" { continue } - } - if c.isNumber { - hasDigit = true - continue - } - if c == "." && !hasDot && !hasE { - hasDot = true - continue - } - if (c == "e" || c == "E") && hasDigit && !hasE { - hasE = true - hasDigit = false - if let next = scanner.next() { - if next == "+" || next == "-" || next.isNumber { - if next.isNumber { hasDigit = true } - continue - } - } - return false - } - return false - } - return hasDigit + PluginNumericLiteral.isValid(value) } func executeUserQuery(query: String, rowCap: Int?, parameters: [PluginCellValue]?) async throws -> PluginQueryResult { @@ -592,3 +598,99 @@ public extension PluginDatabaseDriver { ) } } + +public enum PluginSQLFilter { + public static func escapeForLike(_ text: String) -> String { + text + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "%", with: "\\%") + .replacingOccurrences(of: "_", with: "\\_") + .replacingOccurrences(of: "'", with: "''") + } + + public static func buildOrderByClause( + sortColumns: [(columnIndex: Int, ascending: Bool)], + columns: [String], + quoteIdentifier: (String) -> String + ) -> String? { + let parts = sortColumns.compactMap { sortCol -> String? in + guard sortCol.columnIndex >= 0, sortCol.columnIndex < columns.count else { return nil } + let direction = sortCol.ascending ? "ASC" : "DESC" + return "\(quoteIdentifier(columns[sortCol.columnIndex])) \(direction)" + } + guard !parts.isEmpty else { return nil } + return "ORDER BY " + parts.joined(separator: ", ") + } + + public static func buildWhereClause( + filters: [(column: String, op: String, value: String)], + logicMode: String, + quoteIdentifier: (String) -> String, + escapeValue: (String) -> String, + regexCondition: (_ quotedColumn: String, _ value: String) -> String? + ) -> String { + let conditions = filters.compactMap { filter in + buildFilterCondition( + column: filter.column, + op: filter.op, + value: filter.value, + quoteIdentifier: quoteIdentifier, + escapeValue: escapeValue, + regexCondition: regexCondition + ) + } + guard !conditions.isEmpty else { return "" } + let separator = logicMode == "and" ? " AND " : " OR " + return conditions.joined(separator: separator) + } + + public static func buildFilterCondition( + column: String, + op: String, + value: String, + quoteIdentifier: (String) -> String, + escapeValue: (String) -> String, + regexCondition: (_ quotedColumn: String, _ value: String) -> String? + ) -> String? { + let quoted = quoteIdentifier(column) + switch op { + case "=": return "\(quoted) = \(escapeValue(value))" + case "!=": return "\(quoted) != \(escapeValue(value))" + case ">": return "\(quoted) > \(escapeValue(value))" + case ">=": return "\(quoted) >= \(escapeValue(value))" + case "<": return "\(quoted) < \(escapeValue(value))" + case "<=": return "\(quoted) <= \(escapeValue(value))" + case "IS NULL": return "\(quoted) IS NULL" + case "IS NOT NULL": return "\(quoted) IS NOT NULL" + case "IS EMPTY": return "(\(quoted) IS NULL OR \(quoted) = '')" + case "IS NOT EMPTY": return "(\(quoted) IS NOT NULL AND \(quoted) != '')" + case "CONTAINS": + return "\(quoted) LIKE '%\(escapeForLike(value))%' ESCAPE '\\'" + case "NOT CONTAINS": + return "\(quoted) NOT LIKE '%\(escapeForLike(value))%' ESCAPE '\\'" + case "STARTS WITH": + return "\(quoted) LIKE '\(escapeForLike(value))%' ESCAPE '\\'" + case "ENDS WITH": + return "\(quoted) LIKE '%\(escapeForLike(value))' ESCAPE '\\'" + case "IN": + let values = value.split(separator: ",") + .map { escapeValue($0.trimmingCharacters(in: .whitespaces)) } + .joined(separator: ", ") + return values.isEmpty ? nil : "\(quoted) IN (\(values))" + case "NOT IN": + let values = value.split(separator: ",") + .map { escapeValue($0.trimmingCharacters(in: .whitespaces)) } + .joined(separator: ", ") + return values.isEmpty ? nil : "\(quoted) NOT IN (\(values))" + case "BETWEEN": + let parts = value.split(separator: ",", maxSplits: 1) + guard parts.count == 2 else { return nil } + let v1 = escapeValue(parts[0].trimmingCharacters(in: .whitespaces)) + let v2 = escapeValue(parts[1].trimmingCharacters(in: .whitespaces)) + return "\(quoted) BETWEEN \(v1) AND \(v2)" + case "REGEX": + return regexCondition(quoted, value) + default: return nil + } + } +} diff --git a/Plugins/TableProPluginKit/PluginExportUtilities.swift b/Plugins/TableProPluginKit/PluginExportUtilities.swift index 5bdf9f718..222ec7ee6 100644 --- a/Plugins/TableProPluginKit/PluginExportUtilities.swift +++ b/Plugins/TableProPluginKit/PluginExportUtilities.swift @@ -84,6 +84,15 @@ public enum PluginExportUtilities { result = result.replacingOccurrences(of: "--", with: "") return result } + + public static func isNumericColumnType(_ typeName: String) -> Bool { + let numericPrefixes = [ + "int", "bigint", "decimal", "float", "double", "numeric", + "real", "smallint", "tinyint", "mediumint", "integer", "number" + ] + let lower = typeName.lowercased() + return numericPrefixes.contains { lower.hasPrefix($0) } + } } public extension String { diff --git a/TablePro/Core/ChangeTracking/DataChangeManager.swift b/TablePro/Core/ChangeTracking/DataChangeManager.swift index 1980e0f75..e7f3c5ea6 100644 --- a/TablePro/Core/ChangeTracking/DataChangeManager.swift +++ b/TablePro/Core/ChangeTracking/DataChangeManager.swift @@ -53,7 +53,7 @@ final class DataChangeManager: ChangeManaging { var primaryKeyColumns: [String] = [] /// First PK column, for contexts that need a single column (paste, filters) var primaryKeyColumn: String? { primaryKeyColumns.first } - var databaseType: DatabaseType = .mysql + var databaseType: DatabaseType? var pluginDriver: (any PluginDatabaseDriver)? var columns: [String] = [] @@ -94,7 +94,7 @@ final class DataChangeManager: ChangeManaging { tableName: String, columns: [String], primaryKeyColumns: [String], - databaseType: DatabaseType = .mysql, + databaseType: DatabaseType, triggerReload: Bool = true ) { self.tableName = tableName @@ -407,9 +407,15 @@ final class DataChangeManager: ChangeManaging { } } + guard let databaseType else { + throw DatabaseError.queryFailed( + "Cannot generate statements: table dialect not configured" + ) + } + if PluginManager.shared.editorLanguage(for: databaseType) != .sql { throw DatabaseError.queryFailed( - "Cannot generate statements for \(databaseType.rawValue) — plugin driver not initialized" + "Cannot generate statements for \(databaseType.rawValue): plugin driver not initialized" ) } diff --git a/TablePro/Core/Cloudflare/CloudflareTunnelManager.swift b/TablePro/Core/Cloudflare/CloudflareTunnelManager.swift index ad6c3872d..58e6361b2 100644 --- a/TablePro/Core/Cloudflare/CloudflareTunnelManager.swift +++ b/TablePro/Core/Cloudflare/CloudflareTunnelManager.swift @@ -310,8 +310,12 @@ actor CloudflareTunnelManager { UserDefaults.standard.removeObject(forKey: Self.stalePidsDefaultsKey) return } - guard let data = try? JSONEncoder().encode(records) else { return } - UserDefaults.standard.set(data, forKey: Self.stalePidsDefaultsKey) + do { + let data = try JSONEncoder().encode(records) + UserDefaults.standard.set(data, forKey: Self.stalePidsDefaultsKey) + } catch { + Self.logger.error("Failed to persist cloudflared PID records, leaked processes may survive to next launch: \(error.localizedDescription, privacy: .public)") + } } private static func isLiveCloudflared(_ record: CloudflaredPidRecord) -> Bool { diff --git a/TablePro/Core/Database/AWS/AWSSSO.swift b/TablePro/Core/Database/AWS/AWSSSO.swift index d33a86828..695b43cdc 100644 --- a/TablePro/Core/Database/AWS/AWSSSO.swift +++ b/TablePro/Core/Database/AWS/AWSSSO.swift @@ -300,6 +300,6 @@ enum AWSSSO { data.withUnsafeBytes { ptr in _ = CC_SHA1(ptr.baseAddress, CC_LONG(data.count), &hash) } - return hash.map { String(format: "%02x", $0) }.joined() + return hash.hexEncoded } } diff --git a/TablePro/Core/Database/AWS/AWSSigV4.swift b/TablePro/Core/Database/AWS/AWSSigV4.swift index 63d38b6fd..db96d0ff0 100644 --- a/TablePro/Core/Database/AWS/AWSSigV4.swift +++ b/TablePro/Core/Database/AWS/AWSSigV4.swift @@ -27,7 +27,7 @@ enum AWSSigV4 { } static func hmacHex(key: Data, data: Data) -> String { - hmac(key: key, data: data).map { String(format: "%02x", $0) }.joined() + hmac(key: key, data: data).hexEncoded } static func sha256Hex(_ data: Data) -> String { @@ -35,7 +35,7 @@ enum AWSSigV4 { data.withUnsafeBytes { ptr in _ = CC_SHA256(ptr.baseAddress, CC_LONG(data.count), &hash) } - return hash.map { String(format: "%02x", $0) }.joined() + return hash.hexEncoded } static func deriveSigningKey(secretKey: String, dateStamp: String, region: String, service: String) -> Data { diff --git a/TablePro/Core/MCP/Auth/MCPBearerTokenAuthenticator.swift b/TablePro/Core/MCP/Auth/MCPBearerTokenAuthenticator.swift index 5177af407..4204f0cef 100644 --- a/TablePro/Core/MCP/Auth/MCPBearerTokenAuthenticator.swift +++ b/TablePro/Core/MCP/Auth/MCPBearerTokenAuthenticator.swift @@ -207,7 +207,7 @@ public actor MCPBearerTokenAuthenticator: MCPAuthenticator { internal static func fingerprint(of token: String) -> String { guard let data = token.data(using: .utf8) else { return "" } let digest = SHA256.hash(data: data) - let hex = digest.map { String(format: "%02x", $0) }.joined() + let hex = digest.hexEncoded return String(hex.prefix(16)) } } diff --git a/TablePro/Core/MCP/MCPTokenStore.swift b/TablePro/Core/MCP/MCPTokenStore.swift index f1b848848..bc71debf1 100644 --- a/TablePro/Core/MCP/MCPTokenStore.swift +++ b/TablePro/Core/MCP/MCPTokenStore.swift @@ -341,7 +341,7 @@ actor MCPTokenStore { let input = salt + plaintext guard let data = input.data(using: .utf8) else { return "" } let digest = SHA256.hash(data: data) - return digest.map { String(format: "%02x", $0) }.joined() + return digest.hexEncoded } private func base64UrlEncode(_ data: Data) -> String { diff --git a/TablePro/Core/Plugins/PluginDriverAdapter.swift b/TablePro/Core/Plugins/PluginDriverAdapter.swift index d4d9955b1..e95e5275b 100644 --- a/TablePro/Core/Plugins/PluginDriverAdapter.swift +++ b/TablePro/Core/Plugins/PluginDriverAdapter.swift @@ -85,7 +85,7 @@ final class PluginDriverAdapter: DatabaseDriver, SchemaSwitchable { case let d as Date: return Self.iso8601Formatter.string(from: d) case let data as Data: - return data.map { String(format: "%02x", $0) }.joined() + return data.hexEncoded case let uuid as UUID: return uuid.uuidString default: diff --git a/TablePro/Core/Plugins/PluginInstaller.swift b/TablePro/Core/Plugins/PluginInstaller.swift index 6a146b3cd..c06bcf454 100644 --- a/TablePro/Core/Plugins/PluginInstaller.swift +++ b/TablePro/Core/Plugins/PluginInstaller.swift @@ -197,7 +197,7 @@ actor PluginInstaller { let payload = try Data(contentsOf: tempDownloadURL) let digest = SHA256.hash(data: payload) - let hex = digest.map { String(format: "%02x", $0) }.joined() + let hex = digest.hexEncoded guard hex == binary.sha256.lowercased() else { throw PluginError.checksumMismatch } diff --git a/TablePro/Core/SSH/SSHPathUtilities.swift b/TablePro/Core/SSH/SSHPathUtilities.swift index ff995fb18..db75e4557 100644 --- a/TablePro/Core/SSH/SSHPathUtilities.swift +++ b/TablePro/Core/SSH/SSHPathUtilities.swift @@ -85,7 +85,7 @@ struct SSHTokenContext: Sendable { if result.contains("%C") { let basis = "\(localHostnameFQDN())\(hostname ?? "")\(port.map(String.init) ?? "")\(remoteUser ?? "")" let digest = Insecure.SHA1.hash(data: Data(basis.utf8)) - let hex = digest.map { String(format: "%02x", $0) }.joined() + let hex = digest.hexEncoded result = result.replacingOccurrences(of: "%C", with: hex) } diff --git a/TablePro/Core/SchemaTracking/StructureChangeManager.swift b/TablePro/Core/SchemaTracking/StructureChangeManager.swift index da9b9976a..8a606edd5 100644 --- a/TablePro/Core/SchemaTracking/StructureChangeManager.swift +++ b/TablePro/Core/SchemaTracking/StructureChangeManager.swift @@ -32,7 +32,6 @@ final class StructureChangeManager: ChangeManaging { var workingPrimaryKey: [String] = [] var tableName: String? - var databaseType: DatabaseType = .mysql // MARK: - Undo/Redo Support @@ -61,11 +60,9 @@ final class StructureChangeManager: ChangeManaging { columns: [ColumnInfo], indexes: [IndexInfo], foreignKeys: [ForeignKeyInfo], - primaryKey: [String], - databaseType: DatabaseType + primaryKey: [String] ) { self.tableName = tableName - self.databaseType = databaseType // Convert to definitions self.currentColumns = columns.map { EditableColumnDefinition.from($0) } diff --git a/TablePro/Core/Services/Formatting/ValueDisplayFormatService.swift b/TablePro/Core/Services/Formatting/ValueDisplayFormatService.swift index 1e919dabb..82fef04d9 100644 --- a/TablePro/Core/Services/Formatting/ValueDisplayFormatService.swift +++ b/TablePro/Core/Services/Formatting/ValueDisplayFormatService.swift @@ -116,7 +116,7 @@ final class ValueDisplayFormatService { // Try raw binary bytes (isoLatin1 encoding from MySQL) if let data = rawValue.data(using: .isoLatin1), data.count == 16 { let bytes = [UInt8](data) - let hex = bytes.map { String(format: "%02x", $0) }.joined() + let hex = bytes.hexEncoded return insertUuidHyphens(hex) } diff --git a/TablePro/Extensions/Sequence+HexEncoded.swift b/TablePro/Extensions/Sequence+HexEncoded.swift new file mode 100644 index 000000000..fe5a1c002 --- /dev/null +++ b/TablePro/Extensions/Sequence+HexEncoded.swift @@ -0,0 +1,7 @@ +import Foundation + +extension Sequence where Element == UInt8 { + var hexEncoded: String { + map { String(format: "%02x", $0) }.joined() + } +} diff --git a/TablePro/Extensions/String+SHA256.swift b/TablePro/Extensions/String+SHA256.swift index 001a933e2..09277d317 100644 --- a/TablePro/Extensions/String+SHA256.swift +++ b/TablePro/Extensions/String+SHA256.swift @@ -13,6 +13,6 @@ extension String { var sha256: String { let data = Data(utf8) let digest = SHA256.hash(data: data) - return digest.map { String(format: "%02x", $0) }.joined() + return digest.hexEncoded } } diff --git a/TablePro/Models/Connection/DatabaseConnection.swift b/TablePro/Models/Connection/DatabaseConnection.swift index c420c17bd..8196a64bd 100644 --- a/TablePro/Models/Connection/DatabaseConnection.swift +++ b/TablePro/Models/Connection/DatabaseConnection.swift @@ -64,9 +64,6 @@ extension DatabaseType { static var allKnownTypes: [DatabaseType] { PluginMetadataRegistry.shared.allRegisteredTypeIds().map { DatabaseType(rawValue: $0) } } - - /// Compatibility shim for CaseIterable call sites. - static var allCases: [DatabaseType] { allKnownTypes } } extension DatabaseType { diff --git a/TablePro/Theme/ThemeLayout.swift b/TablePro/Theme/ThemeFonts.swift similarity index 98% rename from TablePro/Theme/ThemeLayout.swift rename to TablePro/Theme/ThemeFonts.swift index 4f796657c..8fa0901dd 100644 --- a/TablePro/Theme/ThemeLayout.swift +++ b/TablePro/Theme/ThemeFonts.swift @@ -1,5 +1,5 @@ // -// ThemeLayout.swift +// ThemeFonts.swift // TablePro // diff --git a/TablePro/Theme/ThemeRegistryInstaller.swift b/TablePro/Theme/ThemeRegistryInstaller.swift index 75aca2d4e..159ba82f4 100644 --- a/TablePro/Theme/ThemeRegistryInstaller.swift +++ b/TablePro/Theme/ThemeRegistryInstaller.swift @@ -213,7 +213,7 @@ internal final class ThemeRegistryInstaller { let downloadedData = try Data(contentsOf: tempDownloadURL) let digest = SHA256.hash(data: downloadedData) - let hexChecksum = digest.map { String(format: "%02x", $0) }.joined() + let hexChecksum = digest.hexEncoded if hexChecksum != resolved.sha256.lowercased() { throw PluginError.checksumMismatch diff --git a/TablePro/Views/Integrations/IntegrationsActivityLogPane.swift b/TablePro/Views/Integrations/IntegrationsActivityLogPane.swift index 45d9afe1f..6df0fd756 100644 --- a/TablePro/Views/Integrations/IntegrationsActivityLogPane.swift +++ b/TablePro/Views/Integrations/IntegrationsActivityLogPane.swift @@ -233,18 +233,26 @@ struct IntegrationsActivityLogPane: View { panel.canCreateDirectories = true panel.title = String(localized: "Export Activity Log") - guard panel.runModal() == .OK, let url = panel.url else { return } + if let window = AlertHelper.resolveWindow(nil) { + panel.beginSheetModal(for: window) { response in + guard response == .OK, let url = panel.url else { return } + writeActivityLog(to: url) + } + } else { + guard panel.runModal() == .OK, let url = panel.url else { return } + writeActivityLog(to: url) + } + } - let csv = csvString(for: filteredEntries) + private func writeActivityLog(to url: URL) { do { - try csv.write(to: url, atomically: true, encoding: .utf8) + try csvString(for: filteredEntries).write(to: url, atomically: true, encoding: .utf8) } catch { - let alert = NSAlert() - alert.messageText = String(localized: "Could not export activity log") - alert.informativeText = error.localizedDescription - alert.alertStyle = .warning - alert.addButton(withTitle: String(localized: "OK")) - alert.runModal() + AlertHelper.showErrorSheet( + title: String(localized: "Could not export activity log"), + message: error.localizedDescription, + window: nil + ) } } diff --git a/TablePro/Views/Structure/TableStructureView+DataLoading.swift b/TablePro/Views/Structure/TableStructureView+DataLoading.swift index 920ac4470..8d5342743 100644 --- a/TablePro/Views/Structure/TableStructureView+DataLoading.swift +++ b/TablePro/Views/Structure/TableStructureView+DataLoading.swift @@ -98,8 +98,7 @@ extension TableStructureView { columns: columns, indexes: indexes, foreignKeys: foreignKeys, - primaryKey: primaryKey, - databaseType: connection.type + primaryKey: primaryKey ) } diff --git a/TableProTests/Core/ChangeTracking/AnyChangeManagerTests.swift b/TableProTests/Core/ChangeTracking/AnyChangeManagerTests.swift index 3d486f880..6047b4467 100644 --- a/TableProTests/Core/ChangeTracking/AnyChangeManagerTests.swift +++ b/TableProTests/Core/ChangeTracking/AnyChangeManagerTests.swift @@ -18,7 +18,7 @@ struct AnyChangeManagerTests { @Test("DataChangeManager wrapper: hasChanges forwards correctly") func dataManagerHasChangesForwards() { let dataManager = DataChangeManager() - dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"]) + dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"], databaseType: .mysql) let wrapper = AnyChangeManager(dataManager) #expect(wrapper.hasChanges == false) @@ -32,7 +32,7 @@ struct AnyChangeManagerTests { @Test("DataChangeManager wrapper: reloadVersion forwards correctly") func dataManagerReloadVersionForwards() { let dataManager = DataChangeManager() - dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"]) + dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"], databaseType: .mysql) let wrapper = AnyChangeManager(dataManager) let initialVersion = wrapper.reloadVersion @@ -44,7 +44,7 @@ struct AnyChangeManagerTests { @Test("isRowDeleted delegates correctly for DataChangeManager") func isRowDeletedDelegatesCorrectly() { let dataManager = DataChangeManager() - dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"]) + dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"], databaseType: .mysql) let wrapper = AnyChangeManager(dataManager) #expect(wrapper.isRowDeleted(0) == false) @@ -57,7 +57,7 @@ struct AnyChangeManagerTests { @Test("recordCellChange forwards to DataChangeManager") func recordCellChangeForwards() { let dataManager = DataChangeManager() - dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"]) + dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"], databaseType: .mysql) let wrapper = AnyChangeManager(dataManager) wrapper.recordCellChange(rowIndex: 0, columnIndex: 1, columnName: "name", oldValue: "Alice", newValue: "Bob", originalRow: ["1", "Alice"]) @@ -69,7 +69,7 @@ struct AnyChangeManagerTests { @Test("No retain cycle — wrapper can be deallocated") func noRetainCycleOnWrapper() { let dataManager = DataChangeManager() - dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"]) + dataManager.configureForTable(tableName: "users", columns: ["id", "name"], primaryKeyColumns: ["id"], databaseType: .mysql) weak var weakWrapper: AnyChangeManager? diff --git a/TableProTests/Core/ChangeTracking/DataChangeManagerExtendedTests.swift b/TableProTests/Core/ChangeTracking/DataChangeManagerExtendedTests.swift index a106f57fa..62cd7165d 100644 --- a/TableProTests/Core/ChangeTracking/DataChangeManagerExtendedTests.swift +++ b/TableProTests/Core/ChangeTracking/DataChangeManagerExtendedTests.swift @@ -25,7 +25,8 @@ struct DataChangeManagerExtendedTests { manager.configureForTable( tableName: "test_table", columns: columns, - primaryKeyColumns: [pk].compactMap { $0 } + primaryKeyColumns: [pk].compactMap { $0 }, + databaseType: .mysql ) return manager } @@ -694,6 +695,7 @@ struct DataChangeManagerExtendedTests { tableName: "test", columns: ["a", "b"], primaryKeyColumns: ["a"], + databaseType: .mysql, triggerReload: false ) #expect(manager.reloadVersion == before) diff --git a/TableProTests/Core/ChangeTracking/DataChangeManagerTests.swift b/TableProTests/Core/ChangeTracking/DataChangeManagerTests.swift index 9ef1a5968..9763e24a0 100644 --- a/TableProTests/Core/ChangeTracking/DataChangeManagerTests.swift +++ b/TableProTests/Core/ChangeTracking/DataChangeManagerTests.swift @@ -40,13 +40,31 @@ struct DataChangeManagerTests { #expect(manager.databaseType == .postgresql) } + @Test("generateSQL throws when the table dialect is not configured") + func generateSQLThrowsWhenDialectNotConfigured() { + let manager = DataChangeManager() + manager.recordCellChange( + rowIndex: 0, + columnIndex: 1, + columnName: "name", + oldValue: "Alice", + newValue: "Bob" + ) + + #expect(manager.databaseType == nil) + #expect(throws: (any Error).self) { + _ = try manager.generateSQL() + } + } + @Test("configureForTable clears existing changes") func configureForTableClearsChanges() async { let manager = DataChangeManager() manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -61,7 +79,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "products", columns: ["id", "title"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) #expect(!manager.hasChanges) @@ -86,7 +105,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -106,7 +126,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -132,7 +153,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -153,7 +175,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -184,7 +207,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -214,7 +238,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -246,7 +271,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordRowDeletion(rowIndex: 0, originalRow: ["1", "Alice"]) @@ -260,7 +286,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -286,7 +313,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordRowDeletion(rowIndex: 2, originalRow: ["3", "Charlie"]) @@ -303,7 +331,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) let rows: [(rowIndex: Int, originalRow: [PluginCellValue])] = [ @@ -327,7 +356,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -352,7 +382,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -377,7 +408,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -397,7 +429,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -421,7 +454,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -443,7 +477,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( @@ -484,7 +519,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) let initialVersion = manager.reloadVersion @@ -506,7 +542,8 @@ struct DataChangeManagerTests { manager.configureForTable( tableName: "users", columns: ["id", "name"], - primaryKeyColumns: ["id"] + primaryKeyColumns: ["id"], + databaseType: .mysql ) manager.recordCellChange( diff --git a/TableProTests/Core/SchemaTracking/StructureChangeManagerPKTests.swift b/TableProTests/Core/SchemaTracking/StructureChangeManagerPKTests.swift index 06fea744e..397888708 100644 --- a/TableProTests/Core/SchemaTracking/StructureChangeManagerPKTests.swift +++ b/TableProTests/Core/SchemaTracking/StructureChangeManagerPKTests.swift @@ -57,8 +57,7 @@ struct StructureChangeManagerPKTests { columns: sampleColumns(), indexes: sampleIndexes(), foreignKeys: [], - primaryKey: ["id"], - databaseType: .mysql + primaryKey: ["id"] ) let idCol = manager.workingColumns.first { $0.name == "id" } @@ -83,8 +82,7 @@ struct StructureChangeManagerPKTests { columns: sampleColumnsNoPK(), indexes: sampleIndexes(), foreignKeys: [], - primaryKey: ["id"], - databaseType: .postgresql + primaryKey: ["id"] ) // The working columns should have isPrimaryKey set based on the primaryKey parameter @@ -115,8 +113,7 @@ struct StructureChangeManagerPKTests { columns: columns, indexes: [], foreignKeys: [], - primaryKey: ["tenant_id", "user_id"], - databaseType: .postgresql + primaryKey: ["tenant_id", "user_id"] ) let tenantCol = manager.workingColumns.first { $0.name == "tenant_id" } @@ -140,8 +137,7 @@ struct StructureChangeManagerPKTests { columns: sampleColumnsNoPK(), indexes: [], foreignKeys: [], - primaryKey: [], - databaseType: .postgresql + primaryKey: [] ) for col in manager.workingColumns { diff --git a/TableProTests/Core/SchemaTracking/StructureChangeManagerUndoTests.swift b/TableProTests/Core/SchemaTracking/StructureChangeManagerUndoTests.swift index 020df04c8..6d864e27f 100644 --- a/TableProTests/Core/SchemaTracking/StructureChangeManagerUndoTests.swift +++ b/TableProTests/Core/SchemaTracking/StructureChangeManagerUndoTests.swift @@ -40,8 +40,7 @@ struct StructureChangeManagerUndoTests { columns: columns, indexes: indexes, foreignKeys: [], - primaryKey: ["id"], - databaseType: .mysql + primaryKey: ["id"] ) } diff --git a/TableProTests/Core/Sync/ConflictResolverTests.swift b/TableProTests/Core/Sync/ConflictResolverTests.swift new file mode 100644 index 000000000..7a2e93b8d --- /dev/null +++ b/TableProTests/Core/Sync/ConflictResolverTests.swift @@ -0,0 +1,89 @@ +// +// ConflictResolverTests.swift +// TableProTests +// + +import CloudKit +import Foundation +import Testing + +@testable import TablePro + +@Suite("ConflictResolver", .serialized) +@MainActor +struct ConflictResolverTests { + private let resolver = ConflictResolver.shared + + init() { + while resolver.hasConflicts { + _ = resolver.resolveCurrentConflict(keepLocal: false) + } + } + + private func makeConflict(local: [String: String], server: [String: String]) -> SyncConflict { + let localRecord = CKRecord(recordType: "Connection") + for (key, value) in local { + localRecord[key] = value as CKRecordValue + } + let serverRecord = CKRecord(recordType: "Connection") + for (key, value) in server { + serverRecord[key] = value as CKRecordValue + } + return SyncConflict( + recordType: .connection, + entityName: "users", + localRecord: localRecord, + serverRecord: serverRecord, + localModifiedAt: Date(timeIntervalSince1970: 100), + serverModifiedAt: Date(timeIntervalSince1970: 200) + ) + } + + @Test("addConflict queues the conflict as current") + func addConflictQueues() { + #expect(!resolver.hasConflicts) + resolver.addConflict(makeConflict(local: ["name": "L"], server: ["name": "S"])) + + #expect(resolver.hasConflicts) + #expect(resolver.currentConflict?.entityName == "users") + + _ = resolver.resolveCurrentConflict(keepLocal: false) + } + + @Test("Keeping the server version discards the conflict and returns nil") + func keepServerReturnsNil() { + resolver.addConflict(makeConflict(local: ["name": "L"], server: ["name": "S"])) + + let result = resolver.resolveCurrentConflict(keepLocal: false) + + #expect(result == nil) + #expect(!resolver.hasConflicts) + } + + @Test("Keeping local copies local field values onto the server record") + func keepLocalCopiesFieldsOntoServerRecord() { + resolver.addConflict(makeConflict(local: ["name": "Local"], server: ["name": "Server"])) + + let resolved = resolver.resolveCurrentConflict(keepLocal: true) + + #expect(resolved?["name"] as? String == "Local") + #expect(!resolver.hasConflicts) + } + + @Test("Conflicts are resolved in FIFO order") + func conflictsResolveInFifoOrder() { + resolver.addConflict(makeConflict(local: ["name": "first"], server: ["name": "s1"])) + resolver.addConflict(makeConflict(local: ["name": "second"], server: ["name": "s2"])) + + #expect(resolver.currentConflict?.localRecord["name"] as? String == "first") + _ = resolver.resolveCurrentConflict(keepLocal: false) + #expect(resolver.currentConflict?.localRecord["name"] as? String == "second") + _ = resolver.resolveCurrentConflict(keepLocal: false) + #expect(!resolver.hasConflicts) + } + + @Test("Resolving with no pending conflicts returns nil") + func resolveWithNoConflictsReturnsNil() { + #expect(resolver.resolveCurrentConflict(keepLocal: false) == nil) + } +} diff --git a/TableProTests/Core/Sync/SyncChangeTrackerTests.swift b/TableProTests/Core/Sync/SyncChangeTrackerTests.swift new file mode 100644 index 000000000..d15237de4 --- /dev/null +++ b/TableProTests/Core/Sync/SyncChangeTrackerTests.swift @@ -0,0 +1,79 @@ +// +// SyncChangeTrackerTests.swift +// TableProTests +// + +import Foundation +import Testing + +@testable import TablePro + +@Suite("SyncChangeTracker") +@MainActor +struct SyncChangeTrackerTests { + private let metadata: SyncMetadataStorage + private let tracker: SyncChangeTracker + + init() { + let unique = UUID().uuidString + let syncDefaults = UserDefaults(suiteName: "com.TablePro.tests.SyncChangeTracker.\(unique)")! + metadata = SyncMetadataStorage(userDefaults: syncDefaults) + tracker = SyncChangeTracker(metadataStorage: metadata) + } + + @Test("markDirty records the id as dirty") + func markDirtyAddsId() { + tracker.markDirty(.connection, id: "conn-1") + #expect(tracker.dirtyRecords(for: .connection) == ["conn-1"]) + } + + @Test("markDirty with multiple ids records all of them") + func markDirtyMultiple() { + tracker.markDirty(.connection, ids: ["a", "b", "c"]) + #expect(tracker.dirtyRecords(for: .connection) == ["a", "b", "c"]) + } + + @Test("markDirty with an empty id list records nothing") + func markDirtyEmptyIsNoop() { + tracker.markDirty(.connection, ids: []) + #expect(tracker.dirtyRecords(for: .connection).isEmpty) + } + + @Test("markDeleted clears the dirty flag and records a tombstone") + func markDeletedClearsDirtyAndTombstones() { + tracker.markDirty(.connection, id: "conn-1") + tracker.markDeleted(.connection, id: "conn-1") + + #expect(!tracker.dirtyRecords(for: .connection).contains("conn-1")) + #expect(metadata.tombstones(for: .connection).contains { $0.id == "conn-1" }) + } + + @Test("Suppression makes markDirty and markDeleted no-ops") + func suppressionDisablesTracking() { + tracker.isSuppressed = true + tracker.markDirty(.connection, id: "conn-1") + tracker.markDeleted(.group, id: "group-1") + + #expect(tracker.dirtyRecords(for: .connection).isEmpty) + #expect(metadata.tombstones(for: .group).isEmpty) + } + + @Test("clearDirty removes one id; clearAllDirty clears the type") + func clearDirtyBehavior() { + tracker.markDirty(.connection, ids: ["a", "b"]) + tracker.clearDirty(.connection, id: "a") + #expect(tracker.dirtyRecords(for: .connection) == ["b"]) + + tracker.clearAllDirty(.connection) + #expect(tracker.dirtyRecords(for: .connection).isEmpty) + } + + @Test("Dirty records are scoped per record type") + func dirtyRecordsScopedByType() { + tracker.markDirty(.connection, id: "x") + tracker.markDirty(.group, id: "y") + + #expect(tracker.dirtyRecords(for: .connection) == ["x"]) + #expect(tracker.dirtyRecords(for: .group) == ["y"]) + } +} diff --git a/TableProTests/Models/DatabaseTypeCassandraTests.swift b/TableProTests/Models/DatabaseTypeCassandraTests.swift index 0ad36d743..013961300 100644 --- a/TableProTests/Models/DatabaseTypeCassandraTests.swift +++ b/TableProTests/Models/DatabaseTypeCassandraTests.swift @@ -86,11 +86,11 @@ struct DatabaseTypeCassandraTests { @Test("Cassandra included in allCases") func cassandraIncludedInAllCases() { - #expect(DatabaseType.allCases.contains(.cassandra)) + #expect(DatabaseType.allKnownTypes.contains(.cassandra)) } @Test("ScyllaDB included in allCases") func scylladbIncludedInAllCases() { - #expect(DatabaseType.allCases.contains(.scylladb)) + #expect(DatabaseType.allKnownTypes.contains(.scylladb)) } } diff --git a/TableProTests/Models/DatabaseTypeCockroachDBTests.swift b/TableProTests/Models/DatabaseTypeCockroachDBTests.swift index 5ad8a4f71..80723a128 100644 --- a/TableProTests/Models/DatabaseTypeCockroachDBTests.swift +++ b/TableProTests/Models/DatabaseTypeCockroachDBTests.swift @@ -68,6 +68,6 @@ struct DatabaseTypeCockroachDBTests { @Test("allCases shim contains cockroachdb") func allCasesContainsCockroachDB() { - #expect(DatabaseType.allCases.contains(.cockroachdb)) + #expect(DatabaseType.allKnownTypes.contains(.cockroachdb)) } } diff --git a/TableProTests/Models/DatabaseTypeMSSQLTests.swift b/TableProTests/Models/DatabaseTypeMSSQLTests.swift index 7b2f9d882..2c93361dc 100644 --- a/TableProTests/Models/DatabaseTypeMSSQLTests.swift +++ b/TableProTests/Models/DatabaseTypeMSSQLTests.swift @@ -53,6 +53,6 @@ struct DatabaseTypeMSSQLTests { @Test("allCases shim contains mssql") func allCasesContainsMSSql() { - #expect(DatabaseType.allCases.contains(.mssql)) + #expect(DatabaseType.allKnownTypes.contains(.mssql)) } } diff --git a/TableProTests/Models/DatabaseTypeRedisTests.swift b/TableProTests/Models/DatabaseTypeRedisTests.swift index 0f0cc44c2..8e1b2ee7b 100644 --- a/TableProTests/Models/DatabaseTypeRedisTests.swift +++ b/TableProTests/Models/DatabaseTypeRedisTests.swift @@ -46,6 +46,6 @@ struct DatabaseTypeRedisTests { @Test("Included in allCases shim") func includedInAllCases() { - #expect(DatabaseType.allCases.contains(.redis)) + #expect(DatabaseType.allKnownTypes.contains(.redis)) } } diff --git a/TableProTests/Models/DatabaseTypeTests.swift b/TableProTests/Models/DatabaseTypeTests.swift index ee4d1616b..215b8cde2 100644 --- a/TableProTests/Models/DatabaseTypeTests.swift +++ b/TableProTests/Models/DatabaseTypeTests.swift @@ -47,11 +47,6 @@ struct DatabaseTypeTests { #expect(knownTypes.count >= 5) } - @Test("allCases shim matches allKnownTypes") - func testAllCasesShim() { - #expect(DatabaseType.allCases == DatabaseType.allKnownTypes) - } - @Test("Raw value matches display name", arguments: [ (DatabaseType.mysql, "MySQL"), (DatabaseType.mariadb, "MariaDB"), diff --git a/TableProTests/Views/Structure/StructureEditingSupportFieldDiffTests.swift b/TableProTests/Views/Structure/StructureEditingSupportFieldDiffTests.swift index f8e2d47fd..ff40043e9 100644 --- a/TableProTests/Views/Structure/StructureEditingSupportFieldDiffTests.swift +++ b/TableProTests/Views/Structure/StructureEditingSupportFieldDiffTests.swift @@ -251,8 +251,7 @@ struct StructureChangeManagerUndoDeleteTests { columns: columns, indexes: [], foreignKeys: [], - primaryKey: ["id"], - databaseType: .mysql + primaryKey: ["id"] ) return manager }