Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions benchmarks/tpc/generate-comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
import logging
import matplotlib.pyplot as plt
import numpy as np
import re

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

QUERY_KEY_RE = re.compile(r'^(\d+)([a-z]*)$')

def geomean(data):
return np.prod(data) ** (1 / len(data))

Expand All @@ -34,19 +37,20 @@ def get_durations(result, query_key):
return value["durations"]
return value

def query_sort_key(key):
"""Sort key for query labels like "14", "14a", "14b" so sub-queries sit between 14 and 15."""
m = QUERY_KEY_RE.match(str(key))
if m:
return (int(m.group(1)), m.group(2))
return (float('inf'), str(key))

def get_all_queries(results):
"""Return the sorted union of all query keys across all result sets."""
"""Return the sorted union of query keys across all result sets, as strings."""
all_keys = set()
for result in results:
all_keys.update(result.keys())
# Filter to numeric query keys and sort numerically
numeric_keys = []
for k in all_keys:
try:
numeric_keys.append(int(k))
except ValueError:
pass
return sorted(numeric_keys)
query_keys = [str(k) for k in all_keys if QUERY_KEY_RE.match(str(k))]
return sorted(query_keys, key=query_sort_key)

def get_common_queries(results, labels):
"""Return queries present in ALL result sets, warning about queries missing from some files."""
Expand Down
89 changes: 58 additions & 31 deletions benchmarks/tpc/tpcbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ def result_hash(rows):
return h.hexdigest()


def first_keyword(sql):
"""Return the first non-comment, non-whitespace keyword in lowercase."""
for line in sql.splitlines():
stripped = line.strip()
if not stripped or stripped.startswith("--"):
continue
return stripped.split(None, 1)[0].lstrip("(").lower()
return ""


def is_select_statement(sql):
"""Classify a SQL statement: True if it is a SELECT/WITH query, False if DDL."""
return first_keyword(sql) in ("select", "with")


def main(
benchmark: str,
data_path: str,
Expand Down Expand Up @@ -147,44 +162,56 @@ def main(
queries_to_run = range(1, num_queries + 1)

for query in queries_to_run:
spark.sparkContext.setJobDescription(f"{benchmark} q{query}")

path = f"{query_path}/q{query}.sql"
print(f"\nRunning query {query} from {path}")

with open(path, "r") as f:
text = f.read()
queries = text.split(";")

statements = [s.strip() for s in text.split(";") if s.strip()]
select_indices = [i for i, s in enumerate(statements) if is_select_statement(s)]
multi = len(select_indices) > 1

for stmt_idx, sql in enumerate(statements):
sql = sql.replace("create view", "create temp view")
is_query = stmt_idx in select_indices
if is_query:
if multi:
suffix = chr(ord("a") + select_indices.index(stmt_idx))
query_label = f"{query}{suffix}"
else:
query_label = str(query)
spark.sparkContext.setJobDescription(f"{benchmark} q{query_label}")
print(f"Executing query {query_label}: {sql[:100]}...")
else:
print(f"Executing DDL (not timed): {sql[:100]}...")

start_time = time.time()
for sql in queries:
sql = sql.strip().replace("create view", "create temp view")
if len(sql) > 0:
print(f"Executing: {sql[:100]}...")
df = spark.sql(sql)
df.explain("formatted")

if write_path is not None:
if len(df.columns) > 0:
output_path = f"{write_path}/q{query}"
deduped = dedup_columns(df)
deduped.orderBy(*deduped.columns).coalesce(1).write.mode("overwrite").parquet(output_path)
print(f"Results written to {output_path}")
else:
rows = df.collect()
row_count = len(rows)
row_hash = result_hash(rows)
print(f"Query {query} returned {row_count} rows, hash={row_hash}")

end_time = time.time()
elapsed = end_time - start_time
print(f"Query {query} took {elapsed:.2f} seconds")

query_result = results.setdefault(query, {"durations": []})
query_result["durations"].append(round(elapsed, 3))
if "row_count" not in query_result and not write_path:
query_result["row_count"] = row_count
query_result["result_hash"] = row_hash
df = spark.sql(sql)
df.explain("formatted")

if is_query and write_path is not None:
if len(df.columns) > 0:
output_path = f"{write_path}/q{query_label}"
deduped = dedup_columns(df)
deduped.orderBy(*deduped.columns).coalesce(1).write.mode("overwrite").parquet(output_path)
print(f"Results written to {output_path}")
elapsed = time.time() - start_time
print(f"Query {query_label} took {elapsed:.2f} seconds")
query_result = results.setdefault(query_label, {"durations": []})
query_result["durations"].append(round(elapsed, 3))
else:
rows = df.collect()
elapsed = time.time() - start_time
if is_query:
row_count = len(rows)
row_hash = result_hash(rows)
print(f"Query {query_label} took {elapsed:.2f} seconds, returned {row_count} rows, hash={row_hash}")
query_result = results.setdefault(query_label, {"durations": []})
query_result["durations"].append(round(elapsed, 3))
if "row_count" not in query_result:
query_result["row_count"] = row_count
query_result["result_hash"] = row_hash

iter_end_time = time.time()
print(f"\nIteration {iteration + 1} took {iter_end_time - iter_start_time:.2f} seconds")
Expand Down
Loading