Skip to content

Commit 0aafd49

Browse files
committed
refactor(server): normalize compute driver config acquisition
Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent 7dd0077 commit 0aafd49

4 files changed

Lines changed: 621 additions & 420 deletions

File tree

crates/openshell-server/src/cli.rs

Lines changed: 90 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ use tracing::{info, warn};
1414
use tracing_subscriber::EnvFilter;
1515

1616
use crate::certgen;
17-
use crate::compute::{DockerComputeConfig, VmComputeConfig};
17+
use crate::compute::driver_config::GuestTlsPaths;
1818
use crate::config_file::{self, ConfigFile, GatewayFileSection};
1919
use crate::defaults::{self, LocalTlsPaths};
20-
use crate::{run_server, tracing_bus::TracingLogBus};
20+
use crate::{ServerStartupConfig, run_server, tracing_bus::TracingLogBus};
2121

2222
/// `OpenShell` gateway process - gRPC and HTTP server with protocol multiplexing.
2323
///
@@ -232,34 +232,30 @@ pub async fn run_cli() -> Result<()> {
232232
}
233233
}
234234

235-
async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
235+
fn prepare_server_config(args: &mut RunArgs, matches: &ArgMatches) -> Result<ServerStartupConfig> {
236236
// Load TOML when explicitly requested, or from the default XDG location
237237
// when that file exists. Missing default config is not an error: runtime
238238
// defaults and OPENSHELL_* env vars are enough for package-managed starts.
239-
let config_path = resolve_config_path(&args)?;
239+
let config_path = resolve_config_path(args)?;
240240
let file: Option<ConfigFile> = if let Some(path) = config_path {
241241
Some(config_file::load(&path).map_err(|e| miette::miette!("{e}"))?)
242242
} else {
243243
None
244244
};
245245
if let Some(file) = file.as_ref() {
246-
merge_file_into_args(&mut args, &file.openshell.gateway, &matches);
246+
merge_file_into_args(args, &file.openshell.gateway, matches);
247247
}
248-
normalize_compute_driver_socket_args(&mut args, &matches)?;
248+
normalize_compute_driver_socket_args(args, matches)?;
249249

250-
let local_tls = apply_runtime_defaults(&mut args)?;
250+
let local_tls = apply_runtime_defaults(args)?;
251+
let guest_tls = local_tls.as_ref().map(GuestTlsPaths::from);
251252
let local_jwt = defaults::complete_local_jwt_config()?;
252253

253-
let tracing_log_bus = TracingLogBus::new();
254-
tracing_log_bus.install_subscriber(
255-
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&args.log_level)),
256-
);
257-
258254
let bind = SocketAddr::new(args.bind_address, args.port);
259255

260256
let has_client_ca = args.tls_client_ca.is_some();
261257
let has_oidc = args.oidc_issuer.is_some();
262-
let mtls_auth_enabled = resolve_mtls_auth_enabled(&args, &matches, file.as_ref());
258+
let mtls_auth_enabled = resolve_mtls_auth_enabled(args, matches, file.as_ref());
263259

