Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 60 additions & 20 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@ package krouter
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"strconv"
"time"

"github.com/Shopify/sarama"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/tryfix/errors"
"github.com/tryfix/kstream/data"
"github.com/tryfix/log"
traceable_context "github.com/tryfix/traceable-context"
"io/ioutil"
"net/http"
"reflect"
"strconv"
"time"
)

type InvalidHeaderError struct {
Expand All @@ -30,12 +31,14 @@ func (i InvalidHeaderError) Error() string {
type Payload struct {
headers map[string]interface{}
params map[string]interface{}
query map[string]interface{}
Body interface{}
}

type HttpPayload struct {
headers map[string]string
params map[string]string
query map[string]string
Body interface{}
}

Expand All @@ -47,6 +50,10 @@ func (p *HttpPayload) Header(name string) string {
return p.headers[name]
}

func (p *HttpPayload) Query(name string) string {
return p.query[name]
}

type PreRouteHandleFunc func(ctx context.Context, payload HttpPayload) (interface{}, error)

type PostRouteHandleFunc func(ctx context.Context, payload Payload) error
Expand All @@ -59,6 +66,10 @@ func (p *Payload) Header(name string) interface{} {
return p.headers[name]
}

func (p *Payload) Query(name string) interface{} {
return p.query[name]
}

type handlerOption func(*Handler)

func HandlerWithValidator(v Validator) handlerOption {
Expand Down Expand Up @@ -92,6 +103,16 @@ func HandlerWithParameter(name string, typ ParamType) handlerOption {
}
}

func HandlerWithQueryParam(name string, typ ParamType, required bool) handlerOption {
return func(h *Handler) {
h.supportedQueryParams = append(h.supportedQueryParams, Param{
name: name,
typ: typ,
whenEmpty: nil,
})
}
}

func HandlerWithSuccessHandlerFunc(fn SuccessHandlerFunc) handlerOption {
return func(h *Handler) {
h.successHandlerFunc = fn
Expand Down Expand Up @@ -119,21 +140,21 @@ func HandlerWithContextExtractor(fn ContextExtractor) handlerOption {
type KeyMapper func(ctx context.Context, routeName string, body interface{}, params, headers map[string]string) (string, error)

type Handler struct {
request *Payload
logger log.Logger
postHandler PostRouteHandleFunc
preHandler PreRouteHandleFunc
validators []Validator
encode Encoder
supportedHeaders []Param
supportedParams []Param
router *Router
name string
headersFuncs map[string]func() string
keyMapper KeyMapper
successHandlerFunc SuccessHandlerFunc
errorHandlerFunc ErrorHandlerFunc
contextExtractor ContextExtractor
logger log.Logger
postHandler PostRouteHandleFunc
preHandler PreRouteHandleFunc
validators []Validator
encode Encoder
supportedHeaders []Param
supportedParams []Param
supportedQueryParams []Param
router *Router
name string
headersFuncs map[string]func() string
keyMapper KeyMapper
successHandlerFunc SuccessHandlerFunc
errorHandlerFunc ErrorHandlerFunc
contextExtractor ContextExtractor
}

func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand All @@ -150,6 +171,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *Handler) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
rawParams := map[string]string{}
rawHeaders := map[string]string{}
rawQuery := map[string]string{}

// apply http request headers to re-route headers
for _, h := range h.supportedHeaders {
Expand All @@ -165,6 +187,16 @@ func (h *Handler) serve(ctx context.Context, w http.ResponseWriter, r *http.Requ
rawHeaders[h.name] = h.whenEmpty()
}

// apply query parameters
queryValues := r.URL.Query()
for _, q := range h.supportedQueryParams {
if value := queryValues.Get(q.name); value != "" {
rawQuery[q.name] = value
} else if q.whenEmpty != nil {
rawQuery[q.name] = q.whenEmpty()
}
}

// apply http Router params to re-route params
vars := mux.Vars(r)
for _, h := range h.supportedParams {
Expand Down Expand Up @@ -202,9 +234,16 @@ func (h *Handler) serve(ctx context.Context, w http.ResponseWriter, r *http.Requ
}
}

// decode query parameters
_, err = h.decodeParams(h.supportedQueryParams, rawQuery)
if err != nil {
return fmt.Errorf("failed to decode query parameters: %v", err)
}

payload := HttpPayload{
headers: rawHeaders,
params: rawParams,
query: rawQuery,
Body: v,
}

Expand All @@ -227,6 +266,7 @@ func (h *Handler) serve(ctx context.Context, w http.ResponseWriter, r *http.Requ
route := Route{
Params: rawParams,
Headers: rawHeaders,
Query: rawQuery,
Payload: string(byt),
Name: h.name,
}
Expand Down
2 changes: 1 addition & 1 deletion route.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import "encoding/json"
type Route struct {
Params map[string]string `json:"params"`
Headers map[string]string `json:"headers"`
Query map[string]string `json:"query"`
Payload string `json:"payload"`
Name string `json:"name"`
}
Expand All @@ -16,4 +17,3 @@ func (r Route) Encode() ([]byte, error) {
func (r Route) Decode(data []byte) error {
return json.Unmarshal(data, &r)
}

14 changes: 11 additions & 3 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"

"github.com/Shopify/sarama"
"github.com/google/uuid"
"github.com/tryfix/errors"
Expand All @@ -12,9 +15,7 @@ import (
"github.com/tryfix/kstream/producer"
"github.com/tryfix/log"
"github.com/tryfix/metrics"
"github.com/tryfix/traceable-context"
"net/http"
"time"
traceable_context "github.com/tryfix/traceable-context"
)

type group struct {
Expand Down Expand Up @@ -262,9 +263,16 @@ func (r *Router) process(ctx context.Context, record *data.Record) error {
return errors.WithPrevious(err, fmt.Sprintf(`header decode error on route [%s]`, route.Name))
}

// decode query parameters
query, err := h.decodeParams(h.supportedQueryParams, route.Query)
if err != nil {
return errors.WithPrevious(err, fmt.Sprintf(`query parameter decode error on route [%s]`, route.Name))
}

payload := Payload{
params: params,
headers: headers,
query: query,
Body: nil,
}

Expand Down
30 changes: 18 additions & 12 deletions router_error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/tryfix/kstream/admin"
"github.com/tryfix/kstream/consumer"
"github.com/tryfix/kstream/producer"
"github.com/tryfix/log"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"time"

"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/tryfix/kstream/admin"
"github.com/tryfix/kstream/consumer"
"github.com/tryfix/kstream/producer"
"github.com/tryfix/log"
)

type someError struct {
Expand Down Expand Up @@ -69,12 +70,14 @@ var somePrehandlerError = func(ctx context.Context, payload HttpPayload) (interf
}

type somehandler struct {
actualUserId uuid.UUID
actualSomeInt int
actualSomeString string
actualAccid customInt
actualPayload fooRequest
mu sync.Mutex
actualUserId uuid.UUID
actualSomeInt int
actualSomeString string
actualAccid customInt
actualPayload fooRequest
actualQueryParam string
expectedQueryParam string
mu sync.Mutex
}

func (s *somehandler) Handle(ctx context.Context, request Payload) error {
Expand All @@ -85,6 +88,9 @@ func (s *somehandler) Handle(ctx context.Context, request Payload) error {
s.actualSomeString = request.Header(`some-string`).(string)
s.actualAccid = request.Param(`acc-id`).(customInt)
s.actualPayload = request.Body.(fooRequest)
if qoo := request.Query("qoo"); qoo != nil {
s.actualQueryParam = qoo.(string)
}
return nil
}

Expand Down
29 changes: 21 additions & 8 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/tryfix/kstream/admin"
"github.com/tryfix/kstream/consumer"
"github.com/tryfix/kstream/producer"
"github.com/tryfix/log"
"net/http"
"net/http/httptest"
"reflect"
"strconv"
"testing"
"time"

"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/tryfix/kstream/admin"
"github.com/tryfix/kstream/consumer"
"github.com/tryfix/kstream/producer"
"github.com/tryfix/log"
)

type testErrorResponse struct {
Expand Down Expand Up @@ -87,15 +88,20 @@ func TestHandler_ServeHTTP(t *testing.T) {
Id: 1000,
Name: "some name",
}
assertQueryParam := "var"

payloadhandler := new(somehandler)
payloadhandler := &somehandler{
expectedQueryParam: assertQueryParam,
}

h := router.NewHandler(`route1`, fooRequestEncoder{}, somePrehandler, payloadhandler.Handle,
HandlerWithHeader(`user-id`, ParamTypeUuid, nil),
HandlerWithHeader(`some-int`, ParamTypeInt, nil),
HandlerWithHeader(`some-string`, ParamTypeString, nil),
HandlerWithHeader(`trace-id`, ParamTypeUuid, func() string { return uuid.New().String() }),
HandlerWithParameter(`acc-id`, `my-custom-type`),
HandlerWithQueryParam(`qoo`, ParamTypeString, false),
HandlerWithQueryParam(`qoo`, ParamTypeString, false),
HandlerWithKeyMapper(func(ctx context.Context, routeName string, body interface{}, params, headers map[string]string) (s string, err error) {
return fmt.Sprint(params[`acc-id`]), nil
}),
Expand Down Expand Up @@ -128,7 +134,7 @@ func TestHandler_ServeHTTP(t *testing.T) {
t.Error(err)
}
buff := bytes.NewBuffer(testByt)
req := httptest.NewRequest("POST", "http://example.com/foo/1222/bar", buff)
req := httptest.NewRequest("POST", "http://example.com/foo/1222/bar?qoo=var", buff)
req.Header.Set(`user-id`, assertUserId.String())
req.Header.Set(`some-int`, `133`)
req.Header.Set(`some-string`, assertSomeString)
Expand All @@ -138,6 +144,9 @@ func TestHandler_ServeHTTP(t *testing.T) {
r.Handle(`/foo/{acc-id}/bar`, h).Methods(http.MethodPost)
r.ServeHTTP(w, req)

// Query parameter validation is now done through the handler's stored value
time.Sleep(1 * time.Second)

res := successResponse{}
if err := json.NewDecoder(w.Result().Body).Decode(&res); err != nil {
t.Error(err)
Expand Down Expand Up @@ -171,4 +180,8 @@ func TestHandler_ServeHTTP(t *testing.T) {
if !reflect.DeepEqual(payloadhandler.actualPayload, assertpayload) {
t.Fail()
}
if payloadhandler.actualQueryParam != payloadhandler.expectedQueryParam {
t.Errorf("Expected query param 'qoo' to be '%s', got '%s'",
payloadhandler.expectedQueryParam, payloadhandler.actualQueryParam)
}
}