Skip to content

Commit e4ba257

Browse files
authored
fix issue #61 #66 (#87)
1 parent 105eb91 commit e4ba257

File tree

3 files changed

+190
-2
lines changed

3 files changed

+190
-2
lines changed

lang/golang/parser/utils.go

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@ package parser
1717
import (
1818
"bufio"
1919
"bytes"
20+
"container/list"
2021
"fmt"
2122
"go/ast"
23+
"go/build"
2224
"go/types"
2325
"io"
2426
"os"
2527
"path"
2628
"regexp"
2729
"strings"
30+
"sync"
2831

2932
"github.com/Knetic/govaluate"
3033
. "github.com/cloudwego/abcoder/lang/uniast"
@@ -49,8 +52,84 @@ func (c cache) Visited(val interface{}) bool {
4952
return ok
5053
}
5154

55+
type cacheEntry struct {
56+
key string
57+
value bool
58+
}
59+
60+
// PackageCache 缓存 importPath 是否是 system package
61+
type PackageCache struct {
62+
lock sync.Mutex
63+
cache map[string]*list.Element
64+
lru *list.List
65+
lruCapacity int
66+
}
67+
68+
func NewPackageCache(lruCapacity int) *PackageCache {
69+
return &PackageCache{
70+
cache: make(map[string]*list.Element),
71+
lru: list.New(),
72+
lruCapacity: lruCapacity,
73+
}
74+
}
75+
76+
// get retrieves a value from the cache.
77+
func (pc *PackageCache) get(key string) (bool, bool) {
78+
pc.lock.Lock()
79+
defer pc.lock.Unlock()
80+
if elem, ok := pc.cache[key]; ok {
81+
pc.lru.MoveToFront(elem)
82+
return elem.Value.(*cacheEntry).value, true
83+
}
84+
return false, false
85+
}
86+
87+
// set adds a value to the cache.
88+
func (pc *PackageCache) set(key string, value bool) {
89+
pc.lock.Lock()
90+
defer pc.lock.Unlock()
91+
92+
if elem, ok := pc.cache[key]; ok {
93+
pc.lru.MoveToFront(elem)
94+
elem.Value.(*cacheEntry).value = value
95+
return
96+
}
97+
98+
if pc.lru.Len() >= pc.lruCapacity {
99+
oldest := pc.lru.Back()
100+
if oldest != nil {
101+
pc.lru.Remove(oldest)
102+
delete(pc.cache, oldest.Value.(*cacheEntry).key)
103+
}
104+
}
105+
106+
elem := pc.lru.PushFront(&cacheEntry{key: key, value: value})
107+
pc.cache[key] = elem
108+
}
109+
110+
// IsStandardPackage 检查一个包是否为标准库,并使用内部缓存。
111+
func (pc *PackageCache) IsStandardPackage(path string) bool {
112+
if isStd, found := pc.get(path); found {
113+
return isStd
114+
}
115+
116+
pkg, err := build.Import(path, "", build.FindOnly)
117+
if err != nil {
118+
// Cannot find the package, assume it's not a standard package
119+
pc.set(path, false)
120+
return false
121+
}
122+
123+
isStd := pkg.Goroot
124+
pc.set(path, isStd)
125+
return isStd
126+
}
127+
128+
// stdlibCache 缓存 importPath 是否是 system package, 10000 个缓存
129+
var stdlibCache = NewPackageCache(10000)
130+
52131
func isSysPkg(importPath string) bool {
53-
return !strings.Contains(strings.Split(importPath, "/")[0], ".")
132+
return stdlibCache.IsStandardPackage(importPath)
54133
}
55134

56135
var (

lang/golang/parser/utils_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ import (
2121
"go/token"
2222
"go/types"
2323
"slices"
24+
"sync"
2425
"testing"
2526

27+
"github.com/stretchr/testify/assert"
28+
2629
"github.com/stretchr/testify/require"
2730
)
2831

@@ -195,3 +198,90 @@ var f func() (*http.Request, error)`,
195198
})
196199
}
197200
}
201+
202+
func resetGlobals() {
203+
// 重置包缓存
204+
stdlibCache = NewPackageCache(10000)
205+
}
206+
207+
func Test_isSysPkg(t *testing.T) {
208+
// 测试在 `go env GOROOT` 可以成功执行时的行为
209+
t.Run("Group: Happy Path - GOROOT is found", func(t *testing.T) {
210+
resetGlobals()
211+
212+
testCases := []struct {
213+
name string
214+
importPath string
215+
want bool
216+
}{
217+
{"standard library package", "fmt", true},
218+
{"nested standard library package", "net/http", true},
219+
{"third-party package", "github.com/google/uuid", false},
220+
{"extended library package", "golang.org/x/sync/errgroup", false},
221+
{"local-like package name", "myproject/utils", false},
222+
{"non-existent package", "non/existent/package", false},
223+
{"root-level package with dot", "gopkg.in/yaml.v2", false},
224+
}
225+
226+
for _, tc := range testCases {
227+
t.Run(tc.name, func(t *testing.T) {
228+
if got := isSysPkg(tc.importPath); got != tc.want {
229+
t.Errorf("isSysPkg(%q) = %v, want %v", tc.importPath, got, tc.want)
230+
}
231+
})
232+
}
233+
})
234+
235+
// 测试并发调用时的行为
236+
t.Run("Group: Concurrency Test", func(t *testing.T) {
237+
resetGlobals()
238+
var wg sync.WaitGroup
239+
numGoroutines := 50
240+
numOpsPerGoroutine := 100
241+
242+
for i := 0; i < numGoroutines; i++ {
243+
wg.Add(1)
244+
go func() {
245+
defer wg.Done()
246+
for j := 0; j < numOpsPerGoroutine; j++ {
247+
isSysPkg("fmt")
248+
isSysPkg("github.com/cloudwego/abcoder")
249+
isSysPkg("net/http")
250+
isSysPkg("a/b/c")
251+
}
252+
}()
253+
}
254+
wg.Wait()
255+
})
256+
257+
// 测试 LRU 缓存的驱逐策略
258+
t.Run("Group: LRU Eviction Test", func(t *testing.T) {
259+
resetGlobals()
260+
stdlibCache.lruCapacity = 2
261+
262+
// 1. 填满 Cache
263+
isSysPkg("fmt")
264+
isSysPkg("os")
265+
assert.Equal(t, 2, stdlibCache.lru.Len(), "Cache should be full")
266+
267+
// 2. 访问 "fmt" 使它最近被使用
268+
isSysPkg("fmt")
269+
assert.Equal(t, "fmt", stdlibCache.lru.Front().Value.(*cacheEntry).key, "fmt should be the most recently used")
270+
271+
// 3. 访问 "net" 使它最近被使用
272+
isSysPkg("net") // "os" should be evicted
273+
assert.Equal(t, 2, stdlibCache.lru.Len(), "Cache size should remain at capacity")
274+
275+
// 4. "fmt" 应该在 Cache 中
276+
_, foundFmt := stdlibCache.get("fmt")
277+
assert.True(t, foundFmt, "fmt should still be in the cache")
278+
279+
// 5. "net" 应该在 Cache 中
280+
_, foundNet := stdlibCache.get("net")
281+
assert.True(t, foundNet, "net should be in the cache")
282+
283+
// 6. "os" 不应该在 Cache 中
284+
_, foundOs := stdlibCache.get("os")
285+
assert.False(t, foundOs, "os should have been evicted from the cache")
286+
})
287+
}

lang/golang/writer/write.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
)
3636

3737
var _ uniast.Writer = (*Writer)(nil)
38+
var testPkgPathRegex = regexp.MustCompile(`^(.+?) \[(.+)\]$`)
3839

3940
type Options struct {
4041
// RepoDir string
@@ -81,6 +82,22 @@ func (w *Writer) WriteRepo(repo *uniast.Repository, outDir string) error {
8182
return nil
8283
}
8384

85+
// sanitizePkgPath sanitize the package path, remove the suffix in brackets
86+
func sanitizePkgPath(pkgPath string) string {
87+
matches := testPkgPathRegex.FindStringSubmatch(pkgPath)
88+
// matches should be 3 elements:
89+
// 1. The full string
90+
// 2. The package name
91+
// 3. The content inside the brackets
92+
if len(matches) == 3 {
93+
packageName := matches[1]
94+
testName := matches[2]
95+
if testName == packageName+".test" {
96+
return packageName
97+
}
98+
}
99+
return pkgPath
100+
}
84101
func (w *Writer) WriteModule(repo *uniast.Repository, modPath string, outDir string) error {
85102
mod := repo.Modules[modPath]
86103
if mod == nil {
@@ -94,7 +111,9 @@ func (w *Writer) WriteModule(repo *uniast.Repository, modPath string, outDir str
94111

95112
outdir := filepath.Join(outDir, mod.Dir)
96113
for dir, pkg := range w.visited {
97-
rel := strings.TrimPrefix(dir, mod.Name)
114+
// sanitize the package path
115+
cleanDir := sanitizePkgPath(dir)
116+
rel := strings.TrimPrefix(cleanDir, mod.Name)
98117
pkgDir := filepath.Join(outdir, rel)
99118
if err := os.MkdirAll(pkgDir, 0755); err != nil {
100119
return fmt.Errorf("mkdir %s failed: %v", pkgDir, err)

0 commit comments

Comments
 (0)