Skip to content

Commit c5e9b05

Browse files
committed
Merge pull request #109 from Workiva/batcher_deadlock
Avoid deadlock between Put and Get
2 parents fa486f9 + c85f8d4 commit c5e9b05

File tree

2 files changed

+89
-44
lines changed

2 files changed

+89
-44
lines changed

batcher/batcher.go

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,15 @@ package batcher
1919
import (
2020
"errors"
2121
"sync"
22+
"sync/atomic"
2223
"time"
2324
)
2425

26+
const (
27+
batcherActive = uint32(0)
28+
batcherDisposed = uint32(1)
29+
)
30+
2531
// Batcher provides an API for accumulating items into a batch for processing.
2632
type Batcher interface {
2733
// Put adds items to the batcher.
@@ -54,11 +60,13 @@ type basicBatcher struct {
5460
maxItems uint
5561
maxBytes uint
5662
calculateBytes CalculateBytes
57-
disposed bool
63+
disposed uint32
5864
items []interface{}
5965
lock sync.RWMutex
6066
batchChan chan []interface{}
67+
disposeChan chan struct{}
6168
availableBytes uint
69+
waiting int32
6270
}
6371

6472
// New creates a new Batcher using the provided arguments.
@@ -80,17 +88,19 @@ func New(maxTime time.Duration, maxItems, maxBytes, queueLen uint, calculate Cal
8088
calculateBytes: calculate,
8189
items: make([]interface{}, 0, maxItems),
8290
batchChan: make(chan []interface{}, queueLen),
91+
disposeChan: make(chan struct{}),
8392
}, nil
8493
}
8594