264260
if args.disable_tls && has_client_ca {
265261
return Err(miette::miette!(
@@ -278,7 +274,7 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
278274
}
279275
if mtls_auth_enabled
280276
&& matches!(
281-
effective_single_driver(&args),
277+
effective_single_driver(args),
282278
Some(ComputeDriverKind::Kubernetes)
283279
)
284280
{
@@ -329,14 +325,14 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
329325
let health_bind = resolve_aux_listener(
330326
args.bind_address,
331327
args.health_port,
332-
&matches,
328+
matches,
333329
"health_port",
334330
|| file_gateway.and_then(|g| g.health_bind_address),
335331
);
336332
let metrics_bind = resolve_aux_listener(
337333
args.bind_address,
338334
args.metrics_port,
339-
&matches,
335+
matches,
340336
"metrics_port",
341337
|| file_gateway.and_then(|g| g.metrics_bind_address),
342338
);
@@ -422,15 +418,31 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
422418
config.gateway_jwt = Some(jwt);
423419
}
424420

425-
let vm_config = build_vm_config(
426-
file.as_ref(),
427-
local_tls.as_ref(),
428-
args.disable_tls,
429-
args.port,
430-
)?;
431-
let docker_config = build_docker_config(file.as_ref(), local_tls.as_ref())?;
421+
Ok(ServerStartupConfig {
422+
config,
423+
config_file: file,
424+
guest_tls,
425+
})
426+
}
432427

433-
if args.disable_tls {
428+
async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
429+
let prepared = prepare_server_config(&mut args, &matches)?;
430+
431+
let tracing_log_bus = TracingLogBus::new();
432+
tracing_log_bus.install_subscriber(
433+
EnvFilter::try_from_default_env()
434+
.unwrap_or_else(|_| EnvFilter::new(&prepared.config.log_level)),
435+
);
436+
437+
let has_client_ca = prepared
438+
.config
439+
.tls
440+
.as_ref()
441+
.and_then(|tls| tls.client_ca_path.as_ref())
442+
.is_some();
443+
let has_oidc = prepared.config.oidc.is_some();
444+
445+
if prepared.config.tls.is_none() {
434446
warn!("TLS disabled — listening on plaintext HTTP");
435447
} else {
436448
info!("TLS enabled — listening on encrypted HTTPS");
@@ -439,40 +451,34 @@ async fn run_from_args(mut args: RunArgs, matches: ArgMatches) -> Result<()> {
439451
if has_client_ca {
440452
info!("TLS client certificate verification enabled");
441453
}
442-
if config.mtls_auth.enabled {
454+
if prepared.config.mtls_auth.enabled {
443455
info!("mTLS user authentication enabled");
444456
}
445457
if has_oidc {
446458
info!("OIDC authentication enabled");
447459
}
448-
if config.auth.allow_unauthenticated_users {
460+
if prepared.config.auth.allow_unauthenticated_users {
449461
warn!(
450462
"Unauthenticated user access enabled — only use this for trusted local development or a fully trusted fronting proxy"
451463
);
452464
}
453465

454-
if !config.auth.allow_unauthenticated_users
455-
&& !config.mtls_auth.enabled
466+
if !prepared.config.auth.allow_unauthenticated_users
467+
&& !prepared.config.mtls_auth.enabled
456468
&& !has_oidc
457-
&& config.gateway_jwt.is_none()
469+
&& prepared.config.gateway_jwt.is_none()
458470
{
459471
warn!(
460472
"Neither mTLS user auth nor OIDC nor sandbox JWT auth is configured — \
461473
the gateway has no authentication mechanism"
462474
);
463475
}
464476

465-
info!(bind = %config.bind_address, "Starting OpenShell server");
477+
info!(bind = %prepared.config.bind_address, "Starting OpenShell server");
466478

467-
Box::pin(run_server(
468-
config,
469-
vm_config,
470-
docker_config,
471-
file,
472-
tracing_log_bus,
473-
))
474-
.await
475-
.into_diagnostic()
479+
Box::pin(run_server(prepared, tracing_log_bus))
480+
.await
481+
.into_diagnostic()
476482
}
477483

478484
fn parse_compute_driver(value: &str) -> std::result::Result<String, String> {
@@ -751,87 +757,6 @@ fn resolve_mtls_auth_enabled(
751757
is_singleplayer_driver(args)
752758
}
753759

754-
/// Build [`VmComputeConfig`] from the `[openshell.drivers.vm]` table
755-
/// inherited from `[openshell.gateway]`.
756-
fn build_vm_config(
757-
file: Option<&ConfigFile>,
758-
local_tls: Option<&LocalTlsPaths>,
759-
disable_tls: bool,
760-
gateway_port: u16,
761-
) -> Result<VmComputeConfig> {
762-
let mut cfg = if let Some(file) = file {
763-
let merged = config_file::driver_table(
764-
ComputeDriverKind::Vm.as_str(),
765-
&file.openshell.gateway,
766-
file.openshell.drivers.get("vm"),
767-
);
768-
merged
769-
.try_into::<VmComputeConfig>()
770-
.map_err(|e| miette::miette!("invalid [openshell.drivers.vm] table: {e}"))?
771-
} else {
772-
VmComputeConfig::default()
773-
};
774-
775-
if cfg.state_dir.as_os_str().is_empty() {
776-
cfg.state_dir = VmComputeConfig::default_state_dir();
777-
}
778-
if cfg.grpc_endpoint.trim().is_empty() && (disable_tls || local_tls.is_some()) {
779-
let scheme = if disable_tls { "http" } else { "https" };
780-
cfg.grpc_endpoint = format!("{scheme}://127.0.0.1:{gateway_port}");
781-
}
782-
apply_guest_tls_defaults(
783-
&mut cfg.guest_tls_ca,
784-
&mut cfg.guest_tls_cert,
785-
&mut cfg.guest_tls_key,
786-
local_tls,
787-
);
788-
Ok(cfg)
789-
}
790-
791-
/// Build [`DockerComputeConfig`] using the same inheritance pattern as
792-
/// [`build_vm_config`].
793-
fn build_docker_config(
794-
file: Option<&ConfigFile>,
795-
local_tls: Option<&LocalTlsPaths>,
796-
) -> Result<DockerComputeConfig> {
797-
let mut cfg = if let Some(file) = file {
798-
let merged = config_file::driver_table(
799-
ComputeDriverKind::Docker.as_str(),
800-
&file.openshell.gateway,
801-
file.openshell.drivers.get("docker"),
802-
);
803-
merged
804-
.try_into::<DockerComputeConfig>()
805-
.map_err(|e| miette::miette!("invalid [openshell.drivers.docker] table: {e}"))?
806-
} else {
807-
DockerComputeConfig::default()
808-
};
809-
apply_guest_tls_defaults(
810-
&mut cfg.guest_tls_ca,
811-
&mut cfg.guest_tls_cert,
812-
&mut cfg.guest_tls_key,
813-
local_tls,
814-
);
815-
Ok(cfg)
816-
}
817-
818-
fn apply_guest_tls_defaults(
819-
ca: &mut Option<PathBuf>,
820-
cert: &mut Option<PathBuf>,
821-
key: &mut Option<PathBuf>,
822-
local_tls: Option<&LocalTlsPaths>,
823-
) {
824-
if ca.is_none()
825-
&& cert.is_none()
826-
&& key.is_none()
827-
&& let Some(paths) = local_tls
828-
{
829-
*ca = Some(paths.ca.clone());
830-
*cert = Some(paths.client_cert.clone());
831-
*key = Some(paths.client_key.clone());
832-
}
833-
}
834-
835760
#[cfg(test)]
836761
mod tests {
837762
use super::{Cli, command};
@@ -1793,6 +1718,51 @@ enable_loopback_service_http = false
17931718
);
17941719
}
17951720

1721+
#[test]
1722+
fn server_config_preparation_ignores_unselected_driver_tables() {
1723+
let _lock = ENV_LOCK
1724+
.lock()
1725+
.unwrap_or_else(std::sync::PoisonError::into_inner);
1726+
let state = tempfile::tempdir().unwrap();
1727+
let local_tls = tempfile::tempdir().unwrap();
1728+
let _g1 = EnvVarGuard::set("XDG_STATE_HOME", state.path().to_str().unwrap());
1729+
let _g2 = EnvVarGuard::set(
1730+
"OPENSHELL_LOCAL_TLS_DIR",
1731+
local_tls.path().to_str().unwrap(),
1732+
);
1733+
let config_path = state.path().join("gateway.toml");
1734+
std::fs::write(
1735+
&config_path,
1736+
r#"
1737+
[openshell.drivers.docker]
1738+
unknown_docker_key = true
1739+
1740+
[openshell.drivers.vm]
1741+
mem_mib = "not-a-number"
1742+
"#,
1743+
)
1744+
.unwrap();
1745+
1746+
let (mut args, matches) = parse_with_args(&[
1747+
"openshell-gateway",
1748+
"--config",
1749+
config_path.to_str().unwrap(),
1750+
"--db-url",
1751+
"sqlite::memory:",
1752+
"--drivers",
1753+
"podman",
1754+
"--disable-tls",
1755+
]);
1756+
1757+
let prepared =
1758+
super::prepare_server_config(&mut args, &matches).expect("server config is prepared");
1759+
1760+
assert_eq!(prepared.config.compute_drivers, vec!["podman".to_string()]);
1761+
let file = prepared.config_file.expect("config file is preserved");
1762+
assert!(file.openshell.drivers.contains_key("docker"));
1763+
assert!(file.openshell.drivers.contains_key("vm"));
1764+
}
1765+
17961766
#[test]
17971767
fn driver_inherits_shared_image_from_gateway_section() {
17981768
// [openshell.gateway].default_image inherits into the K8s driver
@@ -1839,18 +1809,4 @@ default_image = "k8s-specific:1.0"
18391809
.expect("deserializes");
18401810
assert_eq!(parsed.default_image, "k8s-specific:1.0");
18411811
}
1842-
1843-
#[test]
1844-
fn docker_config_reads_bind_mount_opt_in_from_driver_table() {
1845-
let file = config_file_from_toml(
1846-
r"
1847-
[openshell.drivers.docker]
1848-
enable_bind_mounts = true
1849-
",
1850-
);
1851-
1852-
let cfg = super::build_docker_config(Some(&file), None).expect("docker config");
1853-
1854-
assert!(cfg.enable_bind_mounts);
1855-
}
18561812
}

0 commit comments

Comments
 (0)