Skip to content

Commit 1a051f8

Browse files
authored
Merge pull request #6 from charmbracelet/ctx-race
Ctx race
2 parents efe1ff2 + a30642c commit 1a051f8

File tree

2 files changed

+80
-12
lines changed

2 files changed

+80
-12
lines changed

context.go

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/hex"
66
"net"
77
"sync"
8+
"time"
89

910
gossh "golang.org/x/crypto/ssh"
1011
)
@@ -92,10 +93,14 @@ type Context interface {
9293
}
9394

9495
type sshContext struct {
95-
context.Context
96-
*sync.RWMutex
96+
ctx context.Context
97+
mtx *sync.RWMutex
9798
}
9899

100+
var _ context.Context = &sshContext{}
101+
102+
var _ sync.Locker = &sshContext{}
103+
99104
func newContext(srv *Server) (*sshContext, context.CancelFunc) {
100105
innerCtx, cancel := context.WithCancel(context.Background())
101106
ctx := &sshContext{innerCtx, &sync.RWMutex{}}
@@ -120,21 +125,45 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) {
120125
}
121126

122127
func (ctx *sshContext) SetValue(key, value interface{}) {
123-
ctx.RWMutex.Lock()
124-
defer ctx.RWMutex.Unlock()
125-
ctx.Context = context.WithValue(ctx.Context, key, value)
128+
ctx.mtx.Lock()
129+
defer ctx.mtx.Unlock()
130+
ctx.ctx = context.WithValue(ctx.ctx, key, value)
126131
}
127132

128133
func (ctx *sshContext) Value(key interface{}) interface{} {
129-
ctx.RWMutex.RLock()
130-
defer ctx.RWMutex.RUnlock()
131-
return ctx.Context.Value(key)
134+
ctx.mtx.RLock()
135+
defer ctx.mtx.RUnlock()
136+
return ctx.ctx.Value(key)
132137
}
133138

134139
func (ctx *sshContext) Done() <-chan struct{} {
135-
ctx.RWMutex.RLock()
136-
defer ctx.RWMutex.RUnlock()
137-
return ctx.Context.Done()
140+
ctx.mtx.RLock()
141+
defer ctx.mtx.RUnlock()
142+
return ctx.ctx.Done()
143+
}
144+
145+
// Deadline implements context.Context.
146+
func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) {
147+
ctx.mtx.RLock()
148+
defer ctx.mtx.RUnlock()
149+
return ctx.ctx.Deadline()
150+
}
151+
152+
// Err implements context.Context.
153+
func (ctx *sshContext) Err() error {
154+
ctx.mtx.RLock()
155+
defer ctx.mtx.RUnlock()
156+
return ctx.ctx.Err()
157+
}
158+
159+
// Lock implements sync.Locker.
160+
func (ctx *sshContext) Lock() {
161+
ctx.mtx.Lock()
162+
}
163+
164+
// Unlock implements sync.Locker.
165+
func (ctx *sshContext) Unlock() {
166+
ctx.mtx.Unlock()
138167
}
139168

140169
func (ctx *sshContext) User() string {

context_test.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package ssh
22

3-
import "testing"
3+
import (
4+
"testing"
5+
"time"
6+
)
47

58
func TestSetPermissions(t *testing.T) {
69
t.Parallel()
@@ -69,3 +72,39 @@ func TestRaceRWIssue160(t *testing.T) {
6972
t.Fatal(err)
7073
}
7174
}
75+
76+
// Taken from https://github.com/gliderlabs/ssh/pull/211/commits/02f9d573009f8c13755b6b90fa14a4f549b17b22
77+
func TestSetValueConcurrency(t *testing.T) {
78+
ctx, cancel := newContext(nil)
79+
defer cancel()
80+
81+
go func() {
82+
for { // use a loop to access context.Context functions to make sure they are thread-safe with SetValue
83+
_, _ = ctx.Deadline()
84+
_ = ctx.Err()
85+
_ = ctx.Value("foo")
86+
select {
87+
case <-ctx.Done():
88+
break
89+
default:
90+
}
91+
}
92+
}()
93+
ctx.SetValue("bar", -1) // a context value which never changes
94+
now := time.Now()
95+
var cnt int64
96+
go func() {
97+
for time.Since(now) < 100*time.Millisecond {
98+
cnt++
99+
ctx.SetValue("foo", cnt) // a context value which changes a lot
100+
}
101+
cancel()
102+
}()
103+
<-ctx.Done()
104+
if ctx.Value("foo") != cnt {
105+
t.Fatal("context.Value(foo) doesn't match latest SetValue")
106+
}
107+
if ctx.Value("bar") != -1 {
108+
t.Fatal("context.Value(bar) doesn't match latest SetValue")
109+
}
110+
}

0 commit comments

Comments
 (0)