From a555b832d6f775ca717f86a0a7b9cff5752059d3 Mon Sep 17 00:00:00 2001 From: Vishal Gupta Date: Fri, 22 May 2026 09:08:30 +0000 Subject: [PATCH] Use asynchronous aggregation queries --- .../ingestion-helper/aggregation_utils.py | 199 +++++++++++++----- .../workflow/ingestion-helper/main.py | 27 ++- .../workflow/spanner-ingestion-workflow.yaml | 90 +++++--- 3 files changed, 229 insertions(+), 87 deletions(-) diff --git a/import-automation/workflow/ingestion-helper/aggregation_utils.py b/import-automation/workflow/ingestion-helper/aggregation_utils.py index 0993ac32b8..07aab2fb65 100644 --- a/import-automation/workflow/ingestion-helper/aggregation_utils.py +++ b/import-automation/workflow/ingestion-helper/aggregation_utils.py @@ -21,8 +21,10 @@ logging.getLogger().setLevel(logging.INFO) + class BigQueryExecutor: """Handles BigQuery client initialization and query execution.""" + def __init__(self, connection_id: str, project_id: str, @@ -35,7 +37,8 @@ def __init__(self, self.database_id = database_id self.location = location try: - self.client = bigquery.Client(project=self.project_id, location=self.location) + self.client = bigquery.Client(project=self.project_id, + location=self.location) except Exception as e: logging.warning(f"Failed to initialize BigQuery client: {e}") self.client = None @@ -44,47 +47,110 @@ def get_spanner_destination_uri(self) -> str: """Returns the Spanner destination URI for EXPORT DATA.""" return f"https://spanner.googleapis.com/projects/{self.project_id}/instances/{self.instance_id}/databases/{self.database_id}" - def execute(self, query: str, job_config: Optional[bigquery.QueryJobConfig] = None) -> bigquery.table.RowIterator: + def execute( + self, + query: str, + job_config: Optional[bigquery.QueryJobConfig] = None + ) -> bigquery.table.RowIterator: """Executes a query and returns the result.""" - if not self.client: - logging.error("BigQuery client not initialized") - raise RuntimeError("BigQuery client not initialized") - start_time = time.time() - logging.info(f"Executing query (first 100 chars): {query.strip()[:100]}...") - try: - query_job = self.client.query(query, job_config=job_config) + query_job = self.execute(query, job_config) result = query_job.result() duration = time.time() - start_time - logging.info(f"Query completed in {duration:.2f}s. Job ID: {query_job.job_id}") + logging.info( + f"Query completed in {duration:.2f}s. Job ID: {query_job.job_id}" + ) return result except Exception as e: - logging.error(f"Query execution failed after {time.time() - start_time:.2f}s: {e}") + logging.error( + f"Query execution failed after {time.time() - start_time:.2f}s: {e}" + ) + raise + + def execute( + self, + query: str, + job_config: Optional[bigquery.QueryJobConfig] = None + ) -> bigquery.job.QueryJob: + """Submits a query asynchronously and returns the QueryJob.""" + if not self.client: + logging.error("BigQuery client not initialized") + raise RuntimeError("BigQuery client not initialized") + + logging.info( + f"Submitting query (first 100 chars): {query.strip()[:100]}...") + + try: + query_job = self.client.query(query, job_config=job_config) + logging.info(f"Query submitted. Job ID: {query_job.job_id}") + return query_job + except Exception as e: + logging.error(f"Failed to submit query: {e}") raise + def get_jobs_status(self, job_ids: List[str]) -> Dict[str, Any]: + """Returns the overall status of a list of BigQuery jobs.""" + if not self.client: + logging.error("BigQuery client not initialized") + raise RuntimeError("BigQuery client not initialized") + + overall_status = "DONE" + failed_jobs = [] + error_message = "" + + for job_id in job_ids: + try: + job = self.client.get_job(job_id, location=self.location) + if job.error_result: + overall_status = "FAILED" + failed_jobs.append(job_id) + error_message += f"Job {job_id} failed: {job.error_result}. " + elif job.state != "DONE" and overall_status != "FAILED": + overall_status = "RUNNING" + except Exception as e: + logging.error(f"Failed to get job status for {job_id}: {e}") + overall_status = "FAILED" + failed_jobs.append(job_id) + error_message += f"Failed to get job {job_id}: {e}. " + + if overall_status == "FAILED": + return { + "status": overall_status, + "error": error_message, + "failedJobs": failed_jobs + } + else: + return {"status": overall_status} + class LinkedEdgeGenerator: """Generates and ingests linked relationship edges (e.g., transitive closures) into Spanner for faster lookup.""" - def __init__(self, executor: BigQueryExecutor, is_base_dc: bool = True) -> None: + + def __init__(self, + executor: BigQueryExecutor, + is_base_dc: bool = True) -> None: self.executor = executor self.is_base_dc = is_base_dc - def run_all(self, import_names: List[str] = None) -> None: - """Runs all global aggregations in sequence.""" + def run_all(self, + import_names: List[str] = None) -> List[bigquery.job.QueryJob]: + """Runs all global aggregations asynchronously and returns their jobs.""" if not import_names: logging.info("No imports specified. Skipping global aggregations.") - return + return [] logging.info(f"Running global aggregations for imports: {import_names}") - - # TODO: Run these methods in parallel to speed up execution since they are independent. - self.run_linked_contained_in_place(import_names) - self.run_linked_member_of(import_names) - self.run_linked_member(import_names) + jobs = [ + self.run_linked_contained_in_place(import_names), + self.run_linked_member_of(import_names), + self.run_linked_member(import_names) + ] + return jobs - def run_linked_contained_in_place(self, import_names: List[str] = None) -> None: + def run_linked_contained_in_place(self, + import_names: List[str] = None) -> None: """Expands place containment hierarchies.""" if not import_names: return @@ -96,7 +162,7 @@ def run_linked_contained_in_place(self, import_names: List[str] = None) -> None: provenances = [f"'{prefix}{name}'" for name in safe_names] provenance_filter = f" AND provenance IN ({', '.join(provenances)})" gen_graphs_prov = 'dc/base/GeneratedGraphs' if self.is_base_dc else 'GeneratedGraphs' - + query = f""" -- Pull base edges needed for containedInPlace aggregation CREATE OR REPLACE TEMPORARY TABLE `temp_base_contained_in_place` AS @@ -171,7 +237,7 @@ def run_linked_contained_in_place(self, import_names: List[str] = None) -> None: FROM FilteredEdges """ - self.executor.execute(query) + return self.executor.execute(query) def run_linked_member_of(self, import_names: List[str] = None) -> None: """Expands membership hierarchies using memberOf and specializationOf.""" @@ -263,7 +329,7 @@ def run_linked_member_of(self, import_names: List[str] = None) -> None: FROM FilteredEdges """ - self.executor.execute(query) + return self.executor.execute(query) def run_linked_member(self, import_names: List[str] = None) -> None: """Expands topic/SVGP descendants to identify leaf members.""" @@ -356,38 +422,44 @@ def run_linked_member(self, import_names: List[str] = None) -> None: FROM FilteredEdges """ - self.executor.execute(query) + return self.executor.execute(query) class ProvenanceSummaryGenerator: """Contains the SQL queries to generate ProvenanceSummary in the Cache table.""" - def __init__(self, executor: BigQueryExecutor, is_base_dc: bool = True) -> None: + + def __init__(self, + executor: BigQueryExecutor, + is_base_dc: bool = True) -> None: self.executor = executor self.is_base_dc = is_base_dc - def run_all(self, import_names: List[str]) -> None: - """Runs all provenance summary generation in sequence.""" + def run_all(self, import_names: List[str]) -> List[bigquery.job.QueryJob]: + """Runs all provenance summary generation asynchronously and returns their jobs.""" if not import_names: logging.info("No imports specified. Skipping cache aggregations.") - return + return [] - logging.info(f"Running provenance summary generation for imports: {import_names}") - self.run_provenance_summary_aggregation(import_names) + logging.info( + f"Running provenance summary generation for imports: {import_names}" + ) + return [self.run_provenance_summary_aggregation(import_names)] - def run_provenance_summary_aggregation(self, import_names: List[str]) -> None: + def run_provenance_summary_aggregation(self, + import_names: List[str]) -> None: """Calculates ProvenanceSummary for all variables and populates the Cache table.""" if not import_names: return dest = self.executor.get_spanner_destination_uri() connection_id = self.executor.connection_id - + # Escape single quotes to prevent SQL injection safe_names = [name.replace("'", "''") for name in import_names] # Format import names for the SQL IN clause imports_str = ", ".join([f"'{name}'" for name in safe_names]) provenance_dcid_expr = "CONCAT('dc/base/', raw.import_name)" if self.is_base_dc else "raw.import_name" - + query = f""" -- Step 1: Fetch Observation rows for the specific import -- We cast 'observations' to STRING to avoid the PROTO error. @@ -571,31 +643,33 @@ def run_provenance_summary_aggregation(self, import_names: List[str]) -> None: FROM facet_summaries GROUP BY variable_measured, provenance_dcid; """ - self.executor.execute(query) + return self.executor.execute(query) class AggregationUtils: """Orchestrates the overall aggregation workflow.""" - def __init__(self, + + def __init__(self, connection_id: str, project_id: str, instance_id: str, database_id: str, location: Optional[str] = None, is_base_dc: bool = True) -> None: - self.executor = BigQueryExecutor( - connection_id=connection_id, - project_id=project_id, - instance_id=instance_id, - database_id=database_id, - location=location - ) - self.linked_edge_generator = LinkedEdgeGenerator(self.executor, is_base_dc) - self.provenance_summary_generator = ProvenanceSummaryGenerator(self.executor, is_base_dc) - - def run_aggregation(self, import_list: List[Dict[str, Any]]) -> bool: + self.executor = BigQueryExecutor(connection_id=connection_id, + project_id=project_id, + instance_id=instance_id, + database_id=database_id, + location=location) + self.linked_edge_generator = LinkedEdgeGenerator( + self.executor, is_base_dc) + self.provenance_summary_generator = ProvenanceSummaryGenerator( + self.executor, is_base_dc) + + def run_aggregation(self, import_list: List[Dict[str, Any]]) -> List[str]: """ Orchestrates standard per-import aggregations and global aggregations. + Returns a list of BigQuery job IDs for async polling. """ logging.info(f"Received request for importList: {import_list}") @@ -608,17 +682,34 @@ def run_aggregation(self, import_list: List[Dict[str, Any]]) -> bool: import_names.append(import_name) query = "SELECT @import_name as import_name, CURRENT_TIMESTAMP() as execution_time" job_config = bigquery.QueryJobConfig(query_parameters=[ - bigquery.ScalarQueryParameter("import_name", "STRING", import_name), + bigquery.ScalarQueryParameter("import_name", "STRING", + import_name), ]) self.executor.execute(query, job_config=job_config) else: - logging.info('Skipping aggregation logic for empty importName') + logging.info( + 'Skipping aggregation logic for empty importName') - # 2. Run global aggregations - self.linked_edge_generator.run_all(import_names) - self.provenance_summary_generator.run_all(import_names) - - return True + # 2. Run global aggregations asynchronously + jobs = [] + jobs.extend(self.linked_edge_generator.run_all(import_names)) + jobs.extend(self.provenance_summary_generator.run_all(import_names)) + + job_ids = [job.job_id for job in jobs if job] + logging.info(f"Submitted async aggregation jobs: {job_ids}") + + return job_ids except Exception as e: logging.error(f"Aggregation failed: {e}") raise e + + def check_aggregation_status(self, job_ids: List[str]) -> Dict[str, Any]: + """ + Checks the status of the provided BigQuery job IDs. + """ + logging.info(f"Checking status for jobs: {job_ids}") + try: + return self.executor.get_jobs_status(job_ids) + except Exception as e: + logging.error(f"Failed to check aggregation status: {e}") + raise e diff --git a/import-automation/workflow/ingestion-helper/main.py b/import-automation/workflow/ingestion-helper/main.py index 511d9e7108..00e66a2d75 100644 --- a/import-automation/workflow/ingestion-helper/main.py +++ b/import-automation/workflow/ingestion-helper/main.py @@ -282,13 +282,10 @@ def ingestion_helper(request): is_base_dc=FLAGS.is_base_dc, ) try: - if aggregation.run_aggregation(import_list): - return ('OK', 200) - else: - return ('Aggregation failed', 500) + job_ids = aggregation.run_aggregation(import_list) + return jsonify({'status': 'SUBMITTED', 'jobIds': job_ids}), 200 except Exception as e: return (f"Aggregation failed: {str(e)}", 500) - elif action_type == 'clear_redis_cache': logging.info("Action: clear_redis_cache") redis_host = os.environ.get("REDIS_HOST") @@ -306,6 +303,26 @@ def ingestion_helper(request): else: logging.warning("REDIS_HOST not set, skipping cache flush.") return jsonify({'status': 'SKIPPED', 'message': 'REDIS_HOST not set'}), 200 + elif action_type == 'check_aggregation_status': + # Checks the status of submitted aggregation BigQuery jobs. + # Input: + # jobIds: list of BigQuery job IDs + job_ids = request_json.get('jobIds', []) + if not job_ids: + return ('Missing or empty jobIds', 400) + + aggregation = AggregationUtils( + connection_id=FLAGS.spanner_connection_id, + project_id=FLAGS.spanner_project_id, + instance_id=FLAGS.spanner_instance_id, + database_id=FLAGS.spanner_graph_database_id, + location=FLAGS.location, + ) + try: + status = aggregation.check_aggregation_status(job_ids) + return jsonify(status), 200 + except Exception as e: + return (f"Aggregation status check failed: {str(e)}", 500) else: return (f'Unknown actionType: {action_type}', 400) diff --git a/import-automation/workflow/spanner-ingestion-workflow.yaml b/import-automation/workflow/spanner-ingestion-workflow.yaml index 1f18d1feda..7b60bf48aa 100644 --- a/import-automation/workflow/spanner-ingestion-workflow.yaml +++ b/import-automation/workflow/spanner-ingestion-workflow.yaml @@ -14,6 +14,7 @@ main: - spanner_database_id: '${sys.get_env("SPANNER_DATABASE_ID")}' - helper_url: ${"https://ingestion-helper-service-" + sys.get_env("PROJECT_NUMBER") + "." + location + ".run.app"} - import_list: ${default(map.get(args, "importList"), [])} + - dataflow_job_id: null - execution_error: null - acquire_ingestion_lock: try: @@ -62,16 +63,11 @@ main: helper_url: ${helper_url} workflow_id: '${sys.get_env("GOOGLE_CLOUD_WORKFLOW_EXECUTION_ID")}' result: dataflow_job_id - - run_aggregation: - call: http.post + - run_aggregation_job: + call: run_aggregation_job args: - url: ${helper_url} - timeout: 1800 - auth: - type: OIDC - body: - actionType: run_aggregation - importList: ${import_info.body} + import_list: ${import_info.body} + helper_url: ${helper_url} - update_ingestion_status: call: http.post args: @@ -88,6 +84,19 @@ main: except: as: e steps: + - record_failure: + call: http.post + args: + url: ${helper_url} + auth: + type: OIDC + body: + actionType: update_ingestion_status + workflowId: '${sys.get_env("GOOGLE_CLOUD_WORKFLOW_EXECUTION_ID")}' + jobId: ${default(dataflow_job_id, default(map.get(e, "job_id"), "N/A"))} + importList: '${import_info.body}' + status: 'RETRY' + result: retry_response - capture_error: assign: - execution_error: ${e} @@ -108,6 +117,45 @@ main: - return_import_info: return: '${import_info.body}' +# This sub-workflow runs aggregation jobs and waits for them to complete. +run_aggregation_job: + params: [import_list, helper_url] + steps: + - run_aggregation: + call: http.post + args: + url: ${helper_url} + timeout: 300 + auth: + type: OIDC + body: + actionType: run_aggregation + importList: ${import_list} + result: aggregation_response + - check_aggregation_status_loop: + steps: + - wait_for_aggregation: + call: sys.sleep + args: + seconds: 15 + - check_aggregation_status: + call: http.post + args: + url: ${helper_url} + auth: + type: OIDC + body: + actionType: check_aggregation_status + jobIds: ${aggregation_response.body.jobIds} + result: aggregation_status_response + - evaluate_aggregation_status: + switch: + - condition: ${aggregation_status_response.body.status == "DONE"} + return: 'OK' + - condition: ${aggregation_status_response.body.status == "FAILED"} + raise: ${aggregation_status_response.body.error} + next: check_aggregation_status_loop + # This sub-workflow launches a Dataflow job and waits for it to complete. run_dataflow_job: params: [import_list, project_id, job_name, template_gcs_path, location, spanner_project_id, spanner_instance_id, spanner_database_id, wait_period, helper_url, workflow_id] @@ -163,22 +211,8 @@ run_dataflow_job: - condition: ${job_status.currentState == "JOB_STATE_DONE"} return: ${launch_result.job.id} - condition: ${job_status.currentState == "JOB_STATE_FAILED" or job_status.currentState == "JOB_STATE_CANCELLED"} - next: record_failed_imports - next: wait_for_job_completion - - record_failed_imports: - call: http.post - args: - url: ${helper_url} - auth: - type: OIDC - body: - actionType: update_ingestion_status - workflowId: '${workflow_id}' - jobId: '${launch_result.job.id}' - importList: '${json.decode(import_list)}' - status: 'PENDING' - result: retry_response - - fail_workflow: - raise: - message: '${jobName + " dataflow job failed with status: " + job_status.currentState}' - code: 500 \ No newline at end of file + raise: + message: '${jobName + " dataflow job failed with status: " + job_status.currentState}' + code: 500 + job_id: ${launch_result.job.id} + next: wait_for_job_completion \ No newline at end of file