Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/vuln_analysis/tools/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest

from vuln_analysis.utils.functions_parsers.golang_functions_parsers import (
GoLanguageFunctionsParser,
)


@pytest.fixture(scope="module")
def go_parser() -> GoLanguageFunctionsParser:
"""
Provides a single instance of the GoLanguageFunctionsParser
for all tests in a module.
"""
return GoLanguageFunctionsParser()
78 changes: 78 additions & 0 deletions src/vuln_analysis/tools/tests/test_go_segmenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import textwrap

from vuln_analysis.utils.go_segmenter_extended import GoSegmenterExtended


def _extract(code: str):
seg = GoSegmenterExtended(textwrap.dedent(code))
return seg.extract_functions_classes()


def test_segmenter_extracts_type_and_generic_method():
code = """
type Box[T any] struct { value T }
func (b *Box[T]) Set(v T) { b.value = v }
"""
expected_chunks = [
"type Box[T any] struct { value T }",
"func (b *Box[T]) Set(v T) { b.value = v }",
]

actual_chunks = [c.strip() for c in _extract(code)]
assert actual_chunks == expected_chunks


def test_segmenter_extracts_toplevel_function_only_and_ignores_nested():
code = """
func makeAdder(x int) func(int) int {
return func(y int) int { return x + y }
}
"""
expected_chunks = [
textwrap.dedent("""
func makeAdder(x int) func(int) int {
return func(y int) int { return x + y }
}
""").strip()
]

actual_chunks = [c.strip() for c in _extract(code)]
assert actual_chunks == expected_chunks


def test_segmenter_handles_double_pointer_receiver():
code = """
type Conn struct{}
func (c **Conn) Reset() {}
"""
expected_chunks = [
"type Conn struct{}",
"func (c **Conn) Reset() {}",
]

actual_chunks = [c.strip() for c in _extract(code)]
assert actual_chunks == expected_chunks


def test_segmenter_handles_multiline_generic_method():
code = """
func (r *Repo[
T any,
E error,
]) Save(v T) (E, error) {
return nil, nil
}
"""
expected_chunks = [
textwrap.dedent("""
func (r *Repo[
T any,
E error,
]) Save(v T) (E, error) {
return nil, nil
}
""").strip()
]

actual_chunks = [c.strip() for c in _extract(code)]
assert actual_chunks == expected_chunks
132 changes: 132 additions & 0 deletions src/vuln_analysis/tools/tests/test_golang_functions_parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@

import textwrap

import pytest
from langchain_core.documents import Document

HAPPY_PATH_CASES = [
("simple_function", "func DoSomething() {}", "DoSomething"),
("with_parameters", "func DoSomething(p1 string, p2 int) {}", "DoSomething"),
("with_return_value", "func DoSomething(v int) string {}", "DoSomething"),
(
"with_named_return",
"func DoSomething(a, b float64) (q float64, e error) {}",
"DoSomething",
),
(
"method_with_receiver",
"func (p *Point) DoSomething() float64 {}",
"DoSomething",
),
]


EDGE_CASES_TEST = [
("generic_function", "func DoSomething[T any](s []T) {}", "DoSomething"),
(
"letter_or_underscores",
"func _internal_calculate_v2() {}",
"_internal_calculate_v2",
),
(
"receivers_double_pointer_function",
"func (c **Connection) Close() error {}",
"Close",
),
(
"receivers_without_the_name_function",
"func (*Point) IsOrigin() bool {}",
"IsOrigin",
),
(
"multiline_function",
"""
func (r *Repository[
T Model,
K KeyType,
]) FindByID(id K) (*T, error) {}
""",
"FindByID",
),
]

NEGATIVE_ANONYMOUS_CASES = [
(
"assigned_to_variable",
"var greeter = func(name string) { fmt.Println('Hello,', name) }",
),
(
"assigned_to_variable2",
textwrap.dedent(
"""
greet := func() { // Assigning anonymous function to a variable 'greet'
fmt.Println("Greetings from a variable-assigned anonymous function!")
}
"""
),
),
(
"go_routine",
"go func() { fmt.Println('Running in background') }()",
),
(
"defer_statement",
"defer func() { file.Close() }()",
),
(
"callback_argument",
"http.HandleFunc('/', func(w http.ResponseWriter, r *http.Request) {})",
),
]

MALFORMED_INPUT_CASES = [
("empty_string", ""),
("whitespace_only", " \n\t "),
("just_the_keyword", "func"),
("incomplete_header", "func myFunc("),
("garbage_input", "a = b + c;"),
]

@pytest.mark.parametrize("test_id, code_snippet, expected_name", HAPPY_PATH_CASES)
def test_happy_path_function_names(go_parser, test_id,code_snippet, expected_name):
doc = Document(page_content=code_snippet.strip(), metadata={"source": "test.go"})
actual_name = go_parser.get_function_name(doc)
assert actual_name == expected_name, f"Test case '{test_id}' failed"

@pytest.mark.parametrize("test_id, code_snippet, expected_name", EDGE_CASES_TEST)
def test_edge_cases_function_names(go_parser, test_id, code_snippet, expected_name):
doc = Document(page_content=code_snippet.strip(), metadata={"source": "test.go"})
actual_name = go_parser.get_function_name(doc)
assert actual_name == expected_name, f"Test case '{test_id}' failed"

