Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 133 additions & 4 deletions specifyweb/backend/stored_queries/build_models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
from typing import Optional
from specifyweb.specify.models_utils.load_datamodel import Datamodel, Table, Field, Relationship
from typing import NamedTuple, Optional
from specifyweb.specify.models_utils.load_datamodel import (
Datamodel,
Field,
ManyToMany,
Relationship,
Table,
)
from sqlalchemy import Table as Table_Sqlalchemy, Column, ForeignKey, types, orm, MetaData
from sqlalchemy.dialects.mysql import BIT as mysql_bit_type
metadata = MetaData()


MANY_TO_MANY_TABLES = {
"Project_colobj": {
"table": "project_colobj",
"id_column": "ProjectColObjID",
"through_fields": {
"collectionobject": {
"model": "CollectionObject",
"column": "CollectionObjectID",
},
"project": {
"model": "Project",
"column": "ProjectID",
},
},
},
}


class ManyToManyJoinInfo(NamedTuple):
table: str
local_column: str
remote_column: str

class BaseIdAlias:
id_attr_name: Optional[str] = None

Expand Down Expand Up @@ -65,6 +95,69 @@ def make_column(flddef: Field):
nullable = not flddef.required)


def make_many_to_many_table(datamodel: Datamodel, table_info):
columns = [Column(table_info["id_column"], types.Integer, primary_key=True)]

for through_field in table_info["through_fields"].values():
remote_tabledef = datamodel.get_table(through_field["model"])
if remote_tabledef is None:
return

fk_target = ".".join((remote_tabledef.table, remote_tabledef.idColumn))
columns.append(
Column(
through_field["column"],
types.Integer,
ForeignKey(fk_target),
nullable=False,
)
)

return Table_Sqlalchemy(table_info["table"], metadata, *columns)


def get_many_to_many_join_info(
datamodel: Datamodel, reldef: Relationship
) -> ManyToManyJoinInfo | None:
if not isinstance(reldef, ManyToMany):
return None

table_info = MANY_TO_MANY_TABLES.get(reldef.through_model)
if table_info is None:
return None

local_field = table_info["through_fields"].get(reldef.through_field)
if local_field is None:
return None

remote_field = None
related_table = datamodel.get_table(reldef.relatedModelName)
related_relationship = (
related_table.get_field(reldef.otherSideName, strict=False)
if related_table is not None and reldef.otherSideName
else None
)
remote_through_field = getattr(related_relationship, "through_field", None)
if remote_through_field is not None:
remote_field = table_info["through_fields"].get(remote_through_field)

if remote_field is None:
remote_fields = [
field
for through_field, field in table_info["through_fields"].items()
if through_field != reldef.through_field
]
if len(remote_fields) != 1:
return None
remote_field = remote_fields[0]

return ManyToManyJoinInfo(
table=table_info["table"],
local_column=local_field["column"],
remote_column=remote_field["column"],
)


field_type_map = {
'text' : types.Text,
'json' : types.JSON,
Expand All @@ -84,7 +177,15 @@ def make_column(flddef: Field):
}

def make_tables(datamodel: Datamodel):
return {td.table: make_table(datamodel, td) for td in datamodel.tables}
tables = {td.table: make_table(datamodel, td) for td in datamodel.tables}

for table_info in MANY_TO_MANY_TABLES.values():
if table_info["table"] not in tables:
table = make_many_to_many_table(datamodel, table_info)
if table is not None:
tables[table_info["table"]] = table

return tables

def make_classes(datamodel: Datamodel):
def make_class(tabledef):
Expand All @@ -106,6 +207,35 @@ def map_class(tabledef):
table = tables[ tabledef.table ]

def make_relationship(reldef):
if isinstance(reldef, ManyToMany):
join_info = get_many_to_many_join_info(datamodel, reldef)
remote_tabledef = datamodel.get_table(reldef.relatedModelName)
if (
join_info is None
or remote_tabledef is None
or reldef.relatedModelName not in classes
or join_info.table not in tables
):
return

remote_class = classes[reldef.relatedModelName]
remote_table = tables[remote_tabledef.table]
secondary_table = tables[join_info.table]

return reldef.name, orm.relationship(
remote_class,
secondary=secondary_table,
primaryjoin=(
table.c[tabledef.idColumn]
== secondary_table.c[join_info.local_column]
),
secondaryjoin=(
remote_table.c[remote_tabledef.idColumn]
== secondary_table.c[join_info.remote_column]
),
viewonly=True,
)

if not hasattr(reldef, 'column') or not reldef.column or reldef.relatedModelName not in classes:
return

Expand Down Expand Up @@ -141,4 +271,3 @@ def make_relationship(reldef):

for tabledef in datamodel.tables:
map_class(tabledef)

25 changes: 19 additions & 6 deletions specifyweb/backend/stored_queries/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from specifyweb.backend.stored_queries.queryfield import QueryField

