Skip to content

Commit 4b877d9

Browse files
committed
refactor(go-segmenter): replace custom GoSegmenter with native Tree-Sitter implementation
Signed-off-by: Vladimir Belousov <[email protected]>
1 parent 3964f2d commit 4b877d9

File tree

6 files changed

+224
-66
lines changed

6 files changed

+224
-66
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from vuln_analysis.utils.go_segmenter_extended import GoSegmenterExtended
2+
3+
4+
def _extract(code: str):
5+
seg = GoSegmenterExtended(code)
6+
return [s.strip() for s in seg.extract_functions_classes()]
7+
8+
9+
def test_generic_method_basic():
10+
code = """
11+
type Box[T any] struct { value T }
12+
func (b *Box[T]) Set(v T) { b.value = v }
13+
"""
14+
chunks = _extract(code)
15+
assert any("Set" in c for c in chunks), "generic method not extracted"
16+
17+
18+
def test_generic_multiple_type_params():
19+
code = """
20+
func MapKeys[K comparable, V any](m map[K]V) []K {
21+
keys := make([]K, 0, len(m))
22+
for k := range m {
23+
keys = append(keys, k)
24+
}
25+
return keys
26+
}
27+
"""
28+
chunks = _extract(code)
29+
assert any("MapKeys" in c for c in chunks), "multiple generics not parsed"
30+
31+
32+
def test_function_returning_func():
33+
code = """
34+
func makeAdder(x int) func(int) int {
35+
return func(y int) int { return x + y }
36+
}
37+
"""
38+
chunks = _extract(code)
39+
assert any("makeAdder" in c for c in chunks), "failed to parse func returning func"
40+
41+
42+
def test_inline_anonymous_func():
43+
code = """
44+
func Worker() {
45+
defer func() { cleanup() }()
46+
go func() { runTask() }()
47+
}
48+
"""
49+
chunks = _extract(code)
50+
assert any("Worker" in c for c in chunks), "missed inline anonymous func"
51+
52+
53+
def test_double_pointer_receiver():
54+
code = """
55+
type Conn struct{}
56+
func (c **Conn) Reset() {}
57+
"""
58+
chunks = _extract(code)
59+
assert any("Reset" in c for c in chunks), "failed to detect pointer receiver"
60+
61+
62+
def test_multiline_generic_method():
63+
code = """
64+
func (r *Repo[
65+
T any,
66+
E error,
67+
]) Save(v T) (E, error) {
68+
return nil, nil
69+
}
70+
"""
71+
chunks = _extract(code)
72+
assert any("Save" in c for c in chunks), "multiline generic method not parsed"

src/vuln_analysis/tools/tests/test_transitive_code_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ async def test_transitive_search_golang_5():
8686
(path_found, list_path) = result
8787

8888
assert path_found is False
89-
assert len(list_path) is 1
89+
assert len(list_path) == 1
9090

9191
# Test fix of https://issues.redhat.com/browse/APPENG-3435
9292
@pytest.mark.asyncio
@@ -100,7 +100,7 @@ async def test_transitive_search_golang_6():
100100
(path_found, list_path) = result
101101
print(result)
102102
assert path_found is False
103-
assert len(list_path) is 0
103+
assert len(list_path) == 0
104104

105105

