diff --git a/src/vuln_analysis/tools/tests/conftest.py b/src/vuln_analysis/tools/tests/conftest.py new file mode 100644 index 00000000..648bf7bc --- /dev/null +++ b/src/vuln_analysis/tools/tests/conftest.py @@ -0,0 +1,14 @@ +import pytest + +from vuln_analysis.utils.functions_parsers.golang_functions_parsers import ( + GoLanguageFunctionsParser, +) + + +@pytest.fixture(scope="module") +def go_parser() -> GoLanguageFunctionsParser: + """ + Provides a single instance of the GoLanguageFunctionsParser + for all tests in a module. + """ + return GoLanguageFunctionsParser() diff --git a/src/vuln_analysis/tools/tests/test_go_segmenter.py b/src/vuln_analysis/tools/tests/test_go_segmenter.py new file mode 100644 index 00000000..1f86cedf --- /dev/null +++ b/src/vuln_analysis/tools/tests/test_go_segmenter.py @@ -0,0 +1,78 @@ +import textwrap + +from vuln_analysis.utils.go_segmenter_extended import GoSegmenterExtended + + +def _extract(code: str): + seg = GoSegmenterExtended(textwrap.dedent(code)) + return seg.extract_functions_classes() + + +def test_segmenter_extracts_type_and_generic_method(): + code = """ + type Box[T any] struct { value T } + func (b *Box[T]) Set(v T) { b.value = v } + """ + expected_chunks = [ + "type Box[T any] struct { value T }", + "func (b *Box[T]) Set(v T) { b.value = v }", + ] + + actual_chunks = [c.strip() for c in _extract(code)] + assert actual_chunks == expected_chunks + + +def test_segmenter_extracts_toplevel_function_only_and_ignores_nested(): + code = """ + func makeAdder(x int) func(int) int { + return func(y int) int { return x + y } + } + """ + expected_chunks = [ + textwrap.dedent(""" + func makeAdder(x int) func(int) int { + return func(y int) int { return x + y } + } + """).strip() + ] + + actual_chunks = [c.strip() for c in _extract(code)] + assert actual_chunks == expected_chunks + + +def test_segmenter_handles_double_pointer_receiver(): + code = """ + type Conn struct{} + func (c **Conn) Reset() {} + """ + expected_chunks = [ + "type Conn struct{}", + "func (c **Conn) Reset() {}", + ] + + actual_chunks = [c.strip() for c in _extract(code)] + assert actual_chunks == expected_chunks + + +def test_segmenter_handles_multiline_generic_method(): + code = """ + func (r *Repo[ + T any, + E error, + ]) Save(v T) (E, error) { + return nil, nil + } + """ + expected_chunks = [ + textwrap.dedent(""" + func (r *Repo[ + T any, + E error, + ]) Save(v T) (E, error) { + return nil, nil + } + """).strip() + ] + + actual_chunks = [c.strip() for c in _extract(code)] + assert actual_chunks == expected_chunks diff --git a/src/vuln_analysis/tools/tests/test_golang_functions_parsers.py b/src/vuln_analysis/tools/tests/test_golang_functions_parsers.py new file mode 100644 index 00000000..ef024078 --- /dev/null +++ b/src/vuln_analysis/tools/tests/test_golang_functions_parsers.py @@ -0,0 +1,132 @@ + +import textwrap + +import pytest +from langchain_core.documents import Document + +HAPPY_PATH_CASES = [ + ("simple_function", "func DoSomething() {}", "DoSomething"), + ("with_parameters", "func DoSomething(p1 string, p2 int) {}", "DoSomething"), + ("with_return_value", "func DoSomething(v int) string {}", "DoSomething"), + ( + "with_named_return", + "func DoSomething(a, b float64) (q float64, e error) {}", + "DoSomething", + ), + ( + "method_with_receiver", + "func (p *Point) DoSomething() float64 {}", + "DoSomething", + ), +] + + +EDGE_CASES_TEST = [ + ("generic_function", "func DoSomething[T any](s []T) {}", "DoSomething"), + ( + "letter_or_underscores", + "func _internal_calculate_v2() {}", + "_internal_calculate_v2", + ), + ( + "receivers_double_pointer_function", + "func (c **Connection) Close() error {}", + "Close", + ), + ( + "receivers_without_the_name_function", + "func (*Point) IsOrigin() bool {}", + "IsOrigin", + ), + ( + "multiline_function", + """ + func (r *Repository[ + T Model, + K KeyType, + ]) FindByID(id K) (*T, error) {} + """, + "FindByID", + ), +] + +NEGATIVE_ANONYMOUS_CASES = [ + ( + "assigned_to_variable", + "var greeter = func(name string) { fmt.Println('Hello,', name) }", + ), + ( + "assigned_to_variable2", + textwrap.dedent( + """ + greet := func() { // Assigning anonymous function to a variable 'greet' + fmt.Println("Greetings from a variable-assigned anonymous function!") + } + """ + ), + ), + ( + "go_routine", + "go func() { fmt.Println('Running in background') }()", + ), + ( + "defer_statement", + "defer func() { file.Close() }()", + ), + ( + "callback_argument", + "http.HandleFunc('/', func(w http.ResponseWriter, r *http.Request) {})", + ), +] + +MALFORMED_INPUT_CASES = [ + ("empty_string", ""), + ("whitespace_only", " \n\t "), + ("just_the_keyword", "func"), + ("incomplete_header", "func myFunc("), + ("garbage_input", "a = b + c;"), +] + +@pytest.mark.parametrize("test_id, code_snippet, expected_name", HAPPY_PATH_CASES) +def test_happy_path_function_names(go_parser, test_id,code_snippet, expected_name): + doc = Document(page_content=code_snippet.strip(), metadata={"source": "test.go"}) + actual_name = go_parser.get_function_name(doc) + assert actual_name == expected_name, f"Test case '{test_id}' failed" + +@pytest.mark.parametrize("test_id, code_snippet, expected_name", EDGE_CASES_TEST) +def test_edge_cases_function_names(go_parser, test_id, code_snippet, expected_name): + doc = Document(page_content=code_snippet.strip(), metadata={"source": "test.go"}) + actual_name = go_parser.get_function_name(doc) + assert actual_name == expected_name, f"Test case '{test_id}' failed" + +@pytest.mark.parametrize("test_id, code_snippet", NEGATIVE_ANONYMOUS_CASES) +def test_negative_cases_anonymous_functions(go_parser, test_id, code_snippet): + doc = Document(page_content=code_snippet.strip(), metadata={"source": "proxy.go"}) + name = go_parser.get_function_name(doc) + assert name.startswith("anon_"), ( + f"[{test_id}] Expected name to start with 'anon_', but got '{name}'" + ) + parts = name.split("_") + assert len(parts) == 3, ( + f"[{test_id}] Expected name format 'anon__', but got '{name}'" + ) + + assert parts[1] == "proxy", ( + f"[{test_id}] Expected file prefix 'proxy', but got '{parts[1]}'" + ) + hash_part = parts[2] + assert len(hash_part) == 8, ( + f"[{test_id}] Hash part should be 8 characters, but got '{hash_part}'" + ) + assert all(c in "0123456789abcdef" for c in hash_part), ( + f"[{test_id}] Hash part should be hex, but got '{hash_part}'" + ) + +@pytest.mark.parametrize("test_id, code_snippet", MALFORMED_INPUT_CASES) +def test_malformed_input_graceful_failure(go_parser, test_id, code_snippet): + doc = Document(page_content=code_snippet, metadata={"source": "malformed.go"}) + name = go_parser.get_function_name(doc) + + assert name.startswith("anon_"), ( + f"[{test_id}] Failed to handle malformed input gracefully. Got: {name}" + ) diff --git a/src/vuln_analysis/tools/tests/test_transitive_code_search.py b/src/vuln_analysis/tools/tests/test_transitive_code_search.py index 5622ed20..55bd19cd 100644 --- a/src/vuln_analysis/tools/tests/test_transitive_code_search.py +++ b/src/vuln_analysis/tools/tests/test_transitive_code_search.py @@ -86,7 +86,7 @@ async def test_transitive_search_golang_5(): (path_found, list_path) = result assert path_found is False - assert len(list_path) is 1 + assert len(list_path) == 1 # Test fix of https://issues.redhat.com/browse/APPENG-3435 @pytest.mark.asyncio @@ -100,7 +100,7 @@ async def test_transitive_search_golang_6(): (path_found, list_path) = result print(result) assert path_found is True - assert len(list_path) is 2 + assert len(list_path) == 2 def set_input_for_next_run(git_repository: str, git_ref: str, included_extensions: list[str], diff --git a/src/vuln_analysis/utils/chain_of_calls_retriever.py b/src/vuln_analysis/utils/chain_of_calls_retriever.py index a9957957..f700add1 100644 --- a/src/vuln_analysis/utils/chain_of_calls_retriever.py +++ b/src/vuln_analysis/utils/chain_of_calls_retriever.py @@ -191,14 +191,15 @@ def __init__(self, documents: List[Document], ecosystem: Ecosystem, manifest_pat logger.debug(f"no_macro_documents len : {len(no_macro_documents)}") if not self.language_parser.is_script_language(): - # filter out types and full code documents, retaining only functions/methods documents in this attribute. + # filter out types and full code documents, retaining only functions/methods documents in this attribute. self.documents = [doc for doc in no_macro_documents if self.language_parser.is_function(doc)] self.documents_of_functions = self.documents else: self.documents = filtered_documents self.documents_of_functions = [doc for doc in self.documents if doc.page_content.startswith(self.language_parser.get_function_reserved_word())] - + # sort documents to ensure deterministic behavior + self.documents.sort(key=lambda doc: doc.metadata.get('source', '')) logger.debug(f"self.documents len : {len(self.documents)}") logger.debug("Chain of Calls Retriever - retaining only types/classes docs " "documents_of_types len %d", len(self.documents_of_types)) @@ -249,6 +250,8 @@ def __find_caller_function_dfs(self, document_function: Document, function_packa # Add same package itself to search path. # direct_parents.extend([function_package]) # gets list of documents to search in only from parents of function' package. + # Fixes non-deterministic behavior in chain-of-calls resolution where identical inputs produced different call-chain lengths + direct_parents.sort() function_name_to_search = self.language_parser.get_function_name(document_function) if function_name_to_search == self.language_parser.get_constructor_method_name(): function_name_to_search = self.language_parser.get_class_name_from_class_function(document_function) diff --git a/src/vuln_analysis/utils/document_embedding.py b/src/vuln_analysis/utils/document_embedding.py index e01a600a..dae2e639 100644 --- a/src/vuln_analysis/utils/document_embedding.py +++ b/src/vuln_analysis/utils/document_embedding.py @@ -37,7 +37,7 @@ from langchain_core.document_loaders.blob_loaders import Blob from vuln_analysis.data_models.input import SourceDocumentsInfo -from vuln_analysis.utils.go_segmenters_with_methods import GoSegmenterWithMethods +from vuln_analysis.utils.go_segmenter_extended import GoSegmenterExtended from vuln_analysis.utils.python_segmenters_with_classes_methods import PythonSegmenterWithClassesMethods from vuln_analysis.utils.js_extended_parser import ExtendedJavaScriptSegmenter from vuln_analysis.utils.source_code_git_loader import SourceCodeGitLoader @@ -144,7 +144,7 @@ class ExtendedLanguageParser(LanguageParser): "javascript": ExtendedJavaScriptSegmenter, "js": ExtendedJavaScriptSegmenter, } - additional_segmenters["go"] = GoSegmenterWithMethods + additional_segmenters["go"] = GoSegmenterExtended additional_segmenters["python"] = PythonSegmenterWithClassesMethods additional_segmenters["c"] = CSegmenterExtended diff --git a/src/vuln_analysis/utils/functions_parsers/golang_functions_parsers.py b/src/vuln_analysis/utils/functions_parsers/golang_functions_parsers.py index 94a677b0..b17db9d2 100644 --- a/src/vuln_analysis/utils/functions_parsers/golang_functions_parsers.py +++ b/src/vuln_analysis/utils/functions_parsers/golang_functions_parsers.py @@ -1,11 +1,11 @@ +import hashlib import os import re +from typing import Callable, final from langchain_core.documents import Document from .lang_functions_parsers import LanguageFunctionsParser -from ..dep_tree import Ecosystem -from ..standard_library_cache import StandardLibraryCache EMBEDDED_TYPE = "embedded_type" @@ -45,7 +45,7 @@ def check_types_from_callee_package(params: list[tuple], type_documents: list[Do # Only type without package if len(parts) == 1: for the_type in type_documents: - if the_type.page_content.startwith(f"type {parts[0]}"): + if the_type.page_content.startswith(f"type {parts[0]}"): code_with_type_file = code_documents.get(the_type.metadata['source']) type_file_package_name = get_package_name_file(code_with_type_file) if type_file_package_name == callee_function_file_package_name: @@ -53,7 +53,7 @@ def check_types_from_callee_package(params: list[tuple], type_documents: list[Do # type with package qualifier else: for the_type in type_documents: - if the_type.page_content.startwith(f"type {parts[1]}"): + if the_type.page_content.startswith(f"type {parts[1]}"): code_with_type_file = code_documents.get(the_type.metadata['source']) package_match = handle_imports(code_with_type_file, parts[0], callee_package) return package_match @@ -115,11 +115,56 @@ def handle_imports(code_content: str, identifier: str, callee_package: str) -> b return False +@final +class GoConstants: + UNKNOWN_FUNCTION = "" + ANON_FUNCTION_PREFIX = "anon" + FUNC_KEYWORD = "func" + class GoLanguageFunctionsParser(LanguageFunctionsParser): def get_dummy_function(self, function_name): return f"{self.get_function_reserved_word()} {function_name}() {{}}" - + def _generate_fallback_name(self, document: Document) -> str: + """ + Generates a deterministic name for anonymous or unparsable functions. + """ + try: + source = document.metadata.get("source", "unknown_source") + prefix = source.split("/")[-1].split(".")[0] + content_bytes = document.page_content.encode("utf-8") + short_hash = hashlib.sha256(content_bytes).hexdigest()[:8] + return f"{GoConstants.ANON_FUNCTION_PREFIX}_{prefix}_{short_hash}" + except Exception: + return GoConstants.UNKNOWN_FUNCTION + def _try_parse_method(self, header: str) -> str | None: + """ + Tries to parse a function name assuming it's a method with a receiver. + Returns the name or None if it doesn't match the pattern. + """ + if not header.startswith("func ("): + return None + receiver_end_idx = header.find(")") + if receiver_end_idx == -1: + return None + + after_receiver = header[receiver_end_idx + 1 :].lstrip() + + if "(" in after_receiver and ")" not in after_receiver: + return None + + name_match = re.match(r"([a-zA-Z0-9_]+)", after_receiver) + + return name_match.group(1) if name_match else None + def _try_parse_regular_function(self, header: str) -> str | None: + """ + Tries to parse a regular or generic function name. + Returns the name or None if it's an anonymous function or doesn't match. + """ + match = re.search(r"^func\s+([a-zA-Z0-9_]+)\s*[\[\(](?=.*[\]\)])", header) + if match: + return match.group(1) + return None def __trace_down_package(self, expression: str, code_documents: dict[str, Document], type_documents: list[Document], callee_package: str, fields_of_types: dict[tuple, list[tuple]], functions_local_variables_index: dict[str, dict], @@ -171,7 +216,9 @@ def __prepare_package_lookup(self, parts, variables_mappings, the_part: int): if var_properties is not None: resolved_type = var_properties.get("type") value = var_properties.get("value") - struct_initializer_expression = re.search(r"(&|\\*)?\w+\s*{", value) + struct_initializer_expression = None + if value: + struct_initializer_expression = re.search(r"(&|\\*)?\w+\s*{", value) resolved_type = str(resolved_type).replace("&", "").replace("*", "") return resolved_type, struct_initializer_expression, value, var_properties else: @@ -202,10 +249,15 @@ def __get_type_docs_matched_with_callee_package(self, callee_package, checked_ty (self.get_type_name(a_type) == checked_type or self.get_type_name(a_type) in checked_type)] def create_map_of_local_vars(self, functions_methods_documents: list[Document]) -> dict[str, dict]: + """ + Builds a mapping of function identifiers to their local variables and parameters. + Includes support for anonymous functions with deterministic names. + """ mappings = dict() for func_method in functions_methods_documents: - func_key = f"{self.get_function_name(func_method)}@{func_method.metadata['source']}" all_vars = dict() + func_name = self.get_function_name(func_method) + func_key = f"{func_name}@{func_method.metadata['source']}" for row in func_method.page_content.splitlines(): if not self.is_comment_line(row): # Extract arguments and receiver argument of type as parameters @@ -262,7 +314,7 @@ def create_map_of_local_vars(self, functions_methods_documents: list[Document]) right_side = func_method.page_content[index_of_start_if + 2: index_of_start_if + end_of_assignment - 1].strip() - all_vars[(left_side.strip())] = {" value": right_side.strip().replace + all_vars[(left_side.strip())] = {"value": right_side.strip().replace ("\n\t", "").replace("\t", ""), "type": LOCAL_IMPLICIT } @@ -284,7 +336,8 @@ def create_map_of_local_vars(self, functions_methods_documents: list[Document]) else: pass - mappings[func_key] = all_vars + if func_name != "": + mappings[func_key] = all_vars return mappings @@ -327,20 +380,26 @@ def parse_all_type_struct_class_to_fields(self, types: list[Document]) -> dict[t next_struct = current_line_stripped.find("struct") if next_eol > - 1 and (next_struct == -1 or next_struct > next_eol): if not self.is_comment_line(current_line_stripped[:next_eol + 1]): - declaration_parts = current_line_stripped[:next_eol + 1].split() + # remove inline comments + line = current_line_stripped[:next_eol + 1] + line = line.split("//", 1)[0].split("/*", 1)[0].strip() # If row inside block contains func, then it's a function type and need to parse it in a # special way. - if current_line_stripped[:next_eol + 1].__contains__("func"): - declaration_parts = current_line_stripped[:next_eol + 1].split("func") + if line.__contains__("func"): + declaration_parts = line.split("func") declaration_parts = [part.strip() for part in declaration_parts] if len(declaration_parts) == 2: declaration_parts[1] = f"func {declaration_parts[1]}" + else: + # For regular types, split by whitespace + declaration_parts = line.split() # ignore alias' "equals" notation - if len(declaration_parts) == 3: - [name, _, type_name] = declaration_parts - elif len(declaration_parts) == 2: - [name, type_name] = declaration_parts - if len(declaration_parts) == (2 or 3): + if len(declaration_parts) in [2, 3]: + if len(declaration_parts) == 3: + [name, _, type_name] = declaration_parts + elif len(declaration_parts) == 2: + [name, type_name] = declaration_parts + self.parse_one_type(Document(page_content=f"type {name} {type_name}", metadata={"source": the_type.metadata['source']}), types_mapping) @@ -428,38 +487,32 @@ def dir_name_for_3rd_party_packages(self) -> str: def is_exported_function(self, function: Document) -> bool: function_name = self.get_function_name(function) - return re.search("[A-Z][a-z0-9-]*", function_name) + return bool(re.search("[A-Z][a-z0-9-]*", function_name)) def get_function_name(self, function: Document) -> str: - try: - index_of_function_opening = function.page_content.index("{") - except ValueError as e: - function_line = function.page_content.find(os.linesep) - # print(f"function {function.page_content[:function_line]} => contains no body ") - return function.page_content[:function_line] - - function_header = function.page_content[:index_of_function_opening] - # function is a method of a type - if function_header.startswith("func ("): - index_of_first_right_bracket = function_header.index(")") - skip_receiver_arg = function_header[index_of_first_right_bracket + 1:] - index_of_first_left_bracket = skip_receiver_arg.index("(") - return skip_receiver_arg[:index_of_first_left_bracket].strip() - # regular function not tied to a certain type - else: - try: - index_of_first_left_bracket = function_header.index("(") - # Go Generic function - except ValueError: - try: - index_of_first_left_bracket = function_header.index("[") - except ValueError: - raise ValueError(f"Invalid function header - {function_header}") - func_with_name = function_header[:index_of_first_left_bracket] - if len(func_with_name.split(" ")) > 1: - return func_with_name.split(" ")[1] - # TODO Try to extract anonymous function var - # else: + """ + Extracts the function name from its Go definition. + """ + if function is None or getattr(function, "page_content", None) is None: + return GoConstants.UNKNOWN_FUNCTION + + content = function.page_content.strip() + if not content: + return self._generate_fallback_name(function) + + body_start_idx = content.find("{") + header = content if body_start_idx == -1 else content[:body_start_idx] + + parsing_strategies: list[Callable[[str], str | None]] = [ + self._try_parse_method, + self._try_parse_regular_function, + ] + + for strategy in parsing_strategies: + name = strategy(header) + if name: + return name + return self._generate_fallback_name(function) def search_for_called_function(self, caller_function: Document, callee_function_name: str, callee_function: Document, @@ -508,7 +561,7 @@ def __check_identifier_resolved_to_callee_function_package(self, function: Docum try: callee_function = code_documents[callee_function_file_name] callee_function_package = get_package_name_file(callee_function).strip() - except KeyError as e: + except KeyError: # Standard library function , there is no function code, thus the source name is the package name in # this case callee_function_package = callee_function_file_name @@ -579,31 +632,34 @@ def __check_identifier_resolved_to_callee_function_package(self, function: Docum return False def get_package_names(self, function: Document) -> list[str]: - package_names = list() - full_doc_path = str(function.metadata['source']) - parts = full_doc_path.split("/") + package_names = [] + full_doc_path = str(function.metadata.get("source") or "").strip() + if not full_doc_path: + return [""] + + parts = [p for p in full_doc_path.split("/") if p] version = "" + if len(parts) > 4: match = re.search(r"[vV][0-9]{1,2}", parts[4]) - if match and match.group(0): + if match: version = f"/{match.group(0)}" - if parts[0].startswith(self.dir_name_for_3rd_party_packages()) and len(parts) > 3: + if parts and parts[0].startswith(self.dir_name_for_3rd_party_packages()) and len(parts) > 3: package_names.append(f"{parts[1]}/{parts[2]}{version}") package_names.append(f"{parts[1]}/{parts[2]}/{parts[3]}{version}") - else: - try: - package_names.append(f"{parts[0]}/{parts[1]}{version}") + elif len(parts) >= 2: + package_names.append(f"{parts[0]}/{parts[1]}{version}") + if len(parts) >= 3: package_names.append(f"{parts[0]}/{parts[1]}/{parts[2]}{version}") - # Standard library package - except IndexError as index_excp: - if len(parts) > 1: - package_names.append(f"{parts[0]}/{parts[1]}{version}") - else: - package_names.append(f"{parts[0]}{version}") + elif len(parts) == 1: + package_names.append(f"{parts[0]}{version}") + else: + package_names.append("") return package_names + def is_root_package(self, function: Document) -> bool: return not function.metadata['source'].startswith(self.dir_name_for_3rd_party_packages()) diff --git a/src/vuln_analysis/utils/go_segmenter_extended.py b/src/vuln_analysis/utils/go_segmenter_extended.py new file mode 100644 index 00000000..e449ff31 --- /dev/null +++ b/src/vuln_analysis/utils/go_segmenter_extended.py @@ -0,0 +1,36 @@ +from langchain_community.document_loaders.parsers.language.go import GoSegmenter + +CHUNK_QUERY_EXT = """ +(source_file + [ + (function_declaration) @function + (method_declaration) @method + (type_declaration) @type + ] +) +""".strip() + + +class GoSegmenterExtended(GoSegmenter): + def get_chunk_query(self) -> str: + return CHUNK_QUERY_EXT + + def extract_functions_classes(self): + """ + Extracts all TOP-LEVEL functions, methods, and types. + Nested anonymous functions are kept inside their parent functions. + """ + language = self.get_language() + query = language.query(self.get_chunk_query()) + + parser = self.get_parser() + tree = parser.parse(bytes(self.code, "UTF-8")) + + captures = query.captures(tree.root_node) + + chunks = [] + for node, _ in captures: + chunk_text = node.text.decode("UTF-8") + chunks.append(chunk_text) + + return chunks