diff --git a/test-requirements.txt b/test-requirements.txt index 61783e4..96b77f6 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,10 +1,12 @@ fakeredis httpx lmdb +networkx numpy pytest pytest-asyncio pytest-cov pytest-mock +scipy tox tox-venv \ No newline at end of file diff --git a/tests/unit/aragorn/test_aragorn_examine_query.py b/tests/unit/aragorn/test_aragorn_examine_query.py new file mode 100644 index 0000000..6f1ed1a --- /dev/null +++ b/tests/unit/aragorn/test_aragorn_examine_query.py @@ -0,0 +1,244 @@ +"""Tests for ``workers.aragorn.worker.examine_query`` and the workflow +selection logic on the entrypoint. + +These exercise the pure-logic ``examine_query`` against the same query +shapes the production code routes (lookup, infer, pathfinder, mixed) and +verify that the auto-generated workflow matches the query type. +""" + +import copy +import json +import logging + +import pytest + +from tests.helpers.generate_messages import creative_query +from workers.aragorn.worker import aragorn, examine_query + +logger = logging.getLogger(__name__) + + +def _make_task(message_lookup, workflow=None): + """Build the (msg_id, fields) tuple a worker receives.""" + return [ + "test", + { + "query_id": "test", + "response_id": "test_response", + "workflow": json.dumps(workflow), + "log_level": "20", + "otel": json.dumps({}), + }, + ] + + +def test_examine_query_pure_lookup_returns_no_question_or_answer(): + """All-lookup queries: infer=False, pathfinder=False, no nodes returned.""" + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["MONDO:1"]}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + } + } + infer, q, a, pathfinder = examine_query(msg) + assert (infer, pathfinder) == (False, False) + assert q is None and a is None + + +def test_examine_query_inferred_one_hop_returns_question_and_answer(): + msg = copy.deepcopy(creative_query) + infer, q, a, pathfinder = examine_query(msg) + assert infer is True + assert pathfinder is False + # The fixture pins ON, leaves SN unbound -> answer node SN + assert q == "ON" and a == "SN" + + +def test_examine_query_pathfinder_returns_pathfinder_true(): + msg = { + "message": { + "query_graph": { + "nodes": {}, + "edges": {}, + "paths": {"p1": {"subject": "n0", "object": "n1"}}, + } + } + } + infer, q, a, pathfinder = examine_query(msg) + assert pathfinder is True + # No edges, so not infer + assert infer is False + + +def test_examine_query_rejects_multiple_paths(): + msg = { + "message": { + "query_graph": { + "nodes": {}, + "edges": {}, + "paths": {"p1": {}, "p2": {}}, + } + } + } + with pytest.raises(Exception, match="single path"): + examine_query(msg) + + +def test_examine_query_rejects_mixed_path_and_edges(): + msg = { + "message": { + "query_graph": { + "nodes": {}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + "paths": {"p1": {}}, + } + } + } + with pytest.raises(Exception, match="Mixed mode pathfinder"): + examine_query(msg) + + +def test_examine_query_rejects_multiple_inferred_edges(): + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": { + "e0": {"subject": "a", "object": "b", "knowledge_type": "inferred"}, + "e1": {"subject": "a", "object": "b", "knowledge_type": "inferred"}, + }, + } + } + } + with pytest.raises(Exception, match="single infer edge"): + examine_query(msg) + + +def test_examine_query_rejects_mixed_lookup_and_infer(): + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": { + "e0": {"subject": "a", "object": "b", "knowledge_type": "inferred"}, + "e1": {"subject": "a", "object": "b"}, + }, + } + } + } + with pytest.raises(Exception, match="Mixed infer and lookup"): + examine_query(msg) + + +def test_examine_query_rejects_both_creative_nodes_pinned(): + msg = { + "message": { + "query_graph": { + "nodes": { + "a": {"ids": ["X:1"]}, + "b": {"ids": ["Y:2"]}, + }, + "edges": { + "e0": { + "subject": "a", + "object": "b", + "knowledge_type": "inferred", + }, + }, + } + } + } + with pytest.raises(Exception, match="Both nodes of creative edge pinned"): + examine_query(msg) + + +def test_examine_query_rejects_no_creative_node_pinned(): + msg = { + "message": { + "query_graph": { + "nodes": { + "a": {}, + "b": {}, + }, + "edges": { + "e0": { + "subject": "a", + "object": "b", + "knowledge_type": "inferred", + }, + }, + } + } + } + with pytest.raises(Exception, match="No nodes of creative edge pinned"): + examine_query(msg) + + +@pytest.mark.asyncio +async def test_aragorn_pathfinder_workflow(redis_mock, mocker): + """A pathfinder query routes through aragorn.pathfinder and gandalf.rehydrate.""" + mocker.patch( + "workers.aragorn.worker.get_message", + return_value={ + "message": { + "query_graph": { + "nodes": {}, + "edges": {}, + "paths": {"p1": {"subject": "a", "object": "b"}}, + } + } + }, + ) + task = _make_task(None) + await aragorn(task, logger) + workflow = json.loads(task[1]["workflow"]) + assert [op["id"] for op in workflow] == [ + "aragorn.pathfinder", + "score_paths", + "sort_results_score", + "filter_analyses_top_n", + "filter_kgraph_orphans", + "gandalf.rehydrate", + ] + + +@pytest.mark.asyncio +async def test_aragorn_lookup_only_workflow(redis_mock, mocker): + """A pure-lookup query (no inferred edges) gets the lookup workflow.""" + mocker.patch( + "workers.aragorn.worker.get_message", + return_value={ + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + } + }, + ) + task = _make_task(None) + await aragorn(task, logger) + workflow = json.loads(task[1]["workflow"]) + assert [op["id"] for op in workflow] == [ + "aragorn.lookup", + "aragorn.omnicorp", + "aragorn.score", + "sort_results_score", + "filter_results_top_n", + "filter_kgraph_orphans", + ] + + +@pytest.mark.asyncio +async def test_aragorn_preexisting_workflow_is_preserved(redis_mock, mocker): + """If a workflow is already on the task, aragorn() should not overwrite it.""" + mocker.patch( + "workers.aragorn.worker.get_message", + return_value=copy.deepcopy(creative_query), + ) + custom_workflow = [{"id": "aragorn.lookup"}] + task = _make_task(custom_workflow, workflow=custom_workflow) + await aragorn(task, logger) + assert json.loads(task[1]["workflow"]) == custom_workflow diff --git a/tests/unit/aragorn/test_aragorn_lookup_helpers.py b/tests/unit/aragorn/test_aragorn_lookup_helpers.py new file mode 100644 index 0000000..f701ff7 --- /dev/null +++ b/tests/unit/aragorn/test_aragorn_lookup_helpers.py @@ -0,0 +1,381 @@ +"""Tests for the pure helpers in ``workers.aragorn_lookup.worker``. + +Covers ``examine_query``, ``get_infer_parameters``, ``get_rule_key``, and +``expand_aragorn_query``. The async ``aragorn_lookup`` entrypoint is already +covered in ``test_aragorn_lookup.py`` for the creative path; here we add +coverage for the non-infer / gandalf branches by mocking out the network. +""" + +import copy +import json +import logging + +import pytest + +from tests.helpers.generate_messages import creative_query +from workers.aragorn_lookup import worker as lookup_worker +from workers.aragorn_lookup.worker import ( + aragorn_lookup, + examine_query, + expand_aragorn_query, + get_infer_parameters, + get_rule_key, +) + +logger = logging.getLogger(__name__) + + +def test_examine_query_lookup_no_pinned_required(): + msg = { + "message": { + "query_graph": { + "nodes": {"a": {}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + } + } + infer, q, a, pathfinder = examine_query(msg) + assert (infer, pathfinder) == (False, False) + assert q is None and a is None + + +def test_examine_query_inferred_returns_question_and_answer(): + msg = copy.deepcopy(creative_query) + infer, q, a, pathfinder = examine_query(msg) + assert infer is True + assert pathfinder is False + assert (q, a) == ("ON", "SN") + + +def test_examine_query_pathfinder_three_inferred_edges(): + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}, "c": {}, "d": {}}, + "edges": { + "e0": {"subject": "a", "object": "b", "knowledge_type": "inferred"}, + "e1": {"subject": "b", "object": "c", "knowledge_type": "inferred"}, + "e2": {"subject": "c", "object": "d", "knowledge_type": "inferred"}, + }, + } + } + } + _, _, _, pathfinder = examine_query(msg) + assert pathfinder is True + + +def test_examine_query_rejects_two_inferred_when_not_pathfinder(): + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": { + "e0": {"subject": "a", "object": "b", "knowledge_type": "inferred"}, + "e1": {"subject": "a", "object": "b", "knowledge_type": "inferred"}, + }, + } + } + } + with pytest.raises(Exception, match="single infer edge"): + examine_query(msg) + + +def test_get_infer_parameters_extracts_source_input_form(): + """A creative query with the subject pinned: source_input=True, input_id is the + subject's id.""" + msg = copy.deepcopy(creative_query) + # creative_query pins ON (the object). Force the subject to be pinned for this case. + msg["message"]["query_graph"]["nodes"]["SN"]["ids"] = ["CHEBI:1"] + msg["message"]["query_graph"]["nodes"]["ON"].pop("ids", None) + input_id, predicate, qualifiers, source, source_input, target, qedge = ( + get_infer_parameters(msg) + ) + assert input_id == "CHEBI:1" + assert predicate == "biolink:treats" + assert qualifiers == {} + assert source == "SN" + assert target == "ON" + assert source_input is True + assert qedge == "e0" + + +def test_get_infer_parameters_extracts_target_input_form(): + """When the object is pinned (creative_query default): source_input=False.""" + msg = copy.deepcopy(creative_query) + input_id, _, _, _, source_input, _, _ = get_infer_parameters(msg) + assert source_input is False + assert input_id == "MONDO:0001" + + +def test_get_infer_parameters_with_qualifier_constraints(): + msg = copy.deepcopy(creative_query) + msg["message"]["query_graph"]["edges"]["e0"]["qualifier_constraints"] = [ + { + "qualifier_set": [ + { + "qualifier_type_id": "biolink:object_aspect_qualifier", + "qualifier_value": "activity", + } + ] + } + ] + _, _, qualifiers, _, _, _, _ = get_infer_parameters(msg) + assert qualifiers == { + "qualifier_constraints": [ + { + "qualifier_set": [ + { + "qualifier_type_id": "biolink:object_aspect_qualifier", + "qualifier_value": "activity", + } + ] + } + ] + } + + +def test_get_rule_key_no_qualifiers_returns_predicate_only(): + key = get_rule_key("biolink:treats", {}, logger) + assert json.loads(key) == {"predicate": "biolink:treats"} + + +def test_get_rule_key_with_aspect_and_direction(): + qualifiers = { + "qualifier_constraints": [ + { + "qualifier_set": [ + { + "qualifier_type_id": "biolink:object_aspect_qualifier", + "qualifier_value": "activity", + }, + { + "qualifier_type_id": "biolink:object_direction_qualifier", + "qualifier_value": "increased", + }, + ] + } + ] + } + key = get_rule_key("biolink:affects", qualifiers, logger) + assert json.loads(key) == { + "object_aspect_qualifier": "activity", + "object_direction_qualifier": "increased", + "predicate": "biolink:affects", + } + + +def test_get_rule_key_empty_qualifier_constraints_falls_back_to_predicate(): + """If qualifier_constraints is an empty list, only predicate ends up in the key.""" + key = get_rule_key("biolink:treats", {"qualifier_constraints": []}, logger) + assert json.loads(key) == {"predicate": "biolink:treats"} + + +def test_get_rule_key_empty_qualifier_set_falls_back_to_predicate(): + key = get_rule_key( + "biolink:treats", {"qualifier_constraints": [{"qualifier_set": []}]}, logger + ) + assert json.loads(key) == {"predicate": "biolink:treats"} + + +def test_expand_aragorn_query_includes_direct_query_with_no_expansions(mocker): + """Without any matching AMIE rule, expand_aragorn_query still emits a single + direct (non-inferred) query.""" + mocker.patch( + "workers.aragorn_lookup.worker.json.load", + return_value={}, # empty AMIE expansions + ) + msg = copy.deepcopy(creative_query) + msg["parameters"] = {"timeout": 60, "tiers": [0]} + msg["submitter"] = "test" + out = expand_aragorn_query(msg, logger) + assert len(out) == 1 + direct = out[0] + assert "knowledge_type" not in direct["message"]["query_graph"]["edges"]["e0"] + + +def test_expand_aragorn_query_appends_amie_rule_template(mocker): + """When AMIE has a rule matching the query key, an extra expanded message is + appended.""" + msg = copy.deepcopy(creative_query) + msg["parameters"] = {"timeout": 60, "tiers": [0]} + msg["submitter"] = "test" + + # A trivial 1-edge expansion template; the worker expects the template to + # contain a top-level "query_graph" key. + rule_template = { + "query_graph": { + "nodes": { + "$source": { + "categories": ["biolink:ChemicalEntity"], + "ids": ["$source_id"], + }, + "$target": { + "categories": ["biolink:DiseaseOrPhenotypicFeature"], + "ids": ["$target_id"], + }, + }, + "edges": { + "expanded": { + "subject": "$source", + "object": "$target", + "predicates": ["biolink:related_to"], + } + }, + } + } + + mocker.patch( + "workers.aragorn_lookup.worker.json.load", + return_value={ + json.dumps({"predicate": "biolink:treats"}): [{"template": rule_template}], + }, + ) + + out = expand_aragorn_query(msg, logger) + # Direct query + one AMIE expansion. + assert len(out) == 2 + expanded = out[1] + qg = expanded["message"]["query_graph"] + # Target was the unpinned node; its IDs list should have been removed. + assert "ids" not in qg["nodes"]["SN"] + # Source (pinned input) keeps its CURIE through template substitution. + assert qg["nodes"]["ON"]["ids"] == ["MONDO:0001"] + + +@pytest.mark.asyncio +async def test_aragorn_lookup_handles_examine_query_failure(redis_mock, mocker): + """When examine_query raises (e.g. mixed query), aragorn_lookup logs and + returns ``(None, 500)`` rather than propagating the exception.""" + bad_message = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": { + "e0": {"subject": "a", "object": "b", "knowledge_type": "inferred"}, + "e1": {"subject": "a", "object": "b"}, # mixed -> raises + }, + } + } + } + mocker.patch( + "workers.aragorn_lookup.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=bad_message, + ) + + task = [ + "test", + { + "query_id": "test", + "response_id": "test_response", + "workflow": json.dumps([{"id": "aragorn.lookup"}]), + "log_level": "20", + "otel": json.dumps({}), + }, + ] + out = await aragorn_lookup(task, logger) + assert out == (None, 500) + + +@pytest.mark.asyncio +async def test_aragorn_lookup_pure_lookup_path_calls_kg_retrieval(redis_mock, mocker): + """A non-inferred query (no creative work needed) goes straight to the kg + retrieval URL without expansion.""" + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + }, + # Non-zero timeout so the polling loop body executes once and breaks + # when get_running_callbacks returns []. + "parameters": {"timeout": 5}, + } + mocker.patch( + "workers.aragorn_lookup.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=msg, + ) + mocker.patch( + "workers.aragorn_lookup.worker.add_callback_id", + new_callable=mocker.AsyncMock, + ) + mock_running = mocker.patch( + "workers.aragorn_lookup.worker.get_running_callbacks", + new_callable=mocker.AsyncMock, + return_value=[], + ) + mock_post = mocker.patch( + "httpx.AsyncClient.post", + new_callable=mocker.AsyncMock, + return_value=mocker.Mock(status_code=200), + ) + + task = [ + "test", + { + "query_id": "test", + "response_id": "test_response", + "workflow": json.dumps([{"id": "aragorn.lookup"}]), + "log_level": "20", + "otel": json.dumps({}), + }, + ] + await aragorn_lookup(task, logger) + assert mock_post.called + assert mock_running.called + + +@pytest.mark.asyncio +async def test_aragorn_lookup_gandalf_lookup_uses_add_task(redis_mock, mocker): + """When ``parameters.gandalf=True``, the lookup hands off to the gandalf + stream via add_task instead of hitting the kg retrieval URL.""" + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + }, + "parameters": {"gandalf": True, "timeout": 5}, + } + mocker.patch( + "workers.aragorn_lookup.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=msg, + ) + mocker.patch( + "workers.aragorn_lookup.worker.add_callback_id", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "workers.aragorn_lookup.worker.save_message", + new_callable=mocker.AsyncMock, + ) + mock_add_task = mocker.patch( + "workers.aragorn_lookup.worker.add_task", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "workers.aragorn_lookup.worker.get_running_callbacks", + new_callable=mocker.AsyncMock, + return_value=[], + ) + + task = [ + "test", + { + "query_id": "test", + "response_id": "test_response", + "workflow": json.dumps([{"id": "aragorn.lookup"}]), + "log_level": "20", + "otel": json.dumps({}), + }, + ] + await aragorn_lookup(task, logger) + # The gandalf add_task should have been called with the gandalf stream. + assert mock_add_task.called + args, _ = mock_add_task.call_args + assert args[0] == "gandalf" + assert args[1]["target"] == "aragorn" diff --git a/tests/unit/aragorn/test_aragorn_score_helpers.py b/tests/unit/aragorn/test_aragorn_score_helpers.py new file mode 100644 index 0000000..5cb0ad5 --- /dev/null +++ b/tests/unit/aragorn/test_aragorn_score_helpers.py @@ -0,0 +1,296 @@ +"""Tests for the pure helpers in ``workers.aragorn_score.worker``. + +Covers: + +- ``get_base_weight`` / ``get_source_weight`` / ``get_source_sigmoid`` — + per-source weight lookups and the 0-centered sigmoid transform. +- ``get_profile`` — selects between the blended/clinical/correlated/curated + weight profiles. +- ``kirchhoff`` — Kirchhoff index given a graph laplacian. +- ``get_edge_support_kg`` — recursive support-graph kg extraction. +""" + +import math + +import numpy as np +import pytest + +from workers.aragorn_score.worker import ( + BLENDED_PROFILE, + CLINICAL_PROFILE, + CORRELATED_PROFILE, + CURATED_PROFILE, + DEFAULT_WEIGHT, + get_base_weight, + get_edge_support_kg, + get_profile, + get_source_sigmoid, + get_source_weight, + kirchhoff, +) + +# --- get_base_weight ------------------------------------------------------ + + +def test_get_base_weight_known_source(): + assert get_base_weight("infores:omnicorp") == 0 + + +def test_get_base_weight_unknown_source_falls_back_to_default(): + assert get_base_weight("infores:not-real") == DEFAULT_WEIGHT + + +def test_get_base_weight_custom_weights_table(): + """When a caller provides a custom base_weights table (no default_weight + key), unknown sources still resolve via DEFAULT_WEIGHT.""" + out = get_base_weight("infores:not-real", base_weights={"infores:foo": 0.7}) + assert out == DEFAULT_WEIGHT + + +# --- get_source_weight ---------------------------------------------------- + + +def test_get_source_weight_known_source_known_property(): + assert get_source_weight("infores:omnicorp", "literature_co-occurrence") == 1 + + +def test_get_source_weight_known_source_unknown_property(): + """Unknown property of a known source -> 0 (unknown_property).""" + out = get_source_weight("infores:omnicorp", "not-real") + assert out == 0 + + +def test_get_source_weight_unknown_source_uses_unknown_source_weight(): + """Unknown source falls back to the unknown_source_weight table.""" + out = get_source_weight("infores:nope", "publications") + assert out == 1 + + +# --- get_source_sigmoid --------------------------------------------------- + + +def test_sigmoid_at_midpoint_is_centered(): + """At the midpoint, the sigmoid evaluates to (lower + upper) / 2.""" + parameters = BLENDED_PROFILE["unknown_source_transformation"]["publications"] + val = get_source_sigmoid(parameters["midpoint"]) + assert val == pytest.approx((parameters["lower"] + parameters["upper"]) / 2) + + +def test_sigmoid_saturates_below_for_negative_rate(): + """A negative rate (e.g. p_value) means high values approach the lower + bound; low values approach the upper bound.""" + p_val_low = get_source_sigmoid( + 0.0, + source="infores:genetics-data-provider", + property="p_value", + ) + p_val_high = get_source_sigmoid( + 1.0, + source="infores:genetics-data-provider", + property="p_value", + ) + # Negative rate: high p_value -> low sigmoid; low p_value -> high sigmoid. + assert p_val_low > p_val_high + + +def test_sigmoid_unknown_property_uses_unknown_property_default(): + """Property not in the source's transformation table falls back to the + unknown_source_transformation['unknown_property'] entry, which has lower + and upper both at 0 -> sigmoid is always 0.""" + val = get_source_sigmoid( + 100.0, + source="infores:omnicorp", + property="not-real", + ) + assert val == 0 + + +# --- get_profile ---------------------------------------------------------- + + +@pytest.mark.parametrize( + "name, expected_table", + [ + ("blended", BLENDED_PROFILE), + ("clinical", CLINICAL_PROFILE), + ("correlated", CORRELATED_PROFILE), + ("curated", CURATED_PROFILE), + ], +) +def test_get_profile_returns_matching_tables(name, expected_table): + sw, usw, st, ust, bw = get_profile(name) + assert sw == expected_table["source_weights"] + assert usw == expected_table["unknown_source_weight"] + assert st == expected_table["source_transformation"] + assert ust == expected_table["unknown_source_transformation"] + assert bw == expected_table["base_weights"] + + +def test_get_profile_unknown_name_falls_back_to_blended(): + sw, *_ = get_profile("not-a-real-profile") + assert sw == BLENDED_PROFILE["source_weights"] + + +# --- kirchhoff ------------------------------------------------------------ + + +def test_kirchhoff_returns_neg_inf_for_invalid_probe_index(): + """When the probe references an index out of range of the laplacian, + the function catches the IndexError and returns -inf.""" + L = np.array([[1.0, -1.0], [-1.0, 1.0]]) + out = kirchhoff(L, [(0, 5)]) # 5 is out of range + assert out == -np.inf + + +def test_kirchhoff_two_node_unit_resistor_returns_unit_distance(): + """The Kirchhoff index between connected unit-weight nodes in a 2-node + graph is 1.""" + # Laplacian for a 2-node graph with a single unit edge. + L = np.array([[1.0, -1.0], [-1.0, 1.0]]) + val = kirchhoff(L, [(0, 1)]) + assert val == pytest.approx(1.0) + + +def test_kirchhoff_returns_real_finite_value_for_three_node_chain(): + """A 3-node line graph 0--1--2 with unit resistors has kirchhoff index 4 + summed over all pairs (1+1+4)/3 ... no wait, summed: (1, 1, 4) = 6 if all + pairs included. We test a single probe pair which is well-defined.""" + # Laplacian for a 3-node line graph. + L = np.array( + [ + [1.0, -1.0, 0.0], + [-1.0, 2.0, -1.0], + [0.0, -1.0, 1.0], + ] + ) + out = kirchhoff(L, [(0, 1)]) + # A single unit edge: effective resistance is 1. + assert out == pytest.approx(1.0, abs=1e-6) + + +# --- get_edge_support_kg -------------------------------------------------- + + +def test_get_edge_support_kg_returns_empty_when_edge_missing(): + """Calling on an edge id that's not in the kg returns the empty default.""" + out = get_edge_support_kg("missing", {"edges": {}}, {}) + assert out == {"node_ids": set(), "edge_ids": set()} + + +def test_get_edge_support_kg_returns_empty_when_edge_has_no_attributes(): + """An edge with no attributes contributes nothing (the function bails).""" + kg = { + "edges": { + "e1": {"subject": "A", "object": "B"}, + } + } + out = get_edge_support_kg("e1", kg, {}) + assert out == {"node_ids": set(), "edge_ids": set()} + + +def test_get_edge_support_kg_collects_edge_endpoints_when_attributes_present(): + """An edge with any attributes records the edge id and both endpoints.""" + kg = { + "edges": { + "e1": { + "subject": "A", + "object": "B", + "attributes": [ + {"attribute_type_id": "biolink:has_evidence", "value": 5}, + ], + } + } + } + out = get_edge_support_kg("e1", kg, {}) + assert out == {"node_ids": {"A", "B"}, "edge_ids": {"e1"}} + + +def test_get_edge_support_kg_recurses_into_support_graphs(): + """Support graphs reference more edges that should be collected too.""" + kg = { + "edges": { + "e1": { + "subject": "A", + "object": "B", + "attributes": [ + { + "attribute_type_id": "biolink:support_graphs", + "value": ["aux1"], + }, + ], + }, + "e_support": { + "subject": "B", + "object": "C", + "attributes": [ + {"attribute_type_id": "biolink:has_evidence", "value": 1}, + ], + }, + } + } + aux_graphs = {"aux1": {"edges": ["e_support"], "nodes": ["EXTRA"]}} + out = get_edge_support_kg("e1", kg, aux_graphs) + assert out["edge_ids"] == {"e1", "e_support"} + assert {"A", "B", "C", "EXTRA"}.issubset(out["node_ids"]) + + +def test_get_edge_support_kg_skips_missing_aux_graph(): + """Support graphs that aren't in aux_graphs are silently skipped.""" + kg = { + "edges": { + "e1": { + "subject": "A", + "object": "B", + "attributes": [ + { + "attribute_type_id": "biolink:support_graphs", + "value": ["doesnt_exist"], + }, + ], + } + } + } + out = get_edge_support_kg("e1", kg, {}) + assert out["edge_ids"] == {"e1"} + + +def test_get_edge_support_kg_skips_missing_support_edges(): + """Support graph edges that aren't in the kg are silently skipped (defends + against malformed TRAPI).""" + kg = { + "edges": { + "e1": { + "subject": "A", + "object": "B", + "attributes": [ + { + "attribute_type_id": "biolink:support_graphs", + "value": ["aux1"], + }, + ], + } + } + } + aux_graphs = {"aux1": {"edges": ["nonexistent"], "nodes": []}} + out = get_edge_support_kg("e1", kg, aux_graphs) + assert out["edge_ids"] == {"e1"} + + +def test_get_edge_support_kg_threads_through_existing_accumulator(): + """If the caller passes in an existing edge_kg, results are merged into it.""" + kg = { + "edges": { + "e1": { + "subject": "A", + "object": "B", + "attributes": [ + {"attribute_type_id": "biolink:has_evidence", "value": 1}, + ], + } + } + } + pre_existing = {"node_ids": {"PRE"}, "edge_ids": {"PRE_EDGE"}} + out = get_edge_support_kg("e1", kg, {}, edge_kg=pre_existing) + assert "PRE" in out["node_ids"] + assert "PRE_EDGE" in out["edge_ids"] + assert "A" in out["node_ids"] and "e1" in out["edge_ids"] diff --git a/tests/unit/aragorn/test_aragorn_score_ranker.py b/tests/unit/aragorn/test_aragorn_score_ranker.py new file mode 100644 index 0000000..5bb28fa --- /dev/null +++ b/tests/unit/aragorn/test_aragorn_score_ranker.py @@ -0,0 +1,675 @@ +"""Tests for the ``Ranker`` class in ``workers.aragorn_score.worker``. + +Exercises the methods used during scoring: + +- ``__init__`` profile selection +- ``probes`` (probe node selection from the qgraph) +- ``get_rgraph`` (per-analysis r-graph construction, support graphs) +- ``get_omnicorp_node_pubs`` (publication-count attribute extraction) +- ``get_edge_values`` (attribute parsing for publications, p_value, + literature co-occurrence, affinity) +- ``graph_laplacian`` (weighted laplacian + zero-row pruning) +- ``rank`` (sorted result ordering after scoring) +- ``score`` (per-analysis scoring with degenerate laplacian -> 0) +""" + +import copy +import logging + +import numpy as np +import pytest + +from tests.helpers.generate_messages import response_1 +from workers.aragorn_score.worker import ( + BLENDED_PROFILE, + CLINICAL_PROFILE, + Ranker, +) + +logger = logging.getLogger(__name__) + + +# --- __init__ profile selection ------------------------------------------ + + +def test_ranker_init_uses_blended_profile_by_default(): + msg = response_1["message"] + r = Ranker(msg, logger) + assert r.source_weights == BLENDED_PROFILE["source_weights"] + assert r.base_weights == BLENDED_PROFILE["base_weights"] + + +def test_ranker_init_can_select_clinical_profile(): + msg = response_1["message"] + r = Ranker(msg, logger, profile="clinical") + assert r.source_weights == CLINICAL_PROFILE["source_weights"] + + +def test_ranker_init_handles_minimal_message(): + """A minimal message shouldn't crash; ranker handles missing kg/qgraph.""" + r = Ranker({}, logger) + assert r.kgraph == {"nodes": {}, "edges": {}} + assert r.qgraph == {"nodes": {}, "edges": {}} + assert r.agraphs == {} + + +# --- probes -------------------------------------------------------------- + + +def test_probes_returns_q_node_pairs_for_one_hop_graph(): + """For a 1-edge query graph, probes() returns a single (n0, n1) pair.""" + msg = { + "query_graph": { + "nodes": {"n0": {}, "n1": {}}, + "edges": {"e0": {"subject": "n0", "object": "n1"}}, + }, + "knowledge_graph": {"nodes": {}, "edges": {}}, + } + r = Ranker(msg, logger) + probes = r.probes() + assert len(probes) == 1 + assert set(probes[0]) == {"n0", "n1"} + + +def test_probes_matches_response_1_fixture(): + """The shared TRAPI fixture has n0->n1 connectivity; probes returns one + pair containing both query node ids.""" + msg = response_1["message"] + r = Ranker(msg, logger) + probes = r.probes() + assert len(probes) >= 1 + flat = {n for pair in probes for n in pair} + assert flat <= set(msg["query_graph"]["nodes"]) + + +# --- get_omnicorp_node_pubs --------------------------------------------- + + +def test_get_omnicorp_node_pubs_finds_attribute(): + """Picks up the omnicorp_article_count attribute and caches it.""" + msg = { + "knowledge_graph": { + "nodes": { + "MONDO:1": { + "attributes": [ + { + "original_attribute_name": "omnicorp_article_count", + "value": 100, + } + ] + } + }, + "edges": {}, + }, + "query_graph": {"nodes": {}, "edges": {}}, + } + r = Ranker(msg, logger) + assert r.get_omnicorp_node_pubs("MONDO:1") == 100 + # Cached on the instance. + assert r.node_pubs["MONDO:1"] == 100 + # Second call hits the cache. + assert r.get_omnicorp_node_pubs("MONDO:1") == 100 + + +def test_get_omnicorp_node_pubs_recognises_alternate_name(): + """The newer ``num_publications`` attribute name is also recognised.""" + msg = { + "knowledge_graph": { + "nodes": { + "MONDO:1": { + "attributes": [ + {"original_attribute_name": "num_publications", "value": 7} + ] + } + }, + "edges": {}, + }, + "query_graph": {"nodes": {}, "edges": {}}, + } + r = Ranker(msg, logger) + assert r.get_omnicorp_node_pubs("MONDO:1") == 7 + + +def test_get_omnicorp_node_pubs_returns_zero_when_no_attribute(): + msg = { + "knowledge_graph": { + "nodes": {"MONDO:1": {"attributes": []}}, + "edges": {}, + }, + "query_graph": {"nodes": {}, "edges": {}}, + } + r = Ranker(msg, logger) + assert r.get_omnicorp_node_pubs("MONDO:1") == 0 + + +def test_get_omnicorp_node_pubs_handles_unparseable_value(): + """A non-numeric string value falls back to 0 rather than raising.""" + msg = { + "knowledge_graph": { + "nodes": { + "MONDO:1": { + "attributes": [ + { + "original_attribute_name": "omnicorp_article_count", + "value": "not-a-number", + } + ] + } + }, + "edges": {}, + }, + "query_graph": {"nodes": {}, "edges": {}}, + } + r = Ranker(msg, logger) + assert r.get_omnicorp_node_pubs("MONDO:1") == 0 + + +def test_get_omnicorp_node_pubs_raises_for_unknown_node(): + msg = { + "knowledge_graph": {"nodes": {}, "edges": {}}, + "query_graph": {"nodes": {}, "edges": {}}, + } + r = Ranker(msg, logger) + with pytest.raises(KeyError, match="ghost"): + r.get_omnicorp_node_pubs("ghost") + + +# --- get_edge_values ----------------------------------------------------- + + +def _make_msg_with_edge(edge): + return { + "knowledge_graph": { + "nodes": {}, + "edges": {"e1": edge}, + }, + "query_graph": {"nodes": {}, "edges": {}}, + } + + +def test_get_edge_values_publications_split_by_pipe(): + """``["PMID:1|2|3"]`` should be split into 3 publications.""" + edge = { + "subject": "A", + "object": "B", + "sources": [ + {"resource_id": "infores:foo", "resource_role": "primary_knowledge_source"} + ], + "attributes": [ + {"original_attribute_name": "publications", "value": ["PMID:1|2|3"]} + ], + } + r = Ranker(_make_msg_with_edge(edge), logger) + vals = r.get_edge_values("e1") + # publications adds a 'publications' entry under the source. + pub_data = vals["infores:foo"]["publications"] + assert pub_data["value"] == 3 + + +def test_get_edge_values_publications_split_by_comma(): + edge = { + "subject": "A", + "object": "B", + "sources": [ + {"resource_id": "infores:foo", "resource_role": "primary_knowledge_source"} + ], + "attributes": [ + {"original_attribute_name": "publications", "value": ["PMID:1,2,3,4"]} + ], + } + r = Ranker(_make_msg_with_edge(edge), logger) + vals = r.get_edge_values("e1") + assert vals["infores:foo"]["publications"]["value"] == 4 + + +def test_get_edge_values_publications_string_value_becomes_single_pub(): + """A bare string value (not a list) turns into a 1-element list.""" + edge = { + "subject": "A", + "object": "B", + "sources": [ + {"resource_id": "infores:foo", "resource_role": "primary_knowledge_source"} + ], + "attributes": [{"original_attribute_name": "publications", "value": "PMID:1"}], + } + r = Ranker(_make_msg_with_edge(edge), logger) + vals = r.get_edge_values("e1") + assert vals["infores:foo"]["publications"]["value"] == 1 + + +def test_get_edge_values_evidence_count_overwrites_num_publications(): + edge = { + "subject": "A", + "object": "B", + "sources": [ + {"resource_id": "infores:foo", "resource_role": "primary_knowledge_source"} + ], + "attributes": [ + {"attribute_type_id": "biolink:evidence_count", "value": 42}, + ], + } + r = Ranker(_make_msg_with_edge(edge), logger) + vals = r.get_edge_values("e1") + assert vals["infores:foo"]["publications"]["value"] == 42 + + +def test_get_edge_values_p_value_extracts_numeric(): + edge = { + "subject": "A", + "object": "B", + "sources": [ + { + "resource_id": "infores:genetics-data-provider", + "resource_role": "primary_knowledge_source", + } + ], + "attributes": [ + {"original_attribute_name": "p_value", "value": 0.001}, + ], + } + r = Ranker(_make_msg_with_edge(edge), logger) + vals = r.get_edge_values("e1") + p_data = vals["infores:genetics-data-provider"]["p_value"] + assert p_data["value"] == 0.001 + + +def test_get_edge_values_p_value_unwraps_list(): + edge = { + "subject": "A", + "object": "B", + "sources": [ + { + "resource_id": "infores:genetics-data-provider", + "resource_role": "primary_knowledge_source", + } + ], + "attributes": [ + {"original_attribute_name": "p_value", "value": [0.005]}, + ], + } + r = Ranker(_make_msg_with_edge(edge), logger) + vals = r.get_edge_values("e1") + p_data = vals["infores:genetics-data-provider"]["p_value"] + assert p_data["value"] == 0.005 + + +def test_get_edge_values_p_value_string_parsed_to_float(): + edge = { + "subject": "A", + "object": "B", + "sources": [ + { + "resource_id": "infores:genetics-data-provider", + "resource_role": "primary_knowledge_source", + } + ], + "attributes": [ + {"original_attribute_name": "p_value", "value": "0.005"}, + ], + } + r = Ranker(_make_msg_with_edge(edge), logger) + vals = r.get_edge_values("e1") + p_data = vals["infores:genetics-data-provider"]["p_value"] + assert p_data["value"] == 0.005 + + +def test_get_edge_values_p_value_unparseable_string_logged_as_none(): + edge = { + "subject": "A", + "object": "B", + "sources": [ + { + "resource_id": "infores:genetics-data-provider", + "resource_role": "primary_knowledge_source", + } + ], + "attributes": [ + {"original_attribute_name": "p_value", "value": "not-a-number"}, + ], + } + r = Ranker(_make_msg_with_edge(edge), logger) + vals = r.get_edge_values("e1") + # Unparseable -> p_value remains None and so the source dict lacks p_value. + assert "p_value" not in vals["infores:genetics-data-provider"] + + +def test_get_edge_values_no_primary_source_uses_unspecified(): + """An edge with no primary_knowledge_source defaults to 'unspecified'.""" + edge = { + "subject": "A", + "object": "B", + "sources": [ + {"resource_id": "infores:foo", "resource_role": "supporting_data_source"}, + ], + "attributes": [ + {"original_attribute_name": "publications", "value": ["P:1"]}, + ], + } + r = Ranker(_make_msg_with_edge(edge), logger) + vals = r.get_edge_values("e1") + assert "unspecified" in vals + + +def test_get_edge_values_caches_per_edge_id(): + """Calling twice for the same edge id returns the cached dict.""" + edge = { + "subject": "A", + "object": "B", + "sources": [ + {"resource_id": "infores:foo", "resource_role": "primary_knowledge_source"} + ], + "attributes": [ + {"original_attribute_name": "publications", "value": ["PMID:1"]} + ], + } + r = Ranker(_make_msg_with_edge(edge), logger) + first = r.get_edge_values("e1") + second = r.get_edge_values("e1") + assert first is second # identity, not just equality + + +def test_get_edge_values_raises_for_unknown_edge(): + msg = { + "knowledge_graph": {"nodes": {}, "edges": {}}, + "query_graph": {"nodes": {}, "edges": {}}, + } + r = Ranker(msg, logger) + with pytest.raises(KeyError, match="ghost"): + r.get_edge_values("ghost") + + +def test_get_edge_values_affinity_and_confidence_score_extracted(): + edge = { + "subject": "A", + "object": "B", + "sources": [ + { + "resource_id": "infores:text-mining-provider-targeted", + "resource_role": "primary_knowledge_source", + } + ], + "attributes": [ + {"original_attribute_name": "affinity", "value": 0.5}, + {"original_attribute_name": "biolink:tmkp_confidence_score", "value": 0.8}, + ], + } + r = Ranker(_make_msg_with_edge(edge), logger) + vals = r.get_edge_values("e1") + assert vals["infores:text-mining-provider-targeted"]["affinity"]["value"] == 0.5 + + +def test_get_edge_values_literature_cooccurrence_with_omnicorp_predicate(): + """The literature_cooccurrence pathway requires the predicate + biolink:occurs_together_in_literature_with and a biolink:has_count + attribute. Subject and object node pubs are pulled from + omnicorp_article_count attributes.""" + msg = { + "knowledge_graph": { + "nodes": { + "A": { + "attributes": [ + { + "original_attribute_name": "omnicorp_article_count", + "value": 100, + } + ] + }, + "B": { + "attributes": [ + { + "original_attribute_name": "omnicorp_article_count", + "value": 100, + } + ] + }, + }, + "edges": { + "e1": { + "subject": "A", + "object": "B", + "predicate": "biolink:occurs_together_in_literature_with", + "sources": [ + { + "resource_id": "infores:omnicorp", + "resource_role": "primary_knowledge_source", + } + ], + "attributes": [ + {"attribute_type_id": "biolink:has_count", "value": 50}, + ], + } + }, + }, + "query_graph": {"nodes": {}, "edges": {}}, + } + r = Ranker(msg, logger) + vals = r.get_edge_values("e1") + assert "literature_coocurrence" in vals["infores:omnicorp"] + # Cov >= 0 by construction. + assert vals["infores:omnicorp"]["literature_coocurrence"]["value"] >= 0 + + +# --- get_rgraph ---------------------------------------------------------- + + +def test_get_rgraph_builds_one_rgraph_per_analysis(): + """Each analysis on a result yields its own r_graph.""" + msg = { + "query_graph": { + "nodes": {"n0": {}, "n1": {}}, + "edges": {"e0": {"subject": "n0", "object": "n1"}}, + }, + "knowledge_graph": { + "nodes": {"K0": {}, "K1": {}}, + "edges": { + "ke0": {"subject": "K0", "object": "K1", "attributes": []}, + }, + }, + "auxiliary_graphs": {}, + } + r = Ranker(msg, logger) + result = { + "node_bindings": { + "n0": [{"id": "K0"}], + "n1": [{"id": "K1"}], + }, + "analyses": [ + {"edge_bindings": {"e0": [{"id": "ke0"}]}}, + {"edge_bindings": {"e0": [{"id": "ke0"}]}}, + ], + } + r_graphs = r.get_rgraph(result) + assert len(r_graphs) == 2 + for rg in r_graphs: + assert rg["nodes"] == {"n0", "n1"} + # One edge tuple (n0, n1, ke0) + assert any(e[2] == "ke0" for e in rg["edges"]) + + +def test_get_rgraph_skips_edges_not_in_kgraph(): + """An edge_binding that references a missing kg edge is logged and + skipped (the r_graph just lacks that tuple).""" + msg = { + "query_graph": { + "nodes": {"n0": {}, "n1": {}}, + "edges": {"e0": {"subject": "n0", "object": "n1"}}, + }, + "knowledge_graph": {"nodes": {"K0": {}, "K1": {}}, "edges": {}}, + "auxiliary_graphs": {}, + } + r = Ranker(msg, logger) + result = { + "node_bindings": {"n0": [{"id": "K0"}], "n1": [{"id": "K1"}]}, + "analyses": [{"edge_bindings": {"e0": [{"id": "missing-edge"}]}}], + } + rgs = r.get_rgraph(result) + assert rgs[0]["edges"] == set() + + +def test_get_rgraph_pulls_in_support_graph_nodes_and_edges(): + """An analysis ``support_graphs`` entry pulls aux-graph nodes/edges into + the r_graph.""" + msg = { + "query_graph": { + "nodes": {"n0": {}, "n1": {}}, + "edges": {"e0": {"subject": "n0", "object": "n1"}}, + }, + "knowledge_graph": { + "nodes": {"K0": {}, "K1": {}, "K_SUP": {}}, + "edges": { + "ke0": {"subject": "K0", "object": "K1", "attributes": []}, + "k_sup_edge": { + "subject": "K0", + "object": "K_SUP", + "attributes": [], + }, + }, + }, + "auxiliary_graphs": { + "aux1": {"edges": ["k_sup_edge"], "nodes": ["K_SUP"]}, + }, + } + r = Ranker(msg, logger) + result = { + "node_bindings": {"n0": [{"id": "K0"}], "n1": [{"id": "K1"}]}, + "analyses": [ + { + "edge_bindings": {"e0": [{"id": "ke0"}]}, + "support_graphs": ["aux1"], + } + ], + } + rgs = r.get_rgraph(result) + rg = rgs[0] + edge_ids = {e[2] for e in rg["edges"]} + assert "k_sup_edge" in edge_ids + + +# --- graph_laplacian ---------------------------------------------------- + + +def test_graph_laplacian_two_node_graph_yields_2x2_matrix(): + """A trivial graph with one edge between two q-nodes returns a 2x2 + laplacian after pruning. Probe inds align to the kept nodes.""" + msg = { + "query_graph": { + "nodes": {"n0": {}, "n1": {}}, + "edges": {"e0": {"subject": "n0", "object": "n1"}}, + }, + "knowledge_graph": { + "nodes": {}, + "edges": { + "ke0": { + "subject": "K0", + "object": "K1", + "sources": [ + { + "resource_id": "infores:foo", + "resource_role": "primary_knowledge_source", + } + ], + "attributes": [ + {"original_attribute_name": "publications", "value": ["P:1"]} + ], + } + }, + }, + "auxiliary_graphs": {}, + } + r = Ranker(msg, logger) + r_graph = { + "nodes": {"n0", "n1"}, + "edges": {("n0", "n1", "ke0")}, + } + L, probe_inds, details = r.graph_laplacian(r_graph, [("n0", "n1")]) + assert L.shape == (2, 2) + # Diagonal is positive; off-diagonal negative. + assert L[0, 0] > 0 + assert L[0, 1] < 0 + # Probe inds map to kept node positions. + assert probe_inds == [(0, 1)] or probe_inds == [(1, 0)] + + +def test_graph_laplacian_drops_zero_rows_unless_probe(): + """Nodes that have no edges (zero rows) are pruned, but probes are + protected.""" + msg = { + "query_graph": { + "nodes": {"n0": {}, "n1": {}, "n2": {}}, + "edges": {"e0": {"subject": "n0", "object": "n1"}}, + }, + "knowledge_graph": { + "nodes": {}, + "edges": { + "ke0": { + "subject": "K0", + "object": "K1", + "sources": [ + { + "resource_id": "infores:foo", + "resource_role": "primary_knowledge_source", + } + ], + "attributes": [ + {"original_attribute_name": "publications", "value": ["P:1"]} + ], + } + }, + }, + "auxiliary_graphs": {}, + } + r = Ranker(msg, logger) + # n2 has no edges -> a zero row in the unfiltered laplacian. + r_graph = { + "nodes": {"n0", "n1", "n2"}, + "edges": {("n0", "n1", "ke0")}, + } + L, probe_inds, _ = r.graph_laplacian(r_graph, [("n0", "n1")]) + # n2 was dropped -> 2x2. + assert L.shape == (2, 2) + + +# --- score and rank ------------------------------------------------------ + + +def test_score_returns_float_for_well_formed_answer(): + """The fixture response has well-formed analyses; score() should return + floats in [0, 1] (after exp(-kirchhoff)).""" + msg = copy.deepcopy(response_1["message"]) + r = Ranker(msg, logger) + answer = msg["results"][0] + scored, details = r.score(answer) + score = scored["analyses"][0]["score"] + assert isinstance(score, float) + # Scores from the existing fixture are in (0.063, 0.064) range; broader + # check here so the test tolerates ranker drift. + assert -1 <= score <= 1 + + +def test_rank_returns_answers_sorted_by_score(): + """rank() returns answers ordered by their max analysis score.""" + msg = copy.deepcopy(response_1["message"]) + r = Ranker(msg, logger) + ranked = r.rank(copy.deepcopy(msg["results"])) + assert len(ranked) == len(msg["results"]) + # All ranked entries have analyses with scores assigned. + scores = [ + max(a.get("score", 0) for a in ans["analyses"]) + for ans in ranked + if ans.get("analyses") + ] + # Ascending order (rank() uses ``sorted`` without reverse=True). + assert scores == sorted(scores) + + +def test_score_jaccard_like_returns_score_over_one_minus_score(): + """jaccard_like=True transforms each score to s/(1-s).""" + msg = copy.deepcopy(response_1["message"]) + r = Ranker(msg, logger) + scored, _ = r.score(copy.deepcopy(msg["results"][0]), jaccard_like=True) + raw, _ = Ranker(msg, logger).score(copy.deepcopy(msg["results"][0])) + raw_score = raw["analyses"][0]["score"] + if 0 < raw_score < 1: + assert scored["analyses"][0]["score"] == pytest.approx( + raw_score / (1 - raw_score) + ) diff --git a/tests/unit/test_aragorn_pathfinder.py b/tests/unit/test_aragorn_pathfinder.py new file mode 100644 index 0000000..a73cdc1 --- /dev/null +++ b/tests/unit/test_aragorn_pathfinder.py @@ -0,0 +1,224 @@ +"""Tests for ``workers.aragorn_pathfinder.worker.shadowfax``. + +The pathfinder entry point validates the input pathfinder query, builds a +3-hop expanded query, and delegates to the gandalf stream. We exercise the +validation paths and the happy path (mocking out postgres + redis). +""" + +import json +import logging + +import pytest + +from workers.aragorn_pathfinder.worker import shadowfax + +logger = logging.getLogger(__name__) + + +def _make_task(workflow=None): + return [ + "test", + { + "query_id": "qid-1", + "response_id": "rid-1", + "workflow": json.dumps(workflow if workflow is not None else []), + "log_level": "20", + "otel": json.dumps({}), + }, + ] + + +def _pathfinder_message(constraints=None): + msg = { + "message": { + "query_graph": { + "nodes": { + "n0": {"ids": ["MONDO:0001"]}, + "n1": {"ids": ["MONDO:0002"]}, + }, + "paths": { + "p0": {"subject": "n0", "object": "n1"}, + }, + } + }, + "parameters": {"timeout": 5}, + } + if constraints is not None: + msg["message"]["query_graph"]["paths"]["p0"]["constraints"] = constraints + return msg + + +@pytest.mark.asyncio +async def test_shadowfax_happy_path_dispatches_to_gandalf(redis_mock, mocker): + """A valid pathfinder query: callback registered, threehop saved, gandalf + task enqueued, polling loop exits when no callbacks remain.""" + mocker.patch( + "workers.aragorn_pathfinder.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=_pathfinder_message(), + ) + mock_add_cb = mocker.patch( + "workers.aragorn_pathfinder.worker.add_callback_id", + new_callable=mocker.AsyncMock, + ) + mock_save = mocker.patch( + "workers.aragorn_pathfinder.worker.save_message", + new_callable=mocker.AsyncMock, + ) + mock_add_task = mocker.patch( + "workers.aragorn_pathfinder.worker.add_task", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "workers.aragorn_pathfinder.worker.get_running_callbacks", + new_callable=mocker.AsyncMock, + return_value=[], + ) + + await shadowfax(_make_task(), logger) + + assert mock_add_cb.called + assert mock_save.called + assert mock_add_task.called + target_stream, payload = mock_add_task.call_args.args[:2] + assert target_stream == "gandalf" + assert payload["target"] == "aragorn" + + # The threehop saved to redis should be a 3-edge query whose endpoints + # match the original pinned pathfinder nodes. + saved_id, saved_threehop = mock_save.call_args.args[:2] + qg = saved_threehop["message"]["query_graph"] + assert "n0" in qg["nodes"] and "n1" in qg["nodes"] + assert "intermediate_0" in qg["nodes"] and "intermediate_1" in qg["nodes"] + assert set(qg["edges"].keys()) == {"e0", "e1", "e2"} + # Endpoints stitch through the intermediates. + assert qg["edges"]["e0"]["subject"] == "n0" + assert qg["edges"]["e0"]["object"] == "intermediate_0" + assert qg["edges"]["e2"]["object"] == "n1" + + +@pytest.mark.asyncio +async def test_shadowfax_requires_two_distinct_pinned_nodes(redis_mock, mocker): + """Only one distinct pinned id is invalid for a pathfinder query.""" + msg = _pathfinder_message() + msg["message"]["query_graph"]["nodes"]["n1"]["ids"] = ["MONDO:0001"] # duplicate + mocker.patch( + "workers.aragorn_pathfinder.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=msg, + ) + with pytest.raises(Exception, match="two pinned nodes"): + await shadowfax(_make_task(), logger) + + +@pytest.mark.asyncio +async def test_shadowfax_rejects_multiple_constraints(redis_mock, mocker): + """Multiple constraints on the path is unsupported.""" + msg = _pathfinder_message( + constraints=[ + {"intermediate_categories": ["biolink:Gene"]}, + {"intermediate_categories": ["biolink:Disease"]}, + ] + ) + mocker.patch( + "workers.aragorn_pathfinder.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=msg, + ) + with pytest.raises(Exception, match="multiple constraints"): + await shadowfax(_make_task(), logger) + + +@pytest.mark.asyncio +async def test_shadowfax_rejects_multiple_intermediate_categories(redis_mock, mocker): + """A single constraint may not list multiple intermediate categories.""" + msg = _pathfinder_message( + constraints=[{"intermediate_categories": ["biolink:Gene", "biolink:Disease"]}] + ) + mocker.patch( + "workers.aragorn_pathfinder.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=msg, + ) + with pytest.raises(Exception, match="multiple intermediate categories"): + await shadowfax(_make_task(), logger) + + +@pytest.mark.asyncio +async def test_shadowfax_uses_intermediate_category_from_constraint(redis_mock, mocker): + """When a constraint provides an intermediate category, the threehop's + intermediates carry that category instead of biolink:NamedThing.""" + msg = _pathfinder_message( + constraints=[{"intermediate_categories": ["biolink:Gene"]}] + ) + mocker.patch( + "workers.aragorn_pathfinder.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=msg, + ) + mocker.patch( + "workers.aragorn_pathfinder.worker.add_callback_id", + new_callable=mocker.AsyncMock, + ) + mock_save = mocker.patch( + "workers.aragorn_pathfinder.worker.save_message", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "workers.aragorn_pathfinder.worker.add_task", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "workers.aragorn_pathfinder.worker.get_running_callbacks", + new_callable=mocker.AsyncMock, + return_value=[], + ) + + await shadowfax(_make_task(), logger) + threehop = mock_save.call_args.args[1] + nodes = threehop["message"]["query_graph"]["nodes"] + assert nodes["intermediate_0"]["categories"] == ["biolink:Gene"] + assert nodes["intermediate_1"]["categories"] == ["biolink:Gene"] + + +@pytest.mark.asyncio +async def test_shadowfax_propagates_gandalf_parameters(redis_mock, mocker): + """Custom gandalf_parameters in the input should ride along into the + saved threehop's parameters.""" + msg = _pathfinder_message() + msg["parameters"]["gandalf_parameters"] = { + "min_information_content": 1, + "max_node_degree": 10, + "dehydrated": False, + } + mocker.patch( + "workers.aragorn_pathfinder.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=msg, + ) + mocker.patch( + "workers.aragorn_pathfinder.worker.add_callback_id", + new_callable=mocker.AsyncMock, + ) + mock_save = mocker.patch( + "workers.aragorn_pathfinder.worker.save_message", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "workers.aragorn_pathfinder.worker.add_task", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "workers.aragorn_pathfinder.worker.get_running_callbacks", + new_callable=mocker.AsyncMock, + return_value=[], + ) + + await shadowfax(_make_task(), logger) + threehop = mock_save.call_args.args[1] + gp = threehop["parameters"]["gandalf_parameters"] + assert gp == { + "min_information_content": 1, + "max_node_degree": 10, + "dehydrated": False, + } diff --git a/tests/unit/test_broker.py b/tests/unit/test_broker.py new file mode 100644 index 0000000..0344169 --- /dev/null +++ b/tests/unit/test_broker.py @@ -0,0 +1,151 @@ +"""Tests for ``shepherd_utils.broker``. + +Exercise the redis-streams add/get/ack flow against fakeredis, and the +pubsub-backed lock acquisition / release scripts. +""" + +import asyncio +import logging + +import pytest + +from shepherd_utils.broker import ( + acquire_lock, + add_task, + create_consumer_group, + get_task, + mark_task_as_complete, + remove_lock, +) + +logger = logging.getLogger(__name__) + + +@pytest.mark.asyncio +async def test_create_consumer_group_swallows_existing_group(redis_mock): + """Creating a group twice on the same stream should not raise.""" + await create_consumer_group("stream1", "consumer", logger) + # Second call would normally raise BUSYGROUP; broker swallows it. + await create_consumer_group("stream1", "consumer", logger) + + +@pytest.mark.asyncio +async def test_add_and_get_task_roundtrip(redis_mock): + payload = { + "query_id": "q1", + "response_id": "r1", + "workflow": "[]", + "log_level": "20", + "otel": "{}", + } + await add_task("teststream", payload, logger) + msg_id, fields = await get_task("teststream", "consumer", "test", logger) + assert isinstance(msg_id, str) + assert fields == payload + + +@pytest.mark.asyncio +async def test_get_task_returns_none_when_no_messages(redis_mock): + """A short, empty xreadgroup call should return None and not raise.""" + # The broker's get_task calls create_consumer_group first; we still hit + # the block timeout. fakeredis returns immediately because no messages + # are in the stream after the group is created. + out = await get_task("empty_stream", "consumer", "test", logger) + assert out is None + + +@pytest.mark.asyncio +async def test_mark_task_as_complete_acks_message(redis_mock): + """An ACKed message should leave the pending list.""" + await add_task("ackstream", {"q": "1"}, logger) + msg_id, _ = await get_task("ackstream", "consumer", "tester", logger) + + pending_before = await redis_mock["broker"].xpending("ackstream", "consumer") + assert pending_before["pending"] == 1 + + await mark_task_as_complete("ackstream", "consumer", msg_id, logger) + + pending_after = await redis_mock["broker"].xpending("ackstream", "consumer") + assert pending_after["pending"] == 0 + + +@pytest.mark.asyncio +async def test_acquire_lock_returns_true_when_lock_is_free(redis_mock): + got = await acquire_lock("resource-1", "consumer-A", logger) + assert got is True + val = await redis_mock["lock"].get("resource-1") + assert val == "consumer-A" + + +@pytest.mark.asyncio +async def test_remove_lock_releases_only_when_token_matches(redis_mock, mocker): + """The unlock Lua script only removes the key when the consumer id matches. + + fakeredis does not implement ``evalsha``, so we replace ``register_script`` + with a python emulator that runs the same logic (compare-and-delete). + """ + + def fake_register_script(_script): + async def _runner(keys, args): + (key,) = keys + (token,) = args + current = await redis_mock["lock"].get(key) + if current == token: + await redis_mock["lock"].delete(key) + return 1 + return 0 + + return _runner + + mocker.patch.object( + redis_mock["lock"], "register_script", side_effect=fake_register_script + ) + + await acquire_lock("resource-2", "consumer-A", logger) + # Wrong owner can't release. + await remove_lock("resource-2", "consumer-B", logger) + assert await redis_mock["lock"].get("resource-2") == "consumer-A" + + await remove_lock("resource-2", "consumer-A", logger) + assert await redis_mock["lock"].get("resource-2") is None + + +@pytest.mark.asyncio +async def test_acquire_lock_blocks_until_other_consumer_releases(redis_mock, mocker): + """A second acquire while the lock is held should fail. + + We patch the pubsub message wait so the loop iterates fast instead of + really waiting 5s per try. + """ + + # Patch pubsub.get_message to always raise asyncio.TimeoutError so the + # loop falls through quickly. This emulates "no notification arrived". + real_pubsub = redis_mock["lock"].pubsub + + class FastPubSub: + def __init__(self): + self._inner = real_pubsub() + + async def subscribe(self, channel): + return await self._inner.subscribe(channel) + + async def unsubscribe(self, channel): + return await self._inner.unsubscribe(channel) + + async def aclose(self): + return await self._inner.aclose() + + async def get_message(self, ignore_subscribe_messages=True, timeout=5): + raise asyncio.TimeoutError + + mocker.patch.object(redis_mock["lock"], "pubsub", FastPubSub) + + # First consumer holds the lock. + assert await acquire_lock("resource-3", "consumer-A", logger) is True + + # Second consumer tries; can't get it. With our fast pubsub stub this + # falls through 12 iterations almost instantly. + got = await acquire_lock("resource-3", "consumer-B", logger) + assert got is False + # And the original holder's value is unchanged. + assert await redis_mock["lock"].get("resource-3") == "consumer-A" diff --git a/tests/unit/test_bte.py b/tests/unit/test_bte.py new file mode 100644 index 0000000..6dc1ae4 --- /dev/null +++ b/tests/unit/test_bte.py @@ -0,0 +1,214 @@ +"""Tests for ``workers.bte.worker``: ``examine_query`` and the entry-point +workflow construction.""" + +import copy +import json +import logging + +import pytest + +from tests.helpers.generate_messages import creative_query +from workers.bte.worker import bte, examine_query + +logger = logging.getLogger(__name__) + + +def test_examine_query_pure_lookup_returns_no_question_or_answer(): + msg = { + "message": { + "query_graph": { + "nodes": {"a": {}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + } + } + infer, q, a, pathfinder = examine_query(msg) + assert (infer, pathfinder) == (False, False) + assert q is None and a is None + + +def test_examine_query_inferred_returns_question_and_answer(): + msg = copy.deepcopy(creative_query) + infer, q, a, pathfinder = examine_query(msg) + assert infer is True + assert (q, a) == ("ON", "SN") + + +def test_examine_query_pathfinder_three_inferred_edges(): + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}, "c": {}, "d": {}}, + "edges": { + "e0": {"subject": "a", "object": "b", "knowledge_type": "inferred"}, + "e1": {"subject": "b", "object": "c", "knowledge_type": "inferred"}, + "e2": {"subject": "c", "object": "d", "knowledge_type": "inferred"}, + }, + } + } + } + _, _, _, pathfinder = examine_query(msg) + assert pathfinder is True + + +def test_examine_query_rejects_two_inferred_when_not_pathfinder(): + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": { + "e0": {"subject": "a", "object": "b", "knowledge_type": "inferred"}, + "e1": {"subject": "a", "object": "b", "knowledge_type": "inferred"}, + }, + } + } + } + with pytest.raises(Exception, match="single infer edge"): + examine_query(msg) + + +def test_examine_query_rejects_mixed_lookup_and_infer(): + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": { + "e0": {"subject": "a", "object": "b", "knowledge_type": "inferred"}, + "e1": {"subject": "a", "object": "b"}, + }, + } + } + } + with pytest.raises(Exception, match="Mixed infer and lookup"): + examine_query(msg) + + +def test_examine_query_rejects_both_creative_nodes_pinned(): + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {"ids": ["Y:2"]}}, + "edges": { + "e0": { + "subject": "a", + "object": "b", + "knowledge_type": "inferred", + }, + }, + } + } + } + with pytest.raises(Exception, match="Both nodes of creative edge pinned"): + examine_query(msg) + + +def test_examine_query_rejects_no_creative_node_pinned(): + msg = { + "message": { + "query_graph": { + "nodes": {"a": {}, "b": {}}, + "edges": { + "e0": { + "subject": "a", + "object": "b", + "knowledge_type": "inferred", + }, + }, + } + } + } + with pytest.raises(Exception, match="No nodes of creative edge pinned"): + examine_query(msg) + + +def _make_task(workflow=None): + return [ + "test", + { + "query_id": "qid", + "response_id": "rid", + "workflow": json.dumps(workflow), + "log_level": "20", + "otel": json.dumps({}), + }, + ] + + +@pytest.mark.asyncio +async def test_bte_lookup_workflow(redis_mock, mocker): + """Pure-lookup query: install the standard BTE workflow.""" + mocker.patch( + "workers.bte.worker.get_message", + new_callable=mocker.AsyncMock, + return_value={ + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + } + }, + ) + task = _make_task(None) + await bte(task, logger) + workflow = json.loads(task[1]["workflow"]) + assert [op["id"] for op in workflow] == [ + "bte.lookup", + "aragorn.omnicorp", + "aragorn.score", + "sort_results_score", + "filter_results_top_n", + "filter_kgraph_orphans", + ] + + +@pytest.mark.asyncio +async def test_bte_inferred_workflow_matches_lookup(redis_mock, mocker): + """Inferred (creative) workflow currently mirrors the lookup workflow.""" + mocker.patch( + "workers.bte.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=copy.deepcopy(creative_query), + ) + task = _make_task(None) + await bte(task, logger) + workflow = json.loads(task[1]["workflow"]) + assert [op["id"] for op in workflow][0] == "bte.lookup" + + +@pytest.mark.asyncio +async def test_bte_rejects_pathfinder_query(redis_mock, mocker): + """BTE explicitly rejects pathfinder queries.""" + pathfinder_msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}, "c": {}, "d": {}}, + "edges": { + "e0": {"subject": "a", "object": "b", "knowledge_type": "inferred"}, + "e1": {"subject": "b", "object": "c", "knowledge_type": "inferred"}, + "e2": {"subject": "c", "object": "d", "knowledge_type": "inferred"}, + }, + } + } + } + mocker.patch( + "workers.bte.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=pathfinder_msg, + ) + with pytest.raises(Exception, match="does not support Pathfinder"): + await bte(_make_task(None), logger) + + +@pytest.mark.asyncio +async def test_bte_preexisting_workflow_is_preserved(redis_mock, mocker): + """If the task already has a workflow, bte() leaves it alone.""" + mocker.patch( + "workers.bte.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=copy.deepcopy(creative_query), + ) + custom = [{"id": "bte.lookup"}] + task = _make_task(custom) + await bte(task, logger) + assert json.loads(task[1]["workflow"]) == custom diff --git a/tests/unit/test_bte_lookup.py b/tests/unit/test_bte_lookup.py new file mode 100644 index 0000000..e26112a --- /dev/null +++ b/tests/unit/test_bte_lookup.py @@ -0,0 +1,321 @@ +"""Tests for ``workers.bte_lookup.worker`` helpers and the lookup entry point. + +Covers ``examine_query``, ``get_params``, ``match_templates``, +``fill_templates``, ``expand_bte_query``, plus pure-lookup and +inferred-with-templates branches of ``bte_lookup``. +""" + +import copy +import json +import logging + +import pytest + +from tests.helpers.generate_messages import creative_query +from workers.bte_lookup.worker import ( + AsyncResponse, + bte_lookup, + examine_query, + expand_bte_query, + fill_templates, + get_params, + match_templates, +) + +logger = logging.getLogger(__name__) + + +# --- examine_query --------------------------------------------------------- + + +def test_examine_query_lookup_only(): + msg = { + "message": { + "query_graph": { + "nodes": {"a": {}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + } + } + infer, _, _, pathfinder = examine_query(msg) + assert (infer, pathfinder) == (False, False) + + +def test_examine_query_inferred_returns_creative_pair(): + msg = copy.deepcopy(creative_query) + infer, q, a, _ = examine_query(msg) + assert infer is True + assert q == "ON" and a == "SN" + + +# --- get_params ------------------------------------------------------------ + + +def test_get_params_extracts_full_tuple(): + qg = { + "nodes": { + "a": {"categories": ["biolink:ChemicalEntity"], "ids": ["CHEBI:1"]}, + "b": {"categories": ["biolink:Disease"]}, + }, + "edges": { + "e0": { + "subject": "a", + "object": "b", + "predicates": ["biolink:treats"], + "qualifier_constraints": [ + { + "qualifier_set": [ + { + "qualifier_type_id": "biolink:object_aspect_qualifier", + "qualifier_value": "activity", + } + ] + } + ], + } + }, + } + s_key, s_type, s_curie, o_key, o_type, o_curie, predicate, qualifiers = get_params( + qg + ) + assert s_key == "a" + assert s_type == "biolink:ChemicalEntity" + assert s_curie == "CHEBI:1" + assert o_key == "b" + assert o_type == "biolink:Disease" + assert o_curie is None + assert predicate == "biolink:treats" + assert qualifiers == {"biolink:object_aspect_qualifier": "activity"} + + +def test_get_params_no_qualifier_constraints_returns_empty_dict(): + qg = { + "nodes": { + "a": {"categories": ["biolink:ChemicalEntity"], "ids": ["CHEBI:1"]}, + "b": {"categories": ["biolink:Disease"]}, + }, + "edges": { + "e0": { + "subject": "a", + "object": "b", + "predicates": ["biolink:treats"], + } + }, + } + *_, qualifiers = get_params(qg) + assert qualifiers == {} + + +# --- match_templates ------------------------------------------------------- + + +def test_match_templates_returns_paths_for_drug_treats_disease(): + """A subject/object/predicate combo that the production + ``template_groups.json`` lists as Drug-treats-Disease should match + actual template files on disk.""" + paths = match_templates( + subject_type="biolink:Drug", + object_type="biolink:Disease", + predicate="biolink:treats", + qualifiers={}, + logger=logger, + ) + # Should at least find one template for the Drug-treats-Disease group. + assert paths + assert all(p.suffix == ".json" for p in paths) + + +def test_match_templates_no_match_returns_empty(): + """Nonsense types should not match any group.""" + paths = match_templates( + subject_type="biolink:NotAThing", + object_type="biolink:NotAThing", + predicate="biolink:not_real", + qualifiers={}, + logger=logger, + ) + assert paths == [] + + +def test_match_templates_strips_biolink_prefix_when_matching(): + """The matcher removes the ``biolink:`` prefix before checking + template_groups.json (which lists bare names like 'Drug').""" + no_prefix = match_templates( + subject_type="Drug", + object_type="Disease", + predicate="treats", + qualifiers={}, + logger=logger, + ) + with_prefix = match_templates( + subject_type="biolink:Drug", + object_type="biolink:Disease", + predicate="biolink:treats", + qualifiers={}, + logger=logger, + ) + # Both forms should yield the same set of templates. + assert {p.name for p in no_prefix} == {p.name for p in with_prefix} + + +# --- fill_templates -------------------------------------------------------- + + +def test_fill_templates_substitutes_subject_curie(): + """When the subject curie is provided, templates substitute source_id and + delete the target's ids.""" + paths = match_templates( + subject_type="biolink:Drug", + object_type="biolink:Disease", + predicate="biolink:treats", + qualifiers={}, + logger=logger, + ) + assert paths + filled = fill_templates( + paths=[paths[0]], + query_body={"parameters": {"timeout": 60}, "submitter": "test"}, + subject_key="SN", + subject_curie="CHEBI:1", + object_key="ON", + object_curie=None, + ) + assert len(filled) == 1 + qg = filled[0]["message"]["query_graph"] + # Source got the curie; target has no ids + assert qg["nodes"]["SN"]["ids"] == ["CHEBI:1"] + assert "ids" not in qg["nodes"]["ON"] + assert filled[0]["workflow"] == [{"id": "lookup"}] + assert filled[0]["submitter"] == "test" + + +def test_fill_templates_substitutes_object_curie_when_subject_unset(): + """The mirror direction: object pinned, subject empty.""" + paths = match_templates( + subject_type="biolink:Drug", + object_type="biolink:Disease", + predicate="biolink:treats", + qualifiers={}, + logger=logger, + ) + filled = fill_templates( + paths=[paths[0]], + query_body={"parameters": {"timeout": 60}, "submitter": "test"}, + subject_key="SN", + subject_curie=None, + object_key="ON", + object_curie="MONDO:1", + ) + qg = filled[0]["message"]["query_graph"] + assert qg["nodes"]["ON"]["ids"] == ["MONDO:1"] + assert "ids" not in qg["nodes"]["SN"] + + +# --- expand_bte_query ------------------------------------------------------ + + +def test_expand_bte_query_returns_empty_for_missing_query_graph(): + out = expand_bte_query({"message": {}}, logger) + assert out == [] + + +def test_expand_bte_query_includes_direct_query_first(): + """The first message should be the direct (non-inferred) version of the + query; subsequent messages are template expansions.""" + msg = copy.deepcopy(creative_query) + msg["message"]["query_graph"]["nodes"]["SN"]["categories"] = [ + "biolink:Drug", + ] + msg["message"]["query_graph"]["nodes"]["ON"]["categories"] = [ + "biolink:Disease", + ] + msg["parameters"] = {"timeout": 60} + msg["submitter"] = "test" + out = expand_bte_query(msg, logger) + assert len(out) >= 1 + direct = out[0] + # Direct query has had knowledge_type stripped from each edge + assert "knowledge_type" not in direct["message"]["query_graph"]["edges"]["e0"] + + +# --- AsyncResponse dataclass ---------------------------------------------- + + +def test_async_response_dataclass_defaults(): + r = AsyncResponse(status_code=200, success=True, callback_id="abc") + assert r.error is None + assert r.success is True + + +# --- bte_lookup entry point ----------------------------------------------- + + +def _make_task(): + return [ + "test", + { + "query_id": "qid", + "response_id": "rid", + "workflow": json.dumps([{"id": "bte.lookup"}]), + "log_level": "20", + "otel": json.dumps({}), + }, + ] + + +@pytest.mark.asyncio +async def test_bte_lookup_pure_lookup_calls_kg_retrieval(redis_mock, mocker): + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + }, + "parameters": {"timeout": 5}, + } + mocker.patch( + "workers.bte_lookup.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=msg, + ) + mocker.patch( + "workers.bte_lookup.worker.add_callback_id", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "workers.bte_lookup.worker.get_running_callbacks", + new_callable=mocker.AsyncMock, + return_value=[], + ) + mock_post = mocker.patch( + "httpx.AsyncClient.post", + new_callable=mocker.AsyncMock, + return_value=mocker.Mock(status_code=200), + ) + await bte_lookup(_make_task(), logger) + assert mock_post.called + + +@pytest.mark.asyncio +async def test_bte_lookup_rejects_pathfinder_query(redis_mock, mocker): + pathfinder_msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}, "c": {}, "d": {}}, + "edges": { + "e0": {"subject": "a", "object": "b", "knowledge_type": "inferred"}, + "e1": {"subject": "b", "object": "c", "knowledge_type": "inferred"}, + "e2": {"subject": "c", "object": "d", "knowledge_type": "inferred"}, + }, + } + }, + "parameters": {"timeout": 5}, + } + mocker.patch( + "workers.bte_lookup.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=pathfinder_msg, + ) + with pytest.raises(Exception, match="does not support Pathfinder"): + await bte_lookup(_make_task(), logger) diff --git a/tests/unit/test_bte_lookup_branches.py b/tests/unit/test_bte_lookup_branches.py new file mode 100644 index 0000000..332c6a2 --- /dev/null +++ b/tests/unit/test_bte_lookup_branches.py @@ -0,0 +1,545 @@ +"""Branch-coverage tests for ``workers.bte_lookup.worker`` paths not +exercised by ``test_bte_lookup.py``. + +Covers: + +- ``run_async_lookup`` happy path and HTTPX-raises path. +- ``bte_lookup`` inferred fanout: expand_bte_query is called, requests + fire via ``run_async_lookup``, failed responses get their callback ids + removed. +- ``bte_lookup`` polling-loop branches: in-progress callbacks delay, + timeout cleanup branch. +- ``process_task`` happy and failure paths. +""" + +import asyncio +import json +import logging + +import httpx +import pytest + +from workers.bte_lookup import worker as btel +from workers.bte_lookup.worker import ( + AsyncResponse, + bte_lookup, + process_task, + run_async_lookup, +) + +logger = logging.getLogger(__name__) + + +def _make_task(): + return [ + "test", + { + "query_id": "qid", + "response_id": "rid", + "workflow": json.dumps([{"id": "bte.lookup"}]), + "log_level": "20", + "otel": json.dumps({}), + }, + ] + + +# --- run_async_lookup ---------------------------------------------------- + + +@pytest.mark.asyncio +async def test_run_async_lookup_returns_success_on_200(redis_mock, mocker): + mocker.patch.object(btel, "add_callback_id", new_callable=mocker.AsyncMock) + + fake_response = mocker.Mock() + fake_response.status_code = 200 + client = mocker.Mock() + client.post = mocker.AsyncMock(return_value=fake_response) + + out = await run_async_lookup(client, {"message": {}}, "qid", logger) + assert isinstance(out, AsyncResponse) + assert out.success is True + assert out.status_code == 200 + assert out.error is None + + +@pytest.mark.asyncio +async def test_run_async_lookup_returns_failure_on_non_200(redis_mock, mocker): + mocker.patch.object(btel, "add_callback_id", new_callable=mocker.AsyncMock) + + fake_response = mocker.Mock() + fake_response.status_code = 500 + client = mocker.Mock() + client.post = mocker.AsyncMock(return_value=fake_response) + + out = await run_async_lookup(client, {"message": {}}, "qid", logger) + assert out.success is False + assert out.status_code == 500 + + +@pytest.mark.asyncio +async def test_run_async_lookup_returns_500_when_post_raises(redis_mock, mocker): + """A network exception is caught and surfaced as a 500 ``AsyncResponse`` + with the error string populated.""" + mocker.patch.object(btel, "add_callback_id", new_callable=mocker.AsyncMock) + + client = mocker.Mock() + client.post = mocker.AsyncMock(side_effect=httpx.ConnectError("boom")) + + out = await run_async_lookup(client, {"message": {}}, "qid", logger) + assert out.success is False + assert out.status_code == 500 + assert "boom" in out.error + + +@pytest.mark.asyncio +async def test_run_async_lookup_writes_callback_url_into_message(redis_mock, mocker): + """The function mutates the supplied message to include a callback URL + routed back to the BTE callback endpoint.""" + mocker.patch.object(btel, "add_callback_id", new_callable=mocker.AsyncMock) + + fake_response = mocker.Mock() + fake_response.status_code = 200 + client = mocker.Mock() + client.post = mocker.AsyncMock(return_value=fake_response) + + msg = {"message": {}} + await run_async_lookup(client, msg, "qid", logger) + assert "callback" in msg + assert "/bte/callback/" in msg["callback"] + + +# --- bte_lookup inferred fanout ----------------------------------------- + + +@pytest.mark.asyncio +async def test_bte_lookup_inferred_fans_out_via_expand_bte_query(redis_mock, mocker): + """Inferred query: expand_bte_query is called, multiple + run_async_lookup calls fire (one per expanded message).""" + inferred_msg = { + "message": { + "query_graph": { + "nodes": { + "a": {"ids": ["X:1"], "categories": ["biolink:Drug"]}, + "b": {"categories": ["biolink:Disease"]}, + }, + "edges": { + "e0": { + "subject": "a", + "object": "b", + "knowledge_type": "inferred", + "predicates": ["biolink:treats"], + } + }, + } + }, + "parameters": {"timeout": 5}, + } + mocker.patch.object( + btel, + "get_message", + new_callable=mocker.AsyncMock, + return_value=inferred_msg, + ) + expanded = [ + {"message": {"query_graph": {}}, "parameters": {}, "submitter": "test"}, + {"message": {"query_graph": {}}, "parameters": {}, "submitter": "test"}, + ] + mocker.patch.object(btel, "expand_bte_query", return_value=expanded) + mock_run = mocker.patch.object( + btel, + "run_async_lookup", + new_callable=mocker.AsyncMock, + return_value=AsyncResponse(status_code=200, success=True, callback_id="cb-1"), + ) + mocker.patch.object( + btel, + "get_running_callbacks", + new_callable=mocker.AsyncMock, + return_value=[], + ) + await bte_lookup(_make_task(), logger) + # One run_async_lookup per expanded message. + assert mock_run.call_count == 2 + + +@pytest.mark.asyncio +async def test_bte_lookup_inferred_removes_failed_callback_ids(redis_mock, mocker): + """A run_async_lookup that returns an unsuccessful AsyncResponse should + trigger remove_callback_id.""" + inferred_msg = { + "message": { + "query_graph": { + "nodes": { + "a": {"ids": ["X:1"], "categories": ["biolink:Drug"]}, + "b": {"categories": ["biolink:Disease"]}, + }, + "edges": { + "e0": { + "subject": "a", + "object": "b", + "knowledge_type": "inferred", + "predicates": ["biolink:treats"], + } + }, + } + }, + "parameters": {"timeout": 5}, + } + mocker.patch.object( + btel, + "get_message", + new_callable=mocker.AsyncMock, + return_value=inferred_msg, + ) + mocker.patch.object( + btel, + "expand_bte_query", + return_value=[ + {"message": {"query_graph": {}}, "parameters": {}, "submitter": "t"} + ], + ) + mocker.patch.object( + btel, + "run_async_lookup", + new_callable=mocker.AsyncMock, + return_value=AsyncResponse( + status_code=500, success=False, callback_id="failed-cb", error="x" + ), + ) + mock_remove = mocker.patch.object( + btel, "remove_callback_id", new_callable=mocker.AsyncMock + ) + mocker.patch.object( + btel, + "get_running_callbacks", + new_callable=mocker.AsyncMock, + return_value=[], + ) + await bte_lookup(_make_task(), logger) + mock_remove.assert_awaited_once_with("failed-cb", logger) + + +@pytest.mark.asyncio +async def test_bte_lookup_inferred_logs_exception_responses(redis_mock, mocker): + """An exception in ``asyncio.gather`` (return_exceptions=True) is logged + but no remove_callback_id is called for it.""" + inferred_msg = { + "message": { + "query_graph": { + "nodes": { + "a": {"ids": ["X:1"], "categories": ["biolink:Drug"]}, + "b": {"categories": ["biolink:Disease"]}, + }, + "edges": { + "e0": { + "subject": "a", + "object": "b", + "knowledge_type": "inferred", + "predicates": ["biolink:treats"], + } + }, + } + }, + "parameters": {"timeout": 5}, + } + mocker.patch.object( + btel, + "get_message", + new_callable=mocker.AsyncMock, + return_value=inferred_msg, + ) + mocker.patch.object( + btel, + "expand_bte_query", + return_value=[ + {"message": {"query_graph": {}}, "parameters": {}, "submitter": "t"} + ], + ) + # run_async_lookup raises -> gather captures and returns exception. + mocker.patch.object( + btel, + "run_async_lookup", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("boom"), + ) + mock_remove = mocker.patch.object( + btel, "remove_callback_id", new_callable=mocker.AsyncMock + ) + mocker.patch.object( + btel, + "get_running_callbacks", + new_callable=mocker.AsyncMock, + return_value=[], + ) + await bte_lookup(_make_task(), logger) + assert not mock_remove.called + + +# --- bte_lookup timeout / polling branches ------------------------------ + + +@pytest.mark.asyncio +async def test_bte_lookup_polling_loop_iterates_until_callbacks_drain( + redis_mock, mocker +): + """The polling loop sees in-progress callbacks for one iteration, then + they drain on the next iteration and the loop breaks.""" + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + }, + "parameters": {"timeout": 5}, + } + mocker.patch.object( + btel, "get_message", new_callable=mocker.AsyncMock, return_value=msg + ) + mocker.patch.object(btel, "add_callback_id", new_callable=mocker.AsyncMock) + mocker.patch( + "httpx.AsyncClient.post", + new_callable=mocker.AsyncMock, + return_value=mocker.Mock(status_code=200), + ) + # First call: in-progress; second call: drained. + running = mocker.patch.object( + btel, + "get_running_callbacks", + new_callable=mocker.AsyncMock, + side_effect=[[("running",)], []], + ) + # Don't actually sleep. + mocker.patch("asyncio.sleep", new_callable=mocker.AsyncMock) + await bte_lookup(_make_task(), logger) + assert running.call_count == 2 + + +@pytest.mark.asyncio +async def test_bte_lookup_polling_loop_retries_after_db_error(redis_mock, mocker): + """An exception on get_running_callbacks doesn't abort the loop; it sleeps + and retries on the next iteration.""" + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + }, + "parameters": {"timeout": 5}, + } + mocker.patch.object( + btel, "get_message", new_callable=mocker.AsyncMock, return_value=msg + ) + mocker.patch.object(btel, "add_callback_id", new_callable=mocker.AsyncMock) + mocker.patch( + "httpx.AsyncClient.post", + new_callable=mocker.AsyncMock, + return_value=mocker.Mock(status_code=200), + ) + running = mocker.patch.object( + btel, + "get_running_callbacks", + new_callable=mocker.AsyncMock, + side_effect=[RuntimeError("pg dead"), []], + ) + mocker.patch("asyncio.sleep", new_callable=mocker.AsyncMock) + await bte_lookup(_make_task(), logger) + assert running.call_count == 2 + + +@pytest.mark.asyncio +async def test_bte_lookup_timeout_triggers_cleanup_callbacks(redis_mock, mocker): + """When the polling loop never sees a drained queue and time runs out, + cleanup_callbacks is called.""" + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + }, + # Use a tiny but non-zero timeout. We patch time.time to control flow. + "parameters": {"timeout": 5}, + } + mocker.patch.object( + btel, "get_message", new_callable=mocker.AsyncMock, return_value=msg + ) + mocker.patch.object(btel, "add_callback_id", new_callable=mocker.AsyncMock) + mocker.patch( + "httpx.AsyncClient.post", + new_callable=mocker.AsyncMock, + return_value=mocker.Mock(status_code=200), + ) + # Always say there's a callback in progress. + mocker.patch.object( + btel, + "get_running_callbacks", + new_callable=mocker.AsyncMock, + return_value=[("still-running",)], + ) + + # Patch time.time so the loop goes through one iteration and then exceeds + # the 5s timeout. We can't predict exactly how many ``time.time`` calls + # the worker makes, so use an iterator that returns "late" forever after + # the first two calls. + def _fake_time(): + yield 0 + yield 0.1 + while True: + yield 100 + + gen = _fake_time() + mocker.patch.object(btel.time, "time", side_effect=lambda: next(gen)) + mocker.patch("asyncio.sleep", new_callable=mocker.AsyncMock) + mock_cleanup = mocker.patch.object( + btel, "cleanup_callbacks", new_callable=mocker.AsyncMock + ) + await bte_lookup(_make_task(), logger) + assert mock_cleanup.called + + +# --- bte_lookup non-infer branch hits the http POST -------------------- + + +@pytest.mark.asyncio +async def test_bte_lookup_non_infer_branch_writes_callback_url(redis_mock, mocker): + """Pure-lookup queries set message['callback'] before posting to the + retrieval URL.""" + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + }, + "parameters": {"timeout": 5}, + } + mocker.patch.object( + btel, "get_message", new_callable=mocker.AsyncMock, return_value=msg + ) + mocker.patch.object(btel, "add_callback_id", new_callable=mocker.AsyncMock) + mock_post = mocker.patch( + "httpx.AsyncClient.post", + new_callable=mocker.AsyncMock, + return_value=mocker.Mock(status_code=200), + ) + mocker.patch.object( + btel, + "get_running_callbacks", + new_callable=mocker.AsyncMock, + return_value=[], + ) + await bte_lookup(_make_task(), logger) + posted_msg = mock_post.call_args.kwargs["json"] + assert "callback" in posted_msg + assert "/bte/callback/" in posted_msg["callback"] + + +# --- bte_lookup adds default submitter ---------------------------------- + + +@pytest.mark.asyncio +async def test_bte_lookup_adds_default_submitter_when_missing(redis_mock, mocker): + """If the input message has no submitter, bte_lookup populates one with + the shepherd-bte infores string.""" + msg = { + "message": { + "query_graph": { + "nodes": {"a": {"ids": ["X:1"]}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + }, + "parameters": {"timeout": 5}, + } + mocker.patch.object( + btel, "get_message", new_callable=mocker.AsyncMock, return_value=msg + ) + mocker.patch.object(btel, "add_callback_id", new_callable=mocker.AsyncMock) + mock_post = mocker.patch( + "httpx.AsyncClient.post", + new_callable=mocker.AsyncMock, + return_value=mocker.Mock(status_code=200), + ) + mocker.patch.object( + btel, + "get_running_callbacks", + new_callable=mocker.AsyncMock, + return_value=[], + ) + await bte_lookup(_make_task(), logger) + sent = mock_post.call_args.kwargs["json"] + assert sent["submitter"].startswith("infores:shepherd-bte") + + +# --- process_task ------------------------------------------------------- + + +class _Limiter: + def __init__(self): + self.released = False + + def release(self): + self.released = True + + +@pytest.mark.asyncio +async def test_bte_lookup_process_task_happy_path(redis_mock, mocker): + mocker.patch.object(btel, "bte_lookup", new_callable=mocker.AsyncMock) + mock_wrap = mocker.patch.object(btel, "wrap_up_task", new_callable=mocker.AsyncMock) + limiter = _Limiter() + await process_task(_make_task(), None, logger, limiter) + assert mock_wrap.called + assert limiter.released + + +@pytest.mark.asyncio +async def test_bte_lookup_process_task_routes_failure_to_handle_task_failure( + redis_mock, mocker +): + mocker.patch.object( + btel, + "bte_lookup", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("kaboom"), + ) + mock_failure = mocker.patch.object( + btel, "handle_task_failure", new_callable=mocker.AsyncMock + ) + limiter = _Limiter() + await process_task(_make_task(), None, logger, limiter) + assert mock_failure.called + assert limiter.released + + +@pytest.mark.asyncio +async def test_bte_lookup_process_task_swallows_cancellation(redis_mock, mocker): + mocker.patch.object( + btel, + "bte_lookup", + new_callable=mocker.AsyncMock, + side_effect=asyncio.CancelledError, + ) + mock_failure = mocker.patch.object( + btel, "handle_task_failure", new_callable=mocker.AsyncMock + ) + limiter = _Limiter() + await process_task(_make_task(), None, logger, limiter) + assert not mock_failure.called + assert limiter.released + + +@pytest.mark.asyncio +async def test_bte_lookup_process_task_swallows_wrap_up_failure(redis_mock, mocker): + """A wrap_up_task failure should be logged but not escape.""" + mocker.patch.object(btel, "bte_lookup", new_callable=mocker.AsyncMock) + mocker.patch.object( + btel, + "wrap_up_task", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("redis dropped"), + ) + limiter = _Limiter() + # Should not raise. + await process_task(_make_task(), None, logger, limiter) + assert limiter.released diff --git a/tests/unit/test_db.py b/tests/unit/test_db.py new file mode 100644 index 0000000..5884e77 --- /dev/null +++ b/tests/unit/test_db.py @@ -0,0 +1,179 @@ +"""Tests for ``shepherd_utils.db`` helpers. + +Covers the pure codecs (``encode_message``/``decode_message``) and the +Redis-backed read/write functions that don't touch postgres. Postgres-backed +helpers (``add_query``, ``add_callback_id`` etc.) are exercised via +``postgres_mock`` from the conftest. +""" + +import logging + +import orjson +import pytest + +from shepherd_utils.db import ( + decode_message, + encode_message, + get_logs, + get_message, + get_message_sync, + save_logs, + save_message, + save_message_sync, +) +from shepherd_utils.logger import QueryLogger + +logger = logging.getLogger(__name__) + + +def test_encode_decode_roundtrip_preserves_payload(): + payload = { + "message": { + "results": [{"score": 0.5}], + "knowledge_graph": {"nodes": {"A": {}}, "edges": {}}, + } + } + encoded = encode_message(payload) + assert isinstance(encoded, (bytes, bytearray)) + assert decode_message(encoded) == payload + + +def test_encode_message_compresses_repeating_input(): + """zstd should achieve good compression on a redundant payload.""" + big_payload = {"message": {"results": [{"x": "y" * 1000}] * 50}} + encoded = encode_message(big_payload) + assert len(encoded) < len(orjson.dumps(big_payload)) + + +@pytest.mark.asyncio +async def test_save_and_get_message_roundtrip(redis_mock): + payload = {"message": {"results": [{"score": 0.42}]}} + await save_message("rid-1", payload, logger) + fetched = await get_message("rid-1", logger) + assert fetched == payload + + +@pytest.mark.asyncio +async def test_get_message_raises_keyerror_for_missing(redis_mock): + with pytest.raises(KeyError, match="missing-id"): + await get_message("missing-id", logger) + + +def test_save_and_get_message_sync_roundtrip(redis_mock, mocker): + """The sync variants are used inside ProcessPoolExecutor workers; route + them through fakeredis by patching the lazy client accessor.""" + sync_client = mocker.Mock() + storage = {} + + def _set(key, blob, ex=None): + storage[key] = blob + + def _get(key): + return storage.get(key) + + sync_client.set.side_effect = _set + sync_client.get.side_effect = _get + mocker.patch("shepherd_utils.db._get_sync_data_db", return_value=sync_client) + + payload = {"message": {"foo": "bar"}} + save_message_sync("sid-1", payload) + assert get_message_sync("sid-1") == payload + + +def test_get_message_sync_raises_keyerror_for_missing(mocker): + sync_client = mocker.Mock() + sync_client.get.return_value = None + mocker.patch("shepherd_utils.db._get_sync_data_db", return_value=sync_client) + + with pytest.raises(KeyError, match="missing-sid"): + get_message_sync("missing-sid") + + +@pytest.mark.asyncio +async def test_save_logs_appends_query_log_handler_records(redis_mock): + """save_logs reads logs from a QueryLogHandler attached to the logger and + persists them (newest-first reversed) into the logs db.""" + handler = QueryLogger().log_handler + sub_logger = logging.getLogger("test.save_logs.appends") + sub_logger.handlers.clear() + sub_logger.addHandler(handler) + sub_logger.setLevel(logging.DEBUG) + sub_logger.info("first message") + sub_logger.info("second message") + try: + await save_logs("resp-1", sub_logger) + finally: + sub_logger.removeHandler(handler) + + raw = await redis_mock["logs"].get("resp-1") + assert raw is not None + logs = orjson.loads(raw) + messages = [entry["message"] for entry in logs] + # Insertion order: handler emits to a deque (appendleft), reversed in + # save_logs, so logs end up oldest-first. + assert messages == ["first message", "second message"] + + +@pytest.mark.asyncio +async def test_save_logs_extends_existing_logs(redis_mock): + """A pre-existing log array in redis is preserved and extended.""" + existing = [ + {"message": "from-earlier", "timestamp": "2024-01-01T00:00:00", "level": "INFO"} + ] + await redis_mock["logs"].set("resp-2", orjson.dumps(existing)) + + handler = QueryLogger().log_handler + sub_logger = logging.getLogger("test.save_logs.extends") + sub_logger.handlers.clear() + sub_logger.addHandler(handler) + sub_logger.setLevel(logging.DEBUG) + sub_logger.info("new entry") + try: + await save_logs("resp-2", sub_logger) + finally: + sub_logger.removeHandler(handler) + + raw = await redis_mock["logs"].get("resp-2") + logs = orjson.loads(raw) + assert [entry["message"] for entry in logs] == ["from-earlier", "new entry"] + + +@pytest.mark.asyncio +async def test_get_logs_returns_empty_list_when_missing(redis_mock): + """Reading logs for an unknown response id should return an empty list.""" + out = await get_logs("does-not-exist", logger) + assert out == [] + + +@pytest.mark.asyncio +async def test_get_logs_returns_stored_logs(redis_mock): + stored = [{"message": "hello", "timestamp": "ts", "level": "INFO"}] + await redis_mock["logs"].set("resp-3", orjson.dumps(stored)) + out = await get_logs("resp-3", logger) + assert out == stored + + +@pytest.mark.asyncio +async def test_save_message_retries_on_failure(redis_mock, mocker): + """The first call to data_db_client.set raises; save_message should sleep + and retry rather than dropping the message.""" + real_set = redis_mock["data"].set + set_mock = mocker.AsyncMock() + call_state = {"calls": 0} + + async def flaky_set(*args, **kwargs): + call_state["calls"] += 1 + if call_state["calls"] == 1: + raise RuntimeError("simulated transient failure") + return await real_set(*args, **kwargs) + + set_mock.side_effect = flaky_set + mocker.patch("shepherd_utils.db.data_db_client.set", set_mock) + + # Patch sleep so the test doesn't block. + mocker.patch("asyncio.sleep", new=mocker.AsyncMock()) + + await save_message("retry-1", {"a": 1}, logger) + # First call failed, second call succeeded via real_set; expect at least 2. + assert call_state["calls"] >= 2 + assert await get_message("retry-1", logger) == {"a": 1} diff --git a/tests/unit/test_db_postgres.py b/tests/unit/test_db_postgres.py new file mode 100644 index 0000000..2ecfb8a --- /dev/null +++ b/tests/unit/test_db_postgres.py @@ -0,0 +1,276 @@ +"""Tests for the postgres-backed helpers in ``shepherd_utils.db``. + +These cover ``add_query``, ``add_callback_id``, ``remove_callback_id``, +``get_running_callbacks``, ``cleanup_callbacks``, ``get_callback_query_id``, +``get_query_state``, ``set_query_completed``, plus the redis-only ``add_query`` +storage path. + +Each test patches ``shepherd_utils.db.pool`` with a custom AsyncMock chain so +the postgres path uses an in-process fake. +""" + +import logging +from unittest.mock import AsyncMock + +import pytest +from psycopg import OperationalError +from psycopg_pool import AsyncConnectionPool + +from shepherd_utils import db + +logger = logging.getLogger(__name__) + + +def _install_pool_mock( + mocker, *, cursor_fetchone=None, cursor_fetchall=None, raise_on_execute=None +): + """Install a postgres pool mock at ``shepherd_utils.db.pool``. + + Returns ``(mock_conn, mock_pool)`` so tests can assert on the mock calls. + The pool's ``connection(60)`` async context yields ``mock_conn``. + ``conn.execute`` returns a cursor mock whose ``fetchone()`` / ``fetchall()`` + return the values supplied here. + """ + mock_cursor = AsyncMock() + mock_cursor.fetchone = AsyncMock(return_value=cursor_fetchone) + mock_cursor.fetchall = AsyncMock(return_value=cursor_fetchall or []) + + mock_conn = AsyncMock() + if raise_on_execute is not None: + mock_conn.execute.side_effect = raise_on_execute + else: + mock_conn.execute.return_value = mock_cursor + + mock_pool = AsyncMock(spec=AsyncConnectionPool) + mock_pool.connection.return_value.__aenter__.return_value = mock_conn + mock_pool.connection.return_value.__aexit__.return_value = None + mocker.patch.object(db, "pool", mock_pool) + return mock_conn, mock_pool + + +# --- add_query ------------------------------------------------------------ + + +@pytest.mark.asyncio +async def test_add_query_persists_to_redis_and_postgres(redis_mock, mocker): + """``add_query`` writes the encoded query to redis (twice — query_id and + response_id) and inserts a row into shepherd_brain.""" + mock_conn, _ = _install_pool_mock(mocker) + await db.add_query( + "qid-1", "rid-1", {"message": {}}, callback_url=None, logger=logger + ) + + # Both ids made it into redis. + assert await redis_mock["data"].exists("qid-1") + assert await redis_mock["data"].exists("rid-1") + + # Single INSERT issued. + sql, params = mock_conn.execute.call_args.args + assert "INSERT INTO shepherd_brain" in sql + assert params == ("qid-1", "rid-1", None, "QUEUED", "OK") + assert mock_conn.commit.called + + +@pytest.mark.asyncio +async def test_add_query_raises_when_redis_save_fails(redis_mock, mocker): + """If both initial sets fail, add_query raises rather than continuing.""" + mocker.patch.object( + db.data_db_client, + "set", + new_callable=mocker.AsyncMock, + side_effect=Exception("simulated"), + ) + with pytest.raises(Exception, match="Failed to save initial query"): + await db.add_query( + "qid-1", "rid-1", {"message": {}}, callback_url=None, logger=logger + ) + + +@pytest.mark.asyncio +async def test_add_query_raises_when_postgres_insert_fails(redis_mock, mocker): + _install_pool_mock(mocker, raise_on_execute=Exception("pg down")) + with pytest.raises(Exception, match="Failed to save initial query state"): + await db.add_query( + "qid-1", + "rid-1", + {"message": {}}, + callback_url="http://cb", + logger=logger, + ) + + +# --- add_callback_id / remove_callback_id --------------------------------- + + +@pytest.mark.asyncio +async def test_add_callback_id_inserts_row(mocker): + mock_conn, _ = _install_pool_mock(mocker) + await db.add_callback_id("qid", "cb-1", '{"trace": 1}', logger) + sql, params = mock_conn.execute.call_args.args + assert "INSERT INTO callbacks" in sql + assert params == ("qid", "cb-1", '{"trace": 1}') + assert mock_conn.commit.called + + +@pytest.mark.asyncio +async def test_add_callback_id_retries_on_operational_error(mocker): + """OperationalError causes a retry with exponential backoff. The second + attempt succeeds.""" + mock_cursor = AsyncMock() + mock_conn = AsyncMock() + mock_conn.execute.side_effect = [ + OperationalError("first failed"), + mock_cursor, + ] + mock_pool = AsyncMock(spec=AsyncConnectionPool) + mock_pool.connection.return_value.__aenter__.return_value = mock_conn + mock_pool.connection.return_value.__aexit__.return_value = None + mocker.patch.object(db, "pool", mock_pool) + # Don't actually sleep. + mocker.patch("asyncio.sleep", new_callable=mocker.AsyncMock) + + await db.add_callback_id("qid", "cb-1", "{}", logger) + assert mock_conn.execute.call_count == 2 + + +@pytest.mark.asyncio +async def test_add_callback_id_swallows_non_operational_errors(mocker): + """Non-OperationalError exceptions are logged and the function returns + without retrying.""" + _install_pool_mock(mocker, raise_on_execute=ValueError("bad")) + # Should not raise. + await db.add_callback_id("qid", "cb-1", "{}", logger) + + +@pytest.mark.asyncio +async def test_remove_callback_id_runs_delete(mocker): + mock_conn, _ = _install_pool_mock(mocker) + await db.remove_callback_id("cb-1", logger) + sql, params = mock_conn.execute.call_args.args + assert "DELETE FROM callbacks" in sql + assert params == ("cb-1",) + + +@pytest.mark.asyncio +async def test_remove_callback_id_retries_on_operational_error(mocker): + mock_cursor = AsyncMock() + mock_conn = AsyncMock() + mock_conn.execute.side_effect = [OperationalError("x"), mock_cursor] + mock_pool = AsyncMock(spec=AsyncConnectionPool) + mock_pool.connection.return_value.__aenter__.return_value = mock_conn + mock_pool.connection.return_value.__aexit__.return_value = None + mocker.patch.object(db, "pool", mock_pool) + mocker.patch("asyncio.sleep", new_callable=mocker.AsyncMock) + await db.remove_callback_id("cb-1", logger) + assert mock_conn.execute.call_count == 2 + + +# --- get_running_callbacks / cleanup_callbacks ---------------------------- + + +@pytest.mark.asyncio +async def test_get_running_callbacks_returns_rows(mocker): + rows = [("cb-1",), ("cb-2",)] + mock_conn, _ = _install_pool_mock(mocker, cursor_fetchall=rows) + out = await db.get_running_callbacks("qid", logger) + assert out == rows + + +@pytest.mark.asyncio +async def test_get_running_callbacks_propagates_non_operational_error(mocker): + """Non-OperationalError exceptions are re-raised so callers can decide.""" + _install_pool_mock(mocker, raise_on_execute=RuntimeError("kaboom")) + with pytest.raises(RuntimeError, match="kaboom"): + await db.get_running_callbacks("qid", logger) + + +@pytest.mark.asyncio +async def test_cleanup_callbacks_runs_delete(mocker): + mock_conn, _ = _install_pool_mock(mocker) + await db.cleanup_callbacks("qid-99", logger) + sql, params = mock_conn.execute.call_args.args + assert "DELETE FROM callbacks" in sql + assert params == ("qid-99",) + + +# --- get_callback_query_id ------------------------------------------------ + + +@pytest.mark.asyncio +async def test_get_callback_query_id_returns_row(mocker): + """Found row is returned as-is (a (query_id, otel_trace) tuple).""" + mock_conn, _ = _install_pool_mock(mocker, cursor_fetchone=("qid-1", "{}")) + out = await db.get_callback_query_id("cb-1", logger) + assert out == ("qid-1", "{}") + + +@pytest.mark.asyncio +async def test_get_callback_query_id_returns_none_when_missing(mocker): + _install_pool_mock(mocker, cursor_fetchone=None) + out = await db.get_callback_query_id("missing", logger) + assert out is None + + +# --- get_query_state / set_query_completed -------------------------------- + + +@pytest.mark.asyncio +async def test_get_query_state_returns_full_row(mocker): + fake_row = ("qid-1", "now", None, "QUEUED", "OK", None, None, "rid-1", None) + _install_pool_mock(mocker, cursor_fetchone=fake_row) + out = await db.get_query_state("qid-1", logger) + assert out == fake_row + + +@pytest.mark.asyncio +async def test_get_query_state_returns_none_when_missing(mocker): + _install_pool_mock(mocker, cursor_fetchone=None) + out = await db.get_query_state("ghost", logger) + assert out is None + + +@pytest.mark.asyncio +async def test_set_query_completed_runs_update(mocker): + mock_conn, _ = _install_pool_mock(mocker) + await db.set_query_completed("qid-1", "OK", logger) + sql, params = mock_conn.execute.call_args.args + assert "UPDATE shepherd_brain" in sql + assert params == ("OK", "qid-1") + assert mock_conn.commit.called + + +@pytest.mark.asyncio +async def test_set_query_completed_retries_on_operational_error(mocker): + mock_cursor = AsyncMock() + mock_conn = AsyncMock() + mock_conn.execute.side_effect = [OperationalError("x"), mock_cursor] + mock_pool = AsyncMock(spec=AsyncConnectionPool) + mock_pool.connection.return_value.__aenter__.return_value = mock_conn + mock_pool.connection.return_value.__aexit__.return_value = None + mocker.patch.object(db, "pool", mock_pool) + mocker.patch("asyncio.sleep", new_callable=mocker.AsyncMock) + await db.set_query_completed("qid-1", "OK", logger) + assert mock_conn.execute.call_count == 2 + + +# --- initialize_db / shutdown_db ------------------------------------------ + + +@pytest.mark.asyncio +async def test_initialize_and_shutdown_open_and_close_pool(mocker): + mock_pool = AsyncMock(spec=AsyncConnectionPool) + mocker.patch.object(db, "pool", mock_pool) + await db.initialize_db() + assert mock_pool.open.called + await db.shutdown_db() + assert mock_pool.close.called + + +# --- check_connection ----------------------------------------------------- + + +@pytest.mark.asyncio +async def test_check_connection_executes_select_one(): + mock_conn = AsyncMock() + await db.check_connection(mock_conn) + mock_conn.execute.assert_awaited_once_with("SELECT 1") diff --git a/tests/unit/test_filter_analyses_top_n.py b/tests/unit/test_filter_analyses_top_n.py new file mode 100644 index 0000000..1ead5aa --- /dev/null +++ b/tests/unit/test_filter_analyses_top_n.py @@ -0,0 +1,90 @@ +"""Tests for ``workers.filter_analyses_top_n.worker``. + +The worker truncates each result's ``analyses`` array to ``max_analyses`` +items (defaulting to 1000). +""" + +import json +import logging + +import pytest + +from shepherd_utils.db import get_message +from workers.filter_analyses_top_n.worker import filter_analyses_top_n + +logger = logging.getLogger(__name__) + + +def _make_task(workflow): + return [ + "test", + { + "query_id": "test", + "response_id": "test_response", + "workflow": json.dumps(workflow), + "log_level": "20", + "otel": json.dumps({}), + }, + ] + + +@pytest.mark.asyncio +async def test_filter_analyses_top_n_truncates_to_max(redis_mock, mocker): + """Each result's analyses list should be capped at max_analyses.""" + mocker.patch( + "workers.filter_analyses_top_n.worker.get_message", + new_callable=mocker.AsyncMock, + return_value={ + "message": { + "results": [ + {"analyses": [{"score": i} for i in range(5)]}, + {"analyses": [{"score": i} for i in range(2)]}, + ] + } + }, + ) + + await filter_analyses_top_n( + _make_task([{"id": "filter_analyses_top_n", "max_analyses": 2}]), + logger, + ) + saved = await get_message("test_response", logger) + assert len(saved["message"]["results"][0]["analyses"]) == 2 + assert len(saved["message"]["results"][1]["analyses"]) == 2 + + +@pytest.mark.asyncio +async def test_filter_analyses_top_n_default_cap_when_unset(redis_mock, mocker): + """No max_analyses on the workflow op falls back to 1000.""" + mocker.patch( + "workers.filter_analyses_top_n.worker.get_message", + new_callable=mocker.AsyncMock, + return_value={ + "message": { + "results": [ + {"analyses": [{"score": i} for i in range(5)]}, + ] + } + }, + ) + + await filter_analyses_top_n(_make_task([{"id": "filter_analyses_top_n"}]), logger) + saved = await get_message("test_response", logger) + # No truncation when results are smaller than the default cap. + assert len(saved["message"]["results"][0]["analyses"]) == 5 + + +@pytest.mark.asyncio +async def test_filter_analyses_top_n_handles_empty_results(redis_mock, mocker): + """No results: the worker should round-trip the empty list with no error.""" + mocker.patch( + "workers.filter_analyses_top_n.worker.get_message", + new_callable=mocker.AsyncMock, + return_value={"message": {"results": []}}, + ) + + await filter_analyses_top_n( + _make_task([{"id": "filter_analyses_top_n", "max_analyses": 5}]), logger + ) + saved = await get_message("test_response", logger) + assert saved["message"]["results"] == [] diff --git a/tests/unit/test_finish_query_edge_cases.py b/tests/unit/test_finish_query_edge_cases.py new file mode 100644 index 0000000..d8d2e32 --- /dev/null +++ b/tests/unit/test_finish_query_edge_cases.py @@ -0,0 +1,190 @@ +"""Additional ``workers.finish_query.worker`` tests covering paths the +existing happy-path tests don't reach. +""" + +import json +import logging + +import pytest + +from workers.finish_query.worker import finish_query + +logger = logging.getLogger(__name__) + + +@pytest.mark.asyncio +async def test_finish_query_skips_callback_when_state_missing(redis_mock, mocker): + """If get_query_state returns None, don't try to fetch a message or POST. + + The query is not marked completed in this branch either -- nothing in the + db to update. + """ + mock_query_state = mocker.patch( + "workers.finish_query.worker.get_query_state", + new_callable=mocker.AsyncMock, + return_value=None, + ) + mock_set_query_completed = mocker.patch( + "workers.finish_query.worker.set_query_completed", + new_callable=mocker.AsyncMock, + ) + mock_get_message = mocker.patch( + "workers.finish_query.worker.get_message", + new_callable=mocker.AsyncMock, + ) + mock_post = mocker.patch("httpx.AsyncClient.post", new_callable=mocker.AsyncMock) + + await finish_query( + [ + "test", + { + "query_id": "ghost", + "response_id": "ignored", + "workflow": json.dumps([]), + "log_level": "20", + }, + ], + logger, + ) + + assert mock_query_state.called + assert not mock_set_query_completed.called + assert not mock_get_message.called + assert not mock_post.called + + +@pytest.mark.asyncio +async def test_finish_query_propagates_status_to_set_query_completed( + redis_mock, mocker +): + """An ERROR status (set on a failure-routed task) should be passed along + to set_query_completed.""" + mocker.patch( + "workers.finish_query.worker.get_query_state", + new_callable=mocker.AsyncMock, + return_value=["", "", "", "", "", "", "", "rid", None], # sync query + ) + mock_set_query_completed = mocker.patch( + "workers.finish_query.worker.set_query_completed", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "workers.finish_query.worker.get_message", + new_callable=mocker.AsyncMock, + return_value={"message": {}}, + ) + + await finish_query( + [ + "test", + { + "query_id": "test", + "response_id": "rid", + "workflow": json.dumps([]), + "log_level": "20", + "status": "ERROR", + }, + ], + logger, + ) + + mock_set_query_completed.assert_called_once_with("test", "ERROR", logger) + + +@pytest.mark.asyncio +async def test_finish_async_query_retries_callback_on_failure(redis_mock, mocker): + """If the first POST raises, finish_query should retry up to CALLBACK_RETRIES + times with backoff before giving up and still mark the query completed.""" + mocker.patch( + "workers.finish_query.worker.get_query_state", + new_callable=mocker.AsyncMock, + return_value=["", "", "", "", "", "", "", "rid", "http://callback"], + ) + mock_set_query_completed = mocker.patch( + "workers.finish_query.worker.set_query_completed", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "workers.finish_query.worker.get_message", + new_callable=mocker.AsyncMock, + return_value={"message": {"results": []}}, + ) + mocker.patch( + "workers.finish_query.worker.get_logs", + new_callable=mocker.AsyncMock, + return_value=[], + ) + + mock_post = mocker.patch( + "httpx.AsyncClient.post", + new_callable=mocker.AsyncMock, + side_effect=Exception("simulated network error"), + ) + # Don't actually sleep between retries. + mocker.patch("asyncio.sleep", new_callable=mocker.AsyncMock) + + await finish_query( + [ + "test", + { + "query_id": "test", + "response_id": "rid", + "workflow": json.dumps([]), + "log_level": "20", + }, + ], + logger, + ) + # 3 retries baked into the worker. + assert mock_post.call_count == 3 + assert mock_set_query_completed.called + + +@pytest.mark.asyncio +async def test_finish_async_query_attaches_logs_to_message_payload(redis_mock, mocker): + """The async callback POST should send the message with logs attached.""" + mocker.patch( + "workers.finish_query.worker.get_query_state", + new_callable=mocker.AsyncMock, + return_value=["", "", "", "", "", "", "", "rid", "http://callback"], + ) + mocker.patch( + "workers.finish_query.worker.set_query_completed", + new_callable=mocker.AsyncMock, + ) + mocker.patch( + "workers.finish_query.worker.get_message", + new_callable=mocker.AsyncMock, + return_value={"message": {"results": []}}, + ) + mocker.patch( + "workers.finish_query.worker.get_logs", + new_callable=mocker.AsyncMock, + return_value=[ + {"message": "log line", "timestamp": "ts", "level": "INFO"}, + ], + ) + mock_response = mocker.Mock() + mock_response.raise_for_status = mocker.Mock() + mock_post = mocker.patch( + "httpx.AsyncClient.post", + new_callable=mocker.AsyncMock, + return_value=mock_response, + ) + + await finish_query( + [ + "test", + { + "query_id": "test", + "response_id": "rid", + "workflow": json.dumps([]), + "log_level": "20", + }, + ], + logger, + ) + assert mock_post.called + posted_payload = mock_post.call_args.kwargs["json"] + assert "logs" in posted_payload + assert posted_payload["logs"][0]["message"] == "log line" diff --git a/tests/unit/test_inject_shepherd_arax_provenance.py b/tests/unit/test_inject_shepherd_arax_provenance.py new file mode 100644 index 0000000..6c50dc0 --- /dev/null +++ b/tests/unit/test_inject_shepherd_arax_provenance.py @@ -0,0 +1,186 @@ +"""Tests for ``workers.arax.inject_shepherd_arax_provenance``. + +The shepherd-arax injector tags every kgraph edge with an aggregator +``infores:shepherd-arax`` source so downstream consumers can attribute +provenance back through the Shepherd layer. + +Note: ``workers/arax/worker.py`` uses a bare relative import +(``from inject_shepherd_arax_provenance import ...``) which is not valid +under the project's package layout. We import the helper module directly +through its package path here. +""" + +import copy + +from workers.arax.inject_shepherd_arax_provenance import ( + SHEPHERD_ARAX_SOURCE, + add_shepherd_arax_to_edge_sources, +) + + +def test_adds_source_when_missing_entirely(): + response = { + "message": { + "knowledge_graph": { + "edges": { + "e1": {"subject": "A", "object": "B"}, + }, + "nodes": {"A": {}, "B": {}}, + } + } + } + out = add_shepherd_arax_to_edge_sources(response) + sources = out["message"]["knowledge_graph"]["edges"]["e1"]["sources"] + assert len(sources) == 1 + assert sources[0]["resource_id"] == "infores:shepherd-arax" + + +def test_appends_source_when_other_sources_already_present(): + response = { + "message": { + "knowledge_graph": { + "edges": { + "e1": { + "subject": "A", + "object": "B", + "sources": [ + { + "resource_id": "infores:other", + "resource_role": "primary_knowledge_source", + } + ], + } + }, + "nodes": {"A": {}, "B": {}}, + } + } + } + out = add_shepherd_arax_to_edge_sources(response) + sources = out["message"]["knowledge_graph"]["edges"]["e1"]["sources"] + ids = [s["resource_id"] for s in sources] + assert "infores:other" in ids + assert "infores:shepherd-arax" in ids + + +def test_idempotent_when_source_already_present(): + response = { + "message": { + "knowledge_graph": { + "edges": { + "e1": { + "subject": "A", + "object": "B", + "sources": [dict(SHEPHERD_ARAX_SOURCE)], + } + }, + "nodes": {"A": {}, "B": {}}, + } + } + } + snapshot = copy.deepcopy(response) + out = add_shepherd_arax_to_edge_sources(response) + # No duplicate source appended. + assert len(out["message"]["knowledge_graph"]["edges"]["e1"]["sources"]) == 1 + # Equivalent to the original. + assert out == snapshot + + +def test_adds_source_to_every_edge(): + response = { + "message": { + "knowledge_graph": { + "edges": { + "e1": {"subject": "A", "object": "B"}, + "e2": {"subject": "B", "object": "C"}, + "e3": {"subject": "C", "object": "D"}, + }, + "nodes": {"A": {}, "B": {}, "C": {}, "D": {}}, + } + } + } + out = add_shepherd_arax_to_edge_sources(response) + edges = out["message"]["knowledge_graph"]["edges"] + for eid, edge in edges.items(): + assert any( + s.get("resource_id") == "infores:shepherd-arax" for s in edge["sources"] + ), f"missing shepherd-arax source on {eid}" + + +def test_returns_input_unchanged_when_message_missing(): + response = {"not_a_message": True} + out = add_shepherd_arax_to_edge_sources(response) + assert out is response + + +def test_returns_input_unchanged_when_kg_not_a_dict(): + response = {"message": {"knowledge_graph": "not-a-dict"}} + out = add_shepherd_arax_to_edge_sources(response) + assert out is response + + +def test_returns_input_unchanged_when_edges_not_a_dict(): + response = {"message": {"knowledge_graph": {"edges": []}}} + out = add_shepherd_arax_to_edge_sources(response) + # The function returns early; edges stays a list. + assert out["message"]["knowledge_graph"]["edges"] == [] + + +def test_skips_non_dict_edge_entries(): + response = { + "message": { + "knowledge_graph": { + "edges": { + "valid": {"subject": "A", "object": "B"}, + "garbage": "not-a-dict", + }, + "nodes": {"A": {}, "B": {}}, + } + } + } + out = add_shepherd_arax_to_edge_sources(response) + edges = out["message"]["knowledge_graph"]["edges"] + # Valid edge gets the source. + assert edges["valid"]["sources"][0]["resource_id"] == "infores:shepherd-arax" + # Garbage entry untouched. + assert edges["garbage"] == "not-a-dict" + + +def test_skips_edge_with_non_list_sources(): + response = { + "message": { + "knowledge_graph": { + "edges": { + "e1": { + "subject": "A", + "object": "B", + "sources": "not-a-list", + } + }, + "nodes": {"A": {}, "B": {}}, + } + } + } + out = add_shepherd_arax_to_edge_sources(response) + # ``sources`` was not a list; injector skips it without crashing. + assert out["message"]["knowledge_graph"]["edges"]["e1"]["sources"] == "not-a-list" + + +def test_each_inserted_source_is_an_independent_copy(): + """Edges should not share a mutable source dict instance.""" + response = { + "message": { + "knowledge_graph": { + "edges": { + "e1": {"subject": "A", "object": "B"}, + "e2": {"subject": "C", "object": "D"}, + }, + "nodes": {"A": {}, "B": {}, "C": {}, "D": {}}, + } + } + } + out = add_shepherd_arax_to_edge_sources(response) + s1 = out["message"]["knowledge_graph"]["edges"]["e1"]["sources"][0] + s2 = out["message"]["knowledge_graph"]["edges"]["e2"]["sources"][0] + # Same content, different identity. + assert s1 == s2 + assert s1 is not s2 diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py new file mode 100644 index 0000000..d3eeb02 --- /dev/null +++ b/tests/unit/test_logger.py @@ -0,0 +1,103 @@ +"""Tests for ``shepherd_utils.logger``: per-query log handler + formatter.""" + +import logging +import os + +from shepherd_utils.logger import ( + QueryLogger, + ReasonerLogEntryFormatter, + get_logging_config, +) + + +def _make_record(msg, level=logging.INFO, name="testlogger"): + return logging.LogRecord( + name=name, + level=level, + pathname=__file__, + lineno=10, + msg=msg, + args=None, + exc_info=None, + ) + + +def test_reasoner_formatter_string_message(): + formatter = ReasonerLogEntryFormatter() + out = formatter.format(_make_record("hello")) + assert out["message"] == "hello" + assert out["level"] == "INFO" + # Timestamp is iso8601-ish (T separator). + assert "T" in out["timestamp"] + + +def test_reasoner_formatter_dict_message_merges_extra_keys(): + formatter = ReasonerLogEntryFormatter() + record = _make_record( + {"message": "embedded", "extra": "data"}, level=logging.WARNING + ) + out = formatter.format(record) + assert out["message"] == "embedded" + assert out["extra"] == "data" + assert out["level"] == "WARNING" + + +def test_query_logger_handler_collects_records_newest_first(): + """Records should be appended to the front of the deque (appendleft).""" + ql = QueryLogger() + handler = ql.log_handler + sub_logger = logging.getLogger("test_query_logger.collects") + sub_logger.handlers.clear() + sub_logger.addHandler(handler) + sub_logger.setLevel(logging.DEBUG) + try: + sub_logger.info("first") + sub_logger.info("second") + contents = list(handler.contents()) + # newest first + assert [c["message"] for c in contents] == ["second", "first"] + finally: + sub_logger.removeHandler(handler) + + +def test_query_logger_handler_named_query_log_handler(): + """save_logs in db.py looks up the handler by name; verify the contract.""" + handler = QueryLogger().log_handler + assert handler.name == "query_log_handler" + + +def test_query_logger_respects_maxlen(): + """When a maxlen is set, oldest records get dropped.""" + ql = QueryLogger(maxlen=2) + handler = ql.log_handler + sub_logger = logging.getLogger("test_query_logger.maxlen") + sub_logger.handlers.clear() + sub_logger.addHandler(handler) + sub_logger.setLevel(logging.DEBUG) + try: + for i in range(5): + sub_logger.info(f"msg-{i}") + contents = list(handler.contents()) + # Newest two + assert [c["message"] for c in contents] == ["msg-4", "msg-3"] + finally: + sub_logger.removeHandler(handler) + + +def test_get_logging_config_local_includes_file_handler(monkeypatch, tmp_path): + """Outside Kubernetes, the config should set up a rotating file handler.""" + monkeypatch.delenv("KUBERNETES_SERVICE_HOST", raising=False) + monkeypatch.chdir(tmp_path) + config = get_logging_config() + assert "file" in config["handlers"] + assert "console" in config["handlers"] + assert set(config["loggers"]["shepherd"]["handlers"]) == {"console", "file"} + # The function eagerly creates the logs/ dir for file output. + assert os.path.isdir(tmp_path / "logs") + + +def test_get_logging_config_kubernetes_skips_file_handler(monkeypatch): + monkeypatch.setenv("KUBERNETES_SERVICE_HOST", "10.0.0.1") + config = get_logging_config() + assert "file" not in config["handlers"] + assert config["loggers"]["shepherd"]["handlers"] == ["console"] diff --git a/tests/unit/test_merge_message_helpers.py b/tests/unit/test_merge_message_helpers.py new file mode 100644 index 0000000..c5cb9f3 --- /dev/null +++ b/tests/unit/test_merge_message_helpers.py @@ -0,0 +1,560 @@ +"""Tests for the pure helpers in ``workers.merge_message.worker``. + +These cover: + +- ``get_edgeset`` — kgraph-edge frozenset extraction from a result. +- ``create_aux_graph`` — analysis -> aux graph conversion. +- ``add_knowledge_edge`` — synthetic creative-mode kg edge construction. +- ``_normalize_query`` / ``queries_equivalent`` — query graph equivalence. +- ``has_unique_nodes`` / ``filter_repeated_nodes`` — duplicate-node filtering. +- ``get_promiscuous_qnodes`` / ``remove_promiscuous_knode_results`` / + ``filter_promiscuous_results`` — over-popular knode pruning. +- ``get_answer_node`` — finding the unpinned qnode. +- ``group_results_by_qnode`` — grouping creative + lookup results. +- ``merge_messages`` — the top-level merge (lookup, creative, pathfinder, error). +""" + +import copy +import logging + +import pytest + +from tests.helpers.generate_messages import ( + creative_query, + generate_query, + generate_response, + response_1, + response_2, +) +from workers.merge_message.worker import ( + add_knowledge_edge, + create_aux_graph, + filter_promiscuous_results, + filter_repeated_nodes, + get_answer_node, + get_edgeset, + get_promiscuous_qnodes, + group_results_by_qnode, + has_unique_nodes, + merge_messages, + queries_equivalent, + remove_promiscuous_knode_results, + _normalize_query, +) + +logger = logging.getLogger(__name__) + + +def test_get_edgeset_collapses_all_edge_bindings(): + result = { + "analyses": [ + { + "edge_bindings": { + "e0": [{"id": "k1"}, {"id": "k2"}], + "e1": [{"id": "k3"}], + } + }, + {"edge_bindings": {"e2": [{"id": "k4"}]}}, + ] + } + out = get_edgeset(result) + assert out == frozenset({"k1", "k2", "k3", "k4"}) + + +def test_create_aux_graph_returns_uuid_and_edge_list(): + analysis = { + "edge_bindings": { + "e0": [{"id": "kedge_a"}, {"id": "kedge_b"}], + "e1": [{"id": "kedge_c"}], + } + } + aux_id, aux_graph = create_aux_graph(analysis) + assert isinstance(aux_id, str) and len(aux_id) > 0 + assert sorted(aux_graph["edges"]) == ["kedge_a", "kedge_b", "kedge_c"] + assert aux_graph["attributes"] == [] + + +def test_add_knowledge_edge_with_object_pinned_uses_answer_as_subject(): + """When the object qnode is pinned (creative_query default), the new + knowledge edge uses the answer CURIE as subject and pinned id as object.""" + result_message = copy.deepcopy(creative_query) + result_message["message"]["knowledge_graph"] = {"nodes": {}, "edges": {}} + new_edge_id = add_knowledge_edge( + target="aragorn", + result_message=result_message, + aux_graph_ids=["aux1", "aux2"], + answer="CHEBI:NEW", + ) + edges = result_message["message"]["knowledge_graph"]["edges"] + new_edge = edges[new_edge_id] + assert new_edge["subject"] == "CHEBI:NEW" # answer was the unpinned subject + assert new_edge["object"] == "MONDO:0001" # pinned object's id + assert new_edge["predicate"] == "biolink:treats" + # Aux graph id list lands in the support_graphs attribute. + sg_attrs = [ + a + for a in new_edge["attributes"] + if a["attribute_type_id"] == "biolink:support_graphs" + ] + assert sg_attrs[0]["value"] == ["aux1", "aux2"] + # Source is shepherd-{target}. + assert new_edge["sources"][0]["resource_id"] == "infores:shepherd-aragorn" + + +def test_add_knowledge_edge_with_subject_pinned_uses_answer_as_object(): + """Mirror of the above: subject is pinned, answer is the object.""" + msg = copy.deepcopy(creative_query) + msg["message"]["query_graph"]["nodes"]["SN"]["ids"] = ["CHEBI:1"] + msg["message"]["query_graph"]["nodes"]["ON"].pop("ids") + msg["message"]["knowledge_graph"] = {"nodes": {}, "edges": {}} + new_edge_id = add_knowledge_edge( + target="aragorn", + result_message=msg, + aux_graph_ids=["aux1"], + answer="MONDO:NEW", + ) + new_edge = msg["message"]["knowledge_graph"]["edges"][new_edge_id] + assert new_edge["subject"] == "CHEBI:1" + assert new_edge["object"] == "MONDO:NEW" + + +def test_add_knowledge_edge_passes_through_qualifier_constraints(): + msg = copy.deepcopy(creative_query) + msg["message"]["query_graph"]["edges"]["e0"]["qualifier_constraints"] = [ + { + "qualifier_set": [ + { + "qualifier_type_id": "biolink:object_aspect_qualifier", + "qualifier_value": "activity", + } + ] + } + ] + msg["message"]["knowledge_graph"] = {"nodes": {}, "edges": {}} + new_edge_id = add_knowledge_edge( + target="aragorn", result_message=msg, aux_graph_ids=["a"], answer="CHEBI:NEW" + ) + new_edge = msg["message"]["knowledge_graph"]["edges"][new_edge_id] + assert new_edge["qualifiers"] == [ + { + "qualifier_type_id": "biolink:object_aspect_qualifier", + "qualifier_value": "activity", + } + ] + + +def test_normalize_query_collapses_optional_and_synonym_predicates(): + """Empty constraint lists, BATCH set_interpretation, and biolink:treats → + biolink:treats_or_applied_or_studied_to_treat all get normalized away.""" + q = { + "nodes": { + "n": { + "ids": None, + "categories": None, + "is_set": False, + "set_interpretation": "BATCH", + "constraints": [], + "member_ids": [], + } + }, + "edges": { + "e": { + "subject": "n", + "object": "n", + "predicates": ["biolink:treats"], + "knowledge_type": "lookup", + "attribute_constraints": [], + "qualifier_constraints": [], + } + }, + } + out = _normalize_query(q) + n = out["nodes"]["n"] + assert "ids" not in n and "categories" not in n + assert "is_set" not in n + assert "set_interpretation" not in n + assert "constraints" not in n and "member_ids" not in n + e = out["edges"]["e"] + assert e["predicates"] == ["biolink:treats_or_applied_or_studied_to_treat"] + assert "knowledge_type" not in e + assert "attribute_constraints" not in e and "qualifier_constraints" not in e + + +def test_queries_equivalent_treats_predicate_synonyms_as_same(): + a = { + "nodes": {"x": {"ids": ["MONDO:1"]}}, + "edges": { + "e": {"subject": "x", "object": "x", "predicates": ["biolink:treats"]} + }, + } + b = { + "nodes": {"x": {"ids": ["MONDO:1"]}}, + "edges": { + "e": { + "subject": "x", + "object": "x", + "predicates": ["biolink:treats_or_applied_or_studied_to_treat"], + } + }, + } + assert queries_equivalent(a, b) is True + + +def test_queries_equivalent_distinguishes_different_predicates(): + a = { + "nodes": {"x": {}}, + "edges": { + "e": {"subject": "x", "object": "x", "predicates": ["biolink:related_to"]} + }, + } + b = { + "nodes": {"x": {}}, + "edges": { + "e": {"subject": "x", "object": "x", "predicates": ["biolink:treats"]} + }, + } + assert queries_equivalent(a, b) is False + + +def test_normalize_query_does_not_mutate_input(): + """The previous implementation deep-copied; the new one must not mutate.""" + q = { + "nodes": {"n": {"ids": None, "is_set": False, "set_interpretation": "BATCH"}}, + "edges": {}, + } + snapshot = copy.deepcopy(q) + _normalize_query(q) + assert q == snapshot + + +def test_has_unique_nodes_false_when_two_qnodes_share_binding(): + result = { + "node_bindings": { + "n0": [{"id": "A"}], + "n1": [{"id": "A"}], # duplicate + } + } + assert has_unique_nodes(result) is False + + +def test_has_unique_nodes_true_for_distinct_bindings(): + result = { + "node_bindings": { + "n0": [{"id": "A"}], + "n1": [{"id": "B"}], + } + } + assert has_unique_nodes(result) is True + + +def test_filter_repeated_nodes_drops_results_with_repeated_kvalues(): + response = { + "message": { + "query_graph": {}, + "knowledge_graph": {"nodes": {}, "edges": {}}, + "auxiliary_graphs": {}, + "results": [ + { + "node_bindings": {"a": [{"id": "X"}], "b": [{"id": "X"}]}, + "analyses": [{"edge_bindings": {}}], + }, + { + "node_bindings": {"a": [{"id": "Y"}], "b": [{"id": "Z"}]}, + "analyses": [{"edge_bindings": {}}], + }, + ], + } + } + filter_repeated_nodes(response, logger) + remaining = response["message"]["results"] + assert len(remaining) == 1 + assert remaining[0]["node_bindings"]["a"][0]["id"] == "Y" + + +def test_filter_repeated_nodes_no_results_is_a_noop(): + response = {"message": {"results": []}} + filter_repeated_nodes(response, logger) + assert response["message"]["results"] == [] + + +def test_get_promiscuous_qnodes_finds_shared_subject(): + """Two edges sharing a subject with the same predicate -> the subject is the + promiscuous (center) node candidate.""" + response = { + "message": { + "query_graph": { + "nodes": {"a": {}, "b": {}, "c": {}, "d": {}}, + "edges": { + "e1": {"subject": "c", "object": "a", "predicates": ["biolink:p"]}, + "e2": {"subject": "c", "object": "b", "predicates": ["biolink:p"]}, + "e3": {"subject": "d", "object": "a", "predicates": ["biolink:p"]}, + }, + } + } + } + out = get_promiscuous_qnodes(response) + assert "c" in out + + +def test_get_promiscuous_qnodes_returns_empty_for_few_edges(): + response = {"message": {"query_graph": {"nodes": {}, "edges": {"e1": {}}}}} + assert get_promiscuous_qnodes(response) == [] + + +def test_remove_promiscuous_knode_results_drops_overrepresented_knode(): + """Construct an oversubscribed knode and verify it gets pruned.""" + response = { + "message": { + "results": [{"node_bindings": {"qx": [{"id": "BOZO"}]}} for _ in range(15)] + + [ + {"node_bindings": {"qx": [{"id": "GOOD"}]}}, + ], + } + } + remove_promiscuous_knode_results(MAX_C=10, qnode="qx", response=response) + remaining_ids = [ + r["node_bindings"]["qx"][0]["id"] for r in response["message"]["results"] + ] + assert "BOZO" not in remaining_ids + assert remaining_ids == ["GOOD"] + + +def test_filter_promiscuous_results_short_circuits_when_results_below_threshold(): + response = {"message": {"results": [{"node_bindings": {}, "analyses": []}]}} + # Should be a noop; no query_graph access needed. + filter_promiscuous_results(response, logger) + assert len(response["message"]["results"]) == 1 + + +def test_get_answer_node_returns_unpinned_qnode(): + qg = {"nodes": {"a": {"ids": ["X:1"]}, "b": {}}} + assert get_answer_node(qg) == "b" + + +def test_get_answer_node_returns_none_when_multiple_unpinned(): + qg = {"nodes": {"a": {}, "b": {}}} + assert get_answer_node(qg) is None + + +def test_group_results_by_qnode_partitions_into_creative_and_lookup(): + result_message = { + "message": { + "query_graph": {}, + "results": [ + { + "node_bindings": { + "qn": [{"id": "X"}], + }, + "analyses": [{"edge_bindings": {}}], + }, + ], + } + } + lookup_results = [ + { + "node_bindings": {"qn": [{"id": "X"}]}, + "analyses": [{"edge_bindings": {}}], + }, + { + "node_bindings": {"qn": [{"id": "Y"}]}, + "analyses": [{"edge_bindings": {}}], + }, + ] + grouped = group_results_by_qnode("qn", result_message, lookup_results) + # X has both a creative result and a lookup result. + x_key = frozenset(["X"]) + y_key = frozenset(["Y"]) + assert len(grouped[x_key]["creative"]) == 1 + assert len(grouped[x_key]["lookup"]) == 1 + assert grouped[y_key]["creative"] == [] + assert len(grouped[y_key]["lookup"]) == 1 + + +def test_merge_messages_unsupported_query_type_raises(): + """A query graph without ``edges`` or ``paths`` should raise.""" + with pytest.raises(TypeError, match="Unsupported"): + merge_messages( + target="t", + original_query_graph={"nodes": {}}, + response={"message": {"query_graph": {"nodes": {}}, "results": []}}, + new_response={"message": {"query_graph": {"nodes": {}}, "results": []}}, + logger=logger, + ) + + +def test_merge_messages_lookup_only_returns_new_response_directly(): + """A direct lookup with no creative answer node returns the new response as-is.""" + qg = { + "nodes": {"a": {"ids": ["X:1"]}, "b": {"ids": ["Y:2"]}}, + "edges": {"e0": {"subject": "a", "object": "b", "predicates": ["biolink:p"]}}, + } + new_response = { + "message": { + "query_graph": qg, + "knowledge_graph": {"nodes": {}, "edges": {}}, + "results": [ + { + "node_bindings": { + "a": [{"id": "X:1"}], + "b": [{"id": "Y:2"}], + }, + "analyses": [{"edge_bindings": {}}], + }, + ], + "auxiliary_graphs": {}, + } + } + out = merge_messages( + target="t", + original_query_graph=qg, + response={ + "message": { + "query_graph": qg, + "knowledge_graph": {"nodes": {}, "edges": {}}, + "results": [], + "auxiliary_graphs": {}, + } + }, + new_response=new_response, + logger=logger, + ) + # No creative answer node -> direct lookup return path. + assert len(out["message"]["results"]) == 1 + + +def test_merge_messages_combines_aux_graphs(): + """Aux graphs from both messages should land in the merged auxiliary_graphs.""" + response = generate_response() + callback = copy.deepcopy(response_2) + response["message"]["auxiliary_graphs"]["aux_a"] = { + "edges": ["a-edge"], + "attributes": [{"foo": "bar"}], + } + callback["message"]["auxiliary_graphs"]["aux_b"] = { + "edges": ["b-edge"], + "attributes": [], + } + out = merge_messages( + target="aragorn", + original_query_graph=response_1["message"]["query_graph"], + response=response, + new_response=callback, + logger=logger, + ) + aux = out["message"]["auxiliary_graphs"] + assert "aux_a" in aux + assert "aux_b" in aux + + +def test_merge_messages_pathfinder_query_returns_single_result(): + """A pathfinder query (paths instead of edges) should return a single + pathfinder result with the start/end node IDs threaded through.""" + pathfinder_qg = { + "nodes": { + "n0": {"ids": ["MONDO:0001"]}, + "n1": {"ids": ["MONDO:0002"]}, + }, + "paths": { + "p1": {"subject": "n0", "object": "n1"}, + }, + } + new_response = { + "message": { + "query_graph": pathfinder_qg, + "knowledge_graph": {"nodes": {}, "edges": {}}, + "auxiliary_graphs": {}, + "results": [ + { + "node_bindings": { + "n0": [{"id": "MONDO:0001"}], + "n1": [{"id": "MONDO:0002"}], + }, + "analyses": [ + { + "edge_bindings": {"p1": [{"id": "kedge_pathfinder"}]}, + } + ], + "score": 0.42, + } + ], + } + } + response = { + "message": { + "query_graph": pathfinder_qg, + "knowledge_graph": {"nodes": {}, "edges": {}}, + "auxiliary_graphs": {}, + "results": [], + } + } + out = merge_messages( + target="aragorn", + original_query_graph=pathfinder_qg, + response=response, + new_response=new_response, + logger=logger, + ) + assert len(out["message"]["results"]) == 1 + pf = out["message"]["results"][0] + assert pf["node_bindings"]["n0"][0]["id"] == "MONDO:0001" + assert pf["node_bindings"]["n1"][0]["id"] == "MONDO:0002" + # An auxiliary graph should have been created and associated with an + # analysis path binding. + assert pf["analyses"][0]["score"] == 0.42 + assert "p1" in pf["analyses"][0]["path_bindings"] + + +def test_merge_messages_pathfinder_missing_endpoint_raises(): + """A pathfinder query with no subject/object should raise KeyError.""" + bad_qg = {"nodes": {}, "paths": {"p1": {"subject": None, "object": "n1"}}} + with pytest.raises(KeyError, match="subject or object"): + merge_messages( + target="t", + original_query_graph=bad_qg, + response={ + "message": { + "query_graph": bad_qg, + "knowledge_graph": {"nodes": {}, "edges": {}}, + "auxiliary_graphs": {}, + "results": [], + } + }, + new_response={ + "message": { + "query_graph": bad_qg, + "knowledge_graph": {"nodes": {}, "edges": {}}, + "auxiliary_graphs": {}, + "results": [], + } + }, + logger=logger, + ) + + +def test_merge_messages_creative_query_creates_knowledge_edges(): + """A creative-mode query (matches existing test fixture) should have its + inferred creative results converted into synthetic knowledge edges.""" + original_qg = generate_query()["message"]["query_graph"] + response = generate_response() + callback = copy.deepcopy(response_2) + + out = merge_messages( + target="aragorn", + original_query_graph=original_qg, + response=response, + new_response=callback, + logger=logger, + ) + edges = out["message"]["knowledge_graph"]["edges"] + # Every creative result should have produced at least one synthetic edge + # whose source is the shepherd-aragorn primary kg. + shepherd_edges = [ + e + for e in edges.values() + if any( + s.get("resource_id") == "infores:shepherd-aragorn" + and s.get("resource_role") == "primary_knowledge_source" + for s in e.get("sources", []) + ) + ] + assert shepherd_edges diff --git a/tests/unit/test_misc_branch_coverage.py b/tests/unit/test_misc_branch_coverage.py new file mode 100644 index 0000000..e1b2f97 --- /dev/null +++ b/tests/unit/test_misc_branch_coverage.py @@ -0,0 +1,405 @@ +"""Branch-coverage tests targeting edge cases not exercised by other tests. + +Covers: + +- ``shepherd_utils.shared.merge_kgraph`` overlapping-edge attribute and + source merge paths. +- ``shepherd_utils.shared.validate_message`` edge with bad attribute path. +- ``shepherd_utils.shared.get_next_operation`` corner case (single op). +- ``workers.aragorn_omnicorp.worker`` helpers: ``create_log_entry``, + ``add_node_pmid_counts`` (default count = 0, attributes init), the + setnode branch of ``generate_curie_pairs``, and ``add_shared_pmid_counts`` + reuse of an existing OMNICORP support graph. +- ``workers.finish_query.worker.process_task`` happy and failure paths. +""" + +import asyncio +import json +import logging + +import pytest + +from shepherd_utils.shared import merge_kgraph, validate_message + +logger = logging.getLogger(__name__) + + +# --- merge_kgraph overlapping-edge merge paths ---------------------------- + + +def test_merge_kgraph_overlapping_edge_merges_attributes(): + """When the same edge id appears in both messages, the new attributes get + combined onto the existing edge.""" + og = { + "nodes": {}, + "edges": { + "shared": { + "subject": "A", + "object": "B", + "attributes": [{"attribute_type_id": "biolink:foo", "value": 1}], + "sources": [ + { + "resource_id": "infores:original", + "resource_role": "primary_knowledge_source", + } + ], + } + }, + } + new = { + "nodes": {}, + "edges": { + "shared": { + "subject": "A", + "object": "B", + "attributes": [{"attribute_type_id": "biolink:bar", "value": 2}], + "sources": [ + { + "resource_id": "infores:other", + "resource_role": "supporting_data_source", + } + ], + } + }, + } + merged = merge_kgraph(og, new, "infores:test", logger) + edge = merged["edges"]["shared"] + type_ids = {a["attribute_type_id"] for a in edge["attributes"]} + assert type_ids == {"biolink:foo", "biolink:bar"} + resource_ids = {s["resource_id"] for s in edge["sources"]} + # original + new sources combine; aggregator is NOT added in the merge + # path because the edge already existed. + assert {"infores:original", "infores:other"}.issubset(resource_ids) + + +def test_merge_kgraph_overlapping_edge_adopts_attrs_when_existing_empty(): + """If the existing edge has no attributes, the incoming attributes are + adopted directly (no combine_unique_dicts call).""" + og = { + "nodes": {}, + "edges": { + "shared": { + "subject": "A", + "object": "B", + "attributes": [], + "sources": [ + { + "resource_id": "infores:original", + "resource_role": "primary_knowledge_source", + } + ], + } + }, + } + new_attrs = [{"attribute_type_id": "biolink:bar", "value": 2}] + new = { + "nodes": {}, + "edges": { + "shared": { + "subject": "A", + "object": "B", + "attributes": new_attrs, + "sources": [], + } + }, + } + merged = merge_kgraph(og, new, "infores:test", logger) + assert merged["edges"]["shared"]["attributes"] == new_attrs + + +def test_merge_kgraph_overlapping_edge_adopts_sources_when_existing_empty(): + new_sources = [ + {"resource_id": "infores:other", "resource_role": "supporting_data_source"} + ] + og = { + "nodes": {}, + "edges": { + "shared": { + "subject": "A", + "object": "B", + "attributes": [], + "sources": [], + } + }, + } + new = { + "nodes": {}, + "edges": { + "shared": { + "subject": "A", + "object": "B", + "attributes": [], + "sources": new_sources, + } + }, + } + merged = merge_kgraph(og, new, "infores:test", logger) + assert merged["edges"]["shared"]["sources"] == new_sources + + +# --- validate_message attribute-loop branch ------------------------------- + + +def test_validate_message_skips_attributes_typo_field(tmp_path, monkeypatch): + """The implementation looks up the misspelled ``attibutes`` field, so the + attribute support_graph check is effectively dead. We verify it doesn't + crash on edges that have well-formed ``attributes`` but the typo'd field + is absent.""" + monkeypatch.chdir(tmp_path) + message = { + "message": { + "knowledge_graph": { + "nodes": {"A": {}, "B": {}}, + "edges": { + "e1": { + "subject": "A", + "object": "B", + "attributes": [ + { + "attribute_type_id": "biolink:support_graphs", + "value": ["aux1"], + } + ], + } + }, + }, + "auxiliary_graphs": {"aux1": {"edges": []}}, + } + } + validate_message(message, logger) + assert not (tmp_path / "invalid_message.json").exists() + + +# --- aragorn_omnicorp helpers --------------------------------------------- + + +def test_create_log_entry_returns_shaped_dict(): + from workers.aragorn_omnicorp.worker import create_log_entry + + entry = create_log_entry("hi", "INFO", code="X1") + assert entry["message"] == "hi" + assert entry["level"] == "INFO" + assert entry["code"] == "X1" + assert "timestamp" in entry + + +def test_create_log_entry_default_code_none(): + from workers.aragorn_omnicorp.worker import create_log_entry + + entry = create_log_entry("hi", "WARNING") + assert entry["code"] is None + + +@pytest.mark.asyncio +async def test_add_node_pmid_counts_uses_zero_for_missing_curies(): + """add_node_pmid_counts attaches an attribute with value=0 to nodes whose + curie isn't in the counts dict.""" + from workers.aragorn_omnicorp.worker import add_node_pmid_counts + + kgraph = { + "nodes": { + "MONDO:0001": {"attributes": []}, + "MISSING:CURIE": {"attributes": None}, # exercises the None-init branch + } + } + await add_node_pmid_counts(kgraph, {"MONDO:0001": 42}) + found = kgraph["nodes"]["MONDO:0001"]["attributes"] + found_attr = next( + a for a in found if a.get("original_attribute_name") == "omnicorp_article_count" + ) + assert found_attr["value"] == 42 + # Missing curie still gets the attribute, value = 0. + missing_attrs = kgraph["nodes"]["MISSING:CURIE"]["attributes"] + missing_attr = next( + a + for a in missing_attrs + if a.get("original_attribute_name") == "omnicorp_article_count" + ) + assert missing_attr["value"] == 0 + + +@pytest.mark.asyncio +async def test_add_shared_pmid_counts_reuses_existing_omnicorp_support_graph(): + """If an analysis already has an OMNICORP support graph, new co-occurrence + edges should be appended to that one rather than creating a second.""" + from workers.aragorn_omnicorp.worker import add_shared_pmid_counts + + message = { + "knowledge_graph": {"nodes": {"A": {}, "B": {}}, "edges": {}}, + "auxiliary_graphs": { + "OMNICORP_support_graph_existing": {"edges": [], "attributes": []}, + }, + "results": [ + { + "analyses": [ + { + "edge_bindings": {}, + "support_graphs": ["OMNICORP_support_graph_existing"], + } + ] + } + ], + } + pair_to_answer = {("A", "B"): {(0, 0)}} + await add_shared_pmid_counts(message, {("A", "B"): 7}, pair_to_answer) + sgs = message["results"][0]["analyses"][0]["support_graphs"] + # No new OMNICORP support graph created; the existing one was reused. + assert "OMNICORP_support_graph_existing" in sgs + assert sum(1 for s in set(sgs) if s.startswith("OMNICORP_support_graph")) == 1 + # The existing aux graph picked up the new co-occurrence edge. + assert ( + len(message["auxiliary_graphs"]["OMNICORP_support_graph_existing"]["edges"]) + == 1 + ) + + +@pytest.mark.asyncio +async def test_add_shared_pmid_counts_skips_zero_publication_counts(): + """A pair with publication_count == 0 should not produce any edges.""" + from workers.aragorn_omnicorp.worker import add_shared_pmid_counts + + message = { + "knowledge_graph": {"nodes": {"A": {}, "B": {}}, "edges": {}}, + "auxiliary_graphs": {}, + "results": [ + { + "analyses": [ + { + "edge_bindings": {}, + } + ] + } + ], + } + await add_shared_pmid_counts(message, {("A", "B"): 0}, {("A", "B"): {(0, 0)}}) + assert message["knowledge_graph"]["edges"] == {} + + +@pytest.mark.asyncio +async def test_generate_curie_pairs_includes_setnode_pairings(): + """When the qgraph has setnodes, every cross-product with non-set nodes + becomes a candidate pair.""" + from workers.aragorn_omnicorp.worker import generate_curie_pairs + + answers = [ + { + "node_bindings": { + "qset": [{"id": "S1"}, {"id": "S2"}], + "qother": [{"id": "O1"}], + }, + "analyses": [ + { + "edge_bindings": {}, + } + ], + } + ] + qgraph_setnodes = {"qset"} + node_pub_counts = {"S1": 1, "S2": 1, "O1": 1} + message = { + "knowledge_graph": {"edges": {}}, + "auxiliary_graphs": {}, + } + pair_to_answer = await generate_curie_pairs( + answers, qgraph_setnodes, node_pub_counts, message, logger + ) + # S1<->O1 and S2<->O1 are the setnode-vs-nonset pairs. + assert (("O1", "S1") in pair_to_answer) or (("S1", "O1") in pair_to_answer) + assert (("O1", "S2") in pair_to_answer) or (("S2", "O1") in pair_to_answer) + + +# --- finish_query.process_task ------------------------------------------- + + +def _make_finish_task(): + return [ + "msg-id", + { + "query_id": "qid", + "response_id": "rid", + "workflow": json.dumps([]), + "log_level": "20", + "otel": "{}", + }, + ] + + +class _Limiter: + def __init__(self): + self.released = False + + def release(self): + self.released = True + + +@pytest.mark.asyncio +async def test_finish_query_process_task_acks_on_success(redis_mock, mocker): + from workers.finish_query import worker as fq + + mocker.patch.object(fq, "finish_query", new_callable=mocker.AsyncMock) + mock_ack = mocker.patch.object( + fq, "mark_task_as_complete", new_callable=mocker.AsyncMock + ) + limiter = _Limiter() + await fq.process_task(_make_finish_task(), None, logger, limiter) + assert mock_ack.called + assert limiter.released + + +@pytest.mark.asyncio +async def test_finish_query_process_task_acks_on_failure(redis_mock, mocker): + """Even when ``finish_query`` raises, ``mark_task_as_complete`` is still + called in the ``finally`` block.""" + from workers.finish_query import worker as fq + + mocker.patch.object( + fq, + "finish_query", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("kaboom"), + ) + mock_ack = mocker.patch.object( + fq, "mark_task_as_complete", new_callable=mocker.AsyncMock + ) + limiter = _Limiter() + await fq.process_task(_make_finish_task(), None, logger, limiter) + assert mock_ack.called + assert limiter.released + + +@pytest.mark.asyncio +async def test_finish_query_process_task_swallows_ack_failure(redis_mock, mocker): + """An ``mark_task_as_complete`` failure should be logged but not escape.""" + from workers.finish_query import worker as fq + + mocker.patch.object(fq, "finish_query", new_callable=mocker.AsyncMock) + mocker.patch.object( + fq, + "mark_task_as_complete", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("ack failed"), + ) + limiter = _Limiter() + await fq.process_task(_make_finish_task(), None, logger, limiter) + assert limiter.released + + +@pytest.mark.asyncio +async def test_finish_query_process_task_handles_cancellation(redis_mock, mocker): + """A CancelledError inside finish_query should not escape, and ack still + runs.""" + from workers.finish_query import worker as fq + + mocker.patch.object( + fq, + "finish_query", + new_callable=mocker.AsyncMock, + side_effect=asyncio.CancelledError, + ) + mock_ack = mocker.patch.object( + fq, "mark_task_as_complete", new_callable=mocker.AsyncMock + ) + limiter = _Limiter() + await fq.process_task(_make_finish_task(), None, logger, limiter) + assert mock_ack.called + assert limiter.released diff --git a/tests/unit/test_process_task_wrappers.py b/tests/unit/test_process_task_wrappers.py new file mode 100644 index 0000000..db36e61 --- /dev/null +++ b/tests/unit/test_process_task_wrappers.py @@ -0,0 +1,447 @@ +"""Tests for the ``process_task`` wrappers across multiple workers. + +The pattern is repeated across most workers: call the worker's main async +function, then ``wrap_up_task`` on success or ``handle_task_failure`` on +exception. We verify the cancellation and failure branches are wired up +correctly. + +Workers covered: + +- ``filter_kgraph_orphans`` +- ``filter_results_top_n`` +- ``filter_analyses_top_n`` +- ``sort_results_score`` +- ``example_ara`` +- ``example_score`` +- ``example_lookup`` +- ``aragorn`` +- ``aragorn_pathfinder`` +- ``bte`` +""" + +import asyncio +import json +import logging + +import pytest + +logger = logging.getLogger(__name__) + + +def _make_task(stream_id): + return [ + "msg-id", + { + "query_id": "qid", + "response_id": "rid", + "workflow": json.dumps([{"id": stream_id}]), + "log_level": "20", + "otel": "{}", + }, + ] + + +class _Limiter: + """Stand-in for the ``asyncio.Semaphore`` ``process_task`` releases. + + We just record whether ``release()`` was called. + """ + + def __init__(self): + self.released = False + + def release(self): + self.released = True + + +# --- filter_kgraph_orphans ------------------------------------------------ + + +@pytest.mark.asyncio +async def test_filter_kgraph_orphans_process_task_happy_path(redis_mock, mocker): + from workers.filter_kgraph_orphans import worker as fko + + mocker.patch.object(fko, "do_filter_kgraph_orphans", new_callable=mocker.AsyncMock) + mock_wrap = mocker.patch.object(fko, "wrap_up_task", new_callable=mocker.AsyncMock) + + limiter = _Limiter() + await fko.process_task(_make_task("filter_kgraph_orphans"), None, logger, limiter) + assert mock_wrap.called + assert limiter.released + + +@pytest.mark.asyncio +async def test_filter_kgraph_orphans_process_task_failure_routes_to_failure_handler( + redis_mock, mocker +): + from workers.filter_kgraph_orphans import worker as fko + + mocker.patch.object( + fko, + "do_filter_kgraph_orphans", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("kaboom"), + ) + mock_failure = mocker.patch.object( + fko, "handle_task_failure", new_callable=mocker.AsyncMock + ) + + limiter = _Limiter() + await fko.process_task(_make_task("filter_kgraph_orphans"), None, logger, limiter) + assert mock_failure.called + assert limiter.released + + +@pytest.mark.asyncio +async def test_filter_kgraph_orphans_process_task_cancellation_does_not_route_failure( + redis_mock, mocker +): + """A CancelledError should be logged but not routed to handle_task_failure.""" + from workers.filter_kgraph_orphans import worker as fko + + mocker.patch.object( + fko, + "do_filter_kgraph_orphans", + new_callable=mocker.AsyncMock, + side_effect=asyncio.CancelledError, + ) + mock_failure = mocker.patch.object( + fko, "handle_task_failure", new_callable=mocker.AsyncMock + ) + + limiter = _Limiter() + await fko.process_task(_make_task("filter_kgraph_orphans"), None, logger, limiter) + assert not mock_failure.called + assert limiter.released + + +# --- filter_results_top_n ------------------------------------------------- + + +@pytest.mark.asyncio +async def test_filter_results_top_n_process_task_happy_path(redis_mock, mocker): + from workers.filter_results_top_n import worker as frt + + mocker.patch.object(frt, "filter_results_top_n", new_callable=mocker.AsyncMock) + mock_wrap = mocker.patch.object(frt, "wrap_up_task", new_callable=mocker.AsyncMock) + + limiter = _Limiter() + await frt.process_task(_make_task("filter_results_top_n"), None, logger, limiter) + assert mock_wrap.called + + +@pytest.mark.asyncio +async def test_filter_results_top_n_process_task_failure(redis_mock, mocker): + from workers.filter_results_top_n import worker as frt + + mocker.patch.object( + frt, + "filter_results_top_n", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("oops"), + ) + mock_failure = mocker.patch.object( + frt, "handle_task_failure", new_callable=mocker.AsyncMock + ) + + limiter = _Limiter() + await frt.process_task(_make_task("filter_results_top_n"), None, logger, limiter) + assert mock_failure.called + + +# --- filter_analyses_top_n ------------------------------------------------ + + +@pytest.mark.asyncio +async def test_filter_analyses_top_n_process_task_happy_path(redis_mock, mocker): + from workers.filter_analyses_top_n import worker as fan + + mocker.patch.object(fan, "filter_analyses_top_n", new_callable=mocker.AsyncMock) + mock_wrap = mocker.patch.object(fan, "wrap_up_task", new_callable=mocker.AsyncMock) + + limiter = _Limiter() + await fan.process_task(_make_task("filter_analyses_top_n"), None, logger, limiter) + assert mock_wrap.called + + +@pytest.mark.asyncio +async def test_filter_analyses_top_n_process_task_failure(redis_mock, mocker): + from workers.filter_analyses_top_n import worker as fan + + mocker.patch.object( + fan, + "filter_analyses_top_n", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("nope"), + ) + mock_failure = mocker.patch.object( + fan, "handle_task_failure", new_callable=mocker.AsyncMock + ) + + limiter = _Limiter() + await fan.process_task(_make_task("filter_analyses_top_n"), None, logger, limiter) + assert mock_failure.called + + +# --- sort_results_score --------------------------------------------------- + + +@pytest.mark.asyncio +async def test_sort_results_score_process_task_happy_path(redis_mock, mocker): + from workers.sort_results_score import worker as srs + + mocker.patch.object(srs, "sort_results_score", new_callable=mocker.AsyncMock) + mock_wrap = mocker.patch.object(srs, "wrap_up_task", new_callable=mocker.AsyncMock) + + limiter = _Limiter() + await srs.process_task(_make_task("sort_results_score"), None, logger, limiter) + assert mock_wrap.called + + +@pytest.mark.asyncio +async def test_sort_results_score_process_task_failure(redis_mock, mocker): + from workers.sort_results_score import worker as srs + + mocker.patch.object( + srs, + "sort_results_score", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("nope"), + ) + mock_failure = mocker.patch.object( + srs, "handle_task_failure", new_callable=mocker.AsyncMock + ) + + limiter = _Limiter() + await srs.process_task(_make_task("sort_results_score"), None, logger, limiter) + assert mock_failure.called + + +# --- example_ara ---------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_example_ara_process_task_happy_path(redis_mock, mocker): + from workers.example_ara import worker as eara + + mocker.patch.object(eara, "example_ara", new_callable=mocker.AsyncMock) + mock_wrap = mocker.patch.object(eara, "wrap_up_task", new_callable=mocker.AsyncMock) + + limiter = _Limiter() + await eara.process_task(_make_task("example"), None, logger, limiter) + assert mock_wrap.called + + +@pytest.mark.asyncio +async def test_example_ara_process_task_failure(redis_mock, mocker): + from workers.example_ara import worker as eara + + mocker.patch.object( + eara, + "example_ara", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("nope"), + ) + mock_failure = mocker.patch.object( + eara, "handle_task_failure", new_callable=mocker.AsyncMock + ) + + limiter = _Limiter() + await eara.process_task(_make_task("example"), None, logger, limiter) + assert mock_failure.called + + +# --- example_score -------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_example_score_process_task_happy_path(redis_mock, mocker): + from workers.example_score import worker as escore + + mocker.patch.object(escore, "example_score", new_callable=mocker.AsyncMock) + mock_wrap = mocker.patch.object( + escore, "wrap_up_task", new_callable=mocker.AsyncMock + ) + + limiter = _Limiter() + await escore.process_task(_make_task("example.score"), None, logger, limiter) + assert mock_wrap.called + + +@pytest.mark.asyncio +async def test_example_score_process_task_failure(redis_mock, mocker): + from workers.example_score import worker as escore + + mocker.patch.object( + escore, + "example_score", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("nope"), + ) + mock_failure = mocker.patch.object( + escore, "handle_task_failure", new_callable=mocker.AsyncMock + ) + + limiter = _Limiter() + await escore.process_task(_make_task("example.score"), None, logger, limiter) + assert mock_failure.called + + +# --- example_lookup ------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_example_lookup_process_task_happy_path(redis_mock, mocker): + from workers.example_lookup import worker as elookup + + mocker.patch.object(elookup, "example_lookup", new_callable=mocker.AsyncMock) + mock_wrap = mocker.patch.object( + elookup, "wrap_up_task", new_callable=mocker.AsyncMock + ) + + limiter = _Limiter() + await elookup.process_task(_make_task("example.lookup"), None, logger, limiter) + assert mock_wrap.called + + +@pytest.mark.asyncio +async def test_example_lookup_process_task_failure(redis_mock, mocker): + from workers.example_lookup import worker as elookup + + mocker.patch.object( + elookup, + "example_lookup", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("nope"), + ) + mock_failure = mocker.patch.object( + elookup, "handle_task_failure", new_callable=mocker.AsyncMock + ) + + limiter = _Limiter() + await elookup.process_task(_make_task("example.lookup"), None, logger, limiter) + assert mock_failure.called + + +# --- aragorn ------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_aragorn_process_task_happy_path(redis_mock, mocker): + from workers.aragorn import worker as ar + + mocker.patch.object(ar, "aragorn", new_callable=mocker.AsyncMock) + mock_wrap = mocker.patch.object(ar, "wrap_up_task", new_callable=mocker.AsyncMock) + + limiter = _Limiter() + await ar.process_task(_make_task("aragorn"), None, logger, limiter) + assert mock_wrap.called + + +@pytest.mark.asyncio +async def test_aragorn_process_task_failure(redis_mock, mocker): + from workers.aragorn import worker as ar + + mocker.patch.object( + ar, + "aragorn", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("nope"), + ) + mock_failure = mocker.patch.object( + ar, "handle_task_failure", new_callable=mocker.AsyncMock + ) + + limiter = _Limiter() + await ar.process_task(_make_task("aragorn"), None, logger, limiter) + assert mock_failure.called + + +# --- aragorn_pathfinder -------------------------------------------------- + + +@pytest.mark.asyncio +async def test_aragorn_pathfinder_process_task_happy_path(redis_mock, mocker): + from workers.aragorn_pathfinder import worker as apf + + mocker.patch.object(apf, "shadowfax", new_callable=mocker.AsyncMock) + mock_wrap = mocker.patch.object(apf, "wrap_up_task", new_callable=mocker.AsyncMock) + + limiter = _Limiter() + await apf.process_task(_make_task("aragorn.pathfinder"), None, logger, limiter) + assert mock_wrap.called + + +@pytest.mark.asyncio +async def test_aragorn_pathfinder_process_task_failure(redis_mock, mocker): + from workers.aragorn_pathfinder import worker as apf + + mocker.patch.object( + apf, + "shadowfax", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("nope"), + ) + mock_failure = mocker.patch.object( + apf, "handle_task_failure", new_callable=mocker.AsyncMock + ) + + limiter = _Limiter() + await apf.process_task(_make_task("aragorn.pathfinder"), None, logger, limiter) + assert mock_failure.called + + +# --- bte ----------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_bte_process_task_happy_path(redis_mock, mocker): + from workers.bte import worker as bte + + mocker.patch.object(bte, "bte", new_callable=mocker.AsyncMock) + mock_wrap = mocker.patch.object(bte, "wrap_up_task", new_callable=mocker.AsyncMock) + + limiter = _Limiter() + await bte.process_task(_make_task("bte"), None, logger, limiter) + assert mock_wrap.called + + +@pytest.mark.asyncio +async def test_bte_process_task_failure(redis_mock, mocker): + from workers.bte import worker as bte + + mocker.patch.object( + bte, + "bte", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("nope"), + ) + mock_failure = mocker.patch.object( + bte, "handle_task_failure", new_callable=mocker.AsyncMock + ) + + limiter = _Limiter() + await bte.process_task(_make_task("bte"), None, logger, limiter) + assert mock_failure.called + + +# --- wrap_up_task failures shouldn't escape ------------------------------- + + +@pytest.mark.asyncio +async def test_process_task_swallows_wrap_up_failures(redis_mock, mocker): + """The pattern is: try/except wrap_up_task — failures get logged but + don't escape the wrapper. Verify on filter_kgraph_orphans.""" + from workers.filter_kgraph_orphans import worker as fko + + mocker.patch.object(fko, "do_filter_kgraph_orphans", new_callable=mocker.AsyncMock) + mocker.patch.object( + fko, + "wrap_up_task", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("redis dropped"), + ) + limiter = _Limiter() + # Should not raise; logged and limiter still released. + await fko.process_task(_make_task("filter_kgraph_orphans"), None, logger, limiter) + assert limiter.released diff --git a/tests/unit/test_shared_utils.py b/tests/unit/test_shared_utils.py new file mode 100644 index 0000000..d209574 --- /dev/null +++ b/tests/unit/test_shared_utils.py @@ -0,0 +1,558 @@ +"""Tests for ``shepherd_utils.shared`` helper functions. + +These exercise the pure helpers (``combine_unique_dicts``, ``merge_kgraph``, +``is_support_edge``, the recursive support-graph traversal, and +``filter_kgraph_orphans``) plus the broker-driven workflow helpers +(``get_next_operation``, ``wrap_up_task``, ``handle_task_failure``). +""" + +import json +import logging + +import pytest + +from shepherd_utils.broker import get_task +from shepherd_utils.shared import ( + combine_unique_dicts, + filter_kgraph_orphans, + get_next_operation, + handle_task_failure, + is_support_edge, + merge_kgraph, + recursive_get_auxgraph_edges, + recursive_get_edge_support_graphs, + validate_message, + wrap_up_task, +) + +logger = logging.getLogger(__name__) + + +def test_get_next_operation_returns_first_op(): + workflow = [{"id": "first"}, {"id": "second"}] + next_op, returned = get_next_operation(workflow) + assert next_op == {"id": "first"} + # Function returns the same workflow back; it does not mutate it. + assert returned is workflow + assert returned == [{"id": "first"}, {"id": "second"}] + + +def test_combine_unique_dicts_dedupes_across_lists(): + a = [{"x": 1}, {"y": 2}] + b = [{"x": 1}, {"z": 3}] + out = combine_unique_dicts(a, b, logger) + assert {"x": 1} in out + assert {"y": 2} in out + assert {"z": 3} in out + assert len(out) == 3 + + +def test_combine_unique_dicts_treats_key_order_as_equal(): + a = [{"a": 1, "b": 2}] + b = [{"b": 2, "a": 1}] + out = combine_unique_dicts(a, b, logger) + assert out == [{"a": 1, "b": 2}] + + +def test_combine_unique_dicts_handles_unhashable_values_via_default_str(): + """Items with non-JSON values (e.g. sets) shouldn't blow up: default=str + makes the signature stable.""" + + class NotSerializable: + def __str__(self): + return "constant-token" + + a = [{"v": NotSerializable()}] + b = [{"v": NotSerializable()}] + out = combine_unique_dicts(a, b, logger) + assert len(out) == 1 + + +def test_is_support_edge_true_for_support_graph_attribute(): + edge = { + "attributes": [ + {"attribute_type_id": "biolink:support_graphs", "value": ["aux1"]}, + ], + } + assert is_support_edge(edge) is True + + +def test_is_support_edge_false_for_no_attributes(): + assert is_support_edge({}) is False + + +def test_is_support_edge_false_for_non_support_attributes(): + edge = { + "attributes": [ + {"attribute_type_id": "biolink:has_evidence", "value": 5}, + ], + } + assert is_support_edge(edge) is False + + +def test_merge_kgraph_adds_new_nodes_and_appends_aggregator_source(): + og = {"nodes": {}, "edges": {}} + new = { + "nodes": { + "MONDO:1": { + "name": "n1", + "categories": ["biolink:Disease"], + "attributes": [], + }, + }, + "edges": { + "e1": { + "subject": "MONDO:1", + "object": "MONDO:2", + "attributes": [], + "sources": [ + { + "resource_id": "infores:original", + "resource_role": "primary_knowledge_source", + } + ], + }, + }, + } + merged = merge_kgraph(og, new, "infores:test", logger) + assert "MONDO:1" in merged["nodes"] + assert "e1" in merged["edges"] + sources = merged["edges"]["e1"]["sources"] + aggregator_ids = [s["resource_id"] for s in sources] + assert "infores:test" in aggregator_ids + + +def test_merge_kgraph_merges_existing_node_attributes_and_categories(): + og = { + "nodes": { + "MONDO:1": { + "name": "", + "categories": ["biolink:Disease"], + "attributes": [{"attribute_type_id": "biolink:foo", "value": 1}], + }, + }, + "edges": {}, + } + new = { + "nodes": { + "MONDO:1": { + "name": "Updated", + "categories": ["biolink:DiseaseOrPhenotypicFeature"], + "attributes": [{"attribute_type_id": "biolink:bar", "value": 2}], + }, + }, + "edges": {}, + } + merged = merge_kgraph(og, new, "infores:test", logger) + node = merged["nodes"]["MONDO:1"] + assert node["name"] == "Updated" + assert set(node["categories"]) == { + "biolink:Disease", + "biolink:DiseaseOrPhenotypicFeature", + } + # Both attributes preserved through dedupe. + type_ids = {a["attribute_type_id"] for a in node["attributes"]} + assert {"biolink:foo", "biolink:bar"}.issubset(type_ids) + + +def test_merge_kgraph_does_not_double_aggregator_source(): + aggregator = { + "resource_id": "infores:test", + "resource_role": "aggregator_knowledge_source", + "upstream_resource_ids": ["infores:retriever"], + } + og = {"nodes": {}, "edges": {}} + new = { + "nodes": {}, + "edges": { + "e1": { + "subject": "A", + "object": "B", + "attributes": [], + "sources": [aggregator], + }, + }, + } + merged = merge_kgraph(og, new, "infores:test", logger) + sources = merged["edges"]["e1"]["sources"] + matching = [s for s in sources if s["resource_id"] == "infores:test"] + assert len(matching) == 1 + + +def test_merge_kgraph_does_not_append_aggregator_to_support_edge(): + """Support edges (those carrying biolink:support_graphs attributes) should + not have a new aggregator source appended.""" + og = {"nodes": {}, "edges": {}} + new = { + "nodes": {}, + "edges": { + "e1": { + "subject": "A", + "object": "B", + "attributes": [ + {"attribute_type_id": "biolink:support_graphs", "value": ["aux1"]}, + ], + "sources": [ + { + "resource_id": "infores:original", + "resource_role": "primary_knowledge_source", + } + ], + }, + }, + } + merged = merge_kgraph(og, new, "infores:test", logger) + sources = merged["edges"]["e1"]["sources"] + assert all(s["resource_id"] != "infores:test" for s in sources) + + +def test_recursive_get_edge_support_graphs_short_circuits_on_visited_edge(): + """Edges already in the visited set should not be re-traversed (otherwise + circular support graphs would loop forever).""" + edges = {"already-seen"} + auxgraphs = set() + nodes = set() + out = recursive_get_edge_support_graphs( + "already-seen", + edges, + auxgraphs, + message_edges={"already-seen": {"subject": "A", "object": "B"}}, + message_auxgraphs={}, + nodes=nodes, + ) + assert out == (edges, auxgraphs, nodes) + + +def test_recursive_get_edge_support_graphs_walks_into_auxgraphs(): + message_edges = { + "edge1": { + "subject": "A", + "object": "B", + "attributes": [ + {"attribute_type_id": "biolink:support_graphs", "value": ["aux1"]}, + ], + }, + "edge2": { + "subject": "B", + "object": "C", + "attributes": [], + }, + } + message_auxgraphs = {"aux1": {"edges": ["edge2"]}} + edges, auxgraphs, nodes = recursive_get_edge_support_graphs( + "edge1", set(), set(), message_edges, message_auxgraphs, set() + ) + assert {"edge1", "edge2"} == edges + assert {"aux1"} == auxgraphs + assert {"A", "B", "C"} == nodes + + +def test_recursive_get_auxgraph_edges_raises_for_missing_aux_edge(): + message_auxgraphs = {"aux1": {"edges": ["missing"]}} + with pytest.raises(KeyError, match="missing"): + recursive_get_auxgraph_edges( + "aux1", + set(), + set(), + message_edges={}, + message_auxgraphs=message_auxgraphs, + nodes=set(), + ) + + +def test_recursive_get_edge_support_graphs_raises_for_unknown_auxgraph(): + message_edges = { + "edge1": { + "subject": "A", + "object": "B", + "attributes": [ + {"attribute_type_id": "biolink:support_graphs", "value": ["missing"]}, + ], + } + } + with pytest.raises(KeyError, match="missing"): + recursive_get_edge_support_graphs( + "edge1", set(), set(), message_edges, message_auxgraphs={}, nodes=set() + ) + + +def test_validate_message_keeps_valid_message(tmp_path, monkeypatch): + """A message whose edges all reference existing nodes should not write the + invalid_message.json side effect file.""" + monkeypatch.chdir(tmp_path) + message = { + "message": { + "knowledge_graph": { + "nodes": {"A": {}, "B": {}}, + "edges": {"e1": {"subject": "A", "object": "B"}}, + }, + "auxiliary_graphs": {}, + }, + } + validate_message(message, logger) + assert not (tmp_path / "invalid_message.json").exists() + + +def test_validate_message_dumps_invalid_message_on_missing_node(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + message = { + "message": { + "knowledge_graph": { + "nodes": {"A": {}}, + "edges": {"e1": {"subject": "A", "object": "MISSING"}}, + }, + "auxiliary_graphs": {}, + }, + } + validate_message(message, logger) + out = tmp_path / "invalid_message.json" + assert out.exists() + with open(out, encoding="utf-8") as f: + dumped = json.load(f) + assert dumped == message + + +def test_filter_kgraph_orphans_keeps_support_graph_chain(): + """An edge that references an aux graph that references another edge: all + of that should be retained.""" + message = { + "message": { + "knowledge_graph": { + "nodes": {"A": {}, "B": {}, "C": {}, "ORPHAN": {}}, + "edges": { + "result_edge": { + "subject": "A", + "object": "B", + "attributes": [ + { + "attribute_type_id": "biolink:support_graphs", + "value": ["aux1"], + } + ], + }, + "support_edge": { + "subject": "B", + "object": "C", + "attributes": [], + }, + "orphan_edge": { + "subject": "ORPHAN", + "object": "A", + "attributes": [], + }, + }, + }, + "auxiliary_graphs": { + "aux1": {"edges": ["support_edge"]}, + "aux_orphan": {"edges": ["orphan_edge"]}, + }, + "results": [ + { + "node_bindings": { + "qn1": [{"id": "A"}], + "qn2": [{"id": "B"}], + }, + "analyses": [ + {"edge_bindings": {"e0": [{"id": "result_edge"}]}}, + ], + } + ], + }, + } + filter_kgraph_orphans(message, logger) + nodes = message["message"]["knowledge_graph"]["nodes"] + edges = message["message"]["knowledge_graph"]["edges"] + auxgraphs = message["message"]["auxiliary_graphs"] + assert set(nodes.keys()) == {"A", "B", "C"} + assert set(edges.keys()) == {"result_edge", "support_edge"} + assert set(auxgraphs.keys()) == {"aux1"} + + +def test_filter_kgraph_orphans_warns_on_missing_aux_edge_and_continues(): + """If a support graph references an edge that's not in the kg, we drop it + but don't crash the whole filter.""" + message = { + "message": { + "knowledge_graph": { + "nodes": {"A": {}, "B": {}}, + "edges": { + "result_edge": { + "subject": "A", + "object": "B", + "attributes": [ + { + "attribute_type_id": "biolink:support_graphs", + "value": ["bad_aux"], + } + ], + }, + }, + }, + "auxiliary_graphs": { + "bad_aux": {"edges": ["nonexistent_edge"]}, + }, + "results": [ + { + "node_bindings": {"qn1": [{"id": "A"}], "qn2": [{"id": "B"}]}, + "analyses": [ + {"edge_bindings": {"e0": [{"id": "result_edge"}]}}, + ], + } + ], + }, + } + filter_kgraph_orphans(message, logger) + # result_edge is still kept because it was directly bound, even if its + # support graph couldn't be fully resolved. + assert "result_edge" in message["message"]["knowledge_graph"]["edges"] + + +def test_filter_kgraph_orphans_handles_path_bindings_and_support_graphs(): + """Pathfinder-style results: path_bindings + support_graphs analyses both + pull aux graphs in.""" + message = { + "message": { + "knowledge_graph": { + "nodes": {"A": {}, "B": {}, "C": {}}, + "edges": { + "edge_via_path": {"subject": "A", "object": "B", "attributes": []}, + "edge_via_support": { + "subject": "B", + "object": "C", + "attributes": [], + }, + }, + }, + "auxiliary_graphs": { + "aux_path": {"edges": ["edge_via_path"]}, + "aux_support": {"edges": ["edge_via_support"]}, + }, + "results": [ + { + "node_bindings": {"qn1": [{"id": "A"}]}, + "analyses": [ + { + "edge_bindings": {}, + "path_bindings": { + "p0": [{"id": "aux_path"}], + }, + "support_graphs": ["aux_support"], + }, + ], + } + ], + }, + } + filter_kgraph_orphans(message, logger) + auxgraphs = message["message"]["auxiliary_graphs"] + assert {"aux_path", "aux_support"}.issubset(auxgraphs.keys()) + edges = message["message"]["knowledge_graph"]["edges"] + assert {"edge_via_path", "edge_via_support"}.issubset(edges.keys()) + + +def test_filter_kgraph_orphans_creates_empty_kg_when_results_present_but_kg_missing(): + """If results reference an edge but no knowledge_graph exists yet, the + filter should still make the message structurally valid (empty kg).""" + message = { + "message": { + "results": [ + { + "node_bindings": {"qn": [{"id": "A"}]}, + "analyses": [{"edge_bindings": {"e0": [{"id": "missing"}]}}], + } + ], + }, + } + filter_kgraph_orphans(message, logger) + assert message["message"]["knowledge_graph"] == {"nodes": {}, "edges": {}} + assert message["message"]["auxiliary_graphs"] == {} + + +@pytest.mark.asyncio +async def test_wrap_up_task_pops_completed_op_and_queues_next(redis_mock): + """If the worker stream matches the head of the workflow, that op is + popped and the next one is enqueued for processing.""" + task = [ + "msg-1", + { + "query_id": "q1", + "response_id": "r1", + "workflow": json.dumps( + [ + {"id": "stream_a"}, + {"id": "stream_b"}, + ] + ), + "log_level": "20", + "otel": "{}", + }, + ] + await wrap_up_task("stream_a", "consumer", task, logger) + next_task = await get_task("stream_b", "consumer", "test", logger) + assert next_task is not None + workflow = json.loads(next_task[1]["workflow"]) + assert [op["id"] for op in workflow] == ["stream_b"] + + +@pytest.mark.asyncio +async def test_wrap_up_task_routes_empty_workflow_to_finish_query(redis_mock): + task = [ + "msg-2", + { + "query_id": "q1", + "response_id": "r1", + "workflow": json.dumps([{"id": "stream_a"}]), + "log_level": "20", + "otel": "{}", + }, + ] + await wrap_up_task("stream_a", "consumer", task, logger) + next_task = await get_task("finish_query", "consumer", "test", logger) + assert next_task is not None + workflow = json.loads(next_task[1]["workflow"]) + assert workflow == [] + + +@pytest.mark.asyncio +async def test_wrap_up_task_does_not_pop_for_entry_worker(redis_mock): + """Entry workers (whose stream name doesn't match any workflow op id) + should run the first op in the workflow, not skip it.""" + task = [ + "msg-3", + { + "query_id": "q1", + "response_id": "r1", + "workflow": json.dumps( + [ + {"id": "real_op"}, + {"id": "next_op"}, + ] + ), + "log_level": "20", + "otel": "{}", + }, + ] + await wrap_up_task("entry_stream", "consumer", task, logger) + next_task = await get_task("real_op", "consumer", "test", logger) + assert next_task is not None + workflow = json.loads(next_task[1]["workflow"]) + # real_op was the first item and must still be present (not consumed). + assert [op["id"] for op in workflow] == ["real_op", "next_op"] + + +@pytest.mark.asyncio +async def test_handle_task_failure_routes_to_finish_query_with_error_status(redis_mock): + task = [ + "msg-4", + { + "query_id": "q1", + "response_id": "r1", + "workflow": json.dumps([{"id": "broken_op"}]), + "log_level": "20", + "otel": "{}", + }, + ] + await handle_task_failure("broken_op", "consumer", task, logger) + next_task = await get_task("finish_query", "consumer", "test", logger) + assert next_task is not None + assert next_task[1]["status"] == "ERROR" diff --git a/tests/unit/test_sipr.py b/tests/unit/test_sipr.py new file mode 100644 index 0000000..954bf2e --- /dev/null +++ b/tests/unit/test_sipr.py @@ -0,0 +1,295 @@ +"""Tests for the SIPR (Set-Input Page Rank) worker. + +Covers the pure helpers that don't need network access: + +- ``write_trapi``: TRAPI query construction for the neighborhood expansion +- ``get_nodes``: kgraph filtering + PPR-based truncation +- ``distribute_weights``: personalized PageRank over a small TRAPI kgraph + +Plus the ``sipr`` entrypoint's two main branches (set-input and not). +""" + +import json +import logging + +import pytest + +from workers.sipr import worker as sipr_worker + +logger = logging.getLogger(__name__) + + +def test_write_trapi_first_hop_uses_named_thing(): + """Hop 0 keeps a generic NamedThing object category.""" + query = sipr_worker.write_trapi(["MONDO:0001"], hop_num=0) + qg = query["message"]["query_graph"] + assert qg["nodes"]["n0"]["ids"] == ["MONDO:0001"] + assert qg["nodes"]["n1"]["categories"] == ["biolink:NamedThing"] + assert qg["edges"]["e0"]["predicates"] == ["biolink:related_to"] + + +def test_write_trapi_second_hop_narrows_object_categories(): + """Hop 1 narrows to a specific list of biolink categories.""" + query = sipr_worker.write_trapi(["MONDO:0001"], hop_num=1) + qg = query["message"]["query_graph"] + cats = qg["nodes"]["n1"]["categories"] + assert "biolink:Disease" in cats + assert "biolink:Gene" in cats + assert "biolink:NamedThing" not in cats + + +def test_distribute_weights_assigns_higher_score_to_seed(): + """Personalized PageRank should bias mass toward the personalization + targets (the input ``target_nodes``).""" + response = { + "message": { + "knowledge_graph": { + "edges": { + "e1": {"subject": "A", "object": "B"}, + "e2": {"subject": "B", "object": "C"}, + "e3": {"subject": "C", "object": "A"}, + }, + "nodes": {"A": {}, "B": {}, "C": {}}, + } + } + } + ppr = sipr_worker.distribute_weights([response], ["A"], logger) + assert set(ppr.keys()) == {"A", "B", "C"} + # Seed node A should outrank the rest. + assert ppr["A"] > ppr["B"] + assert ppr["A"] > ppr["C"] + + +def test_distribute_weights_no_targets_uses_uniform_personalization(): + """No target nodes -> uniform PageRank, no error raised.""" + response = { + "message": { + "knowledge_graph": { + "edges": {"e1": {"subject": "A", "object": "B"}}, + "nodes": {"A": {}, "B": {}}, + } + } + } + ppr = sipr_worker.distribute_weights([response], [], logger) + assert set(ppr.keys()) == {"A", "B"} + + +def test_get_nodes_drops_subclass_and_related_to_edges(): + """The filterer drops subclass_of and related_to edges from the kg + before scoring.""" + response = { + "message": { + "knowledge_graph": { + "edges": { + "drop_subclass": { + "subject": "MONDO:1", + "object": "MONDO:2", + "predicate": "biolink:subclass_of", + }, + "drop_related": { + "subject": "MONDO:1", + "object": "MONDO:3", + "predicate": "biolink:related_to", + }, + "keep": { + "subject": "MONDO:1", + "object": "CHEBI:1", + "predicate": "biolink:treats", + }, + }, + "nodes": { + "MONDO:1": {"name": "n1"}, + "MONDO:2": {"name": "n2"}, + "MONDO:3": {"name": "n3"}, + "CHEBI:1": {"name": "c1"}, + }, + } + } + } + kept_nodes, filtered = sipr_worker.get_nodes(response, ["MONDO:1"], logger) + edges = filtered["message"]["knowledge_graph"]["edges"] + # Only the non-subclass / non-related_to edge survives the initial filter. + assert "keep" in edges + assert "drop_subclass" not in edges + assert "drop_related" not in edges + + +def test_get_nodes_drops_hp_to_hp_edges(): + """Edges where both endpoints are HP CURIEs get filtered out.""" + response = { + "message": { + "knowledge_graph": { + "edges": { + "hp_to_hp": { + "subject": "HP:0001", + "object": "HP:0002", + "predicate": "biolink:phenotype_of", + }, + "mixed": { + "subject": "MONDO:0001", + "object": "HP:0003", + "predicate": "biolink:phenotype_of", + }, + }, + "nodes": { + "HP:0001": {}, + "HP:0002": {}, + "HP:0003": {}, + "MONDO:0001": {}, + }, + } + } + } + _, filtered = sipr_worker.get_nodes(response, ["MONDO:0001"], logger) + edges = filtered["message"]["knowledge_graph"]["edges"] + assert "hp_to_hp" not in edges + assert "mixed" in edges + + +def test_get_nodes_excludes_hp_curies_from_returned_node_list(): + """The function returns the top scored non-HP nodes (truncated to 15).""" + response = { + "message": { + "knowledge_graph": { + "edges": { + "e1": { + "subject": "MONDO:0001", + "object": "HP:0001", + "predicate": "biolink:phenotype_of", + }, + "e2": { + "subject": "MONDO:0001", + "object": "CHEBI:0001", + "predicate": "biolink:treats", + }, + }, + "nodes": { + "MONDO:0001": {}, + "HP:0001": {}, + "CHEBI:0001": {}, + }, + } + } + } + kept_nodes, _ = sipr_worker.get_nodes(response, ["MONDO:0001"], logger) + # HP nodes excluded from the truncated return list + assert "HP:0001" not in kept_nodes + + +def _make_task(): + return [ + "test", + { + "query_id": "qid", + "response_id": "rid", + "workflow": json.dumps([{"id": "sipr"}]), + "log_level": "20", + "otel": json.dumps({}), + }, + ] + + +@pytest.mark.asyncio +async def test_sipr_skips_non_set_input_query(redis_mock, mocker): + """A query without any ``set_interpretation: MANY`` node should be skipped + silently (no save).""" + mocker.patch( + "workers.sipr.worker.get_message", + new_callable=mocker.AsyncMock, + return_value={ + "message": { + "query_graph": { + "nodes": {"a": {}, "b": {}}, + "edges": {"e0": {"subject": "a", "object": "b"}}, + } + } + }, + ) + mock_save = mocker.patch( + "workers.sipr.worker.save_message", + new_callable=mocker.AsyncMock, + ) + task = _make_task() + await sipr_worker.sipr(task, logger) + assert not mock_save.called + # The workflow on the task gets a stable sipr/sort_results_score appended. + workflow = json.loads(task[1]["workflow"]) + assert [op["id"] for op in workflow] == ["sipr", "sort_results_score"] + + +@pytest.mark.asyncio +async def test_sipr_set_input_runs_pagerank_and_saves_message(redis_mock, mocker): + """A set-input query: get_neighborhood is called, weights are + distributed, the final TRAPI message is saved with results.""" + set_input_query = { + "message": { + "query_graph": { + "nodes": { + "SN": { + "ids": ["MONDO:0001", "MONDO:0002"], + "set_interpretation": "MANY", + }, + "ON": {"categories": ["biolink:NamedThing"]}, + }, + "edges": { + "e0": {"subject": "SN", "object": "ON"}, + }, + } + } + } + mocker.patch( + "workers.sipr.worker.get_message", + new_callable=mocker.AsyncMock, + return_value=set_input_query, + ) + + # Stub out get_neighborhood so we don't reach the network at all. + fake_neighborhood = [ + { + "message": { + "knowledge_graph": { + "nodes": { + "MONDO:0001": {"categories": ["biolink:Disease"], "name": "d1"}, + "MONDO:0002": {"categories": ["biolink:Disease"], "name": "d2"}, + "CHEBI:0001": { + "categories": ["biolink:ChemicalEntity"], + "name": "c1", + }, + }, + "edges": { + "e1": { + "subject": "MONDO:0001", + "object": "CHEBI:0001", + "predicate": "biolink:treats", + }, + "e2": { + "subject": "MONDO:0002", + "object": "CHEBI:0001", + "predicate": "biolink:treats", + }, + }, + } + } + } + ] + mocker.patch( + "workers.sipr.worker.get_neighborhood", + new_callable=mocker.AsyncMock, + return_value=fake_neighborhood, + ) + mock_save = mocker.patch( + "workers.sipr.worker.save_message", + new_callable=mocker.AsyncMock, + ) + + await sipr_worker.sipr(_make_task(), logger) + assert mock_save.called + saved_id, saved_msg = mock_save.call_args.args[:2] + assert saved_id == "rid" + # Ensure the input-set node ids ended up in the kg, plus we got at least + # one analysis result with a non-trivial score. + kg_nodes = saved_msg["message"]["knowledge_graph"]["nodes"] + assert "MONDO:0001" in kg_nodes and "MONDO:0002" in kg_nodes + if saved_msg["message"]["results"]: + analysis = saved_msg["message"]["results"][0]["analyses"][0] + assert "score" in analysis and analysis["score"] > 0 diff --git a/tests/unit/test_sipr_graph_filterer.py b/tests/unit/test_sipr_graph_filterer.py new file mode 100644 index 0000000..f0e7a0f --- /dev/null +++ b/tests/unit/test_sipr_graph_filterer.py @@ -0,0 +1,106 @@ +"""Tests for ``workers.sipr.graph_filterer``. + +Covers ``remove_isolated``, ``add_hub_node``, ``filter_graph_by_weight``, +and ``apply_to_graph``. See PR notes: ``apply_list_to_graph`` references +an undefined ``G`` and is currently broken; we don't test it. +""" + +import networkx as nx + +from workers.sipr import graph_filterer as gf + + +def test_remove_isolated_strips_disconnected_nodes(): + g = nx.Graph() + g.add_edge("A", "B") + g.add_node("LONELY") + out = gf.remove_isolated(g) + assert "LONELY" not in out.nodes + assert {"A", "B"} == set(out.nodes) + + +def test_remove_isolated_returns_copy_does_not_mutate_input(): + g = nx.Graph() + g.add_edge("A", "B") + g.add_node("LONELY") + _ = gf.remove_isolated(g) + assert "LONELY" in g.nodes # original untouched + + +def test_add_hub_node_connects_to_every_other_node(): + g = nx.Graph() + g.add_edges_from([("A", "B"), ("B", "C")]) + out = gf.add_hub_node(g, hub_name="HUB") + assert "HUB" in out.nodes + for n in {"A", "B", "C"}: + assert out.has_edge("HUB", n) + # Original graph is untouched. + assert "HUB" not in g.nodes + + +def test_add_hub_node_default_name(): + g = nx.Graph() + g.add_edge("A", "B") + out = gf.add_hub_node(g) + assert "Hub" in out.nodes + + +def test_filter_graph_by_weight_keep_above_drops_low_edges(): + g = nx.Graph() + g.add_edge("A", "B", weight=0.9) + g.add_edge("B", "C", weight=0.1) + g.add_edge("C", "D", weight=0.5) + out = gf.filter_graph_by_weight(g, weight_cutoff=0.5, keep_above=True) + assert out.has_edge("A", "B") + assert out.has_edge("C", "D") + assert not out.has_edge("B", "C") + + +def test_filter_graph_by_weight_keep_below_keeps_low_edges(): + g = nx.Graph() + g.add_edge("A", "B", weight=0.9) + g.add_edge("B", "C", weight=0.1) + out = gf.filter_graph_by_weight(g, weight_cutoff=0.5, keep_above=False) + assert not out.has_edge("A", "B") + assert out.has_edge("B", "C") + + +def test_filter_graph_by_weight_drops_orphans_after_filter(): + """After dropping the only edge a node has, the node should be removed + too (post-filter remove_isolated).""" + g = nx.Graph() + g.add_edge("A", "B", weight=0.1) + out = gf.filter_graph_by_weight(g, weight_cutoff=0.5, keep_above=True) + assert "A" not in out.nodes and "B" not in out.nodes + + +def test_filter_graph_by_weight_preserves_node_attributes(): + g = nx.Graph() + g.add_node("A", category="biolink:Disease") + g.add_node("B", category="biolink:Drug") + g.add_edge("A", "B", weight=0.9) + out = gf.filter_graph_by_weight(g, weight_cutoff=0.1, keep_above=True) + assert out.nodes["A"]["category"] == "biolink:Disease" + assert out.nodes["B"]["category"] == "biolink:Drug" + + +def test_filter_graph_by_weight_default_weight_when_missing(): + """Edges without a 'weight' attribute default to 0; filtering with + keep_above=True and a positive cutoff should drop them.""" + g = nx.Graph() + g.add_edge("A", "B") # no weight attribute + out = gf.filter_graph_by_weight(g, weight_cutoff=0.1, keep_above=True) + assert not out.has_edge("A", "B") + + +def test_apply_to_graph_runs_function_on_copy(): + g = nx.Graph() + g.add_edge("A", "B", weight=0.9) + g.add_edge("B", "C", weight=0.1) + out = gf.apply_to_graph( + g, lambda x: gf.filter_graph_by_weight(x, weight_cutoff=0.5, keep_above=True) + ) + assert out.has_edge("A", "B") + assert not out.has_edge("B", "C") + # Original unaffected. + assert g.has_edge("B", "C") diff --git a/tests/unit/test_trapi_to_networkx.py b/tests/unit/test_trapi_to_networkx.py new file mode 100644 index 0000000..25bf154 --- /dev/null +++ b/tests/unit/test_trapi_to_networkx.py @@ -0,0 +1,309 @@ +"""Tests for ``shepherd_utils.TRAPI_to_NetworkX``. + +These tests exercise the public ``trapi_kg_to_nx`` helper across its mode +matrix (multigraph vs. collapsed, directed vs. undirected, payload modes) +and its weight derivation/transform logic. +""" + +import json + +import networkx as nx +import pytest + +from shepherd_utils.TRAPI_to_NetworkX import trapi_kg_to_nx +from tests.helpers.generate_messages import generate_response + +SIMPLE_KG = { + "knowledge_graph": { + "nodes": { + "A": { + "categories": ["biolink:NamedThing"], + "name": "Node A", + "attributes": [ + { + "original_attribute_name": "synonym", + "value": "alpha", + }, + { + "original_attribute_name": "synonym", + "value": "alfa", + }, + ], + }, + "B": {"categories": ["biolink:NamedThing"], "name": "Node B"}, + }, + "edges": { + "e1": { + "subject": "A", + "object": "B", + "predicate": "biolink:related_to", + "attributes": [ + { + "original_attribute_name": "ngd", + "value": 0.5, + }, + ], + }, + "e2": { + "subject": "A", + "object": "B", + "predicate": "biolink:treats", + "attributes": [ + { + "original_attribute_name": "ngd", + "value": 0.25, + }, + ], + }, + }, + } +} + + +def test_default_returns_multidigraph_with_full_metadata(): + g = trapi_kg_to_nx(SIMPLE_KG) + assert isinstance(g, nx.MultiDiGraph) + assert set(g.nodes) == {"A", "B"} + assert g.number_of_edges() == 2 + # Metadata preserved + edge_data = list(g.get_edge_data("A", "B").values()) + predicates = {e["predicate"] for e in edge_data} + assert predicates == {"biolink:related_to", "biolink:treats"} + # attributes_flat keys on the multigraph edge + assert all("attributes_flat" in e for e in edge_data) + + +def test_undirected_multigraph_returns_multigraph_class(): + g = trapi_kg_to_nx(SIMPLE_KG, directed=False) + assert isinstance(g, nx.MultiGraph) + # Edges are undirected + assert g.has_edge("A", "B") and g.has_edge("B", "A") + + +def test_collapsed_directed_returns_digraph(): + g = trapi_kg_to_nx(SIMPLE_KG, multigraph=False) + assert isinstance(g, nx.DiGraph) + assert g.number_of_edges() == 1 + # last-seen metadata wins for the (A,B) pair in collapsed/full mode + assert g["A"]["B"]["predicate"] in {"biolink:related_to", "biolink:treats"} + + +def test_collapsed_undirected_returns_graph(): + g = trapi_kg_to_nx(SIMPLE_KG, multigraph=False, directed=False) + assert isinstance(g, nx.Graph) + assert not isinstance(g, nx.DiGraph) + + +def test_weights_disabled_no_weight_on_edges(): + g = trapi_kg_to_nx(SIMPLE_KG) + for _, _, data in g.edges(data=True): + assert "weight" not in data + + +def test_weights_enabled_via_attribute(): + g = trapi_kg_to_nx(SIMPLE_KG, edge_weight_attr="ngd") + weights = sorted(d["weight"] for _, _, d in g.edges(data=True)) + assert weights == [0.25, 0.5] + + +def test_weights_default_when_attr_missing(): + g = trapi_kg_to_nx( + { + "knowledge_graph": { + "nodes": {"A": {}, "B": {}}, + "edges": { + "e1": {"subject": "A", "object": "B", "attributes": []}, + }, + } + }, + edge_weight_attr="ngd", + default_weight=0.7, + ) + edge_data = next(iter(g.get_edge_data("A", "B").values())) + assert edge_data["weight"] == 0.7 + + +def test_weight_default_used_when_attr_unparseable(): + """If the attribute exists but isn't numeric, fall back to default_weight.""" + g = trapi_kg_to_nx( + { + "knowledge_graph": { + "nodes": {"A": {}, "B": {}}, + "edges": { + "e1": { + "subject": "A", + "object": "B", + "attributes": [ + {"original_attribute_name": "ngd", "value": "not-a-number"}, + ], + }, + }, + } + }, + edge_weight_attr="ngd", + default_weight=2.5, + ) + weights = [d["weight"] for _, _, d in g.edges(data=True)] + assert weights == [2.5] + + +def test_weight_transform_applied(): + g = trapi_kg_to_nx( + SIMPLE_KG, + edge_weight_attr="ngd", + edge_weight_transform=lambda x: x * 10, + ) + weights = sorted(d["weight"] for _, _, d in g.edges(data=True)) + assert weights == [2.5, 5.0] + + +def test_weight_transform_failure_falls_back_to_raw_value(): + """When the transform raises, the raw value (not default) is returned.""" + + def bad_transform(x): + raise RuntimeError("boom") + + g = trapi_kg_to_nx( + SIMPLE_KG, + edge_weight_attr="ngd", + edge_weight_transform=bad_transform, + default_weight=99.0, + ) + weights = sorted(d["weight"] for _, _, d in g.edges(data=True)) + assert weights == [0.25, 0.5] + + +@pytest.mark.parametrize( + "agg, expected", + [ + ("sum", 0.75), + ("max", 0.5), + ("min", 0.25), + ("first", 0.5), + ], +) +def test_collapsed_weight_only_aggregations(agg, expected): + g = trapi_kg_to_nx( + SIMPLE_KG, + multigraph=False, + edge_weight_attr="ngd", + edge_payload="weight_only", + weight_agg=agg, + ) + assert g["A"]["B"]["weight"] == pytest.approx(expected) + + +def test_collapsed_weight_only_unknown_agg_falls_back_to_sum(): + """An unrecognized string aggregator name should fall back to sum.""" + g = trapi_kg_to_nx( + SIMPLE_KG, + multigraph=False, + edge_weight_attr="ngd", + edge_payload="weight_only", + weight_agg="not-a-real-name", + ) + assert g["A"]["B"]["weight"] == pytest.approx(0.75) + + +def test_collapsed_weight_only_with_callable_aggregator(): + """A user callable aggregator: takes (existing, new), returns combined.""" + g = trapi_kg_to_nx( + SIMPLE_KG, + multigraph=False, + edge_weight_attr="ngd", + edge_payload="weight_only", + weight_agg=lambda a, b: a * b, + ) + assert g["A"]["B"]["weight"] == pytest.approx(0.5 * 0.25) + + +def test_collapsed_weight_only_with_no_weights_strips_attributes(): + g = trapi_kg_to_nx( + SIMPLE_KG, + multigraph=False, + edge_payload="weight_only", + ) + # No edge_weight_attr, so the bare edge has no attributes. + assert g["A"]["B"] == {} + + +def test_invalid_edge_payload_raises(): + with pytest.raises(ValueError, match="edge_payload"): + trapi_kg_to_nx(SIMPLE_KG, edge_payload="something_else") + + +def test_full_payload_with_weight_agg_raises(): + with pytest.raises(ValueError, match="weight_agg"): + trapi_kg_to_nx(SIMPLE_KG, edge_payload="full", weight_agg="max") + + +def test_accepts_message_root_dict(): + """Real TRAPI responses wrap the KG in ``message.knowledge_graph`` -- the + function should walk into that automatically.""" + response = generate_response() + g = trapi_kg_to_nx(response, multigraph=False, directed=True) + # The fixture has 3 kg nodes + assert len(g.nodes) == 3 + # And 2 edges, but they go between different node pairs so collapsed + # graph should still have 2 directed edges. + assert g.number_of_edges() == 2 + + +def test_accepts_json_string_input(): + g = trapi_kg_to_nx(json.dumps(SIMPLE_KG)) + assert set(g.nodes) == {"A", "B"} + + +def test_accepts_bytes_input(): + g = trapi_kg_to_nx(json.dumps(SIMPLE_KG).encode("utf-8")) + assert g.number_of_edges() == 2 + + +def test_missing_kg_raises_keyerror(): + with pytest.raises(KeyError, match="knowledge_graph"): + trapi_kg_to_nx({"message": {"results": []}}) + + +def test_edge_endpoint_not_in_node_map_is_added_as_bare_node(): + g = trapi_kg_to_nx( + { + "knowledge_graph": { + "nodes": {"A": {"categories": ["biolink:NamedThing"]}}, + "edges": { + "e1": { + "subject": "A", + "object": "DANGLING", + "attributes": [], + } + }, + } + } + ) + assert "DANGLING" in g.nodes + + +def test_edge_with_missing_endpoint_is_skipped(): + g = trapi_kg_to_nx( + { + "knowledge_graph": { + "nodes": {"A": {}}, + "edges": { + "bad": {"subject": "A", "attributes": []}, # no object + }, + } + } + ) + assert g.number_of_edges() == 0 + + +def test_node_attributes_flat_dedupes_into_list_for_repeated_keys(): + g = trapi_kg_to_nx(SIMPLE_KG) + a_flat = g.nodes["A"]["attributes_flat"] + assert a_flat["synonym"] == ["alpha", "alfa"] + + +def test_multigraph_edge_id_uses_source_id(): + """In multigraph mode the edge ``id`` attribute comes from the TRAPI key.""" + g = trapi_kg_to_nx(SIMPLE_KG) + edge_ids = {data["id"] for _, _, data in g.edges(data=True)} + assert edge_ids == {"e1", "e2"} diff --git a/workers/sipr/__init__.py b/workers/sipr/__init__.py new file mode 100644 index 0000000..e69de29