106106
def set_input_for_next_run(git_repository: str, git_ref: str, included_extensions: list[str],

src/vuln_analysis/utils/chain_of_calls_retriever.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,15 @@ def __init__(self, documents: List[Document], ecosystem: Ecosystem, manifest_pat
190190
logger.debug(f"no_macro_documents len : {len(no_macro_documents)}")
191191

192192
if not self.language_parser.is_script_language():
193-
# filter out types and full code documents, retaining only functions/methods documents in this attribute.
193+
# filter out types and full code documents, retaining only functions/methods documents in this attribute.
194194
self.documents = [doc for doc in no_macro_documents if self.language_parser.is_function(doc)]
195195
self.documents_of_functions = self.documents
196196
else:
197197
self.documents = filtered_documents
198198
self.documents_of_functions = [doc for doc in self.documents
199199
if doc.page_content.startswith(self.language_parser.get_function_reserved_word())]
200-
200+
# sort documents to ensure deterministic behavior
201+
self.documents.sort(key=lambda doc: doc.metadata.get('source', ''))
201202
logger.debug(f"self.documents len : {len(self.documents)}")
202203
logger.debug("Chain of Calls Retriever - retaining only types/classes docs "
203204
"documents_of_types len %d", len(self.documents_of_types))
@@ -248,6 +249,8 @@ def __find_caller_function_dfs(self, document_function: Document, function_packa
248249
# Add same package itself to search path.
249250
# direct_parents.extend([function_package])
250251
# gets list of documents to search in only from parents of function' package.
252+
# Fixes non-deterministic behavior in chain-of-calls resolution where identical inputs produced different call-chain lengths
253+
direct_parents.sort()
251254
function_name_to_search = self.language_parser.get_function_name(document_function)
252255
if function_name_to_search == self.language_parser.get_constructor_method_name():
253256
function_name_to_search = self.language_parser.get_class_name_from_class_function(document_function)

src/vuln_analysis/utils/document_embedding.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,35 +15,37 @@
1515

1616
import copy
1717
import json
18-
import logging
1918
import os
2019
import pickle
2120
import sys
2221
import time
2322
import typing
2423
from hashlib import sha512
25-
from pathlib import Path
26-
from pathlib import PurePath
24+
from pathlib import Path, PurePath
2725

2826
from langchain.docstore.document import Document
29-
from langchain.text_splitter import Language
30-
from langchain.text_splitter import RecursiveCharacterTextSplitter
27+
from langchain.text_splitter import Language, RecursiveCharacterTextSplitter
3128
from langchain_community.document_loaders.generic import GenericLoader
32-
from langchain_community.document_loaders.parsers.language.code_segmenter import CodeSegmenter
33-
from langchain_community.document_loaders.parsers.language.language_parser import LANGUAGE_EXTENSIONS
34-
from langchain_community.document_loaders.parsers.language.language_parser import LANGUAGE_SEGMENTERS
35-
from langchain_community.document_loaders.parsers.language.language_parser import LanguageParser
29+
from langchain_community.document_loaders.parsers.language.code_segmenter import (
30+
CodeSegmenter,
31+
)
32+
from langchain_community.document_loaders.parsers.language.language_parser import (
33+
LANGUAGE_EXTENSIONS,
34+
LANGUAGE_SEGMENTERS,
35+
LanguageParser,
36+
)
3637
from langchain_community.vectorstores import FAISS
3738
from langchain_core.document_loaders.blob_loaders import Blob
3839

3940
from vuln_analysis.data_models.input import SourceDocumentsInfo
40-
from vuln_analysis.utils.go_segmenters_with_methods import GoSegmenterWithMethods
41-
from vuln_analysis.utils.python_segmenters_with_classes_methods import PythonSegmenterWithClassesMethods
41+
from vuln_analysis.logging.loggers_factory import LoggingFactory
42+
from vuln_analysis.utils.git_utils import sanitize_git_url_for_path
43+
from vuln_analysis.utils.go_segmenter_extended import GoSegmenterExtended
4244
from vuln_analysis.utils.js_extended_parser import ExtendedJavaScriptSegmenter
45+
from vuln_analysis.utils.python_segmenters_with_classes_methods import (
46+
PythonSegmenterWithClassesMethods,
47+
)
4348
from vuln_analysis.utils.source_code_git_loader import SourceCodeGitLoader
44-
from vuln_analysis.utils.git_utils import sanitize_git_url_for_path
45-
from vuln_analysis.utils.transitive_code_searcher_tool import TransitiveCodeSearcher
46-
from vuln_analysis.logging.loggers_factory import LoggingFactory
4749

4850
from vuln_analysis.utils.c_segmenter_custom import CSegmenterExtended
4951

@@ -144,7 +146,8 @@ class ExtendedLanguageParser(LanguageParser):
144146
"javascript": ExtendedJavaScriptSegmenter,
145147
"js": ExtendedJavaScriptSegmenter,
146148
}
147-
additional_segmenters["go"] = GoSegmenterWithMethods
149+
# additional_segmenters["go"] = GoSegmenterWithMethods
150+
additional_segmenters["go"] = GoSegmenterExtended
148151
additional_segmenters["python"] = PythonSegmenterWithClassesMethods
149152
additional_segmenters["c"] = CSegmenterExtended
150153

0 commit comments

Comments
 (0)