Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions language-grammar/src/main/antlr4/OpenSearchSQLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,8 @@ aggregationFunctionName
| STDDEV
| STDDEV_POP
| STDDEV_SAMP
| STATS
| EXTENDED_STATS
;

mathematicalFunctionName
Expand Down
2 changes: 2 additions & 0 deletions sql/src/main/antlr/OpenSearchSQLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,8 @@ aggregationFunctionName
| STDDEV
| STDDEV_POP
| STDDEV_SAMP
| STATS
| EXTENDED_STATS
;

mathematicalFunctionName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@

import com.google.common.collect.ImmutableList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import org.antlr.v4.runtime.tree.ParseTree;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.AllFields;
import org.opensearch.sql.ast.expression.Function;
Expand Down Expand Up @@ -127,10 +129,30 @@ public UnresolvedPlan visitSelectClause(SelectClauseContext ctx) {
if (ctx.selectElements().star != null) { // TODO: project operator should be required?
builder.add(AllFields.of());
}
ctx.selectElements().selectElement().forEach(field -> builder.add(visitSelectItem(field)));
for (SelectElementContext element : ctx.selectElements().selectElement()) {
UnresolvedExpression item = visitSelectItem(element);
if (CompoundAggregateExpander.isCompoundAggregateAlias(item)) {
builder.addAll(expandCompoundAggregate(item));
} else {
builder.add(item);
}
}
return new Project(builder.build());
}

/** Expands a compound aggregate ({@code STATS} / {@code EXTENDED_STATS}) into its primitives. */
private List<UnresolvedExpression> expandCompoundAggregate(UnresolvedExpression item) {
Alias alias = (Alias) item;
AggregateFunction agg = (AggregateFunction) alias.getDelegated();
String displayPrefix = alias.getAlias() != null ? alias.getAlias() : alias.getName();
return CompoundAggregateExpander.expandAliased(
agg.getFuncName(),
agg.getField(),
agg.getField().toString(),
displayPrefix,
agg.condition());
}

