From bd26ee777e98848fb9af2172ab3d41898fc43460 Mon Sep 17 00:00:00 2001 From: Simon Fayer Date: Fri, 26 Jun 2026 15:12:46 +0100 Subject: [PATCH] fix: Replace escape_string with alternatives --- src/DIRAC/Core/Utilities/MySQL.py | 4 +- .../Client/ComponentInstaller.py | 7 +- tests/Integration/Core/Test_MySQLDB.py | 101 ++++++++++++++++++ 3 files changed, 107 insertions(+), 5 deletions(-) diff --git a/src/DIRAC/Core/Utilities/MySQL.py b/src/DIRAC/Core/Utilities/MySQL.py index 49383b1cdd7..2fc1dd71d8d 100755 --- a/src/DIRAC/Core/Utilities/MySQL.py +++ b/src/DIRAC/Core/Utilities/MySQL.py @@ -624,9 +624,9 @@ def __escapeString(self, myString, connection=None): # self.log.debug('__escape_string: Could not escape string', '"%s"' % myString) return S_ERROR(DErrno.EMYSQL, "__escape_string: Could not escape string") - escape_string = connection.escape_string(myString.encode()).decode() + escape_string = connection.string_literal(myString.encode()).decode() # self.log.debug('__escape_string: returns', '"%s"' % escape_string) - return S_OK(f'"{escape_string}"') + return S_OK(escape_string) except Exception as x: return self._except("__escape_string", x, "Could not escape string", myString) diff --git a/src/DIRAC/FrameworkSystem/Client/ComponentInstaller.py b/src/DIRAC/FrameworkSystem/Client/ComponentInstaller.py index d35fb36582a..2062e7348f9 100644 --- a/src/DIRAC/FrameworkSystem/Client/ComponentInstaller.py +++ b/src/DIRAC/FrameworkSystem/Client/ComponentInstaller.py @@ -103,6 +103,8 @@ from DIRAC.Core.Utilities.Version import getVersion from DIRAC.FrameworkSystem.Client.ComponentMonitoringClient import ComponentMonitoringClient +SQL_IDENTIFIER_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + def _safeFloat(value): try: @@ -2055,9 +2057,8 @@ def installDatabase(self, dbName): """ Install requested DB in MySQL server """ - import MySQLdb - - dbName = MySQLdb.escape_string(dbName.encode()).decode() + if not SQL_IDENTIFIER_RE.match(dbName): + return S_ERROR(f"Invalid database name '{dbName}'") if not self.mysqlRootPwd: rootPwdPath = cfgInstallPath("Database", "RootPwd") return S_ERROR(f"Missing {rootPwdPath} in {self.cfgFile}") diff --git a/tests/Integration/Core/Test_MySQLDB.py b/tests/Integration/Core/Test_MySQLDB.py index 4b334356f44..c89a1e1a58d 100644 --- a/tests/Integration/Core/Test_MySQLDB.py +++ b/tests/Integration/Core/Test_MySQLDB.py @@ -294,3 +294,104 @@ def test_deleteEntries(name, fields, requiredFields, values, table, cond, expect result = mysqlDB.getCounters(name, fields, {}) assert result["OK"], result["Message"] assert result["Value"] == [] + + +# Escape string tests + +escape_table = { + "EscapeTestTable": { + "Fields": { + "ID": "INTEGER UNIQUE NOT NULL AUTO_INCREMENT", + "Payload": "TEXT", + }, + "PrimaryKey": "ID", + } +} + + +def _expect_quoted(expected_inner, actual): + """Check that *actual* is a properly quoted SQL value containing + *expected_inner* inside, regardless of whether single or double + quotes are used.""" + if len(actual) < 2: + raise AssertionError(f"Value too short to be quoted: {actual!r}") + for q in ("'", '"'): + if actual.startswith(q) and actual.endswith(q) and actual[1:-1] == expected_inner: + return True + raise AssertionError( + f"Expected a properly quoted value containing {expected_inner!r} " f"(wrapped in ' or \") but got {actual!r}" + ) + + +# Define test cases as (input_val, expected_inner_body) tuples. +# expected_inner_body is what the escaped output contains inside the quotes. +# Use None when the inner body is the same as the input (no escaping needed). +_ESCAPE_CASES = [ + ("hello world", None), + ("O'Reilly", r"O\'Reilly"), + ('say "hi"', r"say \"hi\""), + (r"C:\path", r"C:\\path"), + ("ab\x00cd", r"ab\0cd"), + ("", None), + ("café", None), + ("日本語", None), + ("'; DROP TABLE EscapeTestTable; --", r"\'; DROP TABLE EscapeTestTable; --"), + (r"test\0value", r"test\\0value"), +] + + +@pytest.mark.parametrize("input_val, expected_inner", _ESCAPE_CASES) +def test_escape_string_escapes_special_chars(input_val, expected_inner): + """Test that _escapeString properly escapes special characters using a real connection.""" + mysqlDB = setupDB() + result = mysqlDB._escapeString(input_val) + assert result["OK"], f"escape_string failed for {input_val!r}: {result.get('Message', '')}" + + escaped = result["Value"] + # If expected_inner is None, default to the input value (no escaping expected) + expected_inner = expected_inner if expected_inner is not None else input_val + _expect_quoted(expected_inner, escaped) + + # The escaped value should be safe for SQL — verify by inserting it + result = mysqlDB._createTables(escape_table, force=True) + assert result["OK"], result["Message"] + + # Insert via direct query using the escaped value + safe_query = f"INSERT INTO EscapeTestTable (Payload) VALUES ({escaped})" + result = mysqlDB._update(safe_query) + assert result["OK"], f"Insert failed with escaped value {escaped!r}: {result.get('Message', '')}" + + # Verify we can retrieve it back + result = mysqlDB.getFields("EscapeTestTable", ["Payload"]) + assert result["OK"], result["Message"] + # The last inserted row should match the original input + assert input_val in result["Value"][-1][0] if result["Value"] else False, ( + f"Retrieved value does not contain original input\n" + f" input: {input_val!r}\n" + f" escaped: {escaped!r}\n" + f" retrieved: {result['Value'][-1][0]!r}" + ) + + +@pytest.mark.parametrize( + "sql_func", + [ + "UTC_TIMESTAMP()", + "TIMESTAMPDIFF(MICROSECOND, col1, col2)", + "TIMESTAMPADD(DAY, 1, col1)", + ], +) +def test_escape_string_passthrough_sql_functions(sql_func): + """Recognised SQL function calls are returned unchanged, without escaping.""" + mysqlDB = setupDB() + result = mysqlDB._escapeString(sql_func) + assert result["OK"], f"escape_string failed for {sql_func!r}: {result.get('Message', '')}" + assert result["Value"] == sql_func, f"Expected {sql_func!r}, got {result['Value']!r}" + + +def test_escape_string_accepts_bytes(): + """Bytes input should be decoded before escaping.""" + mysqlDB = setupDB() + result = mysqlDB._escapeString(b"hello bytes") + assert result["OK"], result["Message"] + _expect_quoted("hello bytes", result["Value"])