Skip to content

Commit a5f34f0

Browse files
committed
optimize tref out
1 parent 6dc7625 commit a5f34f0

File tree

11 files changed

+143
-145
lines changed

11 files changed

+143
-145
lines changed

internal/engine/link.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -481,19 +481,16 @@ func (e *Engine) getDef(ctx context.Context, rs *RequestState, c DefContainer) (
481481
func addSrcTargetToInputs(def *pluginv1.TargetDef, currentTargetAddrHash func() string) {
482482
inputs := def.GetInputs()
483483
for _, input := range inputs {
484-
if input.GetRef().GetPackage() != tref.QueryPackage {
485-
continue
486-
}
484+
refo := input.GetRef()
485+
ref := refo.GetTarget()
487486

488-
ref := input.GetRef()
489-
args := ref.GetArgs()
490-
if args == nil {
491-
args = map[string]string{}
487+
if ref.GetPackage() != tref.QueryPackage {
488+
continue
492489
}
493-
args[querySrcTargetArg] = currentTargetAddrHash()
494-
ref.SetArgs(args)
495490

496-
input.SetRef(ref)
491+
ref = tref.WithArg(ref, querySrcTargetArg, currentTargetAddrHash())
492+
refo.SetTarget(ref)
493+
refo.ClearHash()
497494
}
498495
def.SetInputs(inputs)
499496
}
@@ -760,7 +757,7 @@ func (e *Engine) innerLink(ctx context.Context, rs *RequestState, def *TargetDef
760757

761758
if input.GetRef().HasOutput() {
762759
if !slices.Contains(linkedDep.GetOutputs(), input.GetRef().GetOutput()) {
763-
return fmt.Errorf("%v doesnt have a named output %q", tref.Format(input.GetRef()), input.GetRef().GetOutput())
760+
return fmt.Errorf("%v doesnt have a named output %q", tref.FormatOut(input.GetRef()), input.GetRef().GetOutput())
764761
}
765762

766763
outputs = []string{input.GetRef().GetOutput()}

internal/tmatch/match.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ func walkDirs(ctx context.Context, walkRoot, root string, filter func(path strin
2525
}
2626

2727
return func(yield func(string, error) bool) {
28+
stop := false
2829
err := fsWalkCache.Walk(walkRoot, func(path string, d fs.DirEntry, err error) error {
2930
if d == nil || !d.IsDir() {
3031
return nil
@@ -48,13 +49,14 @@ func walkDirs(ctx context.Context, walkRoot, root string, filter func(path strin
4849
}
4950

5051
if !yield(pkg, nil) {
52+
stop = true
5153
return fs.SkipAll
5254
}
5355

5456
return nil
5557
})
5658
if err != nil {
57-
if !yield("", err) {
59+
if !stop && !yield("", err) {
5860
return
5961
}
6062
}

lib/tref/format.go

Lines changed: 106 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ import (
77
"strings"
88
"unsafe"
99

10-
"github.com/hephbuild/heph/internal/hproto/hashpb"
11-
1210
"github.com/hephbuild/heph/internal/hsync"
1311

1412
cache "github.com/Code-Hex/go-generics-cache"
@@ -22,28 +20,12 @@ import (
2220
type Ref = pluginv1.TargetRef
2321
type RefOut = pluginv1.TargetRefWithOutput
2422

25-
var _ Refable = (*Ref)(nil)
26-
var _ Refable = (*RefOut)(nil)
27-
var _ RefableOut = (*RefOut)(nil)
28-
29-
type Refable interface {
30-
GetArgs() map[string]string
31-
GetPackage() string
32-
GetName() string
33-
}
34-
3523
type HashStore interface {
3624
GetHash() uint64
3725
HasHash() bool
3826
SetHash(uint64)
3927
}
4028

41-
type RefableOut interface {
42-
Refable
43-
GetOutput() string
44-
GetFilters() []string
45-
}
46-
4729
func FormatFile(pkg string, file string) string {
4830
return Format(New(JoinPackage("@heph/file", pkg), "content", map[string]string{"f": file}))
4931
}
@@ -117,89 +99,84 @@ func ParseQuery(ref *pluginv1.TargetRef) (QueryOptions, error) {
11799
var formatCache = cache.New[uint64, string](cache.AsLFU[uint64, string](lfu.WithCapacity(10000)))
118100
var formatSf = hsingleflight.Group[uint64, string]{}
119101

120-
var formatHashPool = hsync.Pool[*xxh3.Hasher]{New: xxh3.New}
121-
122-
func sumRef(ref hashpb.StableWriter) uint64 {
123-
h := formatHashPool.Get()
124-
defer formatHashPool.Put(h)
125-
h.Reset()
102+
var formatOutCache = cache.New[uint64, string](cache.AsLFU[uint64, string](lfu.WithCapacity(10000)))
103+
var formatOutSf = hsingleflight.Group[uint64, string]{}
126104

127-
switch ref := ref.(type) {
128-
case *pluginv1.TargetRef:
129-
sumRefTargetRef(h, ref)
130-
case *pluginv1.TargetRefWithOutput:
131-
sumRefTargetRefWithOutput(h, ref)
132-
default:
133-
hashpb.Hash(h, ref, nil)
134-
}
105+
var formatHashPool = hsync.Pool[*xxh3.Hasher]{New: xxh3.New}
135106

136-
return h.Sum64()
137-
}
107+
func sumRefTargetRef(m *pluginv1.TargetRef) uint64 {
108+
hasher := formatHashPool.Get()
109+
defer formatHashPool.Put(hasher)
110+
hasher.Reset()
138111

139-
func sumRefTargetRef(hasher *xxh3.Hasher, m *pluginv1.TargetRef) {
140112
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(m.GetPackage()), len(m.GetPackage())))
141113
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(m.GetName()), len(m.GetName())))
142114
for k, v := range hmaps.Sorted(m.GetArgs()) {
143115
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(k), len(k)))
144116
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(v), len(v)))
145117
}
118+
119+
return hasher.Sum64()
146120
}
147121

148-
func sumRefTargetRefWithOutput(hasher *xxh3.Hasher, m *pluginv1.TargetRefWithOutput) {
149-
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(m.GetPackage()), len(m.GetPackage())))
150-
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(m.GetName()), len(m.GetName())))
151-
for k, v := range hmaps.Sorted(m.GetArgs()) {
152-
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(k), len(k)))
153-
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(v), len(v)))
154-
}
155-
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(m.GetOutput()), len(m.GetOutput())))
156-
if len(m.GetFilters()) > 0 {
157-
for _, v := range m.GetFilters() {
122+
func sumRefTargetRefWithOutput(m *pluginv1.TargetRefWithOutput) uint64 {
123+
hasher := formatHashPool.Get()
124+
defer formatHashPool.Put(hasher)
125+
hasher.Reset()
126+
127+
{
128+
m := m.GetTarget()
129+
130+
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(m.GetPackage()), len(m.GetPackage())))
131+
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(m.GetName()), len(m.GetName())))
132+
for k, v := range hmaps.Sorted(m.GetArgs()) {
133+
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(k), len(k)))
158134
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(v), len(v)))
159135
}
160136
}
137+
138+
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(m.GetOutput()), len(m.GetOutput())))
139+
for _, v := range m.GetFilters() {
140+
_, _ = hasher.Write(unsafe.Slice(unsafe.StringData(v), len(v)))
141+
}
142+
143+
return hasher.Sum64()
161144
}
162145

