Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust-Axum] Implement support for Basic and Bearer auth in Claims #20584

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/generators/rust-axum.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,16 @@ These options may be applied as additional-properties (cli) or configOptions (pl
|Union|✗|OAS3
|allOf|✗|OAS2,OAS3
|anyOf|✗|OAS3
|oneOf||OAS3
|oneOf||OAS3
|not|✗|OAS3

### Security Feature
| Name | Supported | Defined By |
| ---- | --------- | ---------- |
|BasicAuth||OAS2,OAS3
|BasicAuth||OAS2,OAS3
|ApiKey|✓|OAS2,OAS3
|OpenIDConnect|✗|OAS3
|BearerToken||OAS3
|BearerToken||OAS3
|OAuth2_Implicit|✗|OAS2,OAS3
|OAuth2_Password|✗|OAS2,OAS3
|OAuth2_ClientCredentials|✗|OAS2,OAS3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public class RustAxumServerCodegen extends AbstractRustCodegen implements Codege
// Grouping (Method, Operation) by Path.
private final Map<String, ArrayList<MethodOperation>> pathMethodOpMap = new HashMap<>();
private boolean havingAuthMethods = false;
private boolean havingBasicAuthMethods = false;

// Logger
private final Logger LOGGER = LoggerFactory.getLogger(RustAxumServerCodegen.class);
Expand All @@ -98,7 +99,14 @@ public RustAxumServerCodegen() {
WireFormatFeature.Custom
))
.securityFeatures(EnumSet.of(
SecurityFeature.ApiKey
SecurityFeature.ApiKey,
SecurityFeature.BasicAuth,
SecurityFeature.BearerToken
))
.schemaSupportFeatures(EnumSet.of(
SchemaSupportFeature.Simple,
SchemaSupportFeature.Composite,
SchemaSupportFeature.oneOf
))
.excludeGlobalFeatures(
GlobalFeature.Info,
Expand Down Expand Up @@ -777,6 +785,16 @@ private boolean postProcessOperationWithModels(final CodegenOperation op) {

op.vendorExtensions.put("x-has-auth-methods", true);
hasAuthMethod = true;
} else if (s.isBasic) {
op.vendorExtensions.put("x-has-basic-auth-methods", true);
op.vendorExtensions.put("x-is-basic-bearer", s.isBasicBearer);
op.vendorExtensions.put("x-api-auth-header-name", "authorization");

op.vendorExtensions.put("x-has-auth-methods", true);
hasAuthMethod = true;

if (!this.havingBasicAuthMethods)
this.havingBasicAuthMethods = true;
}
}
}
Expand Down Expand Up @@ -878,6 +896,7 @@ public Map<String, Object> postProcessSupportingFileData(Map<String, Object> bun
.collect(Collectors.toList());
bundle.put("pathMethodOps", pathMethodOps);
if (havingAuthMethods) bundle.put("havingAuthMethods", true);
if (havingBasicAuthMethods) bundle.put("havingBasicAuthMethods", true);

return super.postProcessSupportingFileData(bundle);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,43 @@ pub mod {{classFilename}};
pub trait CookieAuthentication {
type Claims;

/// Extracting Claims from Cookie. Return None if the Claims is invalid.
/// Extracting Claims from Cookie. Return None if the Claims are invalid.
async fn extract_claims_from_cookie(&self, cookies: &axum_extra::extract::CookieJar, key: &str) -> Option<Self::Claims>;
}

{{/isKeyInCookie}}
{{#isKeyInHeader}}
/// API Key Authentication - Header.
#[async_trait::async_trait]
pub trait ApiKeyAuthHeader {
type Claims;

/// Extracting Claims from Header. Return None if the Claims is invalid.
/// Extracting Claims from Header. Return None if the Claims are invalid.
async fn extract_claims_from_header(&self, headers: &axum::http::header::HeaderMap, key: &str) -> Option<Self::Claims>;
}

{{/isKeyInHeader}}
{{/isApiKey}}
{{/authMethods}}
{{#havingBasicAuthMethods}}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum BasicAuthKind {
Basic,
Bearer,
}

/// API Key Authentication - Authentication Header.
/// For `Basic token` and `Bearer token`
#[async_trait::async_trait]
pub trait ApiAuthBasic {
type Claims;

/// Extracting Claims from Header. Return None if the Claims are invalid.
async fn extract_claims_from_auth_header(&self, kind: BasicAuthKind, headers: &axum::http::header::HeaderMap, key: &str) -> Option<Self::Claims>;
}

{{/havingBasicAuthMethods}}

// Error handler for unhandled errors.
#[async_trait::async_trait]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ async fn {{#vendorExtensions}}{{{x-operation-id}}}{{/vendorExtensions}}<I, A, E{
{{#x-has-header-auth-methods}}
headers: HeaderMap,
{{/x-has-header-auth-methods}}
{{^x-has-header-auth-methods}}
{{#x-has-basic-auth-methods}}
headers: HeaderMap,
{{/x-has-basic-auth-methods}}
{{/x-has-header-auth-methods}}
{{/vendorExtensions}}
{{/headerParams.size}}
{{#pathParams.size}}
Expand Down Expand Up @@ -54,7 +59,7 @@ async fn {{#vendorExtensions}}{{{x-operation-id}}}{{/vendorExtensions}}<I, A, E{
) -> Result<Response, StatusCode>
where
I: AsRef<A> + Send + Sync,
A: apis::{{classFilename}}::{{classnamePascalCase}}<E{{#havingAuthMethod}}, Claims = C{{/havingAuthMethod}}>{{#vendorExtensions}}{{#x-has-cookie-auth-methods}}+ apis::CookieAuthentication<Claims = C>{{/x-has-cookie-auth-methods}}{{#x-has-header-auth-methods}}+ apis::ApiKeyAuthHeader<Claims = C>{{/x-has-header-auth-methods}}{{/vendorExtensions}} + Send + Sync,
A: apis::{{classFilename}}::{{classnamePascalCase}}<E{{#havingAuthMethod}}, Claims = C{{/havingAuthMethod}}>{{#vendorExtensions}}{{#x-has-cookie-auth-methods}}+ apis::CookieAuthentication<Claims = C>{{/x-has-cookie-auth-methods}}{{#x-has-header-auth-methods}}+ apis::ApiKeyAuthHeader<Claims = C>{{/x-has-header-auth-methods}}{{#x-has-basic-auth-methods}}+ apis::ApiAuthBasic<Claims = C>{{/x-has-basic-auth-methods}}{{/vendorExtensions}} + Send + Sync,
E: std::fmt::Debug + Send + Sync + 'static,
{
{{#vendorExtensions}}
Expand All @@ -67,14 +72,20 @@ where
{{#x-has-header-auth-methods}}
let claims_in_header = api_impl.as_ref().extract_claims_from_header(&headers, "{{x-api-key-header-name}}").await;
{{/x-has-header-auth-methods}}
{{#x-has-basic-auth-methods}}
let claims_in_auth_header = api_impl.as_ref().extract_claims_from_auth_header(apis::BasicAuthKind::{{#x-is-basic-bearer}}Bearer{{/x-is-basic-bearer}}{{^x-is-basic-bearer}}Basic{{/x-is-basic-bearer}}, &headers, "{{x-api-auth-header-name}}").await;
{{/x-has-basic-auth-methods}}
{{#x-has-auth-methods}}
let claims = None
{{#x-has-cookie-auth-methods}}
.or(claims_in_cookie)
{{/x-has-cookie-auth-methods}}
{{#x-has-header-auth-methods}}
.or(claims_in_header)
.or(claims_in_header)
{{/x-has-header-auth-methods}}
{{#x-has-basic-auth-methods}}
.or(claims_in_auth_header)
{{/x-has-basic-auth-methods}}
;
let Some(claims) = claims else {
return Response::builder()
Expand Down Expand Up @@ -346,7 +357,6 @@ where
Err(why) => {
// Application code returned an error. This should not happen, as the implementation should
// return a valid response.

return api_impl.as_ref().handle_error(&method, &host, &cookies, why).await;
},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pub fn new<I, A, E{{#havingAuthMethods}}, C{{/havingAuthMethods}}>(api_impl: I) -> Router
where
I: AsRef<A> + Clone + Send + Sync + 'static,
A: {{#apiInfo}}{{#apis}}{{#operations}}apis::{{classFilename}}::{{classnamePascalCase}}<E{{#havingAuthMethod}}, Claims = C{{/havingAuthMethod}}> + {{/operations}}{{/apis}}{{/apiInfo}}{{#authMethods}}{{#isApiKey}}{{#isKeyInCookie}}apis::CookieAuthentication<Claims = C> + {{/isKeyInCookie}}{{#isKeyInHeader}}apis::ApiKeyAuthHeader<Claims = C> + {{/isKeyInHeader}}{{/isApiKey}}{{/authMethods}}Send + Sync + 'static,
A: {{#apiInfo}}{{#apis}}{{#operations}}apis::{{classFilename}}::{{classnamePascalCase}}<E{{#havingAuthMethod}}, Claims = C{{/havingAuthMethod}}> + {{/operations}}{{/apis}}{{/apiInfo}}{{#authMethods}}{{#isApiKey}}{{#isKeyInCookie}}apis::CookieAuthentication<Claims = C> + {{/isKeyInCookie}}{{#isKeyInHeader}}apis::ApiKeyAuthHeader<Claims = C> + {{/isKeyInHeader}}{{/isApiKey}}{{#isBasic}}apis::ApiAuthBasic<Claims = C> + {{/isBasic}}{{/authMethods}}Send + Sync + 'static,
E: std::fmt::Debug + Send + Sync + 'static,
{{#havingAuthMethods}}C: Send + Sync + 'static,{{/havingAuthMethods}}
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,49 @@ pub mod payments;
pub trait ApiKeyAuthHeader {
type Claims;

/// Extracting Claims from Header. Return None if the Claims is invalid.
/// Extracting Claims from Header. Return None if the Claims are invalid.
async fn extract_claims_from_header(
&self,
headers: &axum::http::header::HeaderMap,
key: &str,
) -> Option<Self::Claims>;
}

/// Cookie Authentication.
#[async_trait::async_trait]
pub trait CookieAuthentication {
type Claims;

/// Extracting Claims from Cookie. Return None if the Claims is invalid.
/// Extracting Claims from Cookie. Return None if the Claims are invalid.
async fn extract_claims_from_cookie(
&self,
cookies: &axum_extra::extract::CookieJar,
key: &str,
) -> Option<Self::Claims>;
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum BasicAuthKind {
Basic,
Bearer,
}

/// API Key Authentication - Authentication Header.
/// For `Basic token` and `Bearer token`
#[async_trait::async_trait]
pub trait ApiAuthBasic {
type Claims;

/// Extracting Claims from Header. Return None if the Claims are invalid.
async fn extract_claims_from_auth_header(
&self,
kind: BasicAuthKind,
headers: &axum::http::header::HeaderMap,
key: &str,
) -> Option<Self::Claims>;
}

// Error handler for unhandled errors.
#[async_trait::async_trait]
pub trait ErrorHandler<E: std::fmt::Debug + Send + Sync + 'static = ()> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub trait Payments<E: std::fmt::Debug + Send + Sync + 'static = ()>:
method: &Method,
host: &Host,
cookies: &CookieJar,
claims: &Self::Claims,
path_params: &models::GetPaymentMethodByIdPathParams,
) -> Result<GetPaymentMethodByIdResponse, E>;

Expand All @@ -62,6 +63,7 @@ pub trait Payments<E: std::fmt::Debug + Send + Sync + 'static = ()>:
method: &Method,
host: &Host,
cookies: &CookieJar,
claims: &Self::Claims,
) -> Result<GetPaymentMethodsResponse, E>;

/// Make a payment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub fn new<I, A, E, C>(api_impl: I) -> Router
where
I: AsRef<A> + Clone + Send + Sync + 'static,
A: apis::payments::Payments<E, Claims = C>
+ apis::ApiAuthBasic<Claims = C>
+ apis::ApiAuthBasic<Claims = C>
Comment on lines +20 to +21
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is odd, but it works. This happens if both a basic and bearer auth are defined.

+ apis::ApiKeyAuthHeader<Claims = C>
+ apis::CookieAuthentication<Claims = C>
+ Send
Expand Down Expand Up @@ -53,14 +55,28 @@ async fn get_payment_method_by_id<I, A, E, C>(
method: Method,
host: Host,
cookies: CookieJar,
headers: HeaderMap,
Path(path_params): Path<models::GetPaymentMethodByIdPathParams>,
State(api_impl): State<I>,
) -> Result<Response, StatusCode>
where
I: AsRef<A> + Send + Sync,
A: apis::payments::Payments<E, Claims = C> + Send + Sync,
A: apis::payments::Payments<E, Claims = C> + apis::ApiAuthBasic<Claims = C> + Send + Sync,
E: std::fmt::Debug + Send + Sync + 'static,
{
// Authentication
let claims_in_auth_header = api_impl
.as_ref()
.extract_claims_from_auth_header(apis::BasicAuthKind::Bearer, &headers, "authorization")
.await;
let claims = None.or(claims_in_auth_header);
let Some(claims) = claims else {
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR);
};

#[allow(clippy::redundant_closure)]
let validation =
tokio::task::spawn_blocking(move || get_payment_method_by_id_validation(path_params))
Expand All @@ -76,7 +92,7 @@ where

let result = api_impl
.as_ref()
.get_payment_method_by_id(&method, &host, &cookies, &path_params)
.get_payment_method_by_id(&method, &host, &cookies, &claims, &path_params)
.await;

let mut response = Response::builder();
Expand Down Expand Up @@ -133,7 +149,6 @@ where
Err(why) => {
// Application code returned an error. This should not happen, as the implementation should
// return a valid response.

return api_impl
.as_ref()
.handle_error(&method, &host, &cookies, why)
Expand All @@ -157,13 +172,27 @@ async fn get_payment_methods<I, A, E, C>(
method: Method,
host: Host,
cookies: CookieJar,
headers: HeaderMap,
State(api_impl): State<I>,
) -> Result<Response, StatusCode>
where
I: AsRef<A> + Send + Sync,
A: apis::payments::Payments<E, Claims = C> + Send + Sync,
A: apis::payments::Payments<E, Claims = C> + apis::ApiAuthBasic<Claims = C> + Send + Sync,
E: std::fmt::Debug + Send + Sync + 'static,
{
// Authentication
let claims_in_auth_header = api_impl
.as_ref()
.extract_claims_from_auth_header(apis::BasicAuthKind::Bearer, &headers, "authorization")
.await;
let claims = None.or(claims_in_auth_header);
let Some(claims) = claims else {
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR);
};

#[allow(clippy::redundant_closure)]
let validation = tokio::task::spawn_blocking(move || get_payment_methods_validation())
.await
Expand All @@ -178,7 +207,7 @@ where

let result = api_impl
.as_ref()
.get_payment_methods(&method, &host, &cookies)
.get_payment_methods(&method, &host, &cookies, &claims)
.await;

let mut response = Response::builder();
Expand Down Expand Up @@ -212,7 +241,6 @@ where
Err(why) => {
// Application code returned an error. This should not happen, as the implementation should
// return a valid response.

return api_impl
.as_ref()
.handle_error(&method, &host, &cookies, why)
Expand Down Expand Up @@ -250,13 +278,15 @@ async fn post_make_payment<I, A, E, C>(
method: Method,
host: Host,
cookies: CookieJar,
headers: HeaderMap,
State(api_impl): State<I>,
Json(body): Json<Option<models::Payment>>,
) -> Result<Response, StatusCode>
where
I: AsRef<A> + Send + Sync,
A: apis::payments::Payments<E, Claims = C>
+ apis::CookieAuthentication<Claims = C>
+ apis::ApiAuthBasic<Claims = C>
+ Send
+ Sync,
E: std::fmt::Debug + Send + Sync + 'static,
Expand All @@ -266,7 +296,11 @@ where
.as_ref()
.extract_claims_from_cookie(&cookies, "X-API-Key")
.await;
let claims = None.or(claims_in_cookie);
let claims_in_auth_header = api_impl
.as_ref()
.extract_claims_from_auth_header(apis::BasicAuthKind::Bearer, &headers, "authorization")
.await;
let claims = None.or(claims_in_cookie).or(claims_in_auth_header);
let Some(claims) = claims else {
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
Expand Down Expand Up @@ -345,7 +379,6 @@ where
Err(why) => {
// Application code returned an error. This should not happen, as the implementation should
// return a valid response.

return api_impl
.as_ref()
.handle_error(&method, &host, &cookies, why)
Expand Down
Loading
Loading