Skip to content

Commit

Permalink
Able to rate limit by custom HTTP request headers. Also, write a lot …
Browse files Browse the repository at this point in the history
…more tests.
  • Loading branch information
Didip Kerabat committed May 17, 2015
1 parent 93960d6 commit fc751b9
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 59 deletions.
38 changes: 38 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Package config provides data structure to configure rate limiter.
package config

import (
"time"
)

// NewLimiter is a constructor for Limiter.
func NewLimiter(max int64, ttl time.Duration) *Limiter {
limiter := &Limiter{Max: max, TTL: ttl}
limiter.Message = "You have reached maximum request limit."
limiter.StatusCode = 429

return limiter
}

// Limiter is a config struct to limit a particular request handler.
type Limiter struct {
// HTTP message when limit is reached.
Message string

// HTTP status code when limit is reached.
StatusCode int

// Maximum number of requests to limit per duration.
Max int64

// Duration of rate limiter.
TTL time.Duration

// List of HTTP Methods to limit (GET, POST, PUT, etc.).
// Empty means limit all methods.
Methods []string

// List of HTTP headers to limit.
// Empty means skip headers checking.
Headers map[string][]string
}
22 changes: 22 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package config

import (
"testing"
"time"
)

func TestConstructor(t *testing.T) {
limiter := NewLimiter(1, time.Second)
if limiter.Max != 1 {
t.Errorf("Max field is incorrect. Value: %v", limiter.Max)
}
if limiter.TTL != time.Second {
t.Errorf("TTL field is incorrect. Value: %v", limiter.TTL)
}
if limiter.Message != "You have reached maximum request limit." {
t.Errorf("Message field is incorrect. Value: %v", limiter.Message)
}
if limiter.StatusCode != 429 {
t.Errorf("StatusCode field is incorrect. Value: %v", limiter.StatusCode)
}
}
15 changes: 15 additions & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Package errors provide data structure for errors.
package errors

import "fmt"

// HTTPError is an error struct that returns both message and status code.
type HTTPError struct {
Message string
StatusCode int
}

// Error returns error message.
func (httperror *HTTPError) Error() string {
return fmt.Sprintf("%v: %v", httperror.StatusCode, httperror.Message)
}
10 changes: 10 additions & 0 deletions errors/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package errors

import "testing"

func TestError(t *testing.T) {
errs := HTTPError{"blah", 429}
if errs.Error() == "" {
t.Errorf("Unable to print Error(). Value: %v", errs.Error())
}
}
8 changes: 6 additions & 2 deletions libstring/libstring.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package libstring

