diff --git a/import-automation/workflow/ingestion-helper/handlers/__init__.py b/import-automation/workflow/ingestion-helper/handlers/__init__.py new file mode 100644 index 0000000000..2bd3384aaf --- /dev/null +++ b/import-automation/workflow/ingestion-helper/handlers/__init__.py @@ -0,0 +1 @@ +# Initialize handlers package diff --git a/import-automation/workflow/ingestion-helper/handlers/aggregation.py b/import-automation/workflow/ingestion-helper/handlers/aggregation.py new file mode 100644 index 0000000000..ce1851ba1f --- /dev/null +++ b/import-automation/workflow/ingestion-helper/handlers/aggregation.py @@ -0,0 +1,40 @@ +import logging +import os +from absl import flags +from .helpers.aggregation_utils import AggregationUtils + +FLAGS = flags.FLAGS + +def handle_run_aggregation(request_json): + """Runs aggregation logic for the specified imports.""" + import_list = request_json.get('importList', []) + + # Validate required flags are not empty or None + missing_flags = [] + if not FLAGS.spanner_connection_id: + missing_flags.append('spanner_connection_id (BQ_SPANNER_CONN_ID)') + if not FLAGS.spanner_project_id: + missing_flags.append('spanner_project_id (SPANNER_PROJECT_ID)') + if not FLAGS.spanner_instance_id: + missing_flags.append('spanner_instance_id (SPANNER_INSTANCE_ID)') + if not FLAGS.spanner_graph_database_id: + missing_flags.append('spanner_graph_database_id (SPANNER_GRAPH_DATABASE_ID)') + + if missing_flags: + error_msg = f"Missing required configuration flags/env-vars: {', '.join(missing_flags)}" + logging.error(error_msg) + return (error_msg, 400) + + aggregation = AggregationUtils( + connection_id=FLAGS.spanner_connection_id, + project_id=FLAGS.spanner_project_id, + instance_id=FLAGS.spanner_instance_id, + database_id=FLAGS.spanner_graph_database_id, + ) + try: + if aggregation.run_aggregation(import_list): + return ('OK', 200) + else: + return ('Aggregation failed', 500) + except Exception as e: + return (f"Aggregation failed: {str(e)}", 500) diff --git a/import-automation/workflow/ingestion-helper/handlers/cache.py b/import-automation/workflow/ingestion-helper/handlers/cache.py new file mode 100644 index 0000000000..23611196a9 --- /dev/null +++ b/import-automation/workflow/ingestion-helper/handlers/cache.py @@ -0,0 +1,22 @@ +import logging +import os +from flask import jsonify +import redis + +def handle_clear_redis_cache(request_json): + """Flushes the Redis cache.""" + logging.info("Action: clear_redis_cache") + redis_host = os.environ.get("REDIS_HOST") + redis_port = os.environ.get("REDIS_PORT", "6379") + if redis_host: + try: + r = redis.Redis(host=redis_host, port=int(redis_port)) + r.flushall(asynchronous=True) + logging.info(f"Redis cache at {redis_host}:{redis_port} flushed successfully (async).") + return jsonify({'status': 'SUCCESS', 'message': 'Cache cleared'}), 200 + except Exception as e: + logging.error(f"Failed to flush Redis cache: {e}") + return jsonify({'status': 'ERROR', 'message': str(e)}), 500 + else: + logging.warning("REDIS_HOST not set, skipping cache flush.") + return jsonify({'status': 'SKIPPED', 'message': 'REDIS_HOST not set'}), 200 diff --git a/import-automation/workflow/ingestion-helper/handlers/database.py b/import-automation/workflow/ingestion-helper/handlers/database.py new file mode 100644 index 0000000000..60222249c5 --- /dev/null +++ b/import-automation/workflow/ingestion-helper/handlers/database.py @@ -0,0 +1,18 @@ +import logging +from absl import flags + +FLAGS = flags.FLAGS + +def handle_initialize_database(spanner, request_json): + """Initializes the database by creating all required tables and proto bundles.""" + logging.info("Action: initialize_database") + enable_embeddings = request_json.get('enableEmbeddings', + FLAGS.enable_embeddings) + spanner.initialize_database(enable_embeddings=enable_embeddings) + return ('OK', 200) + +def handle_seed_database(spanner): + """Seeds the database with base empty nodes.""" + logging.info("Action: seed_database") + spanner.seed_database() + return ('OK', 200) diff --git a/import-automation/workflow/ingestion-helper/handlers/embeddings.py b/import-automation/workflow/ingestion-helper/handlers/embeddings.py new file mode 100644 index 0000000000..d3b8644dd9 --- /dev/null +++ b/import-automation/workflow/ingestion-helper/handlers/embeddings.py @@ -0,0 +1,26 @@ +import logging +from absl import flags +from .helpers.embedding_utils import get_latest_lock_timestamp, get_updated_nodes, filter_and_convert_nodes, generate_embeddings_partitioned + +FLAGS = flags.FLAGS + +def handle_embedding_ingestion(spanner, request_json): + """Handles embedding ingestion.""" + logging.info("Action: embedding_ingestion") + enable_embeddings = request_json.get('enableEmbeddings', + FLAGS.enable_embeddings) + if not enable_embeddings: + logging.info("Embeddings not enabled, skipping.") + return ('Invalid request on embedding ingestion.', 400) + + node_types = FLAGS.node_types + try: + logging.info(f"Job started. Fetching all nodes for types: {node_types}") + timestamp = get_latest_lock_timestamp(spanner.database) + nodes = get_updated_nodes(spanner.database, timestamp, node_types) + converted_nodes = filter_and_convert_nodes(nodes) + affected_rows = generate_embeddings_partitioned(spanner.database, converted_nodes) + return (f"OK [Affected rows: {affected_rows}]", 200) + except Exception as e: + logging.error(f"Embedding ingestion failed: {e}") + return (f"Error: {e}", 500) diff --git a/import-automation/workflow/ingestion-helper/handlers/helpers/__init__.py b/import-automation/workflow/ingestion-helper/handlers/helpers/__init__.py new file mode 100644 index 0000000000..e6e8b3f199 --- /dev/null +++ b/import-automation/workflow/ingestion-helper/handlers/helpers/__init__.py @@ -0,0 +1 @@ +# Initialize helpers package diff --git a/import-automation/workflow/ingestion-helper/aggregation_utils.py b/import-automation/workflow/ingestion-helper/handlers/helpers/aggregation_utils.py similarity index 100% rename from import-automation/workflow/ingestion-helper/aggregation_utils.py rename to import-automation/workflow/ingestion-helper/handlers/helpers/aggregation_utils.py diff --git a/import-automation/workflow/ingestion-helper/embedding_utils.py b/import-automation/workflow/ingestion-helper/handlers/helpers/embedding_utils.py similarity index 100% rename from import-automation/workflow/ingestion-helper/embedding_utils.py rename to import-automation/workflow/ingestion-helper/handlers/helpers/embedding_utils.py diff --git a/import-automation/workflow/ingestion-helper/import_utils.py b/import-automation/workflow/ingestion-helper/handlers/helpers/import_utils.py similarity index 100% rename from import-automation/workflow/ingestion-helper/import_utils.py rename to import-automation/workflow/ingestion-helper/handlers/helpers/import_utils.py diff --git a/import-automation/workflow/ingestion-helper/handlers/imports.py b/import-automation/workflow/ingestion-helper/handlers/imports.py new file mode 100644 index 0000000000..34305f43ed --- /dev/null +++ b/import-automation/workflow/ingestion-helper/handlers/imports.py @@ -0,0 +1,99 @@ +import logging +import os +from flask import jsonify +from absl import flags +from .helpers import import_utils +from .utils import validate_params + +FLAGS = flags.FLAGS + +def handle_get_import_info(spanner, request_json): + """Gets the details of imports that are ready for ingestion.""" + import_list = request_json.get('importList', []) + import_info = spanner.get_import_info(import_list) + return jsonify(import_info) + +def handle_update_ingestion_status(spanner, request_json): + """Updates the status of imports after ingestion.""" + validation_error = validate_params( + request_json, ['importList', 'workflowId', 'jobId', 'status']) + if validation_error: + return (validation_error, 400) + + import_list = request_json['importList'] + workflow_id = request_json['workflowId'] + status = request_json['status'] + job_id = request_json['jobId'] + ingested_imports = [item['importName'] for item in import_list] + + spanner.update_ingestion_status(ingested_imports, workflow_id, status) + metrics = import_utils.get_ingestion_metrics(FLAGS.project_id, + FLAGS.location, job_id) + spanner.update_ingestion_history(workflow_id, job_id, ingested_imports, + metrics) + if status == 'SUCCESS': + spanner.update_import_version_history(import_list, workflow_id) + return ('OK', 200) + +def handle_update_import_status(spanner, storage, request_json): + """Updates the status of a specific import job.""" + validation_error = validate_params(request_json, + ['importName', 'status']) + if validation_error: + return (validation_error, 400) + + import_name = request_json['importName'] + status = request_json['status'] + logging.info(f'Updating import {import_name} to status {status}') + params = import_utils.get_import_params(request_json) + next_refresh = import_utils.get_next_refresh(FLAGS.project_id, + FLAGS.location, + import_name) + if next_refresh: + params['next_refresh'] = next_refresh + if status == 'STAGING': + version = os.path.basename(request_json.get('latestVersion', '')) + if not version: + return (f'Empty version for import {import_name}', 500) + storage.update_version_file(import_name, version, is_staging=True) + storage.update_provenance_file(import_name, version) + storage.update_import_summary(params) + storage.update_version_file(import_name, version, is_staging=False) + comment = f"import-workflow:{request_json.get('jobId','')}" + spanner.update_version_history(import_name, version, comment) + spanner.update_import_status(params) + return ('OK', 200) + +def handle_update_import_version(spanner, storage, request, request_json): + """Updates the version and status of an import.""" + validation_error = validate_params( + request_json, ['importName', 'version', 'comment']) + if validation_error: + return (validation_error, 400) + + import_name = request_json['importName'] + version = request_json['version'] + comment = request_json['comment'] + logging.info( + f"Updating import {import_name} to version {version} comment:{comment}" + ) + override = request_json.get('override', False) + if version == 'STAGING': + version = storage.get_staging_version(import_name) + summary = storage.get_import_summary(import_name, version) + params = import_utils.get_import_params(summary) + if override: + params['status'] = 'STAGING' + caller = import_utils.get_caller_identity(request) + comment = f'version-override:{caller} {comment}' + if params['status'] == 'STAGING': + storage.update_provenance_file(import_name, version) + storage.update_version_file(import_name, version, is_staging=False) + spanner.update_version_history(import_name, version, comment) + logging.info(f"Updated import {import_name} to version {version}") + else: + logging.info(f"Skipping {import_name} version update") + spanner.update_import_status(params) + return ( + f"OK [Import: {import_name} Version: {version} Status: {params['status']}]", + 200) diff --git a/import-automation/workflow/ingestion-helper/handlers/lock.py b/import-automation/workflow/ingestion-helper/handlers/lock.py new file mode 100644 index 0000000000..4530798d2a --- /dev/null +++ b/import-automation/workflow/ingestion-helper/handlers/lock.py @@ -0,0 +1,27 @@ +import logging +from .utils import validate_params + +def handle_acquire_lock(spanner, request_json): + """Attempts to acquire the global lock for ingestion.""" + validation_error = validate_params(request_json, ['workflowId', 'timeout']) + if validation_error: + return (validation_error, 400) + + workflow = request_json['workflowId'] + timeout = request_json['timeout'] + status = spanner.acquire_lock(workflow, timeout) + if not status: + return ('Failed to acquire lock', 500) + return ('OK', 200) + +def handle_release_lock(spanner, request_json): + """Releases the global ingestion lock.""" + validation_error = validate_params(request_json, ['workflowId']) + if validation_error: + return (validation_error, 400) + + workflow = request_json['workflowId'] + status = spanner.release_lock(workflow) + if not status: + return ('Failed to release lock', 500) + return ('OK', 200) diff --git a/import-automation/workflow/ingestion-helper/handlers/utils.py b/import-automation/workflow/ingestion-helper/handlers/utils.py new file mode 100644 index 0000000000..49cc1c7277 --- /dev/null +++ b/import-automation/workflow/ingestion-helper/handlers/utils.py @@ -0,0 +1,6 @@ +def validate_params(request_json, required_params): + """Validates that required parameters are present in the request JSON.""" + for param in required_params: + if param not in request_json: + return f"'{param}' parameter is missing" + return None diff --git a/import-automation/workflow/ingestion-helper/main.py b/import-automation/workflow/ingestion-helper/main.py index b452788875..ca967094ee 100644 --- a/import-automation/workflow/ingestion-helper/main.py +++ b/import-automation/workflow/ingestion-helper/main.py @@ -1,13 +1,13 @@ import functions_framework from spanner_client import SpannerClient from storage_client import StorageClient -from embedding_utils import get_latest_lock_timestamp, get_updated_nodes, filter_and_convert_nodes, generate_embeddings_partitioned + import logging import os from absl import flags -import import_utils + from flask import jsonify -from aggregation_utils import AggregationUtils + logging.getLogger().setLevel(logging.INFO) @@ -56,6 +56,28 @@ def _validate_params(request_json, required_params): return f"'{param}' parameter is missing" return None +_spanner_client = None +_storage_client = None + +def _get_spanner_client(): + global _spanner_client + if _spanner_client is None: + _spanner_client = SpannerClient( + FLAGS.spanner_project_id, + FLAGS.spanner_instance_id, + FLAGS.spanner_database_id, + graph_database_id=FLAGS.spanner_graph_database_id, + location=FLAGS.location, + model_id=os.environ.get('EMBEDDING_MODEL_ID', 'text-embedding-005') + ) + return _spanner_client + +def _get_storage_client(): + global _storage_client + if _storage_client is None: + _storage_client = StorageClient(FLAGS.gcs_bucket_id) + return _storage_client + @functions_framework.http def ingestion_helper(request): @@ -71,223 +93,49 @@ def ingestion_helper(request): return (validation_error, 400) action_type = request_json['actionType'] - spanner = SpannerClient(FLAGS.spanner_project_id, - FLAGS.spanner_instance_id, - FLAGS.spanner_database_id, - graph_database_id=FLAGS.spanner_graph_database_id, - location=FLAGS.location, - model_id=os.environ.get('EMBEDDING_MODEL_ID', - 'text-embedding-005')) - storage = StorageClient(FLAGS.gcs_bucket_id) + spanner = _get_spanner_client() + storage = _get_storage_client() if action_type == 'get_import_info': - # Gets the details of imports that are ready for ingestion. - # Input: - # importList: list of import names to ingest (optional) - import_list = request_json.get('importList', []) - import_info = spanner.get_import_info(import_list) - return jsonify(import_info) + from handlers.imports import handle_get_import_info + return handle_get_import_info(spanner, request_json) elif action_type == 'acquire_ingestion_lock': - # Attempts to acquire the global lock for ingestion. - # Input: - # workflowId: ID of the workflow acquiring the lock - # timeout: lock duration in seconds - validation_error = _validate_params(request_json, - ['workflowId', 'timeout']) - if validation_error: - return (validation_error, 400) - workflow = request_json['workflowId'] - timeout = request_json['timeout'] - status = spanner.acquire_lock(workflow, timeout) - if not status: - return ('Failed to acquire lock', 500) - return ('OK', 200) + from handlers.lock import handle_acquire_lock + return handle_acquire_lock(spanner, request_json) elif action_type == 'release_ingestion_lock': - # Releases the global ingestion lock. - # Input: - # workflowId: ID of the workflow releasing the lock - validation_error = _validate_params(request_json, ['workflowId']) - if validation_error: - return (validation_error, 400) - workflow = request_json['workflowId'] - status = spanner.release_lock(workflow) - if not status: - return ('Failed to release lock', 500) - return ('OK', 200) + from handlers.lock import handle_release_lock + return handle_release_lock(spanner, request_json) elif action_type == 'update_ingestion_status': - # Updates the status of imports after ingestion. - # Input: - # importList: list of import names - # workflowId: ID of the workflow - # status: import status - # jobId: Dataflow job ID - validation_error = _validate_params( - request_json, ['importList', 'workflowId', 'jobId', 'status']) - if validation_error: - return (validation_error, 400) - import_list = request_json['importList'] - workflow_id = request_json['workflowId'] - status = request_json['status'] - job_id = request_json['jobId'] - ingested_imports = [item['importName'] for item in import_list] - - spanner.update_ingestion_status(ingested_imports, workflow_id, status) - metrics = import_utils.get_ingestion_metrics(FLAGS.project_id, - FLAGS.location, job_id) - spanner.update_ingestion_history(workflow_id, job_id, ingested_imports, - metrics) - if status == 'SUCCESS': - spanner.update_import_version_history(import_list, workflow_id) - return ('OK', 200) + from handlers.imports import handle_update_ingestion_status + return handle_update_ingestion_status(spanner, request_json) elif action_type == 'update_import_status': - # Updates the status of a specific import job. - # Input: - # importName: name of the import - # status: new status - # jobId: Batch job ID (optional) - # executionTime: execution time in seconds (optional) - # dataVolume: data volume in bytes (optional) - # latestVersion: latest version string (optional) - # graphPath: graph path regex (optional) - # nextRefresh: next refresh timestamp (optional) - validation_error = _validate_params(request_json, - ['importName', 'status']) - if validation_error: - return (validation_error, 400) - import_name = request_json['importName'] - status = request_json['status'] - logging.info(f'Updating import {import_name} to status {status}') - params = import_utils.get_import_params(request_json) - next_refresh = import_utils.get_next_refresh(FLAGS.project_id, - FLAGS.location, - import_name) - if next_refresh: - params['next_refresh'] = next_refresh - if status == 'STAGING': - version = os.path.basename(request_json.get('latestVersion', '')) - if not version: - return (f'Empty version for import {import_name}', 500) - storage.update_version_file(import_name, version, is_staging=True) - storage.update_provenance_file(import_name, version) - storage.update_import_summary(params) - storage.update_version_file(import_name, version, is_staging=False) - comment = f"import-workflow:{request_json.get('jobId','')}" - spanner.update_version_history(import_name, version, comment) - spanner.update_import_status(params) - return ('OK', 200) + from handlers.imports import handle_update_import_status + return handle_update_import_status(spanner, storage, request_json) elif action_type == 'update_import_version': - # Updates the version and status of an import. - # Input: - # importName: name of the import - # version: version string - # comment: audit log comment - # override: override status check (optional) - # triggerIngestion: trigger ingestion workflow (optional) - validation_error = _validate_params( - request_json, ['importName', 'version', 'comment']) - if validation_error: - return (validation_error, 400) - import_name = request_json['importName'] - version = request_json['version'] - comment = request_json['comment'] - logging.info( - f"Updating import {import_name} to version {version} comment:{comment}" - ) - override = request_json.get('override', False) - if version == 'STAGING': - version = storage.get_staging_version(import_name) - summary = storage.get_import_summary(import_name, version) - params = import_utils.get_import_params(summary) - if override: - params['status'] = 'STAGING' - caller = import_utils.get_caller_identity(request) - comment = f'version-override:{caller} {comment}' - if params['status'] == 'STAGING': - storage.update_provenance_file(import_name, version) - storage.update_version_file(import_name, version, is_staging=False) - spanner.update_version_history(import_name, version, comment) - logging.info(f"Updated import {import_name} to version {version}") - else: - logging.info(f"Skipping {import_name} version update") - spanner.update_import_status(params) - return ( - f"OK [Import: {import_name} Version: {version} Status: {params['status']}]", - 200) + from handlers.imports import handle_update_import_version + return handle_update_import_version(spanner, storage, request, request_json) elif action_type == 'initialize_database': - # Initializes the database by creating all required tables and proto bundles. - logging.info("Action: initialize_database") - enable_embeddings = request_json.get('enableEmbeddings', - FLAGS.enable_embeddings) - spanner.initialize_database(enable_embeddings=enable_embeddings) - return ('OK', 200) + from handlers.database import handle_initialize_database + return handle_initialize_database(spanner, request_json) elif action_type == 'seed_database': - # Seeds the database with base empty nodes. - logging.info("Action: seed_database") - spanner.seed_database() - return ('OK', 200) + from handlers.database import handle_seed_database + return handle_seed_database(spanner) elif action_type == 'embedding_ingestion': - - logging.info("Action: embedding_ingestion") - enable_embeddings = request_json.get('enableEmbeddings', - FLAGS.enable_embeddings) - if not enable_embeddings: - logging.info("Embeddings not enabled, skipping.") - return ('Invalid request on embedding ingestion.', 400) - - node_types = FLAGS.node_types - try: - logging.info(f"Job started. Fetching all nodes for types: {node_types}") - timestamp = get_latest_lock_timestamp(spanner.database) - nodes = get_updated_nodes(spanner.database, timestamp, node_types) - converted_nodes = filter_and_convert_nodes(nodes) - affected_rows = generate_embeddings_partitioned(spanner.database, converted_nodes) - return (f"OK [Affected rows: {affected_rows}]", 200) - except Exception as e: - logging.error(f"Embedding ingestion failed: {e}") - return (f"Error: {e}", 500) + from handlers.embeddings import handle_embedding_ingestion + return handle_embedding_ingestion(spanner, request_json) elif action_type == 'run_aggregation': - # Runs aggregation logic for the specified imports. - # Input: - # importList: list of imports to aggregate - import_list = request_json.get('importList', []) - - # Validate required flags are not empty or None - missing_flags = [] - if not FLAGS.spanner_connection_id: - missing_flags.append('spanner_connection_id (BQ_SPANNER_CONN_ID)') - if not FLAGS.spanner_project_id: - missing_flags.append('spanner_project_id (SPANNER_PROJECT_ID)') - if not FLAGS.spanner_instance_id: - missing_flags.append('spanner_instance_id (SPANNER_INSTANCE_ID)') - if not FLAGS.spanner_graph_database_id: - missing_flags.append('spanner_graph_database_id (SPANNER_GRAPH_DATABASE_ID)') - - if missing_flags: - error_msg = f"Missing required configuration flags/env-vars: {', '.join(missing_flags)}" - logging.error(error_msg) - return (error_msg, 400) + from handlers.aggregation import handle_run_aggregation + return handle_run_aggregation(request_json) - aggregation = AggregationUtils( - connection_id=FLAGS.spanner_connection_id, - project_id=FLAGS.spanner_project_id, - instance_id=FLAGS.spanner_instance_id, - database_id=FLAGS.spanner_graph_database_id, - location=FLAGS.location, - is_base_dc=FLAGS.is_base_dc, - ) - try: - if aggregation.run_aggregation(import_list): - return ('OK', 200) - else: - return ('Aggregation failed', 500) - except Exception as e: - return (f"Aggregation failed: {str(e)}", 500) + elif action_type == 'clear_redis_cache': + from handlers.cache import handle_clear_redis_cache + return handle_clear_redis_cache(request_json) else: return (f'Unknown actionType: {action_type}', 400) diff --git a/import-automation/workflow/ingestion-helper/pyproject.toml b/import-automation/workflow/ingestion-helper/pyproject.toml index 9fafa50fed..7abcac136a 100644 --- a/import-automation/workflow/ingestion-helper/pyproject.toml +++ b/import-automation/workflow/ingestion-helper/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "google-auth", "absl-py", "google-cloud-bigquery", + "redis", ] [tool.hatch.build.targets.wheel] diff --git a/import-automation/workflow/ingestion-helper/tests/__init__.py b/import-automation/workflow/ingestion-helper/tests/__init__.py new file mode 100644 index 0000000000..3a34395478 --- /dev/null +++ b/import-automation/workflow/ingestion-helper/tests/__init__.py @@ -0,0 +1 @@ +# Initialize tests package diff --git a/import-automation/workflow/ingestion-helper/tests/handlers/__init__.py b/import-automation/workflow/ingestion-helper/tests/handlers/__init__.py new file mode 100644 index 0000000000..e20cb8859a --- /dev/null +++ b/import-automation/workflow/ingestion-helper/tests/handlers/__init__.py @@ -0,0 +1 @@ +# Initialize handler tests package diff --git a/import-automation/workflow/ingestion-helper/tests/handlers/test_cache.py b/import-automation/workflow/ingestion-helper/tests/handlers/test_cache.py new file mode 100644 index 0000000000..9ea6bf8646 --- /dev/null +++ b/import-automation/workflow/ingestion-helper/tests/handlers/test_cache.py @@ -0,0 +1,44 @@ +import unittest +from unittest.mock import MagicMock, patch +import sys +import os + +# Add parent directory to path to find handlers and spanner_client +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) + +from handlers.cache import handle_clear_redis_cache + +class TestCacheHandlers(unittest.TestCase): + + @patch('handlers.cache.os.environ') + @patch('handlers.cache.redis.Redis') + def test_handle_clear_redis_cache_success(self, mock_redis, mock_environ): + mock_environ.get.side_effect = lambda k, d=None: 'localhost' if k == 'REDIS_HOST' else ('6379' if k == 'REDIS_PORT' else d) + + mock_r = MagicMock() + mock_redis.return_value = mock_r + + request_json = {} + + # Mock flask.jsonify + with patch('handlers.cache.jsonify', side_effect=lambda x: x): + response, status_code = handle_clear_redis_cache(request_json) + + self.assertEqual(status_code, 200) + self.assertEqual(response['status'], 'SUCCESS') + mock_r.flushall.assert_called_once_with(asynchronous=True) + + @patch('handlers.cache.os.environ') + def test_handle_clear_redis_cache_skipped(self, mock_environ): + mock_environ.get.return_value = None # REDIS_HOST not set + + request_json = {} + + with patch('handlers.cache.jsonify', side_effect=lambda x: x): + response, status_code = handle_clear_redis_cache(request_json) + + self.assertEqual(status_code, 200) + self.assertEqual(response['status'], 'SKIPPED') + +if __name__ == '__main__': + unittest.main() diff --git a/import-automation/workflow/ingestion-helper/tests/handlers/test_database.py b/import-automation/workflow/ingestion-helper/tests/handlers/test_database.py new file mode 100644 index 0000000000..bd7c6a44e7 --- /dev/null +++ b/import-automation/workflow/ingestion-helper/tests/handlers/test_database.py @@ -0,0 +1,49 @@ +import unittest +from unittest.mock import MagicMock, patch +import sys +import os + +# Add parent directory to path to find handlers and spanner_client +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) + +from handlers.database import handle_initialize_database, handle_seed_database + +class TestDatabaseHandlers(unittest.TestCase): + + @patch('handlers.database.FLAGS') + def test_handle_initialize_database(self, mock_flags): + mock_flags.enable_embeddings = False + + mock_spanner = MagicMock() + request_json = {} + + response, status_code = handle_initialize_database(mock_spanner, request_json) + + self.assertEqual(status_code, 200) + self.assertEqual(response, 'OK') + mock_spanner.initialize_database.assert_called_once_with(enable_embeddings=False) + + @patch('handlers.database.FLAGS') + def test_handle_initialize_database_enable_embeddings(self, mock_flags): + mock_flags.enable_embeddings = False + + mock_spanner = MagicMock() + request_json = {'enableEmbeddings': True} + + response, status_code = handle_initialize_database(mock_spanner, request_json) + + self.assertEqual(status_code, 200) + self.assertEqual(response, 'OK') + mock_spanner.initialize_database.assert_called_once_with(enable_embeddings=True) + + def test_handle_seed_database(self): + mock_spanner = MagicMock() + + response, status_code = handle_seed_database(mock_spanner) + + self.assertEqual(status_code, 200) + self.assertEqual(response, 'OK') + mock_spanner.seed_database.assert_called_once() + +if __name__ == '__main__': + unittest.main() diff --git a/import-automation/workflow/ingestion-helper/tests/handlers/test_imports.py b/import-automation/workflow/ingestion-helper/tests/handlers/test_imports.py new file mode 100644 index 0000000000..d223e567a1 --- /dev/null +++ b/import-automation/workflow/ingestion-helper/tests/handlers/test_imports.py @@ -0,0 +1,75 @@ +import unittest +from unittest.mock import MagicMock, patch +import sys +import os + +# Add parent directory to path to find handlers and spanner_client +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) + +from handlers.imports import handle_get_import_info, handle_update_ingestion_status, handle_update_import_status + +class TestImportHandlers(unittest.TestCase): + + def test_handle_get_import_info(self): + mock_spanner = MagicMock() + mock_spanner.get_import_info.return_value = [{'importName': 'test_import'}] + + request_json = {'importList': ['test_import']} + + with patch('handlers.imports.jsonify', side_effect=lambda x: x): + response = handle_get_import_info(mock_spanner, request_json) + + self.assertEqual(response, [{'importName': 'test_import'}]) + mock_spanner.get_import_info.assert_called_once_with(['test_import']) + + @patch('handlers.imports.FLAGS') + @patch('handlers.imports.import_utils') + def test_handle_update_ingestion_status_success(self, mock_import_utils, mock_flags): + mock_flags.project_id = 'test-project' + mock_flags.location = 'us-central1' + + mock_spanner = MagicMock() + mock_import_utils.get_ingestion_metrics.return_value = {'nodes': 10} + + request_json = { + 'importList': [{'importName': 'import1'}], + 'workflowId': 'wf123', + 'jobId': 'job123', + 'status': 'SUCCESS' + } + + response, status_code = handle_update_ingestion_status(mock_spanner, request_json) + + self.assertEqual(status_code, 200) + self.assertEqual(response, 'OK') + mock_spanner.update_ingestion_status.assert_called_once_with(['import1'], 'wf123', 'SUCCESS') + mock_spanner.update_ingestion_history.assert_called_once() + mock_spanner.update_import_version_history.assert_called_once() + + @patch('handlers.imports.FLAGS') + @patch('handlers.imports.import_utils') + def test_handle_update_import_status_staging(self, mock_import_utils, mock_flags): + mock_flags.project_id = 'test-project' + mock_flags.location = 'us-central1' + + mock_spanner = MagicMock() + mock_storage = MagicMock() + + request_json = { + 'importName': 'import1', + 'status': 'STAGING', + 'latestVersion': 'gs://bucket/import1/v1' + } + + mock_import_utils.get_import_params.return_value = {'import_name': 'import1', 'status': 'STAGING'} + mock_import_utils.get_next_refresh.return_value = None + + response, status_code = handle_update_import_status(mock_spanner, mock_storage, request_json) + + self.assertEqual(status_code, 200) + self.assertEqual(response, 'OK') + mock_storage.update_version_file.assert_called() + mock_spanner.update_import_status.assert_called_once() + +if __name__ == '__main__': + unittest.main() diff --git a/import-automation/workflow/ingestion-helper/tests/handlers/test_lock.py b/import-automation/workflow/ingestion-helper/tests/handlers/test_lock.py new file mode 100644 index 0000000000..532c60abe2 --- /dev/null +++ b/import-automation/workflow/ingestion-helper/tests/handlers/test_lock.py @@ -0,0 +1,65 @@ +import unittest +from unittest.mock import MagicMock +import sys +import os + +# Add parent directory to path to find handlers and spanner_client +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) + +from handlers.lock import handle_acquire_lock, handle_release_lock + +class TestLockHandlers(unittest.TestCase): + + def test_handle_acquire_lock_success(self): + mock_spanner = MagicMock() + mock_spanner.acquire_lock.return_value = True + + request_json = {'workflowId': 'wf123', 'timeout': 3600} + response, status_code = handle_acquire_lock(mock_spanner, request_json) + + self.assertEqual(status_code, 200) + self.assertEqual(response, 'OK') + mock_spanner.acquire_lock.assert_called_once_with('wf123', 3600) + + def test_handle_acquire_lock_failure(self): + mock_spanner = MagicMock() + mock_spanner.acquire_lock.return_value = False + + request_json = {'workflowId': 'wf123', 'timeout': 3600} + response, status_code = handle_acquire_lock(mock_spanner, request_json) + + self.assertEqual(status_code, 500) + self.assertEqual(response, 'Failed to acquire lock') + + def test_handle_acquire_lock_missing_param(self): + mock_spanner = MagicMock() + + request_json = {'workflowId': 'wf123'} # missing timeout + response, status_code = handle_acquire_lock(mock_spanner, request_json) + + self.assertEqual(status_code, 400) + self.assertIn('timeout', response) + + def test_handle_release_lock_success(self): + mock_spanner = MagicMock() + mock_spanner.release_lock.return_value = True + + request_json = {'workflowId': 'wf123'} + response, status_code = handle_release_lock(mock_spanner, request_json) + + self.assertEqual(status_code, 200) + self.assertEqual(response, 'OK') + mock_spanner.release_lock.assert_called_once_with('wf123') + + def test_handle_release_lock_failure(self): + mock_spanner = MagicMock() + mock_spanner.release_lock.return_value = False + + request_json = {'workflowId': 'wf123'} + response, status_code = handle_release_lock(mock_spanner, request_json) + + self.assertEqual(status_code, 500) + self.assertEqual(response, 'Failed to release lock') + +if __name__ == '__main__': + unittest.main() diff --git a/import-automation/workflow/ingestion-helper/embedding_utils_test.py b/import-automation/workflow/ingestion-helper/tests/test_embedding_utils.py similarity index 98% rename from import-automation/workflow/ingestion-helper/embedding_utils_test.py rename to import-automation/workflow/ingestion-helper/tests/test_embedding_utils.py index 299b293dc3..3e7416f50a 100644 --- a/import-automation/workflow/ingestion-helper/embedding_utils_test.py +++ b/import-automation/workflow/ingestion-helper/tests/test_embedding_utils.py @@ -16,7 +16,7 @@ from unittest.mock import MagicMock, patch from datetime import datetime -from embedding_utils import ( +from handlers.helpers.embedding_utils import ( get_latest_lock_timestamp, get_updated_nodes, filter_and_convert_nodes, @@ -121,7 +121,7 @@ def test_filter_and_convert_nodes(self): self.assertEqual(converted[0], ("dc/1", "Node 1", ["Topic"])) self.assertEqual(converted[1], ("dc/3", "Node 3", ["Topic", "StatisticalVariable"])) - @patch('embedding_utils._BATCH_SIZE', 2) + @patch('handlers.helpers.embedding_utils._BATCH_SIZE', 2) def test_generate_embeddings_partitioned(self): mock_database = MagicMock() diff --git a/import-automation/workflow/ingestion-helper/main_test.py b/import-automation/workflow/ingestion-helper/tests/test_main.py similarity index 100% rename from import-automation/workflow/ingestion-helper/main_test.py rename to import-automation/workflow/ingestion-helper/tests/test_main.py diff --git a/import-automation/workflow/ingestion-helper/spanner_client_test.py b/import-automation/workflow/ingestion-helper/tests/test_spanner_client.py similarity index 98% rename from import-automation/workflow/ingestion-helper/spanner_client_test.py rename to import-automation/workflow/ingestion-helper/tests/test_spanner_client.py index dac0cddb51..efa883a82c 100644 --- a/import-automation/workflow/ingestion-helper/spanner_client_test.py +++ b/import-automation/workflow/ingestion-helper/tests/test_spanner_client.py @@ -17,8 +17,8 @@ import sys import os -# Add the current directory to path so we can import spanner_client -sys.path.append(os.path.dirname(__file__)) +# Add the parent directory to path so we can import spanner_client +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from spanner_client import SpannerClient class TestSpannerClient(unittest.TestCase): @@ -281,4 +281,3 @@ def run_in_transaction_side_effect(callback, *args, **kwargs): if __name__ == '__main__': unittest.main() -