diff --git a/code_review_graph/incremental.py b/code_review_graph/incremental.py index a68b2e11..f4dd2129 100644 --- a/code_review_graph/incremental.py +++ b/code_review_graph/incremental.py @@ -36,6 +36,30 @@ def _run_rescript_resolver(store: GraphStore) -> Optional[dict]: logger.warning("ReScript cross-module resolver failed: %s", exc) return None + +def _run_spring_resolver(store: GraphStore) -> Optional[dict]: + """Run the Spring DI call resolver, swallowing any failure so + build never fails because of it. Returns stats or None on error. + """ + try: + from .spring_resolver import resolve_spring_di_calls + return resolve_spring_di_calls(store) + except Exception as exc: # noqa: BLE001 - best-effort post-pass + logger.warning("Spring DI resolver failed: %s", exc) + return None + + +def _run_temporal_resolver(store: GraphStore) -> Optional[dict]: + """Run the Temporal workflow/activity call resolver, swallowing any failure so + build never fails because of it. Returns stats or None on error. + """ + try: + from .temporal_resolver import resolve_temporal_calls + return resolve_temporal_calls(store) + except Exception as exc: # noqa: BLE001 - best-effort post-pass + logger.warning("Temporal resolver failed: %s", exc) + return None + # Default ignore patterns (in addition to .gitignore). # # `/**` patterns are matched at any depth by _should_ignore, so @@ -805,6 +829,8 @@ def full_build( store.commit() rescript_stats = _run_rescript_resolver(store) + spring_stats = _run_spring_resolver(store) + temporal_stats = _run_temporal_resolver(store) return { "files_parsed": len(files), @@ -812,6 +838,8 @@ def full_build( "total_edges": total_edges, "errors": errors, "rescript_resolution": rescript_stats, + "spring_resolution": spring_stats, + "temporal_resolution": temporal_stats, } @@ -931,8 +959,7 @@ def incremental_update( _store_vcs_metadata(repo_root, store) store.commit() - # Only re-run ReScript resolver when changed files touched .res/.resi; - # otherwise prior resolution state is unaffected. + # Only re-run language-specific resolvers when the relevant files changed. rescript_changed = any( rp.endswith((".res", ".resi")) for rp in all_files ) @@ -940,6 +967,10 @@ def incremental_update( _run_rescript_resolver(store) if rescript_changed else None ) + spring_changed = any(rp.endswith(".java") for rp in all_files) + spring_stats = _run_spring_resolver(store) if spring_changed else None + temporal_stats = _run_temporal_resolver(store) if spring_changed else None + return { "files_updated": len(all_files), "total_nodes": total_nodes, @@ -948,6 +979,8 @@ def incremental_update( "dependent_files": list(dependent_files), "errors": errors, "rescript_resolution": rescript_stats, + "spring_resolution": spring_stats, + "temporal_resolution": temporal_stats, } diff --git a/code_review_graph/parser.py b/code_review_graph/parser.py index f681263a..0851de56 100644 --- a/code_review_graph/parser.py +++ b/code_review_graph/parser.py @@ -339,6 +339,52 @@ class EdgeInfo: "org.junit.Test", "org.junit.jupiter.api.Test", }) +# Spring stereotype annotations that mark classes as managed beans +_SPRING_STEREOTYPE_ANNOTATIONS = frozenset({ + "Component", "Service", "Repository", "Controller", "RestController", + "Configuration", "Indexed", "ControllerAdvice", "RestControllerAdvice", + "EventListener", +}) + +# Spring DI injection annotations (field/setter/constructor-level) +_SPRING_INJECT_ANNOTATIONS = frozenset({ + "Autowired", "Inject", "Resource", +}) + +# Lombok annotations that trigger constructor injection of final fields +_LOMBOK_CONSTRUCTOR_ANNOTATIONS = frozenset({ + "RequiredArgsConstructor", "AllArgsConstructor", +}) + +# Temporal workflow/activity interface markers +_TEMPORAL_INTERFACE_ANNOTATIONS = frozenset({ + "WorkflowInterface", "ActivityInterface", +}) + +# Temporal method-level markers +_TEMPORAL_METHOD_ANNOTATIONS = frozenset({ + "WorkflowMethod", "ActivityMethod", "SignalMethod", "QueryMethod", +}) + +# Kafka consumer annotations (annotation-based pattern) +_KAFKA_LISTENER_ANNOTATIONS = frozenset({"KafkaListener", "KafkaHandler"}) + +# Kafka consumer field types (reactive / imperative) +_KAFKA_CONSUMER_TYPES = frozenset({ + "KafkaReceiver", + "ReactiveKafkaConsumerTemplate", + "MessageListenerContainer", + "ConcurrentMessageListenerContainer", +}) + +# Kafka producer field types +_KAFKA_PRODUCER_TYPES = frozenset({ + "KafkaTemplate", + "KafkaOperations", + "ReactiveKafkaProducerTemplate", + "KafkaSender", +}) + # --------------------------------------------------------------------------- # ReScript regex patterns and helpers (no tree-sitter grammar bundled) @@ -2720,6 +2766,386 @@ def _extract_js_field_function( ) return True + @staticmethod + def _get_java_annotations(class_node) -> list[str]: + """Return annotation names from the modifiers child of a Java class/method node.""" + names: list[str] = [] + for child in class_node.children: + if child.type != "modifiers": + continue + for mod in child.children: + if mod.type in ("marker_annotation", "annotation"): + for sub in mod.children: + if sub.type == "identifier": + names.append(sub.text.decode("utf-8", errors="replace")) + break + return names + + def _emit_spring_injections( + self, + class_node, + class_name: str, + class_annotations: list[str], + language: str, + file_path: str, + edges: list[EdgeInfo], + ) -> None: + """Emit INJECTS edges for Spring DI injection points in a Java class. + + Handles three patterns: + - @Autowired / @Inject / @Resource field injection + - @Autowired constructor injection + - Lombok @RequiredArgsConstructor / @AllArgsConstructor with final fields + """ + if language != "java": + return + + has_lombok_constructor = any( + a in _LOMBOK_CONSTRUCTOR_ANNOTATIONS for a in class_annotations + ) + qualified_source = self._qualify(class_name, file_path, None) + + # Find the class body + for node in class_node.children: + if node.type != "class_body": + continue + for member in node.children: + if member.type == "field_declaration": + self._emit_spring_field_injection( + member, qualified_source, file_path, + edges, has_lombok_constructor, + ) + elif member.type == "constructor_declaration": + self._emit_spring_constructor_injection( + member, qualified_source, file_path, edges, + ) + + def _emit_spring_field_injection( + self, + field_node, + qualified_source: str, + file_path: str, + edges: list[EdgeInfo], + has_lombok_constructor: bool, + ) -> None: + """Emit an INJECTS edge for a single field_declaration if injection applies.""" + field_annotations: list[str] = [] + has_final = False + has_static = False + field_type: Optional[str] = None + field_name: Optional[str] = None + + for child in field_node.children: + if child.type == "modifiers": + for mod in child.children: + text = mod.text.decode("utf-8", errors="replace") + if text == "final": + has_final = True + elif text == "static": + has_static = True + elif mod.type in ("marker_annotation", "annotation"): + for sub in mod.children: + if sub.type == "identifier": + field_annotations.append( + sub.text.decode("utf-8", errors="replace") + ) + break + elif child.type in ("type_identifier", "generic_type", "array_type"): + # Use outermost type name for generic types like List + if child.type == "type_identifier": + field_type = child.text.decode("utf-8", errors="replace") + elif child.type == "generic_type": + for sub in child.children: + if sub.type == "type_identifier": + field_type = sub.text.decode("utf-8", errors="replace") + break + elif child.type == "array_type": + for sub in child.children: + if sub.type == "type_identifier": + field_type = sub.text.decode("utf-8", errors="replace") + break + elif child.type == "variable_declarator": + for sub in child.children: + if sub.type == "identifier": + field_name = sub.text.decode("utf-8", errors="replace") + break + + if not field_type or has_static: + return + + has_inject_annotation = any(a in _SPRING_INJECT_ANNOTATIONS for a in field_annotations) + is_lombok_injected = has_lombok_constructor and has_final + + if not has_inject_annotation and not is_lombok_injected: + return + + injection_type = "field" if has_inject_annotation else "constructor_lombok" + extra: dict = {"injection_type": injection_type} + if field_name: + extra["field_name"] = field_name + edges.append(EdgeInfo( + kind="INJECTS", + source=qualified_source, + target=field_type, + file_path=file_path, + line=field_node.start_point[0] + 1, + extra=extra, + )) + + def _emit_spring_constructor_injection( + self, + ctor_node, + qualified_source: str, + file_path: str, + edges: list[EdgeInfo], + ) -> None: + """Emit INJECTS edges for @Autowired constructor parameters.""" + ctor_annotations = self._get_java_annotations(ctor_node) + if not any(a in _SPRING_INJECT_ANNOTATIONS for a in ctor_annotations): + return + + for child in ctor_node.children: + if child.type != "formal_parameters": + continue + for param in child.children: + if param.type != "formal_parameter": + continue + param_type: Optional[str] = None + param_name: Optional[str] = None + for sub in param.children: + if sub.type == "type_identifier" and param_type is None: + param_type = sub.text.decode("utf-8", errors="replace") + elif sub.type == "identifier": + param_name = sub.text.decode("utf-8", errors="replace") + if param_type: + extra: dict = {"injection_type": "constructor"} + if param_name: + extra["field_name"] = param_name + edges.append(EdgeInfo( + kind="INJECTS", + source=qualified_source, + target=param_type, + file_path=file_path, + line=param.start_point[0] + 1, + extra=extra, + )) + + def _emit_temporal_stub_fields( + self, + class_node, + class_name: str, + file_path: str, + edges: list[EdgeInfo], + ) -> None: + """Emit TEMPORAL_STUB edges for Temporal activity/workflow stub fields. + + Detects fields whose type name ends with 'Activity' or 'Workflow' — + the universal naming convention for Temporal interfaces. The temporal + resolver validates these against nodes that have temporal_role in extra. + Static fields are skipped (e.g. logger, constants). + """ + qualified_source = self._qualify(class_name, file_path, None) + + for node in class_node.children: + if node.type != "class_body": + continue + for member in node.children: + if member.type != "field_declaration": + continue + has_static = False + field_type: Optional[str] = None + field_name: Optional[str] = None + + for ch in member.children: + if ch.type == "modifiers": + for mod in ch.children: + if mod.text and mod.text.decode("utf-8", errors="replace") == "static": + has_static = True + elif ch.type == "type_identifier": + field_type = ch.text.decode("utf-8", errors="replace") + elif ch.type == "variable_declarator": + for sub in ch.children: + if sub.type == "identifier": + field_name = sub.text.decode("utf-8", errors="replace") + break + + if has_static or not field_type or not field_name: + continue + + # Only emit for types following the Temporal naming convention + if not (field_type.endswith("Activity") or field_type.endswith("Workflow")): + continue + + edges.append(EdgeInfo( + kind="TEMPORAL_STUB", + source=qualified_source, + target=field_type, + file_path=file_path, + line=member.start_point[0] + 1, + extra={"field_name": field_name, "stub_type": ( + "activity" if field_type.endswith("Activity") else "workflow" + )}, + )) + + @staticmethod + def _get_kafka_annotation_topics(annotation_node) -> list[str]: + """Extract topic strings from @KafkaListener(topics = "...") or topics = {"a","b"}.""" + topics: list[str] = [] + for child in annotation_node.children: + if child.type != "annotation_argument_list": + continue + for pair in child.children: + if pair.type != "element_value_pair": + continue + key_node = next((c for c in pair.children if c.type == "identifier"), None) + if key_node is None: + continue + key = key_node.text.decode("utf-8", errors="replace") + if key not in ("topics", "topicPattern", "value"): + continue + # value can be string_literal or element_value_array_initializer + for val in pair.children: + if val.type == "string_literal": + raw = val.text.decode("utf-8", errors="replace").strip('"').strip("'") + if raw: + topics.append(raw) + elif val.type in ("array_initializer", "element_value_array_initializer"): + for item in val.children: + if item.type == "string_literal": + raw = item.text.decode("utf-8", errors="replace").strip('"').strip("'") + if raw: + topics.append(raw) + return topics + + def _emit_kafka_edges_from_class( + self, + class_node, + class_name: str, + file_path: str, + edges: list[EdgeInfo], + ) -> None: + """Emit CONSUMES/PRODUCES edges for Kafka field declarations. + + Handles: + - KafkaReceiver / ReactiveKafkaConsumerTemplate → CONSUMES + - KafkaTemplate / KafkaOperations / ReactiveKafkaProducerTemplate → PRODUCES + Generic value type (e.g. KafkaReceiver) is + stored in extra.message_type for traceability. + """ + qualified_source = self._qualify(class_name, file_path, None) + + for node in class_node.children: + if node.type != "class_body": + continue + for member in node.children: + if member.type != "field_declaration": + continue + has_static = False + outer_type: Optional[str] = None + value_type: Optional[str] = None # second generic param + field_name: Optional[str] = None + + for ch in member.children: + if ch.type == "modifiers": + for mod in ch.children: + if mod.text and mod.text.decode("utf-8", errors="replace") == "static": + has_static = True + elif ch.type == "type_identifier": + outer_type = ch.text.decode("utf-8", errors="replace") + elif ch.type == "generic_type": + # KafkaReceiver + type_args: list[str] = [] + for sub in ch.children: + if sub.type == "type_identifier": + if outer_type is None: + outer_type = sub.text.decode("utf-8", errors="replace") + elif sub.type == "type_arguments": + for arg in sub.children: + if arg.type == "type_identifier": + type_args.append(arg.text.decode("utf-8", errors="replace")) + if len(type_args) >= 2: + value_type = type_args[-1] # last param is the value/message type + elif ch.type == "variable_declarator": + for sub in ch.children: + if sub.type == "identifier": + field_name = sub.text.decode("utf-8", errors="replace") + break + + if has_static or not outer_type or not field_name: + continue + + extra: dict = {"field_name": field_name} + if value_type: + extra["message_type"] = value_type + + if outer_type in _KAFKA_CONSUMER_TYPES: + extra["kafka_type"] = outer_type + edges.append(EdgeInfo( + kind="CONSUMES", + source=qualified_source, + target=f"kafka:config", + file_path=file_path, + line=member.start_point[0] + 1, + extra=extra, + )) + elif outer_type in _KAFKA_PRODUCER_TYPES: + extra["kafka_type"] = outer_type + edges.append(EdgeInfo( + kind="PRODUCES", + source=qualified_source, + target=f"kafka:config", + file_path=file_path, + line=member.start_point[0] + 1, + extra=extra, + )) + + def _emit_kafka_edges_from_method( + self, + method_node, + method_name: str, + class_name: Optional[str], + file_path: str, + edges: list[EdgeInfo], + ) -> None: + """Emit CONSUMES edges for @KafkaListener / @KafkaHandler annotated methods.""" + qualified_source = self._qualify(method_name, file_path, class_name) + + for child in method_node.children: + if child.type != "modifiers": + continue + for mod in child.children: + if mod.type not in ("annotation", "marker_annotation"): + continue + ann_name: Optional[str] = None + for sub in mod.children: + if sub.type == "identifier": + ann_name = sub.text.decode("utf-8", errors="replace") + break + if ann_name not in _KAFKA_LISTENER_ANNOTATIONS: + continue + # Extract topics from annotation arguments + topics = self._get_kafka_annotation_topics(mod) + if topics: + for topic in topics: + edges.append(EdgeInfo( + kind="CONSUMES", + source=qualified_source, + target=f"kafka:{topic}", + file_path=file_path, + line=method_node.start_point[0] + 1, + extra={"topic": topic, "kafka_type": "KafkaListener"}, + )) + else: + # @KafkaListener without resolvable topic (config placeholder) + edges.append(EdgeInfo( + kind="CONSUMES", + source=qualified_source, + target="kafka:config", + file_path=file_path, + line=method_node.start_point[0] + 1, + extra={"kafka_type": ann_name}, + )) + def _extract_classes( self, child, @@ -2757,6 +3183,24 @@ def _extract_classes( elif child.type == "protocol_declaration": extra["swift_kind"] = "protocol" + # Java: detect Spring stereotype annotations and store as metadata + class_annotations: list[str] = [] + if language == "java": + class_annotations = self._get_java_annotations(child) + spring_stereotypes = [ + a for a in class_annotations if a in _SPRING_STEREOTYPE_ANNOTATIONS + ] + if spring_stereotypes: + extra["spring_stereotype"] = spring_stereotypes[0] + if class_annotations: + extra["spring_annotations"] = class_annotations + temporal_roles = [ + a for a in class_annotations if a in _TEMPORAL_INTERFACE_ANNOTATIONS + ] + if temporal_roles: + role = "workflow_interface" if "WorkflowInterface" in temporal_roles else "activity_interface" + extra["temporal_role"] = role + node = NodeInfo( kind="Class", name=name, @@ -2791,6 +3235,16 @@ def _extract_classes( line=child.start_point[0] + 1, )) + # Spring DI: emit INJECTS edges for injected dependencies + if language == "java": + self._emit_spring_injections( + child, name, class_annotations, language, file_path, edges, + ) + # Temporal: emit TEMPORAL_STUB edges for activity/workflow stub fields + self._emit_temporal_stub_fields(child, name, file_path, edges) + # Kafka: emit CONSUMES/PRODUCES edges for Kafka field declarations + self._emit_kafka_edges_from_class(child, name, file_path, edges) + # Recurse into class body self._extract_from_tree( child, source, language, file_path, nodes, edges, @@ -2854,6 +3308,20 @@ def _extract_functions( params = self._get_params(child, language, source) ret_type = self._get_return_type(child, language, source) + # Java: detect Temporal method-level annotations and Kafka listeners + method_extra: dict = {} + if language == "java" and deco_list: + temporal_method_annots = [ + a for a in deco_list if a in _TEMPORAL_METHOD_ANNOTATIONS + ] + if temporal_method_annots: + method_extra["temporal_role"] = temporal_method_annots[0].lower() + if any(a.split("(")[0] in _KAFKA_LISTENER_ANNOTATIONS for a in deco_list): + method_extra["kafka_listener"] = True + self._emit_kafka_edges_from_method( + child, name, enclosing_class, file_path, edges, + ) + node = NodeInfo( kind=kind, name=name, @@ -2865,6 +3333,7 @@ def _extract_functions( params=params, return_type=ret_type, is_test=is_test, + extra=method_extra, ) nodes.append(node) @@ -3025,20 +3494,80 @@ def _extract_calls( caller = self._qualify( enclosing_func, file_path, enclosing_class, ) - target = self._resolve_call_target( - call_name, file_path, language, - import_map or {}, defined_names or set(), - ) + + # Java method_invocation: extract actual method name and receiver + # separately so the Spring DI resolver can rewrite the target. + call_extra: dict = {} + if language == "java" and child.type == "method_invocation": + method_name, receiver = self._get_java_method_and_receiver(child) + if method_name: + call_name = method_name + if receiver: + call_extra["receiver"] = receiver + + # When a receiver is present, skip scope-based resolution: the method + # lives on the receiver's type, not in the current file's scope. + # The spring_resolver post-pass will do the correct cross-type lookup. + if call_extra.get("receiver"): + target = call_name + else: + target = self._resolve_call_target( + call_name, file_path, language, + import_map or {}, defined_names or set(), + ) edges.append(EdgeInfo( kind="CALLS", source=caller, target=target, file_path=file_path, line=child.start_point[0] + 1, + extra=call_extra, )) return False + @staticmethod + def _get_java_method_and_receiver(node) -> tuple[Optional[str], Optional[str]]: + """For a Java method_invocation node, return (method_name, receiver_name). + + Pattern: [receiver_identifier, '.', method_identifier, argument_list] + Chained: [inner_method_invocation, '.', method_identifier, argument_list] + + Returns (None, None) for unrecognised shapes. + """ + children = node.children + if len(children) < 3: + return None, None + + # method_identifier is always the last identifier before argument_list + method_name: Optional[str] = None + receiver_name: Optional[str] = None + + # Scan backwards for the method identifier + for i in range(len(children) - 1, -1, -1): + ch = children[i] + if ch.type == "argument_list": + continue + if ch.type == "identifier": + if method_name is None: + method_name = ch.text.decode("utf-8", errors="replace") + else: + # Second identifier scanning backwards = receiver + receiver_name = ch.text.decode("utf-8", errors="replace") + break + if ch.type == "." : + continue + # Chained call or complex expression as receiver — no simple receiver + break + + # Receiver is the first child if it's a plain identifier + if method_name and children[0].type == "identifier": + first_text = children[0].text.decode("utf-8", errors="replace") + if first_text != method_name: + receiver_name = first_text + + return method_name, receiver_name + def _extract_jsx_component_call( self, child, @@ -3969,6 +4498,17 @@ def _get_name(self, node, language: str, kind: str) -> Optional[str]: return child.text.decode("utf-8", errors="replace") if child.type == "package" and child.text != b"package": return child.text.decode("utf-8", errors="replace") + # Java: method_declaration has return type_identifier before the method + # identifier — skip straight to the first plain identifier child to + # avoid returning the return type as the function name. + if language == "java" and kind == "function" and node.type in ( + "method_declaration", "constructor_declaration", + ): + for child in node.children: + if child.type == "identifier": + return child.text.decode("utf-8", errors="replace") + return None + # For C/C++/Objective-C: function names are inside # function_declarator / pointer_declarator. Check these first to # avoid matching the return type_identifier as the function name. diff --git a/code_review_graph/spring_resolver.py b/code_review_graph/spring_resolver.py new file mode 100644 index 00000000..dba34b5c --- /dev/null +++ b/code_review_graph/spring_resolver.py @@ -0,0 +1,198 @@ +"""Post-build Spring DI call resolver. + +After tree-sitter parsing, Java CALLS edges whose target is a bare method +name (e.g. ``calculate``) carry ``extra.receiver`` naming the local variable +that was called on (e.g. ``invoiceCalculationService``). This module +resolves those receivers through the INJECTS map to their declared type, then +optionally to the unique concrete implementation via INHERITS edges. + +Resolution chain: + receiver variable name + → injected interface/class (from INJECTS.extra.field_name) + → concrete implementation (from INHERITS, when unique) + +Only Java files are processed. Edges that are already qualified (contain +``::``) or have no ``receiver`` extra key are skipped. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .graph import GraphStore + +logger = logging.getLogger(__name__) + + +def resolve_spring_di_calls(store: GraphStore) -> dict: + """Resolve Java CALLS edges whose receiver is a Spring-injected field. + + Safe to call multiple times — already-resolved edges (targets containing + ``::``) are skipped. + + Returns a dict with resolution counts for telemetry. + """ + conn = store._conn + + # Only process Java files + java_files: set[str] = { + row["file_path"] + for row in conn.execute( + "SELECT DISTINCT file_path FROM nodes WHERE language = 'java'" + ).fetchall() + } + if not java_files: + return {"files_indexed": 0, "calls_resolved": 0} + + # ----------------------------------------------------------------------- + # Build field_map: (source_qualified_class, field_name) → injected_type + # from INJECTS edges that carry extra.field_name + # ----------------------------------------------------------------------- + field_map: dict[tuple[str, str], str] = {} + injects_rows = conn.execute( + "SELECT source_qualified, target_qualified, extra FROM edges WHERE kind = 'INJECTS'" + ).fetchall() + for row in injects_rows: + try: + extra = json.loads(row["extra"] or "{}") + except (json.JSONDecodeError, TypeError): + extra = {} + fname = extra.get("field_name") + if not fname: + continue + # source_qualified is the full class qualified name + class_qual = row["source_qualified"] + field_map[(class_qual, fname)] = row["target_qualified"] + + if not field_map: + logger.info("Spring resolver: no INJECTS edges with field_name found, skipping") + return {"files_indexed": len(java_files), "calls_resolved": 0} + + # ----------------------------------------------------------------------- + # Build class_name → qualified_name lookup from nodes. + # Keyed by bare class name; value is the full "file_path::ClassName" form + # that callers_of uses for its target_qualified exact-match lookup. + # When a name appears in multiple files (e.g. same interface in several + # services), we keep the entry with the shortest path as a tiebreaker — + # this is overridden by the concrete-implementation lookup below. + # ----------------------------------------------------------------------- + name_to_qual: dict[str, str] = {} + for row in conn.execute( + "SELECT name, qualified_name FROM nodes WHERE kind = 'Class' AND language = 'java'" + ).fetchall(): + bare = row["name"] + qual = row["qualified_name"] + if bare not in name_to_qual or len(qual) < len(name_to_qual[bare]): + name_to_qual[bare] = qual + + # Also index Function nodes so we can build "file::Class.method" targets. + # key: (class_name, method_name) → full qualified_name of the method node + method_to_qual: dict[tuple[str, str], str] = {} + for row in conn.execute( + "SELECT name, qualified_name, parent_name FROM nodes " + "WHERE kind IN ('Function', 'Test') AND language = 'java' AND parent_name IS NOT NULL" + ).fetchall(): + method_to_qual[(row["parent_name"], row["name"])] = row["qualified_name"] + + # ----------------------------------------------------------------------- + # Build implementors: bare interface name → list of implementing class quals + # from INHERITS edges (Java uses INHERITS for both extends and implements) + # ----------------------------------------------------------------------- + implementors: dict[str, list[str]] = {} + for row in conn.execute( + "SELECT source_qualified, target_qualified FROM edges WHERE kind = 'INHERITS'" + ).fetchall(): + iface = row["target_qualified"] + impl = row["source_qualified"] + if any(impl.startswith(f) for f in java_files) or "::" in impl: + implementors.setdefault(iface, []).append(impl) + + # ----------------------------------------------------------------------- + # Resolve CALLS edges + # ----------------------------------------------------------------------- + calls_rows = conn.execute( + "SELECT id, source_qualified, target_qualified, extra, file_path " + "FROM edges WHERE kind = 'CALLS'" + ).fetchall() + + resolved = 0 + + for row in calls_rows: + if row["file_path"] not in java_files: + continue + + try: + extra = json.loads(row["extra"] or "{}") + except (json.JSONDecodeError, TypeError): + extra = {} + + receiver = extra.get("receiver") + if not receiver: + continue + + # Skip edges already spring-resolved in a previous pass + if extra.get("spring_resolved"): + continue + + # Strip any prior (possibly wrong) qualification — we have a receiver so + # we can do a better resolution. E.g. "file::ClassName.method" → "method" + raw_target = row["target_qualified"] + if "::" in raw_target: + after = raw_target.split("::", 1)[1] + method_name = after.split(".")[-1] if "." in after else after + else: + method_name = raw_target + source_qual = row["source_qualified"] + + # Derive the enclosing class qualified name from source + # source_qual format: "file_path::ClassName.method_name" + enclosing_class_qual: str | None = None + if "::" in source_qual: + after_sep = source_qual.split("::", 1)[1] + if "." in after_sep: + class_part = after_sep.split(".")[0] + prefix = source_qual.split("::")[0] + enclosing_class_qual = f"{prefix}::{class_part}" + else: + enclosing_class_qual = source_qual + + if not enclosing_class_qual: + continue + + # Look up receiver in field_map for this class + injected_type = field_map.get((enclosing_class_qual, receiver)) + if not injected_type: + continue + + # Resolve to concrete implementation if unique + impls = implementors.get(injected_type, []) + if len(impls) == 1: + concrete_class = impls[0].split("::")[-1] + new_target = method_to_qual.get((concrete_class, method_name)) or f"{impls[0]}.{method_name}" + else: + type_bare = injected_type.split(".")[-1] if "." in injected_type else injected_type + new_target = method_to_qual.get((type_bare, method_name)) or f"{injected_type}.{method_name}" + + extra["spring_resolved"] = True + extra["injected_type"] = injected_type + new_extra = json.dumps(extra) + + conn.execute( + "UPDATE edges SET target_qualified = ?, extra = ? WHERE id = ?", + (new_target, new_extra, row["id"]), + ) + resolved += 1 + logger.debug( + "Spring resolved: %s → %s (was %s, receiver=%s)", + source_qual, new_target, method_name, receiver, + ) + + if resolved: + conn.commit() + + logger.info("Spring DI resolver: resolved %d CALLS edges in %d Java files", + resolved, len(java_files)) + return {"files_indexed": len(java_files), "calls_resolved": resolved} diff --git a/code_review_graph/temporal_resolver.py b/code_review_graph/temporal_resolver.py new file mode 100644 index 00000000..e67f4d96 --- /dev/null +++ b/code_review_graph/temporal_resolver.py @@ -0,0 +1,197 @@ +"""Post-build Temporal workflow/activity call resolver. + +After tree-sitter parsing, Java CALLS edges whose target is a bare method +name carry ``extra.receiver`` naming the local variable called on. This +module resolves those receivers through the TEMPORAL_STUB map to their +declared Temporal interface type, then optionally to the unique concrete +implementation via INHERITS edges. + +Resolution chain: + receiver variable name + → temporal stub field type (from TEMPORAL_STUB.extra.field_name) + → concrete implementation (from INHERITS, when unique) + +Only Java files are processed. TEMPORAL_STUB edges whose target is not a +node with ``temporal_role`` in extra are silently skipped (they may be +non-Temporal types that happen to end in 'Activity'/'Workflow'). +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .graph import GraphStore + +logger = logging.getLogger(__name__) + + +def resolve_temporal_calls(store: GraphStore) -> dict: + """Resolve Java CALLS edges whose receiver is a Temporal activity/workflow stub. + + Safe to call multiple times — already-resolved edges (with + ``extra.temporal_resolved``) are skipped. + + Returns a dict with resolution counts for telemetry. + """ + conn = store._conn + + java_files: set[str] = { + row["file_path"] + for row in conn.execute( + "SELECT DISTINCT file_path FROM nodes WHERE language = 'java'" + ).fetchall() + } + if not java_files: + return {"files_indexed": 0, "calls_resolved": 0} + + # ----------------------------------------------------------------------- + # Collect Temporal interface nodes: bare name → qualified_name + # (nodes whose extra contains temporal_role = workflow_interface|activity_interface) + # ----------------------------------------------------------------------- + temporal_interfaces: dict[str, str] = {} # bare_name → qualified_name + for row in conn.execute( + "SELECT name, qualified_name, extra FROM nodes " + "WHERE language = 'java' AND extra IS NOT NULL AND extra LIKE '%temporal_role%'" + ).fetchall(): + try: + ex = json.loads(row["extra"] or "{}") + except (json.JSONDecodeError, TypeError): + ex = {} + if ex.get("temporal_role") in ("workflow_interface", "activity_interface"): + temporal_interfaces[row["name"]] = row["qualified_name"] + + if not temporal_interfaces: + logger.info("Temporal resolver: no @WorkflowInterface/@ActivityInterface nodes found, skipping") + return {"files_indexed": len(java_files), "calls_resolved": 0} + + # ----------------------------------------------------------------------- + # Build field_map: (source_qualified_class, field_name) → interface_type + # from TEMPORAL_STUB edges whose target is a known Temporal interface + # ----------------------------------------------------------------------- + field_map: dict[tuple[str, str], str] = {} + for row in conn.execute( + "SELECT source_qualified, target_qualified, extra FROM edges WHERE kind = 'TEMPORAL_STUB'" + ).fetchall(): + bare_target = row["target_qualified"] + if bare_target not in temporal_interfaces: + continue + try: + extra = json.loads(row["extra"] or "{}") + except (json.JSONDecodeError, TypeError): + extra = {} + fname = extra.get("field_name") + if not fname: + continue + field_map[(row["source_qualified"], fname)] = bare_target + + if not field_map: + logger.info("Temporal resolver: no TEMPORAL_STUB edges for known Temporal interfaces, skipping") + return {"files_indexed": len(java_files), "calls_resolved": 0} + + # ----------------------------------------------------------------------- + # method_to_qual: (class_name, method_name) → full qualified_name + # ----------------------------------------------------------------------- + method_to_qual: dict[tuple[str, str], str] = {} + for row in conn.execute( + "SELECT name, qualified_name, parent_name FROM nodes " + "WHERE kind IN ('Function', 'Test') AND language = 'java' AND parent_name IS NOT NULL" + ).fetchall(): + method_to_qual[(row["parent_name"], row["name"])] = row["qualified_name"] + + # ----------------------------------------------------------------------- + # implementors: bare interface name → list of implementing class quals + # ----------------------------------------------------------------------- + implementors: dict[str, list[str]] = {} + for row in conn.execute( + "SELECT source_qualified, target_qualified FROM edges WHERE kind = 'INHERITS'" + ).fetchall(): + iface = row["target_qualified"] + impl = row["source_qualified"] + if any(impl.startswith(f) for f in java_files) or "::" in impl: + implementors.setdefault(iface, []).append(impl) + + # ----------------------------------------------------------------------- + # Resolve CALLS edges + # ----------------------------------------------------------------------- + calls_rows = conn.execute( + "SELECT id, source_qualified, target_qualified, extra, file_path " + "FROM edges WHERE kind = 'CALLS'" + ).fetchall() + + resolved = 0 + + for row in calls_rows: + if row["file_path"] not in java_files: + continue + + try: + extra = json.loads(row["extra"] or "{}") + except (json.JSONDecodeError, TypeError): + extra = {} + + receiver = extra.get("receiver") + if not receiver: + continue + + if extra.get("temporal_resolved") or extra.get("spring_resolved"): + continue + + raw_target = row["target_qualified"] + if "::" in raw_target: + after = raw_target.split("::", 1)[1] + method_name = after.split(".")[-1] if "." in after else after + else: + method_name = raw_target + + source_qual = row["source_qualified"] + + # Derive enclosing class qualified name + enclosing_class_qual: str | None = None + if "::" in source_qual: + after_sep = source_qual.split("::", 1)[1] + if "." in after_sep: + class_part = after_sep.split(".")[0] + prefix = source_qual.split("::")[0] + enclosing_class_qual = f"{prefix}::{class_part}" + else: + enclosing_class_qual = source_qual + + if not enclosing_class_qual: + continue + + interface_bare = field_map.get((enclosing_class_qual, receiver)) + if not interface_bare: + continue + + interface_qual = temporal_interfaces.get(interface_bare, interface_bare) + + impls = implementors.get(interface_qual, []) + if len(impls) == 1: + concrete_class = impls[0].split("::")[-1] + new_target = method_to_qual.get((concrete_class, method_name)) or f"{impls[0]}.{method_name}" + else: + new_target = method_to_qual.get((interface_bare, method_name)) or f"{interface_qual}.{method_name}" + + extra["temporal_resolved"] = True + extra["temporal_interface"] = interface_bare + new_extra = json.dumps(extra) + + conn.execute( + "UPDATE edges SET target_qualified = ?, extra = ? WHERE id = ?", + (new_target, new_extra, row["id"]), + ) + resolved += 1 + logger.debug( + "Temporal resolved: %s → %s (receiver=%s, interface=%s)", + source_qual, new_target, receiver, interface_bare, + ) + + if resolved: + conn.commit() + + logger.info("Temporal resolver: resolved %d CALLS edges in %d Java files", + resolved, len(java_files)) + return {"files_indexed": len(java_files), "calls_resolved": resolved} diff --git a/tests/fixtures/KafkaPatterns.java b/tests/fixtures/KafkaPatterns.java new file mode 100644 index 00000000..227130ff --- /dev/null +++ b/tests/fixtures/KafkaPatterns.java @@ -0,0 +1,47 @@ +package com.example.kafka; + +import org.springframework.kafka.annotation.KafkaListener; +import org.springframework.kafka.annotation.KafkaHandler; +import org.springframework.kafka.core.KafkaTemplate; +import org.springframework.kafka.core.KafkaOperations; +import org.springframework.stereotype.Service; +import org.springframework.stereotype.Component; +import lombok.RequiredArgsConstructor; +import reactor.kafka.receiver.KafkaReceiver; + +// ── Annotation-based consumer ───────────────────────────────────────────── + +@Service +class OrderEventConsumer { + + @KafkaListener(topics = "order-events") + public void onOrder(String payload) {} + + @KafkaListener(topics = {"order-dlq", "order-retry"}) + public void onDlq(String payload) {} +} + +// ── Annotation-based producer (KafkaTemplate field) ─────────────────────── + +@Service +@RequiredArgsConstructor +class NotificationProducer { + private final KafkaTemplate kafkaTemplate; + // static field — should NOT produce edge + private static final String TOPIC = "notifications"; +} + +// ── Reactive consumer (KafkaReceiver field) ─────────────────────────────── + +@Service +@RequiredArgsConstructor +class ReactiveOrderConsumer { + private final KafkaReceiver kafkaReceiver; + private final KafkaOperations kafkaOps; +} + +// ── plain class with no Kafka ───────────────────────────────────────────── + +class OrderEvent { + private String id; +} diff --git a/tests/fixtures/SpringDI.java b/tests/fixtures/SpringDI.java new file mode 100644 index 00000000..402c213f --- /dev/null +++ b/tests/fixtures/SpringDI.java @@ -0,0 +1,78 @@ +package com.example.shop; + +import org.springframework.stereotype.Service; +import org.springframework.stereotype.Repository; +import org.springframework.stereotype.Component; +import org.springframework.stereotype.Controller; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import lombok.RequiredArgsConstructor; + +// Plain interface — not a Spring bean +public interface OrderRepository { + void save(Order order); + Order findById(Long id); +} + +// @Repository stereotype — Spring-managed bean +@Repository +class JpaOrderRepository implements OrderRepository { + @Override + public void save(Order order) {} + + @Override + public Order findById(Long id) { return null; } +} + +// @Service with @Autowired field injection +@Service +class NotificationService { + @Autowired + private OrderRepository orderRepository; + + public void notify(Long orderId) { + Order o = orderRepository.findById(orderId); + } +} + +// @Service with Lombok @RequiredArgsConstructor (constructor injection via final fields) +@Service +@RequiredArgsConstructor +class OrderService { + private final OrderRepository orderRepository; + private final NotificationService notificationService; + private static final String TAG = "OrderService"; // static final — NOT injected + + public void placeOrder(Order order) { + orderRepository.save(order); + notificationService.notify(order.getId()); + } +} + +// @Component with explicit @Autowired constructor +@Component +class AuditLogger { + private final OrderRepository orderRepository; + + @Autowired + public AuditLogger(OrderRepository orderRepository) { + this.orderRepository = orderRepository; + } + + public void log(String msg) {} +} + +// @Configuration with @Bean factory methods +@Configuration +class AppConfig { + @Bean + public OrderRepository orderRepository() { + return new JpaOrderRepository(); + } +} + +class Order { + private Long id; + public Long getId() { return id; } +} diff --git a/tests/fixtures/TemporalWorkflow.java b/tests/fixtures/TemporalWorkflow.java new file mode 100644 index 00000000..f329de9b --- /dev/null +++ b/tests/fixtures/TemporalWorkflow.java @@ -0,0 +1,72 @@ +package com.example.temporal; + +import io.temporal.workflow.WorkflowInterface; +import io.temporal.workflow.WorkflowMethod; +import io.temporal.workflow.SignalMethod; +import io.temporal.workflow.QueryMethod; +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; + +// ── Interfaces ─────────────────────────────────────────────────────────────── + +@WorkflowInterface +public interface OrderWorkflow { + @WorkflowMethod + String processOrder(String orderId); + + @SignalMethod + void cancelOrder(String reason); + + @QueryMethod + String getStatus(); +} + +@ActivityInterface +public interface PaymentActivity { + @ActivityMethod + boolean chargeCard(String orderId, double amount); +} + +@ActivityInterface +public interface ShippingActivity { + @ActivityMethod + String shipOrder(String orderId); +} + +// ── Implementations ────────────────────────────────────────────────────────── + +// Workflow impl holds activity stubs as fields +class OrderWorkflowImpl implements OrderWorkflow { + + // These fields are assigned via Workflow.newActivityStub() at runtime + private PaymentActivity paymentActivity; + private ShippingActivity shippingActivity; + + // Static fields should NOT produce TEMPORAL_STUB edges + private static final String TAG = "OrderWorkflowImpl"; + + @Override + public String processOrder(String orderId) { + boolean paid = paymentActivity.chargeCard(orderId, 100.0); + if (!paid) return "FAILED"; + String trackingId = shippingActivity.shipOrder(orderId); + return trackingId; + } + + @Override + public void cancelOrder(String reason) {} + + @Override + public String getStatus() { return "OK"; } +} + +// Activity impls +class PaymentActivityImpl implements PaymentActivity { + @Override + public boolean chargeCard(String orderId, double amount) { return true; } +} + +class ShippingActivityImpl implements ShippingActivity { + @Override + public String shipOrder(String orderId) { return "TRACK-001"; } +} diff --git a/tests/test_multilang.py b/tests/test_multilang.py index 9d45f434..1e9caa5e 100644 --- a/tests/test_multilang.py +++ b/tests/test_multilang.py @@ -1916,3 +1916,380 @@ def test_resolver_is_idempotent(self, tmp_path): # Second run should find nothing new — all already resolved. assert second["calls_resolved"] == 0 assert second["imports_resolved"] == 0 + + +class TestSpringDIParsing: + """Tests for Spring DI annotation detection and INJECTS edge generation.""" + + def setup_method(self): + self.parser = CodeParser() + self.nodes, self.edges = self.parser.parse_file(FIXTURES / "SpringDI.java") + + def test_detects_spring_stereotype_on_repository(self): + classes = {n.name: n for n in self.nodes if n.kind == "Class"} + assert "JpaOrderRepository" in classes + assert classes["JpaOrderRepository"].extra.get("spring_stereotype") == "Repository" + + def test_detects_spring_stereotype_on_service(self): + classes = {n.name: n for n in self.nodes if n.kind == "Class"} + assert "NotificationService" in classes + assert classes["NotificationService"].extra.get("spring_stereotype") == "Service" + assert "OrderService" in classes + assert classes["OrderService"].extra.get("spring_stereotype") == "Service" + + def test_detects_spring_stereotype_on_configuration(self): + classes = {n.name: n for n in self.nodes if n.kind == "Class"} + assert "AppConfig" in classes + assert classes["AppConfig"].extra.get("spring_stereotype") == "Configuration" + + def test_no_stereotype_on_plain_interface(self): + classes = {n.name: n for n in self.nodes if n.kind == "Class"} + assert "OrderRepository" in classes + assert "spring_stereotype" not in classes["OrderRepository"].extra + + def test_spring_annotations_list_stored(self): + classes = {n.name: n for n in self.nodes if n.kind == "Class"} + annotations = classes["OrderService"].extra.get("spring_annotations", []) + assert "Service" in annotations + assert "RequiredArgsConstructor" in annotations + + def test_autowired_field_injection_edge(self): + injects = [e for e in self.edges if e.kind == "INJECTS"] + # NotificationService has @Autowired OrderRepository field + field_edges = [e for e in injects if e.extra.get("injection_type") == "field"] + targets = {e.target for e in field_edges} + assert "OrderRepository" in targets + + def test_autowired_field_source_is_class(self): + injects = [e for e in self.edges if e.kind == "INJECTS" + and e.extra.get("injection_type") == "field"] + sources = {e.source for e in injects} + assert any("NotificationService" in s for s in sources) + + def test_lombok_required_args_constructor_injection(self): + injects = [e for e in self.edges if e.kind == "INJECTS"] + lombok_edges = [e for e in injects + if e.extra.get("injection_type") == "constructor_lombok"] + targets = {e.target for e in lombok_edges} + # OrderService has two final injected fields + assert "OrderRepository" in targets + assert "NotificationService" in targets + + def test_static_final_field_not_injected(self): + """static final String TAG should NOT produce an INJECTS edge.""" + injects = [e for e in self.edges if e.kind == "INJECTS"] + targets = {e.target for e in injects} + assert "String" not in targets + + def test_explicit_autowired_constructor_injection(self): + injects = [e for e in self.edges if e.kind == "INJECTS"] + ctor_edges = [e for e in injects + if e.extra.get("injection_type") == "constructor"] + targets = {e.target for e in ctor_edges} + # AuditLogger has @Autowired constructor with OrderRepository param + assert "OrderRepository" in targets + + def test_autowired_constructor_source_is_class(self): + injects = [e for e in self.edges if e.kind == "INJECTS" + and e.extra.get("injection_type") == "constructor"] + sources = {e.source for e in injects} + assert any("AuditLogger" in s for s in sources) + + def test_total_injects_edge_count(self): + """Sanity check: total INJECTS edges matches known injection points.""" + injects = [e for e in self.edges if e.kind == "INJECTS"] + # NotificationService: 1 field + # OrderService: 2 lombok (orderRepository + notificationService) + # AuditLogger: 1 constructor + assert len(injects) >= 4 + + def test_field_name_stored_in_injects_extra(self): + """INJECTS edges must carry extra.field_name for the resolver.""" + injects = [e for e in self.edges if e.kind == "INJECTS"] + names = {e.extra.get("field_name") for e in injects} + # @Autowired field in NotificationService + assert "orderRepository" in names + # @RequiredArgsConstructor final fields in OrderService + assert "orderRepository" in names + assert "notificationService" in names + # @Autowired constructor param in AuditLogger + assert "orderRepository" in names + + def test_java_method_call_target_is_method_not_receiver(self): + """Java receiver.method() must emit CALLS with method as target, not receiver.""" + calls = [e for e in self.edges if e.kind == "CALLS"] + targets = {e.target for e in calls} + # placeOrder calls orderRepository.save() — target must end in "save" + # (possibly qualified to "::OrderRepository.save" if same-file resolution kicks in) + assert any("save" in t for t in targets), f"expected 'save' in targets, got {targets}" + # receiver variable names must NOT appear as CALLS targets + assert "orderRepository" not in targets + assert "notificationService" not in targets + + def test_java_receiver_stored_in_calls_extra(self): + """CALLS edges for Java method calls must carry extra.receiver.""" + calls = [e for e in self.edges if e.kind == "CALLS" and e.extra.get("receiver")] + receivers = {e.extra["receiver"] for e in calls} + assert "orderRepository" in receivers or "notificationService" in receivers + + +class TestSpringDIResolver: + """Integration tests for the Spring DI post-build resolver.""" + + def _build(self, tmp_path): + """Build a mini Spring repo and run the resolver.""" + pkg = tmp_path / "src/main/java/com/example" + pkg.mkdir(parents=True) + + (pkg / "OrderRepository.java").write_text( + "package com.example;\n" + "public interface OrderRepository {\n" + " void save(Order o);\n" + "}\n" + ) + (pkg / "JpaOrderRepository.java").write_text( + "package com.example;\n" + "import org.springframework.stereotype.Repository;\n" + "@Repository\n" + "public class JpaOrderRepository implements OrderRepository {\n" + " public void save(Order o) {}\n" + "}\n" + ) + (pkg / "OrderService.java").write_text( + "package com.example;\n" + "import org.springframework.stereotype.Service;\n" + "import lombok.RequiredArgsConstructor;\n" + "@Service\n" + "@RequiredArgsConstructor\n" + "public class OrderService {\n" + " private final OrderRepository orderRepository;\n" + " public void place(Order o) {\n" + " orderRepository.save(o);\n" + " }\n" + "}\n" + ) + + from code_review_graph.graph import GraphStore + from code_review_graph.incremental import full_build + from code_review_graph.postprocessing import run_post_processing + + store = GraphStore(str(tmp_path / "graph.db")) + result = full_build(tmp_path, store) + run_post_processing(store) + return store, result + + def test_resolver_runs_and_reports(self, tmp_path): + _, result = self._build(tmp_path) + stats = result.get("spring_resolution") + assert stats is not None + assert stats["files_indexed"] > 0 + + def test_calls_resolved_through_field(self, tmp_path): + store, result = self._build(tmp_path) + stats = result.get("spring_resolution", {}) + assert stats.get("calls_resolved", 0) >= 1 + + def test_resolved_target_includes_method_name(self, tmp_path): + store, _ = self._build(tmp_path) + cur = store._conn.cursor() + rows = cur.execute( + "SELECT target_qualified FROM edges WHERE kind='CALLS' " + "AND extra LIKE '%spring_resolved%'" + ).fetchall() + assert rows, "Expected at least one spring-resolved CALLS edge" + for (target,) in rows: + assert "." in target or "::" in target, ( + f"Resolved target should contain type.method or ::, got: {target!r}" + ) + + +class TestTemporalParsing: + """Tests for Temporal @WorkflowInterface / @ActivityInterface detection.""" + + def setup_method(self): + self.parser = CodeParser() + self.nodes, self.edges = self.parser.parse_file(FIXTURES / "TemporalWorkflow.java") + + def test_workflow_interface_gets_temporal_role(self): + classes = {n.name: n for n in self.nodes if n.kind == "Class"} + assert "OrderWorkflow" in classes + assert classes["OrderWorkflow"].extra.get("temporal_role") == "workflow_interface" + + def test_activity_interface_gets_temporal_role(self): + classes = {n.name: n for n in self.nodes if n.kind == "Class"} + assert "PaymentActivity" in classes + assert classes["PaymentActivity"].extra.get("temporal_role") == "activity_interface" + assert "ShippingActivity" in classes + assert classes["ShippingActivity"].extra.get("temporal_role") == "activity_interface" + + def test_impl_class_has_no_temporal_role(self): + classes = {n.name: n for n in self.nodes if n.kind == "Class"} + assert "OrderWorkflowImpl" in classes + assert "temporal_role" not in classes["OrderWorkflowImpl"].extra + + def test_temporal_stub_edges_emitted_for_activity_fields(self): + stubs = [e for e in self.edges if e.kind == "TEMPORAL_STUB"] + targets = {e.target for e in stubs} + assert "PaymentActivity" in targets + assert "ShippingActivity" in targets + + def test_temporal_stub_field_name_stored(self): + stubs = [e for e in self.edges if e.kind == "TEMPORAL_STUB"] + field_names = {e.extra.get("field_name") for e in stubs} + assert "paymentActivity" in field_names + assert "shippingActivity" in field_names + + def test_static_field_not_in_temporal_stubs(self): + stubs = [e for e in self.edges if e.kind == "TEMPORAL_STUB"] + field_names = {e.extra.get("field_name") for e in stubs} + assert "TAG" not in field_names + + def test_temporal_stub_source_is_workflow_impl(self): + stubs = [e for e in self.edges if e.kind == "TEMPORAL_STUB"] + sources = {e.source for e in stubs} + assert any("OrderWorkflowImpl" in s for s in sources) + + def test_workflow_method_annotation_stored_on_method(self): + interface_methods = [ + n for n in self.nodes if n.kind == "Function" and n.parent_name == "OrderWorkflow" + ] + names = {n.name: n for n in interface_methods} + assert "processOrder" in names + assert names["processOrder"].extra.get("temporal_role") == "workflowmethod" + + def test_signal_method_annotation_stored(self): + interface_methods = [ + n for n in self.nodes if n.kind == "Function" and n.parent_name == "OrderWorkflow" + ] + names = {n.name: n for n in interface_methods} + assert "cancelOrder" in names + assert names["cancelOrder"].extra.get("temporal_role") == "signalmethod" + + def test_activity_method_annotation_stored(self): + activity_methods = [ + n for n in self.nodes if n.kind == "Function" and n.parent_name == "PaymentActivity" + ] + names = {n.name: n for n in activity_methods} + assert "chargeCard" in names + assert names["chargeCard"].extra.get("temporal_role") == "activitymethod" + + +class TestTemporalResolver: + """Integration tests for the Temporal post-build call resolver.""" + + def _build(self, tmp_path): + pkg = tmp_path / "src/main/java/com/example" + pkg.mkdir(parents=True) + + (pkg / "PaymentActivity.java").write_text( + "package com.example;\n" + "import io.temporal.activity.ActivityInterface;\n" + "import io.temporal.activity.ActivityMethod;\n" + "@ActivityInterface\n" + "public interface PaymentActivity {\n" + " @ActivityMethod\n" + " boolean charge(String orderId);\n" + "}\n" + ) + (pkg / "PaymentActivityImpl.java").write_text( + "package com.example;\n" + "public class PaymentActivityImpl implements PaymentActivity {\n" + " public boolean charge(String orderId) { return true; }\n" + "}\n" + ) + (pkg / "OrderWorkflowImpl.java").write_text( + "package com.example;\n" + "public class OrderWorkflowImpl {\n" + " private PaymentActivity paymentActivity;\n" + " public String process(String id) {\n" + " return paymentActivity.charge(id) ? \"OK\" : \"FAIL\";\n" + " }\n" + "}\n" + ) + + from code_review_graph.graph import GraphStore + from code_review_graph.incremental import full_build + + store = GraphStore(str(tmp_path / "graph.db")) + result = full_build(tmp_path, store) + return store, result + + def test_temporal_resolver_runs_and_reports(self, tmp_path): + _, result = self._build(tmp_path) + stats = result.get("temporal_resolution") + assert stats is not None + assert stats["files_indexed"] > 0 + + def test_calls_resolved_through_activity_stub(self, tmp_path): + _, result = self._build(tmp_path) + stats = result.get("temporal_resolution", {}) + assert stats.get("calls_resolved", 0) >= 1 + + def test_resolved_target_is_fully_qualified(self, tmp_path): + store, _ = self._build(tmp_path) + rows = store._conn.execute( + "SELECT target_qualified FROM edges WHERE kind='CALLS' " + "AND extra LIKE '%temporal_resolved%'" + ).fetchall() + assert rows, "Expected at least one temporal-resolved CALLS edge" + for (target,) in rows: + assert "." in target or "::" in target, ( + f"Resolved target should be qualified, got: {target!r}" + ) + + +class TestKafkaParsing: + """Tests for Kafka CONSUMES / PRODUCES edge detection.""" + + def setup_method(self): + self.parser = CodeParser() + self.nodes, self.edges = self.parser.parse_file(FIXTURES / "KafkaPatterns.java") + + def test_kafka_listener_annotation_emits_consumes_edge(self): + consumes = [e for e in self.edges if e.kind == "CONSUMES"] + targets = {e.target for e in consumes} + assert "kafka:order-events" in targets + + def test_kafka_listener_multiple_topics(self): + consumes = [e for e in self.edges if e.kind == "CONSUMES"] + targets = {e.target for e in consumes} + assert "kafka:order-dlq" in targets + assert "kafka:order-retry" in targets + + def test_kafka_listener_topic_in_extra(self): + consumes = [e for e in self.edges if e.kind == "CONSUMES" + and e.target == "kafka:order-events"] + assert consumes + assert consumes[0].extra.get("topic") == "order-events" + + def test_kafka_template_field_emits_produces_edge(self): + produces = [e for e in self.edges if e.kind == "PRODUCES"] + sources = {e.source for e in produces} + assert any("NotificationProducer" in s for s in sources) + + def test_kafka_receiver_field_emits_consumes_edge(self): + consumes = [e for e in self.edges if e.kind == "CONSUMES"] + sources = {e.source for e in consumes} + assert any("ReactiveOrderConsumer" in s for s in sources) + + def test_kafka_receiver_message_type_stored(self): + consumes = [e for e in self.edges if e.kind == "CONSUMES" + and "ReactiveOrderConsumer" in e.source] + assert consumes + assert consumes[0].extra.get("message_type") == "OrderEvent" + + def test_kafka_operations_field_emits_produces_edge(self): + produces = [e for e in self.edges if e.kind == "PRODUCES"] + sources = {e.source for e in produces} + assert any("ReactiveOrderConsumer" in s for s in sources) + + def test_static_field_not_in_kafka_edges(self): + all_kafka = [e for e in self.edges if e.kind in ("CONSUMES", "PRODUCES")] + field_names = {e.extra.get("field_name") for e in all_kafka} + assert "TOPIC" not in field_names + + def test_no_kafka_edges_for_plain_class(self): + # OrderEvent (plain class, no Kafka) should not appear as a source + kafka = [e for e in self.edges if e.kind in ("CONSUMES", "PRODUCES")] + bare_sources = {e.source.split("::")[-1].split(".")[0] for e in kafka} + assert "OrderEvent" not in bare_sources