Skip to content

Commit a9dd2d3

Browse files
authored
Replace *jwt.validateClaimsWithLeeway with custom validation func (#176)
1 parent 6f1ce16 commit a9dd2d3

File tree

4 files changed

+137
-45
lines changed

4 files changed

+137
-45
lines changed

.github/workflows/test.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,5 @@ jobs:
2626
uses: codecov/codecov-action@v3
2727
with:
2828
files: coverage.out
29-
fail_ci_if_error: true
29+
fail_ci_if_error: false
3030
verbose: true

examples/http-jwks-example/main_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ func TestHandler(t *testing.T) {
7373
t.Fatal(err)
7474
}
7575

76-
token := buildJWTForTesting(t, jwk, testServer.URL, test.subject, []string{})
76+
token := buildJWTForTesting(t, jwk, testServer.URL, test.subject, []string{"my-audience"})
7777
req.Header.Set("Authorization", "Bearer "+token)
7878

7979
rr := httptest.NewRecorder()
8080

81-
mainHandler := setupHandler(testServer.URL, []string{})
81+
mainHandler := setupHandler(testServer.URL, []string{"my-audience"})
8282
mainHandler.ServeHTTP(rr, req)
8383

8484
if want, got := test.wantStatusCode, rr.Code; want != got {

validator/validator.go

+88-36
Original file line numberDiff line numberDiff line change
@@ -99,62 +99,114 @@ func (v *Validator) ValidateToken(ctx context.Context, tokenString string) (inte
9999
return nil, fmt.Errorf("could not parse the token: %w", err)
100100
}
101101

102-
if string(v.signatureAlgorithm) != token.Headers[0].Algorithm {
103-
return nil, fmt.Errorf(
104-
"expected %q signing algorithm but token specified %q",
105-
v.signatureAlgorithm,
106-
token.Headers[0].Algorithm,
107-
)
102+
if err = validateSigningMethod(string(v.signatureAlgorithm), token.Headers[0].Algorithm); err != nil {
103+
return nil, fmt.Errorf("signing method is invalid: %w", err)
108104
}
109105

110-
key, err := v.keyFunc(ctx)
106+
registeredClaims, customClaims, err := v.deserializeClaims(ctx, token)
111107
if err != nil {
112-
return nil, fmt.Errorf("error getting the keys from the key func: %w", err)
113-
}
114-
115-
claimDest := []interface{}{&jwt.Claims{}}
116-
if v.customClaims != nil && v.customClaims() != nil {
117-
claimDest = append(claimDest, v.customClaims())
108+
return nil, fmt.Errorf("failed to deserialize token claims: %w", err)
118109
}
119110

120-
if err = token.Claims(key, claimDest...); err != nil {
121-
return nil, fmt.Errorf("could not get token claims: %w", err)
111+
if err = validateClaimsWithLeeway(registeredClaims, v.expectedClaims, v.allowedClockSkew); err != nil {
112+
return nil, fmt.Errorf("expected claims not validated: %w", err)
122113
}
123114

124-
registeredClaims := *claimDest[0].(*jwt.Claims)
125-
expectedClaims := v.expectedClaims
126-
expectedClaims.Time = time.Now()
127-
if err = registeredClaims.ValidateWithLeeway(expectedClaims, v.allowedClockSkew); err != nil {
128-
return nil, fmt.Errorf("expected claims not validated: %w", err)
115+
if customClaims != nil {
116+
if err = customClaims.Validate(ctx); err != nil {
117+
return nil, fmt.Errorf("custom claims not validated: %w", err)
118+
}
129119
}
130120

131121
validatedClaims := &ValidatedClaims{
132122
RegisteredClaims: RegisteredClaims{
133-
Issuer: registeredClaims.Issuer,
134-
Subject: registeredClaims.Subject,
135-
Audience: registeredClaims.Audience,
136-
ID: registeredClaims.ID,
123+
Issuer: registeredClaims.Issuer,
124+
Subject: registeredClaims.Subject,
125+
Audience: registeredClaims.Audience,
126+
ID: registeredClaims.ID,
127+
Expiry: numericDateToUnixTime(registeredClaims.Expiry),
128+
NotBefore: numericDateToUnixTime(registeredClaims.NotBefore),
129+
IssuedAt: numericDateToUnixTime(registeredClaims.IssuedAt),
137130
},
131+
CustomClaims: customClaims,
138132
}
139133

140-
if registeredClaims.Expiry != nil {
141-
validatedClaims.RegisteredClaims.Expiry = registeredClaims.Expiry.Time().Unix()
134+
return validatedClaims, nil
135+
}
136+
137+
func validateClaimsWithLeeway(actualClaims jwt.Claims, expected jwt.Expected, leeway time.Duration) error {
138+
expectedClaims := expected
139+
expectedClaims.Time = time.Now()
140+
141+
if actualClaims.Issuer != expectedClaims.Issuer {
142+
return jwt.ErrInvalidIssuer
142143
}
143144

144-
if registeredClaims.NotBefore != nil {
145-
validatedClaims.RegisteredClaims.NotBefore = registeredClaims.NotBefore.Time().Unix()
145+
foundAudience := false
146+
for _, value := range expectedClaims.Audience {
147+
if actualClaims.Audience.Contains(value) {
148+
foundAudience = true
149+
break
150+
}
151+
}
152+
if !foundAudience {
153+
return jwt.ErrInvalidAudience
146154
}
147155

148-
if registeredClaims.IssuedAt != nil {
149-
validatedClaims.RegisteredClaims.IssuedAt = registeredClaims.IssuedAt.Time().Unix()
156+
if actualClaims.NotBefore != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.NotBefore.Time()) {
157+
return jwt.ErrNotValidYet
150158
}
151159

152-
if v.customClaims != nil && v.customClaims() != nil {
153-
validatedClaims.CustomClaims = claimDest[1].(CustomClaims)
154-
if err = validatedClaims.CustomClaims.Validate(ctx); err != nil {
155-
return nil, fmt.Errorf("custom claims not validated: %w", err)
156-
}
160+
if actualClaims.Expiry != nil && expectedClaims.Time.Add(-leeway).After(actualClaims.Expiry.Time()) {
161+
return jwt.ErrExpired
157162
}
158163

159-
return validatedClaims, nil
164+
if actualClaims.IssuedAt != nil && expectedClaims.Time.Add(leeway).Before(actualClaims.IssuedAt.Time()) {
165+
return jwt.ErrIssuedInTheFuture
166+
}
167+
168+
return nil
169+
}
170+
171+
func validateSigningMethod(validAlg, tokenAlg string) error {
172+
if validAlg != tokenAlg {
173+
return fmt.Errorf("expected %q signing algorithm but token specified %q", validAlg, tokenAlg)
174+
}
175+
return nil
176+
}
177+
178+
func (v *Validator) customClaimsExist() bool {
179+
return v.customClaims != nil && v.customClaims() != nil
180+
}
181+
182+
func (v *Validator) deserializeClaims(ctx context.Context, token *jwt.JSONWebToken) (jwt.Claims, CustomClaims, error) {
183+
key, err := v.keyFunc(ctx)
184+
if err != nil {
185+
return jwt.Claims{}, nil, fmt.Errorf("error getting the keys from the key func: %w", err)
186+
}
187+
188+
claims := []interface{}{&jwt.Claims{}}
189+
if v.customClaimsExist() {
190+
claims = append(claims, v.customClaims())
191+
}
192+
193+
if err = token.Claims(key, claims...); err != nil {
194+
return jwt.Claims{}, nil, fmt.Errorf("could not get token claims: %w", err)
195+
}
196+
197+
registeredClaims := *claims[0].(*jwt.Claims)
198+
199+
var customClaims CustomClaims
200+
if len(claims) > 1 {
201+
customClaims = claims[1].(CustomClaims)
202+
}
203+
204+
return registeredClaims, customClaims, nil
205+
}
206+
207+
func numericDateToUnixTime(date *jwt.NumericDate) int64 {
208+
if date != nil {
209+
return date.Time().Unix()
210+
}
211+
return 0
160212
}

validator/validator_test.go

+46-6
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@ package validator
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"testing"
8+
"time"
79

810
"github.com/stretchr/testify/assert"
911
"github.com/stretchr/testify/require"
12+
"gopkg.in/square/go-jose.v2/jwt"
1013
)
1114

1215
type testClaims struct {
@@ -77,7 +80,7 @@ func TestValidator_ValidateToken(t *testing.T) {
7780
return []byte("secret"), nil
7881
},
7982
algorithm: RS256,
80-
expectedError: errors.New(`expected "RS256" signing algorithm but token specified "HS256"`),
83+
expectedError: errors.New(`signing method is invalid: expected "RS256" signing algorithm but token specified "HS256"`),
8184
},
8285
{
8386
name: "it throws an error when it cannot parse the token",
@@ -95,7 +98,7 @@ func TestValidator_ValidateToken(t *testing.T) {
9598
return nil, errors.New("key func error message")
9699
},
97100
algorithm: HS256,
98-
expectedError: errors.New("error getting the keys from the key func: key func error message"),
101+
expectedError: errors.New("failed to deserialize token claims: error getting the keys from the key func: key func error message"),
99102
},
100103
{
101104
name: "it throws an error when it fails to deserialize the claims because the signature is invalid",
@@ -104,7 +107,7 @@ func TestValidator_ValidateToken(t *testing.T) {
104107
return []byte("secret"), nil
105108
},
106109
algorithm: HS256,
107-
expectedError: errors.New("could not get token claims: square/go-jose: error in cryptographic primitive"),
110+
expectedError: errors.New("failed to deserialize token claims: could not get token claims: square/go-jose: error in cryptographic primitive"),
108111
},
109112
{
110113
name: "it throws an error when it fails to validate the registered claims",
@@ -150,7 +153,7 @@ func TestValidator_ValidateToken(t *testing.T) {
150153
},
151154
{
152155
name: "it successfully validates a token with exp, nbf and iat",
153-
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6MTY2NjkzOTAwMCwiZXhwIjoxNjY3OTM3Njg2fQ.36iSr7w8Q6b9iJoJo-swmfgAfm23w8SlX92NHIHGX2s",
156+
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6MTY2NjkzOTAwMCwiZXhwIjo5NjY3OTM3Njg2fQ.FKZogkm08gTfYfPU6eYu7OHCjJKnKGLiC0IfoIOPEhs",
154157
keyFunc: func(context.Context) (interface{}, error) {
155158
return []byte("secret"), nil
156159
},
@@ -160,12 +163,48 @@ func TestValidator_ValidateToken(t *testing.T) {
160163
Issuer: issuer,
161164
Subject: subject,
162165
Audience: []string{audience},
163-
Expiry: 1667937686,
166+
Expiry: 9667937686,
164167
NotBefore: 1666939000,
165168
IssuedAt: 1666937686,
166169
},
167170
},
168171
},
172+
{
173+
name: "it throws an error when token is not valid yet",
174+
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6OTY2NjkzOTAwMCwiZXhwIjoxNjY3OTM3Njg2fQ.yUizJ-zK_33tv1qBVvDKO0RuCWtvJ02UQKs8gBadgGY",
175+
keyFunc: func(context.Context) (interface{}, error) {
176+
return []byte("secret"), nil
177+
},
178+
algorithm: HS256,
179+
expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrNotValidYet),
180+
},
181+
{
182+
name: "it throws an error when token is expired",
183+
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjE2NjY5Mzc2ODYsIm5iZiI6MTY2NjkzOTAwMCwiZXhwIjo2Njc5Mzc2ODZ9.SKvz82VOXRi_sjvZWIsPG9vSWAXKKgVS4DkGZcwFKL8",
184+
keyFunc: func(context.Context) (interface{}, error) {
185+
return []byte("secret"), nil
186+
},
187+
algorithm: HS256,
188+
expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrExpired),
189+
},
190+
{
191+
name: "it throws an error when token is issued in the future",
192+
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLmV1LmF1dGgwLmNvbS8iLCJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjpbImh0dHBzOi8vZ28tand0LW1pZGRsZXdhcmUtYXBpLyJdLCJpYXQiOjkxNjY2OTM3Njg2LCJuYmYiOjE2NjY5MzkwMDAsImV4cCI6ODY2NzkzNzY4Nn0.ieFV7XNJxiJyw8ARq9yHw-01Oi02e3P2skZO10ypxL8",
193+
keyFunc: func(context.Context) (interface{}, error) {
194+
return []byte("secret"), nil
195+
},
196+
algorithm: HS256,
197+
expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrIssuedInTheFuture),
198+
},
199+
{
200+
name: "it throws an error when token issuer is invalid",
201+
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2hhY2tlZC1qd3QtbWlkZGxld2FyZS5ldS5hdXRoMC5jb20vIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6WyJodHRwczovL2dvLWp3dC1taWRkbGV3YXJlLWFwaS8iXSwiaWF0Ijo5MTY2NjkzNzY4NiwibmJmIjoxNjY2OTM5MDAwLCJleHAiOjg2Njc5Mzc2ODZ9.b5gXNrUNfd_jyCWZF-6IPK_UFfvTr9wBQk9_QgRQ8rA",
202+
keyFunc: func(context.Context) (interface{}, error) {
203+
return []byte("secret"), nil
204+
},
205+
algorithm: HS256,
206+
expectedError: fmt.Errorf("expected claims not validated: %s", jwt.ErrInvalidIssuer),
207+
},
169208
}
170209

171210
for _, testCase := range testCases {
@@ -177,8 +216,9 @@ func TestValidator_ValidateToken(t *testing.T) {
177216
testCase.keyFunc,
178217
testCase.algorithm,
179218
issuer,
180-
[]string{audience},
219+
[]string{audience, "another-audience"},
181220
WithCustomClaims(testCase.customClaims),
221+
WithAllowedClockSkew(time.Second),
182222
)
183223
require.NoError(t, err)
184224

0 commit comments

Comments
 (0)