Skip to content

Commit ed6cab8

Browse files
committed
begin big refactor
1 parent a9bd8e4 commit ed6cab8

File tree

16 files changed

+445
-214
lines changed

16 files changed

+445
-214
lines changed

abcmiddleware/errors.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ func (m *errorManager) Errors(ctrl AppHandler) http.HandlerFunc {
128128
}
129129

130130
// Get the Request ID scoped logger
131-
log := Log(r)
131+
log := Logger(r)
132132

133133
fields := []zapcore.Field{
134134
zap.String("method", r.Method),

abcmiddleware/log.go

Lines changed: 29 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,73 @@
11
package abcmiddleware
22

33
import (
4-
"context"
54
"fmt"
65
"net/http"
76
"time"
87

9-
chimiddleware "github.com/go-chi/chi/middleware"
108
"go.uber.org/zap"
119
)
1210

13-
// zapResponseWriter is a wrapper that includes that http status and size for logging
14-
type zapResponseWriter struct {
15-
http.ResponseWriter
16-
status int
17-
size int
11+
type zapLogMiddleware struct {
12+
logger *zap.Logger
1813
}
1914

20-
// Zap middleware handles web request logging using Zap
21-
func (m Middleware) Zap(next http.Handler) http.Handler {
22-
fn := func(w http.ResponseWriter, r *http.Request) {
23-
t := time.Now()
24-
zw := &zapResponseWriter{ResponseWriter: w}
25-
26-
// Serve the request
27-
next.ServeHTTP(zw, r)
15+
// ZapLog returns a logging middleware that outputs details about a request
16+
func ZapLog(logger *zap.Logger) MW {
17+
return zapLogMiddleware{logger: logger}
18+
}
2819

29-
// Write the request log line
30-
writeZap(m.Log, r, t, zw.status, zw.size)
31-
}
20+
// Zap middleware handles web request logging using Zap
21+
func (z zapLogMiddleware) Wrap(next http.Handler) http.Handler {
22+
return zapLogger{mid: z, next: next}
23+
}
3224

33-
return http.HandlerFunc(fn)
25+
type zapLogger struct {
26+
mid zapLogMiddleware
27+
next http.Handler
3428
}
3529

36-
// RequestIDLogger middleware creates a derived logger to include logging of the
37-
// Request ID, and inserts it into the context object
38-
func (m Middleware) RequestIDLogger(next http.Handler) http.Handler {
39-
fn := func(w http.ResponseWriter, r *http.Request) {
40-
requestID := chimiddleware.GetReqID(r.Context())
41-
derivedLogger := m.Log.With(zap.String("request_id", requestID))
42-
r = r.WithContext(context.WithValue(r.Context(), CtxLoggerKey, derivedLogger))
43-
next.ServeHTTP(w, r)
44-
}
30+
func (z zapLogger) ServeHTTP(w http.ResponseWriter, r *http.Request) {
31+
startTime := time.Now()
32+
zw := &zapResponseWriter{ResponseWriter: w}
4533

46-
return http.HandlerFunc(fn)
47-
}
34+
// Serve the request
35+
z.next.ServeHTTP(zw, r)
4836

49-
// Log returns the Request ID scoped logger from the request Context
50-
// and panics if it cannot be found. This function is only ever used
51-
// by your controllers if your app uses the RequestID middlewares,
52-
// otherwise you should use the controller's receiver logger directly.
53-
func Log(r *http.Request) *zap.Logger {
54-
v := r.Context().Value(CtxLoggerKey)
55-
log, ok := v.(*zap.Logger)
56-
if !ok {
57-
panic("cannot get derived request id logger from context object")
58-
}
59-
return log
37+
// Write the request log line
38+
z.writeZap(zw, r, startTime)
6039
}
6140

62-
func writeZap(log *zap.Logger, r *http.Request, t time.Time, status int, size int) {
63-
elapsed := time.Now().Sub(t)
41+
func (z zapLogger) writeZap(zw *zapResponseWriter, r *http.Request, startTime time.Time) {
42+
elapsed := time.Now().Sub(startTime)
6443
var protocol string
6544
if r.TLS == nil {
6645
protocol = "http"
6746
} else {
6847
protocol = "https"
6948
}
7049

71-
v := r.Context().Value(CtxLoggerKey)
50+
logger := z.mid.logger
51+
v := r.Context().Value(CTXKeyLogger)
7252
if v != nil {
7353
var ok bool
74-
log, ok = v.(*zap.Logger)
54+
logger, ok = v.(*zap.Logger)
7555
if !ok {
7656
panic("cannot get derived request id logger from context object")
7757
}
7858
}
7959

8060
// log all the fields
81-
log.Info(fmt.Sprintf("%s request", protocol),
82-
zap.Int("status", status),
61+
logger.Info(fmt.Sprintf("%s request", protocol),
62+
zap.Int("status", zw.status),
63+
zap.Int("size", zw.size),
64+
zap.Bool("hijacked", zw.hijacked),
8365
zap.String("method", r.Method),
8466
zap.String("uri", r.RequestURI),
8567
zap.Bool("tls", r.TLS != nil),
8668
zap.String("protocol", r.Proto),
8769
zap.String("host", r.Host),
8870
zap.String("remote_addr", r.RemoteAddr),
89-
zap.Int("size", size),
9071
zap.Duration("elapsed", elapsed),
9172
)
9273
}
93-
94-
func (z *zapResponseWriter) WriteHeader(code int) {
95-
z.status = code
96-
z.ResponseWriter.WriteHeader(code)
97-
}
98-
99-
func (z *zapResponseWriter) Write(b []byte) (int, error) {
100-
size, err := z.ResponseWriter.Write(b)
101-
z.size += size
102-
return size, err
103-
}

abcmiddleware/middleware.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@ import (
66
"go.uber.org/zap"
77
)
88

9-
// CtxLoggerKey is the http.Request Context lookup key for the request ID logger
10-
const CtxLoggerKey = "request_id_logger"
11-
129
// MiddlewareFunc is the function signature for Chi's Use() middleware
1310
type MiddlewareFunc func(http.Handler) http.Handler
1411

@@ -18,3 +15,8 @@ type Middleware struct {
1815
// create a derived logger that includes the request ID.
1916
Log *zap.Logger
2017
}
18+
19+
// MW is an interface defining middleware wrapping
20+
type MW interface {
21+
Wrap(http.Handler) http.Handler
22+
}

abcmiddleware/recover.go

Lines changed: 73 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,57 +7,79 @@ import (
77
"go.uber.org/zap"
88
)
99

10-
// Recover middleware recovers panics that occur and gracefully logs their error
11-
func (m Middleware) Recover(next http.Handler) http.Handler {
12-
fn := func(w http.ResponseWriter, r *http.Request) {
13-
defer func() {
14-
if err := recover(); err != nil {
15-
var protocol string
16-
if r.TLS == nil {
17-
protocol = "http"
18-
} else {
19-
protocol = "https"
20-
}
21-
22-
var log *zap.Logger
23-
v := r.Context().Value(CtxLoggerKey)
24-
if v != nil {
25-
var ok bool
26-
log, ok = v.(*zap.Logger)
27-
if !ok {
28-
panic("cannot get derived request id logger from context object")
29-
}
30-
// log with the request_id scoped logger
31-
log.Error(fmt.Sprintf("%s request error", protocol),
32-
zap.String("method", r.Method),
33-
zap.String("uri", r.RequestURI),
34-
zap.Bool("tls", r.TLS != nil),
35-
zap.String("protocol", r.Proto),
36-
zap.String("host", r.Host),
37-
zap.String("remote_addr", r.RemoteAddr),
38-
zap.String("error", fmt.Sprintf("%+v", err)),
39-
)
40-
} else {
41-
// log with the logger attached to middleware struct if
42-
// cannot find request_id scoped logger
43-
m.Log.Error(fmt.Sprintf("%s request error", protocol),
44-
zap.String("method", r.Method),
45-
zap.String("uri", r.RequestURI),
46-
zap.Bool("tls", r.TLS != nil),
47-
zap.String("protocol", r.Proto),
48-
zap.String("host", r.Host),
49-
zap.String("remote_addr", r.RemoteAddr),
50-
zap.String("error", fmt.Sprintf("%+v", err)),
51-
)
52-
}
53-
54-
// Return a http 500 with the HTTP body of "Internal Server Error"
55-
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
56-
}
57-
}()
58-
59-
next.ServeHTTP(w, r)
10+
// ZapRecover will attempt to log a panic, as well as produce a reasonable
11+
// error for the client by calling the passed in errorHandler function.
12+
//
13+
// It uses the zap logger and attempts to look up a request-scoped logger
14+
// created with this package before using the passed in logger.
15+
//
16+
// The zap logger that's used here should be careful to enable stacktrace
17+
// logging for any levels that they require it for.
18+
func ZapRecover(fallback *zap.Logger, errorHandler http.HandlerFunc) MW {
19+
return zapRecoverMiddleware{
20+
fallback: fallback,
21+
eh: errorHandler,
22+
}
23+
}
24+
25+
type zapRecoverMiddleware struct {
26+
fallback *zap.Logger
27+
eh http.HandlerFunc
28+
}
29+
30+
func (z zapRecoverMiddleware) Wrap(next http.Handler) http.Handler {
31+
return zapRecoverer{
32+
zr: z,
33+
next: next,
34+
}
35+
}
36+
37+
type zapRecoverer struct {
38+
zr zapRecoverMiddleware
39+
next http.Handler
40+
}
41+
42+
// recoverPanic was mostly adapted from abcweb
43+
func (z zapRecoverer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
44+
defer z.recoverNicely(w, r)
45+
z.next.ServeHTTP(w, r)
46+
}
47+
48+
func (z zapRecoverer) recoverNicely(w http.ResponseWriter, r *http.Request) {
49+
err := recover()
50+
if err == nil {
51+
return
52+
}
53+
54+
var protocol string
55+
if r.TLS == nil {
56+
protocol = "http"
57+
} else {
58+
protocol = "https"
59+
}
60+
61+
if z.zr.eh != nil {
62+
z.zr.eh(w, r)
63+
}
64+
65+
logger := z.zr.fallback
66+
v := r.Context().Value(CTXKeyLogger)
67+
if v != nil {
68+
var ok bool
69+
logger, ok = v.(*zap.Logger)
70+
if !ok {
71+
panic("cannot get derived request id logger from context object")
72+
}
6073
}
6174

62-
return http.HandlerFunc(fn)
75+
logger.Error(fmt.Sprintf("%s request error", protocol),
76+
zap.String("method", r.Method),
77+
zap.String("uri", r.RequestURI),
78+
zap.Bool("tls", r.TLS != nil),
79+
zap.String("protocol", r.Proto),
80+
zap.String("host", r.Host),
81+
zap.String("remote_addr", r.RemoteAddr),
82+
zap.String("panic", fmt.Sprintf("%+v", err)),
83+
zap.Stack("stacktrace"),
84+
)
6385
}

abcmiddleware/request_id.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package abcmiddleware
2+
3+
import (
4+
"context"
5+
"net/http"
6+
7+
chimiddleware "github.com/go-chi/chi/middleware"
8+
"go.uber.org/zap"
9+
)
10+
11+
type ctxKey int
12+
13+
const (
14+
// CTXKeyLogger is the key under which the request scoped logger is placed
15+
CTXKeyLogger ctxKey = iota
16+
)
17+
18+
// RequestIDHeader sets the X-Request-ID header to the chi request id
19+
// This must be used after the chi request id middleware.
20+
func RequestIDHeader(next http.Handler) http.Handler {
21+
return reqIDInserter{next: next}
22+
}
23+
24+
type reqIDInserter struct {
25+
next http.Handler
26+
}
27+
28+
func (re reqIDInserter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
29+
if reqID := chimiddleware.GetReqID(r.Context()); len(reqID) != 0 {
30+
w.Header().Set("X-Request-ID", reqID)
31+
}
32+
re.next.ServeHTTP(w, r)
33+
}
34+
35+
// ZapRequestIDLogger returns a request id logger middleware. This only works
36+
// if chi has inserted a request id into the stack first.
37+
func ZapRequestIDLogger(logger *zap.Logger) MW {
38+
return zapReqLoggerMiddleware{logger: logger}
39+
}
40+
41+
type zapReqLoggerMiddleware struct {
42+
logger *zap.Logger
43+
}
44+
45+
func (z zapReqLoggerMiddleware) Wrap(next http.Handler) http.Handler {
46+
return zapReqLoggerInserter{logger: z.logger, next: next}
47+
}
48+
49+
type zapReqLoggerInserter struct {
50+
logger *zap.Logger
51+
next http.Handler
52+
}
53+
54+
func (z zapReqLoggerInserter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
55+
requestID := chimiddleware.GetReqID(r.Context())
56+
57+
derivedLogger := z.logger.With(zap.String("request_id", requestID))
58+
59+
r = r.WithContext(context.WithValue(r.Context(), CTXKeyLogger, derivedLogger))
60+
z.next.ServeHTTP(w, r)
61+
}
62+
63+
// Logger returns the Request ID scoped logger from the request Context
64+
// and panics if it cannot be found. This function is only ever used
65+
// by your controllers if your app uses the RequestID middlewares,
66+
// otherwise you should use the controller's receiver logger directly.
67+
func Logger(r *http.Request) *zap.Logger {
68+
return LoggerCTX(r.Context())
69+
}
70+
71+
// LoggerCTX retrieves a logger from a context.
72+
func LoggerCTX(ctx context.Context) *zap.Logger {
73+
v := ctx.Value(CTXKeyLogger)
74+
log, ok := v.(*zap.Logger)
75+
if !ok {
76+
panic("cannot get derived request id logger from context object")
77+
}
78+
return log
79+
}

0 commit comments

Comments
 (0)