Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""Rewrite default return values in ReactNativeFeatureFlagsDefaults.kt.

Reads the Kotlin source from --input, writes the transformed source to stdout.
Overrides are passed as a JSON object via --overrides.
Fails with a non-zero exit code if any requested flag is not found.
"""

from __future__ import annotations

import argparse
import json
import re
import sys


def kotlin_literal(value: bool | int | float) -> str:
if isinstance(value, bool):
return "true" if value else "false"
if isinstance(value, (int, float)):
s = str(value)
if isinstance(value, int) or "." not in s:
s += ".0"
return s
raise ValueError(f"Unsupported value type {type(value).__name__} for override")


def rewrite(source: bytes, overrides: dict[str, object]) -> bytes:
text = source.decode("utf-8")
for name, value in overrides.items():
kotlin_type = "Boolean" if isinstance(value, bool) else "Double"
pattern = rf"""
(
override \s+ fun \s+
{re.escape(name)}
\s* \( \s* \)
\s* : \s* {kotlin_type}
\s* = \s*
)
\S+
"""
text, n = re.subn(
pattern,
rf"\g<1>{kotlin_literal(value)}",
text,
count=1,
flags=re.VERBOSE,
)
if n != 1:
raise ValueError(f"{name} not matched")

return text.encode("utf-8")


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--overrides", default="{}")
parser.add_argument("--input", required=True)
args = parser.parse_args()

overrides: dict[str, object] = json.loads(args.overrides)
with open(args.input, "rb") as f:
source = f.read()

sys.stdout.buffer.write(rewrite(source, overrides))


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

import os
import unittest

from rewrite_feature_flag_defaults import kotlin_literal, rewrite


def _load_source() -> bytes:
with open(os.environ["SOURCE_PATH"], "rb") as f:
return f.read()


class RewriteFeatureFlagDefaultsTest(unittest.TestCase):
def setUp(self) -> None:
self.source = _load_source()

def test_empty_overrides_is_passthrough(self) -> None:
self.assertEqual(rewrite(self.source, {}), self.source)

def test_override_bool_to_true(self) -> None:
result = rewrite(self.source, {"commonTestFlag": True})
self.assertEqual(self._method_value(result, "commonTestFlag"), b"true")

def test_override_bool_to_false(self) -> None:
result = rewrite(self.source, {"commonTestFlag": False})
self.assertEqual(self._method_value(result, "commonTestFlag"), b"false")

def test_kotlin_literal_int_produces_double(self) -> None:
self.assertEqual(kotlin_literal(42), "42.0")

def test_kotlin_literal_float(self) -> None:
self.assertEqual(kotlin_literal(3.14), "3.14")

def test_unmatched_flag_raises(self) -> None:
with self.assertRaises(ValueError):
rewrite(self.source, {"bogusFlag": True})

def test_only_target_method_changes(self) -> None:
result = rewrite(self.source, {"commonTestFlag": True})
src_start, src_end = self._method_value_range(self.source, "commonTestFlag")
res_start, res_end = self._method_value_range(result, "commonTestFlag")
self.assertEqual(self.source[:src_start], result[:res_start])
self.assertEqual(self.source[src_end:], result[res_end:])

def _method_value_range(self, source: bytes, name: str) -> tuple[int, int]:
name_idx = source.find(name.encode())
self.assertNotEqual(name_idx, -1, f"{name} not found in output")
eq_idx = source.find(b"=", name_idx)
eol_idx = source.find(b"\n", eq_idx)
start = eq_idx + 2 # skip "= "
return (start, eol_idx)

def _method_value(self, source: bytes, name: str) -> bytes:
start, end = self._method_value_range(source, name)
return source[start:end].strip()
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,32 @@
import re
import sys

import tree_sitter_cpp
from tree_sitter import Language, Parser, Query, QueryCursor


_TARGET_CLASS = "ReactNativeFeatureFlagsDefaults"


def _method_query(names: set[str]) -> str:
alternation = "|".join(re.escape(n) for n in sorted(names))
return f"""
(class_specifier
name: (type_identifier) @class_name
body: (field_declaration_list
(function_definition
declarator: (function_declarator
declarator: (field_identifier) @method_name)
body: (compound_statement
(return_statement (_) @return_value)))
)
(#eq? @class_name "{_TARGET_CLASS}")
(#match? @method_name "^({alternation})$")
)
"""


def cxx_literal(value: object) -> str:
def cxx_literal(value: bool | int | float) -> str:
if isinstance(value, bool):
return "true" if value else "false"
if isinstance(value, (int, float)):
Expand All @@ -33,33 +57,41 @@ def cxx_literal(value: object) -> str:


def rewrite(source: bytes, overrides: dict[str, object]) -> bytes:
text = source.decode("utf-8")
for name, value in overrides.items():
cxx_type = "bool" if isinstance(value, bool) else "double"
pattern = rf"""
( # group 1: everything up to the value
{cxx_type} \s+ # return type
{re.escape(name)} # method name
\s* \( \s* \) # parameter list
\s+ override # override specifier
\s* \{{ # opening brace
[^}}]*? # body before the return (non-greedy, no nested braces)
return \s+ # return keyword
lang = Language(tree_sitter_cpp.language())
tree = Parser(lang).parse(source)
matches = QueryCursor(Query(lang, _method_query(overrides.keys()))).matches(
tree.root_node
)

matched: set[str] = set()
replacements: list[tuple[int, int, bytes]] = []

for _, match in matches:
method_node = match["method_name"][0]
name = source[method_node.start_byte : method_node.end_byte].decode("utf-8")
rv_node = match["return_value"][0]
replacements.append(
(
rv_node.start_byte,
rv_node.end_byte,
cxx_literal(overrides[name]).encode("utf-8"),
)
[^;]+ # the value to replace
( \s* ; ) # group 2: semicolon
"""
text, n = re.subn(
pattern,
rf"\g<1>{cxx_literal(value)}\2",
text,
count=1,
flags=re.DOTALL | re.VERBOSE,
)
if n != 1:
raise ValueError(f"{name} not matched")
matched.add(name)

unmatched = set(overrides.keys()) - matched
if unmatched:
raise ValueError(f"Unmatched flags: {', '.join(sorted(unmatched))}")

result = bytearray()
pos = 0
for start, end, replacement in replacements:
result.extend(source[pos:start])
result.extend(replacement)
pos = end
result.extend(source[pos:])

return text.encode("utf-8")
return bytes(result)


def main() -> None:
Expand Down
Loading