@@ -3,12 +3,19 @@ use async_nats::jetstream::{self, consumer};
33use clap:: Parser ;
44use config:: Config ;
55use futures:: { SinkExt , StreamExt } ;
6+ use http_body_util:: Full ;
7+ use hyper:: body:: Bytes ;
8+ use hyper:: server:: conn:: http1;
9+ use hyper:: { Request , Response } ;
10+ use hyper_util:: rt:: TokioIo ;
611use pyth_stream:: utils:: setup_jetstream;
712use serde:: { Deserialize , Serialize } ;
813use std:: clone:: Clone ;
914use std:: collections:: { HashMap , HashSet } ;
15+ use std:: net:: SocketAddr ;
1016use std:: panic;
1117use std:: path:: PathBuf ;
18+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
1219use std:: sync:: Arc ;
1320use tokio:: net:: { TcpListener , TcpStream } ;
1421use tokio:: sync:: { mpsc, Mutex } ;
@@ -37,6 +44,7 @@ struct ServerResponse {
3744struct AppConfig {
3845 nats : NatsConfig ,
3946 websocket : WebSocketConfig ,
47+ healthcheck : HealthCheckConfig ,
4048}
4149
4250#[ derive( Debug , Deserialize , Clone ) ]
@@ -49,6 +57,11 @@ struct WebSocketConfig {
4957 addr : String ,
5058}
5159
60+ #[ derive( Debug , Deserialize , Clone ) ]
61+ struct HealthCheckConfig {
62+ addr : String ,
63+ }
64+
5265#[ derive( Parser , Debug , Deserialize ) ]
5366#[ command( author, version, about, long_about = None ) ]
5467struct Args {
@@ -58,6 +71,9 @@ struct Args {
5871
5972type Clients = Arc < Mutex < HashMap < String , ( HashSet < String > , mpsc:: UnboundedSender < Message > ) > > > ;
6073
74+ static NATS_CONNECTED : AtomicBool = AtomicBool :: new ( false ) ;
75+ static WS_LISTENER_ACTIVE : AtomicBool = AtomicBool :: new ( false ) ;
76+
6177#[ tokio:: main]
6278async fn main ( ) -> Result < ( ) > {
6379 fmt ( ) . with_env_filter ( EnvFilter :: from_default_env ( ) ) . init ( ) ;
@@ -76,8 +92,13 @@ async fn main() -> Result<()> {
7692 let clients: Clients = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
7793
7894 let server = TcpListener :: bind ( & config. websocket . addr ) . await ?;
95+ WS_LISTENER_ACTIVE . store ( true , Ordering :: SeqCst ) ;
7996 info ! ( "WebSocket server listening on: {}" , config. websocket. addr) ;
8097
98+ // Start the health check server
99+ let health_addr: SocketAddr = config. healthcheck . addr . parse ( ) ?;
100+ tokio:: spawn ( start_health_check_server ( health_addr) ) ;
101+
81102 let nats_config = config. nats . clone ( ) ;
82103 let clients_clone = clients. clone ( ) ;
83104 tokio:: spawn ( async move {
@@ -331,12 +352,72 @@ fn base58_to_hex(base58: &str) -> Result<String, anyhow::Error> {
331352}
332353
333354async fn connect_and_handle_nats ( config : & NatsConfig , clients : Clients ) -> Result < ( ) > {
355+ NATS_CONNECTED . store ( false , Ordering :: SeqCst ) ;
334356 let nats_client = async_nats:: connect ( & config. url ) . await ?;
335357 let jetstream = setup_jetstream ( & nats_client) . await ?;
336358
359+ NATS_CONNECTED . store ( true , Ordering :: SeqCst ) ;
337360 info ! ( "Connected to NATS server" ) ;
338361
339362 handle_nats_messages ( jetstream, clients) . await ?;
340363
364+ NATS_CONNECTED . store ( false , Ordering :: SeqCst ) ;
341365 Ok ( ( ) )
342366}
367+ async fn start_health_check_server ( addr : SocketAddr ) {
368+ let listener = TcpListener :: bind ( addr) . await . unwrap ( ) ;
369+ info ! ( "Health check server listening on http://{}" , addr) ;
370+
371+ loop {
372+ let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
373+ tokio:: spawn ( handle_health_check_connection ( stream) ) ;
374+ }
375+ }
376+
377+ async fn handle_health_check_connection ( stream : TcpStream ) {
378+ let io = TokioIo :: new ( stream) ;
379+
380+ if let Err ( err) = http1:: Builder :: new ( )
381+ . serve_connection ( io, hyper:: service:: service_fn ( health_check) )
382+ . await
383+ {
384+ error ! ( "Error serving connection: {:?}" , err) ;
385+ }
386+ }
387+
388+ async fn health_check (
389+ req : Request < hyper:: body:: Incoming > ,
390+ ) -> Result < Response < Full < Bytes > > , hyper:: Error > {
391+ if req. uri ( ) . path ( ) != "/health" {
392+ return Ok ( Response :: builder ( )
393+ . status ( hyper:: StatusCode :: NOT_FOUND )
394+ . body ( Full :: new ( Bytes :: from ( "Not Found" ) ) )
395+ . unwrap ( ) ) ;
396+ }
397+
398+ let nats_status = if NATS_CONNECTED . load ( Ordering :: SeqCst ) {
399+ "Connected"
400+ } else {
401+ "Disconnected"
402+ } ;
403+
404+ let ws_status = if WS_LISTENER_ACTIVE . load ( Ordering :: SeqCst ) {
405+ "Active"
406+ } else {
407+ "Inactive"
408+ } ;
409+
410+ let body = format ! ( "NATS: {}\n WebSocket Listener: {}" , nats_status, ws_status) ;
411+ let status = if nats_status == "Connected" && ws_status == "Active" {
412+ hyper:: StatusCode :: OK
413+ } else {
414+ hyper:: StatusCode :: SERVICE_UNAVAILABLE
415+ } ;
416+
417+ let response = Response :: builder ( )
418+ . status ( status)
419+ . body ( Full :: new ( Bytes :: from ( body) ) )
420+ . unwrap ( ) ;
421+
422+ Ok ( response)
423+ }
0 commit comments