diff --git a/handler.go b/handler.go index dbcad19..85b1293 100644 --- a/handler.go +++ b/handler.go @@ -3,6 +3,12 @@ package krouter import ( "context" "fmt" + "io/ioutil" + "net/http" + "reflect" + "strconv" + "time" + "github.com/Shopify/sarama" "github.com/google/uuid" "github.com/gorilla/mux" @@ -10,11 +16,6 @@ import ( "github.com/tryfix/kstream/data" "github.com/tryfix/log" traceable_context "github.com/tryfix/traceable-context" - "io/ioutil" - "net/http" - "reflect" - "strconv" - "time" ) type InvalidHeaderError struct { @@ -30,12 +31,14 @@ func (i InvalidHeaderError) Error() string { type Payload struct { headers map[string]interface{} params map[string]interface{} + query map[string]interface{} Body interface{} } type HttpPayload struct { headers map[string]string params map[string]string + query map[string]string Body interface{} } @@ -47,6 +50,10 @@ func (p *HttpPayload) Header(name string) string { return p.headers[name] } +func (p *HttpPayload) Query(name string) string { + return p.query[name] +} + type PreRouteHandleFunc func(ctx context.Context, payload HttpPayload) (interface{}, error) type PostRouteHandleFunc func(ctx context.Context, payload Payload) error @@ -59,6 +66,10 @@ func (p *Payload) Header(name string) interface{} { return p.headers[name] } +func (p *Payload) Query(name string) interface{} { + return p.query[name] +} + type handlerOption func(*Handler) func HandlerWithValidator(v Validator) handlerOption { @@ -92,6 +103,16 @@ func HandlerWithParameter(name string, typ ParamType) handlerOption { } } +func HandlerWithQueryParam(name string, typ ParamType, required bool) handlerOption { + return func(h *Handler) { + h.supportedQueryParams = append(h.supportedQueryParams, Param{ + name: name, + typ: typ, + whenEmpty: nil, + }) + } +} + func HandlerWithSuccessHandlerFunc(fn SuccessHandlerFunc) handlerOption { return func(h *Handler) { h.successHandlerFunc = fn @@ -119,21 +140,21 @@ func HandlerWithContextExtractor(fn ContextExtractor) handlerOption { type KeyMapper func(ctx context.Context, routeName string, body interface{}, params, headers map[string]string) (string, error) type Handler struct { - request *Payload - logger log.Logger - postHandler PostRouteHandleFunc - preHandler PreRouteHandleFunc - validators []Validator - encode Encoder - supportedHeaders []Param - supportedParams []Param - router *Router - name string - headersFuncs map[string]func() string - keyMapper KeyMapper - successHandlerFunc SuccessHandlerFunc - errorHandlerFunc ErrorHandlerFunc - contextExtractor ContextExtractor + logger log.Logger + postHandler PostRouteHandleFunc + preHandler PreRouteHandleFunc + validators []Validator + encode Encoder + supportedHeaders []Param + supportedParams []Param + supportedQueryParams []Param + router *Router + name string + headersFuncs map[string]func() string + keyMapper KeyMapper + successHandlerFunc SuccessHandlerFunc + errorHandlerFunc ErrorHandlerFunc + contextExtractor ContextExtractor } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -150,6 +171,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *Handler) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { rawParams := map[string]string{} rawHeaders := map[string]string{} + rawQuery := map[string]string{} // apply http request headers to re-route headers for _, h := range h.supportedHeaders { @@ -165,6 +187,16 @@ func (h *Handler) serve(ctx context.Context, w http.ResponseWriter, r *http.Requ rawHeaders[h.name] = h.whenEmpty() } + // apply query parameters + queryValues := r.URL.Query() + for _, q := range h.supportedQueryParams { + if value := queryValues.Get(q.name); value != "" { + rawQuery[q.name] = value + } else if q.whenEmpty != nil { + rawQuery[q.name] = q.whenEmpty() + } + } + // apply http Router params to re-route params vars := mux.Vars(r) for _, h := range h.supportedParams { @@ -202,9 +234,16 @@ func (h *Handler) serve(ctx context.Context, w http.ResponseWriter, r *http.Requ } } + // decode query parameters + _, err = h.decodeParams(h.supportedQueryParams, rawQuery) + if err != nil { + return fmt.Errorf("failed to decode query parameters: %v", err) + } + payload := HttpPayload{ headers: rawHeaders, params: rawParams, + query: rawQuery, Body: v, } @@ -227,6 +266,7 @@ func (h *Handler) serve(ctx context.Context, w http.ResponseWriter, r *http.Requ route := Route{ Params: rawParams, Headers: rawHeaders, + Query: rawQuery, Payload: string(byt), Name: h.name, } diff --git a/route.go b/route.go index 0ef0359..79e597d 100644 --- a/route.go +++ b/route.go @@ -5,6 +5,7 @@ import "encoding/json" type Route struct { Params map[string]string `json:"params"` Headers map[string]string `json:"headers"` + Query map[string]string `json:"query"` Payload string `json:"payload"` Name string `json:"name"` } @@ -16,4 +17,3 @@ func (r Route) Encode() ([]byte, error) { func (r Route) Decode(data []byte) error { return json.Unmarshal(data, &r) } - diff --git a/router.go b/router.go index 3ccf299..bfd0816 100644 --- a/router.go +++ b/router.go @@ -4,6 +4,9 @@ import ( "context" "encoding/json" "fmt" + "net/http" + "time" + "github.com/Shopify/sarama" "github.com/google/uuid" "github.com/tryfix/errors" @@ -12,9 +15,7 @@ import ( "github.com/tryfix/kstream/producer" "github.com/tryfix/log" "github.com/tryfix/metrics" - "github.com/tryfix/traceable-context" - "net/http" - "time" + traceable_context "github.com/tryfix/traceable-context" ) type group struct { @@ -262,9 +263,16 @@ func (r *Router) process(ctx context.Context, record *data.Record) error { return errors.WithPrevious(err, fmt.Sprintf(`header decode error on route [%s]`, route.Name)) } + // decode query parameters + query, err := h.decodeParams(h.supportedQueryParams, route.Query) + if err != nil { + return errors.WithPrevious(err, fmt.Sprintf(`query parameter decode error on route [%s]`, route.Name)) + } + payload := Payload{ params: params, headers: headers, + query: query, Body: nil, } diff --git a/router_error_test.go b/router_error_test.go index eed7669..1bc9f3d 100644 --- a/router_error_test.go +++ b/router_error_test.go @@ -5,18 +5,19 @@ import ( "context" "encoding/json" "fmt" - "github.com/google/uuid" - "github.com/gorilla/mux" - "github.com/tryfix/kstream/admin" - "github.com/tryfix/kstream/consumer" - "github.com/tryfix/kstream/producer" - "github.com/tryfix/log" "net/http" "net/http/httptest" "strconv" "sync" "testing" "time" + + "github.com/google/uuid" + "github.com/gorilla/mux" + "github.com/tryfix/kstream/admin" + "github.com/tryfix/kstream/consumer" + "github.com/tryfix/kstream/producer" + "github.com/tryfix/log" ) type someError struct { @@ -69,12 +70,14 @@ var somePrehandlerError = func(ctx context.Context, payload HttpPayload) (interf } type somehandler struct { - actualUserId uuid.UUID - actualSomeInt int - actualSomeString string - actualAccid customInt - actualPayload fooRequest - mu sync.Mutex + actualUserId uuid.UUID + actualSomeInt int + actualSomeString string + actualAccid customInt + actualPayload fooRequest + actualQueryParam string + expectedQueryParam string + mu sync.Mutex } func (s *somehandler) Handle(ctx context.Context, request Payload) error { @@ -85,6 +88,9 @@ func (s *somehandler) Handle(ctx context.Context, request Payload) error { s.actualSomeString = request.Header(`some-string`).(string) s.actualAccid = request.Param(`acc-id`).(customInt) s.actualPayload = request.Body.(fooRequest) + if qoo := request.Query("qoo"); qoo != nil { + s.actualQueryParam = qoo.(string) + } return nil } diff --git a/router_test.go b/router_test.go index 5f4f6e3..9091ad7 100644 --- a/router_test.go +++ b/router_test.go @@ -5,18 +5,19 @@ import ( "context" "encoding/json" "fmt" - "github.com/google/uuid" - "github.com/gorilla/mux" - "github.com/tryfix/kstream/admin" - "github.com/tryfix/kstream/consumer" - "github.com/tryfix/kstream/producer" - "github.com/tryfix/log" "net/http" "net/http/httptest" "reflect" "strconv" "testing" "time" + + "github.com/google/uuid" + "github.com/gorilla/mux" + "github.com/tryfix/kstream/admin" + "github.com/tryfix/kstream/consumer" + "github.com/tryfix/kstream/producer" + "github.com/tryfix/log" ) type testErrorResponse struct { @@ -87,8 +88,11 @@ func TestHandler_ServeHTTP(t *testing.T) { Id: 1000, Name: "some name", } + assertQueryParam := "var" - payloadhandler := new(somehandler) + payloadhandler := &somehandler{ + expectedQueryParam: assertQueryParam, + } h := router.NewHandler(`route1`, fooRequestEncoder{}, somePrehandler, payloadhandler.Handle, HandlerWithHeader(`user-id`, ParamTypeUuid, nil), @@ -96,6 +100,8 @@ func TestHandler_ServeHTTP(t *testing.T) { HandlerWithHeader(`some-string`, ParamTypeString, nil), HandlerWithHeader(`trace-id`, ParamTypeUuid, func() string { return uuid.New().String() }), HandlerWithParameter(`acc-id`, `my-custom-type`), + HandlerWithQueryParam(`qoo`, ParamTypeString, false), + HandlerWithQueryParam(`qoo`, ParamTypeString, false), HandlerWithKeyMapper(func(ctx context.Context, routeName string, body interface{}, params, headers map[string]string) (s string, err error) { return fmt.Sprint(params[`acc-id`]), nil }), @@ -128,7 +134,7 @@ func TestHandler_ServeHTTP(t *testing.T) { t.Error(err) } buff := bytes.NewBuffer(testByt) - req := httptest.NewRequest("POST", "http://example.com/foo/1222/bar", buff) + req := httptest.NewRequest("POST", "http://example.com/foo/1222/bar?qoo=var", buff) req.Header.Set(`user-id`, assertUserId.String()) req.Header.Set(`some-int`, `133`) req.Header.Set(`some-string`, assertSomeString) @@ -138,6 +144,9 @@ func TestHandler_ServeHTTP(t *testing.T) { r.Handle(`/foo/{acc-id}/bar`, h).Methods(http.MethodPost) r.ServeHTTP(w, req) + // Query parameter validation is now done through the handler's stored value + time.Sleep(1 * time.Second) + res := successResponse{} if err := json.NewDecoder(w.Result().Body).Decode(&res); err != nil { t.Error(err) @@ -171,4 +180,8 @@ func TestHandler_ServeHTTP(t *testing.T) { if !reflect.DeepEqual(payloadhandler.actualPayload, assertpayload) { t.Fail() } + if payloadhandler.actualQueryParam != payloadhandler.expectedQueryParam { + t.Errorf("Expected query param 'qoo' to be '%s', got '%s'", + payloadhandler.expectedQueryParam, payloadhandler.actualQueryParam) + } }