From 50e43b36ca73fa2b3b4fc776c49ef742ba156852 Mon Sep 17 00:00:00 2001 From: Simon Fayer Date: Mon, 15 Jun 2026 11:32:46 +0100 Subject: [PATCH 1/2] fix: Parameterise SQL in TransformationSystem --- .../DB/TransformationDB.py | 270 +++++++++--------- 1 file changed, 134 insertions(+), 136 deletions(-) diff --git a/src/DIRAC/TransformationSystem/DB/TransformationDB.py b/src/DIRAC/TransformationSystem/DB/TransformationDB.py index 4c3b785aafb..43a5a215e04 100755 --- a/src/DIRAC/TransformationSystem/DB/TransformationDB.py +++ b/src/DIRAC/TransformationSystem/DB/TransformationDB.py @@ -194,7 +194,7 @@ def addTransformation( ] subst = ", ".join(f"%({name})s" if name not in unparameterised_columns else params[name] for name in params) - req = f"INSERT INTO Transformations ({', '.join(params)}) VALUES ({subst});" + req = f"INSERT INTO Transformations ({', '.join(params)}) VALUES ({subst})" # nosec res = self._update(req, args=params, conn=connection) if not res["OK"]: @@ -319,7 +319,7 @@ def getTransformations( sqlCmd = "INSERT INTO to_query_TransformationIDs (TransID) VALUES ( %s )" returnValueOrRaise(self._updatemany(sqlCmd, [(transID,) for transID in transIDs], conn=connection)) - req = "SELECT {} FROM Transformations {} {}".format( + req = "SELECT {} FROM Transformations {} {}".format( # nosec intListToString(columns), join_query, self.buildCondition(condDict, older, newer, timeStamp, orderAttribute, limit, offset=offset), @@ -383,8 +383,8 @@ def getTransformationParameters(self, transName, parameters, connection=False): def getTransformationWithStatus(self, status, connection=False): """Gets a list of the transformations with the supplied status""" - req = f"SELECT TransformationID FROM Transformations WHERE Status = '{status}';" - res = self._query(req, conn=connection) + req = "SELECT TransformationID FROM Transformations WHERE Status = %s" + res = self._query(req, args=(status), conn=connection) if not res["OK"]: return res transIDs = [tupleIn[0] for tupleIn in res["Value"]] @@ -428,23 +428,22 @@ def __getTableDistinctAttributeValues( def __updateTransformationParameter(self, transID, paramName, paramValue, connection=False): if paramName not in self.mutable: return S_ERROR(f"Can not update the '{paramName}' transformation parameter") - res = self._escapeString(paramValue) - if not res["OK"]: - return S_ERROR("Failed to parse parameter value") - paramValue = res["Value"] - req = f"UPDATE Transformations SET {paramName}={paramValue}, LastUpdate=UTC_TIMESTAMP() WHERE TransformationID={transID}" - return self._update(req, conn=connection) + req = ( + f"UPDATE Transformations SET {paramName}=%s, LastUpdate=UTC_TIMESTAMP() WHERE TransformationID=%s" # nosec + ) + args = (paramValue, transID) + return self._update(req, args=args, conn=connection) def _getTransformationID(self, transName, connection=False): """Method returns ID of transformation with the name=""" try: transName = int(transName) - cmd = f"SELECT TransformationID from Transformations WHERE TransformationID={transName};" + cmd = "SELECT TransformationID from Transformations WHERE TransformationID=%s" except ValueError: if not isinstance(transName, str): return S_ERROR("Transformation should be ID or name") - cmd = f"SELECT TransformationID from Transformations WHERE TransformationName='{transName}';" - res = self._query(cmd, conn=connection) + cmd = "SELECT TransformationID from Transformations WHERE TransformationName=%s" + res = self._query(cmd, args=(str(transName),), conn=connection) if not res["OK"]: gLogger.error("Failed to obtain transformation ID for transformation", f"{transName}: {res['Message']}") return res @@ -454,7 +453,8 @@ def _getTransformationID(self, transName, connection=False): return S_OK(res["Value"][0][0]) def __deleteTransformation(self, transID, connection=False): - return self._update(f"DELETE FROM Transformations WHERE TransformationID={transID};", conn=connection) + cmd = "DELETE FROM Transformations WHERE TransformationID=%s" + return self._update(cmd, args=(transID,), conn=connection) def __updateFilterQueries(self, connection=False): """Get filters for all defined input streams in all the transformations.""" @@ -491,10 +491,6 @@ def setTransformationParameter(self, transName, paramName, paramValue, author="" if paramName in self.TRANSPARAMS: res = self.__updateTransformationParameter(transID, paramName, paramValue, connection=connection) if res["OK"]: - pv = self._escapeString(paramValue) - if not pv["OK"]: - return S_ERROR("Failed to parse parameter value") - paramValue = pv["Value"] if paramName == "Body": message = "Body updated" else: @@ -533,8 +529,8 @@ def deleteTransformationParameter(self, transName, paramName, author="", connect return res def __addAdditionalTransformationParameter(self, transID, paramName, paramValue, connection=False): - req = f"DELETE FROM AdditionalParameters WHERE TransformationID={transID} AND ParameterName='{paramName}'" - res = self._update(req, conn=connection) + req = "DELETE FROM AdditionalParameters WHERE TransformationID=%s AND ParameterName=%s" + res = self._update(req, args=(transID, paramName), conn=connection) if not res["OK"]: return res res = self._escapeString(paramValue) @@ -544,21 +540,15 @@ def __addAdditionalTransformationParameter(self, transID, paramName, paramValue, paramType = "StringType" if isinstance(paramValue, int): paramType = "IntType" - req = "INSERT INTO AdditionalParameters ({}) VALUES ({},'{}',{},'{}');".format( - ", ".join(self.ADDITIONALPARAMETERS), - transID, - paramName, - paramValue, - paramType, - ) - return self._update(req, conn=connection) + fields = ", ".join(self.ADDITIONALPARAMETERS) + req = f"INSERT INTO AdditionalParameters ({fields}) VALUES (%s, %s, %s, %s)" # nosec + args = (transID, str(paramName), str(paramValue), str(paramType)) + return self._update(req, args=args, conn=connection) def __getAdditionalParameters(self, transID, connection=False): - req = "SELECT %s FROM AdditionalParameters WHERE TransformationID = %d" % ( - ", ".join(self.ADDITIONALPARAMETERS), - transID, - ) - res = self._query(req, conn=connection) + fields = ",".join(self.ADDITIONALPARAMETERS) + req = f"SELECT {fields} FROM AdditionalParameters WHERE TransformationID = %s" # nosec + res = self._query(req, args=(transID,), conn=connection) if not res["OK"]: return res paramDict = {} @@ -572,10 +562,14 @@ def __deleteTransformationParameters(self, transID, parameters=None, connection= """Remove the parameters associated to a transformation""" if parameters is None: parameters = [] - req = f"DELETE FROM AdditionalParameters WHERE TransformationID={transID}" + req = "DELETE FROM AdditionalParameters WHERE TransformationID=%s" + args = [transID] if parameters: - req = f"{req} AND ParameterName IN ({stringListToString(parameters)});" - return self._update(req, conn=connection) + req += " AND ParameterName IN (" + req += ",".join(["%s"] * len(parameters)) + req += ")" + args.extend(parameters) + return self._update(req, args=args, conn=connection) ########################################################################### # @@ -633,8 +627,8 @@ def getTransformationFiles( elif not set(columns).issubset(all_columns): return S_ERROR(f"Invalid columns requested, valid columns are: {all_columns}") - req = ", ".join(f"df.{x}" if x == "LFN" else f"tf.{x}" for x in columns) - req = f"SELECT {req} FROM TransformationFiles tf" + fields = ", ".join(f"df.{x}" if x == "LFN" else f"tf.{x}" for x in columns) + req = f"SELECT {fields} FROM TransformationFiles tf" # nosec if "LFN" in columns or (condDict and "LFN" in condDict): req = f"{req} JOIN DataFiles df ON tf.FileID = df.FileID" @@ -678,7 +672,7 @@ def getFileSummary(self, lfns, connection=False): def setFileStatusForTransformation(self, transID, fileStatusDict=None, connection=False): """Set file status for the given transformation, based on - fileStatusDict {fileID_A: ('statusA',errorA), fileID_B: ('statusB',errorB), ...} + fileStatusDict {fileID_A: ('statusA', errorA), fileID_B: ('statusB', errorB), ...} The ErrorCount is incremented if errorA flag is True """ @@ -686,24 +680,22 @@ def setFileStatusForTransformation(self, transID, fileStatusDict=None, connectio return S_OK() # Building the request with "ON DUPLICATE KEY UPDATE" - reqBase = "INSERT INTO TransformationFiles (TransformationID, FileID, Status, ErrorCount, LastUpdate) VALUES " - # Get fileID and status for each case: error and no error statusFileDict = {} for fileID, (status, error) in fileStatusDict.items(): statusFileDict.setdefault(error, []).append((fileID, status)) for error, fileIDStatusList in statusFileDict.items(): - req = reqBase + ",".join( - f"({transID}, {fileID}, '{status}', 0, UTC_TIMESTAMP())" for fileID, status in fileIDStatusList + insert_clause = ( + "INSERT INTO TransformationFiles " "(TransformationID, FileID, Status, ErrorCount, LastUpdate) VALUES " ) + on_duplicate = "ON DUPLICATE KEY UPDATE Status=VALUES(Status),LastUpdate=VALUES(LastUpdate)" if error: - # Increment the error counter when we requested - req += " ON DUPLICATE KEY UPDATE Status=VALUES(Status),ErrorCount=ErrorCount+1,LastUpdate=VALUES(LastUpdate)" - else: - req += " ON DUPLICATE KEY UPDATE Status=VALUES(Status),LastUpdate=VALUES(LastUpdate)" + on_duplicate += ",ErrorCount=ErrorCount+1" - result = self._update(req, conn=connection) + req = insert_clause + "(%s, %s, %s, 0, UTC_TIMESTAMP()) " + on_duplicate + args = [(transID, fileID, status) for fileID, status in fileIDStatusList] + result = self._updatemany(req, args, conn=connection) if not result["OK"]: return result return S_OK() @@ -753,11 +745,8 @@ def __addFilesToTransformation(self, transID, fileIDs, connection=False): returnValueOrRaise(self._updatemany(sqlCmd, [(fileID,) for fileID in fileIDs], conn=connection)) # Query existing files using JOIN - req = ( - "SELECT tf.FileID FROM TransformationFiles tf JOIN to_query_FileIDs t ON tf.FileID = t.FileID WHERE tf.TransformationID = %d;" - % transID - ) - res = returnValueOrRaise(self._query(req, conn=connection)) + req = "SELECT tf.FileID FROM TransformationFiles tf JOIN to_query_FileIDs t ON tf.FileID = t.FileID WHERE tf.TransformationID = %s" + res = returnValueOrRaise(self._query(req, args=(transID,), conn=connection)) # Remove already existing fileIDs using set difference for efficiency existingFileIDs = {tupleIn[0] for tupleIn in res} @@ -783,26 +772,29 @@ def __insertExistingTransformationFiles(self, transID, fileTuplesList, connectio gLogger.verbose( f"Adding first {len(fileTuples)} files in TransformationFiles (out of {len(fileTuplesList)})" ) - req = "INSERT INTO TransformationFiles (TransformationID,Status,TaskID,FileID,TargetSE,UsedSE,LastUpdate) VALUES" - candidates = False - + # Collect valid rows for parameterized bulk insert + validRows = [] for ft in fileTuples: _lfn, originalID, fileID, status, taskID, targetSE, usedSE, _errorCount, _lastUpdate, _insertTime = ft[ :10 ] if status not in ("Removed",): - candidates = True if not re.search("-", status): status = f"{status}-inherited" - if taskID: - # Should be readable up to 999,999 tasks: that field is an int(11) in the DB, not a string - taskID = 1000000 * int(originalID) + int(taskID) - req = f"{req} ({transID},'{status}','{taskID}',{fileID},'{targetSE}','{usedSE}',UTC_TIMESTAMP())," - if not candidates: + if taskID: + # Should be readable up to 999,999 tasks: that field is an int(11) in the DB, not a string + taskID = 1000000 * int(originalID) + int(taskID) + validRows.append((transID, status, taskID, fileID, targetSE, usedSE)) + if not validRows: continue - req = req.rstrip(",") - res = self._update(req, conn=connection) + # Parameterized bulk INSERT via executemany + req = ( + "INSERT INTO TransformationFiles " + "(TransformationID,Status,TaskID,FileID,TargetSE,UsedSE,LastUpdate) " + "VALUES (%s, %s, %s, %s, %s, %s, UTC_TIMESTAMP())" + ) + res = self._updatemany(req, validRows, conn=connection) if not res["OK"]: return res @@ -810,14 +802,11 @@ def __insertExistingTransformationFiles(self, transID, fileTuplesList, connectio def __assignTransformationFile(self, transID, taskID, se, fileIDs, connection=False): """Make necessary updates to the TransformationFiles table for the newly created task""" - req = "UPDATE TransformationFiles SET TaskID='%d',UsedSE='%s',Status='Assigned',LastUpdate=UTC_TIMESTAMP()" - req = (req + " WHERE TransformationID = %d AND FileID IN (%s);") % ( - taskID, - se, - transID, - intListToString(fileIDs), - ) - res = self._update(req, conn=connection) + req = "UPDATE TransformationFiles SET TaskID=%s,UsedSE=%s,Status='Assigned',LastUpdate=UTC_TIMESTAMP()" + req += " WHERE TransformationID = %s AND FileID IN (" + ",".join(["%s"] * len(fileIDs)) + ")" + args = [taskID, se, transID] + args.extend(fileIDs) + res = self._update(req, args=args, conn=connection) if not res["OK"]: gLogger.error("Failed to assign file to task", res["Message"]) values = [(transID, fileID, taskID) for fileID in fileIDs] @@ -828,26 +817,30 @@ def __assignTransformationFile(self, transID, taskID, se, fileIDs, connection=Fa return res def __setTransformationFileStatus(self, fileIDs, status, connection=False): - req = f"UPDATE TransformationFiles SET Status = '{status}' WHERE FileID IN ({intListToString(fileIDs)});" - res = self._update(req, conn=connection) + req = "UPDATE TransformationFiles SET Status = %s WHERE FileID IN (" + req += ",".join(["%s"] * len(fileIDs)) + req += ")" + args = [status] + args.extend([str(x) for x in fileIDs]) + res = self._update(req, args=args, conn=connection) if not res["OK"]: gLogger.error("Failed to update file status", res["Message"]) return res def __setTransformationFileUsedSE(self, fileIDs, usedSE, connection=False): - req = f"UPDATE TransformationFiles SET UsedSE = '{usedSE}' WHERE FileID IN ({intListToString(fileIDs)});" - res = self._update(req, conn=connection) + req = "UPDATE TransformationFiles SET UsedSE = %s WHERE FileID IN (" + req += ",".join(["%s"] * len(fileIDs)) + req += ")" + args = [usedSE] + args.extend([str(x) for x in fileIDs]) + res = self._update(req, args=args, conn=connection) if not res["OK"]: gLogger.error("Failed to update file usedSE", res["Message"]) return res def __resetTransformationFile(self, transID, taskID, connection=False): - req = ( - "UPDATE TransformationFiles SET TaskID=NULL, UsedSE='Unknown', Status='Unused'\ - WHERE TransformationID = %d AND TaskID=%d;" - % (transID, taskID) - ) - res = self._update(req, conn=connection) + req = "UPDATE TransformationFiles SET TaskID=NULL, UsedSE='Unknown', Status='Unused' WHERE TransformationID = %s AND TaskID=%s" + res = self._update(req, args=(transID, taskID), conn=connection) if not res["OK"]: gLogger.error("Failed to reset transformation file", res["Message"]) return res @@ -864,15 +857,12 @@ def __deleteTransformationFiles(self, transID, connection=False): # The IGNORE keyword will make sure we do not abort the full removal # on a foreign key error # https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html#ignore-strict-comparison - req = ( - "DELETE IGNORE tf, df \ + req = "DELETE IGNORE tf, df \ FROM TransformationFiles tf \ JOIN DataFiles df \ ON tf.FileID=df.FileID \ - WHERE TransformationID = %d;" - % transID - ) - res = self._update(req, conn=connection) + WHERE TransformationID = %s" + res = self._update(req, args=(transID,), conn=connection) if not res["OK"]: gLogger.error("Failed to delete transformation files", res["Message"]) return res @@ -886,13 +876,13 @@ def __deleteTransformationFileTask(self, transID, taskID, connection=False): """Delete the file associated to a given task of a given transformation from the TransformationFileTasks table for transformation with TransformationID and TaskID """ - req = f"DELETE FROM TransformationFileTasks WHERE TransformationID={transID} AND TaskID={taskID}" - return self._update(req, conn=connection) + req = "DELETE FROM TransformationFileTasks WHERE TransformationID=%s AND TaskID=%s" + return self._update(req, args=(transID, taskID), conn=connection) def __deleteTransformationFileTasks(self, transID, connection=False): """Remove all associations between files, tasks and a transformation""" - req = f"DELETE FROM TransformationFileTasks WHERE TransformationID = {transID}" - res = self._update(req, conn=connection) + req = "DELETE FROM TransformationFileTasks WHERE TransformationID = %s" + res = self._update(req, args=(transID,), conn=connection) if not res["OK"]: gLogger.error("Failed to delete transformation files/task history", res["Message"]) return res @@ -915,10 +905,9 @@ def getTransformationTasks( connection=False, ): connection = self.__getConnection(connection) - req = "SELECT {} FROM TransformationTasks {}".format( - intListToString(self.TASKSPARAMS), - self.buildCondition(condDict, older, newer, timeStamp, orderAttribute, limit, offset=offset), - ) + fields = ", ".join(self.TASKSPARAMS) + req = f"SELECT {fields} FROM TransformationTasks " # nosec + req += self.buildCondition(condDict, older, newer, timeStamp, orderAttribute, limit, offset=offset) res = self._query(req, conn=connection) if not res["OK"]: return res @@ -1086,24 +1075,27 @@ def getTransformationTaskStats(self, transName="", connection=False): return S_OK(statusDict) def __setTaskParameterValue(self, transID, taskID, paramName, paramValue, connection=False): - req = f"UPDATE TransformationTasks SET {paramName}='{paramValue}', LastUpdateTime=UTC_TIMESTAMP()" - req = req + " WHERE TransformationID=%d AND TaskID=%d;" % (transID, taskID) - return self._update(req, conn=connection) + if paramName not in self.TASKSPARAMS: + return S_ERROR(f"Invalid task parameter: {paramName}") + req = f"UPDATE TransformationTasks SET {paramName}=%s, LastUpdateTime=UTC_TIMESTAMP()" # nosec + req += " WHERE TransformationID=%s AND TaskID=%s" + args = (paramValue, transID, taskID) + return self._update(req, args=args, conn=connection) def __deleteTransformationTasks(self, transID, connection=False): """Delete all the tasks from the TransformationTasks table for transformation with TransformationID""" - req = f"DELETE FROM TransformationTasks WHERE TransformationID={transID}" - return self._update(req, conn=connection) + req = "DELETE FROM TransformationTasks WHERE TransformationID=%s" + return self._update(req, args=(transID,), conn=connection) def __deleteTransformationTask(self, transID, taskID, connection=False): """Delete the task from the TransformationTasks table for transformation with TransformationID""" - req = f"DELETE FROM TransformationTasks WHERE TransformationID={transID} AND TaskID={taskID}" - return self._update(req, conn=connection) + req = "DELETE FROM TransformationTasks WHERE TransformationID=%s AND TaskID=%s" + return self._update(req, args=(transID, taskID), conn=connection) def __deleteTransformationMetaQueries(self, transID, connection=False): """Delete all the meta queries from the TransformationMetaQueries table for transformation with TransformationID""" - req = f"DELETE FROM TransformationMetaQueries WHERE TransformationID={transID}" - return self._update(req, conn=connection) + req = "DELETE FROM TransformationMetaQueries WHERE TransformationID=%s" + return self._update(req, args=(transID,), conn=connection) #################################################################### # @@ -1173,8 +1165,8 @@ def deleteTransformationMetaQuery(self, transName, queryType, author="", connect if not res["OK"]: return S_ERROR("Failed to parse the transformation query type") queryType = res["Value"] - req = "DELETE FROM TransformationMetaQueries WHERE TransformationID=%d AND QueryType=%s;" % (transID, queryType) - res = self._update(req, conn=connection) + req = "DELETE FROM TransformationMetaQueries WHERE TransformationID=%s AND QueryType=%s" + res = self._update(req, args=(transID, queryType), conn=connection) if not res["OK"]: return res if res["Value"]: @@ -1202,7 +1194,7 @@ def getTransformationMetaQuery(self, transName, queryType, connection=False): return S_ERROR("Failed to parse the transformation query type") queryType = res["Value"] req = "SELECT MetaDataName,MetaDataValue,MetaDataType FROM TransformationMetaQueries" - req = req + " WHERE TransformationID=%d AND QueryType=%s;" % (transID, queryType) + req += f" WHERE TransformationID={transID} AND QueryType={queryType}" res = self._query(req, conn=connection) if not res["OK"]: return res @@ -1237,12 +1229,12 @@ def getTaskInputVector(self, transName, taskID, connection=False): taskIDList = [taskID] else: taskIDList = list(taskID) - taskString = ",".join([f"'{x}'" for x in taskIDList]) - req = "SELECT TaskID,InputVector FROM TaskInputs WHERE TaskID in (%s) AND TransformationID='%d';" % ( - taskString, - transID, - ) - res = self._query(req) + req = "SELECT TaskID,InputVector FROM TaskInputs WHERE TransformationID=%s AND TaskID in (" + req += ",".join(["%s"] * len(taskIDList)) + req += ")" + args = [transID] + args.extend(taskIDList) + res = self._query(req, args=args) inputVectorDict = {} if not res["OK"]: return res @@ -1262,10 +1254,12 @@ def __insertTaskInputs(self, transID, taskID, lfns, connection=False): def __deleteTransformationTaskInputs(self, transID, taskID=0, connection=False): """Delete all the tasks inputs from the TaskInputs table for transformation with TransformationID""" - req = f"DELETE FROM TaskInputs WHERE TransformationID={transID}" + req = "DELETE FROM TaskInputs WHERE TransformationID=%s" + args = [transID] if taskID: - req = f"{req} AND TaskID={taskID}" - return self._update(req, conn=connection) + req += " AND TaskID=%s" + args.append(taskID) + return self._update(req, args=args, conn=connection) ########################################################################### # @@ -1283,9 +1277,8 @@ def __updateTransformationLogging(self, transName, message, author, connection=F return res connection = res["Value"]["Connection"] transID = res["Value"]["TransformationID"] - req = "INSERT INTO TransformationLog (TransformationID,Message,Author,MessageDate)" - req = req + f" VALUES ({transID},'{message}','{author}',UTC_TIMESTAMP());" - return self._update(req, conn=connection) + req = "INSERT INTO TransformationLog (TransformationID,Message,Author,MessageDate) VALUES (%s, %s, %s, UTC_TIMESTAMP())" + return self._update(req, args=(transID, message, author), conn=connection) def getTransformationLogging(self, transName, connection=False): """Get logging info from the TransformationLog table""" @@ -1295,7 +1288,7 @@ def getTransformationLogging(self, transName, connection=False): connection = res["Value"]["Connection"] transID = res["Value"]["TransformationID"] req = "SELECT TransformationID, Message, Author, MessageDate FROM TransformationLog" - req = req + f" WHERE TransformationID={transID} ORDER BY MessageDate;" + req += f" WHERE TransformationID={transID} ORDER BY MessageDate" res = self._query(req) if not res["OK"]: return res @@ -1311,8 +1304,8 @@ def getTransformationLogging(self, transName, connection=False): def __deleteTransformationLog(self, transID, connection=False): """Remove the entries in the transformation log for a transformation""" - req = f"DELETE FROM TransformationLog WHERE TransformationID={transID}" - return self._update(req, conn=connection) + req = "DELETE FROM TransformationLog WHERE TransformationID=%s" + return self._update(req, args=(transID,), conn=connection) ########################################################################### # @@ -1360,8 +1353,11 @@ def __getFileIDsForLfns(self, lfns, connection=False): def __getLfnsForFileIDs(self, fileIDs, connection=False): """Get lfns for the given list of fileIDs""" - req = f"SELECT LFN,FileID FROM DataFiles WHERE FileID in ({stringListToString(fileIDs)});" - res = self._query(req, conn=connection) + req = "SELECT LFN,FileID FROM DataFiles WHERE FileID in (" + req += ",".join(["%s"] * len(fileIDs)) + req += ")" + args = [str(x) for x in fileIDs] + res = self._query(req, args=args, conn=connection) if not res["OK"]: return res fids = dict(res["Value"]) @@ -1377,8 +1373,8 @@ def __addDataFiles(self, lfns, connection=False): # Insert only files not found, and assume the LFN is unique in the table lfnFileIDs = res["Value"][1] for lfn in set(lfns) - set(lfnFileIDs): - req = f"INSERT INTO DataFiles (LFN,Status) VALUES ('{lfn}','New');" - res = self._update(req, conn=connection) + req = "INSERT INTO DataFiles (LFN,Status) VALUES (%s,'New')" + res = self._update(req, args=(lfn,), conn=connection) # If the LFN is duplicate we get an error and ignore it if res["OK"]: lfnFileIDs[lfn] = res["lastRowId"] @@ -1393,8 +1389,12 @@ def __addDataFiles(self, lfns, connection=False): def __setDataFileStatus(self, fileIDs, status, connection=False): """Set the status of the supplied files""" - req = f"UPDATE DataFiles SET Status = '{status}' WHERE FileID IN ({intListToString(fileIDs)});" - return self._update(req, conn=connection) + req = "UPDATE DataFiles SET Status = %s WHERE FileID IN (" + req += ",".join(["%s"] * len(fileIDs)) + req += ")" + args = [status] + args.extend([str(x) for x in fileIDs]) + return self._update(req, args=args, conn=connection) ########################################################################### # @@ -1436,10 +1436,8 @@ def addTaskForTransformation(self, transID, lfns=None, se="Unknown", connection= # Insert the task into the jobs table and retrieve the taskID self.lock.acquire() - req = "INSERT INTO TransformationTasks(TransformationID, ExternalStatus, ExternalID, TargetSE," - req = req + " CreationTime, LastUpdateTime)" - req = req + " VALUES (%s,'%s','%d','%s', UTC_TIMESTAMP(), UTC_TIMESTAMP());" % (transID, "Created", 0, se) - res = self._update(req, conn=connection) + req = "INSERT INTO TransformationTasks(TransformationID, ExternalStatus, ExternalID, TargetSE, CreationTime, LastUpdateTime) VALUES (%s, 'Created', 0, %s, UTC_TIMESTAMP(), UTC_TIMESTAMP())" + res = self._update(req, args=(transID, se), conn=connection) if not res["OK"]: self.lock.release() gLogger.error("Failed to publish task for transformation", res["Message"]) @@ -1555,10 +1553,10 @@ def __removeTransformationTask(self, transID, taskID, connection=False): def __checkUpdate(self, table, param, paramValue, selectDict=None, connection=False): """Check whether the update will perform an update""" - req = f"UPDATE {table} SET {param} = '{paramValue}'" + req = f"UPDATE {table} SET {param} = %s" # nosec if selectDict: - req = f"{req} {self.buildCondition(selectDict)}" - return self._update(req, conn=connection) + req += f" {self.buildCondition(selectDict)}" + return self._update(req, args=(paramValue,), conn=connection) def __getConnection(self, connection): if connection: From 49feb047b6142bef18124025a27396401e2cc074 Mon Sep 17 00:00:00 2001 From: Simon Fayer Date: Fri, 26 Jun 2026 15:12:46 +0100 Subject: [PATCH 2/2] fix: Replace escape_string with alternatives --- src/DIRAC/Core/Utilities/MySQL.py | 4 +- .../Client/ComponentInstaller.py | 7 +- tests/Integration/Core/Test_MySQLDB.py | 100 ++++++++++++++++++ 3 files changed, 106 insertions(+), 5 deletions(-) diff --git a/src/DIRAC/Core/Utilities/MySQL.py b/src/DIRAC/Core/Utilities/MySQL.py index a9d9ef08d3b..3c9de82368a 100755 --- a/src/DIRAC/Core/Utilities/MySQL.py +++ b/src/DIRAC/Core/Utilities/MySQL.py @@ -595,9 +595,9 @@ def __escapeString(self, myString, connection=None): # self.log.debug('__escape_string: Could not escape string', '"%s"' % myString) return S_ERROR(DErrno.EMYSQL, "__escape_string: Could not escape string") - escape_string = connection.escape_string(myString.encode()).decode() + escape_string = connection.string_literal(myString.encode()).decode() # self.log.debug('__escape_string: returns', '"%s"' % escape_string) - return S_OK(f'"{escape_string}"') + return S_OK(escape_string) except Exception as x: return self._except("__escape_string", x, "Could not escape string", myString) diff --git a/src/DIRAC/FrameworkSystem/Client/ComponentInstaller.py b/src/DIRAC/FrameworkSystem/Client/ComponentInstaller.py index b4bbca51e78..41e89570862 100644 --- a/src/DIRAC/FrameworkSystem/Client/ComponentInstaller.py +++ b/src/DIRAC/FrameworkSystem/Client/ComponentInstaller.py @@ -102,6 +102,8 @@ from DIRAC.Core.Utilities.Version import getVersion from DIRAC.FrameworkSystem.Client.ComponentMonitoringClient import ComponentMonitoringClient +SQL_IDENTIFIER_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + def _safeFloat(value): try: @@ -2054,9 +2056,8 @@ def installDatabase(self, dbName): """ Install requested DB in MySQL server """ - import MySQLdb - - dbName = MySQLdb.escape_string(dbName.encode()).decode() + if not SQL_IDENTIFIER_RE.match(dbName): + return S_ERROR(f"Invalid database name '{dbName}'") if not self.mysqlRootPwd: rootPwdPath = cfgInstallPath("Database", "RootPwd") return S_ERROR(f"Missing {rootPwdPath} in {self.cfgFile}") diff --git a/tests/Integration/Core/Test_MySQLDB.py b/tests/Integration/Core/Test_MySQLDB.py index 4b334356f44..373d8349046 100644 --- a/tests/Integration/Core/Test_MySQLDB.py +++ b/tests/Integration/Core/Test_MySQLDB.py @@ -294,3 +294,103 @@ def test_deleteEntries(name, fields, requiredFields, values, table, cond, expect result = mysqlDB.getCounters(name, fields, {}) assert result["OK"], result["Message"] assert result["Value"] == [] + + +# Escape string tests + +escape_table = { + "EscapeTestTable": { + "Fields": { + "ID": "INTEGER UNIQUE NOT NULL AUTO_INCREMENT", + "Payload": "TEXT", + }, + "PrimaryKey": "ID", + } +} + + +def _expect_quoted(expected_inner, actual): + """Check that *actual* is a properly quoted SQL value containing + *expected_inner* inside, regardless of whether single or double + quotes are used.""" + if len(actual) < 2: + raise AssertionError(f"Value too short to be quoted: {actual!r}") + for q in ("'", '"'): + if actual.startswith(q) and actual.endswith(q) and actual[1:-1] == expected_inner: + return True + raise AssertionError( + f"Expected a properly quoted value containing {expected_inner!r} " f"(wrapped in ' or \") but got {actual!r}" + ) + + +# Define test cases as (input_val, expected_inner_body) tuples. +# expected_inner_body is what the escaped output contains inside the quotes. +# Use None when the inner body is the same as the input (no escaping needed). +_ESCAPE_CASES = [ + ("hello world", None), + ("O'Reilly", r"O\'Reilly"), + ('say "hi"', r"say \"hi\""), + (r"C:\path", r"C:\\path"), + ("ab\x00cd", r"ab\0cd"), + ("", None), + ("café", None), + ("'; DROP TABLE EscapeTestTable; --", r"\'; DROP TABLE EscapeTestTable; --"), + (r"test\0value", r"test\\0value"), +] + + +@pytest.mark.parametrize("input_val, expected_inner", _ESCAPE_CASES) +def test_escape_string_escapes_special_chars(input_val, expected_inner): + """Test that _escapeString properly escapes special characters using a real connection.""" + mysqlDB = setupDB() + result = mysqlDB._escapeString(input_val) + assert result["OK"], f"escape_string failed for {input_val!r}: {result.get('Message', '')}" + + escaped = result["Value"] + # If expected_inner is None, default to the input value (no escaping expected) + expected_inner = expected_inner if expected_inner is not None else input_val + _expect_quoted(expected_inner, escaped) + + # The escaped value should be safe for SQL — verify by inserting it + result = mysqlDB._createTables(escape_table, force=True) + assert result["OK"], result["Message"] + + # Insert via direct query using the escaped value + safe_query = f"INSERT INTO EscapeTestTable (Payload) VALUES ({escaped})" + result = mysqlDB._update(safe_query) + assert result["OK"], f"Insert failed with escaped value {escaped!r}: {result.get('Message', '')}" + + # Verify we can retrieve it back + result = mysqlDB.getFields("EscapeTestTable", ["Payload"]) + assert result["OK"], result["Message"] + # The last inserted row should match the original input + assert input_val in result["Value"][-1][0] if result["Value"] else False, ( + f"Retrieved value does not contain original input\n" + f" input: {input_val!r}\n" + f" escaped: {escaped!r}\n" + f" retrieved: {result['Value'][-1][0]!r}" + ) + + +@pytest.mark.parametrize( + "sql_func", + [ + "UTC_TIMESTAMP()", + "TIMESTAMPDIFF(MICROSECOND, col1, col2)", + "TIMESTAMPADD(DAY, 1, col1)", + ], +) +def test_escape_string_passthrough_sql_functions(sql_func): + """Recognised SQL function calls are returned unchanged, without escaping.""" + mysqlDB = setupDB() + result = mysqlDB._escapeString(sql_func) + assert result["OK"], f"escape_string failed for {sql_func!r}: {result.get('Message', '')}" + assert result["Value"] == sql_func, f"Expected {sql_func!r}, got {result['Value']!r}" + + +def test_escape_string_accepts_bytes(): + """Bytes input should be decoded before escaping.""" + mysqlDB = setupDB() + result = mysqlDB._escapeString(b"hello bytes") + assert result["OK"], result["Message"] + _expect_quoted("hello bytes", result["Value"])