Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ export DOCKER_HOST=unix://$HOME/.docker/run/docker.sock

### PostgreSQL Backend

DataJoint supports PostgreSQL 15+ as an alternative to MySQL 8+. To install the PostgreSQL driver:
DataJoint supports MySQL 8.0.13+ and PostgreSQL 15+ as production database backends. To install the PostgreSQL driver:

```bash
pip install -e ".[postgres]" # Installs psycopg2-binary
Expand Down
23 changes: 2 additions & 21 deletions src/datajoint/adapters/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def create_table_sql(
fk_cols = ", ".join(self.quote_identifier(col) for col in fk["columns"])
ref_cols = ", ".join(self.quote_identifier(col) for col in fk["ref_columns"])
lines.append(
f"FOREIGN KEY ({fk_cols}) REFERENCES {fk['ref_table']} ({ref_cols}) " f"ON UPDATE CASCADE ON DELETE RESTRICT"
f"FOREIGN KEY ({fk_cols}) REFERENCES {fk['ref_table']} ({ref_cols}) ON UPDATE CASCADE ON DELETE RESTRICT"
)

# Indexes
Expand Down Expand Up @@ -735,26 +735,7 @@ def parse_foreign_key_error(self, error_message: str) -> dict[str, str | list[st
return result

def get_indexes_sql(self, schema_name: str, table_name: str) -> str:
"""Query to get index definitions.
Note: For MySQL 8.0.13+, EXPRESSION column contains the expression for
functional indexes. COLUMN_NAME is NULL for such indexes.
On MySQL < 8.0.13 and MariaDB, the EXPRESSION column does not exist;
heading.py falls back to get_indexes_sql_fallback() in that case.
"""
return (
f"SELECT INDEX_NAME as index_name, "
f"COALESCE(COLUMN_NAME, CONCAT('(', EXPRESSION, ')')) as column_name, "
f"NON_UNIQUE as non_unique, SEQ_IN_INDEX as seq_in_index "
f"FROM information_schema.statistics "
f"WHERE table_schema = {self.quote_string(schema_name)} "
f"AND table_name = {self.quote_string(table_name)} "
f"AND index_name != 'PRIMARY' "
f"ORDER BY index_name, seq_in_index"
)

def get_indexes_sql_fallback(self, schema_name: str, table_name: str) -> str:
"""Fallback index query for MySQL < 8.0.13 and MariaDB (no EXPRESSION column)."""
"""Query to get index definitions. Functional indexes (NULL COLUMN_NAME) are skipped downstream."""
return (
f"SELECT INDEX_NAME as index_name, "
f"COLUMN_NAME as column_name, "
Expand Down
14 changes: 14 additions & 0 deletions src/datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ def conn(
return conn.connection


def _warn_if_mariadb(version_str: str) -> None:
"""Emit a UserWarning if `version_str` looks like MariaDB. No-op for MySQL."""
if "MariaDB" in version_str:
warnings.warn(
f"MariaDB is not officially supported by DataJoint 2.x "
f"(server reports {version_str}). Compatibility is best-effort "
f"and may break in future releases.",
UserWarning,
stacklevel=3,
)


class EmulatedCursor:
"""acts like a cursor"""

Expand Down Expand Up @@ -221,6 +233,8 @@ def __init__(
f"{self.conn_info['user']}@{self.conn_info['host']}:{self.conn_info['port']}{db_str}"
)
self.connection_id = self.adapter.get_connection_id(self._conn)
if self.adapter.backend == "mysql":
_warn_if_mariadb(self.query("SELECT @@version").fetchone()[0])
else:
raise errors.LostConnectionError(
f"Connection failed {self.conn_info['user']}@{self.conn_info['host']}:{self.conn_info['port']}"
Expand Down
25 changes: 4 additions & 21 deletions src/datajoint/heading.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,30 +551,13 @@ def _init_from_database(self) -> None:

# Read and tabulate secondary indexes
keys = defaultdict(dict)
try:
index_rows = conn.query(
adapter.get_indexes_sql(database, table_name),
as_dict=True,
)
except Exception:
# Fall back for MySQL < 8.0.13 / MariaDB (no EXPRESSION column)
index_rows = (
conn.query(
adapter.get_indexes_sql_fallback(database, table_name),
as_dict=True,
)
if hasattr(adapter, "get_indexes_sql_fallback")
else []
)
for item in index_rows:
# Note: adapter.get_indexes_sql() already filters out PRIMARY key
# MySQL/PostgreSQL adapters return: index_name, column_name, non_unique
for item in conn.query(
adapter.get_indexes_sql(database, table_name),
as_dict=True,
):
index_name = item.get("index_name") or item.get("Key_name")
seq = item.get("seq_in_index") or item.get("Seq_in_index") or len(keys[index_name]) + 1
column = item.get("column_name") or item.get("Column_name")
# MySQL EXPRESSION column stores escaped single quotes - unescape them
if column:
column = column.replace("\\'", "'")
non_unique = item.get("non_unique") or item.get("Non_unique")
nullable = item.get("nullable") or (item.get("Null", "NO").lower() == "yes")

Expand Down
4 changes: 4 additions & 0 deletions tests/integration/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def test_insert_update(schema_json):
assert not q


@pytest.mark.skip(
reason="Functional indexes are not currently round-tripped through Heading.indexes; "
"describe() drops them. Re-enable when functional-index introspection is restored."
)
def test_describe(schema_json):
rel = Team()
context = inspect.currentframe().f_globals
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/test_connection_warning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Unit tests for the MariaDB compatibility warning emitted at connect time."""

import warnings

import pytest

from datajoint.connection import _warn_if_mariadb


@pytest.mark.parametrize(
"version_str",
[
"10.11.5-MariaDB",
"10.5.5-MariaDB-1~bionic",
"5.5.68-MariaDB",
],
)
def test_warn_on_mariadb(version_str):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
_warn_if_mariadb(version_str)
assert len(caught) == 1
assert issubclass(caught[0].category, UserWarning)
assert "MariaDB is not officially supported" in str(caught[0].message)
assert version_str in str(caught[0].message)


@pytest.mark.parametrize(
"version_str",
[
"8.0.40",
"8.0.13",
"8.0.40-0ubuntu0.22.04.1",
"8.4.2-log",
"9.0.0",
],
)
def test_no_warn_on_mysql(version_str):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
_warn_if_mariadb(version_str)
assert caught == []
Loading