diff --git a/CLAUDE.md b/CLAUDE.md index 5c36f38b02..9c659d97dd 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -159,6 +159,11 @@ This is a Rust workspace-based monorepo for Rivet. Key packages and components: - **Shared Libraries** (`shared/{language}/{package}/`) - Libraries shared between the engine and rivetkit (e.g., `shared/typescript/virtual-websocket/`) - **Service Infrastructure** - Distributed services communicate via NATS messaging with service discovery +### Engine Runner Parity +- Keep `engine/sdks/typescript/runner` and `engine/sdks/rust/engine-runner` at feature parity. +- Any behavior, protocol handling, or test coverage added to one runner should be mirrored in the other runner in the same change whenever possible. +- When parity cannot be completed in the same change, explicitly document the gap and add a follow-up task. + ### Important Patterns **Error Handling** diff --git a/Cargo.lock b/Cargo.lock index fba7d437eb..279c7dc43d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4427,16 +4427,29 @@ name = "rivet-engine-runner" version = "2.0.38" dependencies = [ "anyhow", + "async-stream", "async-trait", + "axum 0.8.4", + "base64 0.22.1", + "bytes", "chrono", "futures-util", + "http 1.3.1", + "portpicker", "rand 0.8.5", + "reqwest", + "rivet-config", "rivet-runner-protocol", + "rivet-test-deps", "serde", "serde_bare", "serde_json", + "serde_yaml", + "tempfile", "tokio", + "tokio-stream", "tokio-tungstenite", + "tower 0.5.2", "tracing", "urlencoding", "vbare", diff --git a/engine/packages/engine/tests/actors_alarm.rs b/engine/packages/engine/tests/actors_alarm.rs index 913c6e9918..ed66955c2a 100644 --- a/engine/packages/engine/tests/actors_alarm.rs +++ b/engine/packages/engine/tests/actors_alarm.rs @@ -140,7 +140,7 @@ impl AlarmAndSleepActor { } #[async_trait] -impl TestActor for AlarmAndSleepActor { +impl Actor for AlarmAndSleepActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm actor starting"); @@ -195,7 +195,7 @@ impl AlarmAndSleepOnceActor { } #[async_trait] -impl TestActor for AlarmAndSleepOnceActor { +impl Actor for AlarmAndSleepOnceActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm once actor starting"); @@ -250,7 +250,7 @@ impl AlarmSleepThenClearActor { } #[async_trait] -impl TestActor for AlarmSleepThenClearActor { +impl Actor for AlarmSleepThenClearActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm actor starting"); @@ -311,7 +311,7 @@ impl AlarmSleepThenReplaceActor { } #[async_trait] -impl TestActor for AlarmSleepThenReplaceActor { +impl Actor for AlarmSleepThenReplaceActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm actor starting"); @@ -374,7 +374,7 @@ impl MultipleAlarmSetActor { } #[async_trait] -impl TestActor for MultipleAlarmSetActor { +impl Actor for MultipleAlarmSetActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "multi alarm actor starting"); @@ -429,7 +429,7 @@ impl MultiCycleAlarmActor { } #[async_trait] -impl TestActor for MultiCycleAlarmActor { +impl Actor for MultiCycleAlarmActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "multi cycle alarm actor starting"); @@ -481,7 +481,7 @@ impl AlarmOnceActor { } #[async_trait] -impl TestActor for AlarmOnceActor { +impl Actor for AlarmOnceActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm once actor starting"); @@ -536,7 +536,7 @@ impl AlarmSleepThenCrashActor { } #[async_trait] -impl TestActor for AlarmSleepThenCrashActor { +impl Actor for AlarmSleepThenCrashActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm crash actor starting"); @@ -599,7 +599,7 @@ impl RapidAlarmCycleActor { } #[async_trait] -impl TestActor for RapidAlarmCycleActor { +impl Actor for RapidAlarmCycleActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "rapid alarm cycle actor starting"); @@ -647,7 +647,7 @@ impl SetClearAlarmAndSleepActor { } #[async_trait] -impl TestActor for SetClearAlarmAndSleepActor { +impl Actor for SetClearAlarmAndSleepActor { async fn on_start(&mut self, config: ActorConfig) -> anyhow::Result { let generation = config.generation; tracing::info!(?config.actor_id, generation, "alarm actor starting"); diff --git a/engine/packages/engine/tests/actors_kv_crud.rs b/engine/packages/engine/tests/actors_kv_crud.rs index b3791486e1..4af4b11bc2 100644 --- a/engine/packages/engine/tests/actors_kv_crud.rs +++ b/engine/packages/engine/tests/actors_kv_crud.rs @@ -38,7 +38,7 @@ impl PutAndGetActor { } #[async_trait] -impl TestActor for PutAndGetActor { +impl Actor for PutAndGetActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "put and get actor starting"); @@ -116,7 +116,7 @@ impl GetNonexistentKeyActor { } #[async_trait] -impl TestActor for GetNonexistentKeyActor { +impl Actor for GetNonexistentKeyActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "get nonexistent key actor starting"); @@ -191,7 +191,7 @@ impl PutOverwriteActor { } #[async_trait] -impl TestActor for PutOverwriteActor { +impl Actor for PutOverwriteActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "put overwrite actor starting"); @@ -295,7 +295,7 @@ impl DeleteKeyActor { } #[async_trait] -impl TestActor for DeleteKeyActor { +impl Actor for DeleteKeyActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "delete key actor starting"); @@ -383,7 +383,7 @@ impl DeleteNonexistentKeyActor { } #[async_trait] -impl TestActor for DeleteNonexistentKeyActor { +impl Actor for DeleteNonexistentKeyActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "delete nonexistent key actor starting"); @@ -638,7 +638,7 @@ impl BatchPutActor { } #[async_trait] -impl TestActor for BatchPutActor { +impl Actor for BatchPutActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "batch put actor starting"); @@ -721,7 +721,7 @@ impl BatchGetActor { } #[async_trait] -impl TestActor for BatchGetActor { +impl Actor for BatchGetActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "batch get actor starting"); @@ -808,7 +808,7 @@ impl BatchDeleteActor { } #[async_trait] -impl TestActor for BatchDeleteActor { +impl Actor for BatchDeleteActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "batch delete actor starting"); diff --git a/engine/packages/engine/tests/actors_kv_drop.rs b/engine/packages/engine/tests/actors_kv_drop.rs index 5b1699ad96..ee4ebc2c8a 100644 --- a/engine/packages/engine/tests/actors_kv_drop.rs +++ b/engine/packages/engine/tests/actors_kv_drop.rs @@ -39,7 +39,7 @@ impl DropClearsAllActor { } #[async_trait] -impl TestActor for DropClearsAllActor { +impl Actor for DropClearsAllActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "drop clears all actor starting"); @@ -137,7 +137,7 @@ impl DropEmptyActor { } #[async_trait] -impl TestActor for DropEmptyActor { +impl Actor for DropEmptyActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "drop empty actor starting"); diff --git a/engine/packages/engine/tests/actors_kv_list.rs b/engine/packages/engine/tests/actors_kv_list.rs index 01bd456bb9..006135a36a 100644 --- a/engine/packages/engine/tests/actors_kv_list.rs +++ b/engine/packages/engine/tests/actors_kv_list.rs @@ -39,7 +39,7 @@ impl ListAllEmptyActor { } #[async_trait] -impl TestActor for ListAllEmptyActor { +impl Actor for ListAllEmptyActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list all empty actor starting"); @@ -102,7 +102,7 @@ impl ListAllKeysActor { } #[async_trait] -impl TestActor for ListAllKeysActor { +impl Actor for ListAllKeysActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list all keys actor starting"); @@ -199,7 +199,7 @@ impl ListAllLimitActor { } #[async_trait] -impl TestActor for ListAllLimitActor { +impl Actor for ListAllLimitActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list all limit actor starting"); @@ -277,7 +277,7 @@ impl ListAllReverseActor { } #[async_trait] -impl TestActor for ListAllReverseActor { +impl Actor for ListAllReverseActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list all reverse actor starting"); @@ -368,7 +368,7 @@ impl ListRangeInclusiveActor { } #[async_trait] -impl TestActor for ListRangeInclusiveActor { +impl Actor for ListRangeInclusiveActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list range inclusive actor starting"); @@ -467,7 +467,7 @@ impl ListRangeExclusiveActor { } #[async_trait] -impl TestActor for ListRangeExclusiveActor { +impl Actor for ListRangeExclusiveActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list range exclusive actor starting"); @@ -566,7 +566,7 @@ impl ListPrefixActor { } #[async_trait] -impl TestActor for ListPrefixActor { +impl Actor for ListPrefixActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list prefix actor starting"); @@ -669,7 +669,7 @@ impl ListPrefixNoMatchActor { } #[async_trait] -impl TestActor for ListPrefixNoMatchActor { +impl Actor for ListPrefixNoMatchActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list prefix no match actor starting"); diff --git a/engine/packages/engine/tests/actors_kv_misc.rs b/engine/packages/engine/tests/actors_kv_misc.rs index b68008e1a8..8ccce2c0ff 100644 --- a/engine/packages/engine/tests/actors_kv_misc.rs +++ b/engine/packages/engine/tests/actors_kv_misc.rs @@ -39,7 +39,7 @@ impl BinaryDataActor { } #[async_trait] -impl TestActor for BinaryDataActor { +impl Actor for BinaryDataActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "binary data actor starting"); @@ -114,7 +114,7 @@ impl EmptyValueActor { } #[async_trait] -impl TestActor for EmptyValueActor { +impl Actor for EmptyValueActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "empty value actor starting"); @@ -203,7 +203,7 @@ impl LargeValueActor { } #[async_trait] -impl TestActor for LargeValueActor { +impl Actor for LargeValueActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "large value actor starting"); @@ -286,7 +286,7 @@ impl GetEmptyKeysActor { } #[async_trait] -impl TestActor for GetEmptyKeysActor { +impl Actor for GetEmptyKeysActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "get empty keys actor starting"); @@ -350,7 +350,7 @@ impl ListLimitZeroActor { } #[async_trait] -impl TestActor for ListLimitZeroActor { +impl Actor for ListLimitZeroActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "list limit zero actor starting"); @@ -429,7 +429,7 @@ impl KeyOrderingActor { } #[async_trait] -impl TestActor for KeyOrderingActor { +impl Actor for KeyOrderingActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "key ordering actor starting"); @@ -520,7 +520,7 @@ impl ManyKeysActor { } #[async_trait] -impl TestActor for ManyKeysActor { +impl Actor for ManyKeysActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "many keys actor starting"); diff --git a/engine/packages/engine/tests/actors_lifecycle.rs b/engine/packages/engine/tests/actors_lifecycle.rs index 88aac16ece..ac21e1c7d5 100644 --- a/engine/packages/engine/tests/actors_lifecycle.rs +++ b/engine/packages/engine/tests/actors_lifecycle.rs @@ -42,7 +42,7 @@ fn actor_basic_create() { "runner should have the actor allocated" ); - tracing::info!(?actor_id, runner_id = ?runner.runner_id, "actor allocated to runner"); + tracing::info!(?actor_id, "actor allocated to runner"); }); } @@ -406,7 +406,7 @@ fn actor_explicit_destroy() { actor_id: actor_id.parse().expect("failed to parse actor_id"), }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace.clone()), + namespace: namespace.clone(), }, ) .await @@ -1143,7 +1143,7 @@ fn runner_at_max_capacity() { actor_id: actor_ids[0].parse().unwrap(), }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace.clone()), + namespace: namespace.clone(), }, ) .await diff --git a/engine/packages/engine/tests/actors_scheduling_errors.rs b/engine/packages/engine/tests/actors_scheduling_errors.rs index 07218ab7ce..47ab9714f8 100644 --- a/engine/packages/engine/tests/actors_scheduling_errors.rs +++ b/engine/packages/engine/tests/actors_scheduling_errors.rs @@ -926,7 +926,7 @@ fn serverless_connection_refused_error() { ); } -/// Tests that ServerlessInvalidPayload error is returned when the serverless endpoint +/// Tests that ServerlessInvalidSsePayload error is returned when the serverless endpoint /// returns malformed SSE data. #[test] fn serverless_invalid_payload_error() { @@ -952,7 +952,7 @@ fn serverless_invalid_payload_error() { // Wait for error to be tracked tokio::time::sleep(Duration::from_millis(1500)).await; - // Verify pool error is a base64 or payload error + // Verify pool error is an invalid SSE payload error. let pool_error = get_runner_config_pool_error(guard_port, &namespace, &runner_name).await; @@ -961,15 +961,11 @@ fn serverless_invalid_payload_error() { tracing::info!(?pool_error, "pool error received"); match pool_error { - rivet_types::actor::RunnerPoolError::ServerlessInvalidPayload { message } => { - tracing::info!(?message, "got ServerlessInvalidPayload as expected"); - } - // Could also be InvalidBase64 depending on the payload - rivet_types::actor::RunnerPoolError::ServerlessInvalidBase64 => { - tracing::info!("got ServerlessInvalidBase64 as expected"); + rivet_types::actor::RunnerPoolError::ServerlessInvalidSsePayload { message, .. } => { + tracing::info!(?message, "got ServerlessInvalidSsePayload as expected"); } other => panic!( - "expected ServerlessInvalidPayload or ServerlessInvalidBase64, got: {:?}", + "expected ServerlessInvalidSsePayload, got: {:?}", other ), } diff --git a/engine/packages/engine/tests/api_actors_delete.rs b/engine/packages/engine/tests/api_actors_delete.rs index 0a115fc955..d4eba34d27 100644 --- a/engine/packages/engine/tests/api_actors_delete.rs +++ b/engine/packages/engine/tests/api_actors_delete.rs @@ -36,7 +36,7 @@ fn delete_existing_actor_with_namespace() { actor_id: actor_id.parse().expect("failed to parse actor_id"), }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace.clone()), + namespace: namespace.clone(), }, ) .await @@ -82,7 +82,7 @@ fn delete_existing_actor_without_namespace() { common::api_types::actors::delete::DeletePath { actor_id: actor_id.parse().expect("failed to parse actor_id"), }, - common::api_types::actors::delete::DeleteQuery { namespace: None }, + common::api_types::actors::delete::DeleteQuery { namespace: "".to_string() }, ) .await .expect("failed to delete actor"); @@ -123,7 +123,7 @@ fn delete_actor_current_datacenter() { ctx.leader_dc().guard_port(), common::api_types::actors::delete::DeletePath { actor_id: actor_id }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace.clone()), + namespace: namespace.clone(), }, ) .await @@ -171,7 +171,7 @@ fn delete_actor_remote_datacenter() { actor_id: actor_id.parse().expect("failed to parse actor_id"), }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace.clone()), + namespace: namespace.clone(), }, ) .await @@ -198,7 +198,7 @@ fn delete_non_existent_actor() { common::api_types::actors::delete::DeletePath { actor_id: fake_actor_id, }, - common::api_types::actors::delete::DeleteQuery { namespace: None }, + common::api_types::actors::delete::DeleteQuery { namespace: "".to_string() }, ) .await; @@ -240,7 +240,7 @@ fn delete_actor_wrong_namespace() { actor_id: actor_id.parse().expect("failed to parse actor_id"), }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace2.clone()), + namespace: namespace2.clone(), }, ) .await; @@ -287,7 +287,7 @@ fn delete_with_non_existent_namespace() { actor_id: actor_id.parse().expect("failed to parse actor_id"), }, common::api_types::actors::delete::DeleteQuery { - namespace: Some("non-existent-namespace".to_string()), + namespace: "non-existent-namespace".to_string(), }, ) .await; @@ -337,7 +337,7 @@ fn delete_remote_actor_verify_propagation() { actor_id: actor_id.parse().expect("failed to parse actor_id"), }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace.clone()), + namespace: namespace.clone(), }, ) .await @@ -384,7 +384,7 @@ fn delete_already_destroyed_actor() { actor_id: actor_id.parse().expect("failed to parse actor_id"), }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace.clone()), + namespace: namespace.clone(), }, ) .await @@ -398,7 +398,7 @@ fn delete_already_destroyed_actor() { actor_id: actor_id.parse().expect("failed to parse actor_id"), }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace.clone()), + namespace: namespace.clone(), }, ) .await; @@ -448,7 +448,7 @@ fn delete_actor_twice_rapidly() { actor_id: actor_id.parse().expect("failed to parse actor_id"), }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace.clone()), + namespace: namespace.clone(), }, ) .await @@ -461,7 +461,7 @@ fn delete_actor_twice_rapidly() { actor_id: actor_id_clone.parse().expect("failed to parse actor_id"), }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace_clone.clone()), + namespace: namespace_clone.clone(), }, ) .await diff --git a/engine/packages/engine/tests/api_actors_get_or_create.rs b/engine/packages/engine/tests/api_actors_get_or_create.rs index fab7befbed..fd8e3f80c0 100644 --- a/engine/packages/engine/tests/api_actors_get_or_create.rs +++ b/engine/packages/engine/tests/api_actors_get_or_create.rs @@ -598,7 +598,7 @@ fn get_or_create_with_destroyed_actor() { actor_id: first_actor_id, }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace.clone()), + namespace: namespace.clone(), }, ) .await diff --git a/engine/packages/engine/tests/api_actors_list.rs b/engine/packages/engine/tests/api_actors_list.rs index db27810010..0309e09ed1 100644 --- a/engine/packages/engine/tests/api_actors_list.rs +++ b/engine/packages/engine/tests/api_actors_list.rs @@ -344,7 +344,7 @@ fn list_with_include_destroyed_false() { actor_id: destroyed_actor_id, }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace.clone()), + namespace: namespace.clone(), }, ) .await @@ -424,7 +424,7 @@ fn list_with_include_destroyed_true() { actor_id: res1.actor.actor_id, }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace.clone()), + namespace: namespace.clone(), }, ) .await @@ -938,7 +938,7 @@ fn list_by_key_with_include_destroyed_true() { actor_id: res1.actor.actor_id, }, common::api_types::actors::delete::DeleteQuery { - namespace: Some(namespace.clone()), + namespace: namespace.clone(), }, ) .await diff --git a/engine/packages/engine/tests/api_runner_configs_list.rs b/engine/packages/engine/tests/api_runner_configs_list.rs index 69e9f2c054..6ca5d47a00 100644 --- a/engine/packages/engine/tests/api_runner_configs_list.rs +++ b/engine/packages/engine/tests/api_runner_configs_list.rs @@ -324,6 +324,7 @@ fn list_runner_configs_filter_by_variant_serverless() { min_runners: Some(1), max_runners: 5, runners_margin: Some(2), + metadata_poll_interval: None, }, metadata: None, drain_on_version_upgrade: true, @@ -458,6 +459,7 @@ fn list_runner_configs_validates_returned_data() { min_runners: Some(2), max_runners: 10, runners_margin: Some(3), + metadata_poll_interval: None, }, metadata: Some(serde_json::json!({"key": "value"})), drain_on_version_upgrade: true, @@ -573,6 +575,7 @@ fn list_runner_configs_mixed_variants() { min_runners: Some(1), max_runners: 5, runners_margin: Some(2), + metadata_poll_interval: None, }, metadata: None, drain_on_version_upgrade: true, diff --git a/engine/packages/engine/tests/api_runner_configs_upsert.rs b/engine/packages/engine/tests/api_runner_configs_upsert.rs index 50474eb6f9..697e888e2c 100644 --- a/engine/packages/engine/tests/api_runner_configs_upsert.rs +++ b/engine/packages/engine/tests/api_runner_configs_upsert.rs @@ -98,6 +98,7 @@ fn upsert_runner_config_serverless() { min_runners: Some(1), max_runners: 5, runners_margin: Some(2), + metadata_poll_interval: None, }, metadata: None, drain_on_version_upgrade: true, @@ -518,6 +519,7 @@ fn upsert_runner_config_overwrites_different_variant() { min_runners: Some(1), max_runners: 5, runners_margin: Some(2), + metadata_poll_interval: None, }, metadata: None, drain_on_version_upgrade: true, @@ -617,6 +619,7 @@ fn upsert_runner_config_serverless_slots_per_runner_zero() { min_runners: Some(1), max_runners: 5, runners_margin: Some(2), + metadata_poll_interval: None, }, metadata: None, drain_on_version_upgrade: true, diff --git a/engine/sdks/rust/engine-runner/src/behaviors.rs b/engine/packages/engine/tests/common/behaviors.rs similarity index 95% rename from engine/sdks/rust/engine-runner/src/behaviors.rs rename to engine/packages/engine/tests/common/behaviors.rs index ee9973b7c7..8831918de3 100644 --- a/engine/sdks/rust/engine-runner/src/behaviors.rs +++ b/engine/packages/engine/tests/common/behaviors.rs @@ -1,6 +1,8 @@ -use crate::actor::*; +#![allow(dead_code)] + use anyhow::Result; use async_trait::async_trait; +use super::test_runner::{Actor, ActorConfig, ActorStartResult, ActorStopResult}; use std::{ sync::{Arc, Mutex}, time::Duration, @@ -22,7 +24,7 @@ impl Default for EchoActor { } #[async_trait] -impl TestActor for EchoActor { +impl Actor for EchoActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!(actor_id = ?config.actor_id, generation = config.generation, "echo actor started"); Ok(ActorStartResult::Running) @@ -67,7 +69,7 @@ impl CrashOnStartActor { } #[async_trait] -impl TestActor for CrashOnStartActor { +impl Actor for CrashOnStartActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::warn!( actor_id = ?config.actor_id, @@ -111,7 +113,7 @@ impl DelayedStartActor { } #[async_trait] -impl TestActor for DelayedStartActor { +impl Actor for DelayedStartActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!( actor_id = ?config.actor_id, @@ -147,7 +149,7 @@ impl Default for TimeoutActor { } #[async_trait] -impl TestActor for TimeoutActor { +impl Actor for TimeoutActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::warn!( actor_id = ?config.actor_id, @@ -192,7 +194,7 @@ impl Default for SleepImmediatelyActor { } #[async_trait] -impl TestActor for SleepImmediatelyActor { +impl Actor for SleepImmediatelyActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!( actor_id = ?config.actor_id, @@ -240,7 +242,7 @@ impl Default for StopImmediatelyActor { } #[async_trait] -impl TestActor for StopImmediatelyActor { +impl Actor for StopImmediatelyActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!( actor_id = ?config.actor_id, @@ -277,7 +279,7 @@ impl CountingCrashActor { } #[async_trait] -impl TestActor for CountingCrashActor { +impl Actor for CountingCrashActor { async fn on_start(&mut self, config: ActorConfig) -> Result { let count = self .crash_count @@ -303,8 +305,8 @@ impl TestActor for CountingCrashActor { } } -/// Actor that crashes N times then succeeds -/// Used to test crash policy restart with retry reset on success +/// Actor that crashes N times then succeeds. +/// Used to test crash policy restart with retry reset on success. pub struct CrashNTimesThenSucceedActor { crash_count: Arc>, max_crashes: usize, @@ -320,7 +322,7 @@ impl CrashNTimesThenSucceedActor { } #[async_trait] -impl TestActor for CrashNTimesThenSucceedActor { +impl Actor for CrashNTimesThenSucceedActor { async fn on_start(&mut self, config: ActorConfig) -> Result { let mut count = self.crash_count.lock().unwrap(); let current = *count; @@ -358,8 +360,8 @@ impl TestActor for CrashNTimesThenSucceedActor { } } -/// Actor that notifies via a oneshot channel when it starts running -/// This allows tests to wait for the actor to actually start instead of sleeping +/// Actor that notifies via a oneshot channel when it starts running. +/// This allows tests to wait for the actor to actually start instead of sleeping. pub struct NotifyOnStartActor { notify_tx: std::sync::Arc>>>, } @@ -373,7 +375,7 @@ impl NotifyOnStartActor { } #[async_trait] -impl TestActor for NotifyOnStartActor { +impl Actor for NotifyOnStartActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!( actor_id = ?config.actor_id, @@ -400,8 +402,8 @@ impl TestActor for NotifyOnStartActor { } } -/// Actor that verifies it received the expected input data -/// Crashes if input doesn't match or is missing, succeeds if it matches +/// Actor that verifies it received the expected input data. +/// Crashes if input doesn't match or is missing, succeeds if it matches. pub struct VerifyInputActor { expected_input: Vec, } @@ -413,7 +415,7 @@ impl VerifyInputActor { } #[async_trait] -impl TestActor for VerifyInputActor { +impl Actor for VerifyInputActor { async fn on_start(&mut self, config: ActorConfig) -> Result { tracing::info!( actor_id = ?config.actor_id, @@ -463,8 +465,8 @@ impl TestActor for VerifyInputActor { } } -/// Generic actor that accepts closures for on_start and on_stop -/// This allows tests to define actor behavior inline without creating separate structs +/// Generic actor that accepts closures for on_start and on_stop. +/// This allows tests to define actor behavior inline without creating separate structs. pub struct CustomActor { on_start_fn: Box< dyn Fn( @@ -567,7 +569,7 @@ impl Default for CustomActorBuilder { } #[async_trait] -impl TestActor for CustomActor { +impl Actor for CustomActor { async fn on_start(&mut self, config: ActorConfig) -> Result { (self.on_start_fn)(config).await } diff --git a/engine/packages/engine/tests/common/mod.rs b/engine/packages/engine/tests/common/mod.rs index 8c9b100f13..c5c517c87b 100644 --- a/engine/packages/engine/tests/common/mod.rs +++ b/engine/packages/engine/tests/common/mod.rs @@ -2,6 +2,7 @@ pub mod actors; pub mod api; +pub mod behaviors; pub mod ctx; pub mod test_helpers; pub mod test_runner; diff --git a/engine/packages/engine/tests/common/test_runner/mod.rs b/engine/packages/engine/tests/common/test_runner/mod.rs index 8d3cba0faf..01a51ec0bf 100644 --- a/engine/packages/engine/tests/common/test_runner/mod.rs +++ b/engine/packages/engine/tests/common/test_runner/mod.rs @@ -3,24 +3,306 @@ //! This module provides a `TestRunnerBuilder` that wraps the standalone `rivet-engine-runner` //! package, adding test-specific functionality like building from a `TestDatacenter`. -use anyhow::Result; +use anyhow::{Result, anyhow, bail}; +use async_trait::async_trait; +use rivet_runner_protocol::mk2 as rp; use std::collections::HashMap; use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{Mutex, broadcast}; -// Re-export everything from the standalone package +// Re-export from the standalone package pub use rivet_engine_runner::{ - ActorConfig, ActorEvent, ActorLifecycleEvent, ActorStartResult, ActorStopResult, - CountingCrashActor, CrashNTimesThenSucceedActor, CrashOnStartActor, CustomActor, - CustomActorBuilder, DelayedStartActor, EchoActor, KvRequest, NotifyOnStartActor, - PROTOCOL_VERSION, Runner, RunnerBuilder, RunnerBuilderLegacy, RunnerConfig, - SleepImmediatelyActor, StopImmediatelyActor, TestActor, TimeoutActor, VerifyInputActor, + ActorLifecycleEvent, PROTOCOL_VERSION, Runner, RunnerBuilder, RunnerConfig, RunnerHandle, protocol_types, }; -// Type alias for backwards compatibility -pub type TestRunner = Runner; +// Re-export test behavior actors from the local behaviors module +pub use super::behaviors::{ + CountingCrashActor, CrashNTimesThenSucceedActor, CrashOnStartActor, CustomActor, + CustomActorBuilder, DelayedStartActor, EchoActor, NotifyOnStartActor, SleepImmediatelyActor, + StopImmediatelyActor, TimeoutActor, VerifyInputActor, +}; + +#[derive(Clone)] +pub struct TestRunner { + runner: Runner, + handle: RunnerHandle, +} + +impl TestRunner { + pub fn name(&self) -> &str { + self.runner.name() + } + + pub async fn start(&self) -> Result<()> { + self.runner.start().await + } + + pub async fn wait_ready(&self) -> String { + self.runner + .wait_ready() + .await + .expect("runner should become ready") + } + + pub async fn has_actor(&self, actor_id: &str) -> bool { + self.handle.has_actor(actor_id, None).await + } + + pub fn subscribe_lifecycle_events(&self) -> broadcast::Receiver { + self.runner.subscribe_lifecycle_events() + } + + pub async fn get_actor_ids(&self) -> Vec { + self.handle.get_actor_ids().await + } + + pub async fn shutdown(&self) { + if let Err(err) = self.runner.shutdown(false).await { + tracing::error!(?err, "failed to shutdown test runner"); + } + } + + pub async fn crash(&self) { + if let Err(err) = self.runner.crash().await { + tracing::error!(?err, "failed to crash test runner"); + } + } +} + +#[derive(Clone)] +pub struct ActorConfig { + pub actor_id: String, + pub generation: u32, + pub actor_name: String, + pub input: Option>, + runner: RunnerHandle, +} + +impl ActorConfig { + fn from_context(ctx: &rivet_engine_runner::ActorContext, runner: RunnerHandle) -> Self { + Self { + actor_id: ctx.actor_id.clone(), + generation: ctx.generation, + actor_name: ctx.actor_name.clone(), + input: ctx.config.input.clone(), + runner, + } + } + + pub async fn send_kv_get(&self, keys: Vec>) -> Result { + let requested_keys = keys.clone(); + let values = self.runner.kv_get(&self.actor_id, keys).await?; + + let mut response_keys = Vec::new(); + let mut response_values = Vec::new(); + for (key, value) in requested_keys.into_iter().zip(values.into_iter()) { + if let Some(value) = value { + response_keys.push(key); + response_values.push(value); + } + } + + Ok(KvReadResponse { + keys: response_keys, + values: response_values, + }) + } + + pub async fn send_kv_put(&self, keys: Vec>, values: Vec>) -> Result<()> { + if keys.len() != values.len() { + bail!( + "mismatched kv put payload lengths: keys={}, values={}", + keys.len(), + values.len() + ); + } + + let entries = keys.into_iter().zip(values).collect::>(); + self.runner.kv_put(&self.actor_id, entries).await + } + + pub async fn send_kv_delete(&self, keys: Vec>) -> Result<()> { + self.runner.kv_delete(&self.actor_id, keys).await + } + + pub async fn send_kv_drop(&self) -> Result<()> { + self.runner.kv_drop(&self.actor_id).await + } + + pub async fn send_kv_list( + &self, + query: rp::KvListQuery, + reverse: Option, + limit: Option, + ) -> Result { + let entries = match query { + rp::KvListQuery::KvListAllQuery => { + self.runner + .kv_list_all(&self.actor_id, reverse, limit) + .await? + } + rp::KvListQuery::KvListPrefixQuery(prefix) => { + self.runner + .kv_list_prefix(&self.actor_id, prefix.key, reverse, limit) + .await? + } + rp::KvListQuery::KvListRangeQuery(range) => { + self.runner + .kv_list_range( + &self.actor_id, + range.start, + range.end, + range.exclusive, + reverse, + limit, + ) + .await? + } + }; + + let (keys, values): (Vec<_>, Vec<_>) = entries.into_iter().unzip(); + Ok(KvReadResponse { keys, values }) + } + + pub fn send_sleep_intent(&self) { + let runner = self.runner.clone(); + let actor_id = self.actor_id.clone(); + let generation = self.generation; + tokio::spawn(async move { + if let Err(err) = runner.sleep_actor(&actor_id, Some(generation)).await { + tracing::error!(?err, %actor_id, generation, "failed to send sleep intent"); + } + }); + } + + pub fn send_stop_intent(&self) { + let runner = self.runner.clone(); + let actor_id = self.actor_id.clone(); + let generation = self.generation; + tokio::spawn(async move { + if let Err(err) = runner.stop_actor(&actor_id, Some(generation)).await { + tracing::error!(?err, %actor_id, generation, "failed to send stop intent"); + } + }); + } + + pub fn send_set_alarm(&self, alarm_ts: i64) { + let runner = self.runner.clone(); + let actor_id = self.actor_id.clone(); + let generation = self.generation; + tokio::spawn(async move { + if let Err(err) = runner.set_alarm(&actor_id, Some(alarm_ts), Some(generation)).await { + tracing::error!(?err, %actor_id, generation, alarm_ts, "failed to set alarm"); + } + }); + } -type ActorFactory = Arc Box + Send + Sync>; + pub fn send_clear_alarm(&self) { + let runner = self.runner.clone(); + let actor_id = self.actor_id.clone(); + let generation = self.generation; + tokio::spawn(async move { + if let Err(err) = runner.clear_alarm(&actor_id, Some(generation)).await { + tracing::error!(?err, %actor_id, generation, "failed to clear alarm"); + } + }); + } +} + +#[derive(Debug, Clone)] +pub struct KvReadResponse { + pub keys: Vec>, + pub values: Vec>, +} + +#[derive(Debug, Clone)] +pub enum ActorStartResult { + Running, + Crash { code: i32, message: String }, + Delay(Duration), + Timeout, +} + +#[derive(Debug, Clone)] +pub enum ActorStopResult { + Success, +} + +#[async_trait] +pub trait Actor: Send + 'static { + async fn on_start(&mut self, config: ActorConfig) -> Result; + async fn on_stop(&mut self) -> Result; + fn name(&self) -> &str; +} + +type ActorFactory = Arc Box + Send + Sync>; + +#[derive(Clone)] +struct LegacyRunnerApp { + actor_factories: HashMap, + actors: Arc>>>, +} + +impl LegacyRunnerApp { + fn new(actor_factories: HashMap) -> Self { + Self { + actor_factories, + actors: Arc::new(Mutex::new(HashMap::new())), + } + } +} + +#[async_trait] +impl rivet_engine_runner::RunnerApp for LegacyRunnerApp { + async fn on_actor_start( + &self, + runner: RunnerHandle, + ctx: rivet_engine_runner::ActorContext, + ) -> Result<()> { + let actor_factory = self + .actor_factories + .get(&ctx.actor_name) + .ok_or_else(|| anyhow!("no actor behavior registered for '{}'", ctx.actor_name))? + .clone(); + let legacy_config = ActorConfig::from_context(&ctx, runner); + let mut actor = actor_factory(legacy_config.clone()); + + match actor.on_start(legacy_config).await? { + ActorStartResult::Running => { + self.actors.lock().await.insert(ctx.actor_id, actor); + Ok(()) + } + ActorStartResult::Delay(delay) => { + tokio::time::sleep(delay).await; + self.actors.lock().await.insert(ctx.actor_id, actor); + Ok(()) + } + ActorStartResult::Crash { code, message } => { + bail!("actor crashed with code {}: {}", code, message) + } + ActorStartResult::Timeout => { + std::future::pending::<()>().await; + unreachable!() + } + } + } + + async fn on_actor_stop( + &self, + _runner: RunnerHandle, + ctx: rivet_engine_runner::ActorContext, + ) -> Result<()> { + let Some(mut actor) = self.actors.lock().await.remove(&ctx.actor_id) else { + return Ok(()); + }; + + match actor.on_stop().await? { + ActorStopResult::Success => Ok(()), + } + } +} /// Test-specific runner builder that integrates with TestDatacenter pub struct TestRunnerBuilder { @@ -67,7 +349,7 @@ impl TestRunnerBuilder { /// Register an actor factory for a specific actor name pub fn with_actor_behavior(mut self, actor_name: &str, factory: F) -> Self where - F: Fn(ActorConfig) -> Box + Send + Sync + 'static, + F: Fn(ActorConfig) -> Box + Send + Sync + 'static, { self.actor_factories .insert(actor_name.to_string(), Arc::new(factory)); @@ -75,7 +357,7 @@ impl TestRunnerBuilder { } /// Build the runner using the TestDatacenter's guard port - pub async fn build(self, dc: &super::TestDatacenter) -> Result { + pub async fn build(self, dc: &super::TestDatacenter) -> Result { let endpoint = format!("http://127.0.0.1:{}", dc.guard_port()); let token = "dev".to_string(); @@ -90,14 +372,10 @@ impl TestRunnerBuilder { .total_slots(self.total_slots) .build()?; - // Build the runner - let mut builder = RunnerBuilder::new(config); - - // Register all actor factories - for (name, factory) in self.actor_factories { - builder = builder.with_actor_behavior(&name, move |config| factory(config)); - } + let app = LegacyRunnerApp::new(self.actor_factories); + let runner = RunnerBuilder::new(config).app(app).build()?; + let handle = runner.handle(); - builder.build() + Ok(TestRunner { runner, handle }) } } diff --git a/engine/sdks/rust/engine-runner/Cargo.toml b/engine/sdks/rust/engine-runner/Cargo.toml index d623eca8a5..32818f5d8b 100644 --- a/engine/sdks/rust/engine-runner/Cargo.toml +++ b/engine/sdks/rust/engine-runner/Cargo.toml @@ -8,16 +8,31 @@ description = "Rust-based engine runner for Rivet, enabling programmatic actor l [dependencies] anyhow.workspace = true +async-stream.workspace = true async-trait.workspace = true +axum.workspace = true +base64.workspace = true +bytes.workspace = true chrono.workspace = true futures-util.workspace = true +http.workspace = true rand.workspace = true +reqwest.workspace = true rivet-runner-protocol.workspace = true serde_bare.workspace = true serde_json.workspace = true serde.workspace = true +tokio-stream.workspace = true tokio.workspace = true tokio-tungstenite.workspace = true +tower.workspace = true tracing.workspace = true urlencoding.workspace = true vbare.workspace = true + +[dev-dependencies] +portpicker.workspace = true +rivet-config.workspace = true +rivet-test-deps.workspace = true +serde_yaml.workspace = true +tempfile.workspace = true diff --git a/engine/sdks/rust/engine-runner/examples/counter.rs b/engine/sdks/rust/engine-runner/examples/counter.rs new file mode 100644 index 0000000000..40529e419e --- /dev/null +++ b/engine/sdks/rust/engine-runner/examples/counter.rs @@ -0,0 +1,69 @@ +//! Counter example using the Rust engine runner API. + +use anyhow::Result; +use axum::{Json, Router, extract::State, routing::{get, post}}; +use rivet_engine_runner::{ + ActorContext, ActorRequestContext, AxumActorDefinition, AxumRunnerApp, Runner, RunnerConfig, +}; +use serde_json::json; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<()> { + let app = AxumRunnerApp::new().with_actor( + "counter", + AxumActorDefinition::new( + Router::new() + .route("/count", get(get_count)) + .route("/increment", post(increment)), + ) + .on_start(|ctx: ActorContext| async move { + tracing::info!(actor_id = %ctx.actor_id, generation = ctx.generation, "counter actor started"); + Ok(()) + }) + .on_stop(|ctx: ActorContext| async move { + tracing::info!(actor_id = %ctx.actor_id, generation = ctx.generation, "counter actor stopped"); + Ok(()) + }), + ); + + let runner = Runner::builder( + RunnerConfig::builder() + .endpoint("http://127.0.0.1:6420") + .namespace("default") + .runner_name("counter-runner") + .build()?, + ) + .app(app) + .build()?; + + println!( + "runner configured. call runner.start().await in an integration environment with a running engine" + ); + let _ = Arc::new(runner); + Ok(()) +} + +async fn get_count(State(ctx): State) -> Result, axum::http::StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0); + Ok(Json(json!({ "count": count }))) +} + +async fn increment(State(ctx): State) -> Result, axum::http::StatusCode> { + let next = ctx + .kv_get_u64("count") + .await + .map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0) + + 1; + + ctx.kv_put_u64("count", next) + .await + .map_err(|_| axum::http::StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(Json(json!({ "count": next }))) +} diff --git a/engine/sdks/rust/engine-runner/src/actor.rs b/engine/sdks/rust/engine-runner/src/actor.rs index 89409d6db3..50fb46b728 100644 --- a/engine/sdks/rust/engine-runner/src/actor.rs +++ b/engine/sdks/rust/engine-runner/src/actor.rs @@ -1,278 +1,328 @@ -use anyhow::Result; +use anyhow::{Context, Result, bail}; use async_trait::async_trait; -use rivet_runner_protocol::mk2 as rp; -use std::time::Duration; -use tokio::sync::{mpsc, oneshot}; +use axum::{Router, body::Body}; +use bytes::Bytes; +use futures_util::future::BoxFuture; +use http::{Request, Response}; +use std::{collections::HashMap, future::Future, sync::Arc}; +use tower::ServiceExt; -use crate::protocol; +use crate::runner::RunnerHandle; -/// Configuration passed to actor when it starts -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct ActorConfig { - pub actor_id: String, - pub generation: u32, pub name: String, pub key: Option, pub create_ts: i64, pub input: Option>, +} - /// Channel to send events to the runner - pub event_tx: mpsc::UnboundedSender, +#[derive(Clone, Debug)] +pub struct HibernatingRequest { + pub gateway_id: [u8; 4], + pub request_id: [u8; 4], +} + +#[derive(Clone, Debug)] +pub struct ActorContext { + pub actor_id: String, + pub generation: u32, + pub actor_name: String, + pub config: ActorConfig, + pub hibernating_requests: Vec, +} - /// Channel to send KV requests to the runner - pub kv_request_tx: mpsc::UnboundedSender, +#[derive(Clone, Debug)] +pub struct HttpContext { + pub actor_id: String, + pub generation: u32, + pub actor_name: String, + pub gateway_id: [u8; 4], + pub request_id: [u8; 4], } -impl ActorConfig { - pub fn new( - config: &rp::ActorConfig, - actor_id: String, - generation: u32, - event_tx: mpsc::UnboundedSender, - kv_request_tx: mpsc::UnboundedSender, - ) -> Self { - ActorConfig { - actor_id, - generation, - name: config.name.clone(), - key: config.key.clone(), - create_ts: config.create_ts, - input: config.input.as_ref().map(|i| i.to_vec()), - event_tx, - kv_request_tx, +#[derive(Clone, Debug)] +pub struct WebSocketContext { + pub actor_id: String, + pub generation: u32, + pub actor_name: String, + pub gateway_id: [u8; 4], + pub request_id: [u8; 4], + pub path: String, + pub headers: HashMap, + pub is_hibernatable: bool, + pub is_restoring_hibernatable: bool, +} + +#[derive(Clone, Debug)] +pub struct HibernatingWebSocketMetadata { + pub gateway_id: [u8; 4], + pub request_id: [u8; 4], + pub client_message_index: u16, + pub server_message_index: u16, + pub path: String, + pub headers: HashMap, +} + +#[derive(Clone, Debug)] +pub struct WebSocketMessage { + pub data: Vec, + pub binary: bool, + pub message_index: u16, +} + +#[derive(Clone)] +pub struct ActorRequestContext { + pub runner: RunnerHandle, + pub actor_id: String, + pub generation: u32, + pub actor_name: String, +} + +impl ActorRequestContext { + pub async fn kv_get(&self, keys: Vec>) -> Result>>> { + self.runner.kv_get(&self.actor_id, keys).await + } + + pub async fn kv_get_u64(&self, key: impl AsRef<[u8]>) -> Result> { + let values = self.kv_get(vec![key.as_ref().to_vec()]).await?; + let Some(raw) = values.into_iter().next().flatten() else { + return Ok(None); + }; + + if raw.len() != 8 { + bail!("expected u64 value to be 8 bytes, got {}", raw.len()); } + + let mut bytes = [0u8; 8]; + bytes.copy_from_slice(&raw); + Ok(Some(u64::from_le_bytes(bytes))) } -} -impl ActorConfig { - /// Send a sleep intent - pub fn send_sleep_intent(&self) { - let event = protocol::make_actor_intent(rp::ActorIntent::ActorIntentSleep); - self.send_event(event); + pub async fn kv_put(&self, entries: Vec<(Vec, Vec)>) -> Result<()> { + self.runner.kv_put(&self.actor_id, entries).await } - /// Send a stop intent - pub fn send_stop_intent(&self) { - let event = protocol::make_actor_intent(rp::ActorIntent::ActorIntentStop); - self.send_event(event); + pub async fn kv_put_u64(&self, key: impl AsRef<[u8]>, value: u64) -> Result<()> { + self.kv_put(vec![(key.as_ref().to_vec(), value.to_le_bytes().to_vec())]) + .await } - /// Set an alarm to wake at specified timestamp (milliseconds) - pub fn send_set_alarm(&self, alarm_ts: i64) { - let event = protocol::make_set_alarm(Some(alarm_ts)); - self.send_event(event); + pub async fn kv_delete(&self, keys: Vec>) -> Result<()> { + self.runner.kv_delete(&self.actor_id, keys).await } - /// Clear the alarm - pub fn send_clear_alarm(&self) { - let event = protocol::make_set_alarm(None); - self.send_event(event); + pub async fn kv_drop(&self) -> Result<()> { + self.runner.kv_drop(&self.actor_id).await } - /// Send a custom event - fn send_event(&self, event: rp::Event) { - let actor_event = ActorEvent { - actor_id: self.actor_id.clone(), - generation: self.generation, - event, - }; - let _ = self.event_tx.send(actor_event); + pub async fn sleep_actor(&self) -> Result<()> { + self.runner + .sleep_actor(&self.actor_id, Some(self.generation)) + .await } - /// Send a KV get request - pub async fn send_kv_get(&self, keys: Vec>) -> Result { - let (response_tx, response_rx) = oneshot::channel(); - let request = KvRequest { - actor_id: self.actor_id.clone(), - data: rp::KvRequestData::KvGetRequest(rp::KvGetRequest { keys }), - response_tx, - }; - self.kv_request_tx - .send(request) - .map_err(|_| anyhow::anyhow!("failed to send KV get request"))?; - let response: rp::KvResponseData = response_rx + pub async fn stop_actor(&self) -> Result<()> { + self.runner + .stop_actor(&self.actor_id, Some(self.generation)) .await - .map_err(|_| anyhow::anyhow!("KV get request response channel closed"))?; - - match response { - rp::KvResponseData::KvGetResponse(data) => Ok(data), - rp::KvResponseData::KvErrorResponse(err) => { - Err(anyhow::anyhow!("KV get failed: {}", err.message)) - } - _ => Err(anyhow::anyhow!("unexpected response type for KV get")), - } } - /// Send a KV list request - pub async fn send_kv_list( - &self, - query: rp::KvListQuery, - reverse: Option, - limit: Option, - ) -> Result { - let (response_tx, response_rx) = oneshot::channel(); - let request = KvRequest { - actor_id: self.actor_id.clone(), - data: rp::KvRequestData::KvListRequest(rp::KvListRequest { - query, - reverse, - limit, - }), - response_tx, - }; - self.kv_request_tx - .send(request) - .map_err(|_| anyhow::anyhow!("failed to send KV list request"))?; - let response: rp::KvResponseData = response_rx + pub async fn set_alarm(&self, alarm_ts: Option) -> Result<()> { + self.runner + .set_alarm(&self.actor_id, alarm_ts, Some(self.generation)) .await - .map_err(|_| anyhow::anyhow!("KV list request response channel closed"))?; - - match response { - rp::KvResponseData::KvListResponse(data) => Ok(data), - rp::KvResponseData::KvErrorResponse(err) => { - Err(anyhow::anyhow!("KV list failed: {}", err.message)) - } - _ => Err(anyhow::anyhow!("unexpected response type for KV list")), - } } - /// Send a KV put request - pub async fn send_kv_put(&self, keys: Vec>, values: Vec>) -> Result<()> { - let (response_tx, response_rx) = oneshot::channel(); - let request = KvRequest { - actor_id: self.actor_id.clone(), - data: rp::KvRequestData::KvPutRequest(rp::KvPutRequest { keys, values }), - response_tx, - }; + pub async fn clear_alarm(&self) -> Result<()> { + self.set_alarm(None).await + } +} - self.kv_request_tx - .send(request) - .map_err(|_| anyhow::anyhow!("failed to send KV put request"))?; +#[async_trait] +pub trait RunnerApp: Send + Sync + 'static { + async fn on_connected(&self, _runner: RunnerHandle) -> Result<()> { + Ok(()) + } - let response: rp::KvResponseData = response_rx - .await - .map_err(|_| anyhow::anyhow!("KV put request response channel closed"))?; - - match response { - rp::KvResponseData::KvPutResponse => Ok(()), - rp::KvResponseData::KvErrorResponse(err) => { - Err(anyhow::anyhow!("KV put failed: {}", err.message)) - } - _ => Err(anyhow::anyhow!("unexpected response type for KV put")), - } + async fn on_disconnected(&self, _runner: RunnerHandle, _code: u16, _reason: String) -> Result<()> { + Ok(()) } - /// Send a KV delete request - pub async fn send_kv_delete(&self, keys: Vec>) -> Result<()> { - let (response_tx, response_rx) = oneshot::channel(); - let request = KvRequest { - actor_id: self.actor_id.clone(), - data: rp::KvRequestData::KvDeleteRequest(rp::KvDeleteRequest { keys }), - response_tx, - }; - self.kv_request_tx - .send(request) - .map_err(|_| anyhow::anyhow!("failed to send KV delete request"))?; - let response: rp::KvResponseData = response_rx - .await - .map_err(|_| anyhow::anyhow!("KV delete request response channel closed"))?; - - match response { - rp::KvResponseData::KvDeleteResponse => Ok(()), - rp::KvResponseData::KvErrorResponse(err) => { - Err(anyhow::anyhow!("KV delete failed: {}", err.message)) - } - _ => Err(anyhow::anyhow!("unexpected response type for KV delete")), - } + async fn on_shutdown(&self, _runner: RunnerHandle) -> Result<()> { + Ok(()) } - /// Send a KV drop request - pub async fn send_kv_drop(&self) -> Result<()> { - let (response_tx, response_rx) = oneshot::channel(); - let request = KvRequest { - actor_id: self.actor_id.clone(), - data: rp::KvRequestData::KvDropRequest, - response_tx, - }; - self.kv_request_tx - .send(request) - .map_err(|_| anyhow::anyhow!("failed to send KV drop request"))?; - let response: rp::KvResponseData = response_rx - .await - .map_err(|_| anyhow::anyhow!("KV drop request response channel closed"))?; - - match response { - rp::KvResponseData::KvDropResponse => Ok(()), - rp::KvResponseData::KvErrorResponse(err) => { - Err(anyhow::anyhow!("KV drop failed: {}", err.message)) - } - _ => Err(anyhow::anyhow!("unexpected response type for KV drop")), + async fn on_actor_start(&self, _runner: RunnerHandle, _ctx: ActorContext) -> Result<()> { + Ok(()) + } + + async fn on_actor_stop(&self, _runner: RunnerHandle, _ctx: ActorContext) -> Result<()> { + Ok(()) + } + + async fn fetch( + &self, + _runner: RunnerHandle, + _ctx: HttpContext, + _request: Request, + ) -> Result> { + Ok(Response::builder() + .status(501) + .body(Bytes::from_static(b"Not Implemented"))?) + } + + async fn websocket( + &self, + _runner: RunnerHandle, + _ctx: WebSocketContext, + ) -> Result<()> { + Ok(()) + } + + async fn websocket_message( + &self, + _runner: RunnerHandle, + _ctx: WebSocketContext, + _message: WebSocketMessage, + ) -> Result<()> { + Ok(()) + } + + async fn websocket_close( + &self, + _runner: RunnerHandle, + _ctx: WebSocketContext, + _code: Option, + _reason: Option, + ) -> Result<()> { + Ok(()) + } + + fn can_hibernate(&self, _ctx: &WebSocketContext) -> bool { + false + } +} + +type LifecycleHook = Arc BoxFuture<'static, Result<()>> + Send + Sync>; + +#[derive(Clone)] +pub struct AxumActorDefinition { + router: Router, + on_start: Option, + on_stop: Option, +} + +impl AxumActorDefinition { + pub fn new(router: Router) -> Self { + Self { + router, + on_start: None, + on_stop: None, } } + + pub fn on_start(mut self, hook: F) -> Self + where + F: Fn(ActorContext) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + self.on_start = Some(Arc::new(move |ctx| Box::pin(hook(ctx)))); + self + } + + pub fn on_stop(mut self, hook: F) -> Self + where + F: Fn(ActorContext) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + self.on_stop = Some(Arc::new(move |ctx| Box::pin(hook(ctx)))); + self + } } -/// Result of actor start operation -#[derive(Debug, Clone)] -pub enum ActorStartResult { - /// Send ActorStateRunning immediately - Running, - /// Wait specified duration before sending running - Delay(Duration), - /// Never send running (simulates timeout) - Timeout, - /// Crash immediately with exit code - Crash { code: i32, message: String }, +#[derive(Clone, Default)] +pub struct AxumRunnerApp { + actors: HashMap, } -/// Result of actor stop operation -#[derive(Debug, Clone)] -pub enum ActorStopResult { - /// Stop successfully (exit code 0) - Success, - /// Wait before stopping - Delay(Duration), - /// Crash with exit code - Crash { code: i32, message: String }, +impl AxumRunnerApp { + pub fn new() -> Self { + Self::default() + } + + pub fn with_actor(mut self, name: impl Into, definition: AxumActorDefinition) -> Self { + self.actors.insert(name.into(), definition); + self + } + + pub fn actor(&self, name: &str) -> Option<&AxumActorDefinition> { + self.actors.get(name) + } } -/// Trait for test actors that can be controlled programmatically #[async_trait] -pub trait TestActor: Send + Sync { - /// Called when actor receives start command - async fn on_start(&mut self, config: ActorConfig) -> Result; +impl RunnerApp for AxumRunnerApp { + async fn on_actor_start(&self, _runner: RunnerHandle, ctx: ActorContext) -> Result<()> { + let actor = self + .actors + .get(&ctx.actor_name) + .with_context(|| format!("actor '{}' is not registered", ctx.actor_name))?; - /// Called when actor receives stop command - async fn on_stop(&mut self) -> Result; + if let Some(on_start) = &actor.on_start { + on_start(ctx).await?; + } - /// Called when actor receives alarm wake signal - async fn on_alarm(&mut self) -> Result<()> { - tracing::debug!("actor received alarm (default no-op)"); Ok(()) } - /// Called when actor receives wake signal (from sleep) - async fn on_wake(&mut self) -> Result<()> { - tracing::debug!("actor received wake (default no-op)"); + async fn on_actor_stop(&self, _runner: RunnerHandle, ctx: ActorContext) -> Result<()> { + let actor = self + .actors + .get(&ctx.actor_name) + .with_context(|| format!("actor '{}' is not registered", ctx.actor_name))?; + + if let Some(on_stop) = &actor.on_stop { + on_stop(ctx).await?; + } + Ok(()) } - /// Get actor's name for logging - fn name(&self) -> &str { - "TestActor" - } -} + async fn fetch( + &self, + runner: RunnerHandle, + ctx: HttpContext, + request: Request, + ) -> Result> { + let actor = self + .actors + .get(&ctx.actor_name) + .with_context(|| format!("actor '{}' is not registered", ctx.actor_name))?; -/// Events that actors can send directly via the event channel -#[derive(Debug, Clone)] -pub struct ActorEvent { - pub actor_id: String, - pub generation: u32, - pub event: rp::Event, -} + let state = ActorRequestContext { + runner, + actor_id: ctx.actor_id.clone(), + generation: ctx.generation, + actor_name: ctx.actor_name, + }; -/// KV requests that actors can send to the runner -pub struct KvRequest { - pub actor_id: String, - pub data: rp::KvRequestData, - pub response_tx: oneshot::Sender, + let (parts, body) = request.into_parts(); + let request = Request::from_parts(parts, Body::from(body)); + + let response = actor + .router + .clone() + .with_state(state) + .oneshot(request) + .await + .context("failed to serve axum actor route")?; + + let (parts, body) = response.into_parts(); + let body = axum::body::to_bytes(body, usize::MAX) + .await + .context("failed to collect actor response body")?; + + Ok(Response::from_parts(parts, body)) + } } diff --git a/engine/sdks/rust/engine-runner/src/lib.rs b/engine/sdks/rust/engine-runner/src/lib.rs index 8cd636fedb..7a088392bf 100644 --- a/engine/sdks/rust/engine-runner/src/lib.rs +++ b/engine/sdks/rust/engine-runner/src/lib.rs @@ -1,46 +1,22 @@ -//! Rust-based engine runner for Rivet. +//! Rust engine runner SDK. //! -//! This library provides a pure Rust implementation of a Rivet runner that can be fully controlled -//! programmatically, allowing simulation of: -//! - Actor crashes with specific exit codes -//! - Protocol timing issues (delays, timeouts) -//! - Custom protocol events (sleep, alarms, etc.) -//! - Runner disconnection/reconnection scenarios -//! -//! # Example -//! -//! ```ignore -//! use rivet_engine_runner::{Runner, RunnerConfig, EchoActor}; -//! -//! let config = RunnerConfig::builder() -//! .endpoint("http://127.0.0.1:8080") -//! .token("dev") -//! .namespace("my-namespace") -//! .runner_name("my-runner") -//! .runner_key("unique-key") -//! .build(); -//! -//! let mut runner = Runner::new(config)?; -//! runner.register_actor("echo", |_| Box::new(EchoActor::new())); -//! runner.start().await?; -//! ``` +//! This crate mirrors the TypeScript engine runner semantics with Rust-idiomatic +//! types and async traits. mod actor; -mod behaviors; mod protocol; mod runner; -pub use actor::{ActorConfig, ActorEvent, ActorStartResult, ActorStopResult, KvRequest, TestActor}; -pub use behaviors::{ - CountingCrashActor, CrashNTimesThenSucceedActor, CrashOnStartActor, CustomActor, - CustomActorBuilder, DelayedStartActor, EchoActor, NotifyOnStartActor, SleepImmediatelyActor, - StopImmediatelyActor, TimeoutActor, VerifyInputActor, +pub use actor::{ + ActorConfig, ActorContext, ActorRequestContext, AxumActorDefinition, AxumRunnerApp, + HibernatingRequest, HibernatingWebSocketMetadata, HttpContext, RunnerApp, WebSocketContext, + WebSocketMessage, }; pub use protocol::PROTOCOL_VERSION; pub use runner::{ - ActorLifecycleEvent, Runner, RunnerBuilder, RunnerBuilderLegacy, RunnerConfig, - RunnerConfigBuilder, + ActorLifecycleEvent, PrepopulateActorName, Runner, RunnerBuilder, RunnerConfig, + RunnerConfigBuilder, RunnerHandle, ServerlessConfig, ServerlessConfigBuilder, ServerlessRunner, + ServerlessRunnerBuilder, }; -// Re-export commonly used types from the protocol pub use rivet_runner_protocol::mk2 as protocol_types; diff --git a/engine/sdks/rust/engine-runner/src/protocol.rs b/engine/sdks/rust/engine-runner/src/protocol.rs index 958fb14ccd..960b45de76 100644 --- a/engine/sdks/rust/engine-runner/src/protocol.rs +++ b/engine/sdks/rust/engine-runner/src/protocol.rs @@ -17,35 +17,3 @@ pub fn encode_to_server(msg: rp2::ToServer) -> Vec { .serialize(PROTOCOL_VERSION) .expect("failed to serialize ToServer") } - -/// Helper to create event wrapper with checkpoint (MK2) -pub fn make_event_wrapper( - actor_id: &str, - generation: u32, - index: u64, - event: rp2::Event, -) -> rp2::EventWrapper { - rp2::EventWrapper { - checkpoint: rp2::ActorCheckpoint { - actor_id: actor_id.to_string(), - generation, - index: index as i64, - }, - inner: event, - } -} - -/// Helper to create actor state update event (MK2) -pub fn make_actor_state_update(state: rp2::ActorState) -> rp2::Event { - rp2::Event::EventActorStateUpdate(rp2::EventActorStateUpdate { state }) -} - -/// Helper to create actor intent event (MK2) -pub fn make_actor_intent(intent: rp2::ActorIntent) -> rp2::Event { - rp2::Event::EventActorIntent(rp2::EventActorIntent { intent }) -} - -/// Helper to create set alarm event (MK2) -pub fn make_set_alarm(alarm_ts: Option) -> rp2::Event { - rp2::Event::EventActorSetAlarm(rp2::EventActorSetAlarm { alarm_ts }) -} diff --git a/engine/sdks/rust/engine-runner/src/runner.rs b/engine/sdks/rust/engine-runner/src/runner.rs index 08340cf245..b94ca1c5c6 100644 --- a/engine/sdks/rust/engine-runner/src/runner.rs +++ b/engine/sdks/rust/engine-runner/src/runner.rs @@ -1,62 +1,69 @@ -use crate::{actor::*, protocol}; -use anyhow::{Context, Result}; +use crate::{ + actor::{ + ActorConfig, ActorContext, HttpContext, HibernatingRequest, HibernatingWebSocketMetadata, + RunnerApp, WebSocketContext, WebSocketMessage, + }, + protocol, +}; +use anyhow::{Context, Result, anyhow, bail}; +use async_stream::stream; +use base64::Engine; +use bytes::Bytes; use futures_util::{SinkExt, StreamExt}; +use http::{Request, Response, StatusCode}; use rivet_runner_protocol::mk2 as rp; use std::{ collections::HashMap, + hash::{Hash, Hasher}, sync::{ Arc, - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicU32, Ordering}, }, time::Duration, }; -use tokio::sync::{Mutex, broadcast, mpsc, oneshot}; +use tokio::{ + sync::{Mutex, Notify, broadcast, mpsc, oneshot}, + task::JoinHandle, +}; use tokio_tungstenite::{connect_async, tungstenite::Message}; +use urlencoding::encode; +use vbare::OwnedVersionedData; -const RUNNER_PING_INTERVAL: Duration = Duration::from_secs(15); - -type ActorFactory = Arc Box + Send + Sync>; -type WsStream = - tokio_tungstenite::WebSocketStream>; +const COMMAND_ACK_INTERVAL: Duration = Duration::from_secs(5 * 60); +const RECONNECT_INITIAL_DELAY_MS: u64 = 1_000; +const RECONNECT_MAX_DELAY_MS: u64 = 30_000; -/// Lifecycle events for actors that tests can subscribe to #[derive(Debug, Clone)] pub enum ActorLifecycleEvent { Started { actor_id: String, generation: u32 }, Stopped { actor_id: String, generation: u32 }, } -/// Configuration for the engine runner. -/// -/// This matches the TypeScript RunnerConfig interface. -#[derive(Clone)] +#[derive(Clone, Debug, Default)] +pub struct PrepopulateActorName { + pub metadata: serde_json::Value, +} + +#[derive(Clone, Debug)] pub struct RunnerConfig { - /// The endpoint URL to connect to (e.g., "http://127.0.0.1:8080") pub endpoint: String, - /// Authentication token pub token: String, - /// Namespace to connect to pub namespace: String, - /// Name of this runner (machine-readable) pub runner_name: String, - /// Unique key for this runner instance pub runner_key: String, - /// Protocol version number pub version: u32, - /// Total number of actor slots this runner supports pub total_slots: u32, - /// Optional metadata to attach to the runner + pub pegboard_endpoint: Option, + pub prepopulate_actor_names: HashMap, pub metadata: Option, } impl RunnerConfig { - /// Create a new builder for RunnerConfig pub fn builder() -> RunnerConfigBuilder { RunnerConfigBuilder::default() } } -/// Builder for RunnerConfig #[derive(Default)] pub struct RunnerConfigBuilder { endpoint: Option, @@ -66,6 +73,8 @@ pub struct RunnerConfigBuilder { runner_key: Option, version: Option, total_slots: Option, + pegboard_endpoint: Option, + prepopulate_actor_names: HashMap, metadata: Option, } @@ -85,13 +94,13 @@ impl RunnerConfigBuilder { self } - pub fn runner_name(mut self, name: impl Into) -> Self { - self.runner_name = Some(name.into()); + pub fn runner_name(mut self, runner_name: impl Into) -> Self { + self.runner_name = Some(runner_name.into()); self } - pub fn runner_key(mut self, key: impl Into) -> Self { - self.runner_key = Some(key.into()); + pub fn runner_key(mut self, runner_key: impl Into) -> Self { + self.runner_key = Some(runner_key.into()); self } @@ -100,8 +109,23 @@ impl RunnerConfigBuilder { self } - pub fn total_slots(mut self, slots: u32) -> Self { - self.total_slots = Some(slots); + pub fn total_slots(mut self, total_slots: u32) -> Self { + self.total_slots = Some(total_slots); + self + } + + pub fn pegboard_endpoint(mut self, endpoint: impl Into) -> Self { + self.pegboard_endpoint = Some(endpoint.into()); + self + } + + pub fn prepopulate_actor_name( + mut self, + name: impl Into, + metadata: serde_json::Value, + ) -> Self { + self.prepopulate_actor_names + .insert(name.into(), PrepopulateActorName { metadata }); self } @@ -123,602 +147,766 @@ impl RunnerConfigBuilder { .unwrap_or_else(|| format!("key-{:012x}", rand::random::())), version: self.version.unwrap_or(1), total_slots: self.total_slots.unwrap_or(100), + pegboard_endpoint: self.pegboard_endpoint, + prepopulate_actor_names: self.prepopulate_actor_names, metadata: self.metadata, }) } } -/// Internal configuration with actor factories -#[derive(Clone)] -struct InternalConfig { - namespace: String, - runner_name: String, - runner_key: String, - version: u32, - total_slots: u32, - endpoint: String, - token: String, - actor_factories: HashMap, -} - -/// Engine runner for programmatic actor lifecycle control -pub struct Runner { - config: InternalConfig, - - // State - pub runner_id: Arc>>, - actors: Arc>>, - /// Per-actor event indices for MK2 checkpoints - actor_event_indices: Arc>>, - event_history: Arc>>, - shutdown: Arc, - is_child_task: bool, - - // Event channel for actors to push events - event_tx: mpsc::UnboundedSender, - event_rx: Arc>>, - - // KV request channel for actors to send KV requests - kv_request_tx: mpsc::UnboundedSender, - kv_request_rx: Arc>>, - next_kv_request_id: Arc>, - kv_pending_requests: Arc>>>, - - // Lifecycle event broadcast channel - lifecycle_tx: broadcast::Sender, - - // Shutdown channel - shutdown_tx: Arc>>>, -} - -struct ActorState { - #[allow(dead_code)] - actor_id: String, - #[allow(dead_code)] - generation: u32, - actor: Box, -} - -/// Builder for creating a Runner instance pub struct RunnerBuilder { config: RunnerConfig, - actor_factories: HashMap, + app: Option>, } impl RunnerBuilder { - /// Create a new RunnerBuilder with the given configuration pub fn new(config: RunnerConfig) -> Self { - Self { - config, - actor_factories: HashMap::new(), - } - } - - /// Create a new RunnerBuilder from a namespace (for backwards compatibility) - /// - /// This is a convenience method that requires endpoint and token to be set later. - pub fn from_namespace(namespace: &str) -> RunnerBuilderLegacy { - RunnerBuilderLegacy { - namespace: namespace.to_string(), - runner_name: "engine-runner".to_string(), - runner_key: format!("key-{:012x}", rand::random::()), - version: 1, - total_slots: 100, - actor_factories: HashMap::new(), - } + Self { config, app: None } } - /// Register an actor factory for a specific actor name - pub fn with_actor_behavior(mut self, actor_name: &str, factory: F) -> Self + pub fn app(mut self, app: A) -> Self where - F: Fn(ActorConfig) -> Box + Send + Sync + 'static, + A: RunnerApp, { - self.actor_factories - .insert(actor_name.to_string(), Arc::new(factory)); + self.app = Some(Arc::new(app)); self } - /// Build the Runner instance - pub fn build(self) -> Result { - let config = InternalConfig { - namespace: self.config.namespace, - runner_name: self.config.runner_name, - runner_key: self.config.runner_key, - version: self.config.version, - total_slots: self.config.total_slots, - endpoint: self.config.endpoint, - token: self.config.token, - actor_factories: self.actor_factories, - }; - - // Create event channel for actors to push events - let (event_tx, event_rx) = mpsc::unbounded_channel(); + pub fn app_arc(mut self, app: Arc) -> Self { + self.app = Some(app); + self + } - // Create KV request channel for actors to send KV requests - let (kv_request_tx, kv_request_rx) = mpsc::unbounded_channel(); + pub fn build(self) -> Result { + let app = self + .app + .context("runner app is required; call RunnerBuilder::app")?; - // Create lifecycle event broadcast channel (capacity of 100 for buffering) - let (lifecycle_tx, _) = broadcast::channel(100); + let (lifecycle_tx, _) = broadcast::channel(128); - Ok(Runner { - config, + let inner = Arc::new(RunnerInner { + config: self.config, + app, runner_id: Arc::new(Mutex::new(None)), - actors: Arc::new(Mutex::new(HashMap::new())), - actor_event_indices: Arc::new(Mutex::new(HashMap::new())), - event_history: Arc::new(Mutex::new(Vec::new())), + ready_notify: Arc::new(Notify::new()), + shutdown_notify: Arc::new(Notify::new()), shutdown: Arc::new(AtomicBool::new(false)), - is_child_task: false, - event_tx, - event_rx: Arc::new(Mutex::new(event_rx)), - kv_request_tx, - kv_request_rx: Arc::new(Mutex::new(kv_request_rx)), - next_kv_request_id: Arc::new(Mutex::new(0)), - kv_pending_requests: Arc::new(Mutex::new(HashMap::new())), + started: Arc::new(AtomicBool::new(false)), + ws_sender: Arc::new(Mutex::new(None)), + actors: Arc::new(Mutex::new(HashMap::new())), lifecycle_tx, - shutdown_tx: Arc::new(Mutex::new(None)), + next_kv_request_id: Arc::new(AtomicU32::new(0)), + pending_kv_requests: Arc::new(Mutex::new(HashMap::new())), + pending_http_requests: Arc::new(Mutex::new(HashMap::new())), + tunnel_message_indices: Arc::new(Mutex::new(HashMap::new())), + websockets: Arc::new(Mutex::new(HashMap::new())), + }); + + Ok(Runner { + inner, + task: Arc::new(Mutex::new(None)), }) } } -/// Legacy builder for backwards compatibility with test code -pub struct RunnerBuilderLegacy { - namespace: String, - runner_name: String, - runner_key: String, - version: u32, - total_slots: u32, - actor_factories: HashMap, +#[derive(Clone)] +pub struct Runner { + inner: Arc, + task: Arc>>>, } -impl RunnerBuilderLegacy { - pub fn with_runner_name(mut self, name: &str) -> Self { - self.runner_name = name.to_string(); - self +impl Runner { + pub fn builder(config: RunnerConfig) -> RunnerBuilder { + RunnerBuilder::new(config) } - pub fn with_runner_key(mut self, key: &str) -> Self { - self.runner_key = key.to_string(); - self + pub fn handle(&self) -> RunnerHandle { + RunnerHandle { + inner: self.inner.clone(), + } } - pub fn with_version(mut self, version: u32) -> Self { - self.version = version; - self + pub fn subscribe_lifecycle_events(&self) -> broadcast::Receiver { + self.inner.lifecycle_tx.subscribe() } - pub fn with_total_slots(mut self, total_slots: u32) -> Self { - self.total_slots = total_slots; - self + pub async fn start(&self) -> Result<()> { + if self.inner.started.swap(true, Ordering::SeqCst) { + bail!("runner.start called more than once") + } + + self.inner.shutdown.store(false, Ordering::SeqCst); + let inner = self.inner.clone(); + let join = tokio::spawn(async move { + if let Err(err) = inner.run().await { + tracing::error!(?err, "runner terminated with error"); + } + }); + + *self.task.lock().await = Some(join); + Ok(()) } - /// Register an actor factory for a specific actor name - pub fn with_actor_behavior(mut self, actor_name: &str, factory: F) -> Self - where - F: Fn(ActorConfig) -> Box + Send + Sync + 'static, - { - self.actor_factories - .insert(actor_name.to_string(), Arc::new(factory)); - self + pub async fn wait_ready(&self) -> Result { + loop { + if let Some(runner_id) = self.inner.runner_id.lock().await.clone() { + return Ok(runner_id); + } + self.inner.ready_notify.notified().await; + } } - /// Build the Runner with an endpoint and token - pub fn build_with_endpoint(self, endpoint: &str, token: &str) -> Result { - let config = InternalConfig { - namespace: self.namespace, - runner_name: self.runner_name, - runner_key: self.runner_key, - version: self.version, - total_slots: self.total_slots, - endpoint: endpoint.to_string(), - token: token.to_string(), - actor_factories: self.actor_factories, - }; + pub async fn shutdown(&self, _immediate: bool) -> Result<()> { + self.inner.shutdown.store(true, Ordering::SeqCst); + self.inner.shutdown_notify.notify_waiters(); - // Create event channel for actors to push events - let (event_tx, event_rx) = mpsc::unbounded_channel(); + if let Some(join) = self.task.lock().await.take() { + let _ = join.await; + } - // Create KV request channel for actors to send KV requests - let (kv_request_tx, kv_request_rx) = mpsc::unbounded_channel(); + self.inner + .app + .on_shutdown(self.handle()) + .await + .context("runner app shutdown callback failed")?; - // Create lifecycle event broadcast channel (capacity of 100 for buffering) - let (lifecycle_tx, _) = broadcast::channel(100); + Ok(()) + } - Ok(Runner { - config, - runner_id: Arc::new(Mutex::new(None)), - actors: Arc::new(Mutex::new(HashMap::new())), - actor_event_indices: Arc::new(Mutex::new(HashMap::new())), - event_history: Arc::new(Mutex::new(Vec::new())), - shutdown: Arc::new(AtomicBool::new(false)), - is_child_task: false, - event_tx, - event_rx: Arc::new(Mutex::new(event_rx)), - kv_request_tx, - kv_request_rx: Arc::new(Mutex::new(kv_request_rx)), - next_kv_request_id: Arc::new(Mutex::new(0)), - kv_pending_requests: Arc::new(Mutex::new(HashMap::new())), - lifecycle_tx, - shutdown_tx: Arc::new(Mutex::new(None)), - }) + pub async fn crash(&self) -> Result<()> { + self.inner.shutdown.store(true, Ordering::SeqCst); + self.inner.shutdown_notify.notify_waiters(); + Ok(()) } -} -impl Runner { - /// Subscribe to actor lifecycle events - pub fn subscribe_lifecycle_events(&self) -> broadcast::Receiver { - self.lifecycle_tx.subscribe() + pub fn name(&self) -> &str { + &self.inner.config.runner_name } - /// Start the runner - pub async fn start(&self) -> Result<()> { - tracing::info!( - namespace = %self.config.namespace, - runner_name = %self.config.runner_name, - runner_key = %self.config.runner_key, - "starting engine runner" - ); + pub async fn get_serverless_init_packet(&self) -> Result> { + self.handle().get_serverless_init_packet().await + } +} - let ws_url = self.build_ws_url(); +#[derive(Clone)] +pub struct RunnerHandle { + inner: Arc, +} - tracing::debug!(ws_url = %ws_url, "connecting to pegboard"); +impl RunnerHandle { + pub fn name(&self) -> &str { + &self.inner.config.runner_name + } - // Connect to WebSocket with protocols - let token_protocol = format!("rivet_token.{}", self.config.token); + pub async fn runner_id(&self) -> Option { + self.inner.runner_id.lock().await.clone() + } - // Build the request properly with all WebSocket headers - use tokio_tungstenite::tungstenite::client::IntoClientRequest; - let mut request = ws_url - .into_client_request() - .context("failed to build WebSocket request")?; + pub async fn has_actor(&self, actor_id: &str, generation: Option) -> bool { + let actors = self.inner.actors.lock().await; + actors + .get(actor_id) + .map(|x| generation.map(|g| g == x.generation).unwrap_or(true)) + .unwrap_or(false) + } - // Add the Sec-WebSocket-Protocol header - request.headers_mut().insert( - "Sec-WebSocket-Protocol", - format!("rivet, {}", token_protocol).parse().unwrap(), - ); + pub async fn get_actor_ids(&self) -> Vec { + let actors = self.inner.actors.lock().await; + actors.keys().cloned().collect() + } - let (ws_stream, _response) = connect_async(request) + pub async fn sleep_actor(&self, actor_id: &str, generation: Option) -> Result<()> { + self.inner + .send_actor_intent(actor_id, generation, rp::ActorIntent::ActorIntentSleep) .await - .context("failed to connect to WebSocket")?; - - tracing::info!("websocket connected"); - - // Create shutdown channel - let (shutdown_tx, shutdown_rx) = oneshot::channel(); - *self.shutdown_tx.lock().await = Some(shutdown_tx); - - // Clone self for the spawned task - let runner = self.clone_for_task(); + } - tokio::spawn(async move { - if let Err(err) = runner.run_message_loop(ws_stream, shutdown_rx).await { - tracing::error!(?err, "engine runner message loop failed"); - } - }); + pub async fn stop_actor(&self, actor_id: &str, generation: Option) -> Result<()> { + self.inner + .send_actor_intent(actor_id, generation, rp::ActorIntent::ActorIntentStop) + .await + } - Ok(()) + pub async fn set_alarm( + &self, + actor_id: &str, + alarm_ts: Option, + generation: Option, + ) -> Result<()> { + self.inner + .send_alarm_event(actor_id, generation, alarm_ts) + .await } - /// Clone the runner for passing to async tasks - fn clone_for_task(&self) -> Self { - Self { - config: self.config.clone(), - runner_id: self.runner_id.clone(), - actors: self.actors.clone(), - actor_event_indices: self.actor_event_indices.clone(), - event_history: self.event_history.clone(), - is_child_task: true, - shutdown: self.shutdown.clone(), - event_tx: self.event_tx.clone(), - event_rx: self.event_rx.clone(), - kv_request_tx: self.kv_request_tx.clone(), - kv_request_rx: self.kv_request_rx.clone(), - next_kv_request_id: self.next_kv_request_id.clone(), - kv_pending_requests: self.kv_pending_requests.clone(), - lifecycle_tx: self.lifecycle_tx.clone(), - shutdown_tx: self.shutdown_tx.clone(), - } + pub async fn clear_alarm(&self, actor_id: &str, generation: Option) -> Result<()> { + self.set_alarm(actor_id, None, generation).await } - /// Wait for runner to be ready and return runner ID - pub async fn wait_ready(&self) -> String { - // Poll until runner_id is set - loop { - let runner_id = self.runner_id.lock().await; - if let Some(id) = runner_id.as_ref() { - // In MK2, we need to wait for the workflow to process the Init signal - // and mark the runner as eligible for actor allocation. - // This can take some time due to workflow processing: - // 1. Workflow receives Init signal - // 2. Workflow executes MarkEligible activity - // 3. Database is updated with runner allocation index - tokio::time::sleep(Duration::from_millis(2000)).await; - return id.clone(); + pub async fn kv_get( + &self, + actor_id: &str, + keys: Vec>, + ) -> Result>>> { + let requested_keys = keys.clone(); + let response = self + .inner + .send_kv_request( + actor_id, + rp::KvRequestData::KvGetRequest(rp::KvGetRequest { keys }), + ) + .await?; + + match response { + rp::KvResponseData::KvGetResponse(resp) => { + let mut values = Vec::with_capacity(requested_keys.len()); + for key in requested_keys { + let mut value = None; + for (idx, response_key) in resp.keys.iter().enumerate() { + if *response_key == key { + value = resp.values.get(idx).cloned(); + break; + } + } + values.push(value); + } + Ok(values) + } + rp::KvResponseData::KvErrorResponse(err) => { + bail!("kv get failed: {}", err.message) } - drop(runner_id); - tokio::time::sleep(Duration::from_millis(100)).await; + other => bail!("unexpected kv get response: {:?}", other), } } - /// Check if runner has an actor - pub async fn has_actor(&self, actor_id: &str) -> bool { - let actors = self.actors.lock().await; - actors.contains_key(actor_id) + pub async fn kv_list_all( + &self, + actor_id: &str, + reverse: Option, + limit: Option, + ) -> Result, Vec)>> { + self.kv_list( + actor_id, + rp::KvListQuery::KvListAllQuery, + reverse, + limit, + ) + .await } - /// Get runner's current actor IDs - pub async fn get_actor_ids(&self) -> Vec { - let actors = self.actors.lock().await; - actors.keys().cloned().collect() + pub async fn kv_list_prefix( + &self, + actor_id: &str, + prefix: Vec, + reverse: Option, + limit: Option, + ) -> Result, Vec)>> { + self.kv_list( + actor_id, + rp::KvListQuery::KvListPrefixQuery(rp::KvListPrefixQuery { key: prefix }), + reverse, + limit, + ) + .await } - pub fn name(&self) -> &str { - &self.config.runner_name + pub async fn kv_list_range( + &self, + actor_id: &str, + start: Vec, + end: Vec, + exclusive: bool, + reverse: Option, + limit: Option, + ) -> Result, Vec)>> { + self.kv_list( + actor_id, + rp::KvListQuery::KvListRangeQuery(rp::KvListRangeQuery { + start, + end, + exclusive, + }), + reverse, + limit, + ) + .await + } + + async fn kv_list( + &self, + actor_id: &str, + query: rp::KvListQuery, + reverse: Option, + limit: Option, + ) -> Result, Vec)>> { + let response = self + .inner + .send_kv_request( + actor_id, + rp::KvRequestData::KvListRequest(rp::KvListRequest { + query, + reverse, + limit, + }), + ) + .await?; + + match response { + rp::KvResponseData::KvListResponse(resp) => Ok(resp.keys.into_iter().zip(resp.values).collect()), + rp::KvResponseData::KvErrorResponse(err) => { + bail!("kv list failed: {}", err.message) + } + other => bail!("unexpected kv list response: {:?}", other), + } } - /// Shutdown the runner gracefully (destroys actors first) - pub async fn shutdown(&self) { - tracing::info!("shutting down engine runner"); - self.shutdown.store(true, Ordering::SeqCst); + pub async fn kv_put(&self, actor_id: &str, entries: Vec<(Vec, Vec)>) -> Result<()> { + let (keys, values): (Vec<_>, Vec<_>) = entries.into_iter().unzip(); + let response = self + .inner + .send_kv_request( + actor_id, + rp::KvRequestData::KvPutRequest(rp::KvPutRequest { keys, values }), + ) + .await?; + + match response { + rp::KvResponseData::KvPutResponse => Ok(()), + rp::KvResponseData::KvErrorResponse(err) => { + bail!("kv put failed: {}", err.message) + } + other => bail!("unexpected kv put response: {:?}", other), + } + } - // Send shutdown signal to close ws_stream - if let Some(tx) = self.shutdown_tx.lock().await.take() { - let _ = tx.send(()); + pub async fn kv_delete(&self, actor_id: &str, keys: Vec>) -> Result<()> { + let response = self + .inner + .send_kv_request( + actor_id, + rp::KvRequestData::KvDeleteRequest(rp::KvDeleteRequest { keys }), + ) + .await?; + + match response { + rp::KvResponseData::KvDeleteResponse => Ok(()), + rp::KvResponseData::KvErrorResponse(err) => { + bail!("kv delete failed: {}", err.message) + } + other => bail!("unexpected kv delete response: {:?}", other), } } - /// Crash the runner without graceful shutdown. - /// This simulates an ungraceful disconnect where the runner stops responding - /// without destroying its actors first. Use this to test RunnerNoResponse errors. - pub async fn crash(&self) { - tracing::info!("crashing engine runner (ungraceful disconnect)"); - self.shutdown.store(true, Ordering::SeqCst); + pub async fn kv_drop(&self, actor_id: &str) -> Result<()> { + let response = self + .inner + .send_kv_request(actor_id, rp::KvRequestData::KvDropRequest) + .await?; - // Just drop the websocket without cleanup - don't send any signals - // The server will detect the disconnect and actors will remain in - // an unresponsive state until they timeout. - if let Some(tx) = self.shutdown_tx.lock().await.take() { - let _ = tx.send(()); + match response { + rp::KvResponseData::KvDropResponse => Ok(()), + rp::KvResponseData::KvErrorResponse(err) => { + bail!("kv drop failed: {}", err.message) + } + other => bail!("unexpected kv drop response: {:?}", other), } + } - // Clear local actor state without notifying server - self.actors.lock().await.clear(); + pub async fn send_hibernatable_websocket_message_ack( + &self, + gateway_id: [u8; 4], + request_id: [u8; 4], + index: u16, + ) -> Result<()> { + self.inner + .send_tunnel_message( + TunnelRequestKey { + gateway_id, + request_id, + }, + rp::ToServerTunnelMessageKind::ToServerWebSocketMessageAck( + rp::ToServerWebSocketMessageAck { index }, + ), + ) + .await } - fn build_ws_url(&self) -> String { - let ws_endpoint = self.config.endpoint.replace("http://", "ws://"); - format!( - "{}/runners/connect?protocol_version={}&namespace={}&runner_key={}", - ws_endpoint.trim_end_matches('/'), - protocol::PROTOCOL_VERSION, - urlencoding::encode(&self.config.namespace), - urlencoding::encode(&self.config.runner_key) - ) + pub async fn send_websocket_message( + &self, + gateway_id: [u8; 4], + request_id: [u8; 4], + data: Vec, + binary: bool, + ) -> Result<()> { + self.inner + .send_tunnel_message( + TunnelRequestKey { + gateway_id, + request_id, + }, + rp::ToServerTunnelMessageKind::ToServerWebSocketMessage(rp::ToServerWebSocketMessage { + data, + binary, + }), + ) + .await } - fn build_init_message(&self) -> rp::ToServer { - // MK2 init doesn't have lastCommandIdx - uses checkpoints instead - rp::ToServer::ToServerInit(rp::ToServerInit { - name: self.config.runner_name.clone(), - version: self.config.version, - total_slots: self.config.total_slots, - prepopulate_actor_names: None, - metadata: None, - }) + pub async fn close_websocket( + &self, + gateway_id: [u8; 4], + request_id: [u8; 4], + code: Option, + reason: Option, + hibernate: bool, + ) -> Result<()> { + self.inner + .send_tunnel_message( + TunnelRequestKey { + gateway_id, + request_id, + }, + rp::ToServerTunnelMessageKind::ToServerWebSocketClose(rp::ToServerWebSocketClose { + code, + reason, + hibernate, + }), + ) + .await } - async fn run_message_loop( - self, - mut ws_stream: WsStream, - mut shutdown_rx: oneshot::Receiver<()>, + pub async fn restore_hibernating_requests( + &self, + actor_id: &str, + meta_entries: Vec, ) -> Result<()> { - // Send init message - let init_msg = self.build_init_message(); - let encoded = protocol::encode_to_server(init_msg); - ws_stream - .send(Message::Binary(encoded.into())) + self.inner + .restore_hibernating_requests(actor_id, meta_entries) .await - .context("failed to send init message")?; + } - tracing::debug!("sent init message"); + pub async fn get_serverless_init_packet(&self) -> Result> { + let Some(runner_id) = self.runner_id().await else { + return Ok(None); + }; - let mut ping_interval = tokio::time::interval(RUNNER_PING_INTERVAL); - // We lock here as these rx's are only for run_message_loop - let mut event_rx = self.event_rx.lock().await; - let mut kv_request_rx = self.kv_request_rx.lock().await; + let payload = rivet_runner_protocol::versioned::ToServerlessServer::wrap_latest( + rp::ToServerlessServer::ToServerlessServerInit(rp::ToServerlessServerInit { + runner_id, + runner_protocol_version: protocol::PROTOCOL_VERSION, + }), + ) + .serialize_with_embedded_version(protocol::PROTOCOL_VERSION)?; - loop { - tokio::select! { - biased; - _ = &mut shutdown_rx => { - tracing::info!("received shutdown signal, closing websocket"); - let _ = ws_stream.close(None).await; - break; - } + Ok(Some(base64::engine::general_purpose::STANDARD.encode(payload))) + } +} - _ = ping_interval.tick() => { - if self.shutdown.load(Ordering::SeqCst) { - break; - } +#[derive(Clone)] +struct RunnerInner { + config: RunnerConfig, + app: Arc, + runner_id: Arc>>, + ready_notify: Arc, + shutdown_notify: Arc, + shutdown: Arc, + started: Arc, + ws_sender: Arc>>>, + actors: Arc>>, + lifecycle_tx: broadcast::Sender, + next_kv_request_id: Arc, + pending_kv_requests: Arc>>>, + pending_http_requests: Arc>>, + tunnel_message_indices: Arc>>, + websockets: Arc>>, +} - // Send pong (MK2 uses ToServerPong instead of ToServerPing) - let pong = rp::ToServer::ToServerPong(rp::ToServerPong { - ts: chrono::Utc::now().timestamp_millis(), - }); - let encoded = protocol::encode_to_server(pong); - ws_stream.send(Message::Binary(encoded.into())).await?; - } +#[derive(Clone, Debug)] +struct ActorRuntimeState { + generation: u32, + config: ActorConfig, + started: bool, + start_notify: Arc, + last_command_idx: i64, + next_event_idx: i64, + event_history: Vec, + hibernating_requests: Vec, + hibernation_restored: bool, +} - // Listen for events pushed from actors - Some(actor_event) = event_rx.recv() => { - if self.shutdown.load(Ordering::SeqCst) { - tracing::info!("shutting down"); - break; - } +#[derive(Clone, Debug)] +struct PendingHttpRequest { + actor_id: String, + method: String, + path: String, + headers: HashMap, + body: Vec, +} + +#[derive(Clone, Debug)] +struct WebSocketRuntimeState { + actor_id: String, + generation: u32, + actor_name: String, + path: String, + headers: HashMap, + is_hibernatable: bool, + is_restoring_hibernatable: bool, + server_message_index: u16, +} + +#[derive(Clone, Copy, Eq)] +struct TunnelRequestKey { + gateway_id: [u8; 4], + request_id: [u8; 4], +} - tracing::debug!( - actor_id = ?actor_event.actor_id, - generation = actor_event.generation, - "received event from actor" - ); +impl PartialEq for TunnelRequestKey { + fn eq(&self, other: &Self) -> bool { + self.gateway_id == other.gateway_id && self.request_id == other.request_id + } +} - self.send_actor_event(&mut ws_stream, actor_event).await?; - } +impl Hash for TunnelRequestKey { + fn hash(&self, state: &mut H) { + self.gateway_id.hash(state); + self.request_id.hash(state); + } +} - // Listen for KV requests from actors - Some(kv_request) = kv_request_rx.recv() => { - if self.shutdown.load(Ordering::SeqCst) { - break; +impl RunnerInner { + async fn run(self: Arc) -> Result<()> { + let mut reconnect_delay_ms = RECONNECT_INITIAL_DELAY_MS; + + while !self.shutdown.load(Ordering::SeqCst) { + let disconnect_info = self.run_connection_loop().await; + + match disconnect_info { + Ok(Some((code, reason))) => { + if let Err(err) = self + .app + .on_disconnected( + RunnerHandle { + inner: self.clone(), + }, + code, + reason, + ) + .await + { + tracing::error!(?err, "runner app on_disconnected callback failed"); } + } + Ok(None) => {} + Err(err) => { + tracing::warn!(?err, "connection loop failed"); + } + } + + if self.shutdown.load(Ordering::SeqCst) { + break; + } + + tokio::time::sleep(Duration::from_millis(reconnect_delay_ms)).await; + reconnect_delay_ms = (reconnect_delay_ms.saturating_mul(2)).min(RECONNECT_MAX_DELAY_MS); + } + + Ok(()) + } + + async fn run_connection_loop(&self) -> Result> { + let ws_url = self.build_ws_url(); + let ws_stream = self.connect_websocket(&ws_url).await?; + + let (ws_tx, mut ws_rx) = mpsc::unbounded_channel::(); + *self.ws_sender.lock().await = Some(ws_tx.clone()); - tracing::debug!( - actor_id = ?kv_request.actor_id, - "received kv request from actor" - ); + let (mut ws_write, mut ws_read) = ws_stream.split(); - self.send_kv_request(&mut ws_stream, kv_request).await?; + let writer = tokio::spawn(async move { + while let Some(message) = ws_rx.recv().await { + let encoded = protocol::encode_to_server(message); + if let Err(err) = ws_write.send(Message::Binary(encoded.into())).await { + return Err::<(), anyhow::Error>(err.into()); } + } + Ok(()) + }); - msg = ws_stream.next() => { - if self.shutdown.load(Ordering::SeqCst) { - break; - } + // Send init as soon as the socket is ready. + ws_tx + .send(self.build_init_message()?) + .map_err(|_| anyhow!("failed to queue init message"))?; + + let mut ack_interval = tokio::time::interval(COMMAND_ACK_INTERVAL); + let mut disconnect_info = None; - match msg { - Some(std::result::Result::Ok(Message::Binary(buf))) => { - self.handle_message(&mut ws_stream, &buf).await?; + loop { + tokio::select! { + _ = self.shutdown_notify.notified() => { + let _ = ws_tx.send(rp::ToServer::ToServerStopping); + break; + } + _ = ack_interval.tick() => { + if let Err(err) = self.send_command_acknowledgement().await { + tracing::warn!(?err, "failed to send command acknowledgement"); + } + } + incoming = ws_read.next() => { + match incoming { + Some(Ok(Message::Binary(payload))) => { + self.handle_incoming_message(&payload).await?; } - Some(std::result::Result::Ok(Message::Close(_))) => { - tracing::info!("websocket closed by server"); + Some(Ok(Message::Close(frame))) => { + let (code, reason) = frame + .map(|f| (u16::from(f.code), f.reason.to_string())) + .unwrap_or((1000, String::new())); + disconnect_info = Some((code, reason)); break; } - Some(std::result::Result::Err(err)) => { - tracing::error!(?err, "websocket error"); + Some(Ok(_)) => {} + Some(Err(err)) => { return Err(err.into()); } None => { - tracing::info!("websocket stream ended"); + disconnect_info = Some((1006, "connection closed".to_string())); break; } - _ => {} } } } } - tracing::info!("engine runner message loop exiting"); - Ok(()) + *self.ws_sender.lock().await = None; + drop(ws_tx); + + if let Err(err) = writer.await { + tracing::warn!(?err, "writer task join error"); + } + + Ok(disconnect_info) } - /// Send an event pushed from an actor - async fn send_actor_event( + async fn connect_websocket( &self, - ws_stream: &mut WsStream, - actor_event: ActorEvent, - ) -> Result<()> { - // Get next event index for this actor (MK2 uses per-actor checkpoints) - let mut indices = self.actor_event_indices.lock().await; - let idx = indices.entry(actor_event.actor_id.clone()).or_insert(-1); - *idx += 1; - let event_idx = *idx; - drop(indices); - - let event_wrapper = protocol::make_event_wrapper( - &actor_event.actor_id, - actor_event.generation, - event_idx as u64, - actor_event.event, - ); + ws_url: &str, + ) -> Result>> { + use tokio_tungstenite::tungstenite::client::IntoClientRequest; - self.event_history.lock().await.push(event_wrapper.clone()); + let mut request = ws_url + .into_client_request() + .context("failed to build websocket request")?; - tracing::debug!( - actor_id = ?actor_event.actor_id, - generation = actor_event.generation, - event_idx = event_idx, - "sending actor event" + let protocol_header = format!("rivet, rivet_token.{}", self.config.token); + request.headers_mut().insert( + "Sec-WebSocket-Protocol", + protocol_header.parse().context("invalid protocol header")?, ); - let msg = rp::ToServer::ToServerEvents(vec![event_wrapper]); - let encoded = protocol::encode_to_server(msg); - ws_stream.send(Message::Binary(encoded.into())).await?; - - Ok(()) + let (ws_stream, _response) = connect_async(request) + .await + .context("failed to connect to pegboard websocket")?; + Ok(ws_stream) } - async fn handle_message(&self, ws_stream: &mut WsStream, buf: &[u8]) -> Result<()> { - let msg = protocol::decode_to_client(buf, protocol::PROTOCOL_VERSION)?; + fn build_ws_url(&self) -> String { + let endpoint = self + .config + .pegboard_endpoint + .as_ref() + .unwrap_or(&self.config.endpoint) + .replace("http://", "ws://") + .replace("https://", "wss://"); - match msg { - rp::ToClient::ToClientInit(init) => { - self.handle_init(init, ws_stream).await?; - } - rp::ToClient::ToClientCommands(commands) => { - self.handle_commands(commands, ws_stream).await?; - } - rp::ToClient::ToClientAckEvents(ack) => { + let base = endpoint.trim_end_matches('/'); + format!( + "{base}/runners/connect?protocol_version={}&namespace={}&runner_key={}", + protocol::PROTOCOL_VERSION, + encode(&self.config.namespace), + encode(&self.config.runner_key) + ) + } + + fn build_init_message(&self) -> Result { + let prepopulate_actor_names = if self.config.prepopulate_actor_names.is_empty() { + None + } else { + let mut map = HashMap::new(); + for (name, entry) in &self.config.prepopulate_actor_names { + map.insert( + name.clone(), + rp::ActorName { + metadata: serde_json::to_string(&entry.metadata)?, + }, + ); + } + Some(map.into()) + }; + + let metadata = self + .config + .metadata + .as_ref() + .map(serde_json::to_string) + .transpose()?; + + Ok(rp::ToServer::ToServerInit(rp::ToServerInit { + name: self.config.runner_name.clone(), + version: self.config.version, + total_slots: self.config.total_slots, + prepopulate_actor_names, + metadata, + })) + } + + async fn handle_incoming_message(&self, payload: &[u8]) -> Result<()> { + let message = protocol::decode_to_client(payload, protocol::PROTOCOL_VERSION)?; + + match message { + rp::ToClient::ToClientInit(init) => { + *self.runner_id.lock().await = Some(init.runner_id); + self.ready_notify.notify_waiters(); + self.resend_unacknowledged_events().await?; + self.app + .on_connected(RunnerHandle { + inner: Arc::new(self.clone_for_handle()), + }) + .await + .context("runner app on_connected callback failed")?; + } + rp::ToClient::ToClientCommands(commands) => { + self.handle_commands(commands).await?; + } + rp::ToClient::ToClientAckEvents(ack) => { self.handle_ack_events(ack).await; } rp::ToClient::ToClientKvResponse(response) => { self.handle_kv_response(response).await; } - _ => { - tracing::debug!(?msg, "ignoring message type"); + rp::ToClient::ToClientTunnelMessage(message) => { + self.handle_tunnel_message(message).await?; + } + rp::ToClient::ToClientPing(ping) => { + self.send_to_server(rp::ToServer::ToServerPong(rp::ToServerPong { ts: ping.ts })) + .await?; } } Ok(()) } - async fn handle_init(&self, init: rp::ToClientInit, _ws_stream: &mut WsStream) -> Result<()> { - tracing::info!( - runner_id = %init.runner_id, - "received init from server" - ); - - *self.runner_id.lock().await = Some(init.runner_id.clone()); - - // MK2 doesn't have lastEventIdx in init - events are acked via checkpoints - // For simplicity, we don't resend events on reconnect in the engine runner - - Ok(()) + fn clone_for_handle(&self) -> RunnerInner { + self.clone() } - async fn handle_commands( - &self, - commands: Vec, - ws_stream: &mut WsStream, - ) -> Result<()> { - tracing::info!(count = commands.len(), "received commands"); - - for cmd_wrapper in commands { - let checkpoint = &cmd_wrapper.checkpoint; - tracing::debug!( - actor_id = %checkpoint.actor_id, - generation = checkpoint.generation, - index = checkpoint.index, - command = ?cmd_wrapper.inner, - "processing command" - ); - - match cmd_wrapper.inner { - rp::Command::CommandStartActor(start_cmd) => { - self.handle_start_actor( - checkpoint.actor_id.clone(), - checkpoint.generation, - start_cmd, - ws_stream, - ) - .await?; + async fn handle_commands(&self, commands: Vec) -> Result<()> { + for command in commands { + let checkpoint = command.checkpoint.clone(); + match command.inner { + rp::Command::CommandStartActor(start) => { + self.handle_command_start_actor(checkpoint, start).await?; } rp::Command::CommandStopActor => { - // MK2 CommandStopActor is void - actor info is in checkpoint - self.handle_stop_actor( - checkpoint.actor_id.clone(), - checkpoint.generation, - ws_stream, - ) - .await?; + self.handle_command_stop_actor(checkpoint).await?; } } } @@ -726,354 +914,1283 @@ impl Runner { Ok(()) } - async fn handle_start_actor( + async fn handle_command_start_actor( &self, - actor_id: String, - generation: u32, - cmd: rp::CommandStartActor, - _ws_stream: &mut WsStream, + checkpoint: rp::ActorCheckpoint, + start: rp::CommandStartActor, ) -> Result<()> { - tracing::info!(?actor_id, generation, name = %cmd.config.name, "starting actor"); + let actor_id = checkpoint.actor_id.clone(); + let hibernating_requests = start + .hibernating_requests + .into_iter() + .map(|x| HibernatingRequest { + gateway_id: x.gateway_id, + request_id: x.request_id, + }) + .collect::>(); - // Create actor config - let config = ActorConfig::new( - &cmd.config, - actor_id.clone(), - generation, - self.event_tx.clone(), - self.kv_request_tx.clone(), - ); + let actor_config = ActorConfig { + name: start.config.name, + key: start.config.key, + create_ts: start.config.create_ts, + input: start.config.input, + }; + + let ctx = ActorContext { + actor_id: actor_id.clone(), + generation: checkpoint.generation, + actor_name: actor_config.name.clone(), + config: actor_config.clone(), + hibernating_requests: hibernating_requests.clone(), + }; + + let state = ActorRuntimeState { + generation: checkpoint.generation, + config: actor_config, + started: false, + start_notify: Arc::new(Notify::new()), + last_command_idx: checkpoint.index, + next_event_idx: 0, + event_history: Vec::new(), + hibernating_requests, + hibernation_restored: false, + }; + + self.actors.lock().await.insert(actor_id.clone(), state); + + let runner_handle = RunnerHandle { + inner: Arc::new(self.clone_for_handle()), + }; + let app = self.app.clone(); + let inner = Arc::new(self.clone_for_handle()); - // Get factory for this actor name - let factory = self - .config - .actor_factories - .get(&cmd.config.name) - .context(format!( - "no factory registered for actor name: {}", - cmd.config.name - ))? - .clone(); - - // Clone self for the spawned task - let runner = self.clone_for_task(); - let actor_id_clone = actor_id.clone(); - - // Spawn actor execution in separate task to avoid blocking message loop tokio::spawn(async move { - // Create actor - let mut actor = factory(config.clone()); + let result = app.on_actor_start(runner_handle.clone(), ctx.clone()).await; - tracing::debug!( - ?actor_id, - generation, - actor_type = actor.name(), - "created actor instance" - ); + match result { + Ok(()) => { + if let Err(err) = inner.mark_actor_started(&ctx.actor_id, ctx.generation).await { + tracing::error!(?err, "failed to mark actor as started"); + return; + } - // Call on_start - let start_result = match actor.on_start(config).await { - std::result::Result::Ok(result) => result, + if let Err(err) = inner + .send_actor_state_update( + &ctx.actor_id, + ctx.generation, + rp::ActorState::ActorStateRunning, + ) + .await + { + tracing::error!(?err, "failed to send actor running state"); + } + + let _ = inner.lifecycle_tx.send(ActorLifecycleEvent::Started { + actor_id: ctx.actor_id.clone(), + generation: ctx.generation, + }); + } Err(err) => { - tracing::error!(?actor_id_clone, generation, ?err, "actor on_start failed"); - return; + tracing::error!(?err, actor_id = %ctx.actor_id, "actor start callback failed"); + + let _ = inner + .send_actor_state_update( + &ctx.actor_id, + ctx.generation, + rp::ActorState::ActorStateStopped(rp::ActorStateStopped { + code: rp::StopCode::Error, + message: Some(err.to_string()), + }), + ) + .await; + + inner.remove_actor_websockets(&ctx.actor_id).await; + inner.actors.lock().await.remove(&ctx.actor_id); } - }; + } + }); - tracing::debug!( - ?actor_id_clone, - generation, - ?start_result, - "actor on_start completed" - ); + Ok(()) + } + + async fn handle_command_stop_actor(&self, checkpoint: rp::ActorCheckpoint) -> Result<()> { + let (actor_name, config, hibernating_requests) = { + let actors = self.actors.lock().await; + let state = actors + .get(&checkpoint.actor_id) + .with_context(|| format!("actor {} not found", checkpoint.actor_id))?; + ( + state.config.name.clone(), + state.config.clone(), + state.hibernating_requests.clone(), + ) + }; + + let ctx = ActorContext { + actor_id: checkpoint.actor_id.clone(), + generation: checkpoint.generation, + actor_name, + config, + hibernating_requests, + }; + + self.app + .on_actor_stop( + RunnerHandle { + inner: Arc::new(self.clone_for_handle()), + }, + ctx.clone(), + ) + .await + .context("actor stop callback failed")?; + + self.send_actor_state_update( + &checkpoint.actor_id, + checkpoint.generation, + rp::ActorState::ActorStateStopped(rp::ActorStateStopped { + code: rp::StopCode::Ok, + message: None, + }), + ) + .await?; + + self.remove_actor_websockets(&checkpoint.actor_id).await; + self.actors.lock().await.remove(&checkpoint.actor_id); - runner - .handle_actor_start_result(actor_id_clone, generation, actor, start_result) - .await; + let _ = self.lifecycle_tx.send(ActorLifecycleEvent::Stopped { + actor_id: checkpoint.actor_id, + generation: checkpoint.generation, }); Ok(()) } - async fn handle_actor_start_result( + async fn mark_actor_started(&self, actor_id: &str, generation: u32) -> Result<()> { + let mut actors = self.actors.lock().await; + let state = actors + .get_mut(actor_id) + .with_context(|| format!("actor {actor_id} not found"))?; + if state.generation != generation { + bail!("actor generation mismatch"); + } + state.started = true; + state.start_notify.notify_waiters(); + Ok(()) + } + + async fn wait_for_actor_started( &self, - actor_id: String, - generation: u32, - actor: Box, - start_result: ActorStartResult, - ) { - // Broadcast lifecycle event - tracing::info!("lifecycle_tx start"); - let _ = self.lifecycle_tx.send(ActorLifecycleEvent::Started { - actor_id: actor_id.clone(), + actor_id: &str, + ) -> Option<(u32, ActorConfig)> { + loop { + let notify = { + let actors = self.actors.lock().await; + let state = actors.get(actor_id)?; + if state.started { + return Some((state.generation, state.config.clone())); + } + state.start_notify.clone() + }; + + notify.notified().await; + } + } + + async fn send_actor_intent( + &self, + actor_id: &str, + generation: Option, + intent: rp::ActorIntent, + ) -> Result<()> { + let generation = self.resolve_actor_generation(actor_id, generation).await?; + self.emit_event( + actor_id, generation, - }); + rp::Event::EventActorIntent(rp::EventActorIntent { intent }), + ) + .await + } - // Store actor - let actor_state = ActorState { - actor_id: actor_id.clone(), + async fn send_alarm_event( + &self, + actor_id: &str, + generation: Option, + alarm_ts: Option, + ) -> Result<()> { + let generation = self.resolve_actor_generation(actor_id, generation).await?; + self.emit_event( + actor_id, generation, - actor, + rp::Event::EventActorSetAlarm(rp::EventActorSetAlarm { alarm_ts }), + ) + .await + } + + async fn resolve_actor_generation( + &self, + actor_id: &str, + generation: Option, + ) -> Result { + let actors = self.actors.lock().await; + let actor = actors + .get(actor_id) + .with_context(|| format!("actor {actor_id} not found"))?; + let generation = generation.unwrap_or(actor.generation); + if generation != actor.generation { + bail!( + "actor generation mismatch, expected {}, got {}", + actor.generation, + generation + ) + } + Ok(generation) + } + + async fn send_actor_state_update( + &self, + actor_id: &str, + generation: u32, + state: rp::ActorState, + ) -> Result<()> { + self.emit_event( + actor_id, + generation, + rp::Event::EventActorStateUpdate(rp::EventActorStateUpdate { state }), + ) + .await + } + + async fn emit_event(&self, actor_id: &str, generation: u32, event: rp::Event) -> Result<()> { + let wrapper = { + let mut actors = self.actors.lock().await; + let actor = actors + .get_mut(actor_id) + .with_context(|| format!("actor {actor_id} not found"))?; + let index = actor.next_event_idx; + actor.next_event_idx += 1; + + let wrapper = rp::EventWrapper { + checkpoint: rp::ActorCheckpoint { + actor_id: actor_id.to_string(), + generation, + index, + }, + inner: event, + }; + actor.event_history.push(wrapper.clone()); + wrapper + }; + + self.send_to_server(rp::ToServer::ToServerEvents(vec![wrapper])) + .await + } + + async fn send_command_acknowledgement(&self) -> Result<()> { + let checkpoints = { + let actors = self.actors.lock().await; + actors + .iter() + .filter(|(_, actor)| actor.last_command_idx >= 0) + .map(|(actor_id, actor)| rp::ActorCheckpoint { + actor_id: actor_id.clone(), + generation: actor.generation, + index: actor.last_command_idx, + }) + .collect::>() + }; + + self.send_to_server(rp::ToServer::ToServerAckCommands(rp::ToServerAckCommands { + last_command_checkpoints: checkpoints, + })) + .await + } + + async fn resend_unacknowledged_events(&self) -> Result<()> { + let events = { + let actors = self.actors.lock().await; + actors + .values() + .flat_map(|x| x.event_history.clone()) + .collect::>() }; - self.actors + + if events.is_empty() { + return Ok(()); + } + + self.send_to_server(rp::ToServer::ToServerEvents(events)).await + } + + async fn handle_ack_events(&self, ack: rp::ToClientAckEvents) { + let mut actors = self.actors.lock().await; + for (actor_id, actor) in actors.iter_mut() { + if let Some(checkpoint) = ack + .last_event_checkpoints + .iter() + .find(|x| x.actor_id == *actor_id) + { + actor.event_history.retain(|entry| { + entry.checkpoint.generation != checkpoint.generation + || entry.checkpoint.index > checkpoint.index + }); + } + } + } + + async fn send_kv_request(&self, actor_id: &str, data: rp::KvRequestData) -> Result { + let request_id = self.next_kv_request_id.fetch_add(1, Ordering::SeqCst); + let (tx, rx) = oneshot::channel(); + + self.pending_kv_requests.lock().await.insert(request_id, tx); + + self.send_to_server(rp::ToServer::ToServerKvRequest(rp::ToServerKvRequest { + actor_id: actor_id.to_string(), + request_id, + data, + })) + .await?; + + let response = tokio::time::timeout(Duration::from_secs(30), rx) + .await + .context("timed out waiting for kv response")? + .context("kv response channel closed")?; + + Ok(response) + } + + async fn handle_kv_response(&self, response: rp::ToClientKvResponse) { + let sender = self + .pending_kv_requests .lock() .await - .insert(actor_id.clone(), actor_state); - - // Handle start result and send state update via event - match start_result { - ActorStartResult::Running => { - let event = protocol::make_actor_state_update(rp::ActorState::ActorStateRunning); - self.event_tx - .send(ActorEvent { - actor_id: actor_id.clone(), - generation, - event, - }) - .expect("failed to send state update"); + .remove(&response.request_id); + if let Some(sender) = sender { + let _ = sender.send(response.data); + } + } + + async fn remove_actor_websockets(&self, actor_id: &str) { + let mut websockets = self.websockets.lock().await; + websockets.retain(|_, ws| ws.actor_id != actor_id); + } + + async fn handle_tunnel_message(&self, message: rp::ToClientTunnelMessage) -> Result<()> { + let incoming_message_index = message.message_id.message_index; + let key = TunnelRequestKey { + gateway_id: message.message_id.gateway_id, + request_id: message.message_id.request_id, + }; + + self.ensure_tunnel_index(key, incoming_message_index) + .await; + + match message.message_kind { + rp::ToClientTunnelMessageKind::ToClientRequestStart(req) => { + self.handle_request_start(key, req).await?; } - ActorStartResult::Delay(duration) => { - let actor_id_clone = actor_id.clone(); - let event_tx = self.event_tx.clone(); - tokio::spawn(async move { - tracing::info!( - ?actor_id_clone, - generation, - delay_ms = duration.as_millis(), - "delaying before sending running state" - ); - tokio::time::sleep(duration).await; - let event = - protocol::make_actor_state_update(rp::ActorState::ActorStateRunning); - event_tx - .send(ActorEvent { - actor_id: actor_id_clone, - generation, - event, - }) - .expect("failed to send delayed state update"); - }); + rp::ToClientTunnelMessageKind::ToClientRequestChunk(chunk) => { + self.handle_request_chunk(key, chunk).await?; } - ActorStartResult::Timeout => { - tracing::warn!( - ?actor_id, - generation, - "actor will timeout (not sending running)" - ); - // Don't send running state + rp::ToClientTunnelMessageKind::ToClientRequestAbort => { + self.pending_http_requests.lock().await.remove(&key); } - ActorStartResult::Crash { code, message } => { - tracing::warn!(?actor_id, generation, code, %message, "actor crashed on start"); - let event = protocol::make_actor_state_update(rp::ActorState::ActorStateStopped( - rp::ActorStateStopped { - code: if code == 0 { - rp::StopCode::Ok - } else { - rp::StopCode::Error - }, - message: Some(message), - }, - )); - let _ = self - .event_tx - .send(ActorEvent { - actor_id: actor_id.clone(), - generation, - event, - }) - .expect("failed to send crash state update"); - - // Remove actor - self.actors.lock().await.remove(&actor_id); + rp::ToClientTunnelMessageKind::ToClientWebSocketOpen(ws) => { + self.handle_websocket_open(key, ws).await?; + } + rp::ToClientTunnelMessageKind::ToClientWebSocketMessage(message) => { + self.handle_websocket_message(key, incoming_message_index, message) + .await?; + } + rp::ToClientTunnelMessageKind::ToClientWebSocketClose(close) => { + self.handle_websocket_close(key, close).await?; } } + + Ok(()) + } + + async fn handle_request_start( + &self, + key: TunnelRequestKey, + req: rp::ToClientRequestStart, + ) -> Result<()> { + if !req.stream { + let pending = PendingHttpRequest { + actor_id: req.actor_id, + method: req.method, + path: req.path, + headers: req.headers.into(), + body: req.body.unwrap_or_default(), + }; + self.spawn_http_request(key, pending); + return Ok(()); + } + + self.pending_http_requests.lock().await.insert( + key, + PendingHttpRequest { + actor_id: req.actor_id, + method: req.method, + path: req.path, + headers: req.headers.into(), + body: req.body.unwrap_or_default(), + }, + ); + Ok(()) } - async fn handle_stop_actor( + async fn handle_request_chunk( &self, - actor_id: String, - generation: u32, - ws_stream: &mut WsStream, + key: TunnelRequestKey, + chunk: rp::ToClientRequestChunk, ) -> Result<()> { - tracing::info!(?actor_id, generation, "stopping actor"); + let mut requests = self.pending_http_requests.lock().await; + let Some(request) = requests.get_mut(&key) else { + return Ok(()); + }; + + request.body.extend_from_slice(&chunk.body); + if chunk.finish { + let request = requests.remove(&key).context("request removed unexpectedly")?; + drop(requests); + self.spawn_http_request(key, request); + } - // Get actor - let mut actors_guard = self.actors.lock().await; - let actor_state = actors_guard.get_mut(&actor_id).context("actor not found")?; + Ok(()) + } - // Call on_stop - let stop_result = actor_state - .actor - .on_stop() - .await - .context("actor on_stop failed")?; + fn spawn_http_request(&self, key: TunnelRequestKey, request: PendingHttpRequest) { + let inner = Arc::new(self.clone_for_handle()); + tokio::spawn(async move { + if let Err(err) = inner.process_http_request(key, request).await { + tracing::error!(?err, "http tunnel request failed"); + } + }); + } + + async fn process_http_request(&self, key: TunnelRequestKey, pending: PendingHttpRequest) -> Result<()> { + let Some((generation, actor_config)) = self.wait_for_actor_started(&pending.actor_id).await else { + self.send_response_error(key, 503, "runner.actor_not_found", "Actor not found") + .await?; + return Ok(()); + }; - tracing::debug!( - ?actor_id, + let request = build_http_request( + pending.method.clone(), + pending.path.clone(), + pending.headers.clone(), + pending.body, + )?; + + let ctx = HttpContext { + actor_id: pending.actor_id.clone(), generation, - ?stop_result, - "actor on_stop completed" - ); + actor_name: actor_config.name.clone(), + gateway_id: key.gateway_id, + request_id: key.request_id, + }; - // Broadcast lifecycle event - let _ = self.lifecycle_tx.send(ActorLifecycleEvent::Stopped { - actor_id: actor_id.clone(), + let response = match self + .app + .fetch( + RunnerHandle { + inner: Arc::new(self.clone_for_handle()), + }, + ctx, + request, + ) + .await + { + Ok(response) => response, + Err(err) => { + tracing::error!(?err, "fetch callback failed"); + self.send_response_error(key, 500, "runner.internal", "Internal Server Error") + .await?; + return Ok(()); + } + }; + + self.send_response_start(key, response).await?; + Ok(()) + } + + async fn handle_websocket_open( + &self, + key: TunnelRequestKey, + open: rp::ToClientWebSocketOpen, + ) -> Result<()> { + let Some((generation, actor_config)) = self.wait_for_actor_started(&open.actor_id).await else { + self.send_tunnel_message( + key, + rp::ToServerTunnelMessageKind::ToServerWebSocketClose(rp::ToServerWebSocketClose { + code: Some(1011), + reason: Some("runner.actor_not_found".to_string()), + hibernate: false, + }), + ) + .await?; + return Ok(()); + }; + + let ws_ctx = WebSocketContext { + actor_id: open.actor_id.clone(), generation, - }); + actor_name: actor_config.name.clone(), + gateway_id: key.gateway_id, + request_id: key.request_id, + path: open.path.clone(), + headers: open.headers.clone().into(), + is_hibernatable: false, // Set after can_hibernate is computed. + is_restoring_hibernatable: false, + }; - // Handle stop result - match stop_result { - ActorStopResult::Success => { - self.send_actor_state_update( - &actor_id, - generation, - rp::ActorState::ActorStateStopped(rp::ActorStateStopped { - code: rp::StopCode::Ok, - message: None, - }), - ws_stream, - ) + let can_hibernate = self.app.can_hibernate(&ws_ctx); + let mut ws_ctx = ws_ctx; + ws_ctx.is_hibernatable = can_hibernate; + + self.websockets.lock().await.insert( + key, + WebSocketRuntimeState { + actor_id: open.actor_id, + generation, + actor_name: actor_config.name, + path: open.path, + headers: open.headers.into(), + is_hibernatable: can_hibernate, + is_restoring_hibernatable: false, + server_message_index: 0, + }, + ); + + if let Err(err) = self + .app + .websocket( + RunnerHandle { + inner: Arc::new(self.clone_for_handle()), + }, + ws_ctx, + ) + .await + { + self.websockets.lock().await.remove(&key); + self.send_tunnel_message( + key, + rp::ToServerTunnelMessageKind::ToServerWebSocketClose(rp::ToServerWebSocketClose { + code: Some(1011), + reason: Some("ws.open_error".to_string()), + hibernate: false, + }), + ) .await?; + bail!("websocket callback failed: {err}"); + } + + self.send_tunnel_message( + key, + rp::ToServerTunnelMessageKind::ToServerWebSocketOpen(rp::ToServerWebSocketOpen { + can_hibernate, + }), + ) + .await?; + + Ok(()) + } + + async fn handle_websocket_message( + &self, + key: TunnelRequestKey, + message_index: u16, + message: rp::ToClientWebSocketMessage, + ) -> Result<()> { + let maybe_state = { + let mut websockets = self.websockets.lock().await; + let Some(state) = websockets.get_mut(&key) else { + return Ok(()); + }; + + if state.is_hibernatable { + if wrapping_lte_u16(message_index, state.server_message_index) { + return Ok(()); + } + + let expected = wrapping_add_u16(state.server_message_index, 1); + if message_index != expected { + let _ = self + .send_tunnel_message( + key, + rp::ToServerTunnelMessageKind::ToServerWebSocketClose( + rp::ToServerWebSocketClose { + code: Some(1008), + reason: Some("ws.message_index_skip".to_string()), + hibernate: false, + }, + ), + ) + .await; + websockets.remove(&key); + return Ok(()); + } + + state.server_message_index = message_index; } - ActorStopResult::Delay(duration) => { - tracing::info!(?actor_id, generation, ?duration, "delaying stop"); - tokio::time::sleep(duration).await; - self.send_actor_state_update( - &actor_id, - generation, - rp::ActorState::ActorStateStopped(rp::ActorStateStopped { - code: rp::StopCode::Ok, - message: None, + + Some(state.clone()) + }; + + let Some(state) = maybe_state else { + return Ok(()); + }; + + let ctx = self.websocket_context_from_state(key, &state); + let callback_result = self + .app + .websocket_message( + RunnerHandle { + inner: Arc::new(self.clone_for_handle()), + }, + ctx, + WebSocketMessage { + data: message.data, + binary: message.binary, + message_index, + }, + ) + .await; + + if let Err(err) = callback_result { + tracing::warn!(?err, "websocket message callback failed"); + self.websockets.lock().await.remove(&key); + self.send_tunnel_message( + key, + rp::ToServerTunnelMessageKind::ToServerWebSocketClose(rp::ToServerWebSocketClose { + code: Some(1011), + reason: Some("ws.message_error".to_string()), + hibernate: false, + }), + ) + .await?; + } + + Ok(()) + } + + async fn handle_websocket_close( + &self, + key: TunnelRequestKey, + close: rp::ToClientWebSocketClose, + ) -> Result<()> { + let state = self.websockets.lock().await.remove(&key); + let Some(state) = state else { + return Ok(()); + }; + + let ctx = self.websocket_context_from_state(key, &state); + self.app + .websocket_close( + RunnerHandle { + inner: Arc::new(self.clone_for_handle()), + }, + ctx, + close.code, + close.reason, + ) + .await + .context("websocket close callback failed") + } + + fn websocket_context_from_state( + &self, + key: TunnelRequestKey, + state: &WebSocketRuntimeState, + ) -> WebSocketContext { + WebSocketContext { + actor_id: state.actor_id.clone(), + generation: state.generation, + actor_name: state.actor_name.clone(), + gateway_id: key.gateway_id, + request_id: key.request_id, + path: state.path.clone(), + headers: state.headers.clone(), + is_hibernatable: state.is_hibernatable, + is_restoring_hibernatable: state.is_restoring_hibernatable, + } + } + + async fn restore_hibernating_requests( + &self, + actor_id: &str, + meta_entries: Vec, + ) -> Result<()> { + let (generation, actor_name, connected_requests) = { + let mut actors = self.actors.lock().await; + let state = actors + .get_mut(actor_id) + .with_context(|| format!("actor {actor_id} not found"))?; + if state.hibernation_restored { + bail!("actor {actor_id} already restored hibernating requests"); + } + state.hibernation_restored = true; + ( + state.generation, + state.config.name.clone(), + state.hibernating_requests.clone(), + ) + }; + + for connected in &connected_requests { + let key = TunnelRequestKey { + gateway_id: connected.gateway_id, + request_id: connected.request_id, + }; + + let meta = meta_entries.iter().find(|entry| { + entry.gateway_id == connected.gateway_id && entry.request_id == connected.request_id + }); + + let Some(meta) = meta else { + self.send_tunnel_message( + key, + rp::ToServerTunnelMessageKind::ToServerWebSocketClose(rp::ToServerWebSocketClose { + code: Some(1000), + reason: Some("ws.meta_not_found_during_restore".to_string()), + hibernate: false, }), - ws_stream, ) .await?; - } - ActorStopResult::Crash { code, message } => { - tracing::warn!(?actor_id, generation, code, %message, "actor crashed on stop"); - self.send_actor_state_update( - &actor_id, - generation, - rp::ActorState::ActorStateStopped(rp::ActorStateStopped { - code: if code == 0 { - rp::StopCode::Ok - } else { - rp::StopCode::Error - }, - message: Some(message), + continue; + }; + + self.set_tunnel_index(key, meta.client_message_index).await; + let ws_state = WebSocketRuntimeState { + actor_id: actor_id.to_string(), + generation, + actor_name: actor_name.clone(), + path: meta.path.clone(), + headers: meta.headers.clone(), + is_hibernatable: true, + is_restoring_hibernatable: true, + server_message_index: meta.server_message_index, + }; + self.websockets.lock().await.insert(key, ws_state.clone()); + + let ws_ctx = self.websocket_context_from_state(key, &ws_state); + if let Err(err) = self + .app + .websocket( + RunnerHandle { + inner: Arc::new(self.clone_for_handle()), + }, + ws_ctx, + ) + .await + { + tracing::warn!(?err, actor_id, "error restoring websocket"); + self.websockets.lock().await.remove(&key); + self.send_tunnel_message( + key, + rp::ToServerTunnelMessageKind::ToServerWebSocketClose(rp::ToServerWebSocketClose { + code: Some(1011), + reason: Some("ws.restore_error".to_string()), + hibernate: false, }), - ws_stream, ) .await?; } } - // Remove actor - actors_guard.remove(&actor_id); + for meta in &meta_entries { + let is_connected = connected_requests.iter().any(|request| { + request.gateway_id == meta.gateway_id && request.request_id == meta.request_id + }); + if is_connected { + continue; + } + + let key = TunnelRequestKey { + gateway_id: meta.gateway_id, + request_id: meta.request_id, + }; + let ws_ctx = WebSocketContext { + actor_id: actor_id.to_string(), + generation, + actor_name: actor_name.clone(), + gateway_id: meta.gateway_id, + request_id: meta.request_id, + path: meta.path.clone(), + headers: meta.headers.clone(), + is_hibernatable: true, + is_restoring_hibernatable: true, + }; + self.app + .websocket_close( + RunnerHandle { + inner: Arc::new(self.clone_for_handle()), + }, + ws_ctx, + Some(1000), + Some("ws.stale_metadata".to_string()), + ) + .await + .context("stale websocket close callback failed")?; + self.websockets.lock().await.remove(&key); + } Ok(()) } - async fn handle_ack_events(&self, ack: rp::ToClientAckEvents) { - // MK2 uses per-actor checkpoints for acknowledgments - let checkpoints = &ack.last_event_checkpoints; - - let mut events = self.event_history.lock().await; - let original_len = events.len(); - - // Remove events that have been acknowledged based on checkpoints - events.retain(|e| { - // Check if this event's checkpoint is covered by any ack checkpoint - !checkpoints.iter().any(|ck| { - ck.actor_id == e.checkpoint.actor_id - && ck.generation == e.checkpoint.generation - && ck.index >= e.checkpoint.index - }) - }); + async fn send_response_start(&self, key: TunnelRequestKey, response: Response) -> Result<()> { + let status = response.status().as_u16(); + let (parts, body) = response.into_parts(); + + let mut headers = HashMap::new(); + for (name, value) in &parts.headers { + if let Ok(value) = value.to_str() { + headers.insert(name.to_string(), value.to_string()); + } + } - let pruned = original_len - events.len(); - if pruned > 0 { - tracing::debug!( - checkpoint_count = checkpoints.len(), - pruned, - "pruned acknowledged events" - ); + if !headers.contains_key("content-length") { + headers.insert("content-length".to_string(), body.len().to_string()); } + + self.send_tunnel_message( + key, + rp::ToServerTunnelMessageKind::ToServerResponseStart(rp::ToServerResponseStart { + status, + headers: headers.into(), + body: Some(body.to_vec()), + stream: false, + }), + ) + .await } - async fn send_actor_state_update( + async fn send_response_error( &self, - actor_id: &str, - generation: u32, - state: rp::ActorState, - ws_stream: &mut WsStream, + key: TunnelRequestKey, + status: u16, + error_code: &str, + message: &str, ) -> Result<()> { - let event = protocol::make_actor_state_update(state); + let mut headers = HashMap::new(); + headers.insert("content-type".to_string(), "text/plain".to_string()); + headers.insert("x-rivet-error".to_string(), error_code.to_string()); + + self.send_tunnel_message( + key, + rp::ToServerTunnelMessageKind::ToServerResponseStart(rp::ToServerResponseStart { + status, + headers: headers.into(), + body: Some(message.as_bytes().to_vec()), + stream: false, + }), + ) + .await + } - self.send_actor_event( - ws_stream, - ActorEvent { - actor_id: actor_id.to_string(), - generation, - event, + async fn ensure_tunnel_index(&self, key: TunnelRequestKey, incoming_index: u16) { + let mut indices = self.tunnel_message_indices.lock().await; + indices.entry(key).or_insert(incoming_index); + } + + async fn set_tunnel_index(&self, key: TunnelRequestKey, next_index: u16) { + let mut indices = self.tunnel_message_indices.lock().await; + indices.insert(key, next_index); + } + + async fn send_tunnel_message( + &self, + key: TunnelRequestKey, + message_kind: rp::ToServerTunnelMessageKind, + ) -> Result<()> { + let message_index = { + let mut indices = self.tunnel_message_indices.lock().await; + let idx = indices.entry(key).or_insert(0); + let current = *idx; + *idx = idx.wrapping_add(1); + current + }; + + self.send_to_server(rp::ToServer::ToServerTunnelMessage(rp::ToServerTunnelMessage { + message_id: rp::MessageId { + gateway_id: key.gateway_id, + request_id: key.request_id, + message_index, }, - ) - .await?; + message_kind, + })) + .await + } - Ok(()) + async fn send_to_server(&self, message: rp::ToServer) -> Result<()> { + let sender = self.ws_sender.lock().await.clone(); + let Some(sender) = sender else { + bail!("runner websocket sender unavailable") + }; + + sender + .send(message) + .map_err(|_| anyhow!("failed to queue message to websocket writer")) } +} - async fn send_kv_request(&self, ws_stream: &mut WsStream, kv_request: KvRequest) -> Result<()> { - let mut request_id = self.next_kv_request_id.lock().await; - let id = *request_id; - *request_id += 1; - drop(request_id); +fn build_http_request( + method: String, + path: String, + headers: HashMap, + body: Vec, +) -> Result> { + let uri = if path.starts_with('/') { + format!("http://actor{path}") + } else { + format!("http://actor/{path}") + }; + + let mut builder = Request::builder() + .method(method.parse::().context("invalid method")?) + .uri(uri.parse::().context("invalid uri")?); + + for (name, value) in headers { + builder = builder.header(name, value); + } - // Store the response channel - self.kv_pending_requests - .lock() - .await - .insert(id, kv_request.response_tx); + Ok(builder.body(Bytes::from(body))?) +} - tracing::debug!( - actor_id = ?kv_request.actor_id, - request_id = id, - "sending kv request" - ); +fn wrapping_add_u16(a: u16, b: u16) -> u16 { + a.wrapping_add(b) +} - let msg = rp::ToServer::ToServerKvRequest(rp::ToServerKvRequest { - actor_id: kv_request.actor_id, - request_id: id, - data: kv_request.data, - }); - let encoded = protocol::encode_to_server(msg); - ws_stream.send(Message::Binary(encoded.into())).await?; +fn wrapping_sub_u16(a: u16, b: u16) -> u16 { + a.wrapping_sub(b) +} - Ok(()) +fn wrapping_lt_u16(a: u16, b: u16) -> bool { + a != b && wrapping_sub_u16(b, a) < (u16::MAX / 2) +} + +fn wrapping_lte_u16(a: u16, b: u16) -> bool { + a == b || wrapping_lt_u16(a, b) +} + +#[derive(Clone, Debug)] +pub struct ServerlessConfig { + pub runner: RunnerConfig, + pub max_runners: u32, + pub slots_per_runner: u32, + pub request_lifespan: u32, +} + +impl ServerlessConfig { + pub fn builder() -> ServerlessConfigBuilder { + ServerlessConfigBuilder::default() + } +} + +#[derive(Default)] +pub struct ServerlessConfigBuilder { + runner: RunnerConfigBuilder, + max_runners: Option, + slots_per_runner: Option, + request_lifespan: Option, +} + +impl ServerlessConfigBuilder { + pub fn endpoint(mut self, endpoint: impl Into) -> Self { + self.runner = self.runner.endpoint(endpoint); + self } - async fn handle_kv_response(&self, response: rp::ToClientKvResponse) { - let request_id = response.request_id; + pub fn token(mut self, token: impl Into) -> Self { + self.runner = self.runner.token(token); + self + } + + pub fn namespace(mut self, namespace: impl Into) -> Self { + self.runner = self.runner.namespace(namespace); + self + } - tracing::debug!(request_id, "received kv response"); + pub fn runner_name(mut self, runner_name: impl Into) -> Self { + self.runner = self.runner.runner_name(runner_name); + self + } - let response_tx = self.kv_pending_requests.lock().await.remove(&request_id); + pub fn runner_key(mut self, runner_key: impl Into) -> Self { + self.runner = self.runner.runner_key(runner_key); + self + } - if let Some(tx) = response_tx { - let _ = tx.send(response.data); - } else { - tracing::warn!(request_id, "received kv response for unknown request id"); - } + pub fn version(mut self, version: u32) -> Self { + self.runner = self.runner.version(version); + self + } + + pub fn total_slots(mut self, total_slots: u32) -> Self { + self.runner = self.runner.total_slots(total_slots); + self + } + + pub fn pegboard_endpoint(mut self, endpoint: impl Into) -> Self { + self.runner = self.runner.pegboard_endpoint(endpoint); + self + } + + pub fn prepopulate_actor_name( + mut self, + name: impl Into, + metadata: serde_json::Value, + ) -> Self { + self.runner = self.runner.prepopulate_actor_name(name, metadata); + self + } + + pub fn metadata(mut self, metadata: serde_json::Value) -> Self { + self.runner = self.runner.metadata(metadata); + self + } + + pub fn max_runners(mut self, max_runners: u32) -> Self { + self.max_runners = Some(max_runners); + self + } + + pub fn slots_per_runner(mut self, slots_per_runner: u32) -> Self { + self.slots_per_runner = Some(slots_per_runner); + self + } + + pub fn request_lifespan(mut self, request_lifespan: u32) -> Self { + self.request_lifespan = Some(request_lifespan); + self + } + + pub fn build(self) -> Result { + Ok(ServerlessConfig { + runner: self.runner.build()?, + max_runners: self.max_runners.unwrap_or(1000), + slots_per_runner: self.slots_per_runner.unwrap_or(1), + request_lifespan: self.request_lifespan.unwrap_or(300), + }) + } +} + +pub struct ServerlessRunnerBuilder { + config: ServerlessConfig, + app: Option>, +} + +impl ServerlessRunnerBuilder { + pub fn app(mut self, app: A) -> Self + where + A: RunnerApp, + { + self.app = Some(Arc::new(app)); + self + } + + pub fn app_arc(mut self, app: Arc) -> Self { + self.app = Some(app); + self + } + + pub fn build(self) -> Result { + let app = self + .app + .context("runner app is required; call ServerlessRunnerBuilder::app")?; + let runner = Runner::builder(self.config.runner.clone()) + .app_arc(app) + .build()?; + Ok(ServerlessRunner { + runner, + config: self.config, + }) } } -impl Drop for Runner { - fn drop(&mut self) { - if self.is_child_task { - return; +#[derive(Clone)] +pub struct ServerlessRunner { + runner: Runner, + config: ServerlessConfig, +} + +impl ServerlessRunner { + pub fn builder(config: ServerlessConfig) -> ServerlessRunnerBuilder { + ServerlessRunnerBuilder { config, app: None } + } + + pub fn runner(&self) -> Runner { + self.runner.clone() + } + + pub fn axum_routes(self: Arc) -> axum::Router { + let state = self.clone(); + axum::Router::new() + .route("/api/rivet/start", axum::routing::get(serverless_start)) + .route("/start", axum::routing::get(serverless_start)) + .route("/api/rivet/metadata", axum::routing::get(serverless_metadata)) + .route("/metadata", axum::routing::get(serverless_metadata)) + .with_state(state) + } + + pub async fn upsert_serverless_runner_config(&self, public_url: &str) -> Result<()> { + let endpoint = self.config.runner.endpoint.trim_end_matches('/'); + let url = format!( + "{endpoint}/runner-configs/{}?namespace={}", + encode(&self.config.runner.runner_name), + encode(&self.config.runner.namespace) + ); + + let client = reqwest::Client::new(); + let response = client + .put(url) + .bearer_auth(&self.config.runner.token) + .json(&self.serverless_runner_config_body(public_url, "default")) + .send() + .await + .context("failed to send serverless runner config request")?; + + if response.status().is_success() { + return Ok(()); + } + + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + + if status == reqwest::StatusCode::BAD_REQUEST { + if let Some(datacenter_name) = self.resolve_datacenter_name(&client, endpoint).await? { + if datacenter_name != "default" { + let retry_response = client + .put(format!( + "{endpoint}/runner-configs/{}?namespace={}", + encode(&self.config.runner.runner_name), + encode(&self.config.runner.namespace) + )) + .bearer_auth(&self.config.runner.token) + .json(&self.serverless_runner_config_body(public_url, &datacenter_name)) + .send() + .await + .context("failed to retry serverless runner config request")?; + + if retry_response.status().is_success() { + return Ok(()); + } + + let retry_status = retry_response.status(); + let retry_body = retry_response.text().await.unwrap_or_default(); + bail!("serverless runner config upsert failed: {retry_status} {retry_body}"); + } + } } - // Signal shutdown when runner is dropped - self.shutdown.store(true, Ordering::SeqCst); - tracing::debug!("engine runner dropped, shutdown signaled"); + + bail!("serverless runner config upsert failed: {status} {body}"); } + + fn serverless_runner_config_body( + &self, + public_url: &str, + datacenter_name: &str, + ) -> serde_json::Value { + serde_json::json!({ + "datacenters": { + datacenter_name: { + "serverless": { + "url": public_url, + "max_runners": self.config.max_runners, + "slots_per_runner": self.config.slots_per_runner, + "request_lifespan": self.config.request_lifespan, + } + } + } + }) + } + + async fn resolve_datacenter_name( + &self, + client: &reqwest::Client, + endpoint: &str, + ) -> Result> { + let response = client + .get(format!("{endpoint}/datacenters")) + .bearer_auth(&self.config.runner.token) + .send() + .await + .context("failed to fetch datacenters")?; + + if !response.status().is_success() { + return Ok(None); + } + + let body: serde_json::Value = response + .json() + .await + .context("failed to decode datacenters response")?; + let datacenter_name = body + .get("datacenters") + .and_then(serde_json::Value::as_array) + .and_then(|x| x.first()) + .and_then(|x| x.get("name")) + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + + Ok(datacenter_name) + } +} + +async fn serverless_start( + axum::extract::State(runner): axum::extract::State>, +) -> Result< + axum::response::sse::Sse>>, + (StatusCode, String), +> { + use axum::response::sse::{Event, Sse}; + + if !runner.runner.inner.started.load(Ordering::SeqCst) { + runner.runner.start().await.map_err(internal_error)?; + } + + runner.runner.wait_ready().await.map_err(internal_error)?; + let packet = runner + .runner + .get_serverless_init_packet() + .await + .map_err(internal_error)? + .context("runner init packet unavailable") + .map_err(internal_error)?; + + let stream = stream! { + yield Ok::(Event::default().event("message").data(packet)); + let mut interval = tokio::time::interval(Duration::from_secs(15)); + loop { + interval.tick().await; + yield Ok::(Event::default().event("ping").data("")); + } + }; + + Ok(Sse::new(stream)) +} + +async fn serverless_metadata( + axum::extract::State(runner): axum::extract::State>, +) -> axum::Json { + let actor_names = runner + .config + .runner + .prepopulate_actor_names + .iter() + .map(|(name, entry)| (name.clone(), entry.metadata.clone())) + .collect::>(); + + axum::Json(serde_json::json!({ + "runtime": "rivetkit", + "version": "1", + "actorNames": actor_names, + "runner": { + "version": runner.config.runner.version, + } + })) +} + +fn internal_error(err: anyhow::Error) -> (StatusCode, String) { + (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) } diff --git a/engine/sdks/rust/engine-runner/tests/common/mod.rs b/engine/sdks/rust/engine-runner/tests/common/mod.rs new file mode 100644 index 0000000000..99bbed105b --- /dev/null +++ b/engine/sdks/rust/engine-runner/tests/common/mod.rs @@ -0,0 +1,357 @@ +use anyhow::{Context, Result, bail}; +use reqwest::Method; +use serde_json::{Value, json}; +use std::{ + fmt::Write as _, + path::PathBuf, + process::{Child, Command, Stdio}, + sync::{Arc, OnceLock}, + time::{Duration, Instant}, +}; +use tempfile::TempDir; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio_tungstenite::{ + connect_async, + tungstenite::client::IntoClientRequest, + WebSocketStream, + MaybeTlsStream, +}; +use urlencoding::encode; + +pub struct EngineProcess { + pub deps: rivet_test_deps::TestDeps, + child: Child, + _config_dir: TempDir, +} + +impl EngineProcess { + pub async fn start() -> Result { + let deps = rivet_test_deps::TestDeps::new().await?; + + let config_dir = tempfile::tempdir().context("failed to create config dir")?; + let config_path = config_dir.path().join("rivet.test.yaml"); + let mut root = (**deps.config()).clone(); + if let Some(rivet_config::config::Database::FileSystem(database)) = root.database.as_mut() { + let db_path = config_dir.path().join("engine-db"); + std::fs::create_dir_all(&db_path).context("failed to create engine db dir")?; + database.path = db_path; + } + + let config_yaml = serde_yaml::to_string(&root) + .context("failed to serialize config")?; + std::fs::write(&config_path, config_yaml).context("failed to write config")?; + + let engine_bin = ensure_engine_binary()?; + let mut cmd = Command::new(engine_bin); + cmd.arg("--config") + .arg(&config_path) + .arg("start") + .arg("-s") + .arg("api_peer") + .arg("-s") + .arg("guard") + .arg("-s") + .arg("workflow_worker") + .arg("-s") + .arg("bootstrap") + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) + .stdin(Stdio::null()); + + let child = cmd.spawn().context("failed to spawn rivet-engine")?; + + wait_for_port(deps.api_peer_port()).await?; + wait_for_port(deps.guard_port()).await?; + + Ok(Self { + deps, + child, + _config_dir: config_dir, + }) + } + + pub fn guard_url(&self) -> String { + format!("http://127.0.0.1:{}", self.deps.guard_port()) + } + + pub async fn create_actor( + &self, + namespace: &str, + name: &str, + runner_name_selector: &str, + key: Option<&str>, + ) -> Result { + let client = reqwest::Client::new(); + let response = client + .post(format!("{}/actors", self.guard_url())) + .query(&[("namespace", namespace)]) + .json(&json!({ + "datacenter": null, + "name": name, + "key": key, + "input": null, + "runner_name_selector": runner_name_selector, + "crash_policy": "sleep", + })) + .send() + .await + .context("failed to create actor")?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + bail!("create actor failed: {status} {body}"); + } + + let body: Value = response.json().await.context("failed to decode actor response")?; + let actor_id = body + .get("actor") + .and_then(|x| x.get("actor_id")) + .and_then(Value::as_str) + .context("actor id missing from create actor response")?; + Ok(actor_id.to_string()) + } + + #[allow(dead_code)] + pub async fn actor_request_json( + &self, + method: Method, + actor_id: &str, + path: &str, + body: Option, + ) -> Result { + let response = self + .actor_request_with_retry(method, actor_id, path, body) + .await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + bail!("actor request failed: {status} {body}"); + } + + response + .json() + .await + .context("failed to decode actor response json") + } + + #[allow(dead_code)] + pub async fn get_actor(&self, namespace: &str, actor_id: &str) -> Result> { + let client = reqwest::Client::new(); + let response = client + .get(format!("{}/actors", self.guard_url())) + .query(&[ + ("namespace", namespace), + ("actor_id", actor_id), + ("include_destroyed", "true"), + ]) + .send() + .await + .context("failed to fetch actors list")?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + bail!("actors list request failed: {status} {body}"); + } + + let body: Value = response + .json() + .await + .context("failed to decode actors list response json")?; + let actor = body + .get("actors") + .and_then(Value::as_array) + .and_then(|actors| actors.first()) + .cloned(); + Ok(actor) + } + + #[allow(dead_code)] + pub async fn actor_request_with_retry( + &self, + method: Method, + actor_id: &str, + path: &str, + body: Option, + ) -> Result { + let url = format!("{}{}", self.guard_url(), path); + let client = reqwest::Client::new(); + + let start = Instant::now(); + let timeout = Duration::from_secs(30); + let mut last_error: Option = None; + + loop { + if start.elapsed() > timeout { + if let Some(err) = last_error { + return Err(err).context("timed out waiting for actor response"); + } + bail!("timed out waiting for actor response"); + } + + let mut request = client + .request(method.clone(), &url) + .header("x-rivet-target", "actor") + .header("x-rivet-token", "dev") + .header("x-rivet-actor", actor_id); + + if let Some(json) = &body { + request = request.json(json); + } + + match request.send().await { + Ok(response) + if response.status() == reqwest::StatusCode::SERVICE_UNAVAILABLE + || response.status() == reqwest::StatusCode::NOT_FOUND => + { + tokio::time::sleep(Duration::from_millis(250)).await; + continue; + } + Ok(response) if response.status() == reqwest::StatusCode::BAD_REQUEST => { + tokio::time::sleep(Duration::from_millis(250)).await; + drop(response); + continue; + } + Ok(response) => return Ok(response), + Err(err) => { + last_error = Some(err.into()); + tokio::time::sleep(Duration::from_millis(250)).await; + } + } + } + } + + #[allow(dead_code)] + pub async fn actor_websocket_connect( + &self, + actor_id: &str, + path: &str, + ) -> Result>> { + let start = Instant::now(); + let timeout = Duration::from_secs(30); + let mut last_error: Option = None; + + loop { + if start.elapsed() > timeout { + if let Some(err) = last_error { + return Err(err).context("timed out connecting actor websocket"); + } + bail!("timed out connecting actor websocket"); + } + + let mut ws_url = self.guard_url().replace("http://", "ws://"); + if path.starts_with('/') { + ws_url.push_str(path); + } else { + ws_url.push('/'); + ws_url.push_str(path); + } + + let mut request = ws_url + .into_client_request() + .context("failed to build websocket request")?; + request + .headers_mut() + .insert("x-rivet-target", "actor".parse().context("invalid target header")?); + request + .headers_mut() + .insert("x-rivet-token", "dev".parse().context("invalid token header")?); + request + .headers_mut() + .insert("x-rivet-actor", actor_id.parse().context("invalid actor header")?); + let actor_id_protocol = format!("rivet_actor.{}", encode(actor_id)); + let websocket_protocol = format!( + "rivet_target.actor, {actor_id_protocol}, rivet_token.dev, rivet" + ); + request.headers_mut().insert( + "Sec-WebSocket-Protocol", + websocket_protocol + .parse() + .context("invalid websocket protocol header")?, + ); + + match connect_async(request).await { + Ok((ws, _response)) => return Ok(ws), + Err(err) => { + last_error = Some(err.into()); + tokio::time::sleep(Duration::from_millis(250)).await; + } + } + } + } +} + +pub async fn acquire_test_lock() -> Result { + static TEST_LOCK: OnceLock> = OnceLock::new(); + let lock = TEST_LOCK + .get_or_init(|| Arc::new(Semaphore::new(1))) + .clone(); + lock.acquire_owned() + .await + .context("failed to acquire test lock") +} + +pub fn random_name(prefix: &str) -> String { + let mut name = String::with_capacity(prefix.len() + 17); + let _ = write!(&mut name, "{}-{:016x}", prefix, rand::random::()); + name +} + +impl Drop for EngineProcess { + fn drop(&mut self) { + let _ = self.child.kill(); + let _ = self.child.wait(); + } +} + +fn ensure_engine_binary() -> Result { + static BUILD_RESULT: OnceLock> = OnceLock::new(); + + let result = BUILD_RESULT.get_or_init(|| { + let workspace = workspace_root(); + let status = Command::new("cargo") + .arg("build") + .arg("-p") + .arg("rivet-engine") + .current_dir(&workspace) + .status(); + + match status { + Ok(status) if status.success() => { + let bin = workspace.join("target").join("debug").join("rivet-engine"); + if bin.exists() { + Ok(bin) + } else { + Err(format!("engine binary not found at {}", bin.display())) + } + } + Ok(status) => Err(format!("cargo build -p rivet-engine failed with status {status}")), + Err(err) => Err(format!("failed to execute cargo build: {err}")), + } + }); + + result.clone().map_err(anyhow::Error::msg) +} + +fn workspace_root() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("../../../../") + .canonicalize() + .expect("workspace root") +} + +async fn wait_for_port(port: u16) -> Result<()> { + let addr = format!("127.0.0.1:{port}"); + let start = Instant::now(); + let timeout = Duration::from_secs(30); + + loop { + match tokio::net::TcpStream::connect(&addr).await { + Ok(_) => return Ok(()), + Err(_) if start.elapsed() <= timeout => tokio::time::sleep(Duration::from_millis(100)).await, + Err(err) => return Err(err).with_context(|| format!("timed out waiting for port {port}")), + } + } +} diff --git a/engine/sdks/rust/engine-runner/tests/e2e_counter_runner.rs b/engine/sdks/rust/engine-runner/tests/e2e_counter_runner.rs new file mode 100644 index 0000000000..16efde33b7 --- /dev/null +++ b/engine/sdks/rust/engine-runner/tests/e2e_counter_runner.rs @@ -0,0 +1,201 @@ +mod common; + +use anyhow::{Result, bail}; +use axum::{ + Json, Router, + extract::State, + http::StatusCode, + routing::{get, post}, +}; +use reqwest::Method; +use rivet_engine_runner::{ + ActorContext, ActorRequestContext, AxumActorDefinition, AxumRunnerApp, Runner, RunnerConfig, +}; +use serde_json::{Value, json}; +use std::{collections::HashSet, sync::Arc, time::{Duration, Instant}}; +use tokio::sync::Mutex; + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn counter_actor_runner_http_kv_e2e() -> Result<()> { + let _test_lock = common::acquire_test_lock().await?; + let engine = common::EngineProcess::start().await?; + let namespace = "default".to_string(); + + let runner_name = common::random_name("rust-counter-runner"); + let runner_key = common::random_name("key"); + let actor_key = common::random_name("counter"); + let actor_registry = Arc::new(Mutex::new(HashSet::::new())); + + let runner = Runner::builder( + RunnerConfig::builder() + .endpoint(engine.guard_url()) + .namespace(namespace.clone()) + .runner_name(runner_name.clone()) + .runner_key(runner_key) + .token("dev") + .total_slots(16) + .build()?, + ) + .app(build_counter_app(actor_registry.clone())) + .build()?; + + runner.start().await?; + runner.wait_ready().await?; + + let actor_id = engine + .create_actor(&namespace, "counter", &runner_name, Some(&actor_key)) + .await?; + wait_for_actor_presence(&actor_registry, &actor_id, true, Duration::from_secs(30)).await?; + let actor = engine + .get_actor(&namespace, &actor_id) + .await? + .ok_or_else(|| anyhow::anyhow!("actor missing after create: {actor_id}"))?; + if actor.get("destroy_ts").is_some_and(|x| !x.is_null()) { + bail!("actor is already destroyed before first request: {actor}"); + } + + let count = match engine + .actor_request_json(Method::GET, &actor_id, "/count", None) + .await + { + Ok(value) => value, + Err(err) => { + let actor = engine.get_actor(&namespace, &actor_id).await?; + bail!("initial actor request failed actor={actor:?}: {err}"); + } + }; + assert_count(&count, 0)?; + + let incremented = match engine + .actor_request_json(Method::POST, &actor_id, "/increment", None) + .await + { + Ok(value) => value, + Err(err) => { + let actor = engine.get_actor(&namespace, &actor_id).await?; + bail!("first increment request failed actor={actor:?}: {err}"); + } + }; + assert_count(&incremented, 1)?; + + let incremented_again = match engine + .actor_request_json(Method::POST, &actor_id, "/increment", None) + .await + { + Ok(value) => value, + Err(err) => { + let actor = engine.get_actor(&namespace, &actor_id).await?; + bail!("second increment request failed actor={actor:?}: {err}"); + } + }; + assert_count(&incremented_again, 2)?; + + runner.handle().sleep_actor(&actor_id, None).await?; + wait_for_actor_presence(&actor_registry, &actor_id, false, Duration::from_secs(30)).await?; + + let persisted = match engine + .actor_request_json(Method::GET, &actor_id, "/count", None) + .await + { + Ok(value) => value, + Err(err) => { + let actor = engine.get_actor(&namespace, &actor_id).await?; + bail!("persisted count request failed actor={actor:?}: {err}"); + } + }; + assert_count(&persisted, 2)?; + wait_for_actor_presence(&actor_registry, &actor_id, true, Duration::from_secs(30)).await?; + + runner.shutdown(true).await?; + + Ok(()) +} + +fn build_counter_app(actor_registry: Arc>>) -> AxumRunnerApp { + let on_start_registry = actor_registry.clone(); + let on_stop_registry = actor_registry; + + AxumRunnerApp::new().with_actor( + "counter", + AxumActorDefinition::new( + Router::new() + .route("/count", get(get_count)) + .route("/increment", post(increment)), + ) + .on_start(move |ctx: ActorContext| { + let actor_registry = on_start_registry.clone(); + async move { + actor_registry.lock().await.insert(ctx.actor_id); + Ok(()) + } + }) + .on_stop(move |ctx: ActorContext| { + let actor_registry = on_stop_registry.clone(); + async move { + actor_registry.lock().await.remove(&ctx.actor_id); + Ok(()) + } + }), + ) +} + +async fn get_count( + State(ctx): State, +) -> Result, StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0); + Ok(Json(json!({ "count": count }))) +} + +async fn increment( + State(ctx): State, +) -> Result, StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0) + + 1; + ctx.kv_put_u64("count", count) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(Json(json!({ "count": count }))) +} + +fn assert_count(value: &Value, expected: u64) -> Result<()> { + let actual = value + .get("count") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("response missing `count` field: {value}"))?; + if actual != expected { + bail!("count mismatch: expected {expected}, got {actual}"); + } + Ok(()) +} + +async fn wait_for_actor_presence( + actor_registry: &Arc>>, + actor_id: &str, + expected: bool, + timeout: Duration, +) -> Result<()> { + let deadline = Instant::now() + timeout; + loop { + let present = actor_registry.lock().await.contains(actor_id); + if present == expected { + return Ok(()); + } + if Instant::now() >= deadline { + bail!( + "timed out waiting for actor presence state actor_id={} expected_present={} actual_present={}", + actor_id, + expected, + present + ); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } +} diff --git a/engine/sdks/rust/engine-runner/tests/e2e_counter_serverless.rs b/engine/sdks/rust/engine-runner/tests/e2e_counter_serverless.rs new file mode 100644 index 0000000000..2a5e46e6be --- /dev/null +++ b/engine/sdks/rust/engine-runner/tests/e2e_counter_serverless.rs @@ -0,0 +1,218 @@ +mod common; + +use anyhow::{Context, Result, bail}; +use axum::{ + Json, Router, + extract::State, + http::StatusCode, + routing::{get, post}, +}; +use reqwest::Method; +use rivet_engine_runner::{ + ActorContext, ActorRequestContext, AxumActorDefinition, AxumRunnerApp, ServerlessConfig, + ServerlessRunner, +}; +use serde_json::{Value, json}; +use std::{collections::HashSet, sync::Arc, time::{Duration, Instant}}; +use tokio::sync::{Mutex, oneshot}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn counter_actor_serverless_http_kv_e2e() -> Result<()> { + let _test_lock = common::acquire_test_lock().await?; + let engine = common::EngineProcess::start().await?; + let namespace = "default".to_string(); + + let runner_name = common::random_name("rust-counter-serverless"); + let runner_key = common::random_name("key"); + let actor_key = common::random_name("counter"); + let actor_registry = Arc::new(Mutex::new(HashSet::::new())); + + let serverless_runner = ServerlessRunner::builder( + ServerlessConfig::builder() + .endpoint(engine.guard_url()) + .namespace(namespace.clone()) + .runner_name(runner_name.clone()) + .runner_key(runner_key) + .prepopulate_actor_name("counter", json!({})) + .token("dev") + .total_slots(1) + .max_runners(1000) + .slots_per_runner(1) + .request_lifespan(300) + .build()?, + ) + .app(build_counter_app(actor_registry.clone())) + .build()?; + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .context("failed to bind serverless test listener")?; + let addr = listener.local_addr().context("missing listener local addr")?; + let serverless_url = format!("http://localhost:{}", addr.port()); + + let routes = Arc::new(serverless_runner.clone()).axum_routes(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let mut server_task = tokio::spawn(async move { + axum::serve(listener, routes) + .with_graceful_shutdown(async move { + let _ = shutdown_rx.await; + }) + .await + .context("serverless axum server exited with error") + }); + + let metadata_response = reqwest::get(format!("{serverless_url}/api/rivet/metadata")) + .await + .context("failed to call serverless metadata endpoint")?; + if metadata_response.status() != reqwest::StatusCode::OK { + bail!("metadata endpoint returned {}", metadata_response.status()); + } + + let start_response = reqwest::Client::new() + .get(format!("{serverless_url}/api/rivet/start")) + .send() + .await?; + if start_response.status() != reqwest::StatusCode::OK { + bail!("serverless start endpoint returned {}", start_response.status()); + } + + let actor_id = engine + .create_actor(&namespace, "counter", &runner_name, Some(&actor_key)) + .await?; + + let count = engine + .actor_request_json(Method::GET, &actor_id, "/count", None) + .await?; + assert_count(&count, 0)?; + + let incremented = engine + .actor_request_json(Method::POST, &actor_id, "/increment", None) + .await?; + assert_count(&incremented, 1)?; + + let incremented_again = engine + .actor_request_json(Method::POST, &actor_id, "/increment", None) + .await?; + assert_count(&incremented_again, 2)?; + + tokio::time::timeout(Duration::from_secs(30), serverless_runner.runner().wait_ready()) + .await + .context("timed out waiting for serverless runner init")??; + + wait_for_actor_presence(&actor_registry, &actor_id, true, Duration::from_secs(30)).await?; + + serverless_runner + .runner() + .handle() + .sleep_actor(&actor_id, None) + .await?; + wait_for_actor_presence(&actor_registry, &actor_id, false, Duration::from_secs(30)).await?; + + let persisted = engine + .actor_request_json(Method::GET, &actor_id, "/count", None) + .await?; + assert_count(&persisted, 2)?; + wait_for_actor_presence(&actor_registry, &actor_id, true, Duration::from_secs(30)).await?; + + serverless_runner.runner().shutdown(true).await?; + + let _ = shutdown_tx.send(()); + if tokio::time::timeout(Duration::from_secs(10), &mut server_task) + .await + .is_err() + { + server_task.abort(); + } + let _ = server_task.await; + + Ok(()) +} + +fn build_counter_app(actor_registry: Arc>>) -> AxumRunnerApp { + let on_start_registry = actor_registry.clone(); + let on_stop_registry = actor_registry; + + AxumRunnerApp::new().with_actor( + "counter", + AxumActorDefinition::new( + Router::new() + .route("/count", get(get_count)) + .route("/increment", post(increment)), + ) + .on_start(move |ctx: ActorContext| { + let actor_registry = on_start_registry.clone(); + async move { + actor_registry.lock().await.insert(ctx.actor_id); + Ok(()) + } + }) + .on_stop(move |ctx: ActorContext| { + let actor_registry = on_stop_registry.clone(); + async move { + actor_registry.lock().await.remove(&ctx.actor_id); + Ok(()) + } + }), + ) +} + +async fn get_count( + State(ctx): State, +) -> Result, StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0); + Ok(Json(json!({ "count": count }))) +} + +async fn increment( + State(ctx): State, +) -> Result, StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0) + + 1; + ctx.kv_put_u64("count", count) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(Json(json!({ "count": count }))) +} + +fn assert_count(value: &Value, expected: u64) -> Result<()> { + let actual = value + .get("count") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("response missing `count` field: {value}"))?; + if actual != expected { + bail!("count mismatch: expected {expected}, got {actual}"); + } + Ok(()) +} + +async fn wait_for_actor_presence( + actor_registry: &Arc>>, + actor_id: &str, + expected: bool, + timeout: Duration, +) -> Result<()> { + let deadline = Instant::now() + timeout; + loop { + let present = actor_registry.lock().await.contains(actor_id); + if present == expected { + return Ok(()); + } + if Instant::now() >= deadline { + bail!( + "timed out waiting for actor presence state actor_id={} expected_present={} actual_present={}", + actor_id, + expected, + present + ); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } +} diff --git a/engine/sdks/rust/engine-runner/tests/e2e_counter_serverless_upsert.rs b/engine/sdks/rust/engine-runner/tests/e2e_counter_serverless_upsert.rs new file mode 100644 index 0000000000..fd91cf6469 --- /dev/null +++ b/engine/sdks/rust/engine-runner/tests/e2e_counter_serverless_upsert.rs @@ -0,0 +1,208 @@ +mod common; + +use anyhow::{Context, Result, bail}; +use axum::{ + Json, Router, + extract::State, + http::StatusCode, + routing::{get, post}, +}; +use reqwest::Method; +use rivet_engine_runner::{ + ActorContext, ActorRequestContext, AxumActorDefinition, AxumRunnerApp, ServerlessConfig, + ServerlessRunner, +}; +use serde_json::{Value, json}; +use std::{ + collections::HashSet, + sync::Arc, + time::{Duration, Instant}, +}; +use tokio::sync::{Mutex, oneshot}; + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn counter_actor_serverless_upsert_config_http_kv_e2e() -> Result<()> { + let _test_lock = common::acquire_test_lock().await?; + let engine = common::EngineProcess::start().await?; + let namespace = "default".to_string(); + + let runner_name = common::random_name("rust-counter-serverless-upsert"); + let runner_key = common::random_name("key"); + let actor_key = common::random_name("counter"); + let actor_registry = Arc::new(Mutex::new(HashSet::::new())); + + let serverless_runner = ServerlessRunner::builder( + ServerlessConfig::builder() + .endpoint(engine.guard_url()) + .namespace(namespace.clone()) + .runner_name(runner_name.clone()) + .runner_key(runner_key) + .prepopulate_actor_name("counter", json!({})) + .token("dev") + .total_slots(1) + .max_runners(1000) + .slots_per_runner(1) + .request_lifespan(300) + .build()?, + ) + .app(build_counter_app(actor_registry.clone())) + .build()?; + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .context("failed to bind serverless test listener")?; + let addr = listener.local_addr().context("missing listener local addr")?; + let serverless_url = format!("http://localhost:{}", addr.port()); + + let routes = Arc::new(serverless_runner.clone()).axum_routes(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let mut server_task = tokio::spawn(async move { + axum::serve(listener, routes) + .with_graceful_shutdown(async move { + let _ = shutdown_rx.await; + }) + .await + .context("serverless axum server exited with error") + }); + + serverless_runner + .upsert_serverless_runner_config(&serverless_url) + .await + .context("failed to upsert serverless runner config")?; + + let actor_id = engine + .create_actor(&namespace, "counter", &runner_name, Some(&actor_key)) + .await?; + + let count = engine + .actor_request_json(Method::GET, &actor_id, "/count", None) + .await?; + assert_count(&count, 0)?; + + let incremented = engine + .actor_request_json(Method::POST, &actor_id, "/increment", None) + .await?; + assert_count(&incremented, 1)?; + + let incremented_again = engine + .actor_request_json(Method::POST, &actor_id, "/increment", None) + .await?; + assert_count(&incremented_again, 2)?; + + tokio::time::timeout(Duration::from_secs(30), serverless_runner.runner().wait_ready()) + .await + .context("timed out waiting for serverless runner init")??; + + wait_for_actor_presence(&actor_registry, &actor_id, true, Duration::from_secs(30)).await?; + + serverless_runner + .runner() + .handle() + .sleep_actor(&actor_id, None) + .await?; + wait_for_actor_presence(&actor_registry, &actor_id, false, Duration::from_secs(30)).await?; + + let persisted = engine + .actor_request_json(Method::GET, &actor_id, "/count", None) + .await?; + assert_count(&persisted, 2)?; + wait_for_actor_presence(&actor_registry, &actor_id, true, Duration::from_secs(30)).await?; + + serverless_runner.runner().shutdown(true).await?; + + let _ = shutdown_tx.send(()); + if tokio::time::timeout(Duration::from_secs(10), &mut server_task) + .await + .is_err() + { + server_task.abort(); + } + let _ = server_task.await; + + Ok(()) +} + +fn build_counter_app(actor_registry: Arc>>) -> AxumRunnerApp { + let on_start_registry = actor_registry.clone(); + let on_stop_registry = actor_registry; + + AxumRunnerApp::new().with_actor( + "counter", + AxumActorDefinition::new( + Router::new() + .route("/count", get(get_count)) + .route("/increment", post(increment)), + ) + .on_start(move |ctx: ActorContext| { + let actor_registry = on_start_registry.clone(); + async move { + actor_registry.lock().await.insert(ctx.actor_id); + Ok(()) + } + }) + .on_stop(move |ctx: ActorContext| { + let actor_registry = on_stop_registry.clone(); + async move { + actor_registry.lock().await.remove(&ctx.actor_id); + Ok(()) + } + }), + ) +} + +async fn get_count(State(ctx): State) -> Result, StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0); + Ok(Json(json!({ "count": count }))) +} + +async fn increment(State(ctx): State) -> Result, StatusCode> { + let count = ctx + .kv_get_u64("count") + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .unwrap_or(0) + + 1; + ctx.kv_put_u64("count", count) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(Json(json!({ "count": count }))) +} + +fn assert_count(value: &Value, expected: u64) -> Result<()> { + let actual = value + .get("count") + .and_then(Value::as_u64) + .ok_or_else(|| anyhow::anyhow!("response missing `count` field: {value}"))?; + if actual != expected { + bail!("count mismatch: expected {expected}, got {actual}"); + } + Ok(()) +} + +async fn wait_for_actor_presence( + actor_registry: &Arc>>, + actor_id: &str, + expected: bool, + timeout: Duration, +) -> Result<()> { + let deadline = Instant::now() + timeout; + loop { + let present = actor_registry.lock().await.contains(actor_id); + if present == expected { + return Ok(()); + } + if Instant::now() >= deadline { + bail!( + "timed out waiting for actor presence state actor_id={} expected_present={} actual_present={}", + actor_id, + expected, + present + ); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } +} diff --git a/engine/sdks/rust/engine-runner/tests/e2e_websocket.rs b/engine/sdks/rust/engine-runner/tests/e2e_websocket.rs new file mode 100644 index 0000000000..8db6160173 --- /dev/null +++ b/engine/sdks/rust/engine-runner/tests/e2e_websocket.rs @@ -0,0 +1,386 @@ +mod common; + +use anyhow::{Context, Result, bail}; +use async_trait::async_trait; +use futures_util::{SinkExt, StreamExt}; +use rivet_engine_runner::{ + ActorContext, HibernatingWebSocketMetadata, Runner, RunnerApp, RunnerConfig, RunnerHandle, + ServerlessConfig, ServerlessRunner, WebSocketContext, WebSocketMessage, +}; +use serde_json::json; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + time::{Duration, Instant}, +}; +use tokio::sync::{Mutex, oneshot}; +use tokio_tungstenite::tungstenite::Message; + +#[derive(Clone, Default)] +struct EchoWebSocketApp { + actors: Arc>>, + closes: Arc>>, + hibernating_metadata: + Arc>>>, +} + +#[async_trait] +impl RunnerApp for EchoWebSocketApp { + async fn on_actor_start(&self, runner: RunnerHandle, ctx: ActorContext) -> Result<()> { + self.actors.lock().await.insert(ctx.actor_id.clone()); + let metadata = self + .hibernating_metadata + .lock() + .await + .get(&ctx.actor_id) + .map(|entries| entries.values().cloned().collect()) + .unwrap_or_default(); + runner + .restore_hibernating_requests(&ctx.actor_id, metadata) + .await?; + Ok(()) + } + + async fn on_actor_stop(&self, _runner: RunnerHandle, ctx: ActorContext) -> Result<()> { + self.actors.lock().await.remove(&ctx.actor_id); + Ok(()) + } + + async fn websocket(&self, _runner: RunnerHandle, ctx: WebSocketContext) -> Result<()> { + self.hibernating_metadata + .lock() + .await + .entry(ctx.actor_id.clone()) + .or_default() + .insert( + (ctx.gateway_id, ctx.request_id), + HibernatingWebSocketMetadata { + gateway_id: ctx.gateway_id, + request_id: ctx.request_id, + client_message_index: 0, + server_message_index: 0, + path: ctx.path, + headers: ctx.headers, + }, + ); + Ok(()) + } + + async fn websocket_message( + &self, + runner: RunnerHandle, + ctx: WebSocketContext, + message: WebSocketMessage, + ) -> Result<()> { + if ctx.is_hibernatable { + runner + .send_hibernatable_websocket_message_ack( + ctx.gateway_id, + ctx.request_id, + message.message_index, + ) + .await?; + } + + let response_data = message.data.clone(); + let response_binary = message.binary; + runner + .send_websocket_message( + ctx.gateway_id, + ctx.request_id, + response_data, + response_binary, + ) + .await?; + + if let Some(actor_entries) = self + .hibernating_metadata + .lock() + .await + .get_mut(&ctx.actor_id) + { + if let Some(meta) = actor_entries.get_mut(&(ctx.gateway_id, ctx.request_id)) { + meta.server_message_index = message.message_index; + meta.client_message_index = meta.client_message_index.wrapping_add(1); + } + } + + Ok(()) + } + + async fn websocket_close( + &self, + _runner: RunnerHandle, + ctx: WebSocketContext, + _code: Option, + _reason: Option, + ) -> Result<()> { + let actor_id = ctx.actor_id.clone(); + self.closes.lock().await.push(actor_id.clone()); + if let Some(actor_entries) = self + .hibernating_metadata + .lock() + .await + .get_mut(&actor_id) + { + actor_entries.remove(&(ctx.gateway_id, ctx.request_id)); + } + Ok(()) + } + + fn can_hibernate(&self, _ctx: &WebSocketContext) -> bool { + true + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn websocket_runner_e2e() -> Result<()> { + let _test_lock = common::acquire_test_lock().await?; + let engine = common::EngineProcess::start().await?; + let namespace = "default".to_string(); + let runner_name = common::random_name("rust-ws-runner"); + let actor_key = common::random_name("ws"); + let app = EchoWebSocketApp::default(); + + let runner = Runner::builder( + RunnerConfig::builder() + .endpoint(engine.guard_url()) + .namespace(namespace.clone()) + .runner_name(runner_name.clone()) + .runner_key(common::random_name("key")) + .token("dev") + .total_slots(16) + .build()?, + ) + .app(app.clone()) + .build()?; + runner.start().await?; + runner.wait_ready().await?; + + let actor_id = engine + .create_actor(&namespace, "ws-echo", &runner_name, Some(&actor_key)) + .await?; + wait_for_actor_presence(&app.actors, &actor_id, true, Duration::from_secs(30)).await?; + + let mut ws = engine.actor_websocket_connect(&actor_id, "/ws").await?; + ws.send(Message::Text("ping".to_string().into())).await?; + let echoed = ws + .next() + .await + .context("missing echoed text frame")??; + assert_text_message(&echoed, "ping")?; + + ws.send(Message::Binary(vec![1u8, 2, 3].into())).await?; + let echoed_binary = ws + .next() + .await + .context("missing echoed binary frame")??; + assert_binary_message(&echoed_binary, &[1, 2, 3])?; + + let mut large_payload = vec![0u8; 64 * 1024]; + for (idx, byte) in large_payload.iter_mut().enumerate() { + *byte = (idx % 251) as u8; + } + ws.send(Message::Binary(large_payload.clone().into())).await?; + let echoed_large_binary = ws + .next() + .await + .context("missing echoed large binary frame")??; + assert_binary_message(&echoed_large_binary, &large_payload)?; + + ws.close(None).await?; + wait_for_close(&app.closes, &actor_id, Duration::from_secs(10)).await?; + + runner.handle().sleep_actor(&actor_id, None).await?; + wait_for_actor_presence(&app.actors, &actor_id, false, Duration::from_secs(30)).await?; + runner.shutdown(true).await?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn websocket_hibernation_restore_runner_e2e() -> Result<()> { + let _test_lock = common::acquire_test_lock().await?; + let engine = common::EngineProcess::start().await?; + let namespace = "default".to_string(); + let runner_name = common::random_name("rust-ws-hibernation-runner"); + let actor_key = common::random_name("ws"); + let app = EchoWebSocketApp::default(); + + let runner = Runner::builder( + RunnerConfig::builder() + .endpoint(engine.guard_url()) + .namespace(namespace.clone()) + .runner_name(runner_name.clone()) + .runner_key(common::random_name("key")) + .token("dev") + .total_slots(16) + .build()?, + ) + .app(app.clone()) + .build()?; + runner.start().await?; + runner.wait_ready().await?; + + let actor_id = engine + .create_actor(&namespace, "ws-echo", &runner_name, Some(&actor_key)) + .await?; + wait_for_actor_presence(&app.actors, &actor_id, true, Duration::from_secs(30)).await?; + + let mut ws = engine.actor_websocket_connect(&actor_id, "/ws").await?; + ws.send(Message::Text("before-sleep".to_string().into())).await?; + let echoed = ws + .next() + .await + .context("missing echoed before-sleep frame")??; + assert_text_message(&echoed, "before-sleep")?; + + runner.handle().sleep_actor(&actor_id, None).await?; + wait_for_actor_presence(&app.actors, &actor_id, false, Duration::from_secs(30)).await?; + + ws.send(Message::Text("after-sleep".to_string().into())).await?; + let echoed_after_sleep = ws + .next() + .await + .context("missing echoed after-sleep frame")??; + assert_text_message(&echoed_after_sleep, "after-sleep")?; + wait_for_actor_presence(&app.actors, &actor_id, true, Duration::from_secs(30)).await?; + + ws.close(None).await?; + wait_for_close(&app.closes, &actor_id, Duration::from_secs(10)).await?; + runner.shutdown(true).await?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn websocket_serverless_e2e() -> Result<()> { + let _test_lock = common::acquire_test_lock().await?; + let engine = common::EngineProcess::start().await?; + let namespace = "default".to_string(); + let runner_name = common::random_name("rust-ws-serverless"); + let actor_key = common::random_name("ws"); + let app = EchoWebSocketApp::default(); + + let serverless_runner = ServerlessRunner::builder( + ServerlessConfig::builder() + .endpoint(engine.guard_url()) + .namespace(namespace.clone()) + .runner_name(runner_name.clone()) + .runner_key(common::random_name("key")) + .token("dev") + .prepopulate_actor_name("ws-echo", json!({})) + .total_slots(1) + .max_runners(1000) + .slots_per_runner(1) + .request_lifespan(300) + .build()?, + ) + .app(app.clone()) + .build()?; + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .context("failed to bind serverless test listener")?; + let addr = listener.local_addr().context("missing listener local addr")?; + let serverless_url = format!("http://localhost:{}", addr.port()); + + let routes = Arc::new(serverless_runner.clone()).axum_routes(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let mut server_task = tokio::spawn(async move { + axum::serve(listener, routes) + .with_graceful_shutdown(async move { + let _ = shutdown_rx.await; + }) + .await + .context("serverless axum server exited with error") + }); + + let start_response = reqwest::Client::new() + .get(format!("{serverless_url}/api/rivet/start")) + .send() + .await?; + if start_response.status() != reqwest::StatusCode::OK { + bail!("serverless start endpoint returned {}", start_response.status()); + } + + let actor_id = engine + .create_actor(&namespace, "ws-echo", &runner_name, Some(&actor_key)) + .await?; + wait_for_actor_presence(&app.actors, &actor_id, true, Duration::from_secs(30)).await?; + + let mut ws = engine.actor_websocket_connect(&actor_id, "/ws").await?; + ws.send(Message::Text("pong".to_string().into())).await?; + let echoed = ws + .next() + .await + .context("missing echoed text frame")??; + assert_text_message(&echoed, "pong")?; + + ws.close(None).await?; + wait_for_close(&app.closes, &actor_id, Duration::from_secs(10)).await?; + + serverless_runner.runner().shutdown(true).await?; + let _ = shutdown_tx.send(()); + if tokio::time::timeout(Duration::from_secs(10), &mut server_task) + .await + .is_err() + { + server_task.abort(); + } + let _ = server_task.await; + Ok(()) +} + +fn assert_text_message(message: &Message, expected: &str) -> Result<()> { + match message { + Message::Text(text) if text.as_str() == expected => Ok(()), + _ => bail!("expected text websocket message `{expected}`, got `{message:?}`"), + } +} + +fn assert_binary_message(message: &Message, expected: &[u8]) -> Result<()> { + match message { + Message::Binary(data) if data.as_ref() == expected => Ok(()), + _ => bail!("expected binary websocket message `{expected:?}`, got `{message:?}`"), + } +} + +async fn wait_for_actor_presence( + actor_registry: &Arc>>, + actor_id: &str, + expected: bool, + timeout: Duration, +) -> Result<()> { + let deadline = Instant::now() + timeout; + loop { + let present = actor_registry.lock().await.contains(actor_id); + if present == expected { + return Ok(()); + } + if Instant::now() >= deadline { + bail!( + "timed out waiting for actor presence state actor_id={} expected_present={} actual_present={}", + actor_id, + expected, + present + ); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } +} + +async fn wait_for_close( + close_registry: &Arc>>, + actor_id: &str, + timeout: Duration, +) -> Result<()> { + let deadline = Instant::now() + timeout; + loop { + if close_registry.lock().await.iter().any(|x| x == actor_id) { + return Ok(()); + } + if Instant::now() >= deadline { + bail!("timed out waiting for websocket close callback actor_id={actor_id}"); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } +}