From 114358c9bc85d5a9736fca73956ec1d05a243c4e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 8 May 2026 17:26:49 -0600 Subject: [PATCH] chore: record per-query timings for multi-statement TPC files Some TPC-DS files (q14, q23, q24, q39) contain two SELECT statements that were previously timed as a single unit. TPC-H q15 wraps a SELECT in CREATE / DROP VIEW statements that should keep executing but should not be treated as separate queries. Classify each `;`-split statement as SELECT/WITH (timed) or DDL (executed only). Multi-SELECT files record per-query timings under keys like `14a` and `14b`; single-SELECT files keep their existing key. As a side effect, q15's row_count and result_hash now come from the SELECT rather than the trailing DROP VIEW. generate-comparison.py is updated to accept alphanumeric query keys and sort them so `14`, `14a`, `14b`, `15` appear in natural order; otherwise the new sub-query timings would be silently filtered out of comparison charts. --- benchmarks/tpc/generate-comparison.py | 22 ++++--- benchmarks/tpc/tpcbench.py | 89 +++++++++++++++++---------- 2 files changed, 71 insertions(+), 40 deletions(-) diff --git a/benchmarks/tpc/generate-comparison.py b/benchmarks/tpc/generate-comparison.py index e5058a3bfa..fb6fe3b112 100644 --- a/benchmarks/tpc/generate-comparison.py +++ b/benchmarks/tpc/generate-comparison.py @@ -20,10 +20,13 @@ import logging import matplotlib.pyplot as plt import numpy as np +import re logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +QUERY_KEY_RE = re.compile(r'^(\d+)([a-z]*)$') + def geomean(data): return np.prod(data) ** (1 / len(data)) @@ -34,19 +37,20 @@ def get_durations(result, query_key): return value["durations"] return value +def query_sort_key(key): + """Sort key for query labels like "14", "14a", "14b" so sub-queries sit between 14 and 15.""" + m = QUERY_KEY_RE.match(str(key)) + if m: + return (int(m.group(1)), m.group(2)) + return (float('inf'), str(key)) + def get_all_queries(results): - """Return the sorted union of all query keys across all result sets.""" + """Return the sorted union of query keys across all result sets, as strings.""" all_keys = set() for result in results: all_keys.update(result.keys()) - # Filter to numeric query keys and sort numerically - numeric_keys = [] - for k in all_keys: - try: - numeric_keys.append(int(k)) - except ValueError: - pass - return sorted(numeric_keys) + query_keys = [str(k) for k in all_keys if QUERY_KEY_RE.match(str(k))] + return sorted(query_keys, key=query_sort_key) def get_common_queries(results, labels): """Return queries present in ALL result sets, warning about queries missing from some files.""" diff --git a/benchmarks/tpc/tpcbench.py b/benchmarks/tpc/tpcbench.py index 036d7b0e9a..a654474a6d 100644 --- a/benchmarks/tpc/tpcbench.py +++ b/benchmarks/tpc/tpcbench.py @@ -56,6 +56,21 @@ def result_hash(rows): return h.hexdigest() +def first_keyword(sql): + """Return the first non-comment, non-whitespace keyword in lowercase.""" + for line in sql.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("--"): + continue + return stripped.split(None, 1)[0].lstrip("(").lower() + return "" + + +def is_select_statement(sql): + """Classify a SQL statement: True if it is a SELECT/WITH query, False if DDL.""" + return first_keyword(sql) in ("select", "with") + + def main( benchmark: str, data_path: str, @@ -147,44 +162,56 @@ def main( queries_to_run = range(1, num_queries + 1) for query in queries_to_run: - spark.sparkContext.setJobDescription(f"{benchmark} q{query}") - path = f"{query_path}/q{query}.sql" print(f"\nRunning query {query} from {path}") with open(path, "r") as f: text = f.read() - queries = text.split(";") + + statements = [s.strip() for s in text.split(";") if s.strip()] + select_indices = [i for i, s in enumerate(statements) if is_select_statement(s)] + multi = len(select_indices) > 1 + + for stmt_idx, sql in enumerate(statements): + sql = sql.replace("create view", "create temp view") + is_query = stmt_idx in select_indices + if is_query: + if multi: + suffix = chr(ord("a") + select_indices.index(stmt_idx)) + query_label = f"{query}{suffix}" + else: + query_label = str(query) + spark.sparkContext.setJobDescription(f"{benchmark} q{query_label}") + print(f"Executing query {query_label}: {sql[:100]}...") + else: + print(f"Executing DDL (not timed): {sql[:100]}...") start_time = time.time() - for sql in queries: - sql = sql.strip().replace("create view", "create temp view") - if len(sql) > 0: - print(f"Executing: {sql[:100]}...") - df = spark.sql(sql) - df.explain("formatted") - - if write_path is not None: - if len(df.columns) > 0: - output_path = f"{write_path}/q{query}" - deduped = dedup_columns(df) - deduped.orderBy(*deduped.columns).coalesce(1).write.mode("overwrite").parquet(output_path) - print(f"Results written to {output_path}") - else: - rows = df.collect() - row_count = len(rows) - row_hash = result_hash(rows) - print(f"Query {query} returned {row_count} rows, hash={row_hash}") - - end_time = time.time() - elapsed = end_time - start_time - print(f"Query {query} took {elapsed:.2f} seconds") - - query_result = results.setdefault(query, {"durations": []}) - query_result["durations"].append(round(elapsed, 3)) - if "row_count" not in query_result and not write_path: - query_result["row_count"] = row_count - query_result["result_hash"] = row_hash + df = spark.sql(sql) + df.explain("formatted") + + if is_query and write_path is not None: + if len(df.columns) > 0: + output_path = f"{write_path}/q{query_label}" + deduped = dedup_columns(df) + deduped.orderBy(*deduped.columns).coalesce(1).write.mode("overwrite").parquet(output_path) + print(f"Results written to {output_path}") + elapsed = time.time() - start_time + print(f"Query {query_label} took {elapsed:.2f} seconds") + query_result = results.setdefault(query_label, {"durations": []}) + query_result["durations"].append(round(elapsed, 3)) + else: + rows = df.collect() + elapsed = time.time() - start_time + if is_query: + row_count = len(rows) + row_hash = result_hash(rows) + print(f"Query {query_label} took {elapsed:.2f} seconds, returned {row_count} rows, hash={row_hash}") + query_result = results.setdefault(query_label, {"durations": []}) + query_result["durations"].append(round(elapsed, 3)) + if "row_count" not in query_result: + query_result["row_count"] = row_count + query_result["result_hash"] = row_hash iter_end_time = time.time() print(f"\nIteration {iteration + 1} took {iter_end_time - iter_start_time:.2f} seconds")