diff --git a/README.md b/README.md index 8326dff..206af0f 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,36 @@ translator can't handle — typically a REGEX pattern the active dialect's regex flavour can't compile. Callers should fall back to `flag_engine.is_context_in_segment` for those segments. +## Bound parameters + +By default the translator inlines each segment value as an escaped SQL string literal. Pass a `Binder` on the `TranslateContext` to bind value-bearing literals as query parameters instead. + +```python +from flagsmith_sql_flag_engine import ( + Binder, + PyformatParamStyle, + TranslateContext, + translate_segment, +) +from flagsmith_sql_flag_engine.dialects import ClickHouseDialect + +binder = Binder(PyformatParamStyle()) +ctx = TranslateContext( + evaluation_context=eval_context, + dialect=ClickHouseDialect(), + binder=binder, +) +where_expr = translate_segment(segment, ctx) +``` + +Hand both to the driver: + +```python + cursor.execute(f"... WHERE ({where_expr})", binder.params) +``` + +Currently, `%`-prefixed style `PyformatParamStyle` and ClickHouse-specific `ClickHouseServerParamStyle` are supported. + ## Schema Each dialect publishes the table layout it expects via a `schema_ddl` diff --git a/src/flagsmith_sql_flag_engine/__init__.py b/src/flagsmith_sql_flag_engine/__init__.py index 1f2b1d3..a22fff1 100644 --- a/src/flagsmith_sql_flag_engine/__init__.py +++ b/src/flagsmith_sql_flag_engine/__init__.py @@ -4,11 +4,22 @@ translate_segment(segment, ctx) -> str | None TranslateContext +By default the translator inlines each segment value as an escaped SQL +string literal. Pass a `Binder` on the `TranslateContext` to bind +value-bearing literals as query parameters instead — read its params off +`Binder.params` after translation. See `flagsmith_sql_flag_engine.binder`. + See README.md for usage. The translator is dialect-aware via the `Dialect` protocol; `flagsmith_sql_flag_engine.dialects.clickhouse.ClickHouseDialect` is the only implementation today. """ +from flagsmith_sql_flag_engine.binder import ( + Binder, + ClickHouseServerParamStyle, + ParamStyle, + PyformatParamStyle, +) from flagsmith_sql_flag_engine.dialect import Dialect from flagsmith_sql_flag_engine.translator import ( TRANSLATABLE_OPERATORS, @@ -20,7 +31,11 @@ __all__ = [ "TRANSLATABLE_OPERATORS", + "Binder", + "ClickHouseServerParamStyle", "Dialect", + "ParamStyle", + "PyformatParamStyle", "TranslateContext", "translate_condition", "translate_rule", diff --git a/src/flagsmith_sql_flag_engine/binder.py b/src/flagsmith_sql_flag_engine/binder.py new file mode 100644 index 0000000..f26e3e6 --- /dev/null +++ b/src/flagsmith_sql_flag_engine/binder.py @@ -0,0 +1,50 @@ +from typing import Protocol + + +class ParamStyle(Protocol): + """A driver's placeholder syntax for a named bound parameter.""" + + def placeholder(self, name: str) -> str: + """The placeholder token referencing bound parameter `name`.""" + ... + + +class PyformatParamStyle: + """`%(name)s` + + Used by `clickhouse-driver` which substitutes parameters + client-side via `query % params`.""" + + def placeholder(self, name: str) -> str: + return f"%({name})s" + + +class ClickHouseServerParamStyle: + """`{name:String}` + + ClickHouse's native server-side parameter syntax, + used by `clickhouse-connect`.""" + + def placeholder(self, name: str) -> str: + return "{" + name + ":String}" + + +class Binder: + """Collects bound parameter values and mints their placeholders. + + Not thread-safe; use one `Binder` per predicate translation. + """ + + def __init__(self, style: ParamStyle, prefix: str = "") -> None: + self.params: dict[str, str] = {} + self._style = style + self._prefix = prefix + self._count = 0 + + def add(self, value: str) -> str: + """Record `value` under a fresh namespaced name and return its + placeholder token for the active paramstyle.""" + name = f"{self._prefix}p{self._count}" + self._count += 1 + self.params[name] = value + return self._style.placeholder(name) diff --git a/src/flagsmith_sql_flag_engine/dialect.py b/src/flagsmith_sql_flag_engine/dialect.py index 86802ff..b8de2d4 100644 --- a/src/flagsmith_sql_flag_engine/dialect.py +++ b/src/flagsmith_sql_flag_engine/dialect.py @@ -3,6 +3,8 @@ from typing import Protocol +from flagsmith_sql_flag_engine.binder import Binder + class Dialect(Protocol): """Per-dialect SQL fragments. @@ -10,6 +12,10 @@ class Dialect(Protocol): Methods return SQL string fragments. Inputs are already-formatted SQL strings (column refs, string literals); the dialect only chooses the right syntax for the operation. + + Methods that embed a segment- or context-derived value take an + optional `binder`: when provided, the value is emitted as a bound + query parameter rather than an inline literal. """ name: str # human-readable, used in test ids and error messages @@ -35,7 +41,14 @@ def trait_path(self, alias: str, trait_key: str) -> str: """ ... - def trait_eq(self, alias: str, trait_key: str, value: object, negate: bool) -> str: + def trait_eq( + self, + alias: str, + trait_key: str, + value: object, + negate: bool, + binder: Binder | None = None, + ) -> str: """Type-aware EQUAL / NOT_EQUAL predicate on a trait, mirroring `flag_engine`'s per-type coercion: the segment value is cast to the trait's runtime type before compare, and a cast failure @@ -45,7 +58,13 @@ def trait_eq(self, alias: str, trait_key: str, value: object, negate: bool) -> s """ ... - def trait_in(self, alias: str, trait_key: str, items: list[str]) -> str: + def trait_in( + self, + alias: str, + trait_key: str, + items: list[str], + binder: Binder | None = None, + ) -> str: """Type-aware IN predicate on a trait, mirroring engine semantics: string trait does direct lookup; integer trait stringifies and looks up; other trait types never match. `items` is the parsed @@ -77,14 +96,19 @@ def regex_supports(self, pattern: str) -> bool: to `flag_engine`.""" ... - def regexp_anchored_match(self, value_expr: str, pattern: str) -> str: + def regexp_anchored_match( + self, + value_expr: str, + pattern: str, + binder: Binder | None = None, + ) -> str: """Boolean: equivalent to Python `re.match(pattern, value)` — anchored at position 0, may be a prefix of the value, not a full-match. - `pattern` is the raw Python regex string; the dialect handles - its own escaping into a SQL literal, since regex flavours - differ in how backslashes are treated.""" + `pattern` is the raw Python regex string. With no `binder`, the + dialect handles its own escaping into a SQL literal, since regex + flavours differ in how backslashes are treated.""" ... def regexp_nth_digit_run(self, value_expr: str, n: int) -> str: diff --git a/src/flagsmith_sql_flag_engine/dialects/clickhouse.py b/src/flagsmith_sql_flag_engine/dialects/clickhouse.py index 2e7df12..184c8b0 100644 --- a/src/flagsmith_sql_flag_engine/dialects/clickhouse.py +++ b/src/flagsmith_sql_flag_engine/dialects/clickhouse.py @@ -76,7 +76,8 @@ ClickHouse Cloud as of 25.12 (no longer experimental on OSS 25.x). Callers should apply this setting at session creation.""" -from flagsmith_sql_flag_engine.utils import re2_safe, string_literal +from flagsmith_sql_flag_engine.binder import Binder +from flagsmith_sql_flag_engine.utils import bind_or_inline, re2_safe SCHEMA_DDL = """\ CREATE TABLE IF NOT EXISTS IDENTITIES ( @@ -154,10 +155,17 @@ def trait_path(self, alias: str, trait_key: str) -> str: sub = self._sub(alias, trait_key) return f"if({sub} IS NULL, NULL, toString({sub}))" - def trait_eq(self, alias: str, trait_key: str, value: object, negate: bool) -> str: + def trait_eq( + self, + alias: str, + trait_key: str, + value: object, + negate: bool, + binder: Binder | None = None, + ) -> str: sub = self._sub(alias, trait_key) str_value = str(value) - str_lit = string_literal(str_value) + str_lit = bind_or_inline(binder, str_value) # Engine bool cast: `v not in ("False", "false")`. A JSON true matches # every segment value except literal "False" / "false"; those two coerce # to False and match a JSON false. @@ -223,7 +231,13 @@ def trait_eq(self, alias: str, trait_key: str, value: object, negate: bool) -> s f"(({str_sub} IS NOT NULL AND {str_sub} <> {str_lit}) OR {bool_branch} OR {num_branch})" ) - def trait_in(self, alias: str, trait_key: str, items: list[str]) -> str: + def trait_in( + self, + alias: str, + trait_key: str, + items: list[str], + binder: Binder | None = None, + ) -> str: # `toString()` returns the canonical string form for any JSON # value type in a single subcolumn read. Engine semantics only # match String and integer trait types — bool / float / array @@ -235,7 +249,7 @@ def trait_in(self, alias: str, trait_key: str, items: list[str]) -> str: bool_sub = f"{sub}.:Bool" float_sub = f"{sub}.:Float64" str_path = f"toString({sub})" - item_lits = ",".join(string_literal(v) for v in items) + item_lits = ",".join(bind_or_inline(binder, v) for v in items) return f"({bool_sub} IS NULL AND {float_sub} IS NULL AND {str_path} IN ({item_lits}))" # ----- string operations ----- @@ -267,13 +281,21 @@ def _regex_literal(pattern: str) -> str: doubled = pattern.replace("\\", "\\\\").replace("'", "''") return f"'{doubled}'" - def regexp_anchored_match(self, value_expr: str, pattern: str) -> str: + def regexp_anchored_match( + self, value_expr: str, pattern: str, binder: Binder | None = None + ) -> str: # `match` is RE2 but unanchored — equivalent to `re.search`. Prepend # `^` to get `re.match` semantics (start-anchored, prefix-allowed). # Wrapping in `(...)` keeps the user's top-level alternation from # binding tighter than the anchor. anchored = "^(" + pattern + ")" - return f"match({_non_null(value_expr)}, {self._regex_literal(anchored)})" + # Bind the raw pattern when a binder is active: the driver escapes + # it, and — crucially — no `%` from a character class like + # `[a-z%]` lands in the query text to trip a `%`-substituting + # driver. Inline, `_regex_literal` doubles backslashes so RE2 sees + # the pattern the segment author wrote. + pattern_lit = binder.add(anchored) if binder is not None else self._regex_literal(anchored) + return f"match({_non_null(value_expr)}, {pattern_lit})" def regexp_nth_digit_run(self, value_expr: str, n: int) -> str: # `extractAll` returns the matches array; subscript is 1-indexed diff --git a/src/flagsmith_sql_flag_engine/translator.py b/src/flagsmith_sql_flag_engine/translator.py index 1ba73f0..0435b61 100644 --- a/src/flagsmith_sql_flag_engine/translator.py +++ b/src/flagsmith_sql_flag_engine/translator.py @@ -22,12 +22,12 @@ from flag_engine.segments.evaluator import is_context_in_segment from flag_engine.segments.types import ConditionOperator +from flagsmith_sql_flag_engine.binder import Binder from flagsmith_sql_flag_engine.dialect import Dialect from flagsmith_sql_flag_engine.utils import ( - escape_string, + bind_or_inline, modulo_literal, numeric_literal, - string_literal, ) TRANSLATABLE_OPERATORS: frozenset[ConditionOperator] = frozenset( @@ -75,7 +75,10 @@ class TranslateContext: being configured here. `identities_alias` is the table alias for `IDENTITIES` in the surrounding query — defaults to `i`. `segment_key` salts `PERCENTAGE_SPLIT` and is auto-injected from - the segment's `key` field by `translate_segment`. + the segment's `key` field by `translate_segment`. `binder`, when + provided, promotes value-bearing literals to bound query parameters + instead of inlining them — see `flagsmith_sql_flag_engine.binder`; + the default `None` inlines them as escaped SQL string literals. """ def __init__( @@ -84,11 +87,13 @@ def __init__( dialect: Dialect, identities_alias: str = "i", segment_key: str | None = None, + binder: Binder | None = None, ) -> None: self.evaluation_context = evaluation_context self.dialect = dialect self.identities_alias = identities_alias self.segment_key = segment_key + self.binder = binder @property def identity_key_expr(self) -> str: @@ -114,6 +119,7 @@ def with_segment_key(self, key: str) -> "TranslateContext": dialect=self.dialect, identities_alias=self.identities_alias, segment_key=key, + binder=self.binder, ) @@ -133,7 +139,7 @@ def _percentage_split_expr( the engine recurses with doubled input; we don't. """ d = ctx.dialect - seg_lit = string_literal(seg_key) + seg_lit = bind_or_inline(ctx.binder, seg_key) hash_subject = f"{seg_lit} || ',' || ({ctx_value_sql})" h = d.md5_hex(hash_subject) s1 = d.parse_hex_chunk(h, 1) @@ -289,18 +295,19 @@ def _comparison( if value is None: return "FALSE" d = ctx.dialect - lit = string_literal(str(value)) str_expr = expr if is_jsonpath else d.cast_string(expr) + # Bind the operand lazily if op == "EQUAL": - return f"{str_expr} = {lit}" + return f"{str_expr} = {bind_or_inline(ctx.binder, str(value))}" if op == "NOT_EQUAL": - return f"{str_expr} <> {lit}" + return f"{str_expr} <> {bind_or_inline(ctx.binder, str(value))}" if op == "IN": - items = "','".join(escape_string(v.strip()) for v in str(value).split(",")) - return f"{str_expr} IN ('{items}')" + items = ",".join(bind_or_inline(ctx.binder, v.strip()) for v in str(value).split(",")) + return f"{str_expr} IN ({items})" if op == "CONTAINS": - return d.position(lit, str_expr) + return d.position(bind_or_inline(ctx.binder, str(value)), str_expr) if op == "NOT_CONTAINS": + lit = bind_or_inline(ctx.binder, str(value)) return f"({expr} IS NOT NULL AND NOT ({d.position(lit, str_expr)}))" if op in {"GREATER_THAN", "LESS_THAN", "GREATER_THAN_INCLUSIVE", "LESS_THAN_INCLUSIVE"}: numeric_lit = numeric_literal(value) @@ -327,7 +334,7 @@ def _comparison( pattern = str(value) if not d.regex_supports(pattern): return None - return f"({expr} IS NOT NULL AND {d.regexp_anchored_match(str_expr, pattern)})" + return f"({expr} IS NOT NULL AND {d.regexp_anchored_match(str_expr, pattern, ctx.binder)})" raise AssertionError( # pragma: no cover - all TRANSLATABLE_OPERATORS handled above f"unhandled translatable operator in _comparison: {op}" ) @@ -369,7 +376,7 @@ def _translate_trait_op( # content, which is what the fall-through handlers already do. if isinstance(val, str) and val.endswith(":semver") and op in _SEMVER_OPS: bare = val[:-7] - bare_lit = string_literal(bare) + bare_lit = bind_or_inline(ctx.binder, bare) col_str = ctx.dialect.cast_string(path) return ( f"({path} IS NOT NULL AND " @@ -382,7 +389,13 @@ def _translate_trait_op( # casts, and short-circuit pitfalls are all engine-specific. if op in {"EQUAL", "NOT_EQUAL"} and val is not None: negate = op == "NOT_EQUAL" - eq_pred = ctx.dialect.trait_eq(ctx.identities_alias, trait_key, val, negate=negate) + eq_pred = ctx.dialect.trait_eq( + ctx.identities_alias, + trait_key, + val, + negate=negate, + binder=ctx.binder, + ) return f"({path} IS NOT NULL AND {eq_pred})" if op == "IN": items = _engine_in_values(val) @@ -390,7 +403,12 @@ def _translate_trait_op( # Bad IN value — neither a string nor a list. Engine returns # False. return "FALSE" - in_pred = ctx.dialect.trait_in(ctx.identities_alias, trait_key, items) + in_pred = ctx.dialect.trait_in( + ctx.identities_alias, + trait_key, + items, + binder=ctx.binder, + ) return f"({path} IS NOT NULL AND {in_pred})" return _comparison(ctx, op, path, val, is_jsonpath=False) diff --git a/src/flagsmith_sql_flag_engine/utils.py b/src/flagsmith_sql_flag_engine/utils.py index b7c6800..cbe6d95 100644 --- a/src/flagsmith_sql_flag_engine/utils.py +++ b/src/flagsmith_sql_flag_engine/utils.py @@ -21,6 +21,21 @@ """ import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from flagsmith_sql_flag_engine.binder import Binder + + +def bind_or_inline(binder: "Binder | None", value: str) -> str: + """Return a bound-parameter placeholder for `value` when a `binder` is + active, else a single-quoted SQL string literal. + + The single seam that lets a value-bearing literal either be inlined + (the default) or promoted to a query parameter, chosen per call site + by whether the caller threaded a `Binder` onto the context. + """ + return binder.add(value) if binder is not None else string_literal(value) def escape_string(value: str) -> str: diff --git a/tests/test_binder_clickhouse.py b/tests/test_binder_clickhouse.py new file mode 100644 index 0000000..9b75d4d --- /dev/null +++ b/tests/test_binder_clickhouse.py @@ -0,0 +1,79 @@ +import uuid +from collections.abc import Iterator +from typing import Any + +import pytest +from clickhouse_connect.driver import Client +from flag_engine.context.types import EvaluationContext, SegmentContext + +from flagsmith_sql_flag_engine import ( + Binder, + ClickHouseServerParamStyle, + TranslateContext, + translate_segment, +) +from flagsmith_sql_flag_engine.dialects.clickhouse import ClickHouseDialect +from tests.harnesses.clickhouse import ClickHouseHarness + + +@pytest.fixture +def clickhouse_session() -> Iterator[Client]: + with ClickHouseHarness().session() as session: + yield session + + +@pytest.fixture +def scratch_identities_table(clickhouse_session: Client) -> Iterator[str]: + table = f"TEST_BINDER_{uuid.uuid4().hex[:8]}" + clickhouse_session.command( + f"CREATE TABLE {table} (environment_id String, id UInt64, identifier String," + f" identity_key String, traits JSON) ENGINE = Memory" + ) + try: + yield table + finally: + clickhouse_session.command(f"DROP TABLE IF EXISTS {table}") + + +def test_translate_segment__bound_percent_regex__matches_in_clickhouse( + clickhouse_session: Any, + scratch_identities_table: str, +) -> None: + # Given + table = scratch_identities_table + clickhouse_session.command( + f"INSERT INTO {table} (environment_id, id, identifier, identity_key, traits) VALUES" + f" ('e', 1, 'match', 'k1', '{{\"email\": \"ada%lovelace@example.com\"}}'::JSON)," + f" ('e', 2, 'no-match', 'k2', '{{\"email\": \"123@example.com\"}}'::JSON)" + ) + seg: SegmentContext = { + "key": "1", + "name": "s", + "rules": [ + { + "type": "ALL", + "conditions": [ + { + "operator": "REGEX", + "property": "email", + "value": r"[a-z%]+@example\.com", + } + ], + } + ], + } + binder = Binder(ClickHouseServerParamStyle()) + eval_ctx: EvaluationContext = {"environment": {"key": "e", "name": "Test"}} + predicate = translate_segment( + seg, TranslateContext(eval_ctx, ClickHouseDialect(), binder=binder) + ) + assert predicate is not None + + # When + rows = clickhouse_session.query( + f"SELECT i.identifier FROM {table} i WHERE ({predicate}) ORDER BY i.identifier", + parameters=binder.params, + ).result_rows + + # Then + assert [row[0] for row in rows] == ["match"] diff --git a/tests/test_binder_unit.py b/tests/test_binder_unit.py new file mode 100644 index 0000000..1ece51a --- /dev/null +++ b/tests/test_binder_unit.py @@ -0,0 +1,258 @@ +import re +from collections.abc import Callable +from typing import cast + +import pytest +from flag_engine.context.types import EvaluationContext, SegmentContext + +from flagsmith_sql_flag_engine import ( + Binder, + PyformatParamStyle, + TranslateContext, + translate_segment, +) +from flagsmith_sql_flag_engine.binder import ( + ClickHouseServerParamStyle, +) +from flagsmith_sql_flag_engine.dialects.clickhouse import ClickHouseDialect + +MakeContextFixture = Callable[[Binder | None], TranslateContext] +MakeSegmentFixture = Callable[[str, str, object], SegmentContext] + + +@pytest.fixture +def make_ctx() -> MakeContextFixture: + """Factory for a ClickHouse `TranslateContext` with the given binder.""" + + def _make(binder: Binder | None) -> TranslateContext: + eval_ctx: EvaluationContext = {"environment": {"key": "e", "name": "Test"}} + return TranslateContext( + evaluation_context=eval_ctx, + dialect=ClickHouseDialect(), + binder=binder, + ) + + return _make + + +@pytest.fixture +def make_segment() -> MakeSegmentFixture: + """Factory for a single-condition segment over one `ALL` rule.""" + + def _make(operator: str, prop: str, value: object) -> SegmentContext: + return cast( + SegmentContext, + { + "key": "1", + "name": "s", + "rules": [ + { + "type": "ALL", + "conditions": [{"operator": operator, "property": prop, "value": value}], + } + ], + }, + ) + + return _make + + +def test_binder__pyformat_style__mints_sequential_placeholders_and_records_values() -> None: + # Given + binder = Binder(PyformatParamStyle()) + + # When + first = binder.add("growth") + second = binder.add("scale") + + # Then + assert first == "%(p0)s" + assert second == "%(p1)s" + assert binder.params == {"p0": "growth", "p1": "scale"} + + +def test_binder__clickhouse_server_style__mints_typed_placeholders() -> None: + # Given + binder = Binder(ClickHouseServerParamStyle()) + + # When + placeholder = binder.add("growth") + + # Then + assert placeholder == "{p0:String}" + assert binder.params == {"p0": "growth"} + + +def test_binder__prefix__namespaces_parameter_names() -> None: + # Given + binder_a = Binder(PyformatParamStyle(), prefix="s13_") + binder_b = Binder(PyformatParamStyle(), prefix="s14_") + + # When + a = binder_a.add("x") + b = binder_b.add("y") + + # Then + assert a == "%(s13_p0)s" + assert b == "%(s14_p0)s" + assert binder_a.params.keys().isdisjoint(binder_b.params.keys()) + + +def test_binder__value_with_percent__stored_verbatim() -> None: + # Given + # a value containing a `%` + value = "[a-z%]+@example.com" + binder = Binder(PyformatParamStyle()) + + # When + placeholder = binder.add(value) + + # Then + assert placeholder == "%(p0)s" + assert binder.params == {"p0": "[a-z%]+@example.com"} + + +def test_translate_segment__equal_with_binder__binds_operand( + make_segment: MakeSegmentFixture, + make_ctx: MakeContextFixture, +) -> None: + # Given + binder = Binder(PyformatParamStyle()) + + # When + sql = translate_segment(make_segment("EQUAL", "plan", "growth"), make_ctx(binder)) + + # Then + assert sql is not None + assert "toString(i.traits.`plan`) = %(p0)s" in sql + assert "'growth'" not in sql + assert binder.params == {"p0": "growth"} + + +def test_translate_segment__in_with_binder__binds_each_item( + make_segment: MakeSegmentFixture, + make_ctx: MakeContextFixture, +) -> None: + # Given + binder = Binder(PyformatParamStyle()) + + # When + sql = translate_segment(make_segment("IN", "country", "GB,US,DE"), make_ctx(binder)) + + # Then + assert sql is not None + assert "IN (%(p0)s,%(p1)s,%(p2)s)" in sql + assert binder.params == {"p0": "GB", "p1": "US", "p2": "DE"} + + +def test_translate_segment__contains_with_binder__binds_needle( + make_segment: MakeSegmentFixture, + make_ctx: MakeContextFixture, +) -> None: + # Given + binder = Binder(PyformatParamStyle()) + + # When + sql = translate_segment(make_segment("CONTAINS", "country", "G"), make_ctx(binder)) + + # Then + assert sql is not None + assert "%(p0)s) > 0" in sql + assert binder.params == {"p0": "G"} + + +def test_translate_segment__not_equal_trait_with_binder__binds_string_operand( + make_segment: MakeSegmentFixture, + make_ctx: MakeContextFixture, +) -> None: + # Given + binder = Binder(PyformatParamStyle()) + + # When + sql = translate_segment(make_segment("NOT_EQUAL", "plan", "growth"), make_ctx(binder)) + + # Then + assert sql is not None + assert "%(p0)s" in sql + assert "'growth'" not in sql + assert binder.params == {"p0": "growth"} + + +def test_translate_segment__semver_with_binder__binds_bare_version( + make_segment: MakeSegmentFixture, + make_ctx: MakeContextFixture, +) -> None: + # Given + binder = Binder(PyformatParamStyle()) + + # When + sql = translate_segment(make_segment("EQUAL", "version", "1.2.3:semver"), make_ctx(binder)) + + # Then + assert sql is not None + assert "%(p0)s" in sql + assert binder.params == {"p0": "1.2.3"} + + +def test_translate_segment__percentage_split_with_binder__binds_segment_key_salt( + make_ctx: MakeContextFixture, +) -> None: + # Given + binder = Binder(PyformatParamStyle()) + seg: SegmentContext = { + "key": "cohort-42", + "name": "s", + "rules": [ + { + "type": "ALL", + "conditions": [{"operator": "PERCENTAGE_SPLIT", "property": "", "value": "50"}], + } + ], + } + + # When + sql = translate_segment(seg, make_ctx(binder)) + + # Then + assert sql is not None + assert "%(p0)s" in sql + assert "<= 50.0" in sql + assert binder.params == {"p0": "cohort-42"} + + +def test_translate_segment__prefix__namespaces_bound_names( + make_segment: MakeSegmentFixture, + make_ctx: MakeContextFixture, +) -> None: + # Given + binder = Binder(PyformatParamStyle(), prefix="s13_") + + # When + sql = translate_segment(make_segment("EQUAL", "plan", "growth"), make_ctx(binder)) + + # Then + assert sql is not None + assert "%(s13_p0)s" in sql + assert binder.params == {"s13_p0": "growth"} + + +def test_translate_segment__regex_with_percent__binder_survives_pyformat_substitution( + make_segment: MakeSegmentFixture, make_ctx: MakeContextFixture +) -> None: + # Given + binder = Binder(PyformatParamStyle()) + seg = make_segment("REGEX", "email", r"[a-z%]+@example\.com") + + # When + param_sql = translate_segment(seg, make_ctx(binder)) + assert param_sql is not None + + # Then + # no stray `%` remains in the query text + assert re.compile(r"%\([^)]+\)s").sub("", param_sql).find("%") == -1 + assert binder.params == {"p0": r"^([a-z%]+@example\.com)"} + + # and the full `query % params` substitution succeeds + query = f"i.environment_id IN %(env_keys)s AND ({param_sql})" + rendered = query % {"env_keys": (1, 2), **binder.params} + assert r"[a-z%]+@example\.com" in rendered diff --git a/tests/test_clickhouse_dialect_unit.py b/tests/test_clickhouse_dialect_unit.py index dd35d34..1f5469c 100644 --- a/tests/test_clickhouse_dialect_unit.py +++ b/tests/test_clickhouse_dialect_unit.py @@ -37,3 +37,30 @@ def test_translate_segment__contains_on_trait__emits_clickhouse_position() -> No assert sql is not None assert "position(toString(" in sql assert ", 'growth') > 0" in sql + + +def test_translate_segment__regex_inline__emits_anchored_match_with_escaped_literal() -> None: + # Given an RE2-safe REGEX on a trait and no binder (the default inline path) + seg: SegmentContext = { + "key": "ch2", + "name": "s", + "rules": [ + { + "type": "ALL", + "conditions": [ + {"operator": "REGEX", "property": "email", "value": r"[a-z]+@example\.com"} + ], + } + ], + } + eval_ctx: EvaluationContext = {"environment": {"key": "e", "name": "Test"}} + ctx = TranslateContext(evaluation_context=eval_ctx, dialect=ClickHouseDialect()) + + # When we translate the segment + sql = translate_segment(seg, ctx) + + # Then the pattern is inlined as an anchored, backslash-doubled `match` + # literal — the branch a binder would replace with a bound parameter + assert sql is not None + assert r"match(ifNull(toString(" in sql + assert r"'^([a-z]+@example\\.com)')" in sql