diff --git a/clients/rust/Cargo.toml b/clients/rust/Cargo.toml index c9ef97acc..0ff9edf77 100644 --- a/clients/rust/Cargo.toml +++ b/clients/rust/Cargo.toml @@ -13,7 +13,7 @@ anyhow = "1.0" base64 = "0.22.1" eventsource-client = "0.14.0" futures-util = "0.3.31" -reqwest = "0.12.12" +reqwest = { version = "0.12.12", features = ["json"] } serde = { version = "1.0", features = ["derive"] } serde_cbor = "0.11.2" serde_json = "1.0" diff --git a/clients/rust/src/client.rs b/clients/rust/src/client.rs index 67e3b6e9f..46a7ee786 100644 --- a/clients/rust/src/client.rs +++ b/clients/rust/src/client.rs @@ -4,9 +4,10 @@ use anyhow::Result; use serde_json::{Value as JsonValue}; use crate::{ - common::{resolve_actor_id, ActorKey, EncodingKind, TransportKind}, + common::{ActorKey, EncodingKind, TransportKind}, handle::ActorHandle, - protocol::query::* + protocol::query::*, + remote_manager::RemoteManager, }; #[derive(Default)] @@ -35,7 +36,7 @@ pub struct CreateOptions { pub struct Client { - manager_endpoint: String, + remote_manager: RemoteManager, encoding_kind: EncodingKind, transport_kind: TransportKind, shutdown_tx: Arc>, @@ -48,7 +49,21 @@ impl Client { encoding_kind: EncodingKind, ) -> Self { Self { - manager_endpoint: manager_endpoint.to_string(), + remote_manager: RemoteManager::new(manager_endpoint, None), + encoding_kind, + transport_kind, + shutdown_tx: Arc::new(tokio::sync::broadcast::channel(1).0) + } + } + + pub fn new_with_token( + manager_endpoint: &str, + token: String, + transport_kind: TransportKind, + encoding_kind: EncodingKind, + ) -> Self { + Self { + remote_manager: RemoteManager::new(manager_endpoint, Some(token)), encoding_kind, transport_kind, shutdown_tx: Arc::new(tokio::sync::broadcast::channel(1).0) @@ -61,7 +76,7 @@ impl Client { query: ActorQuery ) -> ActorHandle { let handle = ActorHandle::new( - &self.manager_endpoint, + self.remote_manager.clone(), params, query, self.shutdown_tx.clone(), @@ -95,11 +110,13 @@ impl Client { pub fn get_for_id( &self, + name: &str, actor_id: &str, opts: GetOptions ) -> Result { let actor_query = ActorQuery::GetForId { get_for_id: GetForIdRequest { + name: name.to_string(), actor_id: actor_id.to_string(), } }; @@ -145,25 +162,17 @@ impl Client { opts: CreateOptions ) -> Result { let input = opts.input; - let region = opts.region; + let _region = opts.region; - let create_query = ActorQuery::Create { - create: CreateRequest { - name: name.to_string(), - key, - input, - region - } - }; - - let actor_id = resolve_actor_id( - &self.manager_endpoint, - create_query, - self.encoding_kind + let actor_id = self.remote_manager.create_actor( + name, + &key, + input, ).await?; let get_query = ActorQuery::GetForId { get_for_id: GetForIdRequest { + name: name.to_string(), actor_id, } }; diff --git a/clients/rust/src/common.rs b/clients/rust/src/common.rs index 62c62974c..fca42ca95 100644 --- a/clients/rust/src/common.rs +++ b/clients/rust/src/common.rs @@ -1,20 +1,37 @@ -use anyhow::Result; -use reqwest::{header::USER_AGENT, RequestBuilder}; -use serde::{de::DeserializeOwned, Serialize}; -use serde_json::{json, Value as JsonValue}; -use tracing::debug; - -use crate::protocol::query::ActorQuery; +#[allow(dead_code)] pub const VERSION: &str = env!("CARGO_PKG_VERSION"); pub const USER_AGENT_VALUE: &str = concat!("ActorClient-Rust/", env!("CARGO_PKG_VERSION")); -pub const HEADER_ACTOR_QUERY: &str = "X-AC-Query"; -pub const HEADER_ENCODING: &str = "X-AC-Encoding"; -pub const HEADER_CONN_PARAMS: &str = "X-AC-Conn-Params"; -pub const HEADER_ACTOR_ID: &str = "X-AC-Actor"; -pub const HEADER_CONN_ID: &str = "X-AC-Conn"; -pub const HEADER_CONN_TOKEN: &str = "X-AC-Conn-Token"; +// Headers +#[allow(dead_code)] +pub const HEADER_ACTOR_QUERY: &str = "x-rivet-query"; +pub const HEADER_ENCODING: &str = "x-rivet-encoding"; +pub const HEADER_CONN_PARAMS: &str = "x-rivet-conn-params"; +#[allow(dead_code)] +pub const HEADER_ACTOR_ID: &str = "x-rivet-actor"; +#[allow(dead_code)] +pub const HEADER_CONN_ID: &str = "x-rivet-conn"; +#[allow(dead_code)] +pub const HEADER_CONN_TOKEN: &str = "x-rivet-conn-token"; + +// Gateway headers +pub const HEADER_RIVET_TARGET: &str = "x-rivet-target"; +pub const HEADER_RIVET_ACTOR: &str = "x-rivet-actor"; +pub const HEADER_RIVET_TOKEN: &str = "x-rivet-token"; + +// Paths +pub const PATH_CONNECT_WEBSOCKET: &str = "/connect/websocket"; + +// WebSocket protocol prefixes +pub const WS_PROTOCOL_STANDARD: &str = "rivet"; +pub const WS_PROTOCOL_TARGET: &str = "rivet_target."; +pub const WS_PROTOCOL_ACTOR: &str = "rivet_actor."; +pub const WS_PROTOCOL_ENCODING: &str = "rivet_encoding."; +pub const WS_PROTOCOL_CONN_PARAMS: &str = "rivet_conn_params."; +pub const WS_PROTOCOL_CONN_ID: &str = "rivet_conn."; +pub const WS_PROTOCOL_CONN_TOKEN: &str = "rivet_conn_token."; +pub const WS_PROTOCOL_TOKEN: &str = "rivet_token."; #[derive(Debug, Clone, Copy)] pub enum TransportKind { @@ -46,154 +63,4 @@ impl ToString for EncodingKind { // Max size of each entry is 128 bytes -pub type ActorKey = Vec; - -pub struct HttpRequestOptions<'a, T: Serialize> { - pub method: &'a str, - pub url: &'a str, - pub headers: Vec<(&'a str, String)>, - pub body: Option, - pub encoding_kind: EncodingKind -} - -impl<'a, T: Serialize> Default for HttpRequestOptions<'a, T> { - fn default() -> Self { - Self { - method: "GET", - url: "", - headers: Vec::new(), - body: None, - encoding_kind: EncodingKind::Json - } - } -} - -fn build_http_request(opts: &HttpRequestOptions) -> Result -where - RQ: Serialize -{ - let client = reqwest::Client::new(); - let mut req = client.request( - reqwest::Method::from_bytes(opts.method.as_bytes()).unwrap(), - opts.url, - ); - - for (key, value) in &opts.headers { - req = req.header(*key, value); - } - - if opts.method == "POST" || opts.method == "PUT" { - let Some(body) = &opts.body else { - return Err(anyhow::anyhow!("Body is required for POST/PUT requests")); - }; - - match opts.encoding_kind { - EncodingKind::Json => { - req = req.header("Content-Type", "application/json"); - let body = serde_json::to_string(&body)?; - req = req.body(body); - } - EncodingKind::Cbor => { - req = req.header("Content-Type", "application/octet-stream"); - let body =serde_cbor::to_vec(&body)?; - req = req.body(body); - } - } - }; - - req = req.header(USER_AGENT, USER_AGENT_VALUE); - - Ok(req) -} - -async fn send_http_request_raw(req: reqwest::RequestBuilder) -> Result { - let res = req.send().await?; - - if !res.status().is_success() { - // TODO: Decode - /* - let data: Option = match opts.encoding_kind { - EncodingKind::Json => { - let data = res.text().await?; - - serde_json::from_str::(&data).ok() - } - EncodingKind::Cbor => { - let data = res.bytes().await?; - - serde_cbor::from_slice(&data).ok() - } - }; - - match data { - Some(data) => { - return Err(anyhow::anyhow!( - "HTTP request failed with status: {}, error: {}", - res.status(), - data.m - )); - }, - None => { - - } - } - */ - return Err(anyhow::anyhow!( - "HTTP request failed with status: {}", - res.status() - )); - } - - Ok(res) -} - -pub async fn send_http_request<'a, RQ, RS>(opts: HttpRequestOptions<'a, RQ>) -> Result -where - RQ: Serialize, - RS: DeserializeOwned, -{ - let req = build_http_request(&opts)?; - let res = send_http_request_raw(req).await?; - - let res: RS = match opts.encoding_kind { - EncodingKind::Json => { - let data = res.text().await?; - serde_json::from_str(&data)? - } - EncodingKind::Cbor => { - let bytes = res.bytes().await?; - serde_cbor::from_slice(&bytes)? - } - }; - - Ok(res) -} - - -pub async fn resolve_actor_id( - manager_endpoint: &str, - query: ActorQuery, - encoding_kind: EncodingKind -) -> Result { - #[derive(serde::Serialize, serde::Deserialize)] - struct ResolveResponse { - i: String, - } - - let query = serde_json::to_string(&query)?; - - let res = send_http_request::( - HttpRequestOptions { - method: "POST", - url: &format!("{}/actors/resolve", manager_endpoint), - headers: vec![ - (HEADER_ENCODING, encoding_kind.to_string()), - (HEADER_ACTOR_QUERY, query), - ], - body: Some(json!({})), - encoding_kind, - } - ).await?; - - Ok(res.i) -} \ No newline at end of file +pub type ActorKey = Vec; \ No newline at end of file diff --git a/clients/rust/src/connection.rs b/clients/rust/src/connection.rs index 935059b74..aa81eb6f5 100644 --- a/clients/rust/src/connection.rs +++ b/clients/rust/src/connection.rs @@ -3,7 +3,7 @@ use futures_util::FutureExt; use serde_json::Value; use std::fmt::Debug; use std::ops::Deref; -use std::sync::atomic::{AtomicI64, Ordering}; +use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; use std::{collections::HashMap, sync::Arc}; use tokio::sync::{broadcast, oneshot, watch, Mutex}; @@ -12,6 +12,7 @@ use crate::{ backoff::Backoff, protocol::{query::ActorQuery, *}, drivers::*, + remote_manager::RemoteManager, EncodingKind, TransportKind }; @@ -45,7 +46,7 @@ struct ConnectionAttempt { } pub struct ActorConnectionInner { - endpoint: String, + remote_manager: RemoteManager, transport_kind: TransportKind, encoding_kind: EncodingKind, query: ActorQuery, @@ -54,34 +55,42 @@ pub struct ActorConnectionInner { driver: Mutex>, msg_queue: Mutex>>, - rpc_counter: AtomicI64, - in_flight_rpcs: Mutex>>, + rpc_counter: AtomicU64, + in_flight_rpcs: Mutex>>, event_subscriptions: Mutex>>>, + // Connection info for reconnection + actor_id: Mutex>, + connection_id: Mutex>, + connection_token: Mutex>, + dc_watch: WatchPair, disconnection_rx: Mutex>>, } impl ActorConnectionInner { pub(crate) fn new( - endpoint: String, + remote_manager: RemoteManager, query: ActorQuery, transport_kind: TransportKind, encoding_kind: EncodingKind, parameters: Option, ) -> ActorConnection { Arc::new(Self { - endpoint: endpoint.clone(), + remote_manager, transport_kind, encoding_kind, query, parameters, driver: Mutex::new(None), msg_queue: Mutex::new(Vec::new()), - rpc_counter: AtomicI64::new(0), + rpc_counter: AtomicU64::new(0), in_flight_rpcs: Mutex::new(HashMap::new()), event_subscriptions: Mutex::new(HashMap::new()), + actor_id: Mutex::new(None), + connection_id: Mutex::new(None), + connection_token: Mutex::new(None), dc_watch: watch::channel(false), disconnection_rx: Mutex::new(None), }) @@ -92,13 +101,19 @@ impl ActorConnectionInner { } async fn try_connect(self: &Arc) -> ConnectionAttempt { + // Get connection info for reconnection + let conn_id = self.connection_id.lock().await.clone(); + let conn_token = self.connection_token.lock().await.clone(); + let Ok((driver, mut recver, task)) = connect_driver( self.transport_kind, DriverConnectArgs { - endpoint: self.endpoint.clone(), + remote_manager: self.remote_manager.clone(), query: self.query.clone(), encoding_kind: self.encoding_kind, parameters: self.parameters.clone(), + conn_id, + conn_token, } ).await else { // Either from immediate disconnect (local device connection refused) @@ -143,7 +158,7 @@ impl ActorConnectionInner { continue; }; - if let to_client::ToClientBody::Init { i: _ } = &msg.b { + if let to_client::ToClientBody::Init(_) = &msg.body { did_connection_open = true; } @@ -173,6 +188,11 @@ impl ActorConnectionInner { async fn on_open(self: &Arc, init: &to_client::Init) { debug!("Connected to server: {:?}", init); + // Store connection info for reconnection + *self.actor_id.lock().await = Some(init.actor_id.clone()); + *self.connection_id.lock().await = Some(init.connection_id.clone()); + *self.connection_token.lock().await = Some(init.connection_token.clone()); + for (event_name, _) in self.event_subscriptions.lock().await.iter() { self.send_subscription(event_name.clone(), true).await; } @@ -186,14 +206,14 @@ impl ActorConnectionInner { } async fn on_message(self: &Arc, msg: Arc) { - let body = &msg.b; + let body = &msg.body; match body { - to_client::ToClientBody::Init { i: init } => { + to_client::ToClientBody::Init(init) => { self.on_open(init).await; } - to_client::ToClientBody::ActionResponse { ar } => { - let id = ar.i; + to_client::ToClientBody::ActionResponse(ar) => { + let id = ar.id; let mut in_flight_rpcs = self.in_flight_rpcs.lock().await; let Some(tx) = in_flight_rpcs.remove(&id) else { debug!("Unexpected response: rpc id not found"); @@ -204,16 +224,25 @@ impl ActorConnectionInner { return; } } - to_client::ToClientBody::EventMessage { ev } => { + to_client::ToClientBody::Event(ev) => { + // Decode CBOR args + let args: Vec = match serde_cbor::from_slice(&ev.args) { + Ok(a) => a, + Err(e) => { + debug!("Failed to decode event args: {:?}", e); + return; + } + }; + let listeners = self.event_subscriptions.lock().await; - if let Some(callbacks) = listeners.get(&ev.n) { + if let Some(callbacks) = listeners.get(&ev.name) { for cb in callbacks { - cb(&ev.a); + cb(&args); } } } - to_client::ToClientBody::Error { e } => { - if let Some(action_id) = e.ai { + to_client::ToClientBody::Error(e) => { + if let Some(action_id) = e.action_id { let mut in_flight_rpcs = self.in_flight_rpcs.lock().await; let Some(tx) = in_flight_rpcs.remove(&action_id) else { debug!("Unexpected response: rpc id not found"); @@ -227,7 +256,7 @@ impl ActorConnectionInner { return; } - + debug!("Connection error: {} - {}", e.code, e.message); } } } @@ -256,39 +285,53 @@ impl ActorConnectionInner { } pub async fn action(self: &Arc, method: &str, params: Vec) -> Result { - let id: i64 = self.rpc_counter.fetch_add(1, Ordering::SeqCst); + let id: u64 = self.rpc_counter.fetch_add(1, Ordering::SeqCst); let (tx, rx) = oneshot::channel(); self.in_flight_rpcs.lock().await.insert(id, tx); + // Encode params as CBOR + let args_cbor = serde_cbor::to_vec(¶ms)?; + self.send_msg( Arc::new(to_server::ToServer { - b: to_server::ToServerBody::ActionRequest { - ar: to_server::ActionRequest { - i: id, - n: method.to_string(), - a: params, + body: to_server::ToServerBody::ActionRequest( + to_server::ActionRequest { + id, + name: method.to_string(), + args: args_cbor, }, - }, + ), }), SendMsgOpts::default(), ) .await; let Ok(res) = rx.await else { - // Verbosity return Err(anyhow::anyhow!("Socket closed during rpc")); }; match res { - Ok(ok) => Ok(ok.o), + Ok(ok) => { + // Decode CBOR output + let output: Value = serde_cbor::from_slice(&ok.output)?; + Ok(output) + } Err(err) => { - let metadata = err.md.unwrap_or(Value::Null); + let metadata = if let Some(md) = &err.metadata { + match serde_cbor::from_slice::(md) { + Ok(v) => v, + Err(_) => Value::Null, + } + } else { + Value::Null + }; Err(anyhow::anyhow!( - "RPC Error({}): {:?}, {:#}", - err.c, - err.m, + "RPC Error({}/{}): {}, {:#}", + err.group, + err.code, + err.message, metadata )) } @@ -298,12 +341,12 @@ impl ActorConnectionInner { async fn send_subscription(self: &Arc, event_name: String, subscribe: bool) { self.send_msg( Arc::new(to_server::ToServer { - b: to_server::ToServerBody::SubscriptionRequest { - sr: to_server::SubscriptionRequest { - e: event_name, - s: subscribe, + body: to_server::ToServerBody::SubscriptionRequest( + to_server::SubscriptionRequest { + event_name, + subscribe, }, - }, + ), }), SendMsgOpts { ephemeral: true }, ) @@ -426,7 +469,6 @@ pub fn start_connection( impl Debug for ActorConnectionInner { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ActorConnection") - .field("endpoint", &self.endpoint) .field("transport_kind", &self.transport_kind) .field("encoding_kind", &self.encoding_kind) .finish() diff --git a/clients/rust/src/drivers/mod.rs b/clients/rust/src/drivers/mod.rs index 8db37c88a..4eaf99792 100644 --- a/clients/rust/src/drivers/mod.rs +++ b/clients/rust/src/drivers/mod.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use crate::{ protocol::{query, to_client, to_server}, + remote_manager::RemoteManager, EncodingKind, TransportKind }; use anyhow::Result; @@ -65,10 +66,12 @@ pub type DriverConnection = ( ); pub struct DriverConnectArgs { - pub endpoint: String, + pub remote_manager: RemoteManager, pub encoding_kind: EncodingKind, pub query: query::ActorQuery, pub parameters: Option, + pub conn_id: Option, + pub conn_token: Option, } pub async fn connect_driver( diff --git a/clients/rust/src/drivers/sse.rs b/clients/rust/src/drivers/sse.rs index be314baa7..87731bc54 100644 --- a/clients/rust/src/drivers/sse.rs +++ b/clients/rust/src/drivers/sse.rs @@ -1,243 +1,11 @@ -use anyhow::{Result}; -use base64::prelude::*; -use eventsource_client::{BoxStream, Client, ClientBuilder, ReconnectOptionsBuilder, SSE}; -use futures_util::StreamExt; -use reqwest::header::USER_AGENT; -use std::sync::Arc; -use tokio::sync::mpsc; -use tracing::debug; +use anyhow::Result; -use crate::{ - common::{EncodingKind, HEADER_ACTOR_ID, HEADER_ACTOR_QUERY, HEADER_CONN_ID, HEADER_CONN_PARAMS, HEADER_CONN_TOKEN, HEADER_ENCODING, USER_AGENT_VALUE}, - protocol::{to_client, to_server} -}; +use super::{DriverConnectArgs, DriverConnection}; -use super::{ - DriverConnectArgs, DriverConnection, DriverHandle, DriverStopReason, MessageToClient, MessageToServer -}; - -#[derive(Debug, Clone, PartialEq, Eq)] -struct ConnectionDetails { - actor_id: String, - id: String, - token: String, -} - - -struct Context { - conn: ConnectionDetails, - encoding_kind: EncodingKind, - endpoint: String, +pub(crate) async fn connect(_args: DriverConnectArgs) -> Result { + // SSE transport is not currently supported with the new gateway architecture + // TODO: Implement SSE support via gateway + Err(anyhow::anyhow!( + "SSE transport not yet supported with gateway architecture" + )) } - -pub(crate) async fn connect(args: DriverConnectArgs) -> Result { - let endpoint = format!("{}/actors/connect/sse", args.endpoint); - - let params_string = match args.parameters { - Some(p) => Some(serde_json::to_string(&p)).transpose(), - None => Ok(None), - }?; - - let client = ClientBuilder::for_url(&endpoint)? - .header(USER_AGENT.as_str(), USER_AGENT_VALUE)? - .header(HEADER_ENCODING, args.encoding_kind.as_str())? - .header(HEADER_ACTOR_QUERY, serde_json::to_string(&args.query)?.as_str())?; - - let client = match params_string { - Some(p) => client.header(HEADER_CONN_PARAMS, p.as_str())?, - None => client, - }; - let client = client.reconnect(ReconnectOptionsBuilder::new(false).build()) - .build(); - - let (in_tx, in_rx) = mpsc::channel::(32); - let (out_tx, out_rx) = mpsc::channel::(32); - - let task = tokio::spawn(start(client, args.endpoint, args.encoding_kind, in_tx, out_rx)); - - let handle = DriverHandle::new(out_tx, task.abort_handle()); - Ok((handle, in_rx, task)) -} - -async fn sse_send_msg(ctx: &Context, msg: MessageToServer) -> Result { - let msg = serialize(ctx.encoding_kind, &msg)?; - - // Add connection ID and token to the request URL - let request_url = format!( - "{}/actors/message", - ctx.endpoint - ); - - let res = reqwest::Client::new() - .post(request_url) - .body(msg) - .header(USER_AGENT, USER_AGENT_VALUE) - .header(HEADER_ENCODING, ctx.encoding_kind.as_str()) - .header(HEADER_ACTOR_ID, ctx.conn.actor_id.as_str()) - .header(HEADER_CONN_ID, ctx.conn.id.as_str()) - .header(HEADER_CONN_TOKEN, ctx.conn.token.as_str()) - .send() - .await?; - - - if !res.status().is_success() { - return Err(anyhow::anyhow!("Failed to send message: {:?}", res)); - } - - let res = res.text().await?; - - Ok(res) -} - -async fn start( - client: impl Client, - endpoint: String, - encoding_kind: EncodingKind, - in_tx: mpsc::Sender, - mut out_rx: mpsc::Receiver, -) -> DriverStopReason { - let mut stream = client.stream(); - - let ctx = Context { - conn: match do_handshake(&mut stream, encoding_kind, &in_tx).await { - Ok(conn) => conn, - Err(reason) => return reason - }, - encoding_kind, - endpoint, - }; - - debug!("Handshake completed successfully"); - - loop { - tokio::select! { - // Handle outgoing messages - msg = out_rx.recv() => { - let Some(msg) = msg else { - return DriverStopReason::UserAborted; - }; - - let res = match sse_send_msg(&ctx, msg).await { - Ok(res) => res, - Err(e) => { - debug!("Failed to send message: {:?}", e); - continue; - } - }; - - debug!("Response: {:?}", res); - }, - msg = stream.next() => { - let Some(msg) = msg else { - // Receiver dropped - return DriverStopReason::ServerDisconnect; - }; - - match msg { - Ok(msg) => match msg { - SSE::Comment(comment) => debug!("Sse comment: {}", comment), - SSE::Connected(_) => debug!("warning: received sse connection past-handshake"), - SSE::Event(event) => { - let msg = match deserialize(encoding_kind, &event.data) { - Ok(msg) => msg, - Err(e) => { - debug!("Failed to deserialize {:?} {:?}", event, e); - continue; - } - }; - - if let Err(e) = in_tx.send(Arc::new(msg)).await { - debug!("Receiver in_rx dropped {:?}", e); - return DriverStopReason::UserAborted; - } - }, - } - Err(e) => { - debug!("Sse error: {}", e); - return DriverStopReason::ServerError; - } - } - } - } - } -} - -async fn do_handshake( - stream: &mut BoxStream>, - encoding_kind: EncodingKind, - in_tx: &mpsc::Sender, -) -> Result { - loop { - tokio::select! { - // Handle sse incoming - msg = stream.next() => { - let Some(msg) = msg else { - debug!("Receiver dropped"); - return Err(DriverStopReason::ServerDisconnect); - }; - - match msg { - Ok(msg) => match msg { - SSE::Comment(comment) => debug!("Sse comment {:?}", comment), - SSE::Connected(_) => debug!("Connected Sse"), - SSE::Event(event) => { - let msg = match deserialize(encoding_kind, &event.data) { - Ok(msg) => msg, - Err(e) => { - debug!("Failed to deserialize {:?} {:?}", event, e); - continue; - } - }; - - let msg = Arc::new(msg); - - if let Err(e) = in_tx.send(msg.clone()).await { - debug!("Receiver in_rx dropped {:?}", e); - return Err(DriverStopReason::UserAborted); - } - - // Wait until we get an Init packet - let to_client::ToClientBody::Init { i } = &msg.b else { - continue; - }; - - // Mark handshake complete - return Ok(ConnectionDetails { - actor_id: i.ai.to_string(), - id: i.ci.clone(), - token: i.ct.clone() - }) - }, - } - Err(e) => { - eprintln!("Sse error: {}", e); - return Err(DriverStopReason::ServerError); - } - } - } - } - } -} - -fn deserialize(encoding_kind: EncodingKind, msg: &str) -> Result { - match encoding_kind { - EncodingKind::Json => { - Ok(serde_json::from_str::(msg)?) - }, - EncodingKind::Cbor => { - let msg = serde_cbor::from_slice::( - &BASE64_STANDARD.decode(msg.as_bytes())? - )?; - - Ok(msg) - } - } -} - -fn serialize(encoding_kind: EncodingKind, msg: &to_server::ToServer) -> Result> { - match encoding_kind { - EncodingKind::Json => Ok(serde_json::to_vec(msg)?), - EncodingKind::Cbor => Ok(serde_cbor::to_vec(msg)?), - } -} - diff --git a/clients/rust/src/drivers/ws.rs b/clients/rust/src/drivers/ws.rs index e8d694500..c4252ed66 100644 --- a/clients/rust/src/drivers/ws.rs +++ b/clients/rust/src/drivers/ws.rs @@ -1,10 +1,8 @@ use anyhow::{Context, Result}; use futures_util::{SinkExt, StreamExt}; use std::sync::Arc; -use tokio::net::TcpStream; use tokio::sync::mpsc; use tokio_tungstenite::tungstenite::Message; -use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use tracing::debug; use crate::{ @@ -17,33 +15,20 @@ use super::{ DriverConnectArgs, DriverConnection, DriverHandle, DriverStopReason, MessageToClient, MessageToServer }; -fn build_connection_url(args: &DriverConnectArgs) -> Result { - let actor_query_string = serde_json::to_string(&args.query)?; - // TODO: Should replace http:// only at the start of the string - let url = args.endpoint - .to_string() - .replace("http://", "ws://") - .replace("https://", "wss://"); - - let url = format!( - "{}/actors/connect/websocket?encoding={}&query={}", - url, - args.encoding_kind.as_str(), - urlencoding::encode(&actor_query_string) - ); - - Ok(url) -} - - pub(crate) async fn connect(args: DriverConnectArgs) -> Result { - let url = build_connection_url(&args)?; + // Resolve actor ID + let actor_id = args.remote_manager.resolve_actor_id(&args.query).await?; - debug!("Connecting to: {}", url); + debug!("Opening WebSocket connection to actor via gateway: {}", actor_id); - let (ws, _res) = tokio_tungstenite::connect_async(url) - .await - .context("Failed to connect to WebSocket")?; + // Open WebSocket via remote manager (gateway) + let ws = args.remote_manager.open_websocket( + &actor_id, + args.encoding_kind, + args.parameters, + args.conn_id, + args.conn_token, + ).await.context("Failed to connect to WebSocket via gateway")?; let (in_tx, in_rx) = mpsc::channel::(32); let (out_tx, out_rx) = mpsc::channel::(32); @@ -51,21 +36,11 @@ pub(crate) async fn connect(args: DriverConnectArgs) -> Result let task = tokio::spawn(start(ws, args.encoding_kind, in_tx, out_rx)); let handle = DriverHandle::new(out_tx, task.abort_handle()); - handle.send(Arc::new( - to_server::ToServer { - b: to_server::ToServerBody::Init { - i: to_server::Init { - p: args.parameters - } - }, - } - )).await?; - Ok((handle, in_rx, task)) } async fn start( - ws: WebSocketStream>, + ws: tokio_tungstenite::WebSocketStream>, encoding_kind: EncodingKind, in_tx: mpsc::Sender, mut out_rx: mpsc::Receiver, diff --git a/clients/rust/src/handle.rs b/clients/rust/src/handle.rs index 4abda92af..2f17df27b 100644 --- a/clients/rust/src/handle.rs +++ b/clients/rust/src/handle.rs @@ -1,17 +1,16 @@ use std::{cell::RefCell, ops::Deref, sync::Arc}; use serde_json::Value as JsonValue; use anyhow::{anyhow, Result}; -use urlencoding::encode as url_encode; +use serde_cbor; use crate::{ - common::{resolve_actor_id, send_http_request, HttpRequestOptions, HEADER_ACTOR_QUERY, HEADER_CONN_PARAMS, HEADER_ENCODING}, + common::{EncodingKind, TransportKind, HEADER_ENCODING, HEADER_CONN_PARAMS}, connection::{start_connection, ActorConnection, ActorConnectionInner}, protocol::query::*, - EncodingKind, - TransportKind + remote_manager::RemoteManager, }; pub struct ActorHandleStateless { - endpoint: String, + remote_manager: RemoteManager, params: Option, encoding_kind: EncodingKind, query: RefCell, @@ -19,13 +18,13 @@ pub struct ActorHandleStateless { impl ActorHandleStateless { pub fn new( - endpoint: &str, + remote_manager: RemoteManager, params: Option, encoding_kind: EncodingKind, query: ActorQuery ) -> Self { Self { - endpoint: endpoint.to_string(), + remote_manager, params, encoding_kind, query: RefCell::new(query) @@ -33,78 +32,76 @@ impl ActorHandleStateless { } pub async fn action(&self, name: &str, args: Vec) -> Result { - #[derive(serde::Serialize)] - struct ActionRequest { - a: Vec, - } - #[derive(serde::Deserialize)] - struct ActionResponse { - o: JsonValue, - } + // Resolve actor ID + let query = self.query.borrow().clone(); + let actor_id = self.remote_manager.resolve_actor_id(&query).await?; - let actor_query = serde_json::to_string(&self.query)?; + // Encode args as CBOR + let args_cbor = serde_cbor::to_vec(&args)?; // Build headers let mut headers = vec![ (HEADER_ENCODING, self.encoding_kind.to_string()), - (HEADER_ACTOR_QUERY, actor_query), ]; if let Some(params) = &self.params { headers.push((HEADER_CONN_PARAMS, serde_json::to_string(params)?)); } - let res = send_http_request::(HttpRequestOptions { - url: &format!( - "{}/actors/actions/{}", - self.endpoint, - url_encode(name) - ), - method: "POST", + // Send request via gateway + let path = format!("/action/{}", urlencoding::encode(name)); + let res = self.remote_manager.send_request( + &actor_id, + &path, + "POST", headers, - body: Some(ActionRequest { - a: args, - }), - encoding_kind: self.encoding_kind, - }).await?; + Some(args_cbor), + ).await?; + + if !res.status().is_success() { + return Err(anyhow!("action failed: {}", res.status())); + } - Ok(res.o) + // Decode response + let output_cbor = res.bytes().await?; + let output: JsonValue = serde_cbor::from_slice(&output_cbor)?; + + Ok(output) } pub async fn resolve(&self) -> Result { let query = { - // None of this is async or runs on multithreads, - // it cannot fail given that both borrows are - // well contained, and cannot overlap. let Ok(query) = self.query.try_borrow() else { return Err(anyhow!("Failed to borrow actor query")); }; - query.clone() }; match query { - ActorQuery::Create { create: _query } => { + ActorQuery::Create { .. } => { Err(anyhow!("actor query cannot be create")) }, - ActorQuery::GetForId { get_for_id: query } => { - Ok(query.clone().actor_id) + ActorQuery::GetForId { get_for_id } => { + Ok(get_for_id.actor_id.clone()) }, _ => { - let actor_id = resolve_actor_id( - &self.endpoint, - query, - self.encoding_kind - ).await?; + let actor_id = self.remote_manager.resolve_actor_id(&query).await?; + + // Get name from the original query + let name = match &query { + ActorQuery::GetForKey { get_for_key } => get_for_key.name.clone(), + ActorQuery::GetOrCreateForKey { get_or_create_for_key } => get_or_create_for_key.name.clone(), + _ => return Err(anyhow!("unexpected query type")), + }; { - let Ok(mut query) = self.query.try_borrow_mut() else { - // Following code will not run (see prior note) + let Ok(mut query_mut) = self.query.try_borrow_mut() else { return Err(anyhow!("Failed to borrow actor query mutably")); }; - *query = ActorQuery::GetForId { + *query_mut = ActorQuery::GetForId { get_for_id: GetForIdRequest { + name, actor_id: actor_id.clone(), } }; @@ -118,7 +115,7 @@ impl ActorHandleStateless { pub struct ActorHandle { handle: ActorHandleStateless, - endpoint: String, + remote_manager: RemoteManager, params: Option, query: ActorQuery, client_shutdown_tx: Arc>, @@ -128,7 +125,7 @@ pub struct ActorHandle { impl ActorHandle { pub fn new( - endpoint: &str, + remote_manager: RemoteManager, params: Option, query: ActorQuery, client_shutdown_tx: Arc>, @@ -136,7 +133,7 @@ impl ActorHandle { encoding_kind: EncodingKind ) -> Self { let handle = ActorHandleStateless::new( - endpoint, + remote_manager.clone(), params.clone(), encoding_kind, query.clone() @@ -144,7 +141,7 @@ impl ActorHandle { Self { handle, - endpoint: endpoint.to_string(), + remote_manager, params, query, client_shutdown_tx, @@ -155,7 +152,7 @@ impl ActorHandle { pub fn connect(&self) -> ActorConnection { let conn = ActorConnectionInner::new( - self.endpoint.clone(), + self.remote_manager.clone(), self.query.clone(), self.transport_kind, self.encoding_kind, diff --git a/clients/rust/src/lib.rs b/clients/rust/src/lib.rs index 0bc31def2..b61abcf09 100644 --- a/clients/rust/src/lib.rs +++ b/clients/rust/src/lib.rs @@ -1,5 +1,6 @@ mod backoff; mod common; +mod remote_manager; pub mod client; pub mod drivers; pub mod connection; diff --git a/clients/rust/src/protocol/query.rs b/clients/rust/src/protocol/query.rs index 88cd2849a..b5e0d54b7 100644 --- a/clients/rust/src/protocol/query.rs +++ b/clients/rust/src/protocol/query.rs @@ -21,6 +21,7 @@ pub struct GetForKeyRequest { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GetForIdRequest { + pub name: String, #[serde(rename = "actorId")] pub actor_id: String, } diff --git a/clients/rust/src/protocol/to_client.rs b/clients/rust/src/protocol/to_client.rs index 5fe67d484..dca584dc8 100644 --- a/clients/rust/src/protocol/to_client.rs +++ b/clients/rust/src/protocol/to_client.rs @@ -1,59 +1,50 @@ use serde::{Deserialize, Serialize}; -use serde_json::Value as JsonValue; -// Only called for SSE because we don't need this for WebSockets #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Init { - // Actor ID - pub ai: String, - // Connection ID - pub ci: String, - // Connection token - pub ct: String, + #[serde(rename = "actorId")] + pub actor_id: String, + #[serde(rename = "connectionId")] + pub connection_id: String, + #[serde(rename = "connectionToken")] + pub connection_token: String, } // Used for connection errors (both during initialization and afterwards) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Error { - // Code - pub c: String, - // Message - pub m: String, - // Metadata + pub group: String, + pub code: String, + pub message: String, #[serde(skip_serializing_if = "Option::is_none")] - pub md: Option, - // Action ID + pub metadata: Option>, + #[serde(rename = "actionId")] #[serde(skip_serializing_if = "Option::is_none")] - pub ai: Option + pub action_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ActionResponse { - // ID - pub i: i64, - // Output - pub o: JsonValue + pub id: u64, + pub output: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Event { - // Event name - pub n: String, - // Event arguments - pub a: Vec, + pub name: String, + pub args: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] +#[serde(tag = "tag", content = "val")] pub enum ToClientBody { - Init { i: Init }, - Error { e: Error }, - ActionResponse { ar: ActionResponse }, - EventMessage { ev: Event }, + Init(Init), + Error(Error), + ActionResponse(ActionResponse), + Event(Event), } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToClient { - // Body - pub b: ToClientBody, + pub body: ToClientBody, } \ No newline at end of file diff --git a/clients/rust/src/protocol/to_server.rs b/clients/rust/src/protocol/to_server.rs index 5f4a3c1d4..ac609e37f 100644 --- a/clients/rust/src/protocol/to_server.rs +++ b/clients/rust/src/protocol/to_server.rs @@ -1,40 +1,27 @@ use serde::{Deserialize, Serialize}; -use serde_json::Value as JsonValue; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Init { - // Conn Params - #[serde(skip_serializing_if = "Option::is_none")] - pub p: Option -} #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ActionRequest { - // ID - pub i: i64, - // Name - pub n: String, - // Args - pub a: Vec, + pub id: u64, + pub name: String, + pub args: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SubscriptionRequest { - // Event name - pub e: String, - // Subscribe - pub s: bool, + #[serde(rename = "eventName")] + pub event_name: String, + pub subscribe: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] +#[serde(tag = "tag", content = "val")] pub enum ToServerBody { - Init { i: Init }, - ActionRequest { ar: ActionRequest }, - SubscriptionRequest { sr: SubscriptionRequest }, + ActionRequest(ActionRequest), + SubscriptionRequest(SubscriptionRequest), } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToServer { - pub b: ToServerBody, + pub body: ToServerBody, } diff --git a/clients/rust/src/remote_manager.rs b/clients/rust/src/remote_manager.rs new file mode 100644 index 000000000..8c107e93b --- /dev/null +++ b/clients/rust/src/remote_manager.rs @@ -0,0 +1,328 @@ +use anyhow::{anyhow, Result}; +use base64::{engine::general_purpose, Engine as _}; +use reqwest::header::USER_AGENT; +use serde::{Deserialize, Serialize}; +use serde_cbor; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; + +use crate::{ + common::{ + ActorKey, EncodingKind, USER_AGENT_VALUE, + HEADER_RIVET_TARGET, HEADER_RIVET_ACTOR, HEADER_RIVET_TOKEN, + WS_PROTOCOL_STANDARD, WS_PROTOCOL_TARGET, WS_PROTOCOL_ACTOR, + WS_PROTOCOL_ENCODING, WS_PROTOCOL_CONN_PARAMS, WS_PROTOCOL_CONN_ID, + WS_PROTOCOL_CONN_TOKEN, WS_PROTOCOL_TOKEN, PATH_CONNECT_WEBSOCKET, + }, + protocol::query::ActorQuery, +}; + +#[derive(Clone)] +pub struct RemoteManager { + endpoint: String, + token: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Actor { + actor_id: String, + name: String, + key: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ActorsListResponse { + actors: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ActorsGetOrCreateRequest { + name: String, + key: String, + #[serde(skip_serializing_if = "Option::is_none")] + input: Option, // base64-encoded CBOR +} + +#[derive(Debug, Serialize, Deserialize)] +struct ActorsGetOrCreateResponse { + actor: Actor, + created: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ActorsCreateRequest { + name: String, + key: String, + #[serde(skip_serializing_if = "Option::is_none")] + input: Option, // base64-encoded CBOR +} + +#[derive(Debug, Serialize, Deserialize)] +struct ActorsCreateResponse { + actor: Actor, +} + +impl RemoteManager { + pub fn new(endpoint: &str, token: Option) -> Self { + Self { + endpoint: endpoint.to_string(), + token, + } + } + + pub async fn get_for_id(&self, name: &str, actor_id: &str) -> Result> { + let url = format!("{}/actors?name={}&actor_ids={}", self.endpoint, urlencoding::encode(name), urlencoding::encode(actor_id)); + + let client = reqwest::Client::new(); + let mut req = client.get(&url).header(USER_AGENT, USER_AGENT_VALUE); + + if let Some(token) = &self.token { + req = req.header(HEADER_RIVET_TOKEN, token); + } + + let res = req.send().await?; + + if !res.status().is_success() { + return Err(anyhow!("failed to get actor: {}", res.status())); + } + + let data: ActorsListResponse = res.json().await?; + + if let Some(actor) = data.actors.first() { + if actor.name == name { + Ok(Some(actor.actor_id.clone())) + } else { + Ok(None) + } + } else { + Ok(None) + } + } + + pub async fn get_with_key(&self, name: &str, key: &ActorKey) -> Result> { + let key_str = serde_json::to_string(key)?; + let url = format!("{}/actors?name={}&key={}", self.endpoint, urlencoding::encode(name), urlencoding::encode(&key_str)); + + let client = reqwest::Client::new(); + let mut req = client.get(&url).header(USER_AGENT, USER_AGENT_VALUE); + + if let Some(token) = &self.token { + req = req.header(HEADER_RIVET_TOKEN, token); + } + + let res = req.send().await?; + + if !res.status().is_success() { + if res.status() == 404 { + return Ok(None); + } + return Err(anyhow!("failed to get actor by key: {}", res.status())); + } + + let data: ActorsListResponse = res.json().await?; + + if let Some(actor) = data.actors.first() { + Ok(Some(actor.actor_id.clone())) + } else { + Ok(None) + } + } + + pub async fn get_or_create_with_key( + &self, + name: &str, + key: &ActorKey, + input: Option, + ) -> Result { + let key_str = serde_json::to_string(key)?; + + let input_encoded = if let Some(inp) = input { + let cbor = serde_cbor::to_vec(&inp)?; + Some(general_purpose::STANDARD.encode(cbor)) + } else { + None + }; + + let request_body = ActorsGetOrCreateRequest { + name: name.to_string(), + key: key_str, + input: input_encoded, + }; + + let client = reqwest::Client::new(); + let mut req = client + .put(format!("{}/actors", self.endpoint)) + .header(USER_AGENT, USER_AGENT_VALUE) + .json(&request_body); + + if let Some(token) = &self.token { + req = req.header(HEADER_RIVET_TOKEN, token); + } + + let res = req.send().await?; + + if !res.status().is_success() { + return Err(anyhow!("failed to get or create actor: {}", res.status())); + } + + let data: ActorsGetOrCreateResponse = res.json().await?; + Ok(data.actor.actor_id) + } + + pub async fn create_actor( + &self, + name: &str, + key: &ActorKey, + input: Option, + ) -> Result { + let key_str = serde_json::to_string(key)?; + + let input_encoded = if let Some(inp) = input { + let cbor = serde_cbor::to_vec(&inp)?; + Some(general_purpose::STANDARD.encode(cbor)) + } else { + None + }; + + let request_body = ActorsCreateRequest { + name: name.to_string(), + key: key_str, + input: input_encoded, + }; + + let client = reqwest::Client::new(); + let mut req = client + .post(format!("{}/actors", self.endpoint)) + .header(USER_AGENT, USER_AGENT_VALUE) + .json(&request_body); + + if let Some(token) = &self.token { + req = req.header(HEADER_RIVET_TOKEN, token); + } + + let res = req.send().await?; + + if !res.status().is_success() { + return Err(anyhow!("failed to create actor: {}", res.status())); + } + + let data: ActorsCreateResponse = res.json().await?; + Ok(data.actor.actor_id) + } + + pub async fn resolve_actor_id(&self, query: &ActorQuery) -> Result { + match query { + ActorQuery::GetForId { get_for_id } => { + self.get_for_id(&get_for_id.name, &get_for_id.actor_id) + .await? + .ok_or_else(|| anyhow!("actor not found")) + } + ActorQuery::GetForKey { get_for_key } => { + self.get_with_key(&get_for_key.name, &get_for_key.key) + .await? + .ok_or_else(|| anyhow!("actor not found")) + } + ActorQuery::GetOrCreateForKey { get_or_create_for_key } => { + self.get_or_create_with_key( + &get_or_create_for_key.name, + &get_or_create_for_key.key, + get_or_create_for_key.input.clone(), + ) + .await + } + ActorQuery::Create { create } => { + self.create_actor(&create.name, &create.key, create.input.clone()) + .await + } + } + } + + pub async fn send_request( + &self, + actor_id: &str, + path: &str, + method: &str, + headers: Vec<(&str, String)>, + body: Option>, + ) -> Result { + let url = format!("{}{}", self.endpoint, path); + + let client = reqwest::Client::new(); + let mut req = client + .request( + reqwest::Method::from_bytes(method.as_bytes())?, + &url, + ) + .header(USER_AGENT, USER_AGENT_VALUE) + .header(HEADER_RIVET_TARGET, "actor") + .header(HEADER_RIVET_ACTOR, actor_id); + + if let Some(token) = &self.token { + req = req.header(HEADER_RIVET_TOKEN, token); + } + + for (key, value) in headers { + req = req.header(key, value); + } + + if let Some(body_data) = body { + req = req.body(body_data); + } + + let res = req.send().await?; + Ok(res) + } + + pub async fn open_websocket( + &self, + actor_id: &str, + encoding: EncodingKind, + params: Option, + conn_id: Option, + conn_token: Option, + ) -> Result>> { + use tokio_tungstenite::connect_async; + + // Build WebSocket URL + let ws_url = if self.endpoint.starts_with("https://") { + format!("wss://{}{}", &self.endpoint[8..], PATH_CONNECT_WEBSOCKET) + } else if self.endpoint.starts_with("http://") { + format!("ws://{}{}", &self.endpoint[7..], PATH_CONNECT_WEBSOCKET) + } else { + return Err(anyhow!("invalid endpoint URL")); + }; + + // Build protocols + let mut protocols = vec![ + WS_PROTOCOL_STANDARD.to_string(), + format!("{}actor", WS_PROTOCOL_TARGET), + format!("{}{}", WS_PROTOCOL_ACTOR, actor_id), + format!("{}{}", WS_PROTOCOL_ENCODING, encoding.as_str()), + ]; + + if let Some(token) = &self.token { + protocols.push(format!("{}{}", WS_PROTOCOL_TOKEN, token)); + } + + if let Some(p) = params { + let params_str = serde_json::to_string(&p)?; + protocols.push(format!("{}{}", WS_PROTOCOL_CONN_PARAMS, urlencoding::encode(¶ms_str))); + } + + if let Some(cid) = conn_id { + protocols.push(format!("{}{}", WS_PROTOCOL_CONN_ID, cid)); + } + + if let Some(ct) = conn_token { + protocols.push(format!("{}{}", WS_PROTOCOL_CONN_TOKEN, ct)); + } + + let mut request = ws_url.into_client_request()?; + request.headers_mut().insert( + "Sec-WebSocket-Protocol", + protocols.join(", ").parse()?, + ); + + let (ws_stream, _) = connect_async(request).await?; + Ok(ws_stream) + } +} diff --git a/clients/rust/tests/e2e.rs b/clients/rust/tests/e2e.rs index 016373748..a0ad4a64b 100644 --- a/clients/rust/tests/e2e.rs +++ b/clients/rust/tests/e2e.rs @@ -2,7 +2,6 @@ use rivetkit_client::{Client, EncodingKind, GetOrCreateOptions, TransportKind}; use fs_extra; use portpicker; use serde_json::json; -use tracing_subscriber::EnvFilter; use std::process::{Child, Command}; use std::time::Duration; use tempfile;