Skip to content

Commit

Permalink
Fix timewheel (#53)
Browse files Browse the repository at this point in the history
* fix time wheel: remove origin task from slot when adding again

* fix param verification in TimeWheel
  • Loading branch information
teckick authored and chicliz committed May 19, 2020
1 parent d44aa62 commit b051f34
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 18 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ require (
github.com/sirupsen/logrus v1.4.2 // indirect
github.com/smartystreets/goconvey v0.0.0-20190222223459-a17d461953aa // indirect
github.com/soheilhy/cmux v0.1.4 // indirect
github.com/stretchr/testify v1.2.2
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8 // indirect
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect
go.etcd.io/bbolt v1.3.2 // indirect
Expand Down
10 changes: 8 additions & 2 deletions util/time_wheel.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ type TimeWheel struct {

// NewTimeWheel create new time wheel
func NewTimeWheel(tick time.Duration, bucketsNum int) (*TimeWheel, error) {
if tick <= 0 || bucketsNum <= 0 {
return nil, errors.New("invalid params")
if bucketsNum <= 0 {
return nil, errors.New("bucket number must be greater than 0")
}
if int(tick.Seconds()) < 1 {
return nil, errors.New("tick cannot be less than 1s")
}

tw := &TimeWheel{
Expand Down Expand Up @@ -125,6 +128,9 @@ func (tw *TimeWheel) add(task *Task) {
round := tw.calculateRound(task.delay)
index := tw.calculateIndex(task.delay)
task.round = round
if originIndex, ok := tw.bucketIndexes[task.key]; ok {
delete(tw.buckets[originIndex], task.key)
}
tw.bucketIndexes[task.key] = index
tw.buckets[index][task.key] = task
}
Expand Down
81 changes: 65 additions & 16 deletions util/time_wheel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,26 @@
package util

import (
"fmt"
"strconv"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

type A struct {
a int
b string
a int
b string
isCallbacked int32
}

func (a *A) callback() {
atomic.StoreInt32(&a.isCallbacked, 1)
}

func callback() {
fmt.Println("timeout")
func (a *A) getCallbackValue() int32 {
return atomic.LoadInt32(&a.isCallbacked)
}

func newTimeWheel() *TimeWheel {
Expand All @@ -39,33 +46,75 @@ func newTimeWheel() *TimeWheel {
return tw
}

func TestNewTimeWheel(t *testing.T) {
tests := []struct {
name string
tick time.Duration
bucketNum int
hasErr bool
}{
{tick: time.Second, bucketNum: 0, hasErr: true},
{tick: time.Millisecond, bucketNum: 1, hasErr: true},
{tick: time.Second, bucketNum: 1, hasErr: false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
_, err := NewTimeWheel(test.tick, test.bucketNum)
assert.Equal(t, test.hasErr, err != nil)
})
}
}

func TestAdd(t *testing.T) {
tw := newTimeWheel()
err := tw.Add(time.Second*1, "test", callback)
if err != nil {
t.Fatalf("test add failed, %v", err)
a := &A{}
err := tw.Add(time.Second*1, "test", a.callback)
assert.NoError(t, err)

time.Sleep(time.Millisecond * 500)
assert.Equal(t, int32(0), a.getCallbackValue())
time.Sleep(time.Second * 2)
assert.Equal(t, int32(1), a.getCallbackValue())
tw.Stop()
}

func TestAddMultipleTimes(t *testing.T) {
a := &A{}
tw := newTimeWheel()
for i := 0; i < 4; i++ {
err := tw.Add(time.Second, "test", a.callback)
assert.NoError(t, err)
time.Sleep(time.Millisecond * 500)
t.Logf("current: %d", i)
assert.Equal(t, int32(0), a.getCallbackValue())
}
time.Sleep(time.Second * 5)

time.Sleep(time.Second * 2)
assert.Equal(t, int32(1), a.getCallbackValue())
tw.Stop()
}

func TestRemove(t *testing.T) {
a := &A{a: 10, b: "test"}
tw := newTimeWheel()
err := tw.Add(time.Second*1, a, callback)
if err != nil {
t.Fatalf("test add failed, %v", err)
}
tw.Remove(a)
time.Sleep(time.Second * 5)
err := tw.Add(time.Second*1, a, a.callback)
assert.NoError(t, err)

time.Sleep(time.Millisecond * 500)
assert.Equal(t, int32(0), a.getCallbackValue())
err = tw.Remove(a)
assert.NoError(t, err)
time.Sleep(time.Second * 2)
assert.Equal(t, int32(0), a.getCallbackValue())
tw.Stop()
}

func BenchmarkAdd(b *testing.B) {
a := &A{}
tw := newTimeWheel()
for i := 0; i < b.N; i++ {
key := "test" + strconv.Itoa(i)
err := tw.Add(time.Second, key, callback)
err := tw.Add(time.Second, key, a.callback)
if err != nil {
b.Fatalf("benchmark Add failed, %v", err)
}
Expand Down

0 comments on commit b051f34

Please sign in to comment.