Skip to content

Commit cd0d7da

Browse files
committed
Keyclock public key rotation support Closes #2
1 parent 1b1095d commit cd0d7da

File tree

5 files changed

+114
-29
lines changed

5 files changed

+114
-29
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package com.avast.grpc.jwt.keycloak.server;
2+
3+
import com.fasterxml.jackson.databind.ObjectMapper;
4+
import java.io.IOException;
5+
import java.net.URL;
6+
import java.security.PublicKey;
7+
import java.time.Clock;
8+
import java.time.Duration;
9+
import java.time.Instant;
10+
import java.util.Map;
11+
import java.util.concurrent.ConcurrentHashMap;
12+
import org.keycloak.constants.ServiceUrlConstants;
13+
import org.keycloak.jose.jwk.JSONWebKeySet;
14+
import org.keycloak.jose.jwk.JWK;
15+
import org.keycloak.util.JWKSUtils;
16+
17+
public class DefaultKeycloakPublicKeyProvider implements KeycloakPublicKeyProvider {
18+
19+
private final String serverUrl;
20+
private final String realm;
21+
private final Duration minTimeBetweenJwksRequests;
22+
private final Duration publicKeyCacheTtl;
23+
private final Clock clock;
24+
25+
private Map<String, PublicKey> currentKeys = new ConcurrentHashMap<>();
26+
private volatile Instant lastRequestTime = Instant.MIN;
27+
28+
public DefaultKeycloakPublicKeyProvider(
29+
String serverUrl,
30+
String realm,
31+
Duration minTimeBetweenJwksRequests,
32+
Duration publicKeyCacheTtl,
33+
Clock clock) {
34+
this.serverUrl = serverUrl;
35+
this.realm = realm;
36+
this.minTimeBetweenJwksRequests = minTimeBetweenJwksRequests;
37+
this.publicKeyCacheTtl = publicKeyCacheTtl;
38+
this.clock = clock;
39+
}
40+
41+
@Override
42+
public PublicKey get(String keyId) {
43+
if (lastRequestTime.plus(publicKeyCacheTtl).isBefore(clock.instant())) {
44+
updateKeys();
45+
}
46+
PublicKey fromCache = currentKeys.get(keyId);
47+
if (fromCache != null) {
48+
return fromCache;
49+
}
50+
updateKeys();
51+
PublicKey res = currentKeys.get(keyId);
52+
if (res == null) {
53+
throw new RuntimeException("Key with following ID not found: " + keyId);
54+
}
55+
return res;
56+
}
57+
58+
protected void updateKeys() {
59+
synchronized (this) {
60+
if (clock.instant().isAfter(lastRequestTime.plus(minTimeBetweenJwksRequests))) {
61+
Map<String, PublicKey> newKeys = fetchNewKeys();
62+
currentKeys.clear();
63+
currentKeys.putAll(newKeys);
64+
lastRequestTime = clock.instant();
65+
}
66+
}
67+
}
68+
69+
protected Map<String, PublicKey> fetchNewKeys() {
70+
try {
71+
ObjectMapper om = new ObjectMapper();
72+
String jwksUrl = serverUrl + ServiceUrlConstants.JWKS_URL.replace("{realm-name}", realm);
73+
JSONWebKeySet jwks = om.readValue(new URL(jwksUrl).openStream(), JSONWebKeySet.class);
74+
return JWKSUtils.getKeysForUse(jwks, JWK.Use.SIG);
75+
} catch (IOException e) {
76+
throw new RuntimeException("Cannot fetch key from Keycloak server", e);
77+
}
78+
}
79+
}

keycloak/src/main/java/com/avast/grpc/jwt/keycloak/server/KeycloakJwtServerInterceptor.java

+10-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import com.avast.grpc.jwt.server.JwtServerInterceptor;
66
import com.avast.grpc.jwt.server.JwtTokenParser;
77
import com.typesafe.config.Config;
8+
import java.time.Clock;
89
import org.keycloak.representations.AccessToken;
910

