Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 145 additions & 54 deletions import-automation/workflow/ingestion-helper/aggregation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

logging.getLogger().setLevel(logging.INFO)


class BigQueryExecutor:
"""Handles BigQuery client initialization and query execution."""

def __init__(self,
connection_id: str,
project_id: str,
Expand All @@ -35,7 +37,8 @@ def __init__(self,
self.database_id = database_id
self.location = location
try:
self.client = bigquery.Client(project=self.project_id, location=self.location)
self.client = bigquery.Client(project=self.project_id,
location=self.location)
except Exception as e:
logging.warning(f"Failed to initialize BigQuery client: {e}")
self.client = None
Expand All @@ -44,47 +47,110 @@ def get_spanner_destination_uri(self) -> str:
"""Returns the Spanner destination URI for EXPORT DATA."""
return f"https://spanner.googleapis.com/projects/{self.project_id}/instances/{self.instance_id}/databases/{self.database_id}"

def execute(self, query: str, job_config: Optional[bigquery.QueryJobConfig] = None) -> bigquery.table.RowIterator:
def execute(
self,
query: str,
job_config: Optional[bigquery.QueryJobConfig] = None
) -> bigquery.table.RowIterator:
"""Executes a query and returns the result."""
if not self.client:
logging.error("BigQuery client not initialized")
raise RuntimeError("BigQuery client not initialized")

start_time = time.time()
logging.info(f"Executing query (first 100 chars): {query.strip()[:100]}...")

try:
query_job = self.client.query(query, job_config=job_config)
query_job = self.execute(query, job_config)
result = query_job.result()
duration = time.time() - start_time
logging.info(f"Query completed in {duration:.2f}s. Job ID: {query_job.job_id}")
logging.info(
f"Query completed in {duration:.2f}s. Job ID: {query_job.job_id}"
)
return result
except Exception as e:
logging.error(f"Query execution failed after {time.time() - start_time:.2f}s: {e}")
logging.error(
f"Query execution failed after {time.time() - start_time:.2f}s: {e}"
)
raise

def execute(
self,
query: str,
job_config: Optional[bigquery.QueryJobConfig] = None
) -> bigquery.job.QueryJob:
"""Submits a query asynchronously and returns the QueryJob."""
if not self.client:
logging.error("BigQuery client not initialized")
raise RuntimeError("BigQuery client not initialized")

logging.info(
f"Submitting query (first 100 chars): {query.strip()[:100]}...")

try:
query_job = self.client.query(query, job_config=job_config)
logging.info(f"Query submitted. Job ID: {query_job.job_id}")
return query_job
except Exception as e:
logging.error(f"Failed to submit query: {e}")
raise

def get_jobs_status(self, job_ids: List[str]) -> Dict[str, Any]:
"""Returns the overall status of a list of BigQuery jobs."""
if not self.client:
logging.error("BigQuery client not initialized")
raise RuntimeError("BigQuery client not initialized")

overall_status = "DONE"
failed_jobs = []
error_message = ""

for job_id in job_ids:
try:
job = self.client.get_job(job_id, location=self.location)
if job.error_result:
overall_status = "FAILED"
failed_jobs.append(job_id)
error_message += f"Job {job_id} failed: {job.error_result}. "
elif job.state != "DONE" and overall_status != "FAILED":
overall_status = "RUNNING"
except Exception as e:
logging.error(f"Failed to get job status for {job_id}: {e}")
overall_status = "FAILED"
failed_jobs.append(job_id)
error_message += f"Failed to get job {job_id}: {e}. "

Comment thread
vish-cs marked this conversation as resolved.
if overall_status == "FAILED":
return {
"status": overall_status,
"error": error_message,
"failedJobs": failed_jobs
}
else:
return {"status": overall_status}


class LinkedEdgeGenerator:
"""Generates and ingests linked relationship edges (e.g., transitive closures) into Spanner for faster lookup."""
def __init__(self, executor: BigQueryExecutor, is_base_dc: bool = True) -> None:

def __init__(self,
executor: BigQueryExecutor,
is_base_dc: bool = True) -> None:
self.executor = executor
self.is_base_dc = is_base_dc

def run_all(self, import_names: List[str] = None) -> None:
"""Runs all global aggregations in sequence."""
def run_all(self,
import_names: List[str] = None) -> List[bigquery.job.QueryJob]:
"""Runs all global aggregations asynchronously and returns their jobs."""
if not import_names:
logging.info("No imports specified. Skipping global aggregations.")
return
return []

logging.info(f"Running global aggregations for imports: {import_names}")

# TODO: Run these methods in parallel to speed up execution since they are independent.
self.run_linked_contained_in_place(import_names)
self.run_linked_member_of(import_names)
self.run_linked_member(import_names)

jobs = [
self.run_linked_contained_in_place(import_names),
self.run_linked_member_of(import_names),
self.run_linked_member(import_names)
]
return jobs

def run_linked_contained_in_place(self, import_names: List[str] = None) -> None:
def run_linked_contained_in_place(self,
import_names: List[str] = None) -> None:
"""Expands place containment hierarchies."""
if not import_names:
return
Expand All @@ -96,7 +162,7 @@ def run_linked_contained_in_place(self, import_names: List[str] = None) -> None:
provenances = [f"'{prefix}{name}'" for name in safe_names]
provenance_filter = f" AND provenance IN ({', '.join(provenances)})"
gen_graphs_prov = 'dc/base/GeneratedGraphs' if self.is_base_dc else 'GeneratedGraphs'

query = f"""
-- Pull base edges needed for containedInPlace aggregation
CREATE OR REPLACE TEMPORARY TABLE `temp_base_contained_in_place` AS
Expand Down Expand Up @@ -171,7 +237,7 @@ def run_linked_contained_in_place(self, import_names: List[str] = None) -> None:
FROM
FilteredEdges
"""
self.executor.execute(query)
return self.executor.execute(query)

def run_linked_member_of(self, import_names: List[str] = None) -> None:
"""Expands membership hierarchies using memberOf and specializationOf."""
Expand Down Expand Up @@ -263,7 +329,7 @@ def run_linked_member_of(self, import_names: List[str] = None) -> None:
FROM
FilteredEdges
"""
self.executor.execute(query)
return self.executor.execute(query)

def run_linked_member(self, import_names: List[str] = None) -> None:
"""Expands topic/SVGP descendants to identify leaf members."""
Expand Down Expand Up @@ -356,38 +422,44 @@ def run_linked_member(self, import_names: List[str] = None) -> None:
FROM
FilteredEdges
"""
self.executor.execute(query)
return self.executor.execute(query)


class ProvenanceSummaryGenerator:
"""Contains the SQL queries to generate ProvenanceSummary in the Cache table."""
def __init__(self, executor: BigQueryExecutor, is_base_dc: bool = True) -> None:

def __init__(self,
executor: BigQueryExecutor,
is_base_dc: bool = True) -> None:
self.executor = executor
self.is_base_dc = is_base_dc

def run_all(self, import_names: List[str]) -> None:
"""Runs all provenance summary generation in sequence."""
def run_all(self, import_names: List[str]) -> List[bigquery.job.QueryJob]:
"""Runs all provenance summary generation asynchronously and returns their jobs."""
if not import_names:
logging.info("No imports specified. Skipping cache aggregations.")
return
return []

logging.info(f"Running provenance summary generation for imports: {import_names}")
self.run_provenance_summary_aggregation(import_names)
logging.info(
f"Running provenance summary generation for imports: {import_names}"
)
return [self.run_provenance_summary_aggregation(import_names)]

def run_provenance_summary_aggregation(self, import_names: List[str]) -> None:
def run_provenance_summary_aggregation(self,
import_names: List[str]) -> None:
"""Calculates ProvenanceSummary for all variables and populates the Cache table."""
if not import_names:
return

dest = self.executor.get_spanner_destination_uri()
connection_id = self.executor.connection_id

# Escape single quotes to prevent SQL injection
safe_names = [name.replace("'", "''") for name in import_names]
# Format import names for the SQL IN clause
imports_str = ", ".join([f"'{name}'" for name in safe_names])
provenance_dcid_expr = "CONCAT('dc/base/', raw.import_name)" if self.is_base_dc else "raw.import_name"

query = f"""
-- Step 1: Fetch Observation rows for the specific import
-- We cast 'observations' to STRING to avoid the PROTO error.
Expand Down Expand Up @@ -571,31 +643,33 @@ def run_provenance_summary_aggregation(self, import_names: List[str]) -> None:
FROM facet_summaries
GROUP BY variable_measured, provenance_dcid;
"""
self.executor.execute(query)
return self.executor.execute(query)


class AggregationUtils:
"""Orchestrates the overall aggregation workflow."""
def __init__(self,

def __init__(self,
connection_id: str,
project_id: str,
instance_id: str,
database_id: str,
location: Optional[str] = None,
is_base_dc: bool = True) -> None:
self.executor = BigQueryExecutor(
connection_id=connection_id,
project_id=project_id,
instance_id=instance_id,
database_id=database_id,
location=location
)
self.linked_edge_generator = LinkedEdgeGenerator(self.executor, is_base_dc)
self.provenance_summary_generator = ProvenanceSummaryGenerator(self.executor, is_base_dc)

def run_aggregation(self, import_list: List[Dict[str, Any]]) -> bool:
self.executor = BigQueryExecutor(connection_id=connection_id,
project_id=project_id,
instance_id=instance_id,
database_id=database_id,
location=location)
self.linked_edge_generator = LinkedEdgeGenerator(
self.executor, is_base_dc)
self.provenance_summary_generator = ProvenanceSummaryGenerator(
self.executor, is_base_dc)

def run_aggregation(self, import_list: List[Dict[str, Any]]) -> List[str]:
"""
Orchestrates standard per-import aggregations and global aggregations.
Returns a list of BigQuery job IDs for async polling.
"""
logging.info(f"Received request for importList: {import_list}")

Expand All @@ -608,17 +682,34 @@ def run_aggregation(self, import_list: List[Dict[str, Any]]) -> bool:
import_names.append(import_name)
query = "SELECT @import_name as import_name, CURRENT_TIMESTAMP() as execution_time"
job_config = bigquery.QueryJobConfig(query_parameters=[
bigquery.ScalarQueryParameter("import_name", "STRING", import_name),
bigquery.ScalarQueryParameter("import_name", "STRING",
import_name),
])
self.executor.execute(query, job_config=job_config)
else:
logging.info('Skipping aggregation logic for empty importName')
logging.info(
'Skipping aggregation logic for empty importName')

# 2. Run global aggregations
self.linked_edge_generator.run_all(import_names)
self.provenance_summary_generator.run_all(import_names)

return True
# 2. Run global aggregations asynchronously
jobs = []
jobs.extend(self.linked_edge_generator.run_all(import_names))
jobs.extend(self.provenance_summary_generator.run_all(import_names))

job_ids = [job.job_id for job in jobs if job]
logging.info(f"Submitted async aggregation jobs: {job_ids}")

return job_ids
except Exception as e:
logging.error(f"Aggregation failed: {e}")
raise e

def check_aggregation_status(self, job_ids: List[str]) -> Dict[str, Any]:
"""
Checks the status of the provided BigQuery job IDs.
"""
logging.info(f"Checking status for jobs: {job_ids}")
try:
return self.executor.get_jobs_status(job_ids)
except Exception as e:
logging.error(f"Failed to check aggregation status: {e}")
raise e
27 changes: 22 additions & 5 deletions import-automation/workflow/ingestion-helper/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,10 @@ def ingestion_helper(request):
is_base_dc=FLAGS.is_base_dc,
)
try:
if aggregation.run_aggregation(import_list):
return ('OK', 200)
else:
return ('Aggregation failed', 500)
job_ids = aggregation.run_aggregation(import_list)
return jsonify({'status': 'SUBMITTED', 'jobIds': job_ids}), 200
except Exception as e:
return (f"Aggregation failed: {str(e)}", 500)

elif action_type == 'clear_redis_cache':
logging.info("Action: clear_redis_cache")
redis_host = os.environ.get("REDIS_HOST")
Expand All @@ -306,6 +303,26 @@ def ingestion_helper(request):
else:
logging.warning("REDIS_HOST not set, skipping cache flush.")
return jsonify({'status': 'SKIPPED', 'message': 'REDIS_HOST not set'}), 200
elif action_type == 'check_aggregation_status':
# Checks the status of submitted aggregation BigQuery jobs.
# Input:
# jobIds: list of BigQuery job IDs
job_ids = request_json.get('jobIds', [])
if not job_ids:
return ('Missing or empty jobIds', 400)

aggregation = AggregationUtils(
connection_id=FLAGS.spanner_connection_id,
project_id=FLAGS.spanner_project_id,
instance_id=FLAGS.spanner_instance_id,
database_id=FLAGS.spanner_graph_database_id,
location=FLAGS.location,
)
try:
status = aggregation.check_aggregation_status(job_ids)
return jsonify(status), 200
except Exception as e:
return (f"Aggregation status check failed: {str(e)}", 500)

else:
return (f'Unknown actionType: {action_type}', 400)
Loading
Loading