diff --git a/packages/react-native/ReactCommon/react/featureflags/rewrite_feature_flag_defaults.py b/packages/react-native/ReactCommon/react/featureflags/rewrite_feature_flag_defaults.py index ed146edd4534..1cb9582aa78e 100644 --- a/packages/react-native/ReactCommon/react/featureflags/rewrite_feature_flag_defaults.py +++ b/packages/react-native/ReactCommon/react/featureflags/rewrite_feature_flag_defaults.py @@ -20,8 +20,32 @@ import re import sys +import tree_sitter_cpp +from tree_sitter import Language, Parser, Query, QueryCursor + + +_TARGET_CLASS = "ReactNativeFeatureFlagsDefaults" + + +def _method_query(names: set[str]) -> str: + alternation = "|".join(re.escape(n) for n in sorted(names)) + return f""" +(class_specifier + name: (type_identifier) @class_name + body: (field_declaration_list + (function_definition + declarator: (function_declarator + declarator: (field_identifier) @method_name) + body: (compound_statement + (return_statement (_) @return_value))) + ) + (#eq? @class_name "{_TARGET_CLASS}") + (#match? @method_name "^({alternation})$") +) +""" + -def cxx_literal(value: object) -> str: +def cxx_literal(value: bool | int | float) -> str: if isinstance(value, bool): return "true" if value else "false" if isinstance(value, (int, float)): @@ -33,33 +57,41 @@ def cxx_literal(value: object) -> str: def rewrite(source: bytes, overrides: dict[str, object]) -> bytes: - text = source.decode("utf-8") - for name, value in overrides.items(): - cxx_type = "bool" if isinstance(value, bool) else "double" - pattern = rf""" - ( # group 1: everything up to the value - {cxx_type} \s+ # return type - {re.escape(name)} # method name - \s* \( \s* \) # parameter list - \s+ override # override specifier - \s* \{{ # opening brace - [^}}]*? # body before the return (non-greedy, no nested braces) - return \s+ # return keyword + lang = Language(tree_sitter_cpp.language()) + tree = Parser(lang).parse(source) + matches = QueryCursor(Query(lang, _method_query(overrides.keys()))).matches( + tree.root_node + ) + + matched: set[str] = set() + replacements: list[tuple[int, int, bytes]] = [] + + for _, match in matches: + method_node = match["method_name"][0] + name = source[method_node.start_byte : method_node.end_byte].decode("utf-8") + rv_node = match["return_value"][0] + replacements.append( + ( + rv_node.start_byte, + rv_node.end_byte, + cxx_literal(overrides[name]).encode("utf-8"), ) - [^;]+ # the value to replace - ( \s* ; ) # group 2: semicolon - """ - text, n = re.subn( - pattern, - rf"\g<1>{cxx_literal(value)}\2", - text, - count=1, - flags=re.DOTALL | re.VERBOSE, ) - if n != 1: - raise ValueError(f"{name} not matched") + matched.add(name) + + unmatched = set(overrides.keys()) - matched + if unmatched: + raise ValueError(f"Unmatched flags: {', '.join(sorted(unmatched))}") + + result = bytearray() + pos = 0 + for start, end, replacement in replacements: + result.extend(source[pos:start]) + result.extend(replacement) + pos = end + result.extend(source[pos:]) - return text.encode("utf-8") + return bytes(result) def main() -> None: