-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalidate.go
206 lines (177 loc) · 5.84 KB
/
validate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
package jwt
import (
"context"
"crypto/rsa"
"encoding/base64"
"fmt"
"log/slog"
"math/big"
"net/http"
"slices"
"strings"
jwt "github.com/golang-jwt/jwt/v5"
)
const (
BearerSchema = "Bearer"
authHeaderPart = 2
)
type JWKS struct {
Keys []JSONWebKey `json:"keys"`
}
type JSONWebKey struct {
Kid string `json:"kid"` // Key ID - required
Kty string `json:"kty"` // Key Type - Required
Use string `json:"use,omitempty"` // Key Use
Alg string `json:"alg,omitempty"` // Algorithm
// RSA-specific parameters
N string `json:"n,omitempty"` // Modulus
E string `json:"e,omitempty"` // Public Exponent
// EC-specific parameters
Crv string `json:"crv,omitempty"` // Curve - e.g. P-256
X string `json:"x,omitempty"` // X Coordinate
Y string `json:"y,omitempty"` // Y Coordinate
// X.509 Certificate Chain
X5c []string `json:"x5c,omitempty"` // Can be used as fallback or primary source
}
type JWTValidator struct {
JWKSFetcher *JWKSFetcher
audiences []string
validMethods []string
}
func NewJWTValidator(fetcher *JWKSFetcher, audiences []string, validMethods []string) *JWTValidator {
return &JWTValidator{
JWKSFetcher: fetcher,
audiences: audiences,
validMethods: validMethods,
}
}
// JWTMiddleware takes a JWTValidator and return a function.
// The returned function takes in and returns a http.Handler.
// The returned http.HandlerFunc is the actual middleware.
func JWTMiddleware(validator *JWTValidator) func(http.Handler) http.Handler {
keyFunc := validator.createKeyFunc()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
slog.ErrorContext(r.Context(), "received request with no auth header")
http.Error(w, "auth header missing", http.StatusUnauthorized)
return
}
parts := strings.SplitN(authHeader, " ", authHeaderPart)
if len(parts) != authHeaderPart || parts[0] != BearerSchema {
slog.ErrorContext(r.Context(), "received request with malformed auth header")
http.Error(w, "bad auth header format", http.StatusBadRequest)
return
}
tokenStr := parts[1]
claims := &UserClaims{}
// Parse and validate token.
token, err := jwt.ParseWithClaims(tokenStr, claims, keyFunc, jwt.WithValidMethods(validator.validMethods))
if err != nil {
msg := "failed to parse jwt token with claims"
http.Error(w, msg, http.StatusUnauthorized)
slog.ErrorContext(r.Context(), msg, "error", err)
return
}
if !token.Valid {
msg := "token parsed but is invalid"
http.Error(w, msg, http.StatusUnauthorized)
slog.ErrorContext(r.Context(), msg)
return
}
// Check for valid audience
validAud := false
tokenAudience := claims.Audience
// Single aud
if len(tokenAudience) == 1 {
if slices.Contains(validator.audiences, tokenAudience[0]) {
validAud = true
}
// multiple auds
} else {
for _, aud := range tokenAudience {
if slices.Contains(validator.audiences, aud) {
validAud = true
break
}
}
}
if !validAud {
slog.ErrorContext(r.Context(), "token audience validation failed", "audiences", claims.Audience)
http.Error(w, "invalid token", http.StatusUnauthorized)
return
}
// Add claims to context.
ctx := context.WithValue(r.Context(), userClaimsKey, claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Parse JWK. Attempt both RSA and EC parsing. Return the public key.
func parseKey(jwk *JSONWebKey) (interface{}, error) {
switch jwk.Kty {
case "RSA":
if jwk.N != "" && jwk.E != "" {
// Construct public key from RSA params.
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, fmt.Errorf("failed to decode RSA modulus 'n': %w", err)
}
n := new(big.Int).SetBytes(nBytes)
eBytes, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return nil, fmt.Errorf("failed to decode RSA modulus 'e': %w", err)
}
e := new(big.Int).SetBytes(eBytes)
if n.BitLen() == 0 || e.BitLen() == 0 {
return nil, fmt.Errorf("RSA modulus or exponent resulted in zero value")
}
// Check if e is to big for convert
if !e.IsInt64() {
return nil, fmt.Errorf("RSA exponent 'e' is too big to fit in an int")
}
return &rsa.PublicKey{N: n, E: int(e.Int64())}, nil
} else {
return nil, fmt.Errorf("missing N and/or E param")
}
case "EC":
// TODO: add EC support
// Extract eclipse params and construct public key
return nil, fmt.Errorf("EC not yet supported")
default:
return nil, fmt.Errorf("method not supported: %s", jwk.Kty)
}
}
// Returns a key lookup function function that takes in a jwt token,
// A KeyFunc return (interface{}, error) where the interface may be a single key or a verificationKeySet with many keys.
func (v *JWTValidator) createKeyFunc() func(*jwt.Token) (interface{}, error) {
return func(token *jwt.Token) (interface{}, error) {
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("no kid in claim")
}
// Lock and read from jwks store.
v.JWKSFetcher.mutex.RLock()
defer v.JWKSFetcher.mutex.RUnlock()
if v.JWKSFetcher.jwks == nil {
return nil, fmt.Errorf("no keys have been fetched (initial fetch pending or failed)")
}
// Check if any of the public keys IDs match the auth header kid.
// If match, parse and return RSA public key.
for _, key := range v.JWKSFetcher.jwks.Keys {
if key.Kid == kid {
pubkey, err := parseKey(&key)
if err != nil {
slog.Error("failed to parse public key from JWK", "error", err)
return nil, fmt.Errorf("failed to parse key for kid %s: %w", kid, err)
}
if pubkey == nil {
return nil, fmt.Errorf("key found for kid %s, but parsing resulted in nil key", key)
}
return pubkey, nil
}
}
return nil, fmt.Errorf("signing key not found")
}
}