diff --git a/.gitignore b/.gitignore index 8cc62b0..77be770 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,16 @@ examples/*/main examples/*/main.exe +# Example binaries built without -o (named after the directory) +examples/basic/basic +examples/comprehensions/comprehensions +examples/context/context +examples/index_analysis/index_analysis +examples/load_table_schema/load_table_schema +examples/logging/logging +examples/parameterized/parameterized +examples/string_extensions/string_extensions + # Claude Code settings .claude/settings.local.json diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c5e910..669f0ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,22 @@ # Changelog ## [Unreleased] +### Fixed +- **JSON array membership (`in`) now generates a correct boolean predicate on + every dialect.** Each dialect now owns the full predicate, emitting both the + element and the array expression, instead of relying on the caller to prepend + `elem = `. This resolves semantically wrong SQL on several dialects: + - **MySQL**: switched from `JSON_CONTAINS(arr, CAST(? AS JSON))` (which + emitted a stray `?` and ignored the element) to + `JSON_OVERLAPS(JSON_ARRAY(elem), arr)`. + - **SQLite/DuckDB**: switched from a bare `(SELECT value FROM json_each(arr))` + scalar subquery to `EXISTS (SELECT 1 FROM json_each(arr) WHERE value = elem)`. + - **BigQuery**: switched from the invalid `= UNNEST(...)` form to + `elem IN UNNEST(JSON_VALUE_ARRAY(arr))`. + - **PostgreSQL**: unchanged semantics (`elem = ANY(ARRAY(SELECT jsonFunc(arr)))`). + - **Spark**: `array_contains(from_json(arr, 'ARRAY'), elem)`. + + Ported from cel2sql4j ([SPANDigital/cel2sql4j@1835215](https://github.com/SPANDigital/cel2sql4j/commit/1835215bb1244b3b15c82315f264354566cfa499)). ## [3.8.4] - 2026-06-08 ### Changed diff --git a/bigquery/provider.go b/bigquery/provider.go index ce6dd57..5981843 100644 --- a/bigquery/provider.go +++ b/bigquery/provider.go @@ -10,9 +10,9 @@ import ( bq "cloud.google.com/go/bigquery" "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/spandigital/cel2sql/v3/internal/celprovider" "github.com/spandigital/cel2sql/v3/schema" ) @@ -40,14 +40,14 @@ type TypeProvider interface { } type typeProvider struct { - schemas map[string]Schema + celprovider.Base client *bq.Client datasetID string } // NewTypeProvider creates a new BigQuery type provider with pre-defined schemas. func NewTypeProvider(schemas map[string]Schema) TypeProvider { - return &typeProvider{schemas: schemas} + return &typeProvider{Base: celprovider.Base{Schemas: schemas, Mapper: bigqueryTypeToCELExprType}} } // NewTypeProviderWithClient creates a new BigQuery type provider that can introspect database schemas. @@ -61,7 +61,7 @@ func NewTypeProviderWithClient(_ context.Context, client *bq.Client, datasetID s } return &typeProvider{ - schemas: make(map[string]Schema), + Base: celprovider.Base{Schemas: make(map[string]Schema), Mapper: bigqueryTypeToCELExprType}, client: client, datasetID: datasetID, }, nil @@ -83,7 +83,7 @@ func (tp *typeProvider) LoadTableSchema(ctx context.Context, tableName string) e return fmt.Errorf("%w: table %q has no columns", ErrInvalidSchema, tableName) } - tp.schemas[tableName] = NewSchema(fields) + tp.Schemas[tableName] = NewSchema(fields) return nil } @@ -127,75 +127,6 @@ func bigqueryFieldTypeToString(ft bq.FieldType) string { return strings.ToLower(string(ft)) } -// Close is a no-op since we don't own the *bigquery.Client. -func (tp *typeProvider) Close() { - // No-op: caller owns the *bigquery.Client connection -} - -// GetSchemas returns the schemas known to this type provider. -func (tp *typeProvider) GetSchemas() map[string]Schema { - return tp.schemas -} - -// EnumValue implements types.Provider. -func (tp *typeProvider) EnumValue(_ string) ref.Val { - return types.NewErr("unknown enum value") -} - -// FindIdent implements types.Provider. -func (tp *typeProvider) FindIdent(_ string) (ref.Val, bool) { - return nil, false -} - -// FindStructType implements types.Provider. -func (tp *typeProvider) FindStructType(structType string) (*types.Type, bool) { - if _, ok := tp.schemas[structType]; ok { - return types.NewObjectType(structType), true - } - return nil, false -} - -// FindStructFieldNames implements types.Provider. -func (tp *typeProvider) FindStructFieldNames(structType string) ([]string, bool) { - s, ok := tp.schemas[structType] - if !ok { - return nil, false - } - fields := s.Fields() - names := make([]string, len(fields)) - for i, f := range fields { - names[i] = f.Name - } - return names, true -} - -// FindStructFieldType implements types.Provider. -func (tp *typeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { - s, ok := tp.schemas[structType] - if !ok { - return nil, false - } - field, found := s.FindField(fieldName) - if !found { - return nil, false - } - - exprType := bigqueryTypeToCELExprType(field) - celType, err := types.ExprTypeToType(exprType) - if err != nil { - return nil, false - } - - return &types.FieldType{ - Type: celType, - }, true -} - -// NewValue implements types.Provider. -func (tp *typeProvider) NewValue(_ string, _ map[string]ref.Val) ref.Val { - return types.NewErr("unknown type in schema") -} - // bigqueryTypeToCELExprType converts a BigQuery field schema to a CEL expression type. func bigqueryTypeToCELExprType(field *schema.FieldSchema) *exprpb.Type { baseType := bigqueryBaseTypeToCEL(field.Type) diff --git a/cel2sql.go b/cel2sql.go index 9a7032c..81e61fc 100644 --- a/cel2sql.go +++ b/cel2sql.go @@ -702,12 +702,25 @@ func (con *converter) visitCallBinary(expr *exprpb.Expr) error { ) } - // Handle array membership (IN operator with list) via dialect before writing LHS. - // This allows dialects like SQLite to use a fundamentally different pattern - // (e.g., "elem IN (SELECT value FROM json_each(array))") instead of "elem = ANY(array)". - if fun == operators.In && isListType(rhsType) { - // Non-JSON list membership - if !isFieldAccessExpression(rhs) || !con.isJSONArrayField(rhs) { + // Handle array membership (IN operator) via dialect before writing LHS. + // This allows each dialect to own the complete boolean predicate, using a + // fundamentally different pattern (e.g., SQLite's + // "EXISTS (SELECT 1 FROM json_each(array) WHERE value = elem)") instead of + // the caller emitting "elem = ANY(array)". + if fun == operators.In && (isListType(rhsType) || isFieldAccessExpression(rhs)) { + // JSON array membership: the dialect emits both the element and array. + if isFieldAccessExpression(rhs) && con.isJSONArrayField(rhs) { + writeElem := func() error { return con.visitMaybeNested(lhs, lhsParen) } + if con.isNestedJSONAccess(rhs) { + return con.dialect.WriteNestedJSONArrayMembership(&con.str, writeElem, + func() error { return con.visitNestedJSONForArray(rhs) }) + } + jsonFunc := con.getJSONArrayFunction(rhs) + return con.dialect.WriteJSONArrayMembership(&con.str, jsonFunc, writeElem, + func() error { return con.visitMaybeNested(rhs, rhsParen) }) + } + // Non-JSON list membership. + if isListType(rhsType) { return con.dialect.WriteArrayMembership(&con.str, func() error { return con.visitMaybeNested(lhs, lhsParen) }, func() error { return con.visitMaybeNested(rhs, rhsParen) }, @@ -769,40 +782,18 @@ func (con *converter) visitCallBinary(expr *exprpb.Expr) error { con.str.WriteString(" ") con.str.WriteString(operator) con.str.WriteString(" ") - if fun == operators.In && (isListType(rhsType) || isFieldAccessExpression(rhs)) { - // Check if we're dealing with a JSON array - if isFieldAccessExpression(rhs) && con.isJSONArrayField(rhs) { - // For JSON arrays, use dialect-specific JSON array membership - jsonFunc := con.getJSONArrayFunction(rhs) - - // For nested JSON access like settings.permissions, we need to handle differently - if con.isNestedJSONAccess(rhs) { - // Use dialect-specific nested JSON array membership - if err := con.dialect.WriteNestedJSONArrayMembership(&con.str, func() error { - return con.visitNestedJSONForArray(rhs) - }); err != nil { - return err - } - return nil - } - // For direct JSON array access - if err := con.dialect.WriteJSONArrayMembership(&con.str, jsonFunc, func() error { - return con.visitMaybeNested(rhs, rhsParen) - }); err != nil { - return err - } - return nil - } + // Remaining membership case: field access on a non-JSON, non-list-typed + // column (e.g. a Dyn-typed array column) wraps the RHS in ANY(). + // JSON arrays and list literals are handled by the dialect before the LHS + // is written. + if fun == operators.In && isFieldAccessExpression(rhs) { con.str.WriteString("ANY(") } if err := con.visitMaybeNested(rhs, rhsParen); err != nil { return err } - if fun == operators.In && (isListType(rhsType) || isFieldAccessExpression(rhs)) { - // Check if we're dealing with a JSON array - already handled above for JSON arrays - if !isFieldAccessExpression(rhs) || !con.isJSONArrayField(rhs) { - con.str.WriteString(")") - } + if fun == operators.In && isFieldAccessExpression(rhs) { + con.str.WriteString(")") } return nil } diff --git a/dialect/bigquery/dialect.go b/dialect/bigquery/dialect.go index 5314237..ffdfafe 100644 --- a/dialect/bigquery/dialect.go +++ b/dialect/bigquery/dialect.go @@ -261,20 +261,28 @@ func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string return nil } -// WriteJSONArrayMembership writes BigQuery JSON array membership. -func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeExpr func() error) error { - w.WriteString("UNNEST(JSON_VALUE_ARRAY(") - if err := writeExpr(); err != nil { +// WriteJSONArrayMembership writes BigQuery JSON array membership using +// elem IN UNNEST(JSON_VALUE_ARRAY(arr)). +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeElem func() error, writeArray func() error) error { + if err := writeElem(); err != nil { + return err + } + w.WriteString(" IN UNNEST(JSON_VALUE_ARRAY(") + if err := writeArray(); err != nil { return err } w.WriteString("))") return nil } -// WriteNestedJSONArrayMembership writes BigQuery nested JSON array membership. -func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { - w.WriteString("UNNEST(JSON_VALUE_ARRAY(") - if err := writeExpr(); err != nil { +// WriteNestedJSONArrayMembership writes BigQuery nested JSON array membership using +// elem IN UNNEST(JSON_VALUE_ARRAY(arr)). +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + if err := writeElem(); err != nil { + return err + } + w.WriteString(" IN UNNEST(JSON_VALUE_ARRAY(") + if err := writeArray(); err != nil { return err } w.WriteString("))") diff --git a/dialect/bigquery/regex.go b/dialect/bigquery/regex.go index 2fcfbf7..cdee30d 100644 --- a/dialect/bigquery/regex.go +++ b/dialect/bigquery/regex.go @@ -1,137 +1,28 @@ package bigquery import ( - "errors" - "fmt" - "regexp" "strings" -) -// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). -const ( - maxRegexPatternLength = 500 - maxRegexGroups = 20 - maxRegexNestingDepth = 10 + "github.com/spandigital/cel2sql/v3/dialect/internal/regexsafe" ) // convertRE2ToBigQuery converts an RE2 regex pattern to BigQuery-compatible format. -// BigQuery uses RE2 natively, so most patterns pass through unchanged. +// BigQuery uses RE2 natively, so \d, \w, \s and \b pass through unchanged. Shared +// ReDoS / unsupported-feature validation lives in regexsafe.Validate. // Returns the converted pattern, whether it's case-insensitive, and any error. func convertRE2ToBigQuery(re2Pattern string) (string, bool, error) { - // 1. Pattern length validation - if len(re2Pattern) > maxRegexPatternLength { - return "", false, fmt.Errorf("pattern length %d exceeds limit of %d characters", len(re2Pattern), maxRegexPatternLength) - } - - // 2. Validate pattern compiles - if _, err := regexp.Compile(re2Pattern); err != nil { - return "", false, fmt.Errorf("invalid regex pattern: %w", err) - } - - // 3. Detect unsupported features (lookahead/lookbehind not in RE2 anyway) - if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { - return "", false, errors.New("lookahead assertions (?=...), (?!...) are not supported") - } - if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported") - } - - // 4. Detect catastrophic nested quantifiers - if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { - return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") - } - - // 5. Check for nested quantifiers in groups - depth := 0 - groupHasQuantifier := make([]bool, 0) - for i := 0; i < len(re2Pattern); i++ { - char := re2Pattern[i] - if i > 0 && re2Pattern[i-1] == '\\' { - continue - } - switch char { - case '(': - depth++ - groupHasQuantifier = append(groupHasQuantifier, false) - case ')': - if depth > 0 { - depth-- - if i+1 < len(re2Pattern) { - nextChar := re2Pattern[i+1] - if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { - if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { - return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") - } - } - } - if len(groupHasQuantifier) > 0 { - groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] - } - } - case '*', '+', '?', '{': - for j := range groupHasQuantifier { - groupHasQuantifier[j] = true - } - } - } - - // 6. Check group count limit - groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, "\\(") - if groupCount > maxRegexGroups { - return "", false, fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, maxRegexGroups) - } - - // 7. Check for quantified alternation - quantifiedAlternation := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) - if quantifiedAlternation.MatchString(re2Pattern) { - return "", false, errors.New("regex contains quantified alternation that could cause ReDoS") - } - - // 8. Check nesting depth - maxDepthVal := 0 - currentDepth := 0 - for i := 0; i < len(re2Pattern); i++ { - if i > 0 && re2Pattern[i-1] == '\\' { - continue - } - switch re2Pattern[i] { - case '(': - currentDepth++ - if currentDepth > maxDepthVal { - maxDepthVal = currentDepth - } - case ')': - if currentDepth > 0 { - currentDepth-- - } - } - } - if maxDepthVal > maxRegexNestingDepth { - return "", false, fmt.Errorf("nesting depth %d exceeds limit of %d", maxDepthVal, maxRegexNestingDepth) - } - - // Process pattern: BigQuery uses RE2 natively, so minimal conversion needed caseInsensitive := false - pattern := re2Pattern - - // Handle (?i) flag - BigQuery REGEXP_CONTAINS embeds the flag in the pattern - if strings.HasPrefix(pattern, "(?i)") { + if strings.HasPrefix(re2Pattern, "(?i)") { caseInsensitive = true - pattern = pattern[4:] + re2Pattern = re2Pattern[4:] } - // Handle inline flags other than (?i) at start - if strings.Contains(pattern, "(?m") || strings.Contains(pattern, "(?s") || strings.Contains(pattern, "(?-") { - return "", false, errors.New("inline flags other than (?i) are not supported in BigQuery regex") + if err := regexsafe.Validate(re2Pattern); err != nil { + return "", false, err } - // Convert non-capturing groups (?:...) to regular groups (...) - pattern = strings.ReplaceAll(pattern, "(?:", "(") - - // BigQuery RE2 supports \d, \w, \s, \b natively - no conversion needed + // Convert non-capturing groups (?:...) to regular groups (...). + pattern := strings.ReplaceAll(re2Pattern, "(?:", "(") return pattern, caseInsensitive, nil } diff --git a/dialect/bigquery/validation.go b/dialect/bigquery/validation.go index 0ae982d..40df12b 100644 --- a/dialect/bigquery/validation.go +++ b/dialect/bigquery/validation.go @@ -1,10 +1,9 @@ package bigquery import ( - "errors" - "fmt" "regexp" - "strings" + + "github.com/spandigital/cel2sql/v3/dialect/internal/identsafe" ) var ( @@ -40,17 +39,5 @@ var ( // validateFieldName validates that a field name follows BigQuery naming conventions. func validateFieldName(name string) error { - if len(name) == 0 { - return errors.New("field name cannot be empty") - } - - if !fieldNameRegexp.MatchString(name) { - return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) - } - - if reservedSQLKeywords[strings.ToLower(name)] { - return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) - } - - return nil + return identsafe.ValidateFieldName(name, "BigQuery", 0, fieldNameRegexp, reservedSQLKeywords) } diff --git a/dialect/dialect.go b/dialect/dialect.go index 618f7a5..a083d04 100644 --- a/dialect/dialect.go +++ b/dialect/dialect.go @@ -132,12 +132,16 @@ type Dialect interface { WriteJSONExtractPath(w *strings.Builder, pathSegments []string, writeRoot func() error) error // WriteJSONArrayMembership writes a JSON array membership test for the IN operator. - // For PostgreSQL: ANY(ARRAY(SELECT jsonb_func(expr))). - WriteJSONArrayMembership(w *strings.Builder, jsonFunc string, writeExpr func() error) error + // The dialect owns the complete boolean predicate, emitting both the element + // and the array expressions. + // For PostgreSQL: elem = ANY(ARRAY(SELECT jsonb_func(arr))). + WriteJSONArrayMembership(w *strings.Builder, jsonFunc string, writeElem func() error, writeArray func() error) error // WriteNestedJSONArrayMembership writes a nested JSON array membership test. - // For PostgreSQL: ANY(ARRAY(SELECT jsonb_array_elements_text(expr))). - WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error + // The dialect owns the complete boolean predicate, emitting both the element + // and the array expressions. + // For PostgreSQL: elem = ANY(ARRAY(SELECT jsonb_array_elements_text(arr))). + WriteNestedJSONArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error // --- Timestamps --- diff --git a/dialect/duckdb/dialect.go b/dialect/duckdb/dialect.go index 5a22235..f034c4c 100644 --- a/dialect/duckdb/dialect.go +++ b/dialect/duckdb/dialect.go @@ -261,23 +261,33 @@ func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string return nil } -// WriteJSONArrayMembership writes DuckDB JSON array membership using json_each. -func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeExpr func() error) error { - w.WriteString("(SELECT value FROM json_each(") - if err := writeExpr(); err != nil { +// WriteJSONArrayMembership writes DuckDB JSON array membership using +// EXISTS (SELECT 1 FROM json_each(arr) WHERE value = elem). +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeElem func() error, writeArray func() error) error { + w.WriteString("EXISTS (SELECT 1 FROM json_each(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(") WHERE value = ") + if err := writeElem(); err != nil { return err } - w.WriteString("))") + w.WriteString(")") return nil } -// WriteNestedJSONArrayMembership writes DuckDB nested JSON array membership. -func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { - w.WriteString("(SELECT value FROM json_each(") - if err := writeExpr(); err != nil { +// WriteNestedJSONArrayMembership writes DuckDB nested JSON array membership using +// EXISTS (SELECT 1 FROM json_each(arr) WHERE value = elem). +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + w.WriteString("EXISTS (SELECT 1 FROM json_each(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(") WHERE value = ") + if err := writeElem(); err != nil { return err } - w.WriteString("))") + w.WriteString(")") return nil } diff --git a/dialect/duckdb/regex.go b/dialect/duckdb/regex.go index 582f83d..3a397fc 100644 --- a/dialect/duckdb/regex.go +++ b/dialect/duckdb/regex.go @@ -1,137 +1,28 @@ package duckdb import ( - "errors" - "fmt" - "regexp" "strings" -) -// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). -const ( - maxRegexPatternLength = 500 - maxRegexGroups = 20 - maxRegexNestingDepth = 10 + "github.com/spandigital/cel2sql/v3/dialect/internal/regexsafe" ) // convertRE2ToDuckDB converts an RE2 regex pattern to DuckDB-compatible format. -// DuckDB uses RE2 natively, so most patterns pass through unchanged. +// DuckDB uses RE2 natively, so \d, \w, \s and \b pass through unchanged. Shared +// ReDoS / unsupported-feature validation lives in regexsafe.Validate. // Returns the converted pattern, whether it's case-insensitive, and any error. func convertRE2ToDuckDB(re2Pattern string) (string, bool, error) { - // 1. Pattern length validation - if len(re2Pattern) > maxRegexPatternLength { - return "", false, fmt.Errorf("pattern length %d exceeds limit of %d characters", len(re2Pattern), maxRegexPatternLength) - } - - // 2. Validate pattern compiles - if _, err := regexp.Compile(re2Pattern); err != nil { - return "", false, fmt.Errorf("invalid regex pattern: %w", err) - } - - // 3. Detect unsupported features (lookahead/lookbehind not in RE2 anyway) - if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { - return "", false, errors.New("lookahead assertions (?=...), (?!...) are not supported") - } - if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported") - } - - // 4. Detect catastrophic nested quantifiers - if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { - return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") - } - - // 5. Check for nested quantifiers in groups - depth := 0 - groupHasQuantifier := make([]bool, 0) - for i := 0; i < len(re2Pattern); i++ { - char := re2Pattern[i] - if i > 0 && re2Pattern[i-1] == '\\' { - continue - } - switch char { - case '(': - depth++ - groupHasQuantifier = append(groupHasQuantifier, false) - case ')': - if depth > 0 { - depth-- - if i+1 < len(re2Pattern) { - nextChar := re2Pattern[i+1] - if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { - if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { - return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") - } - } - } - if len(groupHasQuantifier) > 0 { - groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] - } - } - case '*', '+', '?', '{': - for j := range groupHasQuantifier { - groupHasQuantifier[j] = true - } - } - } - - // 6. Check group count limit - groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, "\\(") - if groupCount > maxRegexGroups { - return "", false, fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, maxRegexGroups) - } - - // 7. Check for quantified alternation - quantifiedAlternation := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) - if quantifiedAlternation.MatchString(re2Pattern) { - return "", false, errors.New("regex contains quantified alternation that could cause ReDoS") - } - - // 8. Check nesting depth - maxDepthVal := 0 - currentDepth := 0 - for i := 0; i < len(re2Pattern); i++ { - if i > 0 && re2Pattern[i-1] == '\\' { - continue - } - switch re2Pattern[i] { - case '(': - currentDepth++ - if currentDepth > maxDepthVal { - maxDepthVal = currentDepth - } - case ')': - if currentDepth > 0 { - currentDepth-- - } - } - } - if maxDepthVal > maxRegexNestingDepth { - return "", false, fmt.Errorf("nesting depth %d exceeds limit of %d", maxDepthVal, maxRegexNestingDepth) - } - - // Process pattern: DuckDB uses RE2 natively, so minimal conversion needed caseInsensitive := false - pattern := re2Pattern - - // Handle (?i) flag - if strings.HasPrefix(pattern, "(?i)") { + if strings.HasPrefix(re2Pattern, "(?i)") { caseInsensitive = true - pattern = pattern[4:] + re2Pattern = re2Pattern[4:] } - // Handle inline flags other than (?i) at start - if strings.Contains(pattern, "(?m") || strings.Contains(pattern, "(?s") || strings.Contains(pattern, "(?-") { - return "", false, errors.New("inline flags other than (?i) are not supported in DuckDB regex") + if err := regexsafe.Validate(re2Pattern); err != nil { + return "", false, err } - // Convert non-capturing groups (?:...) to regular groups (...) - pattern = strings.ReplaceAll(pattern, "(?:", "(") - - // DuckDB RE2 supports \d, \w, \s, \b natively - no conversion needed + // Convert non-capturing groups (?:...) to regular groups (...). + pattern := strings.ReplaceAll(re2Pattern, "(?:", "(") return pattern, caseInsensitive, nil } diff --git a/dialect/duckdb/validation.go b/dialect/duckdb/validation.go index 976e304..8815651 100644 --- a/dialect/duckdb/validation.go +++ b/dialect/duckdb/validation.go @@ -1,10 +1,9 @@ package duckdb import ( - "errors" - "fmt" "regexp" - "strings" + + "github.com/spandigital/cel2sql/v3/dialect/internal/identsafe" ) var ( @@ -39,17 +38,5 @@ var ( // validateFieldName validates that a field name follows DuckDB naming conventions. func validateFieldName(name string) error { - if len(name) == 0 { - return errors.New("field name cannot be empty") - } - - if !fieldNameRegexp.MatchString(name) { - return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) - } - - if reservedSQLKeywords[strings.ToLower(name)] { - return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) - } - - return nil + return identsafe.ValidateFieldName(name, "DuckDB", 0, fieldNameRegexp, reservedSQLKeywords) } diff --git a/dialect/internal/identsafe/identsafe.go b/dialect/internal/identsafe/identsafe.go new file mode 100644 index 0000000..430e1d7 --- /dev/null +++ b/dialect/internal/identsafe/identsafe.go @@ -0,0 +1,44 @@ +// Package identsafe centralises the SQL identifier (field name) validation +// shared by every dialect. +// +// Each dialect previously reimplemented the same validateFieldName skeleton — +// empty check, optional length cap, identifier-format regexp, reserved-keyword +// lookup — differing only in the per-dialect data (max length, keyword set). +// Sharing the skeleton keeps the validation logic in one place while each +// dialect retains its own identifier rules. +package identsafe + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +// ValidateFieldName checks that name is a safe unquoted SQL identifier for a +// dialect. +// +// - dialectName appears in the length-limit error message (e.g. "PostgreSQL"). +// - maxLen is the maximum identifier length; pass 0 (or negative) to skip the +// length check for dialects that impose no practical limit. +// - nameRE is the dialect's identifier-format pattern. +// - reserved is the dialect's set of reserved keywords (lowercased keys). +func ValidateFieldName(name, dialectName string, maxLen int, nameRE *regexp.Regexp, reserved map[string]bool) error { + if len(name) == 0 { + return errors.New("field name cannot be empty") + } + + if maxLen > 0 && len(name) > maxLen { + return fmt.Errorf("field name %q exceeds %s maximum identifier length of %d characters", name, dialectName, maxLen) + } + + if !nameRE.MatchString(name) { + return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) + } + + if reserved[strings.ToLower(name)] { + return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) + } + + return nil +} diff --git a/dialect/internal/regexsafe/regexsafe.go b/dialect/internal/regexsafe/regexsafe.go new file mode 100644 index 0000000..a153e74 --- /dev/null +++ b/dialect/internal/regexsafe/regexsafe.go @@ -0,0 +1,168 @@ +// Package regexsafe centralises the RE2 ReDoS / unsupported-feature validation +// shared by every SQL dialect's regex conversion. +// +// Each dialect previously reimplemented this validation in its own regex.go, +// and the copies had drifted (some compiled the pattern first, the +// nested-quantifier and nesting-depth loops were written differently, error +// wording diverged). A single implementation guarantees that a check tightened +// for one dialect protects all of them. Only the dialect-specific parts — +// case-insensitivity handling and the final character-class / non-capturing +// group transform — remain in each dialect's regex.go. +package regexsafe + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +// Pattern complexity limits to prevent ReDoS attacks (CWE-1333). +const ( + // MaxPatternLength caps the raw pattern length. + MaxPatternLength = 500 + // MaxGroups caps the number of capture groups. + MaxGroups = 20 + // MaxNestingDepth caps how deeply groups may nest. + MaxNestingDepth = 10 +) + +// quantifiedAlternation matches an alternation group immediately followed by a +// quantifier, e.g. (a|a)*, a classic exponential-backtracking shape. +var quantifiedAlternation = regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) + +// Validate runs the dialect-agnostic security checks on an RE2 pattern. +// +// Callers should strip any leading (?i) case-insensitivity flag before calling +// Validate (the flag's handling is dialect-specific); Validate still rejects +// inline flags other than (?i). It returns a non-nil error describing the first +// problem found, or nil if the pattern is safe to convert. +func Validate(pattern string) error { + // 1. Length cap to bound the cost of every subsequent scan. + if len(pattern) > MaxPatternLength { + return fmt.Errorf("pattern length %d exceeds limit of %d characters", len(pattern), MaxPatternLength) + } + + // 2. Reject features RE2 forbids or that cel2sql does not translate. + // This runs before the compile check below so these patterns get a + // descriptive message rather than Go's generic "invalid Perl syntax". + if strings.Contains(pattern, "(?=") || strings.Contains(pattern, "(?!") { + return errors.New("lookahead assertions (?=...), (?!...) are not supported") + } + if strings.Contains(pattern, "(?<=") || strings.Contains(pattern, "(?...) are not supported") + } + if strings.Contains(pattern, "(?m") || strings.Contains(pattern, "(?s") || strings.Contains(pattern, "(?-") { + return errors.New("inline flags other than (?i) are not supported") + } + + // 3. Catastrophic adjacent quantifiers, e.g. a++ / a*+. + if matched, _ := regexp.MatchString(`[*+][*+]`, pattern); matched { + return errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + + // 4. A quantified group whose body is itself quantified, e.g. (a+)+. + if err := checkNestedQuantifiers(pattern); err != nil { + return err + } + + // 5. Capture-group count. + groupCount := strings.Count(pattern, "(") - strings.Count(pattern, `\(`) + if groupCount > MaxGroups { + return fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, MaxGroups) + } + + // 6. Quantified alternation, e.g. (a|a)*b. + if quantifiedAlternation.MatchString(pattern) { + return errors.New("regex contains quantified alternation that could cause ReDoS") + } + + // 7. Group nesting depth. + if depth := maxGroupDepth(pattern); depth > MaxNestingDepth { + return fmt.Errorf("nesting depth %d exceeds limit of %d", depth, MaxNestingDepth) + } + + // 8. Final catch-all: the pattern must compile under RE2. Runs last so the + // heuristic checks above can return their descriptive messages for the + // patterns Go's regexp rejects with a generic parse error. + if _, err := regexp.Compile(pattern); err != nil { + return fmt.Errorf("invalid regex pattern: %w", err) + } + + return nil +} + +// checkNestedQuantifiers rejects a quantified group that itself contains a +// quantifier (the (a+)+ catastrophic-backtracking shape). It tracks, per open +// group, whether a quantifier has been seen inside it, and when a group closes +// with a trailing quantifier while already containing one, flags it. +func checkNestedQuantifiers(pattern string) error { + depth := 0 + groupHasQuantifier := make([]bool, 0) + + for i := 0; i < len(pattern); i++ { + char := pattern[i] + + // Skip escaped characters. + if i > 0 && pattern[i-1] == '\\' { + continue + } + + switch char { + case '(': + depth++ + groupHasQuantifier = append(groupHasQuantifier, false) + case ')': + if depth > 0 { + depth-- + if i+1 < len(pattern) { + next := pattern[i+1] + if next == '*' || next == '+' || next == '?' || next == '{' { + if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { + return errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + } + } + if len(groupHasQuantifier) > 0 { + // Propagate "contains a quantifier" up to the enclosing group. + if len(groupHasQuantifier) > 1 && groupHasQuantifier[len(groupHasQuantifier)-1] { + groupHasQuantifier[len(groupHasQuantifier)-2] = true + } + groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] + } + } + case '*', '+', '?', '{': + if len(groupHasQuantifier) > 0 { + groupHasQuantifier[len(groupHasQuantifier)-1] = true + } + } + } + + return nil +} + +// maxGroupDepth returns the deepest level of nested (unescaped) groups. +func maxGroupDepth(pattern string) int { + maxDepth := 0 + current := 0 + for i := 0; i < len(pattern); i++ { + if i > 0 && pattern[i-1] == '\\' { + continue + } + switch pattern[i] { + case '(': + current++ + if current > maxDepth { + maxDepth = current + } + case ')': + if current > 0 { + current-- + } + } + } + return maxDepth +} diff --git a/dialect/internal/regexsafe/regexsafe_test.go b/dialect/internal/regexsafe/regexsafe_test.go new file mode 100644 index 0000000..88d11a1 --- /dev/null +++ b/dialect/internal/regexsafe/regexsafe_test.go @@ -0,0 +1,67 @@ +package regexsafe + +import ( + "strings" + "testing" +) + +func TestValidate_SafePatterns(t *testing.T) { + safe := []string{ + "a+", + "^[0-9]+$", + `\btest\b`, + `\d{3}-\d{4}`, + `[a-z]+@[a-z]+\.[a-z]+`, + "(abc)+", + "(?:abc)", + strings.Repeat("a", MaxPatternLength), // exactly at the length limit + } + for _, p := range safe { + if err := Validate(p); err != nil { + t.Errorf("Validate(%q) = %v, want nil", p, err) + } + } +} + +func TestValidate_RejectsWithDescriptiveMessage(t *testing.T) { + cases := []struct { + name string + pattern string + wantSub string + }{ + {"lookahead", "(?=foo)", "lookahead assertions"}, + {"neg_lookahead", "(?!foo)", "lookahead assertions"}, + {"lookbehind", "(?<=foo)bar", "lookbehind assertions"}, + {"neg_lookbehind", "(?x)", "named capture groups"}, + {"inline_flag_m", "(?m)^x", "inline flags other than"}, + {"inline_flag_s", "(?s).x", "inline flags other than"}, + {"double_star", "a**", "nested quantifiers"}, + {"double_plus", "a++", "nested quantifiers"}, + {"nested_quantified_group", "(a+)+", "nested quantifiers"}, + {"too_long", strings.Repeat("a", MaxPatternLength+1), "pattern length"}, + {"too_many_groups", strings.Repeat("()", MaxGroups+1), "capture groups"}, + {"quantified_alternation", "(a|a)*b", "quantified alternation"}, + {"too_deep", strings.Repeat("(", MaxNestingDepth+1) + "a" + strings.Repeat(")", MaxNestingDepth+1), "nesting depth"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := Validate(tc.pattern) + if err == nil { + t.Fatalf("Validate(%q) = nil, want error containing %q", tc.pattern, tc.wantSub) + } + if !strings.Contains(err.Error(), tc.wantSub) { + t.Errorf("Validate(%q) error = %q, want substring %q", tc.pattern, err.Error(), tc.wantSub) + } + }) + } +} + +// TestValidate_InvalidPatternCaughtByCompile covers a malformed pattern that the +// heuristic checks do not flag but RE2 cannot compile. +func TestValidate_InvalidPatternCaughtByCompile(t *testing.T) { + err := Validate("[unterminated") + if err == nil || !strings.Contains(err.Error(), "invalid regex pattern") { + t.Errorf("Validate(unterminated class) = %v, want 'invalid regex pattern'", err) + } +} diff --git a/dialect/mysql/dialect.go b/dialect/mysql/dialect.go index 01cd97d..f23aab0 100644 --- a/dialect/mysql/dialect.go +++ b/dialect/mysql/dialect.go @@ -266,23 +266,33 @@ func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string return nil } -// WriteJSONArrayMembership writes MySQL JSON array membership using JSON_CONTAINS. -func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeExpr func() error) error { - w.WriteString("JSON_CONTAINS(") - if err := writeExpr(); err != nil { +// WriteJSONArrayMembership writes MySQL JSON array membership using +// JSON_OVERLAPS(JSON_ARRAY(elem), arr). +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeElem func() error, writeArray func() error) error { + w.WriteString("JSON_OVERLAPS(JSON_ARRAY(") + if err := writeElem(); err != nil { return err } - w.WriteString(", CAST(? AS JSON))") + w.WriteString("), ") + if err := writeArray(); err != nil { + return err + } + w.WriteString(")") return nil } -// WriteNestedJSONArrayMembership writes MySQL nested JSON array membership. -func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { - w.WriteString("JSON_CONTAINS(") - if err := writeExpr(); err != nil { +// WriteNestedJSONArrayMembership writes MySQL nested JSON array membership using +// JSON_OVERLAPS(JSON_ARRAY(elem), arr). +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + w.WriteString("JSON_OVERLAPS(JSON_ARRAY(") + if err := writeElem(); err != nil { return err } - w.WriteString(", CAST(? AS JSON))") + w.WriteString("), ") + if err := writeArray(); err != nil { + return err + } + w.WriteString(")") return nil } diff --git a/dialect/mysql/regex.go b/dialect/mysql/regex.go index 7965097..81ca61b 100644 --- a/dialect/mysql/regex.go +++ b/dialect/mysql/regex.go @@ -1,139 +1,29 @@ package mysql import ( - "errors" - "fmt" - "regexp" "strings" -) -// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). -const ( - maxRegexPatternLength = 500 - maxRegexGroups = 20 - maxRegexNestingDepth = 10 + "github.com/spandigital/cel2sql/v3/dialect/internal/regexsafe" ) // convertRE2ToMySQL converts an RE2 regex pattern to MySQL-compatible format. -// MySQL 8.0+ uses ICU regex which supports most RE2 features. +// MySQL 8.0+ uses ICU regex, which supports \d, \w, \s and \b natively, so the +// pattern passes through almost unchanged. Shared ReDoS / unsupported-feature +// validation lives in regexsafe.Validate. // Returns the converted pattern, whether it's case-insensitive, and any error. func convertRE2ToMySQL(re2Pattern string) (string, bool, error) { - // 1. Pattern length validation - if len(re2Pattern) > maxRegexPatternLength { - return "", false, fmt.Errorf("pattern length %d exceeds limit of %d characters", len(re2Pattern), maxRegexPatternLength) - } - - // 2. Validate pattern compiles - if _, err := regexp.Compile(re2Pattern); err != nil { - return "", false, fmt.Errorf("invalid regex pattern: %w", err) - } - - // 3. Detect unsupported RE2 features - if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { - return "", false, errors.New("lookahead assertions (?=...), (?!...) are not supported in MySQL regex") - } - if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported in MySQL regex") - } - - // 4. Detect catastrophic nested quantifiers - if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { - return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") - } - - // 5. Check for nested quantifiers in groups - depth := 0 - groupHasQuantifier := make([]bool, 0) - for i := 0; i < len(re2Pattern); i++ { - char := re2Pattern[i] - if i > 0 && re2Pattern[i-1] == '\\' { - continue - } - switch char { - case '(': - depth++ - groupHasQuantifier = append(groupHasQuantifier, false) - case ')': - if depth > 0 { - depth-- - if i+1 < len(re2Pattern) { - nextChar := re2Pattern[i+1] - if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { - if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { - return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") - } - } - } - if len(groupHasQuantifier) > 0 { - groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] - } - } - case '*', '+', '?', '{': - for j := range groupHasQuantifier { - groupHasQuantifier[j] = true - } - } - } - - // 6. Check group count limit - groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, "\\(") - if groupCount > maxRegexGroups { - return "", false, fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, maxRegexGroups) - } - - // 7. Check for quantified alternation - quantifiedAlternation := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) - if quantifiedAlternation.MatchString(re2Pattern) { - return "", false, errors.New("regex contains quantified alternation that could cause ReDoS") - } - - // 8. Check nesting depth - maxDepthVal := 0 - currentDepth := 0 - for i := 0; i < len(re2Pattern); i++ { - if i > 0 && re2Pattern[i-1] == '\\' { - continue - } - switch re2Pattern[i] { - case '(': - currentDepth++ - if currentDepth > maxDepthVal { - maxDepthVal = currentDepth - } - case ')': - if currentDepth > 0 { - currentDepth-- - } - } - } - if maxDepthVal > maxRegexNestingDepth { - return "", false, fmt.Errorf("nesting depth %d exceeds limit of %d", maxDepthVal, maxRegexNestingDepth) - } - - // Process pattern: extract case-insensitivity, convert features caseInsensitive := false - pattern := re2Pattern - - // Handle (?i) flag - if strings.HasPrefix(pattern, "(?i)") { + if strings.HasPrefix(re2Pattern, "(?i)") { caseInsensitive = true - pattern = pattern[4:] + re2Pattern = re2Pattern[4:] } - // Handle inline flags other than (?i) at start - if strings.Contains(pattern, "(?m") || strings.Contains(pattern, "(?s") || strings.Contains(pattern, "(?-") { - return "", false, errors.New("inline flags other than (?i) are not supported in MySQL regex") + if err := regexsafe.Validate(re2Pattern); err != nil { + return "", false, err } - // Convert non-capturing groups (?:...) to regular groups (...) - pattern = strings.ReplaceAll(pattern, "(?:", "(") - - // MySQL ICU regex supports \d, \w, \s natively - no conversion needed - // Convert \b word boundary to MySQL's \b (same syntax in ICU) - // No conversion needed for MySQL 8.0+ + // Convert non-capturing groups (?:...) to regular groups (...). + pattern := strings.ReplaceAll(re2Pattern, "(?:", "(") return pattern, caseInsensitive, nil } diff --git a/dialect/mysql/validation.go b/dialect/mysql/validation.go index 15a20d4..9eaaf7f 100644 --- a/dialect/mysql/validation.go +++ b/dialect/mysql/validation.go @@ -1,10 +1,9 @@ package mysql import ( - "errors" - "fmt" "regexp" - "strings" + + "github.com/spandigital/cel2sql/v3/dialect/internal/identsafe" ) const ( @@ -71,21 +70,5 @@ var ( // validateFieldName validates that a field name follows MySQL naming conventions. func validateFieldName(name string) error { - if len(name) == 0 { - return errors.New("field name cannot be empty") - } - - if len(name) > maxMySQLIdentifierLength { - return fmt.Errorf("field name %q exceeds MySQL maximum identifier length of %d characters", name, maxMySQLIdentifierLength) - } - - if !fieldNameRegexp.MatchString(name) { - return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) - } - - if reservedSQLKeywords[strings.ToLower(name)] { - return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) - } - - return nil + return identsafe.ValidateFieldName(name, "MySQL", maxMySQLIdentifierLength, fieldNameRegexp, reservedSQLKeywords) } diff --git a/dialect/postgres/dialect.go b/dialect/postgres/dialect.go index 8aee559..7855edd 100644 --- a/dialect/postgres/dialect.go +++ b/dialect/postgres/dialect.go @@ -275,22 +275,28 @@ func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string return nil } -// WriteJSONArrayMembership writes ANY(ARRAY(SELECT json_func(expr))) for PostgreSQL. -func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, jsonFunc string, writeExpr func() error) error { - w.WriteString("ANY(ARRAY(SELECT ") +// WriteJSONArrayMembership writes elem = ANY(ARRAY(SELECT json_func(arr))) for PostgreSQL. +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, jsonFunc string, writeElem func() error, writeArray func() error) error { + if err := writeElem(); err != nil { + return err + } + w.WriteString(" = ANY(ARRAY(SELECT ") w.WriteString(jsonFunc) w.WriteString("(") - if err := writeExpr(); err != nil { + if err := writeArray(); err != nil { return err } w.WriteString(")))") return nil } -// WriteNestedJSONArrayMembership writes ANY(ARRAY(SELECT jsonb_array_elements_text(expr))) for PostgreSQL. -func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { - w.WriteString("ANY(ARRAY(SELECT jsonb_array_elements_text(") - if err := writeExpr(); err != nil { +// WriteNestedJSONArrayMembership writes elem = ANY(ARRAY(SELECT jsonb_array_elements_text(arr))) for PostgreSQL. +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + if err := writeElem(); err != nil { + return err + } + w.WriteString(" = ANY(ARRAY(SELECT jsonb_array_elements_text(") + if err := writeArray(); err != nil { return err } w.WriteString(")))") diff --git a/dialect/postgres/regex.go b/dialect/postgres/regex.go index 1cc4a4b..aef7b1c 100644 --- a/dialect/postgres/regex.go +++ b/dialect/postgres/regex.go @@ -1,134 +1,30 @@ package postgres import ( - "errors" - "fmt" - "regexp" "strings" -) -// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). -const ( - maxRegexPatternLength = 500 - maxRegexGroups = 20 - maxRegexNestingDepth = 10 + "github.com/spandigital/cel2sql/v3/dialect/internal/regexsafe" ) // convertRE2ToPOSIX converts an RE2 regex pattern to POSIX ERE format for PostgreSQL. -// It performs security validation to prevent ReDoS attacks (CWE-1333). +// Shared ReDoS / unsupported-feature validation lives in regexsafe.Validate; +// this function adds only PostgreSQL's case-insensitivity handling and the +// RE2 → POSIX character-class translation. // Returns: (posixPattern, caseInsensitive, error) func convertRE2ToPOSIX(re2Pattern string) (string, bool, error) { - // 1. Check pattern length to prevent processing extremely long patterns - if len(re2Pattern) > maxRegexPatternLength { - return "", false, fmt.Errorf("pattern length %d exceeds limit of %d characters", len(re2Pattern), maxRegexPatternLength) - } - - // 2. Extract case-insensitive flag if present + // Extract the case-insensitive flag; PostgreSQL signals it via the ~* operator. caseInsensitive := false if strings.HasPrefix(re2Pattern, "(?i)") { caseInsensitive = true re2Pattern = strings.TrimPrefix(re2Pattern, "(?i)") } - // 3. Detect unsupported RE2 features and return errors - if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { - return "", false, errors.New("lookahead assertions (?=...), (?!...) are not supported in PostgreSQL POSIX regex") - } - if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported in PostgreSQL POSIX regex") - } - if strings.Contains(re2Pattern, "(?m") || strings.Contains(re2Pattern, "(?s") || strings.Contains(re2Pattern, "(?-") { - return "", false, errors.New("inline flags other than (?i) are not supported in PostgreSQL POSIX regex") + if err := regexsafe.Validate(re2Pattern); err != nil { + return "", false, err } - // 4. Detect catastrophic nested quantifiers - if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { - return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") - } - - // Check for groups that contain quantifiers and are themselves quantified - depth := 0 - groupHasQuantifier := make([]bool, 0) - - for i := 0; i < len(re2Pattern); i++ { - char := re2Pattern[i] - - // Skip escaped characters - if i > 0 && re2Pattern[i-1] == '\\' { - continue - } - - switch char { - case '(': - depth++ - groupHasQuantifier = append(groupHasQuantifier, false) - case ')': - if depth > 0 { - depth-- - if i+1 < len(re2Pattern) { - nextChar := re2Pattern[i+1] - if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { - if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { - return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") - } - } - } - if len(groupHasQuantifier) > 0 { - if len(groupHasQuantifier) > 1 { - if groupHasQuantifier[len(groupHasQuantifier)-1] { - groupHasQuantifier[len(groupHasQuantifier)-2] = true - } - } - groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] - } - } - case '*', '+', '?': - if len(groupHasQuantifier) > 0 { - groupHasQuantifier[len(groupHasQuantifier)-1] = true - } - case '{': - if len(groupHasQuantifier) > 0 { - groupHasQuantifier[len(groupHasQuantifier)-1] = true - } - } - } - - // 5. Count and limit capture groups - groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, `\(`) - if groupCount > maxRegexGroups { - return "", false, fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, maxRegexGroups) - } - - // 6. Detect exponential alternation patterns - alternationPattern := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) - if alternationPattern.MatchString(re2Pattern) { - return "", false, errors.New("regex contains quantified alternation that could cause ReDoS") - } - - // 7. Check nesting depth - maxDepthVal := 0 - currentDepth := 0 - for _, char := range re2Pattern { - if char == '(' && !strings.HasSuffix(re2Pattern[:strings.LastIndex(re2Pattern, string(char))], `\`) { - currentDepth++ - if currentDepth > maxDepthVal { - maxDepthVal = currentDepth - } - } else if char == ')' && !strings.HasSuffix(re2Pattern[:strings.LastIndex(re2Pattern, string(char))], `\`) { - currentDepth-- - } - } - if maxDepthVal > maxRegexNestingDepth { - return "", false, fmt.Errorf("nesting depth %d exceeds limit of %d", maxDepthVal, maxRegexNestingDepth) - } - - // Passed all security checks - proceed with conversion + // Convert RE2 patterns to POSIX equivalents. posixPattern := re2Pattern - - // Convert RE2 patterns to POSIX equivalents posixPattern = strings.ReplaceAll(posixPattern, `\b`, `\y`) posixPattern = strings.ReplaceAll(posixPattern, `\B`, `[^[:alnum:]_]`) posixPattern = strings.ReplaceAll(posixPattern, `\d`, `[[:digit:]]`) diff --git a/dialect/postgres/validation.go b/dialect/postgres/validation.go index 162da68..9440b9e 100644 --- a/dialect/postgres/validation.go +++ b/dialect/postgres/validation.go @@ -1,10 +1,9 @@ package postgres import ( - "errors" - "fmt" "regexp" - "strings" + + "github.com/spandigital/cel2sql/v3/dialect/internal/identsafe" ) const ( @@ -46,21 +45,5 @@ var ( // validateFieldName validates that a field name follows PostgreSQL naming conventions // and is safe to use in SQL queries without quoting. func validateFieldName(name string) error { - if len(name) == 0 { - return errors.New("field name cannot be empty") - } - - if len(name) > maxPostgreSQLIdentifierLength { - return fmt.Errorf("field name %q exceeds PostgreSQL maximum identifier length of %d characters", name, maxPostgreSQLIdentifierLength) - } - - if !fieldNameRegexp.MatchString(name) { - return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) - } - - if reservedSQLKeywords[strings.ToLower(name)] { - return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) - } - - return nil + return identsafe.ValidateFieldName(name, "PostgreSQL", maxPostgreSQLIdentifierLength, fieldNameRegexp, reservedSQLKeywords) } diff --git a/dialect/spark/dialect.go b/dialect/spark/dialect.go index c66ff67..110e33c 100644 --- a/dialect/spark/dialect.go +++ b/dialect/spark/dialect.go @@ -277,28 +277,34 @@ func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string return nil } -// WriteJSONArrayMembership writes Spark JSON array membership as a scalar -// subquery that scans elements. The converter writes `lhs = ` before this, -// so the result is `lhs = (SELECT col FROM (SELECT EXPLODE(from_json(rhs, -// 'ARRAY')) AS col) t)`. This mirrors SQLite's `lhs = (SELECT value -// FROM json_each(...))` pattern; both dialects rely on the subquery -// returning at most one match for the comparison to succeed. -func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeExpr func() error) error { - w.WriteString("(SELECT col FROM (SELECT EXPLODE(from_json(") - if err := writeExpr(); err != nil { +// WriteJSONArrayMembership writes Spark JSON array membership using +// array_contains(from_json(arr, 'ARRAY'), elem). The dialect owns the +// full boolean predicate, parsing the JSON array and testing for the element. +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeElem func() error, writeArray func() error) error { + w.WriteString("array_contains(from_json(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", 'ARRAY'), ") + if err := writeElem(); err != nil { return err } - w.WriteString(", 'ARRAY')) AS col) t)") + w.WriteString(")") return nil } -// WriteNestedJSONArrayMembership writes Spark nested JSON array membership. -func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { - w.WriteString("(SELECT col FROM (SELECT EXPLODE(from_json(") - if err := writeExpr(); err != nil { +// WriteNestedJSONArrayMembership writes Spark nested JSON array membership using +// array_contains(from_json(arr, 'ARRAY'), elem). +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + w.WriteString("array_contains(from_json(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", 'ARRAY'), ") + if err := writeElem(); err != nil { return err } - w.WriteString(", 'ARRAY')) AS col) t)") + w.WriteString(")") return nil } diff --git a/dialect/spark/regex.go b/dialect/spark/regex.go index b7d1e3b..a2a39d4 100644 --- a/dialect/spark/regex.go +++ b/dialect/spark/regex.go @@ -1,127 +1,26 @@ package spark import ( - "errors" - "fmt" - "regexp" "strings" -) -// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). -const ( - maxRegexPatternLength = 500 - maxRegexGroups = 20 - maxRegexNestingDepth = 10 + "github.com/spandigital/cel2sql/v3/dialect/internal/regexsafe" ) // convertRE2ToSpark converts an RE2 regex pattern to Spark/Java regex format. -// Spark uses java.util.regex.Pattern, which is a superset of RE2 for the safe -// subset cel2sql accepts (we reject lookahead/lookbehind, named captures, and -// non-(?i) inline flags). Java natively supports \d, \w, \s, \b, (?:...), and -// inline (?i), so most patterns pass through unchanged. +// Spark uses java.util.regex.Pattern, a superset of the safe RE2 subset +// cel2sql accepts: it natively supports \d, \w, \s, \b, (?:...) and inline +// (?i), so patterns pass through unchanged. Shared ReDoS / unsupported-feature +// validation lives in regexsafe.Validate. +// +// Unlike the other dialects, Spark keeps the (?i) flag inline (its regex engine +// honours it), so caseInsensitive is always reported as false; the flag is only +// stripped for the purpose of validation counting. func convertRE2ToSpark(re2Pattern string) (string, bool, error) { - // 1. Pattern length validation. - if len(re2Pattern) > maxRegexPatternLength { - return "", false, fmt.Errorf("pattern length %d exceeds limit of %d characters", len(re2Pattern), maxRegexPatternLength) - } - - // 2. Validate the pattern compiles under RE2 (cel2sql input contract). - if _, err := regexp.Compile(re2Pattern); err != nil { - return "", false, fmt.Errorf("invalid regex pattern: %w", err) - } - - // 3. Reject features RE2 forbids but users sometimes ask for. - if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { - return "", false, errors.New("lookahead assertions (?=...), (?!...) are not supported") - } - if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported") - } - - // 4. Detect catastrophic nested quantifiers. - if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { - return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") - } - - // 5. Check for nested quantifiers in groups. - depth := 0 - groupHasQuantifier := make([]bool, 0) - for i := 0; i < len(re2Pattern); i++ { - char := re2Pattern[i] - if i > 0 && re2Pattern[i-1] == '\\' { - continue - } - switch char { - case '(': - depth++ - groupHasQuantifier = append(groupHasQuantifier, false) - case ')': - if depth > 0 { - depth-- - if i+1 < len(re2Pattern) { - nextChar := re2Pattern[i+1] - if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { - if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { - return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") - } - } - } - if len(groupHasQuantifier) > 0 { - groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] - } - } - case '*', '+', '?', '{': - for j := range groupHasQuantifier { - groupHasQuantifier[j] = true - } - } - } - - // 6. Check group count limit. - groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, "\\(") - if groupCount > maxRegexGroups { - return "", false, fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, maxRegexGroups) - } - - // 7. Check for quantified alternation. - quantifiedAlternation := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) - if quantifiedAlternation.MatchString(re2Pattern) { - return "", false, errors.New("regex contains quantified alternation that could cause ReDoS") - } - - // 8. Check nesting depth. - maxDepthVal := 0 - currentDepth := 0 - for i := 0; i < len(re2Pattern); i++ { - if i > 0 && re2Pattern[i-1] == '\\' { - continue - } - switch re2Pattern[i] { - case '(': - currentDepth++ - if currentDepth > maxDepthVal { - maxDepthVal = currentDepth - } - case ')': - if currentDepth > 0 { - currentDepth-- - } - } - } - if maxDepthVal > maxRegexNestingDepth { - return "", false, fmt.Errorf("nesting depth %d exceeds limit of %d", maxDepthVal, maxRegexNestingDepth) - } + toValidate := strings.TrimPrefix(re2Pattern, "(?i)") - // 9. Reject inline flags other than (?i) anywhere in the pattern. - if strings.Contains(re2Pattern, "(?m") || strings.Contains(re2Pattern, "(?s") || strings.Contains(re2Pattern, "(?-") { - return "", false, errors.New("inline flags other than (?i) are not supported in Spark regex") + if err := regexsafe.Validate(toValidate); err != nil { + return "", false, err } - // Java/Spark regex supports (?i) inline, \d/\w/\s, \b, and (?:...) natively. - // Pass the pattern through unchanged; the inline (?i) is honoured by Spark's - // regex engine, so we report caseInsensitive=false here. return re2Pattern, false, nil } diff --git a/dialect/spark/validation.go b/dialect/spark/validation.go index 19173cb..a0cdd5b 100644 --- a/dialect/spark/validation.go +++ b/dialect/spark/validation.go @@ -1,10 +1,9 @@ package spark import ( - "errors" - "fmt" "regexp" - "strings" + + "github.com/spandigital/cel2sql/v3/dialect/internal/identsafe" ) var ( @@ -47,17 +46,5 @@ var ( // validateFieldName validates that a field name follows Spark naming conventions. func validateFieldName(name string) error { - if len(name) == 0 { - return errors.New("field name cannot be empty") - } - - if !fieldNameRegexp.MatchString(name) { - return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) - } - - if reservedSQLKeywords[strings.ToLower(name)] { - return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) - } - - return nil + return identsafe.ValidateFieldName(name, "Spark", 0, fieldNameRegexp, reservedSQLKeywords) } diff --git a/dialect/sqlite/dialect.go b/dialect/sqlite/dialect.go index 5e97d1f..e04dab7 100644 --- a/dialect/sqlite/dialect.go +++ b/dialect/sqlite/dialect.go @@ -247,23 +247,33 @@ func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string return nil } -// WriteJSONArrayMembership writes SQLite JSON array membership using json_each. -func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeExpr func() error) error { - w.WriteString("(SELECT value FROM json_each(") - if err := writeExpr(); err != nil { +// WriteJSONArrayMembership writes SQLite JSON array membership using +// EXISTS (SELECT 1 FROM json_each(arr) WHERE value = elem). +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeElem func() error, writeArray func() error) error { + w.WriteString("EXISTS (SELECT 1 FROM json_each(") + if err := writeArray(); err != nil { return err } - w.WriteString("))") + w.WriteString(") WHERE value = ") + if err := writeElem(); err != nil { + return err + } + w.WriteString(")") return nil } -// WriteNestedJSONArrayMembership writes SQLite nested JSON array membership. -func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { - w.WriteString("(SELECT value FROM json_each(") - if err := writeExpr(); err != nil { +// WriteNestedJSONArrayMembership writes SQLite nested JSON array membership using +// EXISTS (SELECT 1 FROM json_each(arr) WHERE value = elem). +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + w.WriteString("EXISTS (SELECT 1 FROM json_each(") + if err := writeArray(); err != nil { return err } - w.WriteString("))") + w.WriteString(") WHERE value = ") + if err := writeElem(); err != nil { + return err + } + w.WriteString(")") return nil } diff --git a/dialect/sqlite/validation.go b/dialect/sqlite/validation.go index 805c06c..6fb42c9 100644 --- a/dialect/sqlite/validation.go +++ b/dialect/sqlite/validation.go @@ -1,10 +1,9 @@ package sqlite import ( - "errors" - "fmt" "regexp" - "strings" + + "github.com/spandigital/cel2sql/v3/dialect/internal/identsafe" ) var ( @@ -50,19 +49,7 @@ var ( ) // validateFieldName validates that a field name follows SQLite naming conventions. +// SQLite imposes no practical identifier-length limit, so no maximum is enforced. func validateFieldName(name string) error { - if len(name) == 0 { - return errors.New("field name cannot be empty") - } - - // SQLite has no hard limit on identifier length but we use a reasonable limit - if !fieldNameRegexp.MatchString(name) { - return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) - } - - if reservedSQLKeywords[strings.ToLower(name)] { - return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) - } - - return nil + return identsafe.ValidateFieldName(name, "SQLite", 0, fieldNameRegexp, reservedSQLKeywords) } diff --git a/duckdb/provider.go b/duckdb/provider.go index e1b73b9..dfa6449 100644 --- a/duckdb/provider.go +++ b/duckdb/provider.go @@ -10,9 +10,9 @@ import ( "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/spandigital/cel2sql/v3/internal/celprovider" "github.com/spandigital/cel2sql/v3/schema" ) @@ -40,13 +40,13 @@ type TypeProvider interface { } type typeProvider struct { - schemas map[string]Schema - db *sql.DB + celprovider.Base + db *sql.DB } // NewTypeProvider creates a new DuckDB type provider with pre-defined schemas. func NewTypeProvider(schemas map[string]Schema) TypeProvider { - return &typeProvider{schemas: schemas} + return &typeProvider{Base: celprovider.Base{Schemas: schemas, Mapper: duckdbTypeToCELExprType}} } // NewTypeProviderWithConnection creates a new DuckDB type provider that can introspect database schemas. @@ -58,8 +58,8 @@ func NewTypeProviderWithConnection(_ context.Context, db *sql.DB) (TypeProvider, } return &typeProvider{ - schemas: make(map[string]Schema), - db: db, + Base: celprovider.Base{Schemas: make(map[string]Schema), Mapper: duckdbTypeToCELExprType}, + db: db, }, nil } @@ -102,7 +102,7 @@ func (tp *typeProvider) LoadTableSchema(ctx context.Context, tableName string) e return fmt.Errorf("%w: table %q has no columns or does not exist", ErrInvalidSchema, tableName) } - tp.schemas[tableName] = NewSchema(fields) + tp.Schemas[tableName] = NewSchema(fields) return nil } @@ -150,75 +150,6 @@ func normalizeDuckDBType(dataType string) string { return strings.ToLower(dataType) } -// Close is a no-op since we don't own the *sql.DB. -func (tp *typeProvider) Close() { - // No-op: caller owns the *sql.DB connection -} - -// GetSchemas returns the schemas known to this type provider. -func (tp *typeProvider) GetSchemas() map[string]Schema { - return tp.schemas -} - -// EnumValue implements types.Provider. -func (tp *typeProvider) EnumValue(_ string) ref.Val { - return types.NewErr("unknown enum value") -} - -// FindIdent implements types.Provider. -func (tp *typeProvider) FindIdent(_ string) (ref.Val, bool) { - return nil, false -} - -// FindStructType implements types.Provider. -func (tp *typeProvider) FindStructType(structType string) (*types.Type, bool) { - if _, ok := tp.schemas[structType]; ok { - return types.NewObjectType(structType), true - } - return nil, false -} - -// FindStructFieldNames implements types.Provider. -func (tp *typeProvider) FindStructFieldNames(structType string) ([]string, bool) { - s, ok := tp.schemas[structType] - if !ok { - return nil, false - } - fields := s.Fields() - names := make([]string, len(fields)) - for i, f := range fields { - names[i] = f.Name - } - return names, true -} - -// FindStructFieldType implements types.Provider. -func (tp *typeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { - s, ok := tp.schemas[structType] - if !ok { - return nil, false - } - field, found := s.FindField(fieldName) - if !found { - return nil, false - } - - exprType := duckdbTypeToCELExprType(field) - celType, err := types.ExprTypeToType(exprType) - if err != nil { - return nil, false - } - - return &types.FieldType{ - Type: celType, - }, true -} - -// NewValue implements types.Provider. -func (tp *typeProvider) NewValue(_ string, _ map[string]ref.Val) ref.Val { - return types.NewErr("unknown type in schema") -} - // duckdbTypeToCELExprType converts a DuckDB field schema to a CEL expression type. func duckdbTypeToCELExprType(field *schema.FieldSchema) *exprpb.Type { baseType := duckdbBaseTypeToCEL(field.Type) diff --git a/examples/basic/basic b/examples/basic/basic deleted file mode 100755 index e184486..0000000 Binary files a/examples/basic/basic and /dev/null differ diff --git a/examples/load_table_schema/load_table_schema b/examples/load_table_schema/load_table_schema deleted file mode 100755 index d927b76..0000000 Binary files a/examples/load_table_schema/load_table_schema and /dev/null differ diff --git a/internal/celprovider/base.go b/internal/celprovider/base.go new file mode 100644 index 0000000..545a901 --- /dev/null +++ b/internal/celprovider/base.go @@ -0,0 +1,100 @@ +// Package celprovider provides the shared cel-go types.Provider implementation +// embedded by the flat (non-nested-schema) SQL dialect type providers. +// +// The MySQL, SQLite, DuckDB, BigQuery and Spark providers previously each +// reimplemented the same types.Provider boilerplate over a map[string]Schema, +// differing only in their type-name → CEL mapping. Base centralises that +// boilerplate; a dialect embeds Base, supplies its Mapper and LoadTableSchema, +// and overrides Close only if it owns a connection. +// +// The PostgreSQL provider (pg) is intentionally not built on Base: it resolves +// nested/composite schemas through a dotted-path lookup and owns a connection +// pool, so it keeps its own implementation. +package celprovider + +import ( + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/spandigital/cel2sql/v3/schema" +) + +// TypeMapper converts a dialect field schema to a CEL expression type. +type TypeMapper func(*schema.FieldSchema) *exprpb.Type + +// Base implements the portion of cel-go's types.Provider shared by the flat +// dialect type providers. Embed it and set Schemas and Mapper at construction. +type Base struct { + // Schemas holds the table schemas known to the provider, keyed by table name. + Schemas map[string]schema.Schema + // Mapper converts a field schema to its CEL expression type. + Mapper TypeMapper +} + +// GetSchemas returns the schemas known to this provider. +func (b *Base) GetSchemas() map[string]schema.Schema { + return b.Schemas +} + +// EnumValue implements types.Provider. Schemas declare no enums. +func (b *Base) EnumValue(_ string) ref.Val { + return types.NewErr("unknown enum value") +} + +// FindIdent implements types.Provider. Schemas declare no identifiers. +func (b *Base) FindIdent(_ string) (ref.Val, bool) { + return nil, false +} + +// FindStructType implements types.Provider. +func (b *Base) FindStructType(structType string) (*types.Type, bool) { + if _, ok := b.Schemas[structType]; ok { + return types.NewObjectType(structType), true + } + return nil, false +} + +// FindStructFieldNames implements types.Provider. +func (b *Base) FindStructFieldNames(structType string) ([]string, bool) { + s, ok := b.Schemas[structType] + if !ok { + return nil, false + } + fields := s.Fields() + names := make([]string, len(fields)) + for i, f := range fields { + names[i] = f.Name + } + return names, true +} + +// FindStructFieldType implements types.Provider, mapping the field's SQL type +// to a CEL type via Mapper. +func (b *Base) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { + s, ok := b.Schemas[structType] + if !ok { + return nil, false + } + field, found := s.FindField(fieldName) + if !found { + return nil, false + } + + celType, err := types.ExprTypeToType(b.Mapper(field)) + if err != nil { + return nil, false + } + + return &types.FieldType{ + Type: celType, + }, true +} + +// NewValue implements types.Provider. Schema-described types are not constructible. +func (b *Base) NewValue(_ string, _ map[string]ref.Val) ref.Val { + return types.NewErr("unknown type in schema") +} + +// Close is a no-op. Providers that own a connection override it. +func (b *Base) Close() {} diff --git a/json_escaping_test.go b/json_escaping_test.go index f8db840..b55bef0 100644 --- a/json_escaping_test.go +++ b/json_escaping_test.go @@ -69,8 +69,8 @@ func TestJSONFieldNameEscaping_SingleQuote(t *testing.T) { func TestJSONFieldNameEscaping_Documentation(t *testing.T) { t.Log("This test documents that JSON field names are escaped in generated SQL") t.Log("Single quotes in field names would be escaped by doubling them: ' -> ''") - t.Log("The escapeJSONFieldName() function in utils.go handles this escaping") - t.Log("All JSON path operators (->, ->>, ?) use escapeJSONFieldName() for security") + t.Log("Each dialect's escapeJSONFieldName() (dialect//dialect.go) handles this escaping") + t.Log("All JSON path operators (->, ->>, ?) escape field names for security") } // TestJSONFieldNameEscaping_HasFunction tests escaping in has() macro for JSON existence checks @@ -120,22 +120,6 @@ func TestJSONFieldNameEscaping_HasFunction(t *testing.T) { } } -// TestEscapeJSONFieldNameFunction tests the escapeJSONFieldName utility function directly -func TestEscapeJSONFieldNameFunction(t *testing.T) { - // Note: We can't directly test the unexported function, but we verify its behavior - // through the integration tests above. This test documents the expected behavior. - - t.Log("The escapeJSONFieldName() function in utils.go escapes single quotes") - t.Log("Example: \"user's name\" -> \"user''s name\"") - t.Log("This prevents SQL injection when field names contain single quotes") - t.Log("The function is used in:") - t.Log(" - cel2sql.go — visitSelect() for -> and ->> operators") - t.Log(" - cel2sql.go — visitHasFunction() for ? and -> operators") - t.Log(" - cel2sql.go — visitNestedJSONHas() for jsonb_extract_path_text()") - t.Log(" - json.go — buildJSONPathForArray() for nested JSON paths") - t.Log(" - json.go — buildJSONPathInternal() for all JSON path construction") -} - // TestJSONFieldNameEscaping_SecurityImplications tests security aspects func TestJSONFieldNameEscaping_SecurityImplications(t *testing.T) { t.Log("Security Impact: SQL Injection Prevention") diff --git a/mysql/provider.go b/mysql/provider.go index 781f40b..45f04fb 100644 --- a/mysql/provider.go +++ b/mysql/provider.go @@ -10,9 +10,9 @@ import ( "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/spandigital/cel2sql/v3/internal/celprovider" "github.com/spandigital/cel2sql/v3/schema" "github.com/spandigital/cel2sql/v3/sqltypes" ) @@ -41,13 +41,13 @@ type TypeProvider interface { } type typeProvider struct { - schemas map[string]Schema - db *sql.DB + celprovider.Base + db *sql.DB } // NewTypeProvider creates a new MySQL type provider with pre-defined schemas. func NewTypeProvider(schemas map[string]Schema) TypeProvider { - return &typeProvider{schemas: schemas} + return &typeProvider{Base: celprovider.Base{Schemas: schemas, Mapper: mysqlTypeToCELExprType}} } // NewTypeProviderWithConnection creates a new MySQL type provider that can introspect database schemas. @@ -58,8 +58,8 @@ func NewTypeProviderWithConnection(_ context.Context, db *sql.DB) (TypeProvider, } return &typeProvider{ - schemas: make(map[string]Schema), - db: db, + Base: celprovider.Base{Schemas: make(map[string]Schema), Mapper: mysqlTypeToCELExprType}, + db: db, }, nil } @@ -102,7 +102,7 @@ func (tp *typeProvider) LoadTableSchema(ctx context.Context, tableName string) e return fmt.Errorf("%w: table %q has no columns or does not exist", ErrInvalidSchema, tableName) } - tp.schemas[tableName] = NewSchema(fields) + tp.Schemas[tableName] = NewSchema(fields) return nil } @@ -120,76 +120,6 @@ func mysqlColumnToFieldSchema(columnName, dataType, _ string) FieldSchema { } } -// Close is a no-op since we don't own the *sql.DB. -func (tp *typeProvider) Close() { - // No-op: caller owns the *sql.DB connection -} - -// GetSchemas returns the schemas known to this type provider. -func (tp *typeProvider) GetSchemas() map[string]Schema { - return tp.schemas -} - -// EnumValue implements types.Provider. -func (tp *typeProvider) EnumValue(_ string) ref.Val { - return types.NewErr("unknown enum value") -} - -// FindIdent implements types.Provider. -func (tp *typeProvider) FindIdent(_ string) (ref.Val, bool) { - return nil, false -} - -// FindStructType implements types.Provider. -func (tp *typeProvider) FindStructType(structType string) (*types.Type, bool) { - if _, ok := tp.schemas[structType]; ok { - return types.NewObjectType(structType), true - } - return nil, false -} - -// FindStructFieldNames implements types.Provider. -func (tp *typeProvider) FindStructFieldNames(structType string) ([]string, bool) { - s, ok := tp.schemas[structType] - if !ok { - return nil, false - } - fields := s.Fields() - names := make([]string, len(fields)) - for i, f := range fields { - names[i] = f.Name - } - return names, true -} - -// FindStructFieldType implements types.Provider. -func (tp *typeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { - s, ok := tp.schemas[structType] - if !ok { - return nil, false - } - field, found := s.FindField(fieldName) - if !found { - return nil, false - } - - exprType := mysqlTypeToCELExprType(field) - - celType, err := types.ExprTypeToType(exprType) - if err != nil { - return nil, false - } - - return &types.FieldType{ - Type: celType, - }, true -} - -// NewValue implements types.Provider. -func (tp *typeProvider) NewValue(_ string, _ map[string]ref.Val) ref.Val { - return types.NewErr("unknown type in schema") -} - // mysqlTypeToCELExprType converts a MySQL field schema to a CEL expression type. func mysqlTypeToCELExprType(field *schema.FieldSchema) *exprpb.Type { baseType := mysqlBaseTypeToCEL(field.Type) diff --git a/spark/provider.go b/spark/provider.go index 19df186..a7b2677 100644 --- a/spark/provider.go +++ b/spark/provider.go @@ -11,9 +11,9 @@ import ( "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/spandigital/cel2sql/v3/internal/celprovider" "github.com/spandigital/cel2sql/v3/schema" "github.com/spandigital/cel2sql/v3/sqltypes" ) @@ -47,13 +47,13 @@ type TypeProvider interface { } type typeProvider struct { - schemas map[string]Schema - db *sql.DB + celprovider.Base + db *sql.DB } // NewTypeProvider creates a new Spark SQL type provider with pre-defined schemas. func NewTypeProvider(schemas map[string]Schema) TypeProvider { - return &typeProvider{schemas: schemas} + return &typeProvider{Base: celprovider.Base{Schemas: schemas, Mapper: sparkTypeToCELExprType}} } // NewTypeProviderWithConnection creates a new Spark SQL type provider that can @@ -65,8 +65,8 @@ func NewTypeProviderWithConnection(_ context.Context, db *sql.DB) (TypeProvider, return nil, fmt.Errorf("%w: db connection must not be nil", ErrInvalidSchema) } return &typeProvider{ - schemas: make(map[string]Schema), - db: db, + Base: celprovider.Base{Schemas: make(map[string]Schema), Mapper: sparkTypeToCELExprType}, + db: db, }, nil } @@ -114,7 +114,7 @@ func (tp *typeProvider) LoadTableSchema(ctx context.Context, tableName string) e return fmt.Errorf("%w: table %q has no columns or does not exist", ErrInvalidSchema, tableName) } - tp.schemas[tableName] = NewSchema(fields) + tp.Schemas[tableName] = NewSchema(fields) return nil } @@ -198,9 +198,9 @@ func parseSparkStruct(dt string) ([]FieldSchema, bool) { // like decimal(10,2) appearing inside struct<…> definitions. func splitTopLevel(s string, sep byte) []string { var ( - out []string + out []string angle, paren int - start int + start int ) for i := 0; i < len(s); i++ { switch s[i] { @@ -264,71 +264,6 @@ func normalizeSparkType(t string) string { return t } -// Close is a no-op since we don't own the *sql.DB. -func (tp *typeProvider) Close() { - // No-op: caller owns the *sql.DB connection. -} - -// GetSchemas returns the schemas known to this type provider. -func (tp *typeProvider) GetSchemas() map[string]Schema { - return tp.schemas -} - -// EnumValue implements types.Provider. -func (tp *typeProvider) EnumValue(_ string) ref.Val { - return types.NewErr("unknown enum value") -} - -// FindIdent implements types.Provider. -func (tp *typeProvider) FindIdent(_ string) (ref.Val, bool) { - return nil, false -} - -// FindStructType implements types.Provider. -func (tp *typeProvider) FindStructType(structType string) (*types.Type, bool) { - if _, ok := tp.schemas[structType]; ok { - return types.NewObjectType(structType), true - } - return nil, false -} - -// FindStructFieldNames implements types.Provider. -func (tp *typeProvider) FindStructFieldNames(structType string) ([]string, bool) { - s, ok := tp.schemas[structType] - if !ok { - return nil, false - } - fields := s.Fields() - names := make([]string, len(fields)) - for i, f := range fields { - names[i] = f.Name - } - return names, true -} - -// FindStructFieldType implements types.Provider. -func (tp *typeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { - s, ok := tp.schemas[structType] - if !ok { - return nil, false - } - field, found := s.FindField(fieldName) - if !found { - return nil, false - } - exprType := sparkTypeToCELExprType(field) - celType, err := types.ExprTypeToType(exprType) - if err != nil { - return nil, false - } - return &types.FieldType{Type: celType}, true -} - -// NewValue implements types.Provider. -func (tp *typeProvider) NewValue(_ string, _ map[string]ref.Val) ref.Val { - return types.NewErr("unknown type in schema") -} - // sparkTypeToCELExprType converts a Spark field schema to a CEL expression type. func sparkTypeToCELExprType(field *schema.FieldSchema) *exprpb.Type { if field.Repeated { diff --git a/sqlite/provider.go b/sqlite/provider.go index 0877074..c8e21c8 100644 --- a/sqlite/provider.go +++ b/sqlite/provider.go @@ -11,9 +11,9 @@ import ( "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/spandigital/cel2sql/v3/internal/celprovider" "github.com/spandigital/cel2sql/v3/schema" ) @@ -44,13 +44,13 @@ type TypeProvider interface { } type typeProvider struct { - schemas map[string]Schema - db *sql.DB + celprovider.Base + db *sql.DB } // NewTypeProvider creates a new SQLite type provider with pre-defined schemas. func NewTypeProvider(schemas map[string]Schema) TypeProvider { - return &typeProvider{schemas: schemas} + return &typeProvider{Base: celprovider.Base{Schemas: schemas, Mapper: sqliteTypeToCELExprType}} } // NewTypeProviderWithConnection creates a new SQLite type provider that can introspect database schemas. @@ -61,8 +61,8 @@ func NewTypeProviderWithConnection(_ context.Context, db *sql.DB) (TypeProvider, } return &typeProvider{ - schemas: make(map[string]Schema), - db: db, + Base: celprovider.Base{Schemas: make(map[string]Schema), Mapper: sqliteTypeToCELExprType}, + db: db, }, nil } @@ -111,7 +111,7 @@ func (tp *typeProvider) LoadTableSchema(ctx context.Context, tableName string) e return fmt.Errorf("%w: table %q has no columns or does not exist", ErrInvalidSchema, tableName) } - tp.schemas[tableName] = NewSchema(fields) + tp.Schemas[tableName] = NewSchema(fields) return nil } @@ -179,75 +179,6 @@ func normalizeSQLiteType(colType string) string { return sqliteTypeText } -// Close is a no-op since we don't own the *sql.DB. -func (tp *typeProvider) Close() { - // No-op: caller owns the *sql.DB connection -} - -// GetSchemas returns the schemas known to this type provider. -func (tp *typeProvider) GetSchemas() map[string]Schema { - return tp.schemas -} - -// EnumValue implements types.Provider. -func (tp *typeProvider) EnumValue(_ string) ref.Val { - return types.NewErr("unknown enum value") -} - -// FindIdent implements types.Provider. -func (tp *typeProvider) FindIdent(_ string) (ref.Val, bool) { - return nil, false -} - -// FindStructType implements types.Provider. -func (tp *typeProvider) FindStructType(structType string) (*types.Type, bool) { - if _, ok := tp.schemas[structType]; ok { - return types.NewObjectType(structType), true - } - return nil, false -} - -// FindStructFieldNames implements types.Provider. -func (tp *typeProvider) FindStructFieldNames(structType string) ([]string, bool) { - s, ok := tp.schemas[structType] - if !ok { - return nil, false - } - fields := s.Fields() - names := make([]string, len(fields)) - for i, f := range fields { - names[i] = f.Name - } - return names, true -} - -// FindStructFieldType implements types.Provider. -func (tp *typeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { - s, ok := tp.schemas[structType] - if !ok { - return nil, false - } - field, found := s.FindField(fieldName) - if !found { - return nil, false - } - - exprType := sqliteTypeToCELExprType(field) - celType, err := types.ExprTypeToType(exprType) - if err != nil { - return nil, false - } - - return &types.FieldType{ - Type: celType, - }, true -} - -// NewValue implements types.Provider. -func (tp *typeProvider) NewValue(_ string, _ map[string]ref.Val) ref.Val { - return types.NewErr("unknown type in schema") -} - // sqliteTypeToCELExprType converts a SQLite field schema to a CEL expression type. func sqliteTypeToCELExprType(field *schema.FieldSchema) *exprpb.Type { baseType := sqliteBaseTypeToCEL(field.Type) diff --git a/testcases/json_tests.go b/testcases/json_tests.go index 37256e6..2025a93 100644 --- a/testcases/json_tests.go +++ b/testcases/json_tests.go @@ -34,6 +34,20 @@ func JSONTests() []ConvertTestCase { dialect.Spark: "get_json_object(get_json_object(product.metadata, '$.specs'), '$.color') = 'red'", }, }, + { + Name: "json_array_membership", + CELExpr: `"electronics" in product.tags`, + Category: CategoryJSON, + EnvSetup: EnvWithJSON, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "'electronics' = ANY(ARRAY(SELECT jsonb_array_elements_text(product.tags)))", + dialect.MySQL: "JSON_OVERLAPS(JSON_ARRAY('electronics'), product.tags)", + dialect.SQLite: "EXISTS (SELECT 1 FROM json_each(product.tags) WHERE value = 'electronics')", + dialect.DuckDB: "EXISTS (SELECT 1 FROM json_each(product.tags) WHERE value = 'electronics')", + dialect.BigQuery: "'electronics' IN UNNEST(JSON_VALUE_ARRAY(product.tags))", + dialect.Spark: "array_contains(from_json(product.tags, 'ARRAY'), 'electronics')", + }, + }, { Name: "json_has_field", CELExpr: `has(product.metadata.brand)`, diff --git a/utils.go b/utils.go index cf70fee..15988c7 100644 --- a/utils.go +++ b/utils.go @@ -180,9 +180,3 @@ func escapeLikePattern(pattern string) string { escaped = strings.ReplaceAll(escaped, `'`, `''`) return escaped } - -// escapeJSONFieldName escapes single quotes in JSON field names for safe use in PostgreSQL JSON path operators -// In PostgreSQL, single quotes within string literals must be escaped by doubling them -func escapeJSONFieldName(fieldName string) string { - return strings.ReplaceAll(fieldName, "'", "''") -} diff --git a/utils_test.go b/utils_test.go index 03e9844..658c307 100644 --- a/utils_test.go +++ b/utils_test.go @@ -326,63 +326,6 @@ func TestValidateFieldName_AllReservedKeywords(t *testing.T) { } } -func TestEscapeJSONFieldName(t *testing.T) { - tests := []struct { - name string - input string - expected string - }{ - { - name: "no quotes", - input: "fieldname", - expected: "fieldname", - }, - { - name: "single quote at start", - input: "'field", - expected: "''field", - }, - { - name: "single quote in middle", - input: "field'name", - expected: "field''name", - }, - { - name: "single quote at end", - input: "field'", - expected: "field''", - }, - { - name: "multiple single quotes", - input: "field'name'test", - expected: "field''name''test", - }, - { - name: "SQL injection attempt", - input: "' OR '1'='1", - expected: "'' OR ''1''=''1", - }, - { - name: "empty string", - input: "", - expected: "", - }, - { - name: "only single quote", - input: "'", - expected: "''", - }, - { - name: "double quotes not affected", - input: "field\"name", - expected: "field\"name", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := escapeJSONFieldName(tt.input) - require.Equal(t, tt.expected, result, "escapeJSONFieldName should properly escape: %s", tt.input) - }) - } -} +// Note: JSON field-name escaping is now implemented per-dialect (see +// dialect//dialect.go escapeJSONFieldName + their tests). The former +// top-level escapeJSONFieldName in utils.go was unused and has been removed.