Skip to content

Commit 30ab7fd

Browse files
feat: add utility to normalize graphql request (#29)
1 parent f72c037 commit 30ab7fd

File tree

4 files changed

+70
-38
lines changed

4 files changed

+70
-38
lines changed

router.go

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"io/ioutil"
77
"net/http"
8-
"strings"
98
"time"
109

1110
"github.com/99designs/gqlgen/graphql/handler"
@@ -15,9 +14,7 @@ import (
1514
"github.com/gbox-proxy/gbox/admin/generated"
1615
"github.com/gorilla/handlers"
1716
"github.com/gorilla/mux"
18-
"github.com/jensneuse/graphql-go-tools/pkg/astparser"
1917
"github.com/jensneuse/graphql-go-tools/pkg/graphql"
20-
"github.com/jensneuse/graphql-go-tools/pkg/operationreport"
2118
"go.uber.org/zap"
2219
)
2320

@@ -78,7 +75,7 @@ func (h *Handler) GraphQLOverWebsocketHandle(w http.ResponseWriter, r *http.Requ
7875
}
7976

8077
n := r.Context().Value(nextHandlerCtxKey).(caddyhttp.Handler)
81-
mr := newWebsocketMetricsResponseWriter(w, h)
78+
mr := newWebsocketMetricsResponseWriter(w, h.schema, h)
8279
reporter.error = h.ReverseProxy.ServeHTTP(mr, r, n)
8380
}
8481

@@ -153,39 +150,14 @@ func (h *Handler) unmarshalHTTPRequest(r *http.Request) (*graphql.Request, error
153150
return nil, err
154151
}
155152