@pytest.mark.parametrize("test_id, code_snippet", NEGATIVE_ANONYMOUS_CASES)
def test_negative_cases_anonymous_functions(go_parser, test_id, code_snippet):
doc = Document(page_content=code_snippet.strip(), metadata={"source": "proxy.go"})
name = go_parser.get_function_name(doc)
assert name.startswith("anon_"), (
f"[{test_id}] Expected name to start with 'anon_', but got '{name}'"
)
parts = name.split("_")
assert len(parts) == 3, (
f"[{test_id}] Expected name format 'anon_<prefix>_<hash>', but got '{name}'"
)

assert parts[1] == "proxy", (
f"[{test_id}] Expected file prefix 'proxy', but got '{parts[1]}'"
)
hash_part = parts[2]
assert len(hash_part) == 8, (
f"[{test_id}] Hash part should be 8 characters, but got '{hash_part}'"
)
assert all(c in "0123456789abcdef" for c in hash_part), (
f"[{test_id}] Hash part should be hex, but got '{hash_part}'"
)

@pytest.mark.parametrize("test_id, code_snippet", MALFORMED_INPUT_CASES)
def test_malformed_input_graceful_failure(go_parser, test_id, code_snippet):
doc = Document(page_content=code_snippet, metadata={"source": "malformed.go"})
name = go_parser.get_function_name(doc)

assert name.startswith("anon_"), (
f"[{test_id}] Failed to handle malformed input gracefully. Got: {name}"
)
4 changes: 2 additions & 2 deletions src/vuln_analysis/tools/tests/test_transitive_code_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def test_transitive_search_golang_5():
(path_found, list_path) = result

assert path_found is False
assert len(list_path) is 1
assert len(list_path) == 1

# Test fix of https://issues.redhat.com/browse/APPENG-3435
@pytest.mark.asyncio
Expand All @@ -100,7 +100,7 @@ async def test_transitive_search_golang_6():
(path_found, list_path) = result
print(result)
assert path_found is True
assert len(list_path) is 2
assert len(list_path) == 2


def set_input_for_next_run(git_repository: str, git_ref: str, included_extensions: list[str],
Expand Down
7 changes: 5 additions & 2 deletions src/vuln_analysis/utils/chain_of_calls_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,15 @@ def __init__(self, documents: List[Document], ecosystem: Ecosystem, manifest_pat
logger.debug(f"no_macro_documents len : {len(no_macro_documents)}")

if not self.language_parser.is_script_language():
# filter out types and full code documents, retaining only functions/methods documents in this attribute.
# filter out types and full code documents, retaining only functions/methods documents in this attribute.
self.documents = [doc for doc in no_macro_documents if self.language_parser.is_function(doc)]
self.documents_of_functions = self.documents
else:
self.documents = filtered_documents
self.documents_of_functions = [doc for doc in self.documents
if doc.page_content.startswith(self.language_parser.get_function_reserved_word())]

# sort documents to ensure deterministic behavior
self.documents.sort(key=lambda doc: doc.metadata.get('source', ''))
logger.debug(f"self.documents len : {len(self.documents)}")
logger.debug("Chain of Calls Retriever - retaining only types/classes docs "
"documents_of_types len %d", len(self.documents_of_types))
Expand Down Expand Up @@ -249,6 +250,8 @@ def __find_caller_function_dfs(self, document_function: Document, function_packa
# Add same package itself to search path.
# direct_parents.extend([function_package])
# gets list of documents to search in only from parents of function' package.
# Fixes non-deterministic behavior in chain-of-calls resolution where identical inputs produced different call-chain lengths
direct_parents.sort()
function_name_to_search = self.language_parser.get_function_name(document_function)
if function_name_to_search == self.language_parser.get_constructor_method_name():
function_name_to_search = self.language_parser.get_class_name_from_class_function(document_function)
Expand Down
4 changes: 2 additions & 2 deletions src/vuln_analysis/utils/document_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from langchain_core.document_loaders.blob_loaders import Blob

from vuln_analysis.data_models.input import SourceDocumentsInfo
from vuln_analysis.utils.go_segmenters_with_methods import GoSegmenterWithMethods
from vuln_analysis.utils.go_segmenter_extended import GoSegmenterExtended
from vuln_analysis.utils.python_segmenters_with_classes_methods import PythonSegmenterWithClassesMethods
from vuln_analysis.utils.js_extended_parser import ExtendedJavaScriptSegmenter
from vuln_analysis.utils.source_code_git_loader import SourceCodeGitLoader
Expand Down Expand Up @@ -144,7 +144,7 @@ class ExtendedLanguageParser(LanguageParser):
"javascript": ExtendedJavaScriptSegmenter,
"js": ExtendedJavaScriptSegmenter,
}
additional_segmenters["go"] = GoSegmenterWithMethods
additional_segmenters["go"] = GoSegmenterExtended
additional_segmenters["python"] = PythonSegmenterWithClassesMethods
additional_segmenters["c"] = CSegmenterExtended

Expand Down
Loading