Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show amr claim in ID token #902

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import com.fasterxml.jackson.core.JsonProcessingException;
Expand All @@ -33,7 +32,6 @@ public class Jackson2AuditDataSerializer implements AuditDataSerializer {

final ObjectMapper mapper;

@Autowired
public Jackson2AuditDataSerializer(ObjectMapper mapper) {
this.mapper = mapper;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
package it.infn.mw.iam.authn.multi_factor_authentication;

import java.io.IOException;
import java.util.Iterator;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;

import javax.servlet.FilterChain;
Expand All @@ -29,6 +30,9 @@
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.web.filter.GenericFilterBean;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;

import it.infn.mw.iam.core.ExtendedAuthenticationToken;

/**
Expand Down Expand Up @@ -70,24 +74,20 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha
}

/**
* Convert a set of authentication method references into a request parameter string Values are
* separated with a + symbol
* Convert a set of authentication method references into a JSON array of strings.
*
* @param amrSet the set of authentication method references
* @return the parsed string
* @return the parsed JSON array as a string
* @throws JsonProcessingException if an error occurs while converting to JSON
*/
private String parseAuthenticationMethodReferences(Set<IamAuthenticationMethodReference> amrSet) {
String amrClaim = "";
Iterator<IamAuthenticationMethodReference> it = amrSet.iterator();
while (it.hasNext()) {
IamAuthenticationMethodReference current = it.next();
StringBuilder amrClaimBuilder = new StringBuilder(amrClaim);
amrClaimBuilder.append(current.getName()).append("+");
amrClaim = amrClaimBuilder.toString();
private String parseAuthenticationMethodReferences(Set<IamAuthenticationMethodReference> amrSet)
throws JsonProcessingException {
List<String> amrList = new ArrayList<>();
for (IamAuthenticationMethodReference amr : amrSet) {
amrList.add(amr.getName());
}

// Remove trailing + symbol at end of string
amrClaim = amrClaim.substring(0, amrClaim.length() - 1);
return amrClaim;
ObjectMapper objectMapper = new ObjectMapper();
return objectMapper.writeValueAsString(amrList);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import static it.infn.mw.iam.core.oauth.granters.TokenExchangeTokenGranter.TOKEN_EXCHANGE_GRANT_TYPE;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;

Expand All @@ -36,8 +38,12 @@
import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.security.oauth2.provider.TokenRequest;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Joiner;

import it.infn.mw.iam.authn.multi_factor_authentication.IamAuthenticationMethodReference;
import it.infn.mw.iam.core.ExtendedAuthenticationToken;
import it.infn.mw.iam.core.oauth.profile.JWTProfileResolver;
import it.infn.mw.iam.core.oauth.scope.pdp.ScopeFilter;

Expand Down Expand Up @@ -79,6 +85,16 @@ public AuthorizationRequest createAuthorizationRequest(Map<String, String> input

AuthorizationRequest authzRequest = super.createAuthorizationRequest(inputParams);

if (authn instanceof ExtendedAuthenticationToken extendedAuthenticationToken) {
Set<IamAuthenticationMethodReference> amrSet =
extendedAuthenticationToken.getAuthenticationMethodReferences();
try {
authzRequest.getExtensions().put("amr", parseAuthenticationMethodReferences(amrSet));
} catch (JsonProcessingException e) {
LOG.error("Failed to convert amr set to JSON array", e);
}
}

for (String audienceKey : AUDIENCE_KEYS) {
if (inputParams.containsKey(audienceKey)) {
if (!authzRequest.getExtensions().containsKey(AUD)) {
Expand All @@ -93,6 +109,17 @@ public AuthorizationRequest createAuthorizationRequest(Map<String, String> input

}

private String parseAuthenticationMethodReferences(Set<IamAuthenticationMethodReference> amrSet)
throws JsonProcessingException {
List<String> amrList = new ArrayList<>();
for (IamAuthenticationMethodReference amr : amrSet) {
amrList.add(amr.getName());
}

ObjectMapper objectMapper = new ObjectMapper();
return objectMapper.writeValueAsString(amrList);
}

private void handlePasswordGrantAuthenticationTimestamp(OAuth2Request request) {
if (PASSWORD_GRANT.equals(request.getGrantType())) {
String now = Long.toString(System.currentTimeMillis());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@

import static it.infn.mw.iam.core.oauth.profile.iam.ClaimValueHelper.ADDITIONAL_CLAIMS;

import java.util.List;
import java.util.Set;

import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.model.OAuth2AccessTokenEntity;
import org.mitre.openid.connect.service.ScopeClaimTranslationService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.oauth2.provider.OAuth2Request;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.nimbusds.jwt.JWTClaimsSet.Builder;

import it.infn.mw.iam.config.IamProperties;
Expand All @@ -35,6 +40,8 @@
@SuppressWarnings("deprecation")
public class IamJWTProfileIdTokenCustomizer extends BaseIdTokenCustomizer {

public static final Logger LOG = LoggerFactory.getLogger(IamJWTProfileIdTokenCustomizer.class);

protected final ScopeClaimTranslationService scopeClaimConverter;
protected final ClaimValueHelper claimValueHelper;

Expand All @@ -46,7 +53,6 @@ public IamJWTProfileIdTokenCustomizer(IamAccountRepository accountRepo,
this.claimValueHelper = claimValueHelper;
}


@Override
public void customizeIdTokenClaims(Builder idClaims, ClientDetailsEntity client,
OAuth2Request request, String sub, OAuth2AccessTokenEntity accessToken, IamAccount account) {
Expand All @@ -59,7 +65,20 @@ public void customizeIdTokenClaims(Builder idClaims, ClientDetailsEntity client,
.filter(ADDITIONAL_CLAIMS::contains)
.forEach(c -> idClaims.claim(c, claimValueHelper.getClaimValueFromUserInfo(c, info)));

Object amrClaim = request.getExtensions().get("amr");

if (amrClaim instanceof String amrString) {
try {
ObjectMapper objectMapper = new ObjectMapper();
List<String> amrList =
objectMapper.readValue(amrString, new TypeReference<List<String>>() {});

idClaims.claim("amr", amrList);
} catch (Exception e) {
LOG.error("Failed to deserialize amr claim", e);
}
}

includeLabelsInIdToken(idClaims, account);
}

}