156-
err = graphql.UnmarshalHttpRequest(copyHTTPRequest, gqlRequest)
157-
158-
if err != nil {
153+
if err = graphql.UnmarshalHttpRequest(copyHTTPRequest, gqlRequest); err != nil {
159154
return nil, err
160155
}
161156

162-
if result, _ := gqlRequest.Normalize(h.schema); !result.Successful {
163-
return nil, result.Errors
164-
}
165-
166-
operation, _ := astparser.ParseGraphqlDocumentString(gqlRequest.Query)
167-
numOfOperations := operation.NumOfOperationDefinitions()
168-
operationName := strings.TrimSpace(gqlRequest.OperationName)
169-
report := &operationreport.Report{}
170-
171-
if operationName == "" && numOfOperations > 1 {
172-
report.AddExternalError(operationreport.ErrRequiredOperationNameIsMissing())
173-
174-
return nil, report
175-
}
176-
177-
if operationName == "" && numOfOperations == 1 {
178-
operationName = operation.OperationDefinitionNameString(0)
179-
}
180-
181-
if !operation.OperationNameExists(operationName) {
182-
report.AddExternalError(operationreport.ErrOperationWithProvidedOperationNameNotFound(operationName))
183-
184-
return nil, report
157+
if err = normalizeGraphqlRequest(h.schema, gqlRequest); err != nil {
158+
return nil, err
185159
}
186160

187-
gqlRequest.OperationName = operationName
188-
189161
return gqlRequest, nil
190162
}
191163

utils.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@ package gbox
33
import (
44
"bytes"
55
"net/http"
6+
"strings"
67
"sync"
78

9+
"github.com/jensneuse/graphql-go-tools/pkg/astparser"
810
"github.com/jensneuse/graphql-go-tools/pkg/graphql"
11+
"github.com/jensneuse/graphql-go-tools/pkg/operationreport"
912
)
1013

1114
var bufferPool = sync.Pool{
@@ -24,3 +27,34 @@ func writeResponseErrors(errors error, w http.ResponseWriter) error {
2427

2528
return nil
2629
}
30+
31+
func normalizeGraphqlRequest(schema *graphql.Schema, gqlRequest *graphql.Request) error {
32+
if result, _ := gqlRequest.Normalize(schema); !result.Successful {
33+
return result.Errors
34+
}
35+
36+
operation, _ := astparser.ParseGraphqlDocumentString(gqlRequest.Query)
37+
numOfOperations := operation.NumOfOperationDefinitions()
38+
operationName := strings.TrimSpace(gqlRequest.OperationName)
39+
report := &operationreport.Report{}
40+
41+
if operationName == "" && numOfOperations > 1 {
42+
report.AddExternalError(operationreport.ErrRequiredOperationNameIsMissing())
43+
44+
return report
45+
}
46+
47+
if operationName == "" && numOfOperations == 1 {
48+
operationName = operation.OperationDefinitionNameString(0)
49+
}
50+
51+
if !operation.OperationNameExists(operationName) {
52+
report.AddExternalError(operationreport.ErrOperationWithProvidedOperationNameNotFound(operationName))
53+
54+
return report
55+
}
56+
57+
gqlRequest.OperationName = operationName
58+
59+
return nil
60+
}

ws.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@ import (
1616
type wsMetricsResponseWriter struct {
1717
requestMetrics
1818
*caddyhttp.ResponseWriterWrapper
19+
schema *graphql.Schema
1920
}
2021

21-
func newWebsocketMetricsResponseWriter(w http.ResponseWriter, rm requestMetrics) *wsMetricsResponseWriter {
22+
func newWebsocketMetricsResponseWriter(w http.ResponseWriter, s *graphql.Schema, rm requestMetrics) *wsMetricsResponseWriter {
2223
return &wsMetricsResponseWriter{
23-
requestMetrics: rm,
2424
ResponseWriterWrapper: &caddyhttp.ResponseWriterWrapper{
2525
ResponseWriter: w,
2626
},
27+
schema: s,
28+
requestMetrics: rm,
2729
}
2830
}
2931

@@ -35,6 +37,7 @@ func (r *wsMetricsResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error)
3537
c = &wsMetricsConn{
3638
Conn: c,
3739
requestMetrics: r.requestMetrics,
40+
schema: r.schema,
3841
}
3942
}
4043

@@ -45,14 +48,15 @@ type wsMetricsConn struct {
4548
net.Conn
4649
requestMetrics
4750
request *graphql.Request
51+
schema *graphql.Schema
4852
subscribeAt time.Time
4953
}
5054

5155
func (c *wsMetricsConn) Read(b []byte) (n int, err error) {
5256
n, err = c.Conn.Read(b)
5357

54-
if err != nil {
55-
if c.request != nil {
58+
if c.request != nil || err != nil {
59+
if err != nil {
5660
c.addMetricsEndRequest(c.request, time.Since(c.subscribeAt))
5761
c.request = nil
5862
}
@@ -88,6 +92,10 @@ func (c *wsMetricsConn) Read(b []byte) (n int, err error) {
8892
return n, err
8993
}
9094

95+
if e := normalizeGraphqlRequest(c.schema, request); e != nil {
96+
return n, err
97+
}
98+
9199
c.request = request
92100
c.subscribeAt = time.Now()
93101
c.addMetricsBeginRequest(request)

ws_test.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,17 @@ func (m *testRequestMetrics) addMetricsEndRequest(request *graphql.Request, dura
5858
}
5959

6060
func TestWsMetricsConn(t *testing.T) {
61+
s, _ := graphql.NewSchemaFromString(`
62+
type Query {
63+
users [User!]!
64+
}
65+
66+
type User {
67+
id: ID!
68+
}
69+
`)
6170
m := newTestRequestMetrics(t)
62-
w := newWebsocketMetricsResponseWriter(&testWsResponseWriter{}, m)
71+
w := newWebsocketMetricsResponseWriter(&testWsResponseWriter{}, s, m)
6372
conn, _, _ := w.Hijack()
6473
buff := new(bytes.Buffer)
6574
wsutil.WriteClientText(buff, []byte(`{"type": "start", "payload":{"query": "subscription { users { id } }"}}`))
@@ -76,6 +85,15 @@ func TestWsMetricsConn(t *testing.T) {
7685
}
7786

7887
func TestWsMetricsConnBadCases(t *testing.T) {
88+
s, _ := graphql.NewSchemaFromString(`
89+
type Query {
90+
users [User!]!
91+
}
92+
93+
type User {
94+
id: ID!
95+
}
96+
`)
7997
testCases := map[string]struct {
8098
message string
8199
}{
@@ -93,7 +111,7 @@ func TestWsMetricsConnBadCases(t *testing.T) {
93111

94112
for name, testCase := range testCases {
95113
m := newTestRequestMetrics(t)
96-
w := newWebsocketMetricsResponseWriter(&testWsResponseWriter{}, m)
114+
w := newWebsocketMetricsResponseWriter(&testWsResponseWriter{}, s, m)
97115
conn, _, _ := w.Hijack()
98116
buff := new(bytes.Buffer)
99117

0 commit comments

Comments
 (0)