diff --git a/TCT/TCT_neighborhood_finder.py b/TCT/TCT_neighborhood_finder.py index 4c8018a..068a56d 100644 --- a/TCT/TCT_neighborhood_finder.py +++ b/TCT/TCT_neighborhood_finder.py @@ -1,7 +1,8 @@ from collections import Counter -from .TCT import sele_predicates_API, format_query_json, parse_KG, rank_by_primary_infores +from .TCT import sele_predicates_API, parse_KG, rank_by_primary_infores from .TCT_pathfinder import generate_score_results, build_query_graph +from .translator_query import format_query_json def parse_results_for_neighborhood_finder(start_node_id:str, results:dict, start_node_categories=None, end_node_categories=None, diff --git a/TCT/TCT_pathfinder.py b/TCT/TCT_pathfinder.py index b4eaa28..f3c9109 100644 --- a/TCT/TCT_pathfinder.py +++ b/TCT/TCT_pathfinder.py @@ -5,62 +5,7 @@ from . import node_normalizer from . import translator_query -from .TCT import sele_predicates_API, parse_KG, rank_by_primary_infores, merge_ranking_by_number_of_infores - -def format_query_json_for_pathfinder(subject_ids, - object_ids=None, - subject_categories=None, - object_categories=None, - predicates=None): - ''' - Example input: - subject_ids = ["NCBIGene:3845"] - object_ids = [] - subject_categories = ["biolink:Gene"] - object_categories = ["biolink:Gene"] - predicates = ["biolink:positively_correlated_with", "biolink:physically_interacts_with"] - ''' - query_json_temp = { - "message": { - "query_graph": { - - "edges": { - "e00": { - "subject": "n00", - "object": "n01", - "predicates": predicates - } - }, - "nodes": { - "n00": { - "ids":subject_ids, # required - #"categories":[] # optional, if not provided, it will be empty - }, - "n01": { - #"ids":[], - "categories":[] # required - } - } - } - } - } - - if len(subject_ids) > 0: - query_json_temp["message"]["query_graph"]["nodes"]["n00"]["ids"] = subject_ids - - if object_ids is not None and len(object_ids) > 0: - query_json_temp["message"]["query_graph"]["nodes"]["n01"]["ids"] = object_ids - - if subject_categories is not None and len(subject_categories) > 0: - query_json_temp["message"]["query_graph"]["nodes"]["n00"]["categories"] = subject_categories - - if object_categories is not None and len(object_categories) > 0: - query_json_temp["message"]["query_graph"]["nodes"]["n01"]["categories"] = object_categories - - if predicates is not None and len(predicates) > 0: - query_json_temp["message"]["query_graph"]["edges"]["e00"]["predicates"] = predicates - - return query_json_temp +from .TCT import sele_predicates_API def build_query_graph(start_node_id, end_node_id, start_node_categories=None, end_node_categories=None): @@ -289,14 +234,14 @@ def pathfinder(input_node1_id:str, input_node2_id:str, sele_predicates2, sele_APIs2, API_URLs2 = sele_predicates_API(intermediate_categories, input_node2_category, metaKG, APInames) - query_json1 = format_query_json_for_pathfinder(input_node1_list, # a list of identifiers for input node1 + query_json1 = translator_query.format_query_json(input_node1_list, # a list of identifiers for input node1 [], # id list for the intermediate node, it can be empty list if only want to query node1 input_node1_category, # a list of categories of input node1 intermediate_categories, # a list of categories of the intermediate node sele_predicates1) # a list of predicates # for the second hop, we want the predicates to be... - query_json2 = format_query_json_for_pathfinder([], + query_json2 = translator_query.format_query_json([], input_node2_list, intermediate_categories, # a list of categories of input node2 input_node2_category, # a list of categories of the intermediate node diff --git a/TCT/translator_query.py b/TCT/translator_query.py index 870b153..9baaac8 100644 --- a/TCT/translator_query.py +++ b/TCT/translator_query.py @@ -43,6 +43,90 @@ def get_translator_API_predicates() -> tuple[dict, pandas.DataFrame, dict]: return APInames, metaKG, API_predicates + +def build_attribute_constraint(attribute_id, operator, value, name=None, is_not=False): + """ + This creates an attribute constraint for a TRAPI query dict. + """ + if name is None: + name = '' + output = { + 'id': attribute_id, + 'operator': operator, + 'value': value, + 'name': name + } + if is_not: + output['not'] = True + return output + + +def format_query_json(subject_ids:list[str], + object_ids:list[str]|None = None, + subject_categories:list[str]|None = None, + object_categories:list[str]|None = None, + predicates:list[str]|None = None, + attribute_constraints:list[dict]|None = None, + ) -> dict: + ''' + Formats a query dict, with optional constraints. + + Example input: + subject_ids = ["NCBIGene:3845"] + object_ids = [] + subject_categories = ["biolink:Gene"] + object_categories = ["biolink:Gene"] + predicates = ["biolink:positively_correlated_with", "biolink:physically_interacts_with"] + attribute_constraints = [build_attribute_constraint('biolink:has_total', '>', 2)] + ''' + #edited Dec 5, 2023 + query_json_temp = { + "message": { + "query_graph": { + + "edges": { + "e00": { + #"e1": { + "subject": "n01", + "object": "n00", + "predicates": predicates + } + }, + "nodes": { + "n00": { + "ids":subject_ids, # required + #"categories":[] # optional, if not provided, it will be empty + }, + "n01": { + #"ids":[], + "categories":[] # required + }} + } + } + } + + if attribute_constraints is not None and len(attribute_constraints) > 0: + query_json_temp['message']['query_graph']['edges']['attribute_constraints'] = attribute_constraints + + if subject_ids is not None and len(subject_ids) > 0: + query_json_temp["message"]["query_graph"]["nodes"]["n00"]["ids"] = subject_ids + + if object_ids is not None and len(object_ids) > 0: + query_json_temp["message"]["query_graph"]["nodes"]["n01"]["ids"] = object_ids + + if subject_categories is not None and len(subject_categories) > 0: + query_json_temp["message"]["query_graph"]["nodes"]["n01"]["categories"] = subject_categories + + if object_categories is not None and len(object_categories) > 0: + query_json_temp["message"]["query_graph"]["nodes"]["n01"]["categories"] = object_categories + + if predicates is not None and len(predicates) > 0: + query_json_temp["message"]["query_graph"]["edges"]["e00"]["predicates"] = predicates + + return query_json_temp + + + def optimize_query_json(query_json, API_name_cur, API_predicates): ''' Optimize the query JSON by removing predicates that are not supported by the selected APIs.