Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions controller/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ mod tests {

// Construct a system sender.
let system_sender = BoxedMailboxSender::new(MailboxClient::new(
channel::dial(server_handle.local_addr().clone()).unwrap(),
channel::dial(server_handle.local_addr().clone(), "test".to_string()).unwrap(),
));

// Construct a proc forwarder in terms of the system sender.
Expand Down Expand Up @@ -1360,7 +1360,7 @@ mod tests {

// Construct a system sender.
let system_sender = BoxedMailboxSender::new(MailboxClient::new(
channel::dial(server_handle.local_addr().clone()).unwrap(),
channel::dial(server_handle.local_addr().clone(), "test".to_string()).unwrap(),
));

// Construct a proc forwarder in terms of the system sender.
Expand Down Expand Up @@ -1572,8 +1572,11 @@ mod tests {
let world_id = id!(world);
// Set up a local actor.
let local_proc_id = world_id.proc_id(rank);
let (local_proc_addr, local_proc_rx) =
channel::serve(ChannelAddr::any(ChannelTransport::Local)).unwrap();
let (local_proc_addr, local_proc_rx) = channel::serve(
ChannelAddr::any(ChannelTransport::Local),
"test".to_string(),
)
.unwrap();
let local_proc_mbox = Mailbox::new_detached(
local_proc_id.actor_id(format!("test_dummy_proc{}", idx).to_string(), 0),
);
Expand Down
19 changes: 11 additions & 8 deletions hyperactor/benches/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ fn bench_message_sizes(c: &mut Criterion) {
assert!(!socket_addr.ip().is_loopback());
}

let (listen_addr, mut rx) = serve::<Message>(addr).unwrap();
let tx = dial::<Message>(listen_addr).unwrap();
let (listen_addr, mut rx) = serve::<Message>(addr, "".to_string()).unwrap();
let tx = dial::<Message>(listen_addr, "".to_string()).unwrap();
let msg = Message::new(0, size);
let start = Instant::now();
for _ in 0..iters {
Expand Down Expand Up @@ -127,7 +127,8 @@ fn bench_message_rates(c: &mut Criterion) {
b.iter_custom(|iters| async move {
let total_msgs = iters * rate;
let addr = ChannelAddr::any(transport.clone());
let (listen_addr, mut rx) = serve::<Message>(addr).unwrap();
let (listen_addr, mut rx) =
serve::<Message>(addr, "bench".to_string()).unwrap();
tokio::spawn(async move {
let mut received_count = 0;

Expand All @@ -141,7 +142,7 @@ fn bench_message_rates(c: &mut Criterion) {
}
});

let tx = dial::<Message>(listen_addr).unwrap();
let tx = dial::<Message>(listen_addr, "".to_string()).unwrap();
let message = Message::new(0, payload_size);
let start = Instant::now();

Expand Down Expand Up @@ -212,13 +213,15 @@ async fn channel_ping_pong(
struct Message(Part);

let (client_addr, mut client_rx) =
channel::serve::<Message>(ChannelAddr::any(transport.clone())).unwrap();
channel::serve::<Message>(ChannelAddr::any(transport.clone()), "client".to_string())
.unwrap();
let (server_addr, mut server_rx) =
channel::serve::<Message>(ChannelAddr::any(transport.clone())).unwrap();
channel::serve::<Message>(ChannelAddr::any(transport.clone()), "server".to_string())
.unwrap();

let _server_handle: tokio::task::JoinHandle<Result<(), anyhow::Error>> =
tokio::spawn(async move {
let client_tx = channel::dial(client_addr)?;
let client_tx = channel::dial(client_addr, "client".to_string())?;
loop {
let message = server_rx.recv().await?;
client_tx.post(message);
Expand All @@ -227,7 +230,7 @@ async fn channel_ping_pong(

let client_handle: tokio::task::JoinHandle<Result<(), anyhow::Error>> =
tokio::spawn(async move {
let server_tx = channel::dial(server_addr)?;
let server_tx = channel::dial(server_addr, "server".to_string())?;
let message = Message(Part::from(vec![0u8; message_size]));
for _ in 0..num_iter {
server_tx.post(message.clone() /*cheap */);
Expand Down
25 changes: 17 additions & 8 deletions hyperactor/example/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async fn server(mut server_rx: ChannelRx<Message>) -> Result<(), anyhow::Error>
.await?
.into_hello()
.map_err(|_| anyhow::anyhow!("expected hello message"))?;
let client_tx = channel::dial(client_addr)?;
let client_tx = channel::dial(client_addr, "client".to_string())?;
loop {
let message = server_rx.recv().await?;
client_tx.post(message);
Expand All @@ -62,10 +62,13 @@ async fn client(
message_size: usize,
num_iter: Option<usize>,
) -> anyhow::Result<()> {
let server_tx = channel::dial(server_addr)?;
let server_tx = channel::dial(server_addr, "server".to_string())?;

let (client_addr, mut client_rx) =
channel::serve::<Message>(ChannelAddr::any(server_tx.addr().transport().clone())).unwrap();
let (client_addr, mut client_rx) = channel::serve::<Message>(
ChannelAddr::any(server_tx.addr().transport().clone()),
"client".to_string(),
)
.unwrap();

server_tx.post(Message::Hello(client_addr));

Expand Down Expand Up @@ -163,8 +166,11 @@ async fn main() -> Result<(), anyhow::Error> {

match args.command {
Some(Commands::Server) => {
let (server_addr, server_rx) =
channel::serve::<Message>(ChannelAddr::any(args.transport.clone())).unwrap();
let (server_addr, server_rx) = channel::serve::<Message>(
ChannelAddr::any(args.transport.clone()),
"server".to_string(),
)
.unwrap();
eprintln!("server listening on {}", server_addr);
server(server_rx).await?;
}
Expand All @@ -175,8 +181,11 @@ async fn main() -> Result<(), anyhow::Error> {

// No command: run a self-contained benchmark.
None => {
let (server_addr, server_rx) =
channel::serve::<Message>(ChannelAddr::any(args.transport.clone())).unwrap();
let (server_addr, server_rx) = channel::serve::<Message>(
ChannelAddr::any(args.transport.clone()),
"server".to_string(),
)
.unwrap();
let _server_handle = tokio::spawn(server(server_rx));
let client_handle = tokio::spawn(client(server_addr, args.message_size, args.num_iter));

Expand Down
41 changes: 25 additions & 16 deletions hyperactor/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ pub trait Tx<M: RemoteMessage>: std::fmt::Debug {
let _ignore = self.try_post(message, oneshot::channel().0);
}

/// Send a message synchronously, returning when the messsage has
/// Send a message synchronously, returning when the message has
/// been delivered to the remote end of the channel.
async fn send(&self, message: M) -> Result<(), SendError<M>> {
let (tx, rx) = oneshot::channel();
Expand Down Expand Up @@ -825,14 +825,19 @@ impl<M: RemoteMessage> Rx<M> for ChannelRx<M> {
/// if the channel cannot be established. The underlying connection is
/// dropped whenever the returned Tx is dropped.
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, ChannelError> {
pub fn dial<M: RemoteMessage>(
addr: ChannelAddr,
label: String,
) -> Result<ChannelTx<M>, ChannelError> {
tracing::debug!(name = "dial", "dialing channel {}", addr);
let inner = match addr {
ChannelAddr::Local(port) => ChannelTxKind::Local(local::dial(port)?),
ChannelAddr::Tcp(addr) => ChannelTxKind::Tcp(net::tcp::dial(addr)),
ChannelAddr::MetaTls(meta_addr) => ChannelTxKind::MetaTls(net::meta::dial(meta_addr)?),
ChannelAddr::Tcp(addr) => ChannelTxKind::Tcp(net::tcp::dial(addr, label)),
ChannelAddr::MetaTls(meta_addr) => {
ChannelTxKind::MetaTls(net::meta::dial(meta_addr, label)?)
}
ChannelAddr::Sim(sim_addr) => ChannelTxKind::Sim(sim::dial::<M>(sim_addr)?),
ChannelAddr::Unix(path) => ChannelTxKind::Unix(net::unix::dial(path)),
ChannelAddr::Unix(path) => ChannelTxKind::Unix(net::unix::dial(path, label)),
};
Ok(ChannelTx { inner })
}
Expand All @@ -842,19 +847,20 @@ pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, Channel
#[crate::instrument]
pub fn serve<M: RemoteMessage>(
addr: ChannelAddr,
label: String,
) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
tracing::debug!(name = "serve", "serving channel address {}", addr);
match addr {
ChannelAddr::Tcp(addr) => {
let (addr, rx) = net::tcp::serve::<M>(addr)?;
let (addr, rx) = net::tcp::serve::<M>(addr, label)?;
Ok((addr, ChannelRxKind::Tcp(rx)))
}
ChannelAddr::MetaTls(meta_addr) => {
let (addr, rx) = net::meta::serve::<M>(meta_addr)?;
let (addr, rx) = net::meta::serve::<M>(meta_addr, label)?;
Ok((addr, ChannelRxKind::MetaTls(rx)))
}
ChannelAddr::Unix(path) => {
let (addr, rx) = net::unix::serve::<M>(path)?;
let (addr, rx) = net::unix::serve::<M>(path, label)?;
Ok((addr, ChannelRxKind::Unix(rx)))
}
ChannelAddr::Local(0) => {
Expand Down Expand Up @@ -1044,13 +1050,14 @@ mod tests {
#[tokio::test]
async fn test_multiple_connections() {
for addr in ChannelTransport::all().map(ChannelAddr::any) {
let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr).unwrap();
let (listen_addr, mut rx) =
crate::channel::serve::<u64>(addr, "test".to_string()).unwrap();

let mut sends: JoinSet<()> = JoinSet::new();
for message in 0u64..100u64 {
let addr = listen_addr.clone();
sends.spawn(async move {
let tx = dial::<u64>(addr).unwrap();
let tx = dial::<u64>(addr, "test".to_string()).unwrap();
tx.try_post(message, oneshot::channel().0).unwrap();
});
}
Expand Down Expand Up @@ -1083,9 +1090,9 @@ mod tests {
continue;
}

let (listen_addr, rx) = crate::channel::serve::<u64>(addr).unwrap();
let (listen_addr, rx) = crate::channel::serve::<u64>(addr, "test".to_string()).unwrap();

let tx = dial::<u64>(listen_addr).unwrap();
let tx = dial::<u64>(listen_addr, "test".to_string()).unwrap();
tx.try_post(123, oneshot::channel().0).unwrap();
drop(rx);

Expand Down Expand Up @@ -1132,8 +1139,9 @@ mod tests {
#[cfg_attr(not(feature = "fb"), ignore)]
async fn test_dial_serve() {
for addr in addrs() {
let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
let tx = crate::channel::dial(listen_addr).unwrap();
let (listen_addr, mut rx) =
crate::channel::serve::<i32>(addr, "test".to_string()).unwrap();
let tx = crate::channel::dial(listen_addr, "test".to_string()).unwrap();
tx.try_post(123, oneshot::channel().0).unwrap();
assert_eq!(rx.recv().await.unwrap(), 123);
}
Expand All @@ -1152,8 +1160,9 @@ mod tests {
);
let _guard2 = config.override_key(crate::config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
for addr in addrs() {
let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
let tx = crate::channel::dial(listen_addr).unwrap();
let (listen_addr, mut rx) =
crate::channel::serve::<i32>(addr, "test".to_string()).unwrap();
let tx = crate::channel::dial(listen_addr, "test".to_string()).unwrap();
tx.send(123).await.unwrap();
assert_eq!(rx.recv().await.unwrap(), 123);

Expand Down
Loading