Skip to content
This repository was archived by the owner on Oct 18, 2023. It is now read-only.

Commit 7b20fb1

Browse files
authored
per-namespace authentication (#685)
1 parent 983b9a8 commit 7b20fb1

File tree

11 files changed

+188
-65
lines changed

11 files changed

+188
-65
lines changed

sqld/src/auth.rs

Lines changed: 120 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use anyhow::{bail, Context as _, Result};
22
use axum::http::HeaderValue;
33
use tonic::Status;
44

5+
use crate::{namespace::NamespaceName, rpc::NAMESPACE_METADATA_KEY};
6+
57
static GRPC_AUTH_HEADER: &str = "x-authorization";
68
static GRPC_PROXY_AUTH_HEADER: &str = "x-proxy-authorization";
79

@@ -42,16 +44,22 @@ pub enum AuthError {
4244
Other,
4345
}
4446

47+
#[derive(Clone, Debug, PartialEq, Eq)]
48+
pub struct Authorized {
49+
pub namespace: Option<NamespaceName>,
50+
pub permission: Permission,
51+
}
52+
4553
#[non_exhaustive]
4654
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
47-
pub enum Authorized {
55+
pub enum Permission {
4856
FullAccess,
4957
ReadOnly,
5058
}
5159

5260
/// A witness that the user has been authenticated.
5361
#[non_exhaustive]
54-
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
62+
#[derive(Clone, Debug, PartialEq, Eq)]
5563
pub enum Authenticated {
5664
Anonymous,
5765
Authorized(Authorized),
@@ -61,9 +69,13 @@ impl Auth {
6169
pub fn authenticate_http(
6270
&self,
6371
auth_header: Option<&hyper::header::HeaderValue>,
72+
disable_namespaces: bool,
6473
) -> Result<Authenticated, AuthError> {
6574
if self.disabled {
66-
return Ok(Authenticated::Authorized(Authorized::FullAccess));
75+
return Ok(Authenticated::Authorized(Authorized {
76+
namespace: None,
77+
permission: Permission::FullAccess,
78+
}));
6779
}
6880

6981
let Some(auth_header) = auth_header else {
@@ -80,57 +92,98 @@ impl Auth {
8092
let actual_value = actual_value.trim_end_matches('=');
8193
let expected_value = expected_value.trim_end_matches('=');
8294
if actual_value == expected_value {
83-
Ok(Authenticated::Authorized(Authorized::FullAccess))
95+
Ok(Authenticated::Authorized(Authorized {
96+
namespace: None,
97+
permission: Permission::FullAccess,
98+
}))
8499
} else {
85100
Err(AuthError::BasicRejected)
86101
}
87102
}
88-
HttpAuthHeader::Bearer(token) => self.validate_jwt(&token),
103+
HttpAuthHeader::Bearer(token) => self.validate_jwt(&token, disable_namespaces),
89104
}
90105
}
91106

92-
pub fn authenticate_grpc<T>(&self, req: &tonic::Request<T>) -> Result<Authenticated, Status> {
107+
pub fn authenticate_grpc<T>(
108+
&self,
109+
req: &tonic::Request<T>,
110+
disable_namespaces: bool,
111+
) -> Result<Authenticated, Status> {
93112
let metadata = req.metadata();
94113

95114
let auth = metadata
96115
.get(GRPC_AUTH_HEADER)
97116
.map(|v| v.to_bytes().expect("Auth should always be ASCII"))
98117
.map(|v| HeaderValue::from_maybe_shared(v).expect("Should already be valid header"));
99118

100-
self.authenticate_http(auth.as_ref()).map_err(Into::into)
119+
self.authenticate_http(auth.as_ref(), disable_namespaces)
120+
.map_err(Into::into)
101121
}
102122

103-
pub fn authenticate_jwt(&self, jwt: Option<&str>) -> Result<Authenticated, AuthError> {
123+
pub fn authenticate_jwt(
124+
&self,
125+
jwt: Option<&str>,
126+
disable_namespaces: bool,
127+
) -> Result<Authenticated, AuthError> {
104128
if self.disabled {
105-
return Ok(Authenticated::Authorized(Authorized::FullAccess));
129+
return Ok(Authenticated::Authorized(Authorized {
130+
namespace: None,
131+
permission: Permission::FullAccess,
132+
}));
106133
}
107134

108135
let Some(jwt) = jwt else {
109136
return Err(AuthError::JwtMissing)
110137
};
111138

112-
self.validate_jwt(jwt)
139+
self.validate_jwt(jwt, disable_namespaces)
113140
}
114141

115-
fn validate_jwt(&self, jwt: &str) -> Result<Authenticated, AuthError> {
142+
fn validate_jwt(
143+
&self,
144+
jwt: &str,
145+
disable_namespaces: bool,
146+
) -> Result<Authenticated, AuthError> {
116147
let Some(jwt_key) = self.jwt_key.as_ref() else {
117148
return Err(AuthError::JwtNotAllowed)
118149
};
119-
validate_jwt(jwt_key, jwt)
150+
validate_jwt(jwt_key, jwt, disable_namespaces)
120151
}
121152
}
122153

123154
impl Authenticated {
124-
pub fn from_proxy_grpc_request<T>(req: &tonic::Request<T>) -> Result<Self, Status> {
155+
pub fn from_proxy_grpc_request<T>(
156+
req: &tonic::Request<T>,
157+
disable_namespace: bool,
158+
) -> Result<Self, Status> {
159+
let namespace = if disable_namespace {
160+
None
161+
} else {
162+
req.metadata()
163+
.get_bin(NAMESPACE_METADATA_KEY)
164+
.map(|c| c.to_bytes())
165+
.transpose()
166+
.map_err(|_| Status::invalid_argument("failed to parse namespace header"))?
167+
.map(NamespaceName::from_bytes)
168+
.transpose()
169+
.map_err(|_| Status::invalid_argument("invalid namespace name"))?
170+
};
171+
125172
let auth = match req
126173
.metadata()
127174
.get(GRPC_PROXY_AUTH_HEADER)
128175
.map(|v| v.to_str())
129176
.transpose()
130177
.map_err(|_| Status::invalid_argument("missing authorization header"))?
131178
{
132-
Some("full_access") => Authenticated::Authorized(Authorized::FullAccess),
133-
Some("read_only") => Authenticated::Authorized(Authorized::ReadOnly),
179+
Some("full_access") => Authenticated::Authorized(Authorized {
180+
namespace,
181+
permission: Permission::FullAccess,
182+
}),
183+
Some("read_only") => Authenticated::Authorized(Authorized {
184+
namespace,
185+
permission: Permission::ReadOnly,
186+
}),
134187
Some("anonymous") => Authenticated::Anonymous,
135188
Some(level) => {
136189
return Err(Status::permission_denied(format!(
@@ -149,14 +202,34 @@ impl Authenticated {
149202

150203
let auth = match self {
151204
Authenticated::Anonymous => "anonymous",
152-
Authenticated::Authorized(Authorized::FullAccess) => "full_access",
153-
Authenticated::Authorized(Authorized::ReadOnly) => "read_only",
205+
Authenticated::Authorized(Authorized {
206+
permission: Permission::FullAccess,
207+
..
208+
}) => "full_access",
209+
Authenticated::Authorized(Authorized {
210+
permission: Permission::ReadOnly,
211+
..
212+
}) => "read_only",
154213
};
155214

156215
let value = tonic::metadata::AsciiMetadataValue::try_from(auth).unwrap();
157216

158217
req.metadata_mut().insert(key, value);
159218
}
219+
220+
pub fn is_namespace_authorized(&self, namespace: &NamespaceName) -> bool {
221+
match self {
222+
Authenticated::Anonymous => true,
223+
Authenticated::Authorized(Authorized {
224+
namespace: Some(ns),
225+
..
226+
}) => ns == namespace,
227+
// we threat the absence of a specific namespace has a permission to any namespace
228+
Authenticated::Authorized(Authorized {
229+
namespace: None, ..
230+
}) => true,
231+
}
232+
}
160233
}
161234

162235
#[derive(Debug)]
@@ -188,6 +261,7 @@ fn parse_http_auth_header(
188261
fn validate_jwt(
189262
jwt_key: &jsonwebtoken::DecodingKey,
190263
jwt: &str,
264+
disable_namespace: bool,
191265
) -> Result<Authenticated, AuthError> {
192266
use jsonwebtoken::errors::ErrorKind;
193267

@@ -197,13 +271,26 @@ fn validate_jwt(
197271
match jsonwebtoken::decode::<serde_json::Value>(jwt, jwt_key, &validation).map(|t| t.claims) {
198272
Ok(serde_json::Value::Object(claims)) => {
199273
tracing::trace!("Claims: {claims:#?}");
200-
Ok(match claims.get("a").and_then(|s| s.as_str()) {
201-
Some("ro") => Authenticated::Authorized(Authorized::ReadOnly),
202-
Some("rw") => Authenticated::Authorized(Authorized::FullAccess),
203-
Some(_) => Authenticated::Anonymous,
274+
let namespace = if disable_namespace {
275+
None
276+
} else {
277+
claims
278+
.get("id")
279+
.and_then(|ns| NamespaceName::from_string(ns.as_str()?.into()).ok())
280+
};
281+
282+
let permission = match claims.get("a").and_then(|s| s.as_str()) {
283+
Some("ro") => Permission::ReadOnly,
284+
Some("rw") => Permission::FullAccess,
285+
Some(_) => return Ok(Authenticated::Anonymous),
204286
// Backward compatibility - no access claim means full access
205-
None => Authenticated::Authorized(Authorized::FullAccess),
206-
})
287+
None => Permission::FullAccess,
288+
};
289+
290+
Ok(Authenticated::Authorized(Authorized {
291+
namespace,
292+
permission,
293+
}))
207294
}
208295
Ok(_) => Err(AuthError::JwtInvalid),
209296
Err(error) => Err(match error.kind() {
@@ -280,7 +367,7 @@ mod tests {
280367
use hyper::header::HeaderValue;
281368

282369
fn authenticate_http(auth: &Auth, header: &str) -> Result<Authenticated, AuthError> {
283-
auth.authenticate_http(Some(&HeaderValue::from_str(header).unwrap()))
370+
auth.authenticate_http(Some(&HeaderValue::from_str(header).unwrap()), false)
284371
}
285372

286373
const VALID_JWT_KEY: &str = "zaMv-aFGmB7PXkjM4IrMdF6B5zCYEiEGXW3RgMjNAtc";
@@ -312,9 +399,9 @@ mod tests {
312399
#[test]
313400
fn test_default() {
314401
let auth = Auth::default();
315-
assert_err!(auth.authenticate_http(None));
402+
assert_err!(auth.authenticate_http(None, false));
316403
assert_err!(authenticate_http(&auth, "Basic d29qdGVrOnRoZWJlYXI="));
317-
assert_err!(auth.authenticate_jwt(Some(VALID_JWT)));
404+
assert_err!(auth.authenticate_jwt(Some(VALID_JWT), false));
318405
}
319406

320407
#[test]
@@ -332,7 +419,7 @@ mod tests {
332419
assert_err!(authenticate_http(&auth, "Basic d29qdgvronrozwjlyxi="));
333420
assert_err!(authenticate_http(&auth, "Basic d29qdGVrOnRoZWZveA=="));
334421

335-
assert_err!(auth.authenticate_http(None));
422+
assert_err!(auth.authenticate_http(None, false));
336423
assert_err!(authenticate_http(&auth, ""));
337424
assert_err!(authenticate_http(&auth, "foobar"));
338425
assert_err!(authenticate_http(&auth, "foo bar"));
@@ -356,7 +443,10 @@ mod tests {
356443

357444
assert_eq!(
358445
authenticate_http(&auth, &format!("Bearer {VALID_READONLY_JWT}")).unwrap(),
359-
Authenticated::Authorized(Authorized::ReadOnly)
446+
Authenticated::Authorized(Authorized {
447+
namespace: None,
448+
permission: Permission::ReadOnly
449+
})
360450
);
361451
}
362452

@@ -366,7 +456,7 @@ mod tests {
366456
jwt_key: Some(parse_jwt_key(VALID_JWT_KEY).unwrap()),
367457
..Auth::default()
368458
};
369-
assert_ok!(auth.authenticate_jwt(Some(VALID_JWT)));
370-
assert_err!(auth.authenticate_jwt(Some(&VALID_JWT[..80])));
459+
assert_ok!(auth.authenticate_jwt(Some(VALID_JWT), false));
460+
assert_err!(auth.authenticate_jwt(Some(&VALID_JWT[..80]), false));
371461
}
372462
}

sqld/src/connection/libsql.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use sqld_libsql_bindings::wal_hook::WalMethodsHook;
88
use tokio::sync::{oneshot, watch};
99
use tracing::warn;
1010

11-
use crate::auth::{Authenticated, Authorized};
11+
use crate::auth::{Authenticated, Authorized, Permission};
1212
use crate::error::Error;
1313
use crate::libsql::wal_hook::WalHook;
1414
use crate::query::Query;
@@ -507,7 +507,13 @@ fn check_program_auth(auth: Authenticated, pgm: &Program) -> Result<()> {
507507
}
508508
(StmtKind::Read, Authenticated::Authorized(_)) => (),
509509
(StmtKind::TxnBegin, _) | (StmtKind::TxnEnd, _) => (),
510-
(_, Authenticated::Authorized(Authorized::FullAccess)) => (),
510+
(
511+
_,
512+
Authenticated::Authorized(Authorized {
513+
permission: Permission::FullAccess,
514+
..
515+
}),
516+
) => (),
511517
_ => {
512518
return Err(Error::NotAuthorized(format!(
513519
"Current session is not authorized to run: {}",

sqld/src/connection/write_proxy.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,10 @@ impl Connection for WriteProxyConnection {
282282
// transaction, so we rollback the replica, and execute again on the primary.
283283
let (builder, new_state) = self
284284
.read_conn
285-
.execute_program(pgm.clone(), auth, builder)
285+
.execute_program(pgm.clone(), auth.clone(), builder)
286286
.await?;
287287
if new_state != State::Init {
288-
self.read_conn.rollback(auth).await?;
288+
self.read_conn.rollback(auth.clone()).await?;
289289
self.execute_remote(pgm, &mut state, auth, builder).await
290290
} else {
291291
Ok((builder, new_state))

sqld/src/hrana/http/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ async fn handle_pipeline<C: Connection>(
112112
let mut results = Vec::with_capacity(req_body.requests.len());
113113
for request in req_body.requests.into_iter() {
114114
tracing::debug!("pipeline:{{ {:?}, {:?} }}", version, request);
115-
let result = request::handle(&mut stream_guard, auth, request, version).await?;
115+
let result = request::handle(&mut stream_guard, auth.clone(), request, version).await?;
116116
results.push(result);
117117
}
118118

sqld/src/hrana/ws/conn.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ use tokio::sync::oneshot;
1111
use tokio_tungstenite::tungstenite;
1212
use tungstenite::protocol::frame::coding::CloseCode;
1313

14-
use crate::connection::MakeConnection;
1514
use crate::database::Database;
1615
use crate::namespace::{MakeNamespace, NamespaceName};
1716

@@ -35,7 +34,8 @@ struct Conn<F: MakeNamespace> {
3534
join_set: tokio::task::JoinSet<()>,
3635
/// Future responses to requests that we have received but are evaluating asynchronously.
3736
responses: FuturesUnordered<ResponseFuture>,
38-
connection_maker: Arc<dyn MakeConnection<Connection = <F::Database as Database>::Connection>>,
37+
/// Namespace queried by this connections
38+
namespace: NamespaceName,
3939
}
4040

4141
/// A `Future` that stores a handle to a future response to request which is being evaluated
@@ -95,10 +95,6 @@ async fn handle_ws<F: MakeNamespace>(
9595
conn_id: u64,
9696
namespace: NamespaceName,
9797
) -> Result<()> {
98-
let connection_maker = server
99-
.namespaces
100-
.with(namespace, |ns| ns.db.connection_maker())
101-
.await?;
10298
let mut conn = Conn {
10399
conn_id,
104100
server,
@@ -109,7 +105,7 @@ async fn handle_ws<F: MakeNamespace>(
109105
session: None,
110106
join_set: tokio::task::JoinSet::new(),
111107
responses: FuturesUnordered::new(),
112-
connection_maker,
108+
namespace,
113109
};
114110

115111
loop {
@@ -259,7 +255,7 @@ async fn handle_request_msg<F: MakeNamespace>(
259255
session,
260256
&mut conn.join_set,
261257
request,
262-
conn.connection_maker.clone(),
258+
conn.namespace.clone(),
263259
)
264260
.await
265261
.unwrap_or_else(|err| {

0 commit comments

Comments
 (0)