Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,3 @@ publish = false # disable publishing to crates.io
# ensure deps are compatible: https://www.gnu.org/licenses/license-list.en.html#GPLCompatibleLicenses
async-std = { version = "1.13.0", features = ["attributes", "tokio1"] }
rstest = "0.25.0"
serial_test = "3.2.0"
tokio = "1.43.0"
2 changes: 1 addition & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ comment:
require_changes: false # if true: only post the comment if coverage changes

ignore:
- "tests"
- "**/tests/**"
- "third-party"
2 changes: 0 additions & 2 deletions crates/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ objc2-core-foundation = "0.3.0"
[dev-dependencies]
async-std.workspace = true
rstest.workspace = true
serial_test.workspace = true
tokio.workspace = true

[package.metadata.ci]
cargo-run-bin = "1.7.4"
Expand Down
19 changes: 11 additions & 8 deletions crates/server/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ impl OpenApiFromRequest<'_> for AdminGuard {
/// Claims for the JWT.
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub(crate) sub: String,
exp: usize,
/// Subject (user ID) of the JWT
pub sub: String,
/// Expiration time as Unix timestamp
pub exp: usize,
}

const BEARER: &str = "Bearer ";
Expand All @@ -86,9 +88,9 @@ const BEARER: &str = "Bearer ";
pub fn create_token(
user_id: &str,
secret: &str,
) -> String {
) -> Result<String, jsonwebtoken::errors::Error> {
let expiration = chrono::Utc::now()
.checked_add_signed(chrono::Duration::seconds(60))
.checked_add_signed(chrono::Duration::hours(24))
.expect("valid timestamp")
.timestamp();

Expand All @@ -102,7 +104,6 @@ pub fn create_token(
&claims,
&EncodingKey::from_secret(secret.as_ref()),
)
.unwrap()
}

/// Decode a JWT token.
Expand Down Expand Up @@ -161,11 +162,13 @@ pub(crate) fn get_jwt_secret() -> &'static str {
&JWT_SECRET
}