func FlattenMapSliceString(mapSliceString map[string][]string, prefix string) []string {
func FlattenMapSliceString(mapSliceString map[string][]string, prefix string, separator string) []string {
result := make([]string, 0)

if separator == "" {
separator = ":"
}

for key, slice := range mapSliceString {
for _, item := range slice {
result = append(result, prefix+":"+key+":"+item)
result = append(result, prefix+separator+key+separator+item)
}
}
return result
Expand Down
16 changes: 16 additions & 0 deletions libstring/libstring_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package libstring

import (
"testing"
)

func TestFlattenMapSliceString(t *testing.T) {
headersToCheck := make(map[string][]string)
headersToCheck["X-Auth-Token"] = []string{"abc123", "brotato!23"}

for i, flatten := range FlattenMapSliceString(headersToCheck, "headers", "") {
if flatten != "headers:X-Auth-Token:"+headersToCheck["X-Auth-Token"][i] {
t.Errorf("Failed to flatten map correctly. Result: %v", flatten)
}
}
}
98 changes: 41 additions & 57 deletions tollbooth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,100 +2,84 @@
package tollbooth

import (
"fmt"
"github.com/didip/tollbooth/libstring"
"github.com/didip/tollbooth/storages"
"net/http"
"strings"
"time"
)

// NewRequestLimit is a constructor for RequestLimit.
func NewRequestLimit(max int64, ttl time.Duration) *RequestLimit {
return &RequestLimit{Max: max, TTL: ttl}
}

// RequestLimit is a config struct to limit a particular request handler.
type RequestLimit struct {
// Maximum number of requests to limit per duration.
Max int64

// Duration of rate limiter.
TTL time.Duration

// List of HTTP Methods to limit (GET, POST, PUT, etc.).
// Empty means limit all methods.
Methods []string

// List of HTTP headers to limit.
// Empty means skip headers checking.
Headers map[string][]string
}

// HTTPError is an error struct that returns both message and status code.
type HTTPError struct {
Message string
StatusCode int
}
"github.com/didip/tollbooth/config"
"github.com/didip/tollbooth/errors"
"github.com/didip/tollbooth/libstring"
"github.com/didip/tollbooth/storages"
)

// Error returns error message.
func (httperror *HTTPError) Error() string {
return fmt.Sprintf("%v: %v", httperror.StatusCode, httperror.Message)
// NewLimiter is a convenience function to config.NewLimiter.
func NewLimiter(max int64, ttl time.Duration) *config.Limiter {
return config.NewLimiter(max, ttl)
}

// LimitByKeyParts keeps track number of request made by keyParts separated by pipe.
// It returns HTTPError when limit is exceeded.
func LimitByKeyParts(storage storages.ICounterStorage, reqLimit *RequestLimit, keyParts []string) *HTTPError {
// It keeps track number of request made by REMOTE_ADDR and returns HTTPError when limit is exceeded.
func LimitByKeyParts(storage storages.ICounterStorage, limiter *config.Limiter, keyParts []string) *errors.HTTPError {
key := strings.Join(keyParts, "|")

storage.IncrBy(key, int64(1), reqLimit.TTL)
storage.IncrBy(key, int64(1), limiter.TTL)
currentCount, _ := storage.Get(key)

// Check if the returned counter exceeds our limit
if currentCount > reqLimit.Max {
return &HTTPError{Message: "You have reached maximum request limit.", StatusCode: 429}
if currentCount > limiter.Max {
return &errors.HTTPError{Message: limiter.Message, StatusCode: limiter.StatusCode}
}
return nil
}

// LimitByIPHandler is a middleware that limits by IP given http.Handler struct.
// It keeps track number of request made by REMOTE_ADDR and returns HTTPError when limit is exceeded.
func LimitByIPHandler(storage storages.ICounterStorage, reqLimit *RequestLimit, next http.Handler) http.Handler {
// LimitHandler is a middleware that limits by IP given http.Handler struct.
func LimitHandler(storage storages.ICounterStorage, limiter *config.Limiter, next http.Handler) http.Handler {
middle := func(w http.ResponseWriter, r *http.Request) {
remoteIP := r.Header.Get("REMOTE_ADDR")
path := r.URL.Path
defaultKeyParts := []string{remoteIP, path}

var httpError *HTTPError
var httpError *errors.HTTPError

if reqLimit.Methods != nil && reqLimit.Headers != nil {
// Limit by HTTP methods and headers.
for _, method := range reqLimit.Methods {
if limiter.Methods != nil && limiter.Headers != nil {
// Limit by HTTP methods and HTTP headers.
for _, method := range limiter.Methods {
keyParts := append(defaultKeyParts, method)

for _, headerKeyParts := range libstring.FlattenMapSliceString(reqLimit.Headers, "headers") {
for _, headerKeyParts := range libstring.FlattenMapSliceString(limiter.Headers, "headers", ":") {
keyParts = append(keyParts, headerKeyParts)
httpError = LimitByKeyParts(storage, reqLimit, keyParts)
httpError = LimitByKeyParts(storage, limiter, keyParts)
if httpError != nil {
http.Error(w, httpError.Message, httpError.StatusCode)
return
}
}
}

} else if reqLimit.Methods != nil {
// Limit by HTTP methods.
for _, method := range reqLimit.Methods {
} else if limiter.Methods != nil {
// Limit by HTTP methods only.
for _, method := range limiter.Methods {
keyParts := append(defaultKeyParts, method)
httpError = LimitByKeyParts(storage, reqLimit, keyParts)
httpError = LimitByKeyParts(storage, limiter, keyParts)
if httpError != nil {
http.Error(w, httpError.Message, httpError.StatusCode)
return
}
}
} else if limiter.Headers != nil {
// Limit by HTTP headers only.
for _, headerKeyParts := range libstring.FlattenMapSliceString(limiter.Headers, "headers", ":") {
keyParts := append(defaultKeyParts, headerKeyParts)
httpError = LimitByKeyParts(storage, limiter, keyParts)
if httpError != nil {
http.Error(w, httpError.Message, httpError.StatusCode)
return
}
}

} else {
// Default limiter.
httpError = LimitByKeyParts(storage, reqLimit, defaultKeyParts)
// Default: Limit by remote IP and request path.
httpError = LimitByKeyParts(storage, limiter, defaultKeyParts)
if httpError != nil {
http.Error(w, httpError.Message, httpError.StatusCode)
return
Expand All @@ -108,7 +92,7 @@ func LimitByIPHandler(storage storages.ICounterStorage, reqLimit *RequestLimit,
return http.HandlerFunc(middle)
}

// LimitByIPFuncHandler is a middleware that limits by IP given request handler function.
func LimitByIPFuncHandler(storage storages.ICounterStorage, reqLimit *RequestLimit, nextFunc func(http.ResponseWriter, *http.Request)) http.Handler {
return LimitByIPHandler(storage, reqLimit, http.HandlerFunc(nextFunc))
// LimitFuncHandler is a middleware that limits by IP given request handler function.
func LimitFuncHandler(storage storages.ICounterStorage, limiter *config.Limiter, nextFunc func(http.ResponseWriter, *http.Request)) http.Handler {
return LimitHandler(storage, limiter, http.HandlerFunc(nextFunc))
}
29 changes: 29 additions & 0 deletions tollbooth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package tollbooth

import (
"testing"
"time"

"github.com/didip/tollbooth/storages"
)

func TestLimitByKeyParts(t *testing.T) {
storage := storages.NewInMemory()
limiter := NewLimiter(1, time.Second) // Only 1 request per second is allowed.

httperror := LimitByKeyParts(storage, limiter, []string{"127.0.0.1", "/"})
if httperror != nil {
t.Errorf("First time count should not return error. Error: %v", httperror.Error())
}

httperror = LimitByKeyParts(storage, limiter, []string{"127.0.0.1", "/"})
if httperror == nil {
t.Errorf("Second time count should return error because it exceeds 1 request per second.")
}

<-time.After(1 * time.Second)
httperror = LimitByKeyParts(storage, limiter, []string{"127.0.0.1", "/"})
if httperror != nil {
t.Errorf("Third time count should not return error because the 1 second window has passed.")
}
}

0 comments on commit fc751b9

Please sign in to comment.