From 551d37fa60b105c0dcb418717468e6edac07d74c Mon Sep 17 00:00:00 2001 From: Paul Holzinger Date: Fri, 1 Aug 2025 17:51:12 +0200 Subject: [PATCH] type safety: make Socket aware of namespace Using a PhantomData and zero sized struct we can attach an additional generic data for either the host or container namespace to the struct. Because it is zero sized and not used it is optimized away and only the type checker sees it to enforce the right types are used. With that we basically create two different netlink socket types Socket and Socket so they must be used in all type signatures from now on. To keep the changes smaller I have set HostNS as default generic for the struct so we don't need to change most function signatures. For al call sides where we pass sockets around the compiler now enforces that we use the right ones avoid any possible mix ups. Signed-off-by: Paul Holzinger --- src/commands/setup.rs | 7 +++++-- src/dhcp_proxy/ip.rs | 8 ++++---- src/network/bridge.rs | 16 +++++++++++----- src/network/core_utils.rs | 15 +++++++++------ src/network/dhcp.rs | 5 ++++- src/network/driver.rs | 10 ++++++++-- src/network/netlink.rs | 18 +++++++++++++++--- src/network/plugin.rs | 10 ++++++++-- src/network/vlan.rs | 12 +++++++++--- src/test/netlink.rs | 15 +++++++++------ 10 files changed, 82 insertions(+), 34 deletions(-) diff --git a/src/commands/setup.rs b/src/commands/setup.rs index d83c219ae..139e66cef 100644 --- a/src/commands/setup.rs +++ b/src/commands/setup.rs @@ -160,8 +160,11 @@ impl Setup { } } -fn teardown_drivers<'a, I>(drivers: I, host: &mut netlink::Socket, netns: &mut netlink::Socket) -where +fn teardown_drivers<'a, I>( + drivers: I, + host: &mut netlink::Socket, + netns: &mut netlink::Socket, +) where I: Iterator>, { for driver in drivers { diff --git a/src/dhcp_proxy/ip.rs b/src/dhcp_proxy/ip.rs index 131bad5bc..f91937c93 100644 --- a/src/dhcp_proxy/ip.rs +++ b/src/dhcp_proxy/ip.rs @@ -34,8 +34,8 @@ trait Address { fn new(l: &Lease, interface: &str) -> Result where Self: Sized; - fn add_ip(&self, nls: &mut Socket) -> Result<(), ProxyError>; - fn add_gws(&self, nls: &mut Socket) -> Result<(), ProxyError>; + fn add_ip(&self, nls: &mut Socket) -> Result<(), ProxyError>; + fn add_gws(&self, nls: &mut Socket) -> Result<(), ProxyError>; } fn handle_gws(g: Vec, netmask: &str) -> Result, ProxyError> { @@ -112,7 +112,7 @@ impl Address for MacVLAN { } // add the ip address to the container namespace - fn add_ip(&self, nls: &mut Socket) -> Result<(), ProxyError> { + fn add_ip(&self, nls: &mut netlink::Socket) -> Result<(), ProxyError> { debug!("adding network information for {}", self.interface); let ip = IpNet::new(self.address, self.prefix_length)?; let dev = nls.get_link(netlink::LinkID::Name(self.interface.clone()))?; @@ -123,7 +123,7 @@ impl Address for MacVLAN { } // add one or more routes to the container namespace - fn add_gws(&self, nls: &mut Socket) -> Result<(), ProxyError> { + fn add_gws(&self, nls: &mut Socket) -> Result<(), ProxyError> { debug!("adding gateways to {}", self.interface); match core_utils::add_default_routes(nls, &self.gateways, None) { Ok(_) => Ok(()), diff --git a/src/network/bridge.rs b/src/network/bridge.rs index 0a4ef72fc..84c133eda 100644 --- a/src/network/bridge.rs +++ b/src/network/bridge.rs @@ -147,7 +147,10 @@ impl driver::NetworkDriver for Bridge<'_> { fn setup( &self, - netlink_sockets: (&mut netlink::Socket, &mut netlink::Socket), + netlink_sockets: ( + &mut netlink::Socket, + &mut netlink::Socket, + ), ) -> NetavarkResult<(StatusBlock, Option)> { let data = match &self.data { Some(d) => d, @@ -302,7 +305,10 @@ impl driver::NetworkDriver for Bridge<'_> { fn teardown( &self, - netlink_sockets: (&mut netlink::Socket, &mut netlink::Socket), + netlink_sockets: ( + &mut netlink::Socket, + &mut netlink::Socket, + ), ) -> NetavarkResult<()> { let mode: Option = parse_option(&self.info.network.options, OPTION_MODE)?; let mode = get_bridge_mode_from_string(mode.as_deref())?; @@ -547,7 +553,7 @@ const IPV6_FORWARD: &str = "net/ipv6/conf/all/forwarding"; /// returns the container veth mac address fn create_interfaces( host: &mut netlink::Socket, - netns: &mut netlink::Socket, + netns: &mut netlink::Socket, data: &InternalData, internal: bool, rootless: bool, @@ -745,7 +751,7 @@ fn create_interfaces( #[allow(clippy::too_many_arguments)] fn create_veth_pair<'fd>( host: &mut netlink::Socket, - netns: &mut netlink::Socket, + netns: &mut netlink::Socket, data: &InternalData, primary_index: u32, bridge_mac: Option>, @@ -992,7 +998,7 @@ fn check_link_is_vrf(msg: LinkMessage, vrf_name: &str) -> NetavarkResult, mode: BridgeMode, br_name: &str, container_veth_name: &str, diff --git a/src/network/core_utils.rs b/src/network/core_utils.rs index ed1c1b9f9..d645e9617 100644 --- a/src/network/core_utils.rs +++ b/src/network/core_utils.rs @@ -270,24 +270,27 @@ macro_rules! exec_netns { }}; } -pub struct NamespaceOptions { +pub struct NamespaceOptions { /// Note we have to return the File object since the fd is only valid /// as long as the File object is valid pub file: File, - pub netlink: netlink::Socket, + pub netlink: netlink::Socket, } pub fn open_netlink_sockets( netns_path: &str, -) -> NetavarkResult<(NamespaceOptions, NamespaceOptions)> { +) -> NetavarkResult<( + NamespaceOptions, + NamespaceOptions, +)> { let netns = open_netlink_socket(netns_path).wrap("open container netns")?; let hostns = open_netlink_socket("/proc/self/ns/net").wrap("open host netns")?; - let host_socket = netlink::Socket::new().wrap("host netlink socket")?; + let host_socket = netlink::Socket::::new().wrap("host netlink socket")?; let netns_sock = exec_netns!( hostns.as_fd(), netns.as_fd(), - netlink::Socket::new().wrap("netns netlink socket") + netlink::Socket::::new().wrap("netns netlink socket") )?; Ok(( @@ -307,7 +310,7 @@ fn open_netlink_socket(netns_path: &str) -> NetavarkResult { } pub fn add_default_routes( - sock: &mut netlink::Socket, + sock: &mut netlink::Socket, gws: &[ipnet::IpNet], metric: Option, ) -> NetavarkResult<()> { diff --git a/src/network/dhcp.rs b/src/network/dhcp.rs index dec4c3e6d..e7849b49f 100644 --- a/src/network/dhcp.rs +++ b/src/network/dhcp.rs @@ -159,7 +159,10 @@ pub fn release_dhcp_lease( Ok(()) } -pub fn dhcp_teardown(info: &DriverInfo, sock: &mut netlink::Socket) -> NetavarkResult<()> { +pub fn dhcp_teardown( + info: &DriverInfo, + sock: &mut netlink::Socket, +) -> NetavarkResult<()> { let ipam = core_utils::get_ipam_addresses(info.per_network_opts, info.network)?; let if_name = info.per_network_opts.interface_name.clone(); diff --git a/src/network/driver.rs b/src/network/driver.rs index f4df8ba0e..79683fd75 100644 --- a/src/network/driver.rs +++ b/src/network/driver.rs @@ -38,12 +38,18 @@ pub trait NetworkDriver { /// setup the network interfaces/firewall rules for this driver fn setup( &self, - netlink_sockets: (&mut netlink::Socket, &mut netlink::Socket), + netlink_sockets: ( + &mut netlink::Socket, + &mut netlink::Socket, + ), ) -> NetavarkResult<(StatusBlock, Option)>; /// teardown the network interfaces/firewall rules for this driver fn teardown( &self, - netlink_sockets: (&mut netlink::Socket, &mut netlink::Socket), + netlink_sockets: ( + &mut netlink::Socket, + &mut netlink::Socket, + ), ) -> NetavarkResult<()>; /// return the network name diff --git a/src/network/netlink.rs b/src/network/netlink.rs index 67ceef9fd..acd96c9f9 100644 --- a/src/network/netlink.rs +++ b/src/network/netlink.rs @@ -24,11 +24,22 @@ use netlink_packet_route::{ }; use netlink_sys::{protocols::NETLINK_ROUTE, SocketAddr}; -pub struct Socket { +/// Marker trait for the Socket so we know it refers to either HostNS or ContainerNS. +pub trait Namespace {} +pub struct HostNS; +pub struct ContainerNS; +impl Namespace for HostNS {} +impl Namespace for ContainerNS {} + +pub struct Socket { socket: netlink_sys::Socket, sequence_number: u32, /// buffer size for reading netlink messages, see NLMSG_GOODSIZE in the kernel buffer: [u8; 8192], + + /// Marker for host or container namespace, allows functions to specify which netns + /// socket they need which then gets enforced at compile time without any runtime overhead. + _marker: std::marker::PhantomData, } #[derive(Clone)] @@ -110,8 +121,8 @@ macro_rules! function { }}; } -impl Socket { - pub fn new() -> NetavarkResult { +impl Socket { + pub fn new() -> NetavarkResult> { let mut socket = wrap!(netlink_sys::Socket::new(NETLINK_ROUTE), "open")?; let addr = &SocketAddr::new(0, 0); wrap!(socket.bind(addr), "bind")?; @@ -121,6 +132,7 @@ impl Socket { socket, sequence_number: 0, buffer: [0; 8192], + _marker: std::marker::PhantomData, }) } diff --git a/src/network/plugin.rs b/src/network/plugin.rs index 3c4d9e7df..ed4da8e2f 100644 --- a/src/network/plugin.rs +++ b/src/network/plugin.rs @@ -36,7 +36,10 @@ impl NetworkDriver for PluginDriver<'_> { fn setup( &self, - _netlink_sockets: (&mut super::netlink::Socket, &mut super::netlink::Socket), + _netlink_sockets: ( + &mut super::netlink::Socket, + &mut super::netlink::Socket, + ), ) -> NetavarkResult<(types::StatusBlock, Option)> { let result = self.exec_plugin(true, self.info.netns_path).wrap(format!( "plugin {:?} failed", @@ -49,7 +52,10 @@ impl NetworkDriver for PluginDriver<'_> { fn teardown( &self, - _netlink_sockets: (&mut super::netlink::Socket, &mut super::netlink::Socket), + _netlink_sockets: ( + &mut super::netlink::Socket, + &mut super::netlink::Socket, + ), ) -> NetavarkResult<()> { self.exec_plugin(false, self.info.netns_path).wrap(format!( "plugin {:?} failed", diff --git a/src/network/vlan.rs b/src/network/vlan.rs index 489aaa46c..2453ba00c 100644 --- a/src/network/vlan.rs +++ b/src/network/vlan.rs @@ -145,7 +145,10 @@ impl driver::NetworkDriver for Vlan<'_> { fn setup( &self, - netlink_sockets: (&mut netlink::Socket, &mut netlink::Socket), + netlink_sockets: ( + &mut netlink::Socket, + &mut netlink::Socket, + ), ) -> Result<(StatusBlock, Option), NetavarkError> { let data = match &self.data { Some(d) => d, @@ -218,7 +221,10 @@ impl driver::NetworkDriver for Vlan<'_> { fn teardown( &self, - netlink_sockets: (&mut netlink::Socket, &mut netlink::Socket), + netlink_sockets: ( + &mut netlink::Socket, + &mut netlink::Socket, + ), ) -> NetavarkResult<()> { dhcp_teardown(&self.info, netlink_sockets.1)?; @@ -236,7 +242,7 @@ impl driver::NetworkDriver for Vlan<'_> { fn setup( host: &mut netlink::Socket, - netns: &mut netlink::Socket, + netns: &mut netlink::Socket, if_name: &str, data: &InternalData, hostns_fd: BorrowedFd<'_>, diff --git a/src/test/netlink.rs b/src/test/netlink.rs index 6b7eb651c..c45dbd97e 100644 --- a/src/test/netlink.rs +++ b/src/test/netlink.rs @@ -28,13 +28,16 @@ mod tests { #[test] fn test_socket_new() { test_setup!(); - assert!(Socket::new().is_ok(), "Netlink Socket::new() should work"); + assert!( + Socket::::new().is_ok(), + "Netlink Socket::::new() should work" + ); } #[test] fn test_add_link() { test_setup!(); - let mut sock = Socket::new().expect("Socket::new()"); + let mut sock = Socket::::new().expect("Socket::::new()"); let name = String::from("test1"); sock.create_link(CreateLinkOptions::new(name.clone(), InfoKind::Dummy)) @@ -49,7 +52,7 @@ mod tests { #[test] fn test_add_addr() { test_setup!(); - let mut sock = Socket::new().expect("Socket::new()"); + let mut sock = Socket::::new().expect("Socket::::new()"); let out = run_command!("ip", "link", "add", "test1", "type", "dummy"); eprintln!("{}", String::from_utf8(out.stderr).unwrap()); @@ -72,7 +75,7 @@ mod tests { #[test] fn test_del_addr() { test_setup!(); - let mut sock = Socket::new().expect("Socket::new()"); + let mut sock = Socket::::new().expect("Socket::::new()"); let out = run_command!("ip", "link", "add", "test1", "type", "dummy"); eprintln!("{}", String::from_utf8(out.stderr).unwrap()); @@ -110,7 +113,7 @@ mod tests { #[ignore] fn test_del_route() { test_setup!(); - let mut sock = Socket::new().expect("Socket::new()"); + let mut sock = Socket::::new().expect("Socket::::new()"); let out = run_command!("ip", "link", "add", "test1", "type", "dummy"); eprintln!("{}", String::from_utf8(out.stderr).unwrap()); @@ -159,7 +162,7 @@ mod tests { #[test] fn test_dump_addr() { test_setup!(); - let mut sock = Socket::new().expect("Socket::new()"); + let mut sock = Socket::::new().expect("Socket::::new()"); let out = run_command!("ip", "link", "add", "test1", "type", "dummy"); eprintln!("{}", String::from_utf8(out.stderr).unwrap());