from . import models
from .build_models import get_many_to_many_join_info
from .group_concat import group_concat
from .blank_nulls import blank_nulls
from .query_construct import QueryConstruct
Expand Down Expand Up @@ -342,12 +343,24 @@ def aggregate(self, query: QueryConstruct,
limit = None if limit == '' or int(limit) == 0 else limit
orm_table = getattr(models, field.relatedModelName)

join_column = list(inspect(
getattr(orm_table, field.otherSideName)).property.local_columns)[0]
subquery_query = Query([]) \
.select_from(orm_table) \
.filter(join_column == rel_table._id) \
.correlate(rel_table)
join_info = get_many_to_many_join_info(datamodel, field)
if join_info is not None:
secondary_table = models.tables[join_info.table]
subquery_query = Query([]) \
.select_from(orm_table) \
.join(
secondary_table,
secondary_table.c[join_info.remote_column] == orm_table._id,
) \
.filter(secondary_table.c[join_info.local_column] == rel_table._id) \
.correlate(rel_table)
else:
join_column = list(inspect(
getattr(orm_table, field.otherSideName)).property.local_columns)[0]
subquery_query = Query([]) \
.select_from(orm_table) \
.filter(join_column == rel_table._id) \
.correlate(rel_table)

try:
from_table_name = query.query.selectable.froms[0].name.lower()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,37 @@
from unittest.mock import patch, Mock


PROJECT_FORMATTERS = """
<formatters>
<format
name="Project"
title="Project"
class="edu.ku.brc.specify.datamodel.Project"
default="true"
>
<switch single="true">
<fields>
<field>projectName</field>
</fields>
</switch>
</format>
<aggregators>
<aggregator
name="Project"
title="Project"
class="edu.ku.brc.specify.datamodel.Project"
default="true"
separator="; "
ending=""
count="0"
format="Project"
orderfieldname=""
/>
</aggregators>
</formatters>
"""


class TestRunEphemeralQuery(SQLAlchemySetup):

@staticmethod
Expand Down Expand Up @@ -97,6 +128,99 @@ def test_query_with_displayed_date_parts(self, context: Mock):
result,
)

@patch("specifyweb.backend.stored_queries.format.app_resource.get_app_resource")
@patch("specifyweb.backend.stored_queries.execution.models.session_context")
def test_query_project_many_to_many_relationship(
self, context: Mock, get_app_resource: Mock
):
context.side_effect = TestRunEphemeralQuery.test_session_context
get_app_resource.return_value = (PROJECT_FORMATTERS, None, None)

project = models.Project.objects.create(
collectionmemberid=self.collection.id,
projectname="Test Project",
)
models.Project_colobj.objects.create(
project=project,
collectionobject=self.collectionobjects[0],
)

query = deepcopy(simple_query)
query["fields"] = [
{
"fieldname": "projectName",
"formatname": None,
"isdisplay": True,
"isnot": False,
"isrelfld": False,
"operstart": 8,
"position": 0,
"sorttype": 0,
"startvalue": "",
"stringid": "1,66-projects.project.projectName",
"isstrict": False,
},
{
"fieldname": "projects",
"formatname": None,
"isdisplay": True,
"isnot": False,
"isrelfld": True,
"operstart": 8,
"position": 1,
"sorttype": 0,
"startvalue": "",
"stringid": "1,66-projects.project.projects",
"isstrict": False,
},
]

result = run_ephemeral_query(self.collection, self.specifyuser, query)

self.assertEqual(
{
"results": [
(self.collectionobjects[0].id, "Test Project", "Test Project"),
(self.collectionobjects[1].id, None, ""),
(self.collectionobjects[2].id, None, ""),
(self.collectionobjects[3].id, None, ""),
(self.collectionobjects[4].id, None, ""),
]
},
result,
)

accession = models.Accession.objects.create(
accessionnumber="2026-001",
division=self.division,
)
self._update(self.collectionobjects[0], {"accession": accession})

query = deepcopy(simple_query)
query["contexttableid"] = 7
query["fields"] = [
{
"fieldname": "projects",
"formatname": None,
"isdisplay": True,
"isnot": False,
"isrelfld": True,
"operstart": 8,
"position": 0,
"sorttype": 0,
"startvalue": "",
"stringid": "7,1-collectionObjects,66-projects.project.projects",
"isstrict": False,
},
]

result = run_ephemeral_query(self.collection, self.specifyuser, query)

self.assertEqual(
{"results": [(accession.id, "Test Project")]},
result,
)


class TestRunEphemeralQueryByRank(SqlTreeSetup):

Expand Down Expand Up @@ -147,4 +271,4 @@ def test_negated_contains_on_tree_rank_field(self, context: Mock):
(self.collectionobjects[3].id, None),
(self.collectionobjects[4].id, None),
],
)
)
2 changes: 0 additions & 2 deletions specifyweb/backend/stored_queries/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def test_sqlalchemy_model_errors(self):
"AutoNumberingScheme": {"not_found": ["collections", "disciplines", "divisions"]},
"Collection": {"not_found": ["numberingSchemes", "userGroups"]},
"CollectionObject": {
"not_found": ["projects"],
"incorrect_direction": {"cojo": ["onetomany", "onetoone"]},
},
"DNASequencingRun": {
Expand All @@ -286,7 +285,6 @@ def test_sqlalchemy_model_errors(self):
"localityDetails": ["onetomany", "zerotoone"],
}
},
"Project": {"not_found": ["collectionObjects"]},
"SpExportSchema": {"not_found": ["spExportSchemaMappings"]},
"SpExportSchemaMapping": {"not_found": ["spExportSchemas"]},
"SpPermission": {"not_found": ["principals"]},
Expand Down
Loading