Skip to content

Commit

Permalink
Merge pull request #86 from vivek-ng/vivek-ng/fix-panic-caching
Browse files Browse the repository at this point in the history
Panic errors from batch function should not be cached
  • Loading branch information
pavelnikolov authored Apr 22, 2022
2 parents a7ede83 + 58f8c20 commit 4611304
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 5 deletions.
18 changes: 17 additions & 1 deletion dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package dataloader

import (
"context"
"errors"
"fmt"
"log"
"runtime"
Expand Down Expand Up @@ -48,6 +49,17 @@ type ResultMany[V any] struct {
Error []error
}

// PanicErrorWrapper wraps the error interface.
// This is used to check if the error is a panic error.
// We should not cache panic errors.
type PanicErrorWrapper struct {
panicError error
}

func (p *PanicErrorWrapper) Error() string {
return p.panicError.Error()
}

// Loader implements the dataloader.Interface.
type Loader[K comparable, V any] struct {
// the batch function to be used by this loader
Expand Down Expand Up @@ -219,6 +231,10 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
}
result.mu.RLock()
defer result.mu.RUnlock()
var ev *PanicErrorWrapper
if result.value.Error != nil && errors.As(result.value.Error, &ev) {
l.Clear(ctx, key)
}
return result.value.Data, result.value.Error
}
defer finish(thunk)
Expand Down Expand Up @@ -431,7 +447,7 @@ func (b *batcher[K, V]) batch(originalContext context.Context) {

if panicErr != nil {
for _, req := range reqs {
req.channel <- &Result[V]{Error: fmt.Errorf("Panic received in batch function: %v", panicErr)}
req.channel <- &Result[V]{Error: &PanicErrorWrapper{panicError: fmt.Errorf("Panic received in batch function: %v", panicErr)}}
close(req.channel)
}
return
Expand Down
99 changes: 95 additions & 4 deletions dataloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ func TestLoader(t *testing.T) {
}
})

t.Run("test Load Method cache error", func(t *testing.T) {
t.Parallel()
errorCacheLoader, _ := ErrorCacheLoader[string](0)
ctx := context.Background()
futures := []Thunk[string]{}
for i := 0; i < 2; i++ {
futures = append(futures, errorCacheLoader.Load(ctx, strconv.Itoa(i)))
}

for _, f := range futures {
_, err := f()
if err == nil {
t.Error("Error was not propagated")
}
}
nextFuture := errorCacheLoader.Load(ctx, "1")
_, err := nextFuture()

// Normal errors should be cached.
if err == nil {
t.Error("Error from batch function was not cached")
}
})

t.Run("test Load Method Panic Safety in multiple keys", func(t *testing.T) {
t.Parallel()
defer func() {
Expand All @@ -63,7 +87,7 @@ func TestLoader(t *testing.T) {
t.Error("Panic Loader's panic should have been handled'")
}
}()
panicLoader, _ := PanicLoader[string](0)
panicLoader, _ := PanicCacheLoader[string](0)
futures := []Thunk[string]{}
ctx := context.Background()
for i := 0; i < 3; i++ {
Expand All @@ -75,6 +99,18 @@ func TestLoader(t *testing.T) {
t.Error("Panic was not propagated as an error.")
}
}

futures = []Thunk[string]{}
for i := 0; i < 3; i++ {
futures = append(futures, panicLoader.Load(ctx, strconv.Itoa(1)))
}

for _, f := range futures {
_, err := f()
if err != nil {
t.Error("Panic error from batch function was cached")
}
}
})

t.Run("test LoadMany returns errors", func(t *testing.T) {
Expand Down Expand Up @@ -143,13 +179,21 @@ func TestLoader(t *testing.T) {
t.Error("Panic Loader's panic should have been handled'")
}
}()
panicLoader, _ := PanicLoader[string](0)
panicLoader, _ := PanicCacheLoader[string](0)
ctx := context.Background()
future := panicLoader.LoadMany(ctx, []string{"1"})
future := panicLoader.LoadMany(ctx, []string{"1", "2"})
_, errs := future()
if len(errs) < 1 || errs[0].Error() != "Panic received in batch function: Programming error" {
if len(errs) < 2 || errs[0].Error() != "Panic received in batch function: Programming error" {
t.Error("Panic was not propagated as an error.")
}

future = panicLoader.LoadMany(ctx, []string{"1"})
_, errs = future()

if len(errs) > 0 {
t.Error("Panic error from batch function was cached")
}

})

t.Run("test LoadMany method", func(t *testing.T) {
Expand Down Expand Up @@ -531,6 +575,53 @@ func PanicLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
}, WithBatchCapacity[K, K](max), withSilentLogger[K, K]())
return panicLoader, &loadCalls
}

func PanicCacheLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
var loadCalls [][]K
panicCacheLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] {
if len(keys) > 1 {
panic("Programming error")
}

returnResult := make([]*Result[K], len(keys))
for idx := range returnResult {
returnResult[idx] = &Result[K]{
keys[0],
nil,
}
}

return returnResult

}, WithBatchCapacity[K, K](max), withSilentLogger[K, K]())
return panicCacheLoader, &loadCalls
}

func ErrorCacheLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
var loadCalls [][]K
errorCacheLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] {
if len(keys) > 1 {
var results []*Result[K]
for _, key := range keys {
results = append(results, &Result[K]{key, fmt.Errorf("this is a test error")})
}
return results
}

returnResult := make([]*Result[K], len(keys))
for idx := range returnResult {
returnResult[idx] = &Result[K]{
keys[0],
nil,
}
}

return returnResult

}, WithBatchCapacity[K, K](max), withSilentLogger[K, K]())
return errorCacheLoader, &loadCalls
}

func BadLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
var mu sync.Mutex
var loadCalls [][]K
Expand Down

0 comments on commit 4611304

Please sign in to comment.