1011
public class KeycloakJwtServerInterceptor extends JwtServerInterceptor<AccessToken> {
@@ -14,8 +15,16 @@ public KeycloakJwtServerInterceptor(JwtTokenParser<AccessToken> tokenParser) {
1415

1516
public static KeycloakJwtServerInterceptor fromConfig(Config config) {
1617
Config fc = config.withFallback(DefaultConfig);
18+
KeycloakPublicKeyProvider publicKeyProvider =
19+
new DefaultKeycloakPublicKeyProvider(
20+
fc.getString("serverUrl"),
21+
fc.getString("realm"),
22+
fc.getDuration("minTimeBetweenJwksRequests"),
23+
fc.getDuration("publicKeyCacheTtl"),
24+
Clock.systemUTC());
1725
KeycloakJwtTokenParser tokenParser =
18-
KeycloakJwtTokenParser.create(fc.getString("serverUrl"), fc.getString("realm"));
26+
new KeycloakJwtTokenParser(
27+
fc.getString("serverUrl"), fc.getString("realm"), publicKeyProvider);
1928
tokenParser = tokenParser.withExpectedAudience(fc.getString("expectedAudience"));
2029
tokenParser = tokenParser.withExpectedIssuedFor(fc.getString("expectedIssuedFor"));
2130
return new KeycloakJwtServerInterceptor(tokenParser);
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,25 @@
11
package com.avast.grpc.jwt.keycloak.server;
22

33
import com.avast.grpc.jwt.server.JwtTokenParser;
4-
import com.fasterxml.jackson.databind.ObjectMapper;
54
import com.google.common.base.Strings;
6-
import java.net.URL;
7-
import java.security.PublicKey;
85
import java.util.concurrent.CompletableFuture;
96
import java.util.concurrent.CompletionException;
107
import org.keycloak.TokenVerifier;
118
import org.keycloak.common.VerificationException;
129
import org.keycloak.constants.ServiceUrlConstants;
13-
import org.keycloak.jose.jwk.JSONWebKeySet;
14-
import org.keycloak.jose.jwk.JWK;
15-
import org.keycloak.jose.jwk.JWKParser;
1610
import org.keycloak.representations.AccessToken;
1711
import org.keycloak.util.TokenUtil;
1812

1913
public class KeycloakJwtTokenParser implements JwtTokenParser<AccessToken> {
2014

21-
protected final PublicKey publicKey;
15+
protected final KeycloakPublicKeyProvider publicKeyProvider;
2216
protected final TokenVerifier.Predicate<AccessToken>[] checks;
2317
protected String expectedAudience;
2418
protected String expectedIssuedFor;
2519

26-
protected KeycloakJwtTokenParser(String serverUrl, String realm, PublicKey publicKey) {
27-
this.publicKey = publicKey;
20+
public KeycloakJwtTokenParser(
21+
String serverUrl, String realm, KeycloakPublicKeyProvider publicKeyProvider) {
22+
this.publicKeyProvider = publicKeyProvider;
2823
String realmUrl =
2924
serverUrl + ServiceUrlConstants.REALM_INFO_PATH.replace("{realm-name}", realm);
3025
this.checks =
@@ -38,7 +33,14 @@ protected KeycloakJwtTokenParser(String serverUrl, String realm, PublicKey publi
3833

3934
@Override
4035
public CompletableFuture<AccessToken> parseToValid(String jwtToken) {
41-
TokenVerifier<AccessToken> verifier = createTokenVerifier(jwtToken);
36+
TokenVerifier<AccessToken> verifier;
37+
try {
38+
verifier = createTokenVerifier(jwtToken);
39+
} catch (VerificationException e) {
40+
CompletableFuture<AccessToken> r = new CompletableFuture<>();
41+
r.completeExceptionally(e);
42+
return r;
43+
}
4244
return CompletableFuture.supplyAsync(
4345
() -> {
4446
try {
@@ -49,15 +51,17 @@ public CompletableFuture<AccessToken> parseToValid(String jwtToken) {
4951
});
5052
}
5153

52-
protected TokenVerifier<AccessToken> createTokenVerifier(String jwtToken) {
54+
protected TokenVerifier<AccessToken> createTokenVerifier(String jwtToken)
55+
throws VerificationException {
5356
TokenVerifier<AccessToken> verifier =
54-
TokenVerifier.create(jwtToken, AccessToken.class).withChecks(checks).publicKey(publicKey);
57+
TokenVerifier.create(jwtToken, AccessToken.class).withChecks(checks);
5558
if (!Strings.isNullOrEmpty(expectedAudience)) {
5659
verifier = verifier.audience(expectedAudience);
5760
}
5861
if (!Strings.isNullOrEmpty(expectedIssuedFor)) {
5962
verifier = verifier.issuedFor(expectedIssuedFor);
6063
}
64+
verifier.publicKey(publicKeyProvider.get(verifier.getHeader().getKeyId()));
6165
return verifier;
6266
}
6367

@@ -70,20 +74,4 @@ public KeycloakJwtTokenParser withExpectedIssuedFor(String expectedIssuedFor) {
7074
this.expectedIssuedFor = expectedIssuedFor;
7175
return this;
7276
}
73-
74-
public static KeycloakJwtTokenParser create(String serverUrl, String realm) {
75-
try {
76-
ObjectMapper om = new ObjectMapper();
77-
String jwksUrl = serverUrl + ServiceUrlConstants.JWKS_URL.replace("{realm-name}", realm);
78-
JSONWebKeySet jwks = om.readValue(new URL(jwksUrl).openStream(), JSONWebKeySet.class);
79-
if (jwks.getKeys().length == 0) {
80-
throw new RuntimeException("No keys found");
81-
}
82-
JWK jwk = jwks.getKeys()[0];
83-
PublicKey publicKey = JWKParser.create(jwk).toPublicKey();
84-
return new KeycloakJwtTokenParser(serverUrl, realm, publicKey);
85-
} catch (Exception e) {
86-
throw new RuntimeException("Exception when obtaining public key from " + serverUrl, e);
87-
}
88-
}
8977
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package com.avast.grpc.jwt.keycloak.server;
2+
3+
import java.security.PublicKey;
4+
5+
public interface KeycloakPublicKeyProvider {
6+
PublicKey get(String keyId);
7+
}

keycloak/src/main/resources/reference.conf

+2
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@ keycloakDefaults {
1010
// for server
1111
expectedAudience = ""
1212
expectedIssuedFor = ""
13+
minTimeBetweenJwksRequests = 10 seconds
14+
publicKeyCacheTtl = 1 day
1315
}

0 commit comments

Comments
 (0)