@Override
public UnresolvedPlan visitLimitClause(OpenSearchSQLParser.LimitClauseContext ctx) {
return new Limit(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.sql.parser;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.UnaryOperator;
import java.util.stream.Stream;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

/**
* Parse-time expansion utility for fixed-shape compound aggregate calls. A compound aggregate
* registered in {@link #COMPOUND_AGGREGATES} expands into a fixed list of primitive aggregates,
* emitted as bare expressions ({@link #expandPrimitives}) or {@link Alias}-wrapped columns ({@link
* #expandAliased}).
*
* <p>To register a new compound aggregate, add an entry to {@link #COMPOUND_AGGREGATES}.
*
* <p>Aggregates whose output column count depends on call arguments do not fit this registry.
*/
public final class CompoundAggregateExpander {

/**
* Per-output-column descriptor for a compound aggregate's expansion.
*
* @param func primitive aggregate function name to emit
* @param suffix column-name suffix used in the generated alias
* @param operandTransform applied to the operand before passing to the primitive; defaults to
* identity. {@code sum_of_squares} uses {@link #SQUARE}.
*/
record Component(
String func, String suffix, UnaryOperator<UnresolvedExpression> operandTransform) {
Component(String func, String suffix) {
this(func, suffix, UnaryOperator.identity());
}
}

private static final UnaryOperator<UnresolvedExpression> SQUARE =
field -> new Function("*", List.of(field, field));

private static final List<Component> STATS_COMPONENTS =
List.of(
new Component("count", "count"),
new Component("sum", "sum"),
new Component("avg", "avg"),
new Component("min", "min"),
new Component("max", "max"));

private static final List<Component> EXTENDED_STATS_COMPONENTS =
Stream.concat(
STATS_COMPONENTS.stream(),
Stream.of(
new Component("sum", "sumOfSquares", SQUARE),
new Component("var_pop", "variance"),
new Component("stddev_pop", "stdDeviation")))
.toList();

private static final Map<String, List<Component>> COMPOUND_AGGREGATES =
Map.ofEntries(
Map.entry("STATS", STATS_COMPONENTS),
Map.entry("EXTENDED_STATS", EXTENDED_STATS_COMPONENTS));

private CompoundAggregateExpander() {}

/** Returns true if {@code functionName} (case-insensitive) is a registered compound aggregate. */
public static boolean isCompoundName(String functionName) {
return functionName != null
&& COMPOUND_AGGREGATES.containsKey(functionName.toUpperCase(Locale.ROOT));
}

/**
* Returns true if {@code item} is an {@link Alias} wrapping an {@link AggregateFunction} with a
* compound function name — the shape SELECT-list parsing produces.
*/
public static boolean isCompoundAggregateAlias(UnresolvedExpression item) {
return item instanceof Alias alias
&& alias.getDelegated() instanceof AggregateFunction agg
&& isCompoundName(agg.getFuncName());
}

/**
* Returns true if {@code expr} is an un-aliased {@link AggregateFunction} with a compound
* function name — the shape the aggregator-collector visitor sees.
*/
public static boolean isCompoundAggregate(UnresolvedExpression expr) {
return expr instanceof AggregateFunction agg && isCompoundName(agg.getFuncName());
}

/**
* Expands a compound name into primitive {@link AggregateFunction} expressions in registered
* order. Each primitive is built over {@code field} pre-transformed by its component's {@link
* Component#operandTransform()}. {@code condition}, if non-null, is propagated to every
* primitive.
*
* <p>For named columns, use {@link #expandAliased}.
*/
public static List<UnresolvedExpression> expandPrimitives(
String compoundName, UnresolvedExpression field, UnresolvedExpression condition) {
List<Component> components = resolveComponents(compoundName);
List<UnresolvedExpression> primitives = new ArrayList<>(components.size());
for (Component comp : components) {
UnresolvedExpression operand = comp.operandTransform().apply(field);
AggregateFunction primitive = new AggregateFunction(comp.func, operand);
if (condition != null) {
primitive.condition(condition);
}
primitives.add(primitive);
}
return primitives;
}

/**
* Expands a compound call into primitives wrapped in {@link Alias}es. Each output Alias has:
*
* <ul>
* <li>{@code name} — internal lookup key {@code <suffix>(<fieldText>)} (e.g. {@code
* count(price)}). Used by aggregator dedup against explicit primitive aggregates the user
* might also have written.
* <li>{@code alias} — user-visible display name in V1 format {@code <displayPrefix>.<suffix>}
* (e.g. {@code STATS(price).count}, {@code p.sumOfSquares}). Set only when {@code
* displayPrefix} is non-null.
* </ul>
*
* <p>Pass {@code displayPrefix == null} to produce internal-name-only Aliases for the
* aggregator-collection visitor, which has no display concerns.
*
* @param compoundName compound aggregate name (case-insensitive); must be registered
* @param field operand expression
* @param fieldText source-text of the operand, used in the internal name
* @param displayPrefix the part before the dot in V1's display format — typically the user's
* {@code AS} alias if given, otherwise the source text of the entire compound call (e.g.
* {@code STATS(price)}). Pass {@code null} to skip the display alias entirely.
* @param condition propagated to every primitive, or {@code null} for no filter
*/
public static List<UnresolvedExpression> expandAliased(
String compoundName,
UnresolvedExpression field,
String fieldText,
String displayPrefix,
UnresolvedExpression condition) {
List<Component> components = resolveComponents(compoundName);
List<UnresolvedExpression> primitives = expandPrimitives(compoundName, field, condition);
List<UnresolvedExpression> aliasedColumns = new ArrayList<>(primitives.size());
for (int i = 0; i < primitives.size(); i++) {
Component comp = components.get(i);
String internalName = comp.suffix() + "(" + fieldText + ")";
String displayName = displayPrefix != null ? displayPrefix + "." + comp.suffix() : null;
aliasedColumns.add(new Alias(internalName, primitives.get(i), displayName));
}
return aliasedColumns;
}

private static List<Component> resolveComponents(String compoundName) {
if (compoundName == null) {
throw new IllegalArgumentException("compoundName is null");
}
List<Component> components = COMPOUND_AGGREGATES.get(compoundName.toUpperCase(Locale.ROOT));
if (components == null) {
throw new IllegalArgumentException("Not a compound aggregate: " + compoundName);
}
return components;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import lombok.ToString;
import org.antlr.v4.runtime.tree.ParseTree;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.AggregateFunction;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.QualifiedName;
Expand All @@ -38,6 +40,7 @@
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.SelectSpecContext;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParserBaseVisitor;
import org.opensearch.sql.sql.parser.AstExpressionBuilder;
import org.opensearch.sql.sql.parser.CompoundAggregateExpander;

/**
* Query specification domain that collects basic info for a simple query.
Expand Down Expand Up @@ -221,17 +224,42 @@ public Void visitOrderByElement(OrderByElementContext ctx) {

@Override
public Void visitAggregateFunctionCall(AggregateFunctionCallContext ctx) {
aggregators.add(AstDSL.alias(getTextInQuery(ctx, queryString), visitAstExpression(ctx)));
UnresolvedExpression aggregateFunction = visitAstExpression(ctx);
if (CompoundAggregateExpander.isCompoundAggregate(aggregateFunction)) {
aggregators.addAll(expandCompoundAggregate((AggregateFunction) aggregateFunction));
} else {
aggregators.add(AstDSL.alias(getTextInQuery(ctx, queryString), aggregateFunction));
}
return super.visitAggregateFunctionCall(ctx);
}

@Override
public Void visitFilteredAggregationFunctionCall(FilteredAggregationFunctionCallContext ctx) {
UnresolvedExpression aggregateFunction = visitAstExpression(ctx);
aggregators.add(AstDSL.alias(getTextInQuery(ctx, queryString), aggregateFunction));
if (CompoundAggregateExpander.isCompoundAggregate(aggregateFunction)) {
aggregators.addAll(expandCompoundAggregate((AggregateFunction) aggregateFunction));
} else {
aggregators.add(AstDSL.alias(getTextInQuery(ctx, queryString), aggregateFunction));
}
return super.visitFilteredAggregationFunctionCall(ctx);
}

/** Expands a compound aggregate into its primitive aggregators (source-text named). */
private List<UnresolvedExpression> expandCompoundAggregate(AggregateFunction agg) {
List<UnresolvedExpression> primitives = new ArrayList<>();
for (UnresolvedExpression aliased :
CompoundAggregateExpander.expandAliased(
agg.getFuncName(),
agg.getField(),
agg.getField().toString(),
null,
agg.condition())) {
Alias innerAlias = (Alias) aliased;
primitives.add(AstDSL.alias(innerAlias.getName(), innerAlias.getDelegated()));
}
return primitives;
}

private boolean isDistinct(SelectSpecContext ctx) {
return (ctx != null) && (ctx.DISTINCT() != null);
}
Expand Down
Loading
Loading