Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion controller/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,7 @@ mod tests {
// Set up a local actor.
let local_proc_id = world_id.proc_id(rank);
let (local_proc_addr, local_proc_rx) =
channel::serve(ChannelAddr::any(ChannelTransport::Local), "mock_proc_actor").unwrap();
channel::serve(ChannelAddr::any(ChannelTransport::Local)).unwrap();
let local_proc_mbox = Mailbox::new_detached(
local_proc_id.actor_id(format!("test_dummy_proc{}", idx).to_string(), 0),
);
Expand Down
8 changes: 4 additions & 4 deletions hyperactor/benches/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ fn bench_message_sizes(c: &mut Criterion) {
assert!(!socket_addr.ip().is_loopback());
}

let (listen_addr, mut rx) = serve::<Message>(addr, "bench").unwrap();
let (listen_addr, mut rx) = serve::<Message>(addr).unwrap();
let tx = dial::<Message>(listen_addr).unwrap();
let msg = Message::new(0, size);
let start = Instant::now();
Expand Down Expand Up @@ -127,7 +127,7 @@ fn bench_message_rates(c: &mut Criterion) {
b.iter_custom(|iters| async move {
let total_msgs = iters * rate;
let addr = ChannelAddr::any(transport.clone());
let (listen_addr, mut rx) = serve::<Message>(addr, "bench").unwrap();
let (listen_addr, mut rx) = serve::<Message>(addr).unwrap();
tokio::spawn(async move {
let mut received_count = 0;

Expand Down Expand Up @@ -212,9 +212,9 @@ async fn channel_ping_pong(
struct Message(Part);

let (client_addr, mut client_rx) =
channel::serve::<Message>(ChannelAddr::any(transport.clone()), "ping_pong_client").unwrap();
channel::serve::<Message>(ChannelAddr::any(transport.clone())).unwrap();
let (server_addr, mut server_rx) =
channel::serve::<Message>(ChannelAddr::any(transport.clone()), "ping_pong_server").unwrap();
channel::serve::<Message>(ChannelAddr::any(transport.clone())).unwrap();

let _server_handle: tokio::task::JoinHandle<Result<(), anyhow::Error>> =
tokio::spawn(async move {
Expand Down
13 changes: 4 additions & 9 deletions hyperactor/example/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,8 @@ async fn client(
) -> anyhow::Result<()> {
let server_tx = channel::dial(server_addr)?;

let (client_addr, mut client_rx) = channel::serve::<Message>(
ChannelAddr::any(server_tx.addr().transport().clone()),
"example",
)
.unwrap();
let (client_addr, mut client_rx) =
channel::serve::<Message>(ChannelAddr::any(server_tx.addr().transport().clone())).unwrap();

server_tx.post(Message::Hello(client_addr));

Expand Down Expand Up @@ -167,8 +164,7 @@ async fn main() -> Result<(), anyhow::Error> {
match args.command {
Some(Commands::Server) => {
let (server_addr, server_rx) =
channel::serve::<Message>(ChannelAddr::any(args.transport.clone()), "example")
.unwrap();
channel::serve::<Message>(ChannelAddr::any(args.transport.clone())).unwrap();
eprintln!("server listening on {}", server_addr);
server(server_rx).await?;
}
Expand All @@ -180,8 +176,7 @@ async fn main() -> Result<(), anyhow::Error> {
// No command: run a self-contained benchmark.
None => {
let (server_addr, server_rx) =
channel::serve::<Message>(ChannelAddr::any(args.transport.clone()), "example")
.unwrap();
channel::serve::<Message>(ChannelAddr::any(args.transport.clone())).unwrap();
let _server_handle = tokio::spawn(server(server_rx));
let client_handle = tokio::spawn(client(server_addr, args.message_size, args.num_iter));

Expand Down
13 changes: 5 additions & 8 deletions hyperactor/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,6 @@ pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, Channel
#[crate::instrument]
pub fn serve<M: RemoteMessage>(
addr: ChannelAddr,
reason: &str,
) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
match addr {
ChannelAddr::Tcp(addr) => {
Expand Down Expand Up @@ -871,9 +870,7 @@ pub fn serve<M: RemoteMessage>(
.map(|(addr, inner)| {
tracing::debug!(
name = "serve",
"serving channel address {} for {}",
addr,
reason
%addr,
);
(addr, ChannelRx { inner })
})
Expand Down Expand Up @@ -1050,7 +1047,7 @@ mod tests {
#[tokio::test]
async fn test_multiple_connections() {
for addr in ChannelTransport::all().map(ChannelAddr::any) {
let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr, "test").unwrap();
let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr).unwrap();

let mut sends: JoinSet<()> = JoinSet::new();
for message in 0u64..100u64 {
Expand Down Expand Up @@ -1089,7 +1086,7 @@ mod tests {
continue;
}

let (listen_addr, rx) = crate::channel::serve::<u64>(addr, "test").unwrap();
let (listen_addr, rx) = crate::channel::serve::<u64>(addr).unwrap();

let tx = dial::<u64>(listen_addr).unwrap();
tx.try_post(123, oneshot::channel().0).unwrap();
Expand Down Expand Up @@ -1138,7 +1135,7 @@ mod tests {
#[cfg_attr(not(feature = "fb"), ignore)]
async fn test_dial_serve() {
for addr in addrs() {
let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr, "test").unwrap();
let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
let tx = crate::channel::dial(listen_addr).unwrap();
tx.try_post(123, oneshot::channel().0).unwrap();
assert_eq!(rx.recv().await.unwrap(), 123);
Expand All @@ -1158,7 +1155,7 @@ mod tests {
);
let _guard2 = config.override_key(crate::config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
for addr in addrs() {
let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr, "test").unwrap();
let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
let tx = crate::channel::dial(listen_addr).unwrap();
tx.send(123).await.unwrap();
assert_eq!(rx.recv().await.unwrap(), 123);
Expand Down
25 changes: 10 additions & 15 deletions hyperactor/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ impl<M: ProcManager> Host<M> {
/// Serve a host using the provided ProcManager, on the provided `addr`.
/// On success, the host will multiplex messages for procs on the host
/// on the address of the host.
#[tracing::instrument(skip(manager))]
pub async fn serve(
manager: M,
addr: ChannelAddr,
) -> Result<(Self, MailboxServerHandle), HostError> {
let (frontend_addr, frontend_rx) = channel::serve(addr, "host frontend")?;
let (frontend_addr, frontend_rx) = channel::serve(addr)?;

// We set up a cascade of routers: first, the outer router supports
// sending to the the system proc, while the dial router manages dialed
Expand All @@ -144,8 +145,7 @@ impl<M: ProcManager> Host<M> {

// Establish a backend channel on the preferred transport. We currently simply
// serve the same router on both.
let (backend_addr, backend_rx) =
channel::serve(ChannelAddr::any(manager.transport()), "host backend")?;
let (backend_addr, backend_rx) = channel::serve(ChannelAddr::any(manager.transport()))?;

// Set up a system proc. This is often used to manage the host itself.
let service_proc_id = ProcId::Direct(frontend_addr.clone(), "service".to_string());
Expand Down Expand Up @@ -863,6 +863,7 @@ where
ChannelTransport::Local
}

#[tracing::instrument(skip(self, _config))]
async fn spawn(
&self,
proc_id: ProcId,
Expand All @@ -874,10 +875,7 @@ where
proc_id.clone(),
MailboxClient::dial(forwarder_addr)?.into_boxed(),
);
let (proc_addr, rx) = channel::serve(
ChannelAddr::any(transport),
&format!("LocalProcManager spawning: {}", &proc_id),
)?;
let (proc_addr, rx) = channel::serve(ChannelAddr::any(transport))?;
self.procs
.lock()
.await
Expand Down Expand Up @@ -1043,16 +1041,15 @@ where
ChannelTransport::Unix
}

#[tracing::instrument(skip(self, _config))]
async fn spawn(
&self,
proc_id: ProcId,
forwarder_addr: ChannelAddr,
_config: (),
) -> Result<Self::Handle, HostError> {
let (callback_addr, mut callback_rx) = channel::serve(
ChannelAddr::any(ChannelTransport::Unix),
&format!("ProcessProcManager spawning: {}", &proc_id),
)?;
let (callback_addr, mut callback_rx) =
channel::serve(ChannelAddr::any(ChannelTransport::Unix))?;

let mut cmd = Command::new(&self.program);
cmd.env("HYPERACTOR_HOST_PROC_ID", proc_id.to_string());
Expand Down Expand Up @@ -1138,6 +1135,7 @@ where
/// forwarding messages to the provided `backend_addr`,
/// and returning the proc's address and agent actor on
/// the provided `callback_addr`.
#[tracing::instrument(skip(spawn))]
pub async fn spawn_proc<A, S, F>(
proc_id: ProcId,
backend_addr: ChannelAddr,
Expand All @@ -1163,10 +1161,7 @@ where

// Finally serve the proc on the same transport as the backend address,
// and call back.
let (proc_addr, proc_rx) = channel::serve(
ChannelAddr::any(backend_transport),
&format!("proc addr of: {}", &proc_id),
)?;
let (proc_addr, proc_rx) = channel::serve(ChannelAddr::any(backend_transport))?;
proc.clone().serve(proc_rx);
channel::dial(callback_addr)?
.send((proc_addr, agent_handle.bind::<A>()))
Expand Down
5 changes: 2 additions & 3 deletions hyperactor/src/mailbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2868,7 +2868,7 @@ mod tests {
.unwrap(),
);

let (_, rx) = serve::<MessageEnvelope>(ChannelAddr::Sim(dst_addr.clone()), "test").unwrap();
let (_, rx) = serve::<MessageEnvelope>(ChannelAddr::Sim(dst_addr.clone())).unwrap();
let tx = dial::<MessageEnvelope>(src_to_dst).unwrap();
let mbox = Mailbox::new_detached(id!(test[0].actor0));
let serve_handle = mbox.clone().serve(rx);
Expand Down Expand Up @@ -2997,8 +2997,7 @@ mod tests {

let mut handles = Vec::new(); // hold on to handles, or channels get closed
for mbox in mailboxes.iter() {
let (addr, rx) =
channel::serve(ChannelAddr::any(ChannelTransport::Local), "test").unwrap();
let (addr, rx) = channel::serve(ChannelAddr::any(ChannelTransport::Local)).unwrap();
let handle = (*mbox).clone().serve(rx);
handles.push(handle);

Expand Down
22 changes: 13 additions & 9 deletions hyperactor/src/proc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use tracing::Instrument;
use tracing::Level;
use uuid::Uuid;

use crate as hyperactor;
Expand Down Expand Up @@ -333,24 +334,23 @@ impl Proc {
}

/// Create a new direct-addressed proc.
#[tracing::instrument]
pub async fn direct(addr: ChannelAddr, name: String) -> Result<Self, ChannelError> {
let (addr, rx) = channel::serve(addr, &format!("creating Proc::direct: {}", name))?;
let (addr, rx) = channel::serve(addr)?;
let proc_id = ProcId::Direct(addr, name);
let proc = Self::new(proc_id, DialMailboxRouter::new().into_boxed());
proc.clone().serve(rx);
Ok(proc)
}

/// Create a new direct-addressed proc with a default sender for the forwarder.
#[tracing::instrument(skip(default))]
pub fn direct_with_default(
addr: ChannelAddr,
name: String,
default: BoxedMailboxSender,
) -> Result<Self, ChannelError> {
let (addr, rx) = channel::serve(
addr,
&format!("creating Proc::direct_with_default: {}", name),
)?;
let (addr, rx) = channel::serve(addr)?;
let proc_id = ProcId::Direct(addr, name);
let proc = Self::new(
proc_id,
Expand Down Expand Up @@ -495,15 +495,18 @@ impl Proc {
params: A::Params,
) -> Result<ActorHandle<A>, anyhow::Error> {
let actor_id = self.allocate_root_id(name)?;
let _ = tracing::debug_span!(
let span = tracing::span!(
Level::INFO,
"spawn_actor",
actor_name = name,
actor_type = std::any::type_name::<A>(),
actor_id = actor_id.to_string(),
);
let (instance, mut actor_loop_receivers, work_rx) =
Instance::new(self.clone(), actor_id.clone(), false, None);
let actor = A::new(params).await?;
let (instance, mut actor_loop_receivers, work_rx) = {
let _guard = span.clone().entered();
Instance::new(self.clone(), actor_id.clone(), false, None)
};
let actor = A::new(params).instrument(span.clone()).await?;
// Add this actor to the proc's actor ledger. We do not actively remove
// inactive actors from ledger, because the actor's state can be inferred
// from its weak cell.
Expand All @@ -513,6 +516,7 @@ impl Proc {

instance
.start(actor, actor_loop_receivers.take().unwrap(), work_rx)
.instrument(span)
.await
}

Expand Down
2 changes: 1 addition & 1 deletion hyperactor_mesh/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1731,7 +1731,7 @@ mod tests {
let config = hyperactor::config::global::lock();
let _guard = config.override_key(MAX_CAST_DIMENSION_SIZE, 2);

let (_, mut rx) = serve::<usize>(addr, "test").unwrap();
let (_, mut rx) = serve::<usize>(addr).unwrap();

let expected_ranks = selection
.eval(
Expand Down
8 changes: 3 additions & 5 deletions hyperactor_mesh/src/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,6 @@ impl AllocAssignedAddr {

pub(crate) fn serve_with_config<M: RemoteMessage>(
self,
reason: &str,
) -> anyhow::Result<(ChannelAddr, ChannelRx<M>)> {
fn set_as_inaddr_any(original: &mut SocketAddr) {
let inaddr_any: IpAddr = match &original {
Expand Down Expand Up @@ -552,7 +551,7 @@ impl AllocAssignedAddr {
}
};

let (mut bound, rx) = channel::serve(bind_to, reason)?;
let (mut bound, rx) = channel::serve(bind_to)?;

// Restore the original IP address if we used INADDR_ANY.
match &mut bound {
Expand Down Expand Up @@ -837,14 +836,13 @@ pub(crate) mod testing {
transport: ChannelTransport,
) -> (DialMailboxRouter, Instance<()>, Proc, ChannelAddr) {
let (router_channel_addr, router_rx) =
channel::serve(ChannelAddr::any(transport.clone()), "test").unwrap();
channel::serve(ChannelAddr::any(transport.clone())).unwrap();
let router =
DialMailboxRouter::new_with_default((UndeliverableMailboxSender {}).into_boxed());
router.clone().serve(router_rx);

let client_proc_id = ProcId::Ranked(WorldId("test_stuck".to_string()), 0);
let (client_proc_addr, client_rx) =
channel::serve(ChannelAddr::any(transport), "test").unwrap();
let (client_proc_addr, client_rx) = channel::serve(ChannelAddr::any(transport)).unwrap();
let client_proc = Proc::new(
client_proc_id.clone(),
BoxedMailboxSender::new(router.clone()),
Expand Down
5 changes: 1 addition & 4 deletions hyperactor_mesh/src/alloc/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,7 @@ impl Alloc for LocalAlloc {
match self.todo_rx.recv().await? {
Action::Start(rank) => {
let (addr, proc_rx) = loop {
match channel::serve(
ChannelAddr::any(self.transport()),
"LocalAlloc next proc addr",
) {
match channel::serve(ChannelAddr::any(self.transport())) {
Ok(addr_and_proc_rx) => break addr_and_proc_rx,
Err(err) => {
tracing::error!(
Expand Down
7 changes: 2 additions & 5 deletions hyperactor_mesh/src/alloc/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,8 @@ impl Allocator for ProcessAllocator {

#[hyperactor::instrument(fields(name = "process_allocate", monarch_client_trace_id = spec.constraints.match_labels.get(CLIENT_TRACE_ID_LABEL).cloned().unwrap_or_else(|| "".to_string())))]
async fn allocate(&mut self, spec: AllocSpec) -> Result<ProcessAlloc, AllocatorError> {
let (bootstrap_addr, rx) = channel::serve(
ChannelAddr::any(ChannelTransport::Unix),
"ProcessAllocator allocate bootstrap_addr",
)
.map_err(anyhow::Error::from)?;
let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix))
.map_err(anyhow::Error::from)?;

if spec.transport == ChannelTransport::Local {
return Err(AllocatorError::Other(anyhow::anyhow!(
Expand Down
Loading