Skip to content

Commit 595f079

Browse files
authored
Merge pull request #3 from avast/RotatingKeysSupport
Rotating Keycloak public keys support
2 parents b43f09b + 296a711 commit 595f079

File tree

11 files changed

+252
-54
lines changed

11 files changed

+252
-54
lines changed

README.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ YourService.newStub(aChannel).withCallCredentials(callCredentials);
4848

4949
### Server usage
5050
This ensures that only requests with valid `JWT` in `Authorization` header are processed.
51-
The `fromConfig` method automatically downloads the Keycloak public key before the instance is actually created.
5251
```java
5352
import io.grpc.ServerServiceDefinition;
5453
import com.avast.grpc.jwt.keycloak.server.KeycloakJwtServerInterceptor;
@@ -78,4 +77,4 @@ On other hand, `Context key` is set of values that are available during request
7877

7978
So when implementing interceptors, you must be sure that you read Context values from the right thread. It's actually no issue for us because:
8079
1. The right thread is automatically handled by gRPC-core when using`CallCredentials`. So you can call `applier.apply()` method on any thread.
81-
2. Our `ServerInterceptor` implementation is fully synchronous.
80+
2. Our `ServerInterceptor` implementation handles it correctly.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package com.avast.grpc.jwt.server;
2+
3+
import io.grpc.ServerCall;
4+
import java.util.ArrayList;
5+
import java.util.List;
6+
7+
// https://stackoverflow.com/a/53656689/181796
8+
class DelayedServerCallListener<ReqT> extends ServerCall.Listener<ReqT> {
9+
private ServerCall.Listener<ReqT> delegate;
10+
private List<Runnable> events = new ArrayList<>();
11+
12+
@Override
13+
public synchronized void onMessage(ReqT message) {
14+
if (delegate == null) {
15+
events.add(() -> delegate.onMessage(message));
16+
} else {
17+
delegate.onMessage(message);
18+
}
19+
}
20+
21+
@Override
22+
public synchronized void onHalfClose() {
23+
if (delegate == null) {
24+
events.add(() -> delegate.onHalfClose());
25+
} else {
26+
delegate.onHalfClose();
27+
}
28+
}
29+
30+
@Override
31+
public synchronized void onCancel() {
32+
if (delegate == null) {
33+
events.add(() -> delegate.onCancel());
34+
} else {
35+
delegate.onCancel();
36+
}
37+
}
38+
39+
@Override
40+
public synchronized void onComplete() {
41+
if (delegate == null) {
42+
events.add(() -> delegate.onComplete());
43+
} else {
44+
delegate.onComplete();
45+
}
46+
}
47+
48+
@Override
49+
public synchronized void onReady() {
50+
if (delegate == null) {
51+
events.add(() -> delegate.onReady());
52+
} else {
53+
delegate.onReady();
54+
}
55+
}
56+
57+
public synchronized void setDelegate(ServerCall.Listener<ReqT> delegate) {
58+
this.delegate = delegate;
59+
for (Runnable runnable : events) {
60+
runnable.run();
61+
}
62+
events = null;
63+
}
64+
}

core/src/main/java/com/avast/grpc/jwt/server/JwtServerInterceptor.java

+30-11
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,38 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
3737
call.close(Status.UNAUTHENTICATED.withDescription(msg), new Metadata());
3838
return new ServerCall.Listener<ReqT>() {};
3939
}
40-
T token;
40+
DelayedServerCallListener<ReqT> delayedListener = new DelayedServerCallListener<>();
41+
Context context = Context.current(); // we must call this on the right thread
4142
try {
42-
token = tokenParser.parseToValid(authHeader.substring(AUTH_HEADER_PREFIX.length()));
43+
tokenParser
44+
.parseToValid(authHeader.substring(AUTH_HEADER_PREFIX.length()))
45+
.whenComplete(
46+
(token, e) ->
47+
context.run(
48+
() -> {
49+
if (e == null) {
50+
delayedListener.setDelegate(
51+
Contexts.interceptCall(
52+
Context.current().withValue(AccessTokenContextKey, token),
53+
call,
54+
headers,
55+
next));
56+
} else {
57+
delayedListener.setDelegate(handleException(e, call));
58+
}
59+
}));
4360
} catch (Exception e) {
44-
String msg =
45-
Constants.AuthorizationMetadataKey.name()
46-
+ " header validation failed: "
47-
+ e.getMessage();
48-
LOGGER.warn(msg, e);
49-
call.close(Status.UNAUTHENTICATED.withDescription(msg).withCause(e), new Metadata());
50-
return new ServerCall.Listener<ReqT>() {};
61+
return handleException(e, call);
5162
}
52-
return Contexts.interceptCall(
53-
Context.current().withValue(AccessTokenContextKey, token), call, headers, next);
63+
return delayedListener;
64+
}
65+
66+
private <ReqT, RespT> ServerCall.Listener<ReqT> handleException(
67+
Throwable e, ServerCall<ReqT, RespT> call) {
68+
String msg =
69+
Constants.AuthorizationMetadataKey.name() + " header validation failed: " + e.getMessage();
70+
LOGGER.warn(msg, e);
71+
call.close(Status.UNAUTHENTICATED.withDescription(msg).withCause(e), new Metadata());
72+
return new ServerCall.Listener<ReqT>() {};
5473
}
5574
}
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package com.avast.grpc.jwt.server;
22

3+
import java.util.concurrent.CompletableFuture;
4+
35
@FunctionalInterface
46
public interface JwtTokenParser<T> {
57

68
/** Get valid JWT token, throws an exception otherwise. */
7-
T parseToValid(String jwtToken) throws Exception;
9+
CompletableFuture<T> parseToValid(String jwtToken);
810
}

core/src/test/java/com/avast/grpc/jwt/server/JwtServerInterceptorTest.java

+19-1
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,21 @@
77
import io.grpc.Metadata;
88
import io.grpc.ServerCall;
99
import io.grpc.ServerCallHandler;
10+
import java.util.concurrent.CompletableFuture;
1011
import java.util.concurrent.atomic.AtomicReference;
1112
import org.junit.Test;
1213

1314
public class JwtServerInterceptorTest {
1415

15-
JwtTokenParser<String> jwtTokenParser = jwtToken -> jwtToken;
16+
JwtTokenParser<String> jwtTokenParser =
17+
jwtToken -> {
18+
if (jwtToken.equals("Invalid Token")) {
19+
CompletableFuture<String> res = new CompletableFuture<>();
20+
res.completeExceptionally(new RuntimeException("invalid token"));
21+
return res;
22+
}
23+
return CompletableFuture.completedFuture(jwtToken);
24+
};
1625
JwtServerInterceptor<String> target = new JwtServerInterceptor<>(jwtTokenParser);
1726
ServerCall<Object, Object> serverCall = (ServerCall<Object, Object>) mock(ServerCall.class);
1827
ServerCallHandler<Object, Object> next =
@@ -34,6 +43,15 @@ public void closesCallOnInvalidHeader() {
3443
verify(next, never()).startCall(any(), any());
3544
}
3645

46+
@Test
47+
public void closesCallOnInvalidToken() {
48+
Metadata metadata = new Metadata();
49+
metadata.put(Constants.AuthorizationMetadataKey, "Bearer Invalid Token");
50+
target.interceptCall(serverCall, metadata, next);
51+
verify(serverCall).close(any(), any());
52+
verify(next, never()).startCall(any(), any());
53+
}
54+
3755
@Test
3856
public void callNextStageWithContextKeyOnValidHeader() {
3957
Metadata metadata = new Metadata();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package com.avast.grpc.jwt.keycloak.server;
2+
3+
import com.fasterxml.jackson.databind.ObjectMapper;
4+
import java.io.IOException;
5+
import java.net.URL;
6+
import java.security.PublicKey;
7+
import java.time.Clock;
8+
import java.time.Duration;
9+
import java.time.Instant;
10+
import java.util.Map;
11+
import java.util.concurrent.ConcurrentHashMap;
12+
import org.keycloak.constants.ServiceUrlConstants;
13+
import org.keycloak.jose.jwk.JSONWebKeySet;
14+
import org.keycloak.jose.jwk.JWK;
15+
import org.keycloak.util.JWKSUtils;
16+
17+
public class DefaultKeycloakPublicKeyProvider implements KeycloakPublicKeyProvider {
18+
19+
private final String serverUrl;
20+
private final String realm;
21+
private final Duration minTimeBetweenJwksRequests;
22+
private final Duration publicKeyCacheTtl;
23+
private final Clock clock;
24+
25+
private Map<String, PublicKey> currentKeys = new ConcurrentHashMap<>();
26+
private volatile Instant lastRequestTime = Instant.MIN;
27+
28+
public DefaultKeycloakPublicKeyProvider(
29+
String serverUrl,
30+
String realm,
31+
Duration minTimeBetweenJwksRequests,
32+
Duration publicKeyCacheTtl,
33+
Clock clock) {
34+
this.serverUrl = serverUrl;
35+
this.realm = realm;
36+
this.minTimeBetweenJwksRequests = minTimeBetweenJwksRequests;
37+
this.publicKeyCacheTtl = publicKeyCacheTtl;
38+
this.clock = clock;
39+
}
40+
41+
@Override
42+
public PublicKey get(String keyId) {
43+
if (lastRequestTime.plus(publicKeyCacheTtl).isBefore(clock.instant())) {
44+
updateKeys();
45+
}
46+
PublicKey fromCache = currentKeys.get(keyId);
47+
if (fromCache != null) {
48+
return fromCache;
49+
}
50+
updateKeys();
51+
PublicKey res = currentKeys.get(keyId);
52+
if (res == null) {
53+
throw new RuntimeException("Key with following ID not found: " + keyId);
54+
}
55+
return res;
56+
}
57+
58+
protected void updateKeys() {
59+
synchronized (this) {
60+
if (clock.instant().isAfter(lastRequestTime.plus(minTimeBetweenJwksRequests))) {
61+
Map<String, PublicKey> newKeys = fetchNewKeys();
62+
currentKeys.clear();
63+
currentKeys.putAll(newKeys);
64+
lastRequestTime = clock.instant();
65+
}
66+
}
67+
}
68+
69+
protected Map<String, PublicKey> fetchNewKeys() {
70+
try {
71+
ObjectMapper om = new ObjectMapper();
72+
String jwksUrl = serverUrl + ServiceUrlConstants.JWKS_URL.replace("{realm-name}", realm);
73+
JSONWebKeySet jwks = om.readValue(new URL(jwksUrl).openStream(), JSONWebKeySet.class);
74+
return JWKSUtils.getKeysForUse(jwks, JWK.Use.SIG);
75+
} catch (IOException e) {
76+
throw new RuntimeException("Cannot fetch key from Keycloak server", e);
77+
}
78+
}
79+
}

keycloak/src/main/java/com/avast/grpc/jwt/keycloak/server/KeycloakJwtServerInterceptor.java

+10-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import com.avast.grpc.jwt.server.JwtServerInterceptor;
66
import com.avast.grpc.jwt.server.JwtTokenParser;
77
import com.typesafe.config.Config;
8+
import java.time.Clock;
89
import org.keycloak.representations.AccessToken;
910

1011
public class KeycloakJwtServerInterceptor extends JwtServerInterceptor<AccessToken> {
@@ -14,8 +15,16 @@ public KeycloakJwtServerInterceptor(JwtTokenParser<AccessToken> tokenParser) {
1415

1516
public static KeycloakJwtServerInterceptor fromConfig(Config config) {
1617
Config fc = config.withFallback(DefaultConfig);
18+
KeycloakPublicKeyProvider publicKeyProvider =
19+
new DefaultKeycloakPublicKeyProvider(
20+
fc.getString("serverUrl"),
21+
fc.getString("realm"),
22+
fc.getDuration("minTimeBetweenJwksRequests"),
23+
fc.getDuration("publicKeyCacheTtl"),
24+
Clock.systemUTC());
1725
KeycloakJwtTokenParser tokenParser =
18-
KeycloakJwtTokenParser.create(fc.getString("serverUrl"), fc.getString("realm"));
26+
new KeycloakJwtTokenParser(
27+
fc.getString("serverUrl"), fc.getString("realm"), publicKeyProvider);
1928
tokenParser = tokenParser.withExpectedAudience(fc.getString("expectedAudience"));
2029
tokenParser = tokenParser.withExpectedIssuedFor(fc.getString("expectedIssuedFor"));
2130
return new KeycloakJwtServerInterceptor(tokenParser);
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,25 @@
11
package com.avast.grpc.jwt.keycloak.server;
22

33
import com.avast.grpc.jwt.server.JwtTokenParser;
4-
import com.fasterxml.jackson.databind.ObjectMapper;
54
import com.google.common.base.Strings;
6-
import java.net.URL;
7-
import java.security.PublicKey;
5+
import java.util.concurrent.CompletableFuture;
6+
import java.util.concurrent.CompletionException;
87
import org.keycloak.TokenVerifier;
98
import org.keycloak.common.VerificationException;
109
import org.keycloak.constants.ServiceUrlConstants;
11-
import org.keycloak.jose.jwk.JSONWebKeySet;
12-
import org.keycloak.jose.jwk.JWK;
13-
import org.keycloak.jose.jwk.JWKParser;
1410
import org.keycloak.representations.AccessToken;
1511
import org.keycloak.util.TokenUtil;
1612

1713
public class KeycloakJwtTokenParser implements JwtTokenParser<AccessToken> {
1814

19-
protected final PublicKey publicKey;
15+
protected final KeycloakPublicKeyProvider publicKeyProvider;
2016
protected final TokenVerifier.Predicate<AccessToken>[] checks;
2117
protected String expectedAudience;
2218
protected String expectedIssuedFor;
2319

24-
protected KeycloakJwtTokenParser(String serverUrl, String realm, PublicKey publicKey) {
25-
this.publicKey = publicKey;
20+
public KeycloakJwtTokenParser(
21+
String serverUrl, String realm, KeycloakPublicKeyProvider publicKeyProvider) {
22+
this.publicKeyProvider = publicKeyProvider;
2623
String realmUrl =
2724
serverUrl + ServiceUrlConstants.REALM_INFO_PATH.replace("{realm-name}", realm);
2825
this.checks =
@@ -35,20 +32,36 @@ protected KeycloakJwtTokenParser(String serverUrl, String realm, PublicKey publi
3532
}
3633

3734
@Override
38-
public AccessToken parseToValid(String jwtToken) throws VerificationException {
39-
TokenVerifier<AccessToken> verifier = createTokenVerifier(jwtToken);
40-
return verifier.verify().getToken();
35+
public CompletableFuture<AccessToken> parseToValid(String jwtToken) {
36+
TokenVerifier<AccessToken> verifier;
37+
try {
38+
verifier = createTokenVerifier(jwtToken);
39+
} catch (VerificationException e) {
40+
CompletableFuture<AccessToken> r = new CompletableFuture<>();
41+
r.completeExceptionally(e);
42+
return r;
43+
}
44+
return CompletableFuture.supplyAsync(
45+
() -> {
46+
try {
47+
return verifier.verify().getToken();
48+
} catch (VerificationException e) {
49+
throw new CompletionException(e);
50+
}
51+
});
4152
}
4253

43-
protected TokenVerifier<AccessToken> createTokenVerifier(String jwtToken) {
54+
protected TokenVerifier<AccessToken> createTokenVerifier(String jwtToken)
55+
throws VerificationException {
4456
TokenVerifier<AccessToken> verifier =
45-
TokenVerifier.create(jwtToken, AccessToken.class).withChecks(checks).publicKey(publicKey);
57+
TokenVerifier.create(jwtToken, AccessToken.class).withChecks(checks);
4658
if (!Strings.isNullOrEmpty(expectedAudience)) {
4759
verifier = verifier.audience(expectedAudience);
4860
}
4961
if (!Strings.isNullOrEmpty(expectedIssuedFor)) {
5062
verifier = verifier.issuedFor(expectedIssuedFor);
5163
}
64+
verifier.publicKey(publicKeyProvider.get(verifier.getHeader().getKeyId()));
5265
return verifier;
5366
}
5467

@@ -61,20 +74,4 @@ public KeycloakJwtTokenParser withExpectedIssuedFor(String expectedIssuedFor) {
6174
this.expectedIssuedFor = expectedIssuedFor;
6275
return this;
6376
}
64-
65-
public static KeycloakJwtTokenParser create(String serverUrl, String realm) {
66-
try {
67-
ObjectMapper om = new ObjectMapper();
68-
String jwksUrl = serverUrl + ServiceUrlConstants.JWKS_URL.replace("{realm-name}", realm);
69-
JSONWebKeySet jwks = om.readValue(new URL(jwksUrl).openStream(), JSONWebKeySet.class);
70-
if (jwks.getKeys().length == 0) {
71-
throw new RuntimeException("No keys found");
72-
}
73-
JWK jwk = jwks.getKeys()[0];
74-
PublicKey publicKey = JWKParser.create(jwk).toPublicKey();
75-
return new KeycloakJwtTokenParser(serverUrl, realm, publicKey);
76-
} catch (Exception e) {
77-
throw new RuntimeException("Exception when obtaining public key from " + serverUrl, e);
78-
}
79-
}
8077
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package com.avast.grpc.jwt.keycloak.server;
2+
3+
import java.security.PublicKey;
4+
5+
public interface KeycloakPublicKeyProvider {
6+
PublicKey get(String keyId);
7+
}

keycloak/src/main/resources/reference.conf

+2
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@ keycloakDefaults {
1010
// for server
1111
expectedAudience = ""
1212
expectedIssuedFor = ""
13+
minTimeBetweenJwksRequests = 10 seconds
14+
publicKeyCacheTtl = 1 day
1315
}

0 commit comments

Comments
 (0)