Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/DIRAC/Core/Utilities/MySQL.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,9 +624,9 @@ def __escapeString(self, myString, connection=None):
# self.log.debug('__escape_string: Could not escape string', '"%s"' % myString)
return S_ERROR(DErrno.EMYSQL, "__escape_string: Could not escape string")

escape_string = connection.escape_string(myString.encode()).decode()
escape_string = connection.string_literal(myString.encode()).decode()
# self.log.debug('__escape_string: returns', '"%s"' % escape_string)
return S_OK(f'"{escape_string}"')
return S_OK(escape_string)
except Exception as x:
return self._except("__escape_string", x, "Could not escape string", myString)

Expand Down
7 changes: 4 additions & 3 deletions src/DIRAC/FrameworkSystem/Client/ComponentInstaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@
from DIRAC.Core.Utilities.Version import getVersion
from DIRAC.FrameworkSystem.Client.ComponentMonitoringClient import ComponentMonitoringClient

SQL_IDENTIFIER_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")


def _safeFloat(value):
try:
Expand Down Expand Up @@ -2055,9 +2057,8 @@ def installDatabase(self, dbName):
"""
Install requested DB in MySQL server
"""
import MySQLdb

dbName = MySQLdb.escape_string(dbName.encode()).decode()
if not SQL_IDENTIFIER_RE.match(dbName):
return S_ERROR(f"Invalid database name '{dbName}'")
if not self.mysqlRootPwd:
rootPwdPath = cfgInstallPath("Database", "RootPwd")
return S_ERROR(f"Missing {rootPwdPath} in {self.cfgFile}")
Expand Down
101 changes: 101 additions & 0 deletions tests/Integration/Core/Test_MySQLDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,104 @@ def test_deleteEntries(name, fields, requiredFields, values, table, cond, expect
result = mysqlDB.getCounters(name, fields, {})
assert result["OK"], result["Message"]
assert result["Value"] == []


# Escape string tests

escape_table = {
"EscapeTestTable": {
"Fields": {
"ID": "INTEGER UNIQUE NOT NULL AUTO_INCREMENT",
"Payload": "TEXT",
},
"PrimaryKey": "ID",
}
}


def _expect_quoted(expected_inner, actual):
"""Check that *actual* is a properly quoted SQL value containing
*expected_inner* inside, regardless of whether single or double
quotes are used."""
if len(actual) < 2:
raise AssertionError(f"Value too short to be quoted: {actual!r}")
for q in ("'", '"'):
if actual.startswith(q) and actual.endswith(q) and actual[1:-1] == expected_inner:
return True
raise AssertionError(
f"Expected a properly quoted value containing {expected_inner!r} " f"(wrapped in ' or \") but got {actual!r}"
)


# Define test cases as (input_val, expected_inner_body) tuples.
# expected_inner_body is what the escaped output contains inside the quotes.
# Use None when the inner body is the same as the input (no escaping needed).
_ESCAPE_CASES = [
("hello world", None),
("O'Reilly", r"O\'Reilly"),
('say "hi"', r"say \"hi\""),
(r"C:\path", r"C:\\path"),
("ab\x00cd", r"ab\0cd"),
("", None),
("café", None),
("日本語", None),
("'; DROP TABLE EscapeTestTable; --", r"\'; DROP TABLE EscapeTestTable; --"),
(r"test\0value", r"test\\0value"),
]


@pytest.mark.parametrize("input_val, expected_inner", _ESCAPE_CASES)
def test_escape_string_escapes_special_chars(input_val, expected_inner):
"""Test that _escapeString properly escapes special characters using a real connection."""
mysqlDB = setupDB()
result = mysqlDB._escapeString(input_val)
assert result["OK"], f"escape_string failed for {input_val!r}: {result.get('Message', '')}"

escaped = result["Value"]
# If expected_inner is None, default to the input value (no escaping expected)
expected_inner = expected_inner if expected_inner is not None else input_val
_expect_quoted(expected_inner, escaped)

# The escaped value should be safe for SQL — verify by inserting it
result = mysqlDB._createTables(escape_table, force=True)
assert result["OK"], result["Message"]

# Insert via direct query using the escaped value
safe_query = f"INSERT INTO EscapeTestTable (Payload) VALUES ({escaped})"
result = mysqlDB._update(safe_query)
assert result["OK"], f"Insert failed with escaped value {escaped!r}: {result.get('Message', '')}"

# Verify we can retrieve it back
result = mysqlDB.getFields("EscapeTestTable", ["Payload"])
assert result["OK"], result["Message"]
# The last inserted row should match the original input
assert input_val in result["Value"][-1][0] if result["Value"] else False, (
f"Retrieved value does not contain original input\n"
f" input: {input_val!r}\n"
f" escaped: {escaped!r}\n"
f" retrieved: {result['Value'][-1][0]!r}"
)


@pytest.mark.parametrize(
"sql_func",
[
"UTC_TIMESTAMP()",
"TIMESTAMPDIFF(MICROSECOND, col1, col2)",
"TIMESTAMPADD(DAY, 1, col1)",
],
)
def test_escape_string_passthrough_sql_functions(sql_func):
"""Recognised SQL function calls are returned unchanged, without escaping."""
mysqlDB = setupDB()
result = mysqlDB._escapeString(sql_func)
assert result["OK"], f"escape_string failed for {sql_func!r}: {result.get('Message', '')}"
assert result["Value"] == sql_func, f"Expected {sql_func!r}, got {result['Value']!r}"


def test_escape_string_accepts_bytes():
"""Bytes input should be decoded before escaping."""
mysqlDB = setupDB()
result = mysqlDB._escapeString(b"hello bytes")
assert result["OK"], result["Message"]
_expect_quoted("hello bytes", result["Value"])
Loading