163-
func Format(ref Refable) string {
164-
if refh, ok := ref.(hashpb.StableWriter); ok {
165-
var sum uint64
166-
if s, ok := ref.(HashStore); ok {
167-
if s.HasHash() {
168-
sum = s.GetHash()
169-
} else {
170-
sum = sumRef(refh)
171-
s.SetHash(sum)
172-
}
173-
} else {
174-
sum = sumRef(refh)
175-
}
146+
func Format(ref *Ref) string {
147+
var sum uint64
148+
if ref.HasHash() {
149+
sum = ref.GetHash()
150+
} else {
151+
sum = sumRefTargetRef(ref)
152+
ref.SetHash(sum)
153+
}
176154

155+
f, ok := formatCache.Get(sum)
156+
if ok {
157+
return f
158+
}
159+
160+
f, _, _ = formatSf.Do(sum, func() (string, error) {
177161
f, ok := formatCache.Get(sum)
178162
if ok {
179-
return f
163+
return f, nil
180164
}
181165

182-
f, _, _ = formatSf.Do(sum, func() (string, error) {
183-
f, ok := formatCache.Get(sum)
184-
if ok {
185-
return f, nil
186-
}
187-
188-
f = format(ref)
166+
f = format(ref)
189167

190-
formatCache.Set(sum, f)
168+
formatCache.Set(sum, f)
191169

192-
return f, nil
193-
})
170+
return f, nil
171+
})
194172

195-
return f
196-
}
173+
return f
197174

198-
return format(ref)
199175
}
200176

201-
func format(ref Refable) string {
177+
func format(ref *Ref) string {
202178
var sb strings.Builder
179+
203180
sb.WriteString("//")
204181
sb.WriteString(ref.GetPackage())
205182
sb.WriteString(":")
@@ -225,24 +202,65 @@ func format(ref Refable) string {
225202
}
226203
}
227204

228-
if ref, ok := ref.(RefableOut); ok {
229-
out := ref.GetOutput()
230-
if out != "" {
231-
sb.WriteString("|")
232-
sb.WriteString(out)
205+
return sb.String()
206+
}
207+
208+
func FormatOut(ref *RefOut) string {
209+
var sum uint64
210+
if ref.HasHash() {
211+
sum = ref.GetHash()
212+
} else {
213+
sum = sumRefTargetRefWithOutput(ref)
214+
ref.SetHash(sum)
215+
}
216+
217+
f, ok := formatOutCache.Get(sum)
218+
if ok {
219+
return f
220+
}
221+
222+
f, _, _ = formatOutSf.Do(sum, func() (string, error) {
223+
f, ok := formatOutCache.Get(sum)
224+
if ok {
225+
return f, nil
233226
}
234227

235-
if len(ref.GetFilters()) > 0 {
236-
sb.WriteString(" filters=")
237-
first := true
238-
for _, f := range ref.GetFilters() {
239-
if !first {
240-
sb.WriteString(",")
241-
} else {
242-
first = false
243-
}
244-
sb.WriteString(f)
228+
f = formatOut(ref)
229+
230+
formatOutCache.Set(sum, f)
231+
232+
return f, nil
233+
})
234+
235+
return f
236+
237+
}
238+
239+
func formatOut(ref *RefOut) string {
240+
if ref.GetOutput() == "" && len(ref.GetFilters()) == 0 {
241+
return Format(ref.GetTarget())
242+
}
243+
244+
var sb strings.Builder
245+
246+
sb.WriteString(Format(ref.GetTarget()))
247+
248+
out := ref.GetOutput()
249+
if out != "" {
250+
sb.WriteString("|")
251+
sb.WriteString(out)
252+
}
253+
254+
if len(ref.GetFilters()) > 0 {
255+
sb.WriteString(" filters=")
256+
first := true
257+
for _, f := range ref.GetFilters() {
258+
if !first {
259+
sb.WriteString(",")
260+
} else {
261+
first = false
245262
}
263+
sb.WriteString(f)
246264
}
247265
}
248266

lib/tref/utils.go

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,6 @@ func Equal(a, b *pluginv1.TargetRef) bool {
1414
return a.GetPackage() == b.GetPackage() && a.GetName() == b.GetName() && maps.Equal(a.GetArgs(), b.GetArgs())
1515
}
1616

17-
func CompareOut(a, b *pluginv1.TargetRefWithOutput) int {
18-
if v := Compare(WithoutOut(a), WithoutOut(b)); v != 0 {
19-
return v
20-
}
21-
22-
if a.HasOutput() && b.HasOutput() {
23-
if v := strings.Compare(a.GetOutput(), b.GetOutput()); v != 0 {
24-
return v
25-
}
26-
} else {
27-
if !a.HasOutput() {
28-
return 1
29-
}
30-
if !b.HasOutput() {
31-
return -1
32-
}
33-
}
34-
35-
return 0
36-
}
37-
3817
func Compare(a, b *pluginv1.TargetRef) int {
3918
if v := strings.Compare(a.GetPackage(), b.GetPackage()); v != 0 {
4019
return v
@@ -113,20 +92,23 @@ func WithOut(ref *pluginv1.TargetRef, output string) *pluginv1.TargetRefWithOutp
11392
}
11493

11594
return pluginv1.TargetRefWithOutput_builder{
116-
Package: htypes.Ptr(ref.GetPackage()),
117-
Name: htypes.Ptr(ref.GetName()),
118-
Args: ref.GetArgs(),
119-
Output: outputp,
95+
Target: ref,
96+
Output: outputp,
12097
}.Build()
12198
}
12299

123100
func WithFilters(ref *pluginv1.TargetRefWithOutput, filters []string) *pluginv1.TargetRefWithOutput {
101+
if len(filters) == 0 && len(ref.GetFilters()) == 0 {
102+
return ref
103+
}
104+
124105
ref = hproto.Clone(ref)
106+
ref.ClearHash()
125107
ref.SetFilters(filters)
126108

127109
return ref
128110
}
129111

130112
func WithoutOut(ref *pluginv1.TargetRefWithOutput) *pluginv1.TargetRef {
131-
return New(ref.GetPackage(), ref.GetName(), ref.GetArgs())
113+
return ref.GetTarget()
132114
}

plugin/pluginexec/sandbox.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func SetupSandbox(
7474
for artifact := range ArtifactsForId(results, tool.GetId(), pluginv1.Artifact_TYPE_OUTPUT) {
7575
listArtifact, err := SetupSandboxBinArtifact(ctx, artifact.GetArtifact(), binfs)
7676
if err != nil {
77-
return nil, fmt.Errorf("%v: %w", tref.Format(tool.GetRef()), err)
77+
return nil, fmt.Errorf("%v: %w", tref.FormatOut(tool.GetRef()), err)
7878
}
7979
listArtifacts = append(listArtifacts, pluginv1.ArtifactWithOrigin_builder{
8080
Artifact: listArtifact,

0 commit comments

Comments
 (0)