diff --git a/docs/src/main/asciidoc/security-openid-connect-client-reference.adoc b/docs/src/main/asciidoc/security-openid-connect-client-reference.adoc index d36e2056f360b..3c871a43b23c0 100644 --- a/docs/src/main/asciidoc/security-openid-connect-client-reference.adoc +++ b/docs/src/main/asciidoc/security-openid-connect-client-reference.adoc @@ -1391,6 +1391,29 @@ public interface ProtectedResourceService { or +[source,java] +---- +import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; +import io.quarkus.oidc.token.propagation.common.AccessToken; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; + +@RegisterRestClient +@Path("/") +public interface InformationService { + + @AccessToken + @GET + String getUserName(); + + @Path("/public") + @GET + String getPublicInformation(); +} +---- + +or + [source,java] ---- import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; @@ -1482,6 +1505,30 @@ public interface ProtectedResourceService { String getUserName(); } ---- + +or + +[source,java] +---- +import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; +import io.quarkus.oidc.token.propagation.common.AccessToken; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; + +@RegisterRestClient +@Path("/") +public interface InformationService { + + @AccessToken + @GET + String getUserName(); + + @Path("/public") + @GET + String getPublicInformation(); +} +---- + or [source,java] diff --git a/extensions/oidc-token-propagation-common/deployment/src/main/java/io/quarkus/oidc/token/propagation/common/deployment/AccessTokenInstanceBuildItem.java b/extensions/oidc-token-propagation-common/deployment/src/main/java/io/quarkus/oidc/token/propagation/common/deployment/AccessTokenInstanceBuildItem.java index 204023cf9e5cc..1502488213862 100644 --- a/extensions/oidc-token-propagation-common/deployment/src/main/java/io/quarkus/oidc/token/propagation/common/deployment/AccessTokenInstanceBuildItem.java +++ b/extensions/oidc-token-propagation-common/deployment/src/main/java/io/quarkus/oidc/token/propagation/common/deployment/AccessTokenInstanceBuildItem.java @@ -3,6 +3,7 @@ import java.util.Objects; import org.jboss.jandex.AnnotationTarget; +import org.jboss.jandex.MethodInfo; import io.quarkus.builder.item.MultiBuildItem; @@ -14,18 +15,21 @@ public final class AccessTokenInstanceBuildItem extends MultiBuildItem { private final String clientName; private final boolean tokenExchange; private final AnnotationTarget annotationTarget; + private final MethodInfo targetMethodInfo; - AccessTokenInstanceBuildItem(String clientName, Boolean tokenExchange, AnnotationTarget annotationTarget) { + AccessTokenInstanceBuildItem(String clientName, Boolean tokenExchange, AnnotationTarget annotationTarget, + MethodInfo targetMethodInfo) { this.clientName = Objects.requireNonNull(clientName); this.tokenExchange = tokenExchange; this.annotationTarget = Objects.requireNonNull(annotationTarget); + this.targetMethodInfo = targetMethodInfo; } - public String getClientName() { + String getClientName() { return clientName; } - public boolean exchangeTokenActivated() { + boolean exchangeTokenActivated() { return tokenExchange; } @@ -34,6 +38,13 @@ public AnnotationTarget getAnnotationTarget() { } public String targetClass() { - return annotationTarget.asClass().name().toString(); + if (annotationTarget.kind() == AnnotationTarget.Kind.CLASS) { + return annotationTarget.asClass().name().toString(); + } + return annotationTarget.asMethod().declaringClass().name().toString(); + } + + MethodInfo getTargetMethodInfo() { + return targetMethodInfo; } } diff --git a/extensions/oidc-token-propagation-common/deployment/src/main/java/io/quarkus/oidc/token/propagation/common/deployment/AccessTokenRequestFilterGenerator.java b/extensions/oidc-token-propagation-common/deployment/src/main/java/io/quarkus/oidc/token/propagation/common/deployment/AccessTokenRequestFilterGenerator.java index 20f5050276a8e..dd4d8b3021bbd 100644 --- a/extensions/oidc-token-propagation-common/deployment/src/main/java/io/quarkus/oidc/token/propagation/common/deployment/AccessTokenRequestFilterGenerator.java +++ b/extensions/oidc-token-propagation-common/deployment/src/main/java/io/quarkus/oidc/token/propagation/common/deployment/AccessTokenRequestFilterGenerator.java @@ -8,25 +8,30 @@ import jakarta.annotation.Priority; import jakarta.inject.Singleton; +import org.jboss.jandex.MethodInfo; + import io.quarkus.arc.deployment.GeneratedBeanBuildItem; import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor; import io.quarkus.arc.deployment.UnremovableBeanBuildItem; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem; import io.quarkus.gizmo.ClassCreator; +import io.quarkus.gizmo.MethodDescriptor; +import io.quarkus.gizmo.ResultHandle; +import io.quarkus.security.spi.runtime.MethodDescription; public final class AccessTokenRequestFilterGenerator { private static final int AUTHENTICATION = 1000; - private record ClientNameAndExchangeToken(String clientName, boolean exchangeTokenActivated) { + private record RequestFilterKey(String clientName, boolean exchangeTokenActivated, MethodInfo targetMethodInfo) { } private final BuildProducer unremovableBeansProducer; private final BuildProducer reflectiveClassProducer; private final BuildProducer generatedBeanProducer; private final Class requestFilterClass; - private final Map cache = new HashMap<>(); + private final Map cache = new HashMap<>(); public AccessTokenRequestFilterGenerator(BuildProducer unremovableBeansProducer, BuildProducer reflectiveClassProducer, @@ -39,7 +44,9 @@ public AccessTokenRequestFilterGenerator(BuildProducer public String generateClass(AccessTokenInstanceBuildItem instance) { return cache.computeIfAbsent( - new ClientNameAndExchangeToken(instance.getClientName(), instance.exchangeTokenActivated()), i -> { + new RequestFilterKey(instance.getClientName(), instance.exchangeTokenActivated(), + instance.getTargetMethodInfo()), + i -> { var adaptor = new GeneratedBeanGizmoAdaptor(generatedBeanProducer); String className = createUniqueClassName(i); try (ClassCreator classCreator = ClassCreator.builder() @@ -64,6 +71,37 @@ public String generateClass(AccessTokenInstanceBuildItem instance) { methodCreator.returnBoolean(true); } } + + /* + * protected MethodDescription getMethodDescription() { + * return new MethodDescription(declaringClassName, methodName, parameterTypes); + * } + */ + if (i.targetMethodInfo != null) { + try (var methodCreator = classCreator.getMethodCreator("getMethodDescription", + MethodDescription.class)) { + methodCreator.addAnnotation(Override.class.getName(), RetentionPolicy.CLASS); + methodCreator.setModifiers(Modifier.PROTECTED); + + // String methodName + var methodName = methodCreator.load(i.targetMethodInfo.name()); + // String declaringClassName + var declaringClassName = methodCreator + .load(i.targetMethodInfo.declaringClass().name().toString()); + // String[] paramTypes + var paramTypes = methodCreator.marshalAsArray(String[].class, + i.targetMethodInfo.parameterTypes().stream() + .map(pt -> pt.name().toString()).map(methodCreator::load) + .toArray(ResultHandle[]::new)); + // new MethodDescription(declaringClassName, methodName, parameterTypes) + var methodDescriptionCtor = MethodDescriptor.ofConstructor(MethodDescription.class, + String.class, String.class, String[].class); + var newMethodDescription = methodCreator.newInstance(methodDescriptionCtor, declaringClassName, + methodName, paramTypes); + // return new MethodDescription(declaringClassName, methodName, parameterTypes); + methodCreator.returnValue(newMethodDescription); + } + } } unremovableBeansProducer.produce(UnremovableBeanBuildItem.beanClassNames(className)); reflectiveClassProducer @@ -74,9 +112,13 @@ public String generateClass(AccessTokenInstanceBuildItem instance) { }); } - private String createUniqueClassName(ClientNameAndExchangeToken i) { - return "%s_%sClient_%sTokenExchange".formatted(requestFilterClass.getName(), clientName(i.clientName()), - exchangeTokenName(i.exchangeTokenActivated())); + private String createUniqueClassName(RequestFilterKey i) { + String uniqueClassName = "%s_%sClient_%sTokenExchange".formatted(requestFilterClass.getName(), + clientName(i.clientName()), exchangeTokenName(i.exchangeTokenActivated())); + if (i.targetMethodInfo != null) { + uniqueClassName = uniqueClassName + "_" + i.targetMethodInfo.name(); + } + return uniqueClassName; } private static String clientName(String clientName) { diff --git a/extensions/oidc-token-propagation-common/deployment/src/main/java/io/quarkus/oidc/token/propagation/common/deployment/OidcTokenPropagationCommonProcessor.java b/extensions/oidc-token-propagation-common/deployment/src/main/java/io/quarkus/oidc/token/propagation/common/deployment/OidcTokenPropagationCommonProcessor.java index 6194a5fa1791a..410a2a2533893 100644 --- a/extensions/oidc-token-propagation-common/deployment/src/main/java/io/quarkus/oidc/token/propagation/common/deployment/OidcTokenPropagationCommonProcessor.java +++ b/extensions/oidc-token-propagation-common/deployment/src/main/java/io/quarkus/oidc/token/propagation/common/deployment/OidcTokenPropagationCommonProcessor.java @@ -1,9 +1,14 @@ package io.quarkus.oidc.token.propagation.common.deployment; +import static java.util.stream.Collectors.groupingBy; + import java.util.List; +import java.util.Objects; import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.AnnotationTarget; import org.jboss.jandex.DotName; +import org.jboss.jandex.MethodInfo; import io.quarkus.deployment.annotations.BuildStep; import io.quarkus.deployment.builditem.CombinedIndexBuildItem; @@ -26,12 +31,36 @@ private boolean toExchangeToken() { return instance.value("exchangeTokenClient") != null; } + private MethodInfo methodInfo() { + if (instance.target().kind() == AnnotationTarget.Kind.METHOD) { + return instance.target().asMethod(); + } + return null; + } + + private String targetClassName() { + if (instance.target().kind() == AnnotationTarget.Kind.METHOD) { + return instance.target().asMethod().declaringClass().name().toString(); + } + return instance.target().asClass().name().toString(); + } + private AccessTokenInstanceBuildItem build() { - return new AccessTokenInstanceBuildItem(toClientName(), toExchangeToken(), instance.target()); + return new AccessTokenInstanceBuildItem(toClientName(), toExchangeToken(), instance.target(), methodInfo()); } } var accessTokenAnnotations = index.getIndex().getAnnotations(ACCESS_TOKEN); - return accessTokenAnnotations.stream().map(ItemBuilder::new).map(ItemBuilder::build).toList(); + var itemBuilders = accessTokenAnnotations.stream().map(ItemBuilder::new).toList(); + if (!itemBuilders.isEmpty()) { + var targetClassToBuilders = itemBuilders.stream().collect(groupingBy(ItemBuilder::targetClassName)); + targetClassToBuilders.forEach((targetClassName, classBuilders) -> { + if (classBuilders.size() > 1 && classBuilders.stream().map(ItemBuilder::methodInfo).anyMatch(Objects::isNull)) { + throw new RuntimeException( + ACCESS_TOKEN + " annotation can be applied either on class " + targetClassName + " or its methods"); + } + }); + } + return itemBuilders.stream().map(ItemBuilder::build).toList(); } } diff --git a/extensions/oidc-token-propagation-common/runtime/pom.xml b/extensions/oidc-token-propagation-common/runtime/pom.xml index 896b66ea0fe32..3b100d262778c 100644 --- a/extensions/oidc-token-propagation-common/runtime/pom.xml +++ b/extensions/oidc-token-propagation-common/runtime/pom.xml @@ -22,6 +22,10 @@ io.quarkus quarkus-arc + + io.quarkus + quarkus-security-runtime-spi + diff --git a/extensions/oidc-token-propagation-common/runtime/src/main/java/io/quarkus/oidc/token/propagation/common/AccessToken.java b/extensions/oidc-token-propagation-common/runtime/src/main/java/io/quarkus/oidc/token/propagation/common/AccessToken.java index 3fcd4ee1f0faf..eb0be9061d9cc 100644 --- a/extensions/oidc-token-propagation-common/runtime/src/main/java/io/quarkus/oidc/token/propagation/common/AccessToken.java +++ b/extensions/oidc-token-propagation-common/runtime/src/main/java/io/quarkus/oidc/token/propagation/common/AccessToken.java @@ -12,8 +12,11 @@ * The end result is that the request propagates the Bearer token present in the current active request or the token acquired * from the Authorization Code Flow, * as the HTTP {@code Authorization} header's {@code Bearer} scheme value. + *

+ * This annotation may also be placed on individual methods of the REST Client interface. + * When applied to a method, the {@link AccessTokenRequestFilter} will be registered only for that method. */ -@Target({ ElementType.TYPE }) +@Target({ ElementType.TYPE, ElementType.METHOD }) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface AccessToken { diff --git a/extensions/oidc-token-propagation-reactive/deployment/src/test/java/io/quarkus/oidc/token/propagation/reactive/deployment/test/AccessTokenAnnotationTest.java b/extensions/oidc-token-propagation-reactive/deployment/src/test/java/io/quarkus/oidc/token/propagation/reactive/deployment/test/AccessTokenAnnotationTest.java index 43040493e7353..f294497b5e5c4 100644 --- a/extensions/oidc-token-propagation-reactive/deployment/src/test/java/io/quarkus/oidc/token/propagation/reactive/deployment/test/AccessTokenAnnotationTest.java +++ b/extensions/oidc-token-propagation-reactive/deployment/src/test/java/io/quarkus/oidc/token/propagation/reactive/deployment/test/AccessTokenAnnotationTest.java @@ -14,6 +14,7 @@ import org.eclipse.microprofile.jwt.JsonWebToken; import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; import org.eclipse.microprofile.rest.client.inject.RestClient; +import org.hamcrest.Matchers; import org.jboss.shrinkwrap.api.asset.StringAsset; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; @@ -38,7 +39,9 @@ public class AccessTokenAnnotationTest { .withApplicationRoot((jar) -> jar .addClasses(DefaultClientDefaultExchange.class, DefaultClientEnabledExchange.class, NamedClientDefaultExchange.class, MultiProviderFrontendResource.class, ProtectedResource.class, - CustomAccessTokenRequestFilter.class) + CustomAccessTokenRequestFilter.class, NamedClientDefaultExchange_OnMethod.class, + DefaultClientEnabledExchange_OnMethod.class, DefaultClientDefaultExchange_OnMethod.class, + MultipleClientsAndMultipleMethods.class) .addAsResource( new StringAsset( """ @@ -74,16 +77,32 @@ public static void close() { @Test public void testDefaultClientEnabledTokenExchange() { testRestClientTokenPropagation(true, "defaultClientEnabledExchange"); + testRestClientTokenPropagation(true, "defaultClientEnabledExchange_OnMethod"); + testRestClientTokenPropagation(true, "multipleClientsAndMultipleMethods_DefaultClientEnabledExchange"); } @Test public void testDefaultClientDefaultTokenExchange() { testRestClientTokenPropagation(false, "defaultClientDefaultExchange"); + testRestClientTokenPropagation(false, "defaultClientDefaultExchange_OnMethod"); + testRestClientTokenPropagation(false, "multipleClientsAndMultipleMethods_DefaultClientDefaultExchange"); } @Test public void testNamedClientDefaultTokenExchange() { testRestClientTokenPropagation(true, "namedClientDefaultExchange"); + testRestClientTokenPropagation(true, "namedClientDefaultExchange_OnMethod"); + testRestClientTokenPropagation(true, "multipleClientsAndMultipleMethods_NamedClientDefaultExchange"); + } + + @Test + public void testNoTokenPropagation() { + RestAssured.given().auth().oauth2(getBearerAccessToken()) + .queryParam("client-key", "multipleClientsAndMultipleMethods_NoAccessToken") + .when().get("/frontend/token-propagation") + .then() + .statusCode(500) + .body(Matchers.containsString("Unauthorized, status code 401")); } private void testRestClientTokenPropagation(boolean exchangeEnabled, String clientKey) { @@ -124,6 +143,50 @@ public interface NamedClientDefaultExchange { String getUserName(); } + @RegisterRestClient(baseUri = "http://localhost:8081/protected") + @Path("/") + public interface DefaultClientDefaultExchange_OnMethod { + @AccessToken + @GET + String getUserName(); + } + + @RegisterRestClient(baseUri = "http://localhost:8081/protected") + @Path("/") + public interface DefaultClientEnabledExchange_OnMethod { + @AccessToken(exchangeTokenClient = "Default") + @GET + String getUserName(); + } + + @RegisterRestClient(baseUri = "http://localhost:8081/protected") + @Path("/") + public interface MultipleClientsAndMultipleMethods { + + @AccessToken + @GET + String getUserName_DefaultClientDefaultExchange(); + + @AccessToken(exchangeTokenClient = "named") + @GET + String getUserName_NamedClientDefaultExchange(); + + @AccessToken(exchangeTokenClient = "Default") + @GET + String getUserName_DefaultClientEnabledExchange(); + + @GET + String getUserName_NoAccessToken(); + } + + @RegisterRestClient(baseUri = "http://localhost:8081/protected") + @Path("/") + public interface NamedClientDefaultExchange_OnMethod { + @AccessToken(exchangeTokenClient = "named") + @GET + String getUserName(); + } + // tests no AmbiguousResolutionException is raised @Singleton @Unremovable @@ -144,6 +207,22 @@ public static class MultiProviderFrontendResource { @RestClient NamedClientDefaultExchange namedClientDefaultExchange; + @Inject + @RestClient + DefaultClientDefaultExchange_OnMethod defaultClientDefaultExchange_OnMethod; + + @Inject + @RestClient + DefaultClientEnabledExchange_OnMethod defaultClientEnabledExchange_OnMethod; + + @Inject + @RestClient + NamedClientDefaultExchange_OnMethod namedClientDefaultExchange_OnMethod; + + @Inject + @RestClient + MultipleClientsAndMultipleMethods multipleClientsAndMultipleMethods; + @Inject JsonWebToken jwt; @@ -190,6 +269,17 @@ private String getUserName(String clientKey) { case "defaultClientDefaultExchange" -> defaultClientDefaultExchange.getUserName(); case "defaultClientEnabledExchange" -> defaultClientEnabledExchange.getUserName(); case "namedClientDefaultExchange" -> namedClientDefaultExchange.getUserName(); + case "defaultClientDefaultExchange_OnMethod" -> defaultClientDefaultExchange_OnMethod.getUserName(); + case "defaultClientEnabledExchange_OnMethod" -> defaultClientEnabledExchange_OnMethod.getUserName(); + case "namedClientDefaultExchange_OnMethod" -> namedClientDefaultExchange_OnMethod.getUserName(); + case "multipleClientsAndMultipleMethods_DefaultClientDefaultExchange" -> + multipleClientsAndMultipleMethods.getUserName_DefaultClientDefaultExchange(); + case "multipleClientsAndMultipleMethods_DefaultClientEnabledExchange" -> + multipleClientsAndMultipleMethods.getUserName_DefaultClientEnabledExchange(); + case "multipleClientsAndMultipleMethods_NamedClientDefaultExchange" -> + multipleClientsAndMultipleMethods.getUserName_NamedClientDefaultExchange(); + case "multipleClientsAndMultipleMethods_NoAccessToken" -> + multipleClientsAndMultipleMethods.getUserName_NoAccessToken(); default -> throw new IllegalArgumentException("Unknown client key"); }; } diff --git a/extensions/oidc-token-propagation-reactive/runtime/src/main/java/io/quarkus/oidc/token/propagation/reactive/AccessTokenRequestReactiveFilter.java b/extensions/oidc-token-propagation-reactive/runtime/src/main/java/io/quarkus/oidc/token/propagation/reactive/AccessTokenRequestReactiveFilter.java index a1405d41382ad..98e3926afe170 100644 --- a/extensions/oidc-token-propagation-reactive/runtime/src/main/java/io/quarkus/oidc/token/propagation/reactive/AccessTokenRequestReactiveFilter.java +++ b/extensions/oidc-token-propagation-reactive/runtime/src/main/java/io/quarkus/oidc/token/propagation/reactive/AccessTokenRequestReactiveFilter.java @@ -3,6 +3,7 @@ import static io.quarkus.oidc.token.propagation.common.runtime.TokenPropagationConstants.JWT_PROPAGATE_TOKEN_CREDENTIAL; import static io.quarkus.oidc.token.propagation.common.runtime.TokenPropagationConstants.OIDC_PROPAGATE_TOKEN_CREDENTIAL; +import java.lang.reflect.Method; import java.util.Collections; import java.util.function.Consumer; @@ -16,6 +17,7 @@ import org.eclipse.microprofile.config.ConfigProvider; import org.jboss.logging.Logger; +import org.jboss.resteasy.reactive.client.impl.RestClientRequestContext; import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestContext; import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestFilter; @@ -27,6 +29,7 @@ import io.quarkus.oidc.common.runtime.OidcConstants; import io.quarkus.runtime.configuration.ConfigurationException; import io.quarkus.security.credential.TokenCredential; +import io.quarkus.security.spi.runtime.MethodDescription; import io.quarkus.vertx.core.runtime.context.VertxContextSafetyToggle; import io.smallrye.mutiny.Uni; import io.vertx.core.Vertx; @@ -77,6 +80,10 @@ protected boolean isExchangeToken() { @Override public void filter(ResteasyReactiveClientRequestContext requestContext) { + if (skipPropagation(requestContext)) { + return; + } + if (verifyTokenInstance(requestContext)) { if (exchangeTokenClient != null) { @@ -180,4 +187,30 @@ private Uni exchangeToken(String token) { protected void abortRequest(ResteasyReactiveClientRequestContext requestContext) { requestContext.abortWith(Response.status(401).build()); } + + private boolean skipPropagation(ResteasyReactiveClientRequestContext requestContext) { + if (getMethodDescription() == null) { + return false; + } + + return !getMethodDescription().equals(getInvokedRestClientMethodSignature(requestContext)); + } + + private static MethodDescription getInvokedRestClientMethodSignature(ResteasyReactiveClientRequestContext requestContext) { + Method method = (Method) requestContext.getProperty(RestClientRequestContext.INVOKED_METHOD_PROP); + if (method == null) { + throw new IllegalStateException(RestClientRequestContext.INVOKED_METHOD_PROP + " property must not be null"); + } + return MethodDescription.ofMethod(method); + } + + /** + * This method is overridden by generated filter classes if the filter should only be applied on the REST client + * method. + * + * @return REST client method description for which this filter should be applied; or null if applies to all methods + */ + protected MethodDescription getMethodDescription() { + return null; + } } diff --git a/extensions/oidc-token-propagation/deployment/src/test/java/io/quarkus/oidc/token/propagation/deployment/test/AccessTokenAnnotationTest.java b/extensions/oidc-token-propagation/deployment/src/test/java/io/quarkus/oidc/token/propagation/deployment/test/AccessTokenAnnotationTest.java index a916ff8e8bd21..3d2c1f743987a 100644 --- a/extensions/oidc-token-propagation/deployment/src/test/java/io/quarkus/oidc/token/propagation/deployment/test/AccessTokenAnnotationTest.java +++ b/extensions/oidc-token-propagation/deployment/src/test/java/io/quarkus/oidc/token/propagation/deployment/test/AccessTokenAnnotationTest.java @@ -36,7 +36,9 @@ public class AccessTokenAnnotationTest { .withApplicationRoot((jar) -> jar .addClasses(DefaultClientDefaultExchange.class, DefaultClientEnabledExchange.class, NamedClientDefaultExchange.class, MultiProviderFrontendResource.class, ProtectedResource.class, - CustomAccessTokenRequestFilter.class) + CustomAccessTokenRequestFilter.class, NamedClientDefaultExchange_OnMethod.class, + DefaultClientEnabledExchange_OnMethod.class, DefaultClientDefaultExchange_OnMethod.class, + MultipleClientsAndMultipleMethods.class) .addAsResource( new StringAsset( """ @@ -72,16 +74,31 @@ public static void close() { @Test public void testDefaultClientEnabledTokenExchange() { testRestClientTokenPropagation(true, "defaultClientEnabledExchange"); + testRestClientTokenPropagation(true, "defaultClientEnabledExchange_OnMethod"); + testRestClientTokenPropagation(true, "multipleClientsAndMultipleMethods_DefaultClientEnabledExchange"); } @Test public void testDefaultClientDefaultTokenExchange() { testRestClientTokenPropagation(false, "defaultClientDefaultExchange"); + testRestClientTokenPropagation(false, "defaultClientDefaultExchange_OnMethod"); + testRestClientTokenPropagation(false, "multipleClientsAndMultipleMethods_DefaultClientDefaultExchange"); } @Test public void testNamedClientDefaultTokenExchange() { testRestClientTokenPropagation(true, "namedClientDefaultExchange"); + testRestClientTokenPropagation(true, "namedClientDefaultExchange_OnMethod"); + testRestClientTokenPropagation(true, "multipleClientsAndMultipleMethods_NamedClientDefaultExchange"); + } + + @Test + public void testNoTokenPropagation() { + RestAssured.given().auth().oauth2(getBearerAccessToken()) + .queryParam("client-key", "multipleClientsAndMultipleMethods_NoAccessToken") + .when().get("/frontend/token-propagation") + .then() + .statusCode(401); } private void testRestClientTokenPropagation(boolean exchangeEnabled, String clientKey) { @@ -122,6 +139,50 @@ public interface NamedClientDefaultExchange { String getUserName(); } + @RegisterRestClient(baseUri = "http://localhost:8081/protected") + @Path("/") + public interface DefaultClientDefaultExchange_OnMethod { + @AccessToken + @GET + String getUserName(); + } + + @RegisterRestClient(baseUri = "http://localhost:8081/protected") + @Path("/") + public interface DefaultClientEnabledExchange_OnMethod { + @AccessToken(exchangeTokenClient = "Default") + @GET + String getUserName(); + } + + @RegisterRestClient(baseUri = "http://localhost:8081/protected") + @Path("/") + public interface MultipleClientsAndMultipleMethods { + + @AccessToken + @GET + String getUserName_DefaultClientDefaultExchange(); + + @AccessToken(exchangeTokenClient = "named") + @GET + String getUserName_NamedClientDefaultExchange(); + + @AccessToken(exchangeTokenClient = "Default") + @GET + String getUserName_DefaultClientEnabledExchange(); + + @GET + String getUserName_NoAccessToken(); + } + + @RegisterRestClient(baseUri = "http://localhost:8081/protected") + @Path("/") + public interface NamedClientDefaultExchange_OnMethod { + @AccessToken(exchangeTokenClient = "named") + @GET + String getUserName(); + } + // tests no AmbiguousResolutionException is raised @Singleton @Unremovable @@ -142,6 +203,22 @@ public static class MultiProviderFrontendResource { @RestClient NamedClientDefaultExchange namedClientDefaultExchange; + @Inject + @RestClient + DefaultClientDefaultExchange_OnMethod defaultClientDefaultExchange_OnMethod; + + @Inject + @RestClient + DefaultClientEnabledExchange_OnMethod defaultClientEnabledExchange_OnMethod; + + @Inject + @RestClient + NamedClientDefaultExchange_OnMethod namedClientDefaultExchange_OnMethod; + + @Inject + @RestClient + MultipleClientsAndMultipleMethods multipleClientsAndMultipleMethods; + @Inject JsonWebToken jwt; @@ -172,6 +249,17 @@ private String getUserName(String clientKey) { case "defaultClientDefaultExchange" -> defaultClientDefaultExchange.getUserName(); case "defaultClientEnabledExchange" -> defaultClientEnabledExchange.getUserName(); case "namedClientDefaultExchange" -> namedClientDefaultExchange.getUserName(); + case "defaultClientDefaultExchange_OnMethod" -> defaultClientDefaultExchange_OnMethod.getUserName(); + case "defaultClientEnabledExchange_OnMethod" -> defaultClientEnabledExchange_OnMethod.getUserName(); + case "namedClientDefaultExchange_OnMethod" -> namedClientDefaultExchange_OnMethod.getUserName(); + case "multipleClientsAndMultipleMethods_DefaultClientDefaultExchange" -> + multipleClientsAndMultipleMethods.getUserName_DefaultClientDefaultExchange(); + case "multipleClientsAndMultipleMethods_DefaultClientEnabledExchange" -> + multipleClientsAndMultipleMethods.getUserName_DefaultClientEnabledExchange(); + case "multipleClientsAndMultipleMethods_NamedClientDefaultExchange" -> + multipleClientsAndMultipleMethods.getUserName_NamedClientDefaultExchange(); + case "multipleClientsAndMultipleMethods_NoAccessToken" -> + multipleClientsAndMultipleMethods.getUserName_NoAccessToken(); default -> throw new IllegalArgumentException("Unknown client key"); }; } diff --git a/extensions/oidc-token-propagation/runtime/src/main/java/io/quarkus/oidc/token/propagation/AccessTokenRequestFilter.java b/extensions/oidc-token-propagation/runtime/src/main/java/io/quarkus/oidc/token/propagation/AccessTokenRequestFilter.java index 1ba269f42f3bb..d94366a36cdb8 100644 --- a/extensions/oidc-token-propagation/runtime/src/main/java/io/quarkus/oidc/token/propagation/AccessTokenRequestFilter.java +++ b/extensions/oidc-token-propagation/runtime/src/main/java/io/quarkus/oidc/token/propagation/AccessTokenRequestFilter.java @@ -4,6 +4,7 @@ import static io.quarkus.oidc.token.propagation.common.runtime.TokenPropagationConstants.OIDC_PROPAGATE_TOKEN_CREDENTIAL; import java.io.IOException; +import java.lang.reflect.Method; import java.util.Collections; import jakarta.annotation.PostConstruct; @@ -21,6 +22,7 @@ import io.quarkus.oidc.token.propagation.runtime.AbstractTokenRequestFilter; import io.quarkus.runtime.configuration.ConfigurationException; import io.quarkus.security.credential.TokenCredential; +import io.quarkus.security.spi.runtime.MethodDescription; import io.quarkus.vertx.core.runtime.context.VertxContextSafetyToggle; import io.vertx.core.Vertx; @@ -28,6 +30,7 @@ public class AccessTokenRequestFilter extends AbstractTokenRequestFilter { // note: We can't use constructor injection for these fields because they are registered by RESTEasy // which doesn't know about CDI at the point of registration + private static final String INVOKED_METHOD_PROP = "org.eclipse.microprofile.rest.client.invokedMethod"; private static final String ERROR_MSG = "OIDC Token Propagation requires a safe (isolated) Vert.x sub-context because configuration property 'quarkus.resteasy-client-oidc-token-propagation.enabled-during-authentication' has been set to true, but the current context hasn't been flagged as such."; private final boolean enabledDuringAuthentication; private final Instance accessToken; @@ -70,6 +73,10 @@ protected boolean isExchangeToken() { @Override public void filter(ClientRequestContext requestContext) throws IOException { + if (skipPropagation(requestContext)) { + return; + } + if (acquireTokenCredentialFromCtx(requestContext)) { propagateToken(requestContext, exchangeTokenIfNeeded(getTokenCredentialFromContext().getToken())); } else { @@ -114,4 +121,30 @@ private static TokenCredential getTokenCredentialFromContext() { VertxContextSafetyToggle.validateContextIfExists(ERROR_MSG, ERROR_MSG); return Vertx.currentContext().getLocal(TokenCredential.class.getName()); } + + private boolean skipPropagation(ClientRequestContext requestContext) { + if (getMethodDescription() == null) { + return false; + } + + return !getMethodDescription().equals(getInvokedRestClientMethodSignature(requestContext)); + } + + private static MethodDescription getInvokedRestClientMethodSignature(ClientRequestContext requestContext) { + Method method = (Method) requestContext.getProperty(INVOKED_METHOD_PROP); + if (method == null) { + throw new IllegalStateException(INVOKED_METHOD_PROP + " property must not be null"); + } + return MethodDescription.ofMethod(method); + } + + /** + * This method is overridden by generated filter classes if the filter should only be applied on the REST client + * method. + * + * @return REST client method description for which this filter should be applied; or null if applies to all methods + */ + protected MethodDescription getMethodDescription() { + return null; + } }