pub(crate) fn hash_password(password: &str) -> String {
hash(password, DEFAULT_COST).unwrap()
/// Hash a password using BCrypt (handles salting internally)
pub fn hash_password(password: &str) -> Result<String, bcrypt::BcryptError> {
hash(password, DEFAULT_COST)
}

pub(crate) fn verify_password(
/// Verify a password against a BCrypt hash
pub fn verify_password(
password: &str,
hash: &str,
) -> bool {
Expand Down
2 changes: 0 additions & 2 deletions crates/server/src/db/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ table! {
id -> Integer,
username -> Text,
password -> Text,
password_salt -> Text,
pin -> Nullable<Text>,
pin_salt -> Nullable<Text>,
admin -> Bool,
}
}
10 changes: 9 additions & 1 deletion crates/server/src/web/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ use crate::globals;

/// Build the web server.
pub fn rocket() -> rocket::Rocket<rocket::Build> {
rocket_with_db_path(None)
}

/// Build the web server with a custom database path (primarily for testing).
pub fn rocket_with_db_path(custom_db_path: Option<String>) -> rocket::Rocket<rocket::Build> {
// the cert path changes depending on if the user wants to use custom certs
let (cert_path, key_path);
if !GLOBAL_SETTINGS.server.use_custom_certs {
Expand All @@ -32,12 +37,15 @@ pub fn rocket() -> rocket::Rocket<rocket::Build> {
certs::ensure_certificates_exist(cert_path.clone(), key_path.clone());
}

// Use custom database path for tests, or default for production
let db_path = custom_db_path.unwrap_or_else(|| globals::APP_PATHS.db_path.clone());

let figment = Figment::from(Config::default())
.merge((
"databases",
rocket::figment::map! {
"sqlite_db" => rocket::figment::map! {
"url" => format!("sqlite://{}", globals::APP_PATHS.db_path),
"url" => format!("sqlite://{}", db_path),
}
},
))
Expand Down
11 changes: 10 additions & 1 deletion crates/server/src/web/routes/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,21 @@ pub async fn login(
// debug print user info from db
println!("Found user in db: {:?}", user);

// Verify password using BCrypt
if !crate::auth::verify_password(&form.password, &user.password) {
println!("Password verification failed");
return Err(Status::Unauthorized);
}

let token = crate::auth::create_token(&user.id.to_string(), crate::auth::get_jwt_secret());
let token = match crate::auth::create_token(&user.id.to_string(), crate::auth::get_jwt_secret())
{
Ok(token) => token,
Err(e) => {
println!("Failed to create token: {}", e);
return Err(Status::InternalServerError);
}
};

Ok(Json(TokenResponse { token }))
}

Expand Down
25 changes: 18 additions & 7 deletions crates/server/src/web/routes/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,47 @@ pub struct CreateUserForm {
pub async fn create_user(
db: DbConn,
user_form: Json<CreateUserForm>,
admin_guard: Option<AdminGuard>,
auth_guard: Option<AdminGuard>,
) -> Result<&'static str, Status> {
use crate::db::schema::users::dsl::*;
let existing = db

// Check if this is the first user (no authentication required)
let existing_count = db
.run(|conn| users.count().get_result::<i64>(conn))
.await
.unwrap_or(0);

// If there are users, require admin privileges
if existing > 0 && admin_guard.is_none() {
// If there are existing users, require admin privileges
if existing_count > 0 && auth_guard.is_none() {
return Err(Status::Unauthorized);
}

let form = user_form.into_inner();
let hashed_password = crate::auth::hash_password(&form.password);

// Hash password using BCrypt
let hashed_password = match crate::auth::hash_password(&form.password) {
Ok(hash) => hash,
Err(_) => return Err(Status::InternalServerError),
};

// Hash PIN if provided
let hashed_pin = if let Some(pin_value) = form.pin {
if pin_value.parse::<i32>().is_err() {
return Err(Status::BadRequest);
}
if pin_value.len() < 4 || pin_value.len() > 6 {
return Err(Status::BadRequest);
}
Some(crate::auth::hash_password(&pin_value))
match crate::auth::hash_password(&pin_value) {
Ok(hash) => Some(hash),
Err(_) => return Err(Status::InternalServerError),
}
} else {
None
};

let user = User {
id: 0,
id: 0, // This will be auto-incremented by SQLite
username: form.username,
password: hashed_password,
pin: hashed_pin,
Expand Down
59 changes: 35 additions & 24 deletions crates/server/tests/fixtures/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
// standard imports
use std::fs;
use std::path::Path;
use std::path::PathBuf;

// lib imports
use diesel::Connection;
use diesel::sqlite::SqliteConnection;
use diesel_migrations::MigrationHarness;
use once_cell::sync::Lazy;
use rocket::http::Status;
use rocket::local::asynchronous::Client;
use rstest::fixture;
Expand All @@ -18,27 +17,25 @@ use koko::globals::CURRENT_ENV;
use koko::web::rocket;

// test imports
use crate::test_web::test_request;

// constants
static DB_PATH: Lazy<&'static Path> = Lazy::new(|| Path::new("./test_data/koko.db"));
use crate::test_utils::{TestResponse, make_request};

pub struct TestDb {
pub client: Client,
db_path: PathBuf,
}

impl Drop for TestDb {
fn drop(&mut self) {
if DB_PATH.exists() {
if let Ok(mut conn) = SqliteConnection::establish(DB_PATH.to_str().unwrap()) {
if self.db_path.exists() {
if let Ok(mut conn) = SqliteConnection::establish(self.db_path.to_str().unwrap()) {
let _ = conn.revert_all_migrations(MIGRATIONS);
}

// Sleep to try to all the processes to release the database file
// Sleep to allow processes to release the database file
std::thread::sleep(std::time::Duration::from_secs(1));

// Delete the database file
match fs::remove_file(DB_PATH.clone()) {
match fs::remove_file(&self.db_path) {
Ok(_) => (),
Err(e) => eprintln!("Warning: Failed to delete test database: {}", e),
}
Expand All @@ -50,23 +47,36 @@ impl Drop for TestDb {
pub async fn db_fixture(#[default(false)] base_user: bool) -> TestDb {
CURRENT_ENV.store(1, std::sync::atomic::Ordering::SeqCst);

if let Some(parent) = DB_PATH.parent() {
fs::create_dir_all(parent).expect("Failed to create test_data directory");
}
// Create a unique database file for this test
let test_id = std::thread::current().id();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
let db_path = PathBuf::from(format!(
"./test_data/test_{}_{}.db",
timestamp,
format!("{:?}", test_id)
.replace("ThreadId(", "")
.replace(")", "")
));

// Initialize database with migrations
if let Ok(mut conn) = SqliteConnection::establish(DB_PATH.to_str().unwrap()) {
conn.run_pending_migrations(MIGRATIONS)
.expect("Failed to run migrations");
// Ensure test_data directory exists
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent).expect("Failed to create test_data directory");
}

let rocket = rocket();
let client = Client::tracked(rocket)
// Set the database URL for this test
std::env::set_var("DATABASE_URL", format!("sqlite:{}", db_path.display()));

let rocket_instance = rocket();
let client = Client::tracked(rocket_instance)
.await
.expect("Failed to launch web server");
.expect("Failed to launch rocket for test");

if base_user {
let response = test_request(
let response: TestResponse = make_request(
Some(&client),
"post",
"/create_user",
Some(json!({
Expand All @@ -75,13 +85,14 @@ pub async fn db_fixture(#[default(false)] base_user: bool) -> TestDb {
"pin": "1234",
"admin": true,
})),
Status::Ok,
Some(&client),
None,
Some(Status::Ok),
Some(false),
)
.await;

assert_eq!(response.body, "User created");
}

TestDb { client }
TestDb { client, db_path }
}
2 changes: 2 additions & 0 deletions crates/server/tests/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
pub mod test_auth;
pub mod test_dependencies;
pub mod test_tray;
pub mod test_utils;
pub mod test_web;

pub mod fixtures;
Loading
Loading