From 54ad798d0689cf3d4643edd5afb0bc43ec232307 Mon Sep 17 00:00:00 2001 From: Varun Date: Mon, 1 Jun 2026 09:11:51 -0700 Subject: [PATCH] Add STATS and EXTENDED_STATS compound aggregate expansion --- .../src/main/antlr4/OpenSearchSQLParser.g4 | 2 + sql/src/main/antlr/OpenSearchSQLParser.g4 | 2 + .../opensearch/sql/sql/parser/AstBuilder.java | 24 +- .../sql/parser/CompoundAggregateExpander.java | 171 ++++++++++++++ .../parser/context/QuerySpecification.java | 32 ++- .../sql/sql/parser/AstBuilderTest.java | 99 ++++++++ .../parser/CompoundAggregateExpanderTest.java | 219 ++++++++++++++++++ .../context/QuerySpecificationTest.java | 40 ++++ 8 files changed, 586 insertions(+), 3 deletions(-) create mode 100644 sql/src/main/java/org/opensearch/sql/sql/parser/CompoundAggregateExpander.java create mode 100644 sql/src/test/java/org/opensearch/sql/sql/parser/CompoundAggregateExpanderTest.java diff --git a/language-grammar/src/main/antlr4/OpenSearchSQLParser.g4 b/language-grammar/src/main/antlr4/OpenSearchSQLParser.g4 index 5f7361160b3..887ac4e4f54 100644 --- a/language-grammar/src/main/antlr4/OpenSearchSQLParser.g4 +++ b/language-grammar/src/main/antlr4/OpenSearchSQLParser.g4 @@ -500,6 +500,8 @@ aggregationFunctionName | STDDEV | STDDEV_POP | STDDEV_SAMP + | STATS + | EXTENDED_STATS ; mathematicalFunctionName diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 492f6dee9c6..bb92dfbc200 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -532,6 +532,8 @@ aggregationFunctionName | STDDEV | STDDEV_POP | STDDEV_SAMP + | STATS + | EXTENDED_STATS ; mathematicalFunctionName diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java index ee532a10ed9..edd915e6605 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java @@ -21,9 +21,11 @@ import com.google.common.collect.ImmutableList; import java.util.Collections; +import java.util.List; import java.util.Locale; import java.util.Optional; import org.antlr.v4.runtime.tree.ParseTree; +import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.Function; @@ -127,10 +129,30 @@ public UnresolvedPlan visitSelectClause(SelectClauseContext ctx) { if (ctx.selectElements().star != null) { // TODO: project operator should be required? builder.add(AllFields.of()); } - ctx.selectElements().selectElement().forEach(field -> builder.add(visitSelectItem(field))); + for (SelectElementContext element : ctx.selectElements().selectElement()) { + UnresolvedExpression item = visitSelectItem(element); + if (CompoundAggregateExpander.isCompoundAggregateAlias(item)) { + builder.addAll(expandCompoundAggregate(item)); + } else { + builder.add(item); + } + } return new Project(builder.build()); } + /** Expands a compound aggregate ({@code STATS} / {@code EXTENDED_STATS}) into its primitives. */ + private List expandCompoundAggregate(UnresolvedExpression item) { + Alias alias = (Alias) item; + AggregateFunction agg = (AggregateFunction) alias.getDelegated(); + String displayPrefix = alias.getAlias() != null ? alias.getAlias() : alias.getName(); + return CompoundAggregateExpander.expandAliased( + agg.getFuncName(), + agg.getField(), + agg.getField().toString(), + displayPrefix, + agg.condition()); + } + @Override public UnresolvedPlan visitLimitClause(OpenSearchSQLParser.LimitClauseContext ctx) { return new Limit( diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/CompoundAggregateExpander.java b/sql/src/main/java/org/opensearch/sql/sql/parser/CompoundAggregateExpander.java new file mode 100644 index 00000000000..c722e4f1b76 --- /dev/null +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/CompoundAggregateExpander.java @@ -0,0 +1,171 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql.parser; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.UnaryOperator; +import java.util.stream.Stream; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +/** + * Parse-time expansion utility for fixed-shape compound aggregate calls. A compound aggregate + * registered in {@link #COMPOUND_AGGREGATES} expands into a fixed list of primitive aggregates, + * emitted as bare expressions ({@link #expandPrimitives}) or {@link Alias}-wrapped columns ({@link + * #expandAliased}). + * + *

To register a new compound aggregate, add an entry to {@link #COMPOUND_AGGREGATES}. + * + *

Aggregates whose output column count depends on call arguments do not fit this registry. + */ +public final class CompoundAggregateExpander { + + /** + * Per-output-column descriptor for a compound aggregate's expansion. + * + * @param func primitive aggregate function name to emit + * @param suffix column-name suffix used in the generated alias + * @param operandTransform applied to the operand before passing to the primitive; defaults to + * identity. {@code sum_of_squares} uses {@link #SQUARE}. + */ + record Component( + String func, String suffix, UnaryOperator operandTransform) { + Component(String func, String suffix) { + this(func, suffix, UnaryOperator.identity()); + } + } + + private static final UnaryOperator SQUARE = + field -> new Function("*", List.of(field, field)); + + private static final List STATS_COMPONENTS = + List.of( + new Component("count", "count"), + new Component("sum", "sum"), + new Component("avg", "avg"), + new Component("min", "min"), + new Component("max", "max")); + + private static final List EXTENDED_STATS_COMPONENTS = + Stream.concat( + STATS_COMPONENTS.stream(), + Stream.of( + new Component("sum", "sumOfSquares", SQUARE), + new Component("var_pop", "variance"), + new Component("stddev_pop", "stdDeviation"))) + .toList(); + + private static final Map> COMPOUND_AGGREGATES = + Map.ofEntries( + Map.entry("STATS", STATS_COMPONENTS), + Map.entry("EXTENDED_STATS", EXTENDED_STATS_COMPONENTS)); + + private CompoundAggregateExpander() {} + + /** Returns true if {@code functionName} (case-insensitive) is a registered compound aggregate. */ + public static boolean isCompoundName(String functionName) { + return functionName != null + && COMPOUND_AGGREGATES.containsKey(functionName.toUpperCase(Locale.ROOT)); + } + + /** + * Returns true if {@code item} is an {@link Alias} wrapping an {@link AggregateFunction} with a + * compound function name — the shape SELECT-list parsing produces. + */ + public static boolean isCompoundAggregateAlias(UnresolvedExpression item) { + return item instanceof Alias alias + && alias.getDelegated() instanceof AggregateFunction agg + && isCompoundName(agg.getFuncName()); + } + + /** + * Returns true if {@code expr} is an un-aliased {@link AggregateFunction} with a compound + * function name — the shape the aggregator-collector visitor sees. + */ + public static boolean isCompoundAggregate(UnresolvedExpression expr) { + return expr instanceof AggregateFunction agg && isCompoundName(agg.getFuncName()); + } + + /** + * Expands a compound name into primitive {@link AggregateFunction} expressions in registered + * order. Each primitive is built over {@code field} pre-transformed by its component's {@link + * Component#operandTransform()}. {@code condition}, if non-null, is propagated to every + * primitive. + * + *

For named columns, use {@link #expandAliased}. + */ + public static List expandPrimitives( + String compoundName, UnresolvedExpression field, UnresolvedExpression condition) { + List components = resolveComponents(compoundName); + List primitives = new ArrayList<>(components.size()); + for (Component comp : components) { + UnresolvedExpression operand = comp.operandTransform().apply(field); + AggregateFunction primitive = new AggregateFunction(comp.func, operand); + if (condition != null) { + primitive.condition(condition); + } + primitives.add(primitive); + } + return primitives; + } + + /** + * Expands a compound call into primitives wrapped in {@link Alias}es. Each output Alias has: + * + *

+ * + *

Pass {@code displayPrefix == null} to produce internal-name-only Aliases for the + * aggregator-collection visitor, which has no display concerns. + * + * @param compoundName compound aggregate name (case-insensitive); must be registered + * @param field operand expression + * @param fieldText source-text of the operand, used in the internal name + * @param displayPrefix the part before the dot in V1's display format — typically the user's + * {@code AS} alias if given, otherwise the source text of the entire compound call (e.g. + * {@code STATS(price)}). Pass {@code null} to skip the display alias entirely. + * @param condition propagated to every primitive, or {@code null} for no filter + */ + public static List expandAliased( + String compoundName, + UnresolvedExpression field, + String fieldText, + String displayPrefix, + UnresolvedExpression condition) { + List components = resolveComponents(compoundName); + List primitives = expandPrimitives(compoundName, field, condition); + List aliasedColumns = new ArrayList<>(primitives.size()); + for (int i = 0; i < primitives.size(); i++) { + Component comp = components.get(i); + String internalName = comp.suffix() + "(" + fieldText + ")"; + String displayName = displayPrefix != null ? displayPrefix + "." + comp.suffix() : null; + aliasedColumns.add(new Alias(internalName, primitives.get(i), displayName)); + } + return aliasedColumns; + } + + private static List resolveComponents(String compoundName) { + if (compoundName == null) { + throw new IllegalArgumentException("compoundName is null"); + } + List components = COMPOUND_AGGREGATES.get(compoundName.toUpperCase(Locale.ROOT)); + if (components == null) { + throw new IllegalArgumentException("Not a compound aggregate: " + compoundName); + } + return components; + } +} diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/context/QuerySpecification.java b/sql/src/main/java/org/opensearch/sql/sql/parser/context/QuerySpecification.java index 5625371f058..93357d67f55 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/context/QuerySpecification.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/context/QuerySpecification.java @@ -26,6 +26,8 @@ import lombok.ToString; import org.antlr.v4.runtime.tree.ParseTree; import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.QualifiedName; @@ -38,6 +40,7 @@ import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.SelectSpecContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParserBaseVisitor; import org.opensearch.sql.sql.parser.AstExpressionBuilder; +import org.opensearch.sql.sql.parser.CompoundAggregateExpander; /** * Query specification domain that collects basic info for a simple query. @@ -221,17 +224,42 @@ public Void visitOrderByElement(OrderByElementContext ctx) { @Override public Void visitAggregateFunctionCall(AggregateFunctionCallContext ctx) { - aggregators.add(AstDSL.alias(getTextInQuery(ctx, queryString), visitAstExpression(ctx))); + UnresolvedExpression aggregateFunction = visitAstExpression(ctx); + if (CompoundAggregateExpander.isCompoundAggregate(aggregateFunction)) { + aggregators.addAll(expandCompoundAggregate((AggregateFunction) aggregateFunction)); + } else { + aggregators.add(AstDSL.alias(getTextInQuery(ctx, queryString), aggregateFunction)); + } return super.visitAggregateFunctionCall(ctx); } @Override public Void visitFilteredAggregationFunctionCall(FilteredAggregationFunctionCallContext ctx) { UnresolvedExpression aggregateFunction = visitAstExpression(ctx); - aggregators.add(AstDSL.alias(getTextInQuery(ctx, queryString), aggregateFunction)); + if (CompoundAggregateExpander.isCompoundAggregate(aggregateFunction)) { + aggregators.addAll(expandCompoundAggregate((AggregateFunction) aggregateFunction)); + } else { + aggregators.add(AstDSL.alias(getTextInQuery(ctx, queryString), aggregateFunction)); + } return super.visitFilteredAggregationFunctionCall(ctx); } + /** Expands a compound aggregate into its primitive aggregators (source-text named). */ + private List expandCompoundAggregate(AggregateFunction agg) { + List primitives = new ArrayList<>(); + for (UnresolvedExpression aliased : + CompoundAggregateExpander.expandAliased( + agg.getFuncName(), + agg.getField(), + agg.getField().toString(), + null, + agg.condition())) { + Alias innerAlias = (Alias) aliased; + primitives.add(AstDSL.alias(innerAlias.getName(), innerAlias.getDelegated())); + } + return primitives; + } + private boolean isDistinct(SelectSpecContext ctx) { return (ctx != null) && (ctx.DISTINCT() != null); } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java index 7869ba5cdad..368d83b439b 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java @@ -46,6 +46,7 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.NestedAllTupleFields; import org.opensearch.sql.ast.expression.UnresolvedArgument; +import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.SubqueryAlias; import org.opensearch.sql.ast.tree.TableFunction; @@ -770,4 +771,102 @@ public void exists_subquery_throws_syntax_check_exception() { SyntaxCheckException.class, () -> buildAST("SELECT * FROM t WHERE EXISTS (SELECT 1 FROM t2)")); } + + @Test + public void can_build_compound_aggregate_stats() { + assertEquals( + project( + agg( + relation("test"), + buildCompoundAggregators("STATS"), + emptyList(), + emptyList(), + emptyList()), + buildStatsProjectColumns("STATS(age)")), + buildAST("SELECT STATS(age) FROM test")); + } + + @Test + public void can_build_compound_aggregate_with_as_alias() { + assertEquals( + project( + agg( + relation("test"), + buildCompoundAggregators("STATS"), + emptyList(), + emptyList(), + emptyList()), + buildStatsProjectColumns("p")), + buildAST("SELECT STATS(age) AS p FROM test")); + } + + @Test + public void can_build_compound_aggregate_dedupes_with_explicit_primitive() { + // Aggregators dedupe (5), but project keeps every SELECT-list column (6). + assertEquals( + project( + agg( + relation("test"), + buildCompoundAggregators("STATS"), + emptyList(), + emptyList(), + emptyList()), + alias("count(age)", aggregate("count", qualifiedName("age")), "STATS(age).count"), + alias("sum(age)", aggregate("sum", qualifiedName("age")), "STATS(age).sum"), + alias("avg(age)", aggregate("avg", qualifiedName("age")), "STATS(age).avg"), + alias("min(age)", aggregate("min", qualifiedName("age")), "STATS(age).min"), + alias("max(age)", aggregate("max", qualifiedName("age")), "STATS(age).max"), + alias("avg(age)", aggregate("avg", qualifiedName("age")))), + buildAST("SELECT STATS(age), avg(age) FROM test")); + } + + @Test + public void can_build_compound_aggregate_with_group_by() { + UnresolvedExpression cityGroup = alias("city", qualifiedName("city")); + assertEquals( + project( + agg( + relation("test"), + buildCompoundAggregators("STATS"), + emptyList(), + ImmutableList.of(cityGroup), + emptyList()), + cityGroup, + alias("count(age)", aggregate("count", qualifiedName("age")), "STATS(age).count"), + alias("sum(age)", aggregate("sum", qualifiedName("age")), "STATS(age).sum"), + alias("avg(age)", aggregate("avg", qualifiedName("age")), "STATS(age).avg"), + alias("min(age)", aggregate("min", qualifiedName("age")), "STATS(age).min"), + alias("max(age)", aggregate("max", qualifiedName("age")), "STATS(age).max")), + buildAST("SELECT city, STATS(age) FROM test GROUP BY city")); + } + + /** + * Builds the five primitive aggregator expressions {@code STATS(age)} expands to, all referencing + * the {@code age} field via source-text internal names ({@code count(age)}, {@code sum(age)}, …). + * Used as the {@code aggExprList} for both Project and Aggregation assertions. + */ + private ImmutableList buildCompoundAggregators(String compoundName) { + return ImmutableList.of( + alias("count(age)", aggregate("count", qualifiedName("age"))), + alias("sum(age)", aggregate("sum", qualifiedName("age"))), + alias("avg(age)", aggregate("avg", qualifiedName("age"))), + alias("min(age)", aggregate("min", qualifiedName("age"))), + alias("max(age)", aggregate("max", qualifiedName("age")))); + } + + /** + * Project entries for {@code STATS(age)} expansion: same internal names as {@link + * #buildCompoundAggregators}, with display aliases formed as {@code .}. {@code + * prefix} is either the user's AS alias (e.g. {@code "p"}) or the source-text of the compound + * call (e.g. {@code "STATS(age)"}). + */ + private UnresolvedExpression[] buildStatsProjectColumns(String prefix) { + return new UnresolvedExpression[] { + alias("count(age)", aggregate("count", qualifiedName("age")), prefix + ".count"), + alias("sum(age)", aggregate("sum", qualifiedName("age")), prefix + ".sum"), + alias("avg(age)", aggregate("avg", qualifiedName("age")), prefix + ".avg"), + alias("min(age)", aggregate("min", qualifiedName("age")), prefix + ".min"), + alias("max(age)", aggregate("max", qualifiedName("age")), prefix + ".max") + }; + } } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/CompoundAggregateExpanderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/CompoundAggregateExpanderTest.java new file mode 100644 index 00000000000..cf8bbc87762 --- /dev/null +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/CompoundAggregateExpanderTest.java @@ -0,0 +1,219 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql.parser; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +class CompoundAggregateExpanderTest { + + @Test + void recognizes_stats_case_insensitively() { + assertTrue(CompoundAggregateExpander.isCompoundName("STATS")); + assertTrue(CompoundAggregateExpander.isCompoundName("stats")); + assertTrue(CompoundAggregateExpander.isCompoundName("Stats")); + } + + @Test + void recognizes_extended_stats_case_insensitively() { + assertTrue(CompoundAggregateExpander.isCompoundName("EXTENDED_STATS")); + assertTrue(CompoundAggregateExpander.isCompoundName("extended_stats")); + } + + @Test + void rejects_non_compound_names() { + assertFalse(CompoundAggregateExpander.isCompoundName("SUM")); + assertFalse(CompoundAggregateExpander.isCompoundName("AVG")); + assertFalse(CompoundAggregateExpander.isCompoundName("")); + assertFalse(CompoundAggregateExpander.isCompoundName(null)); + } + + @Test + void stats_expands_to_five_primitives_in_canonical_order() { + QualifiedName price = AstDSL.qualifiedName("price"); + + List primitives = + CompoundAggregateExpander.expandPrimitives("STATS", price, null); + + assertEquals( + ImmutableList.of( + new AggregateFunction("count", price), + new AggregateFunction("sum", price), + new AggregateFunction("avg", price), + new AggregateFunction("min", price), + new AggregateFunction("max", price)), + primitives); + } + + @Test + void extended_stats_expands_to_eight_primitives_in_canonical_order() { + QualifiedName price = AstDSL.qualifiedName("price"); + + List primitives = + CompoundAggregateExpander.expandPrimitives("EXTENDED_STATS", price, null); + + assertEquals( + ImmutableList.of( + new AggregateFunction("count", price), + new AggregateFunction("sum", price), + new AggregateFunction("avg", price), + new AggregateFunction("min", price), + new AggregateFunction("max", price), + new AggregateFunction("sum", new Function("*", List.of(price, price))), + new AggregateFunction("var_pop", price), + new AggregateFunction("stddev_pop", price)), + primitives); + } + + @Test + void aliased_without_display_prefix_yields_internal_names_only() { + QualifiedName price = AstDSL.qualifiedName("price"); + + List aliased = + CompoundAggregateExpander.expandAliased("EXTENDED_STATS", price, "price", null, null); + + assertEquals( + ImmutableList.of( + "count(price)", + "sum(price)", + "avg(price)", + "min(price)", + "max(price)", + "sumOfSquares(price)", + "variance(price)", + "stdDeviation(price)"), + aliased.stream().map(a -> ((Alias) a).getName()).toList()); + aliased.forEach(a -> assertNull(((Alias) a).getAlias())); + } + + @Test + void aliased_with_display_prefix_emits_dot_separated_display_names() { + QualifiedName price = AstDSL.qualifiedName("price"); + + List aliased = + CompoundAggregateExpander.expandAliased("EXTENDED_STATS", price, "price", "x", null); + + assertEquals( + ImmutableList.of( + "count(price)", + "sum(price)", + "avg(price)", + "min(price)", + "max(price)", + "sumOfSquares(price)", + "variance(price)", + "stdDeviation(price)"), + aliased.stream().map(a -> ((Alias) a).getName()).toList()); + assertEquals( + ImmutableList.of( + "x.count", + "x.sum", + "x.avg", + "x.min", + "x.max", + "x.sumOfSquares", + "x.variance", + "x.stdDeviation"), + aliased.stream().map(a -> ((Alias) a).getAlias()).toList()); + } + + @Test + void source_text_display_prefix_emits_call_dot_suffix_format() { + // No-AS scenario: displayPrefix is the source text of the call. + QualifiedName price = AstDSL.qualifiedName("price"); + + List aliased = + CompoundAggregateExpander.expandAliased("STATS", price, "price", "STATS(price)", null); + + assertEquals( + ImmutableList.of( + "STATS(price).count", + "STATS(price).sum", + "STATS(price).avg", + "STATS(price).min", + "STATS(price).max"), + aliased.stream().map(a -> ((Alias) a).getAlias()).toList()); + } + + @Test + void condition_propagates_to_each_primitive_when_provided() { + QualifiedName price = AstDSL.qualifiedName("price"); + UnresolvedExpression condition = AstDSL.function(">", price, AstDSL.intLiteral(0)); + + List primitives = + CompoundAggregateExpander.expandPrimitives("EXTENDED_STATS", price, condition); + + primitives.forEach(expr -> assertEquals(condition, ((AggregateFunction) expr).condition())); + } + + @Test + void condition_is_null_on_each_primitive_when_not_provided() { + QualifiedName price = AstDSL.qualifiedName("price"); + + List primitives = + CompoundAggregateExpander.expandPrimitives("STATS", price, null); + + primitives.forEach(expr -> assertNull(((AggregateFunction) expr).condition())); + } + + @Test + void expand_primitives_rejects_unknown_compound_name() { + QualifiedName price = AstDSL.qualifiedName("price"); + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, + () -> CompoundAggregateExpander.expandPrimitives("PERCENTILES", price, null)); + assertTrue(ex.getMessage().contains("PERCENTILES")); + } + + @Test + void expand_primitives_rejects_null_compound_name() { + QualifiedName price = AstDSL.qualifiedName("price"); + assertThrows( + IllegalArgumentException.class, + () -> CompoundAggregateExpander.expandPrimitives(null, price, null)); + } + + @Test + void is_compound_aggregate_alias_true_only_for_alias_wrapping_compound_aggregate() { + QualifiedName price = AstDSL.qualifiedName("price"); + Alias compound = new Alias("STATS(price)", new AggregateFunction("STATS", price)); + Alias nonCompound = new Alias("AVG(price)", new AggregateFunction("AVG", price)); + Alias nonAggregate = new Alias("price", price); + + assertTrue(CompoundAggregateExpander.isCompoundAggregateAlias(compound)); + assertFalse(CompoundAggregateExpander.isCompoundAggregateAlias(nonCompound)); + assertFalse(CompoundAggregateExpander.isCompoundAggregateAlias(nonAggregate)); + assertFalse(CompoundAggregateExpander.isCompoundAggregateAlias(price)); + } + + @Test + void is_compound_aggregate_true_only_for_compound_aggregate_function() { + QualifiedName price = AstDSL.qualifiedName("price"); + AggregateFunction compound = new AggregateFunction("EXTENDED_STATS", price); + AggregateFunction nonCompound = new AggregateFunction("SUM", price); + + assertTrue(CompoundAggregateExpander.isCompoundAggregate(compound)); + assertFalse(CompoundAggregateExpander.isCompoundAggregate(nonCompound)); + assertFalse(CompoundAggregateExpander.isCompoundAggregate(price)); + } +} diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/context/QuerySpecificationTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/context/QuerySpecificationTest.java index 6dd027a74cf..74fd5d1ab3b 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/context/QuerySpecificationTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/context/QuerySpecificationTest.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener; @@ -129,6 +130,45 @@ void can_collect_filtered_aggregation() { collect("SELECT AVG(age) FILTER(WHERE age > 20) FROM test").getAggregators()); } + @Test + void can_collect_compound_aggregate_as_expanded_primitives() { + assertEquals( + ImmutableSet.of( + alias("count(age)", aggregate("count", qualifiedName("age"))), + alias("sum(age)", aggregate("sum", qualifiedName("age"))), + alias("avg(age)", aggregate("avg", qualifiedName("age"))), + alias("min(age)", aggregate("min", qualifiedName("age"))), + alias("max(age)", aggregate("max", qualifiedName("age")))), + collect("SELECT STATS(age) FROM test").getAggregators()); + } + + @Test + void can_collect_filtered_compound_aggregate_with_propagated_condition() { + UnresolvedExpression condition = function(">", qualifiedName("age"), intLiteral(0)); + assertEquals( + ImmutableSet.of( + alias("count(age)", filteredAggregate("count", qualifiedName("age"), condition)), + alias("sum(age)", filteredAggregate("sum", qualifiedName("age"), condition)), + alias("avg(age)", filteredAggregate("avg", qualifiedName("age"), condition)), + alias("min(age)", filteredAggregate("min", qualifiedName("age"), condition)), + alias("max(age)", filteredAggregate("max", qualifiedName("age"), condition))), + collect("SELECT STATS(age) FILTER(WHERE age > 0) FROM test").getAggregators()); + } + + @Test + void should_deduplicate_compound_aggregate_against_explicit_primitive() { + // The compound expansion's internal name format ensures the implicit avg(age) entry + // equals an explicit avg(age) — LinkedHashSet dedupes them into a single aggregator. + assertEquals( + ImmutableSet.of( + alias("count(age)", aggregate("count", qualifiedName("age"))), + alias("sum(age)", aggregate("sum", qualifiedName("age"))), + alias("avg(age)", aggregate("avg", qualifiedName("age"))), + alias("min(age)", aggregate("min", qualifiedName("age"))), + alias("max(age)", aggregate("max", qualifiedName("age")))), + collect("SELECT STATS(age), avg(age) FROM test").getAggregators()); + } + private QuerySpecification collect(String query) { QuerySpecification querySpec = new QuerySpecification(); querySpec.collect(parse(query), query);