@@ -99,62 +99,114 @@ func (v *Validator) ValidateToken(ctx context.Context, tokenString string) (inte
99
99
return nil , fmt .Errorf ("could not parse the token: %w" , err )
100
100
}
101
101
102
- if string (v .signatureAlgorithm ) != token .Headers [0 ].Algorithm {
103
- return nil , fmt .Errorf (
104
- "expected %q signing algorithm but token specified %q" ,
105
- v .signatureAlgorithm ,
106
- token .Headers [0 ].Algorithm ,
107
- )
102
+ if err = validateSigningMethod (string (v .signatureAlgorithm ), token .Headers [0 ].Algorithm ); err != nil {
103
+ return nil , fmt .Errorf ("signing method is invalid: %w" , err )
108
104
}
109
105
110
- key , err := v .keyFunc (ctx )
106
+ registeredClaims , customClaims , err := v .deserializeClaims (ctx , token )
111
107
if err != nil {
112
- return nil , fmt .Errorf ("error getting the keys from the key func: %w" , err )
113
- }
114
-
115
- claimDest := []interface {}{& jwt.Claims {}}
116
- if v .customClaims != nil && v .customClaims () != nil {
117
- claimDest = append (claimDest , v .customClaims ())
108
+ return nil , fmt .Errorf ("failed to deserialize token claims: %w" , err )
118
109
}
119
110
120
- if err = token . Claims ( key , claimDest ... ); err != nil {
121
- return nil , fmt .Errorf ("could not get token claims : %w" , err )
111
+ if err = validateClaimsWithLeeway ( registeredClaims , v . expectedClaims , v . allowedClockSkew ); err != nil {
112
+ return nil , fmt .Errorf ("expected claims not validated : %w" , err )
122
113
}
123
114
124
- registeredClaims := * claimDest [0 ].(* jwt.Claims )
125
- expectedClaims := v .expectedClaims
126
- expectedClaims .Time = time .Now ()
127
- if err = registeredClaims .ValidateWithLeeway (expectedClaims , v .allowedClockSkew ); err != nil {
128
- return nil , fmt .Errorf ("expected claims not validated: %w" , err )
115
+ if customClaims != nil {
116
+ if err = customClaims .Validate (ctx ); err != nil {
117
+ return nil , fmt .Errorf ("custom claims not validated: %w" , err )
118
+ }
129
119
}
130
120
131
121
validatedClaims := & ValidatedClaims {
132
122
RegisteredClaims : RegisteredClaims {
133
- Issuer : registeredClaims .Issuer ,
134
- Subject : registeredClaims .Subject ,
135
- Audience : registeredClaims .Audience ,
136
- ID : registeredClaims .ID ,
123
+ Issuer : registeredClaims .Issuer ,
124
+ Subject : registeredClaims .Subject ,
125
+ Audience : registeredClaims .Audience ,
126
+ ID : registeredClaims .ID ,
127
+ Expiry : numericDateToUnixTime (registeredClaims .Expiry ),
128
+ NotBefore : numericDateToUnixTime (registeredClaims .NotBefore ),
129
+ IssuedAt : numericDateToUnixTime (registeredClaims .IssuedAt ),
137
130
},
131
+ CustomClaims : customClaims ,
138
132
}
139
133
140
- if registeredClaims .Expiry != nil {
141
- validatedClaims .RegisteredClaims .Expiry = registeredClaims .Expiry .Time ().Unix ()
134
+ return validatedClaims , nil
135
+ }
136
+
137
+ func validateClaimsWithLeeway (actualClaims jwt.Claims , expected jwt.Expected , leeway time.Duration ) error {
138
+ expectedClaims := expected
139
+ expectedClaims .Time = time .Now ()
140
+
141
+ if actualClaims .Issuer != expectedClaims .Issuer {
142
+ return jwt .ErrInvalidIssuer
142
143
}
143
144
144
- if registeredClaims .NotBefore != nil {
145
- validatedClaims .RegisteredClaims .NotBefore = registeredClaims .NotBefore .Time ().Unix ()
145
+ foundAudience := false
146
+ for _ , value := range expectedClaims .Audience {
147
+ if actualClaims .Audience .Contains (value ) {
148
+ foundAudience = true
149
+ break
150
+ }
151
+ }
152
+ if ! foundAudience {
153
+ return jwt .ErrInvalidAudience
146
154
}
147
155
148
- if registeredClaims . IssuedAt != nil {
149
- validatedClaims . RegisteredClaims . IssuedAt = registeredClaims . IssuedAt . Time (). Unix ()
156
+ if actualClaims . NotBefore != nil && expectedClaims . Time . Add ( leeway ). Before ( actualClaims . NotBefore . Time ()) {
157
+ return jwt . ErrNotValidYet
150
158
}
151
159
152
- if v .customClaims != nil && v .customClaims () != nil {
153
- validatedClaims .CustomClaims = claimDest [1 ].(CustomClaims )
154
- if err = validatedClaims .CustomClaims .Validate (ctx ); err != nil {
155
- return nil , fmt .Errorf ("custom claims not validated: %w" , err )
156
- }
160
+ if actualClaims .Expiry != nil && expectedClaims .Time .Add (- leeway ).After (actualClaims .Expiry .Time ()) {
161
+ return jwt .ErrExpired
157
162
}
158
163
159
- return validatedClaims , nil
164
+ if actualClaims .IssuedAt != nil && expectedClaims .Time .Add (leeway ).Before (actualClaims .IssuedAt .Time ()) {
165
+ return jwt .ErrIssuedInTheFuture
166
+ }
167
+
168
+ return nil
169
+ }
170
+
171
+ func validateSigningMethod (validAlg , tokenAlg string ) error {
172
+ if validAlg != tokenAlg {
173
+ return fmt .Errorf ("expected %q signing algorithm but token specified %q" , validAlg , tokenAlg )
174
+ }
175
+ return nil
176
+ }
177
+
178
+ func (v * Validator ) customClaimsExist () bool {
179
+ return v .customClaims != nil && v .customClaims () != nil
180
+ }
181
+
182
+ func (v * Validator ) deserializeClaims (ctx context.Context , token * jwt.JSONWebToken ) (jwt.Claims , CustomClaims , error ) {
183
+ key , err := v .keyFunc (ctx )
184
+ if err != nil {
185
+ return jwt.Claims {}, nil , fmt .Errorf ("error getting the keys from the key func: %w" , err )
186
+ }
187
+
188
+ claims := []interface {}{& jwt.Claims {}}
189
+ if v .customClaimsExist () {
190
+ claims = append (claims , v .customClaims ())
191
+ }
192
+
193
+ if err = token .Claims (key , claims ... ); err != nil {
194
+ return jwt.Claims {}, nil , fmt .Errorf ("could not get token claims: %w" , err )
195
+ }
196
+
197
+ registeredClaims := * claims [0 ].(* jwt.Claims )
198
+
199
+ var customClaims CustomClaims
200
+ if len (claims ) > 1 {
201
+ customClaims = claims [1 ].(CustomClaims )
202
+ }
203
+
204
+ return registeredClaims , customClaims , nil
205
+ }
206
+
207
+ func numericDateToUnixTime (date * jwt.NumericDate ) int64 {
208
+ if date != nil {
209
+ return date .Time ().Unix ()
210
+ }
211
+ return 0
160
212
}
0 commit comments