86-
// Put adds items to the batcher.
95+
// Put adds items to the batcher. If Put is continually called without calls to
96+
// Get, an unbounded number of go-routines will be generated.
97+
// Note: there is no order guarantee for items entering/leaving the batcher.
8798
func (b *basicBatcher) Put(item interface{}) error {
88-
b.lock.Lock()
89-
if b.disposed {
90-
b.lock.Unlock()
99+
// Check to see if disposed before putting
100+
if b.IsDisposed() {
91101
return ErrDisposed
92102
}
93-
103+
b.lock.Lock()
94104
b.items = append(b.items, item)
95105
if b.calculateBytes != nil {
96106
b.availableBytes += b.calculateBytes(item)
@@ -104,7 +114,10 @@ func (b *basicBatcher) Put(item interface{}) error {
104114
}
105115

106116
// Get retrieves a batch from the batcher. This call will block until
107-
// one of the conditions for a "complete" batch is reached.
117+
// one of the conditions for a "complete" batch is reached. If Put is
118+
// continually called without calls to Get, an unbounded number of
119+
// go-routines will be generated.
120+
// Note: there is no order guarantee for items entering/leaving the batcher.
108121
func (b *basicBatcher) Get() ([]interface{}, error) {
109122
// Don't check disposed yet so any items remaining in the queue
110123
// will be returned properly.
@@ -114,18 +127,25 @@ func (b *basicBatcher) Get() ([]interface{}, error) {
114127
timeout = time.After(b.maxTime)
115128
}
116129

130+
// Check to see if disposed before blocking
131+
if b.IsDisposed() {
132+
return nil, ErrDisposed
133+
}
134+
117135
select {
118-
case items, ok := <-b.batchChan:
136+
case items := <-b.batchChan:
137+
return items, nil
138+
case _, ok := <-b.disposeChan:
119139
if !ok {
120140
return nil, ErrDisposed
121141
}
122-
return items, nil
142+
return nil, nil
123143
case <-timeout:
124-
b.lock.Lock()
125-
if b.disposed {
126-
b.lock.Unlock()
144+
// Check to see if disposed before getting lock
145+
if b.IsDisposed() {
127146
return nil, ErrDisposed
128147
}
148+
b.lock.Lock()
129149
items := b.items
130150
b.items = make([]interface{}, 0, b.maxItems)
131151
b.availableBytes = 0
@@ -136,11 +156,10 @@ func (b *basicBatcher) Get() ([]interface{}, error) {
136156

137157
// Flush forcibly completes the batch currently being built
138158
func (b *basicBatcher) Flush() error {
139-
b.lock.Lock()
140-
if b.disposed {
141-
b.lock.Unlock()
159+
if b.IsDisposed() {
142160
return ErrDisposed
143161
}
162+
b.lock.Lock()
144163
b.flush()
145164
b.lock.Unlock()
146165
return nil
@@ -150,30 +169,44 @@ func (b *basicBatcher) Flush() error {
150169
// will return ErrDisposed, calls to Get will return an error iff
151170
// there are no more ready batches.
152171
func (b *basicBatcher) Dispose() {
153-
b.lock.Lock()
154-
if b.disposed {
155-
b.lock.Unlock()
172+
// Check to see if disposed before attempting to dispose
173+
if atomic.CompareAndSwapUint32(&b.disposed, batcherActive, batcherDisposed) {
156174
return
157175
}
176+
b.lock.Lock()
158177
b.flush()
159-
b.disposed = true
160178
b.items = nil
179+
close(b.disposeChan)
180+
181+
// Drain the batch channel and all routines waiting to put on the channel
182+
for len(b.batchChan) > 0 || atomic.LoadInt32(&b.waiting) > 0 {
183+
<-b.batchChan
184+
}
161185
close(b.batchChan)
162186
b.lock.Unlock()
163187
}
164188

165189
// IsDisposed will determine if the batcher is disposed
166190
func (b *basicBatcher) IsDisposed() bool {
167-
b.lock.RLock()
168-
disposed := b.disposed
169-
b.lock.RUnlock()
170-
return disposed
191+
return atomic.LoadUint32(&b.disposed) == batcherDisposed
171192
}
172193

173194
// flush adds the batch currently being built to the queue of completed batches.
174195
// flush is not threadsafe, so should be synchronized externally.
175196
func (b *basicBatcher) flush() {
176-
b.batchChan <- b.items
197+
// Note: This needs to be in a go-routine to avoid locking out gets when
198+
// the batch channel is full.
199+
cpItems := make([]interface{}, len(b.items))
200+
for i, val := range b.items {
201+
cpItems[i] = val
202+
}
203+
// Signal one more waiter for the batch channel
204+
atomic.AddInt32(&b.waiting, 1)
205+
// Don't block on the channel put
206+
go func() {
207+
b.batchChan <- cpItems
208+
atomic.AddInt32(&b.waiting, -1)
209+
}()
177210
b.items = make([]interface{}, 0, b.maxItems)
178211
b.availableBytes = 0
179212
}

batcher/batcher_test.go

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@ func TestMaxItems(t *testing.T) {
3636
})
3737
assert.Nil(err)
3838

39-
go func() {
40-
for i := 0; i < 1000; i++ {
41-
assert.Nil(b.Put("foo bar baz"))
42-
}
43-
}()
39+
for i := 0; i < 1000; i++ {
40+
assert.Nil(b.Put("foo bar baz"))
41+
}
4442

4543
batch, err := b.Get()
4644
assert.Len(batch, 100)
@@ -139,32 +137,46 @@ func TestMultiConsumer(t *testing.T) {
139137

140138
func TestDispose(t *testing.T) {
141139
assert := assert.New(t)
142-
b, err := New(0, 2, 100000, 10, func(str interface{}) uint {
140+
b, err := New(1, 2, 100000, 2, func(str interface{}) uint {
143141
return uint(len(str.(string)))
144142
})
145143
assert.Nil(err)
146144
b.Put("a")
147145
b.Put("b")
148146
b.Put("c")
149-
wait := make(chan bool)
150-
go func() {
151-
batch1, err := b.Get()
152-
assert.Equal([]interface{}{"a", "b"}, batch1)
153-
assert.Nil(err)
154-
batch2, err := b.Get()
155-
assert.Equal([]interface{}{"c"}, batch2)
156-
assert.Nil(err)
157-
_, err = b.Get()
158-
assert.Equal(ErrDisposed, err)
159-
wait <- true
160-
}()
147+
148+
possibleBatches := [][]interface{}{
149+
[]interface{}{"a", "b"},
150+
[]interface{}{"c"},
151+
}
152+
153+
// Wait for items to get to the channel
154+
for len(b.(*basicBatcher).batchChan) == 0 {
155+
time.Sleep(1 * time.Millisecond)
156+
}
157+
batch1, err := b.Get()
158+
assert.Contains(possibleBatches, batch1)
159+
assert.Nil(err)
160+
161+
batch2, err := b.Get()
162+
assert.Contains(possibleBatches, batch2)
163+
assert.Nil(err)
164+
165+
b.Put("d")
166+
b.Put("e")
167+
b.Put("f")
168+
b.Put("g")
169+
b.Put("h")
170+
b.Put("i")
161171

162172
b.Dispose()
163173

164-
assert.Equal(ErrDisposed, b.Put("d"))
174+
_, err = b.Get()
175+
assert.Equal(ErrDisposed, err)
176+
177+
assert.Equal(ErrDisposed, b.Put("j"))
165178
assert.Equal(ErrDisposed, b.Flush())
166179

167-
<-wait
168180
}
169181

170182
func TestIsDisposed(t *testing.T) {

0 commit comments

Comments
 (0)