diff --git a/docs/index.rst b/docs/index.rst index ce468008f..30ebea4be 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -45,6 +45,7 @@ kernel/cpu_arch/index kernel/container/index kernel/libs/index + kernel/net/index kernel/trace/index kernel/syscall/index diff --git a/docs/kernel/net/index.rst b/docs/kernel/net/index.rst new file mode 100644 index 000000000..f1a49ea6d --- /dev/null +++ b/docs/kernel/net/index.rst @@ -0,0 +1,9 @@ +网络子系统 +==================================== +DragonOS 网络子系统 + +.. toctree:: + :maxdepth: 1 + + inet + unix \ No newline at end of file diff --git a/docs/kernel/net/inet.md b/docs/kernel/net/inet.md new file mode 100644 index 000000000..674fd36bb --- /dev/null +++ b/docs/kernel/net/inet.md @@ -0,0 +1,38 @@ +# Internet Protocol Socket + +众所都周之,这个 Inet Socket 常用的分为 TCP, UDP 和 ICMP。基于实用性,目前实现的是 TCP 和 UDP。 + +整个 Inet 网络协议栈与网卡的交互基于 `smoltcp` crate 来实现。 + +## Roadmap + +- [ ] TCP + - [x] 接受连接 + - [ ] 发起连接 + - [ ] 半双工关闭 +- [x] UDP + - [x] 传输数据 +- [ ] ICMP +- [ ] ioctl +- [ ] Misc + - [ ] 硬中断转软中断的锁处理(避免死锁) + - [ ] epoll_item 优化 + - [ ] 优化 `inet port` 资源管理 + +## TCP + +根据 TCP 状态机来 TCP Socket 的几个状态类 +- `Init`: 裸状态 + - `Unbound`: 创建出来的状态 + - `Bound`: 绑定了地址 +- `Listening`: 监听状态 +- `Connecting`: 连接中状态 +- `Established`: 连接建立状态 + +## UDP + +UDP 是无连接的,所以没有连接状态。UDP 的状态只有 `Unbound` 和 `Bound` 两种。 + +## BoundInner + +另一个对于 Inet Socket 的抽象,用于处理绑定网卡的 `socket`,从而封装 `smoltcp` 的接口,提供统一的资源管理。 \ No newline at end of file diff --git a/docs/kernel/net/unix.md b/docs/kernel/net/unix.md new file mode 100644 index 000000000..5de22e15e --- /dev/null +++ b/docs/kernel/net/unix.md @@ -0,0 +1,22 @@ +# UNIX + +## unix socket + +unix - 用于进程间通信的socket + + +## 描述 + +AF_UNIX socket family 用于在同一台机器中的不同进程之间的通信(IPC)。unix socket地址现支持绑定文件地址,未支持绑定abstract namespace抽象命名空间。 + +目前unix 域中合法的socket type有:SOCK_STREAM, 提供stream-oriented socket,可靠有序传输消息;SOCK_SEQPACKET,提供connection-oriented,消息边界和按发送顺序交付消息保证的socket。 + +### unix stream socket 进程通信描述 + +unix stream socket 提供进程间流式传输消息的功能。假设对端进程作为服务端,本端进程作为客户端。进程间使用stream socket通信过程如下: + +分别在对端进程和本端进程创建socket,服务端需要bind地址,客户端不必须bind地址。通信过程类似tcp三次握手流程:服务端调用listen系统调用进入监听状态,监听服务端bind的地址;客户端调用connect系统调用连接服务端地址;服务端调用accept系统调用接受来自客户端的连接,返回建立连接的新的socket。成功建立连接后可以调用write\send\sendto\sendmsg进行写操作,调用read\recv\recvfrom\recvmsg进行读操作。目前尚未支持非阻塞式读写,默认为阻塞式读写。读写完毕后调用close系统调用关闭socket连接。 + +### unix seqpacket socket 进程通信描述 + + diff --git a/kernel/Cargo.lock b/kernel/Cargo.lock index 6213178fc..ca935e7dc 100644 --- a/kernel/Cargo.lock +++ b/kernel/Cargo.lock @@ -1534,9 +1534,9 @@ checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" [[package]] name = "smoltcp" -version = "0.11.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a1a996951e50b5971a2c8c0fa05a381480d70a933064245c4a223ddc87ccc97" +checksum = "dad095989c1533c1c266d9b1e8d70a1329dd3723c3edac6d03bbd67e7bf6f4bb" dependencies = [ "bitflags 1.3.2", "byteorder 1.5.0", diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index 0b8fe0bff..d7384c5a5 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -48,7 +48,7 @@ linkme = "=0.3.27" num = { version = "=0.4.0", default-features = false } num-derive = "=0.3" num-traits = { git = "https://git.mirrors.dragonos.org.cn/DragonOS-Community/num-traits.git", rev = "1597c1c", default-features = false } -smoltcp = { version = "=0.11.0", default-features = false, features = [ +smoltcp = { version = "=0.12.0", default-features = false, features = [ "alloc", "socket-raw", "socket-udp", @@ -58,6 +58,7 @@ smoltcp = { version = "=0.11.0", default-features = false, features = [ "socket-dns", "proto-ipv4", "proto-ipv6", + "medium-ip", ] } syscall_table_macros = { path = "crates/syscall_table_macros" } system_error = { path = "crates/system_error" } diff --git a/kernel/src/driver/acpi/sysfs.rs b/kernel/src/driver/acpi/sysfs.rs index 886c0a3ac..34b93db35 100644 --- a/kernel/src/driver/acpi/sysfs.rs +++ b/kernel/src/driver/acpi/sysfs.rs @@ -150,7 +150,7 @@ impl AcpiManager { acpi_table_attr_list().write().push(attr); self.acpi_table_data_init(&header)?; } - + // TODO:UEVENT return Ok(()); } diff --git a/kernel/src/driver/base/device/dd.rs b/kernel/src/driver/base/device/dd.rs index 81adaa184..f79dabd21 100644 --- a/kernel/src/driver/base/device/dd.rs +++ b/kernel/src/driver/base/device/dd.rs @@ -572,6 +572,7 @@ impl DriverManager { } // todo: 发送kobj bind的uevent + // kobject_uevent(); } fn driver_is_bound(&self, device: &Arc) -> bool { diff --git a/kernel/src/driver/base/device/driver.rs b/kernel/src/driver/base/device/driver.rs index dad6682c3..a40762919 100644 --- a/kernel/src/driver/base/device/driver.rs +++ b/kernel/src/driver/base/device/driver.rs @@ -17,7 +17,6 @@ use alloc::{ use core::fmt::Debug; use log::error; use system_error::SystemError; - /// @brief: Driver error #[allow(dead_code)] #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -219,7 +218,8 @@ impl DriverManager { bus_manager().remove_driver(&driver); })?; - // todo: 发送uevent + // todo: 发送uevent,类型问题 + // deferred_probe_extend_timeout(); return Ok(()); } diff --git a/kernel/src/driver/base/device/mod.rs b/kernel/src/driver/base/device/mod.rs index 18199ce03..5f2dbd50e 100644 --- a/kernel/src/driver/base/device/mod.rs +++ b/kernel/src/driver/base/device/mod.rs @@ -611,7 +611,7 @@ impl DeviceManager { } // todo: 发送uevent: KOBJ_ADD - + // kobject_uevent(); // probe drivers for a new device bus_probe_device(&device); diff --git a/kernel/src/driver/base/kobject.rs b/kernel/src/driver/base/kobject.rs index 10447b25f..0b8cec571 100644 --- a/kernel/src/driver/base/kobject.rs +++ b/kernel/src/driver/base/kobject.rs @@ -103,10 +103,9 @@ bitflags! { const ADD_UEVENT_SENT = 1 << 1; const REMOVE_UEVENT_SENT = 1 << 2; const INITIALIZED = 1 << 3; + const UEVENT_SUPPRESS = 1 << 4; } - } - #[derive(Debug)] pub struct LockedKObjectState(RwLock); @@ -251,7 +250,7 @@ impl KObjectManager { } // todo: 发送uevent: KOBJ_REMOVE - + // kobject_uevent(); sysfs_instance().remove_dir(&kobj); kobj.update_kobj_state(None, Some(KObjectState::IN_SYSFS)); let kset = kobj.kset(); diff --git a/kernel/src/driver/base/kset.rs b/kernel/src/driver/base/kset.rs index fa6b4575d..4d890d721 100644 --- a/kernel/src/driver/base/kset.rs +++ b/kernel/src/driver/base/kset.rs @@ -91,6 +91,7 @@ impl KSet { pub fn register(&self, join_kset: Option>) -> Result<(), SystemError> { return KObjectManager::add_kobj(self.self_ref.upgrade().unwrap(), join_kset); // todo: 引入uevent之后,发送uevent + // kobject_uevent(); } /// 注销一个kset diff --git a/kernel/src/driver/base/platform/platform_device.rs b/kernel/src/driver/base/platform/platform_device.rs index ee4cec639..062a69aae 100644 --- a/kernel/src/driver/base/platform/platform_device.rs +++ b/kernel/src/driver/base/platform/platform_device.rs @@ -16,7 +16,7 @@ use crate::{ kobject::{KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState}, kset::KSet, }, - filesystem::kernfs::KernFSInode, + filesystem::{kernfs::KernFSInode, sysfs::AttributeGroup}, libs::{ rwlock::{RwLockReadGuard, RwLockWriteGuard}, spinlock::{SpinLock, SpinLockGuard}, @@ -329,4 +329,8 @@ impl Device for PlatformBusDevice { fn set_dev_parent(&self, dev_parent: Option>) { self.inner().device_common.parent = dev_parent; } + + fn attribute_groups(&self) -> Option<&'static [&'static dyn AttributeGroup]> { + None + } } diff --git a/kernel/src/driver/net/e1000e/e1000e_driver.rs b/kernel/src/driver/net/e1000e/e1000e_driver.rs index 6010ef7ac..1f3e758b8 100644 --- a/kernel/src/driver/net/e1000e/e1000e_driver.rs +++ b/kernel/src/driver/net/e1000e/e1000e_driver.rs @@ -8,7 +8,9 @@ use crate::{ device::{bus::Bus, driver::Driver, Device, DeviceCommonData, DeviceType, IdTable}, kobject::{KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState}, }, - net::{register_netdevice, NetDeivceState, NetDevice, NetDeviceCommonData, Operstate}, + net::{ + register_netdevice, Iface, IfaceCommon, NetDeivceState, NetDeviceCommonData, Operstate, + }, }, libs::{ rwlock::{RwLockReadGuard, RwLockWriteGuard}, @@ -27,11 +29,8 @@ use core::{ ops::{Deref, DerefMut}, }; use log::info; -use smoltcp::{ - phy, - wire::{self, HardwareAddress}, -}; -use system_error::SystemError; +use smoltcp::{phy, wire::HardwareAddress}; +// use system_error::SystemError; use super::e1000e::{E1000EBuffer, E1000EDevice}; @@ -78,12 +77,12 @@ impl Debug for E1000EDriverWrapper { } } -#[cast_to([sync] NetDevice)] +#[cast_to([sync] Iface)] #[cast_to([sync] Device)] +#[derive(Debug)] pub struct E1000EInterface { driver: E1000EDriverWrapper, - iface_id: usize, - iface: SpinLock, + common: IfaceCommon, name: String, inner: SpinLock, locked_kobj_state: LockedKObjectState, @@ -97,11 +96,11 @@ pub struct InnerE1000EInterface { } impl phy::RxToken for E1000ERxToken { - fn consume(mut self, f: F) -> R + fn consume(self, f: F) -> R where - F: FnOnce(&mut [u8]) -> R, + F: FnOnce(&[u8]) -> R, { - let result = f(self.0.as_mut_slice()); + let result = f(self.0.as_slice()); self.0.free_buffer(); return result; } @@ -201,11 +200,9 @@ impl E1000EInterface { let iface = smoltcp::iface::Interface::new(iface_config, &mut driver, Instant::now().into()); - let driver: E1000EDriverWrapper = E1000EDriverWrapper(UnsafeCell::new(driver)); let result = Arc::new(E1000EInterface { - driver, - iface_id, - iface: SpinLock::new(iface), + driver: E1000EDriverWrapper(UnsafeCell::new(driver)), + common: IfaceCommon::new(iface_id, false, iface), name: format!("eth{}", iface_id), inner: SpinLock::new(InnerE1000EInterface { netdevice_common: NetDeviceCommonData::default(), @@ -223,16 +220,6 @@ impl E1000EInterface { } } -impl Debug for E1000EInterface { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("E1000EInterface") - .field("iface_id", &self.iface_id) - .field("iface", &"smoltcp::iface::Interface") - .field("name", &self.name) - .finish() - } -} - impl Device for E1000EInterface { fn dev_type(&self) -> DeviceType { DeviceType::Net @@ -302,52 +289,23 @@ impl Device for E1000EInterface { } } -impl NetDevice for E1000EInterface { +impl Iface for E1000EInterface { + fn common(&self) -> &IfaceCommon { + return &self.common; + } + fn mac(&self) -> smoltcp::wire::EthernetAddress { let mac = self.driver.inner.lock().mac_address(); return smoltcp::wire::EthernetAddress::from_bytes(&mac); } - #[inline] - fn nic_id(&self) -> usize { - return self.iface_id; - } - #[inline] fn iface_name(&self) -> String { return self.name.clone(); } - fn update_ip_addrs(&self, ip_addrs: &[wire::IpCidr]) -> Result<(), SystemError> { - if ip_addrs.len() != 1 { - return Err(SystemError::EINVAL); - } - - self.iface.lock().update_ip_addrs(|addrs| { - let dest = addrs.iter_mut().next(); - - if let Some(dest) = dest { - *dest = ip_addrs[0]; - } else { - addrs.push(ip_addrs[0]).expect("Push ipCidr failed: full"); - } - }); - return Ok(()); - } - - fn poll(&self, sockets: &mut smoltcp::iface::SocketSet) -> Result<(), SystemError> { - let timestamp: smoltcp::time::Instant = Instant::now().into(); - let mut guard = self.iface.lock(); - let poll_res = guard.poll(timestamp, self.driver.force_get_mut(), sockets); - if poll_res { - return Ok(()); - } - return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); - } - - #[inline(always)] - fn inner_iface(&self) -> &SpinLock { - return &self.iface; + fn poll(&self) { + self.common.poll(self.driver.force_get_mut()) } fn addr_assign_type(&self) -> u8 { diff --git a/kernel/src/driver/net/irq_handle.rs b/kernel/src/driver/net/irq_handle.rs index 6a1a3a328..5a6cd0db0 100644 --- a/kernel/src/driver/net/irq_handle.rs +++ b/kernel/src/driver/net/irq_handle.rs @@ -1,13 +1,10 @@ use alloc::sync::Arc; use system_error::SystemError; -use crate::{ - exception::{ - irqdata::IrqHandlerData, - irqdesc::{IrqHandler, IrqReturn}, - IrqNumber, - }, - net::net_core::poll_ifaces_try_lock_onetime, +use crate::exception::{ + irqdata::IrqHandlerData, + irqdesc::{IrqHandler, IrqReturn}, + IrqNumber, }; /// 默认的网卡中断处理函数 @@ -21,7 +18,7 @@ impl IrqHandler for DefaultNetIrqHandler { _static_data: Option<&dyn IrqHandlerData>, _dynamic_data: Option>, ) -> Result { - poll_ifaces_try_lock_onetime().ok(); + super::kthread::wakeup_poll_thread(); Ok(IrqReturn::Handled) } } diff --git a/kernel/src/driver/net/kthread.rs b/kernel/src/driver/net/kthread.rs new file mode 100644 index 000000000..3c634499b --- /dev/null +++ b/kernel/src/driver/net/kthread.rs @@ -0,0 +1,47 @@ +use alloc::borrow::ToOwned; +use alloc::sync::Arc; +use unified_init::macros::unified_init; + +use crate::arch::CurrentIrqArch; +use crate::exception::InterruptArch; +use crate::init::initcall::INITCALL_SUBSYS; +use crate::net::NET_DEVICES; +use crate::process::kthread::{KernelThreadClosure, KernelThreadMechanism}; +use crate::process::{ProcessControlBlock, ProcessManager}; +use crate::sched::{schedule, SchedMode}; + +static mut NET_POLL_THREAD: Option> = None; + +#[unified_init(INITCALL_SUBSYS)] +pub fn net_poll_init() -> Result<(), system_error::SystemError> { + let closure = KernelThreadClosure::StaticEmptyClosure((&(net_poll_thread as fn() -> i32), ())); + let pcb = KernelThreadMechanism::create_and_run(closure, "net_poll".to_owned()) + .ok_or("") + .expect("create net_poll thread failed"); + log::info!("net_poll thread created"); + unsafe { + NET_POLL_THREAD = Some(pcb); + } + return Ok(()); +} + +fn net_poll_thread() -> i32 { + log::info!("net_poll thread started"); + loop { + for (_, iface) in NET_DEVICES.read_irqsave().iter() { + iface.poll(); + } + let irq_guard = unsafe { CurrentIrqArch::save_and_disable_irq() }; + ProcessManager::mark_sleep(true).expect("clocksource_watchdog_kthread:mark sleep failed"); + drop(irq_guard); + schedule(SchedMode::SM_NONE); + } +} + +/// 拉起线程 +pub(super) fn wakeup_poll_thread() { + if unsafe { NET_POLL_THREAD.is_none() } { + return; + } + let _ = ProcessManager::wakeup(unsafe { NET_POLL_THREAD.as_ref().unwrap() }); +} diff --git a/kernel/src/driver/net/loopback.rs b/kernel/src/driver/net/loopback.rs index cb251ce4a..42fb170c3 100644 --- a/kernel/src/driver/net/loopback.rs +++ b/kernel/src/driver/net/loopback.rs @@ -28,7 +28,9 @@ use smoltcp::{ use system_error::SystemError; use unified_init::macros::unified_init; -use super::{register_netdevice, NetDeivceState, NetDevice, NetDeviceCommonData, Operstate}; +use super::{register_netdevice, NetDeivceState, NetDeviceCommonData, Operstate}; + +use super::{Iface, IfaceCommon}; const DEVICE_NAME: &str = "loopback"; @@ -48,11 +50,11 @@ impl phy::RxToken for LoopbackRxToken { /// /// ## 返回值 /// 返回函数 `f` 在 `self.buffer` 上的调用结果。 - fn consume(mut self, f: F) -> R + fn consume(self, f: F) -> R where - F: FnOnce(&mut [u8]) -> R, + F: FnOnce(&[u8]) -> R, { - f(self.buffer.as_mut_slice()) + f(self.buffer.as_slice()) } } @@ -81,6 +83,7 @@ impl phy::TxToken for LoopbackTxToken { let result = f(buffer.as_mut_slice()); let mut device = self.driver.inner.lock(); device.loopback_transmit(buffer); + // debug!("lo transmit!"); result } } @@ -112,7 +115,7 @@ impl Loopback { let buffer = self.queue.pop_front(); match buffer { Some(buffer) => { - //debug!("lo receive:{:?}", buffer); + // debug!("lo receive:{:?}", buffer); return buffer; } None => { @@ -127,7 +130,7 @@ impl Loopback { /// - &mut self:自身可变引用 /// - buffer:需要发送的数据包 pub fn loopback_transmit(&mut self, buffer: Vec) { - //debug!("lo transmit!"); + // debug!("lo transmit:{:?}", buffer); self.queue.push_back(buffer) } } @@ -136,6 +139,7 @@ impl Loopback { /// 为实现获得不可变引用的Interface的内部可变性,故为Driver提供UnsafeCell包裹器 /// /// 参考virtio_net.rs +#[derive(Debug)] struct LoopbackDriverWapper(UnsafeCell); unsafe impl Send for LoopbackDriverWapper {} unsafe impl Sync for LoopbackDriverWapper {} @@ -200,7 +204,7 @@ impl phy::Device for LoopbackDriver { let mut result = phy::DeviceCapabilities::default(); result.max_transmission_unit = 65535; result.max_burst_size = Some(1); - result.medium = smoltcp::phy::Medium::Ethernet; + result.medium = smoltcp::phy::Medium::Ip; return result; } /// ## Loopback驱动处理接受数据事件 @@ -220,8 +224,10 @@ impl phy::Device for LoopbackDriver { let buffer = self.inner.lock().loopback_receive(); //receive队列为为空,返回NONE值以通知上层没有可以receive的包 if buffer.is_empty() { + // log::debug!("lo receive none!"); return Option::None; } + // log::debug!("lo receive!"); let rx = LoopbackRxToken { buffer }; let tx = LoopbackTxToken { driver: self.clone(), @@ -238,6 +244,7 @@ impl phy::Device for LoopbackDriver { /// ## 返回值 /// - 返回一个 `Some`,其中包含一个发送令牌,该令牌包含一个对自身的克隆引用 fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option> { + // log::debug!("lo transmit!"); Some(LoopbackTxToken { driver: self.clone(), }) @@ -246,13 +253,12 @@ impl phy::Device for LoopbackDriver { /// ## LoopbackInterface结构 /// 封装驱动包裹器和iface,设置接口名称 -#[cast_to([sync] NetDevice)] +#[cast_to([sync] Iface)] #[cast_to([sync] Device)] +#[derive(Debug)] pub struct LoopbackInterface { driver: LoopbackDriverWapper, - iface_id: usize, - iface: SpinLock, - name: String, + common: IfaceCommon, inner: SpinLock, locked_kobj_state: LockedKObjectState, } @@ -265,6 +271,8 @@ pub struct InnerLoopbackInterface { } impl LoopbackInterface { + pub const DEVICE_NAME: &str = "lo"; + /// ## `new` 是一个公共函数,用于创建一个新的 `LoopbackInterface` 实例。 /// 生成一个新的接口 ID。创建一个新的接口配置,设置其硬件地址和随机种子,使用接口配置和驱动器创建一个新的 `smoltcp::iface::Interface` 实例。 /// 设置接口的 IP 地址为 127.0.0.1。 @@ -277,25 +285,44 @@ impl LoopbackInterface { /// 返回一个 `Arc`,即一个指向新创建的 `LoopbackInterface` 实例的智能指针。 pub fn new(mut driver: LoopbackDriver) -> Arc { let iface_id = generate_iface_id(); - let mut iface_config = smoltcp::iface::Config::new(HardwareAddress::Ethernet( - smoltcp::wire::EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]), - )); + + // let hardware_addr = HardwareAddress::Ethernet(smoltcp::wire::EthernetAddress([ + // 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // ])); + + let hardware_addr = HardwareAddress::Ip; + + let mut iface_config = smoltcp::iface::Config::new(hardware_addr); + iface_config.random_seed = rand() as u64; let mut iface = smoltcp::iface::Interface::new(iface_config, &mut driver, Instant::now().into()); + + iface.set_any_ip(true); + + let addr = IpAddress::v4(127, 0, 0, 1); + let cidr = IpCidr::new(addr, 8); + //设置网卡地址为127.0.0.1 iface.update_ip_addrs(|ip_addrs| { - ip_addrs - .push(IpCidr::new(IpAddress::v4(127, 0, 0, 1), 8)) - .unwrap(); + ip_addrs.push(cidr).expect("Push ipCidr failed: full"); + }); + + iface.routes_mut().update(|routes_map| { + routes_map + .push(smoltcp::iface::Route { + cidr, + via_router: addr, + preferred_until: None, + expires_at: None, + }) + .expect("Add default ipv4 route failed: full"); }); - let driver = LoopbackDriverWapper(UnsafeCell::new(driver)); + Arc::new(LoopbackInterface { - driver, - iface_id, - iface: SpinLock::new(iface), - name: "lo".to_string(), + driver: LoopbackDriverWapper(UnsafeCell::new(driver)), + common: IfaceCommon::new(iface_id, false, iface), inner: SpinLock::new(InnerLoopbackInterface { netdevice_common: NetDeviceCommonData::default(), device_common: DeviceCommonData::default(), @@ -310,16 +337,7 @@ impl LoopbackInterface { } } -impl Debug for LoopbackInterface { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("LoopbackInterface") - .field("iface_id", &self.iface_id) - .field("iface", &"smtoltcp::iface::Interface") - .field("name", &self.name) - .finish() - } -} - +//TODO: 向sysfs注册lo设备 impl KObject for LoopbackInterface { fn as_any_ref(&self) -> &dyn core::any::Any { self @@ -354,7 +372,7 @@ impl KObject for LoopbackInterface { } fn name(&self) -> String { - self.name.clone() + Self::DEVICE_NAME.to_string() } fn set_name(&self, _name: String) { @@ -447,72 +465,23 @@ impl Device for LoopbackInterface { } } -impl NetDevice for LoopbackInterface { - /// 由于lo网卡设备不是实际的物理设备,其mac地址需要手动设置为一个默认值,这里默认为00:00:00:00:00 - fn mac(&self) -> smoltcp::wire::EthernetAddress { - let mac = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; - smoltcp::wire::EthernetAddress(mac) - } - - #[inline] - fn nic_id(&self) -> usize { - self.iface_id +impl Iface for LoopbackInterface { + fn common(&self) -> &IfaceCommon { + &self.common } - #[inline] fn iface_name(&self) -> String { - self.name.clone() + Self::DEVICE_NAME.to_string() } - /// ## `update_ip_addrs` 用于更新接口的 IP 地址。 - /// - /// ## 参数 - /// - `&self` :自身引用 - /// - `ip_addrs` :一个包含 `smoltcp::wire::IpCidr` 的切片,表示要设置的 IP 地址和子网掩码 - /// - /// ## 返回值 - /// - 如果 `ip_addrs` 的长度不为 1,返回 `Err(SystemError::EINVAL)`,表示输入参数无效 - /// - 如果更新成功,返回 `Ok(())` - fn update_ip_addrs( - &self, - ip_addrs: &[smoltcp::wire::IpCidr], - ) -> Result<(), system_error::SystemError> { - if ip_addrs.len() != 1 { - return Err(SystemError::EINVAL); - } - - self.iface.lock().update_ip_addrs(|addrs| { - let dest = addrs.iter_mut().next(); - if let Some(dest) = dest { - *dest = ip_addrs[0]; - } else { - addrs.push(ip_addrs[0]).expect("Push ipCidr failed: full"); - } - }); - return Ok(()); - } - /// ## `poll` 用于轮询接口的状态。 - /// - /// ## 参数 - /// - `&self` :自身引用 - /// - `sockets` :一个可变引用到 `smoltcp::iface::SocketSet`,表示要轮询的套接字集 - /// - /// ## 返回值 - /// - 如果轮询成功,返回 `Ok(())` - /// - 如果轮询失败,返回 `Err(SystemError::EAGAIN_OR_EWOULDBLOCK)`,表示需要再次尝试或者操作会阻塞 - fn poll(&self, sockets: &mut smoltcp::iface::SocketSet) -> Result<(), SystemError> { - let timestamp: smoltcp::time::Instant = Instant::now().into(); - let mut guard = self.iface.lock(); - let poll_res = guard.poll(timestamp, self.driver.force_get_mut(), sockets); - if poll_res { - return Ok(()); - } - return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + /// 由于lo网卡设备不是实际的物理设备,其mac地址需要手动设置为一个默认值,这里默认为00:00:00:00:00 + fn mac(&self) -> smoltcp::wire::EthernetAddress { + let mac = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + smoltcp::wire::EthernetAddress(mac) } - #[inline(always)] - fn inner_iface(&self) -> &SpinLock { - return &self.iface; + fn poll(&self) { + self.common.poll(self.driver.force_get_mut()) } fn addr_assign_type(&self) -> u8 { @@ -544,7 +513,7 @@ impl NetDevice for LoopbackInterface { pub fn loopback_probe() { loopback_driver_init(); } -/// ## lo网卡设备初始化函数 +/// # lo网卡设备初始化函数 /// 创建驱动和iface,初始化一个lo网卡,添加到全局NET_DEVICES中 pub fn loopback_driver_init() { let driver = LoopbackDriver::new(); @@ -554,7 +523,7 @@ pub fn loopback_driver_init() { NET_DEVICES .write_irqsave() - .insert(iface.iface_id, iface.clone()); + .insert(iface.nic_id(), iface.clone()); register_netdevice(iface.clone()).expect("register lo device failed"); } @@ -563,5 +532,6 @@ pub fn loopback_driver_init() { #[unified_init(INITCALL_DEVICE)] pub fn loopback_init() -> Result<(), SystemError> { loopback_probe(); + log::debug!("Successfully init loopback device"); return Ok(()); } diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index 9a137b0fc..c9d026532 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -1,18 +1,20 @@ +use alloc::{fmt, vec::Vec}; use alloc::{string::String, sync::Arc}; -use smoltcp::{ - iface, - wire::{self, EthernetAddress}, -}; use sysfs::netdev_register_kobject; -use super::base::device::Device; -use crate::libs::spinlock::SpinLock; +use crate::{ + libs::{rwlock::RwLock, spinlock::SpinLock}, + net::socket::inet::{common::PortManager, InetSocket}, + process::ProcessState, +}; +use smoltcp; use system_error::SystemError; pub mod class; mod dma; pub mod e1000e; pub mod irq_handle; +pub mod kthread; pub mod loopback; pub mod sysfs; pub mod virtio_net; @@ -52,23 +54,63 @@ pub enum Operstate { } #[allow(dead_code)] -pub trait NetDevice: Device { - /// @brief 获取网卡的MAC地址 - fn mac(&self) -> EthernetAddress; +pub trait Iface: crate::driver::base::device::Device { + /// # `common` + /// 获取网卡的公共信息 + fn common(&self) -> &IfaceCommon; + + /// # `mac` + /// 获取网卡的MAC地址 + fn mac(&self) -> smoltcp::wire::EthernetAddress; + /// # `name` + /// 获取网卡名 fn iface_name(&self) -> String; - /// @brief 获取网卡的id - fn nic_id(&self) -> usize; + /// # `nic_id` + /// 获取网卡id + fn nic_id(&self) -> usize { + self.common().iface_id + } - fn poll(&self, sockets: &mut iface::SocketSet) -> Result<(), SystemError>; + /// # `poll` + /// 用于轮询接口的状态。 + /// ## 参数 + /// - `sockets` :一个可变引用到 `smoltcp::iface::SocketSet`,表示要轮询的套接字集 + /// ## 返回值 + /// - 成功返回 `Ok(())` + /// - 如果轮询失败,返回 `Err(SystemError::EAGAIN_OR_EWOULDBLOCK)`,表示需要再次尝试或者操作会阻塞 + fn poll(&self); - fn update_ip_addrs(&self, ip_addrs: &[wire::IpCidr]) -> Result<(), SystemError>; + /// # `update_ip_addrs` + /// 用于更新接口的 IP 地址 + /// ## 参数 + /// - `ip_addrs` :一个包含 `smoltcp::wire::IpCidr` 的切片,表示要设置的 IP 地址和子网掩码 + /// ## 返回值 + /// - 如果 `ip_addrs` 的长度不为 1,返回 `Err(SystemError::EINVAL)`,表示输入参数无效 + fn update_ip_addrs(&self, ip_addrs: &[smoltcp::wire::IpCidr]) -> Result<(), SystemError> { + self.common().update_ip_addrs(ip_addrs) + } /// @brief 获取smoltcp的网卡接口类型 - fn inner_iface(&self) -> &SpinLock; + #[inline(always)] + fn smol_iface(&self) -> &SpinLock { + &self.common().smol_iface + } // fn as_any_ref(&'static self) -> &'static dyn core::any::Any; + /// # `sockets` + /// 获取网卡的套接字集 + fn sockets(&self) -> &SpinLock> { + &self.common().sockets + } + + /// # `port_manager` + /// 用于管理网卡的端口 + fn port_manager(&self) -> &PortManager { + &self.common().port_manager + } + fn addr_assign_type(&self) -> u8; fn net_device_type(&self) -> u16; @@ -108,7 +150,7 @@ impl Default for NetDeviceCommonData { /// 将网络设备注册到sysfs中 /// 参考:https://code.dragonos.org.cn/xref/linux-2.6.39/net/core/dev.c?fi=register_netdev#5373 -fn register_netdevice(dev: Arc) -> Result<(), SystemError> { +fn register_netdevice(dev: Arc) -> Result<(), SystemError> { // 在sysfs中注册设备 netdev_register_kobject(dev.clone())?; @@ -117,3 +159,143 @@ fn register_netdevice(dev: Arc) -> Result<(), SystemError> { return Ok(()); } + +pub struct IfaceCommon { + iface_id: usize, + smol_iface: SpinLock, + /// 存smoltcp网卡的套接字集 + sockets: SpinLock>, + /// 存 kernel wrap smoltcp socket 的集合 + bounds: RwLock>>, + /// 端口管理器 + port_manager: PortManager, + /// 下次轮询的时间 + poll_at_ms: core::sync::atomic::AtomicU64, + /// 默认网卡标识 + /// TODO: 此字段设置目的是解决对bind unspecified地址的分包问题,需要在inet实现多网卡监听或路由子系统实现后移除 + default_iface: bool, +} + +impl fmt::Debug for IfaceCommon { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IfaceCommon") + .field("iface_id", &self.iface_id) + .field("sockets", &self.sockets) + .field("bounds", &self.bounds) + .field("port_manager", &self.port_manager) + .field("poll_at_ms", &self.poll_at_ms) + .finish() + } +} + +impl IfaceCommon { + pub fn new(iface_id: usize, default_iface: bool, iface: smoltcp::iface::Interface) -> Self { + IfaceCommon { + iface_id, + smol_iface: SpinLock::new(iface), + sockets: SpinLock::new(smoltcp::iface::SocketSet::new(Vec::new())), + bounds: RwLock::new(Vec::new()), + port_manager: PortManager::new(), + poll_at_ms: core::sync::atomic::AtomicU64::new(0), + default_iface, + } + } + + pub fn poll(&self, device: &mut D) + where + D: smoltcp::phy::Device + ?Sized, + { + let timestamp = crate::time::Instant::now().into(); + let mut sockets = self.sockets.lock_irqsave(); + let mut interface = self.smol_iface.lock_irqsave(); + + let (has_events, poll_at) = { + ( + matches!( + interface.poll(timestamp, device, &mut sockets), + smoltcp::iface::PollResult::SocketStateChanged + ), + loop { + let poll_at = interface.poll_at(timestamp, &sockets); + let Some(instant) = poll_at else { + break poll_at; + }; + if instant > timestamp { + break poll_at; + } + }, + ) + }; + + // drop sockets here to avoid deadlock + drop(interface); + drop(sockets); + + use core::sync::atomic::Ordering; + if let Some(instant) = poll_at { + let _old_instant = self.poll_at_ms.load(Ordering::Relaxed); + let new_instant = instant.total_millis() as u64; + self.poll_at_ms.store(new_instant, Ordering::Relaxed); + + // TODO: poll at + // if old_instant == 0 || new_instant < old_instant { + // self.polling_wait_queue.wake_all(); + // } + } else { + self.poll_at_ms.store(0, Ordering::Relaxed); + } + + self.bounds.read_irqsave().iter().for_each(|bound_socket| { + // incase our inet socket missed the event, we manually notify it each time we poll + if has_events { + bound_socket.on_iface_events(); + let _woke = bound_socket + .wait_queue() + .wakeup(Some(ProcessState::Blocked(true))); + } + }); + + // TODO: remove closed sockets + // let closed_sockets = self + // .closing_sockets + // .lock_irq_disabled() + // .extract_if(|closing_socket| closing_socket.is_closed()) + // .collect::>(); + // drop(closed_sockets); + } + + pub fn update_ip_addrs(&self, ip_addrs: &[smoltcp::wire::IpCidr]) -> Result<(), SystemError> { + if ip_addrs.len() != 1 { + return Err(SystemError::EINVAL); + } + + self.smol_iface.lock().update_ip_addrs(|addrs| { + let dest = addrs.iter_mut().next(); + + if let Some(dest) = dest { + *dest = ip_addrs[0]; + } else { + addrs.push(ip_addrs[0]).expect("Push ipCidr failed: full"); + } + }); + return Ok(()); + } + + // 需要bounds储存具体的Inet Socket信息,以提供不同种类inet socket的事件分发 + pub fn bind_socket(&self, socket: Arc) { + self.bounds.write().push(socket); + } + + pub fn unbind_socket(&self, socket: Arc) { + let mut bounds = self.bounds.write(); + if let Some(index) = bounds.iter().position(|s| Arc::ptr_eq(s, &socket)) { + bounds.remove(index); + log::debug!("unbind socket success"); + } + } + + // TODO: 需要在inet实现多网卡监听或路由子系统实现后移除 + pub fn is_default_iface(&self) -> bool { + self.default_iface + } +} diff --git a/kernel/src/driver/net/sysfs.rs b/kernel/src/driver/net/sysfs.rs index 8878b4a7d..5318790b9 100644 --- a/kernel/src/driver/net/sysfs.rs +++ b/kernel/src/driver/net/sysfs.rs @@ -17,11 +17,11 @@ use intertrait::cast::CastArc; use log::error; use system_error::SystemError; -use super::{class::sys_class_net_instance, NetDeivceState, NetDevice, Operstate}; +use super::{class::sys_class_net_instance, Iface, NetDeivceState, Operstate}; /// 将设备注册到`/sys/class/net`目录下 /// 参考:https://code.dragonos.org.cn/xref/linux-2.6.39/net/core/net-sysfs.c?fi=netdev_register_kobject#1311 -pub fn netdev_register_kobject(dev: Arc) -> Result<(), SystemError> { +pub fn netdev_register_kobject(dev: Arc) -> Result<(), SystemError> { // 初始化设备 device_manager().device_default_initialize(&(dev.clone() as Arc)); @@ -103,8 +103,8 @@ impl Attribute for AttrAddrAssignType { } fn show(&self, kobj: Arc, buf: &mut [u8]) -> Result { - let net_device = kobj.cast::().map_err(|_| { - error!("AttrAddrAssignType::show() failed: kobj is not a NetDevice"); + let net_device = kobj.cast::().map_err(|_| { + error!("AttrAddrAssignType::show() failed: kobj is not a Iface"); SystemError::EINVAL })?; let addr_assign_type = net_device.addr_assign_type(); @@ -271,8 +271,8 @@ impl Attribute for AttrType { } fn show(&self, kobj: Arc, buf: &mut [u8]) -> Result { - let net_deive = kobj.cast::().map_err(|_| { - error!("AttrType::show() failed: kobj is not a NetDevice"); + let net_deive = kobj.cast::().map_err(|_| { + error!("AttrType::show() failed: kobj is not a Iface"); SystemError::EINVAL })?; let net_type = net_deive.net_device_type(); @@ -322,8 +322,8 @@ impl Attribute for AttrAddress { } fn show(&self, kobj: Arc, buf: &mut [u8]) -> Result { - let net_device = kobj.cast::().map_err(|_| { - error!("AttrAddress::show() failed: kobj is not a NetDevice"); + let net_device = kobj.cast::().map_err(|_| { + error!("AttrAddress::show() failed: kobj is not a Iface"); SystemError::EINVAL })?; let mac_addr = net_device.mac(); @@ -373,8 +373,8 @@ impl Attribute for AttrCarrier { } fn show(&self, kobj: Arc, buf: &mut [u8]) -> Result { - let net_device = kobj.cast::().map_err(|_| { - error!("AttrCarrier::show() failed: kobj is not a NetDevice"); + let net_device = kobj.cast::().map_err(|_| { + error!("AttrCarrier::show() failed: kobj is not a Iface"); SystemError::EINVAL })?; if net_device @@ -489,8 +489,8 @@ impl Attribute for AttrOperstate { } fn show(&self, _kobj: Arc, _buf: &mut [u8]) -> Result { - let net_device = _kobj.cast::().map_err(|_| { - error!("AttrOperstate::show() failed: kobj is not a NetDevice"); + let net_device = _kobj.cast::().map_err(|_| { + error!("AttrOperstate::show() failed: kobj is not a Iface"); SystemError::EINVAL })?; if !net_device diff --git a/kernel/src/driver/net/virtio_net.rs b/kernel/src/driver/net/virtio_net.rs index b2b226dba..2500a1e12 100644 --- a/kernel/src/driver/net/virtio_net.rs +++ b/kernel/src/driver/net/virtio_net.rs @@ -15,7 +15,7 @@ use smoltcp::{iface, phy, wire}; use unified_init::macros::unified_init; use virtio_drivers::device::net::VirtIONet; -use super::{NetDeivceState, NetDevice, NetDeviceCommonData, Operstate}; +use super::{Iface, NetDeivceState, NetDeviceCommonData, Operstate}; use crate::{ arch::rand::rand, driver::{ @@ -40,13 +40,13 @@ use crate::{ }, }, exception::{irqdesc::IrqReturn, IrqNumber}, - filesystem::kernfs::KernFSInode, + filesystem::{kernfs::KernFSInode, sysfs::AttributeGroup}, init::initcall::INITCALL_POSTCORE, libs::{ rwlock::{RwLockReadGuard, RwLockWriteGuard}, spinlock::{SpinLock, SpinLockGuard}, }, - net::{generate_iface_id, net_core::poll_ifaces_try_lock_onetime, NET_DEVICES}, + net::{generate_iface_id, NET_DEVICES}, time::Instant, }; use system_error::SystemError; @@ -261,13 +261,15 @@ impl Device for VirtIONetDevice { fn set_dev_parent(&self, parent: Option>) { self.inner().device_common.parent = parent; } + + fn attribute_groups(&self) -> Option<&'static [&'static dyn AttributeGroup]> { + None + } } impl VirtIODevice for VirtIONetDevice { fn handle_irq(&self, _irq: IrqNumber) -> Result { - if poll_ifaces_try_lock_onetime().is_err() { - log::error!("virtio_net: try lock failed"); - } + super::kthread::wakeup_poll_thread(); return Ok(IrqReturn::Handled); } @@ -376,13 +378,13 @@ impl Debug for VirtIONicDeviceInner { } } -#[cast_to([sync] NetDevice)] +#[cast_to([sync] Iface)] #[cast_to([sync] Device)] +#[derive(Debug)] pub struct VirtioInterface { device_inner: VirtIONicDeviceInnerWrapper, - iface_id: usize, iface_name: String, - iface: SpinLock, + iface_common: super::IfaceCommon, inner: SpinLock, locked_kobj_state: LockedKObjectState, } @@ -394,17 +396,6 @@ struct InnerVirtIOInterface { netdevice_common: NetDeviceCommonData, } -impl core::fmt::Debug for VirtioInterface { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("VirtioInterface") - .field("iface_id", &self.iface_id) - .field("iface_name", &self.iface_name) - .field("inner", &self.inner) - .field("locked_kobj_state", &self.locked_kobj_state) - .finish() - } -} - impl VirtioInterface { pub fn new(mut device_inner: VirtIONicDeviceInner) -> Arc { let iface_id = generate_iface_id(); @@ -417,10 +408,9 @@ impl VirtioInterface { let result = Arc::new(VirtioInterface { device_inner: VirtIONicDeviceInnerWrapper(UnsafeCell::new(device_inner)), - iface_id, locked_kobj_state: LockedKObjectState::default(), - iface: SpinLock::new(iface), iface_name: format!("eth{}", iface_id), + iface_common: super::IfaceCommon::new(iface_id, true, iface), inner: SpinLock::new(InnerVirtIOInterface { kobj_common: KObjectCommonData::default(), device_common: DeviceCommonData::default(), @@ -445,7 +435,7 @@ impl VirtioInterface { impl Drop for VirtioInterface { fn drop(&mut self) { // 从全局的网卡接口信息表中删除这个网卡的接口信息 - NET_DEVICES.write_irqsave().remove(&self.iface_id); + NET_DEVICES.write_irqsave().remove(&self.nic_id()); } } @@ -612,11 +602,11 @@ impl phy::TxToken for VirtioNetToken { impl phy::RxToken for VirtioNetToken { fn consume(self, f: F) -> R where - F: FnOnce(&mut [u8]) -> R, + F: FnOnce(&[u8]) -> R, { // 为了线程安全,这里需要对VirtioNet进行加【写锁】,以保证对设备的互斥访问。 - let mut rx_buf = self.rx_buffer.unwrap(); - let result = f(rx_buf.packet_mut()); + let rx_buf = self.rx_buffer.unwrap(); + let result = f(rx_buf.packet()); self.driver .inner .lock() @@ -644,57 +634,26 @@ pub fn virtio_net( } } -impl NetDevice for VirtioInterface { +impl Iface for VirtioInterface { + fn common(&self) -> &super::IfaceCommon { + &self.iface_common + } + fn mac(&self) -> wire::EthernetAddress { let mac: [u8; 6] = self.device_inner.inner.lock().mac_address(); return wire::EthernetAddress::from_bytes(&mac); } - #[inline] - fn nic_id(&self) -> usize { - return self.iface_id; - } - #[inline] fn iface_name(&self) -> String { return self.iface_name.clone(); } - fn update_ip_addrs(&self, ip_addrs: &[wire::IpCidr]) -> Result<(), SystemError> { - if ip_addrs.len() != 1 { - return Err(SystemError::EINVAL); - } - - self.iface.lock().update_ip_addrs(|addrs| { - let dest = addrs.iter_mut().next(); - - if let Some(dest) = dest { - *dest = ip_addrs[0]; - } else { - addrs - .push(ip_addrs[0]) - .expect("Push wire::IpCidr failed: full"); - } - }); - return Ok(()); - } - - fn poll(&self, sockets: &mut iface::SocketSet) -> Result<(), SystemError> { - let timestamp: smoltcp::time::Instant = Instant::now().into(); - let mut guard = self.iface.lock(); - let poll_res = guard.poll(timestamp, self.device_inner.force_get_mut(), sockets); - // todo: notify!!! - // debug!("Virtio Interface poll:{poll_res}"); - if poll_res { - return Ok(()); - } - return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + fn poll(&self) { + // log::debug!("VirtioInterface: poll"); + self.iface_common.poll(self.device_inner.force_get_mut()) } - #[inline(always)] - fn inner_iface(&self) -> &SpinLock { - return &self.iface; - } // fn as_any_ref(&'static self) -> &'static dyn core::any::Any { // return self; // } @@ -786,9 +745,7 @@ impl KObject for VirtioInterface { #[unified_init(INITCALL_POSTCORE)] fn virtio_net_driver_init() -> Result<(), SystemError> { let driver = VirtIONetDriver::new(); - virtio_driver_manager() - .register(driver.clone() as Arc) - .expect("Add virtio net driver failed"); + virtio_driver_manager().register(driver.clone() as Arc)?; unsafe { VIRTIO_NET_DRIVER = Some(driver); } @@ -859,7 +816,7 @@ impl VirtIODriver for VirtIONetDriver { // 设置iface的父设备为virtio_net_device iface.set_dev_parent(Some(Arc::downgrade(&virtio_net_device) as Weak)); // 在sysfs中注册iface - register_netdevice(iface.clone() as Arc)?; + register_netdevice(iface.clone() as Arc)?; // 将网卡的接口信息注册到全局的网卡接口信息表中 NET_DEVICES diff --git a/kernel/src/driver/rtc/class.rs b/kernel/src/driver/rtc/class.rs index d10f11cff..6ed46b97b 100644 --- a/kernel/src/driver/rtc/class.rs +++ b/kernel/src/driver/rtc/class.rs @@ -13,6 +13,7 @@ use crate::{ kobject::KObject, subsys::SubSysPrivate, }, + filesystem::sysfs::AttributeGroup, init::initcall::INITCALL_SUBSYS, time::{timekeeping::do_settimeofday64, PosixTimeSpec}, }; @@ -78,6 +79,9 @@ impl Class for RtcClass { fn subsystem(&self) -> &SubSysPrivate { return &self.subsystem; } + fn dev_groups(&self) -> &'static [&'static dyn AttributeGroup] { + return &[]; + } } /// 注册rtc通用设备 diff --git a/kernel/src/driver/video/fbdev/base/fbmem.rs b/kernel/src/driver/video/fbdev/base/fbmem.rs index c3749015d..7cd78879c 100644 --- a/kernel/src/driver/video/fbdev/base/fbmem.rs +++ b/kernel/src/driver/video/fbdev/base/fbmem.rs @@ -112,6 +112,10 @@ impl Class for GraphicsClass { fn subsystem(&self) -> &SubSysPrivate { return &self.subsystem; } + + fn dev_groups(&self) -> &'static [&'static dyn AttributeGroup] { + return &[]; + } } /// 帧缓冲区管理器 diff --git a/kernel/src/filesystem/vfs/open.rs b/kernel/src/filesystem/vfs/open.rs index a1b1994ff..64a2f7b65 100644 --- a/kernel/src/filesystem/vfs/open.rs +++ b/kernel/src/filesystem/vfs/open.rs @@ -1,5 +1,4 @@ use alloc::sync::Arc; -use log::warn; use system_error::SystemError; use super::{ @@ -62,7 +61,7 @@ pub fn do_fchmodat(dirfd: i32, path: *const u8, _mode: ModeType) -> Result Result { let check = check_unshare_flags(flags)?; diff --git a/kernel/src/net/mod.rs b/kernel/src/net/mod.rs index 84ad6c892..805761cc6 100644 --- a/kernel/src/net/mod.rs +++ b/kernel/src/net/mod.rs @@ -1,23 +1,22 @@ -use core::{ - fmt::{self, Debug}, - sync::atomic::AtomicUsize, -}; +//! # 网络模块 +//! 注意,net模块下,为了方便导入,模块细分,且共用部分模块直接使用 +//! `pub use`导出,导入时也常见`use crate::net::socket::*`的写法, +//! 敬请注意。 +use core::sync::atomic::AtomicUsize; use alloc::{collections::BTreeMap, sync::Arc}; -use crate::{driver::net::NetDevice, libs::rwlock::RwLock}; -use smoltcp::wire::IpEndpoint; - -use self::socket::SocketInode; +use crate::{driver::net::Iface, libs::rwlock::RwLock}; pub mod net_core; +pub mod posix; pub mod socket; pub mod syscall; lazy_static! { /// # 所有网络接口的列表 /// 这个列表在中断上下文会使用到,因此需要irqsave - pub static ref NET_DEVICES: RwLock>> = RwLock::new(BTreeMap::new()); + pub static ref NET_DEVICES: RwLock>> = RwLock::new(BTreeMap::new()); } /// 生成网络接口的id (全局自增) @@ -25,120 +24,3 @@ pub fn generate_iface_id() -> usize { static IFACE_ID: AtomicUsize = AtomicUsize::new(0); return IFACE_ID.fetch_add(1, core::sync::atomic::Ordering::SeqCst); } - -bitflags! { - /// @brief 用于指定socket的关闭类型 - /// 参考:https://code.dragonos.org.cn/xref/linux-6.1.9/include/net/sock.h?fi=SHUTDOWN_MASK#1573 - pub struct ShutdownType: u8 { - const RCV_SHUTDOWN = 1; - const SEND_SHUTDOWN = 2; - const SHUTDOWN_MASK = 3; - } -} - -#[derive(Debug, Clone)] -pub enum Endpoint { - /// 链路层端点 - LinkLayer(LinkLayerEndpoint), - /// 网络层端点 - Ip(Option), - /// inode端点 - Inode(Option>), - // todo: 增加NetLink机制后,增加NetLink端点 -} - -/// @brief 链路层端点 -#[derive(Debug, Clone)] -pub struct LinkLayerEndpoint { - /// 网卡的接口号 - pub interface: usize, -} - -impl LinkLayerEndpoint { - /// @brief 创建一个链路层端点 - /// - /// @param interface 网卡的接口号 - /// - /// @return 返回创建的链路层端点 - pub fn new(interface: usize) -> Self { - Self { interface } - } -} - -/// IP datagram encapsulated protocol. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -#[repr(u8)] -pub enum Protocol { - HopByHop = 0x00, - Icmp = 0x01, - Igmp = 0x02, - Tcp = 0x06, - Udp = 0x11, - Ipv6Route = 0x2b, - Ipv6Frag = 0x2c, - Icmpv6 = 0x3a, - Ipv6NoNxt = 0x3b, - Ipv6Opts = 0x3c, - Unknown(u8), -} - -impl fmt::Display for Protocol { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Protocol::HopByHop => write!(f, "Hop-by-Hop"), - Protocol::Icmp => write!(f, "ICMP"), - Protocol::Igmp => write!(f, "IGMP"), - Protocol::Tcp => write!(f, "TCP"), - Protocol::Udp => write!(f, "UDP"), - Protocol::Ipv6Route => write!(f, "IPv6-Route"), - Protocol::Ipv6Frag => write!(f, "IPv6-Frag"), - Protocol::Icmpv6 => write!(f, "ICMPv6"), - Protocol::Ipv6NoNxt => write!(f, "IPv6-NoNxt"), - Protocol::Ipv6Opts => write!(f, "IPv6-Opts"), - Protocol::Unknown(id) => write!(f, "0x{id:02x}"), - } - } -} - -impl From for Protocol { - fn from(value: smoltcp::wire::IpProtocol) -> Self { - let x: u8 = value.into(); - Protocol::from(x) - } -} - -impl From for Protocol { - fn from(value: u8) -> Self { - match value { - 0x00 => Protocol::HopByHop, - 0x01 => Protocol::Icmp, - 0x02 => Protocol::Igmp, - 0x06 => Protocol::Tcp, - 0x11 => Protocol::Udp, - 0x2b => Protocol::Ipv6Route, - 0x2c => Protocol::Ipv6Frag, - 0x3a => Protocol::Icmpv6, - 0x3b => Protocol::Ipv6NoNxt, - 0x3c => Protocol::Ipv6Opts, - _ => Protocol::Unknown(value), - } - } -} - -impl From for u8 { - fn from(value: Protocol) -> Self { - match value { - Protocol::HopByHop => 0x00, - Protocol::Icmp => 0x01, - Protocol::Igmp => 0x02, - Protocol::Tcp => 0x06, - Protocol::Udp => 0x11, - Protocol::Ipv6Route => 0x2b, - Protocol::Ipv6Frag => 0x2c, - Protocol::Icmpv6 => 0x3a, - Protocol::Ipv6NoNxt => 0x3b, - Protocol::Ipv6Opts => 0x3c, - Protocol::Unknown(id) => id, - } - } -} diff --git a/kernel/src/net/net_core.rs b/kernel/src/net/net_core.rs index d066fa4db..c62badc04 100644 --- a/kernel/src/net/net_core.rs +++ b/kernel/src/net/net_core.rs @@ -1,53 +1,26 @@ -use alloc::{boxed::Box, collections::BTreeMap, sync::Arc}; -use log::{debug, info, warn}; use smoltcp::{socket::dhcpv4, wire}; use system_error::SystemError; -use super::socket::{handle::GlobalSocketHandle, inet::TcpSocket, HANDLE_MAP, SOCKET_SET}; use crate::{ - driver::net::{NetDevice, Operstate}, - filesystem::epoll::{event_poll::EventPoll, EPollEventType}, - libs::rwlock::RwLockReadGuard, - net::{socket::SocketPollMethod, NET_DEVICES}, - time::{ - sleep::nanosleep, - timer::{next_n_ms_timer_jiffies, Timer, TimerFunction}, - PosixTimeSpec, - }, + driver::net::Operstate, + net::NET_DEVICES, + time::{sleep::nanosleep, PosixTimeSpec}, }; -/// The network poll function, which will be called by timer. -/// -/// The main purpose of this function is to poll all network interfaces. -#[derive(Debug)] -#[allow(dead_code)] -struct NetWorkPollFunc; - -impl TimerFunction for NetWorkPollFunc { - fn run(&mut self) -> Result<(), SystemError> { - poll_ifaces_try_lock(10).ok(); - let next_time = next_n_ms_timer_jiffies(10); - let timer = Timer::new(Box::new(NetWorkPollFunc), next_time); - timer.activate(); - return Ok(()); - } -} - pub fn net_init() -> Result<(), SystemError> { - dhcp_query()?; - // Init poll timer function - // let next_time = next_n_ms_timer_jiffies(5); - // let timer = Timer::new(Box::new(NetWorkPollFunc), next_time); - // timer.activate(); - return Ok(()); + dhcp_query() } fn dhcp_query() -> Result<(), SystemError> { let binding = NET_DEVICES.write_irqsave(); - //由于现在os未实现在用户态为网卡动态分配内存,而lo网卡的id最先分配且ip固定不能被分配 - //所以特判取用id为1的网卡(也就是virto_net) - let net_face = binding.get(&1).ok_or(SystemError::ENODEV)?.clone(); + // Default iface, misspelled to net_face + let net_face = binding + .iter() + .find(|(_, iface)| iface.common().is_default_iface()) + .unwrap() + .1 + .clone(); drop(binding); @@ -60,13 +33,18 @@ fn dhcp_query() -> Result<(), SystemError> { // IMPORTANT: This should be removed in production. dhcp_socket.set_max_lease_duration(Some(smoltcp::time::Duration::from_secs(10))); - let dhcp_handle = SOCKET_SET.lock_irqsave().add(dhcp_socket); + let sockets = || net_face.sockets().lock_irqsave(); - const DHCP_TRY_ROUND: u8 = 10; + let dhcp_handle = sockets().add(dhcp_socket); + defer::defer!({ + sockets().remove(dhcp_handle); + }); + + const DHCP_TRY_ROUND: u8 = 100; for i in 0..DHCP_TRY_ROUND { - debug!("DHCP try round: {}", i); - net_face.poll(&mut SOCKET_SET.lock_irqsave()).ok(); - let mut binding = SOCKET_SET.lock_irqsave(); + log::debug!("DHCP try round: {}", i); + net_face.poll(); + let mut binding = sockets(); let event = binding.get_mut::(dhcp_handle).poll(); match event { @@ -82,22 +60,35 @@ fn dhcp_query() -> Result<(), SystemError> { .ok(); if let Some(router) = config.router { - net_face - .inner_iface() - .lock() + let mut smol_iface = net_face.smol_iface().lock(); + smol_iface.routes_mut().update(|table| { + let _ = table.push(smoltcp::iface::Route { + cidr: smoltcp::wire::IpCidr::Ipv4(smoltcp::wire::Ipv4Cidr::new( + smoltcp::wire::Ipv4Address::new(127, 0, 0, 0), + 8, + )), + via_router: smoltcp::wire::IpAddress::v4(127, 0, 0, 1), + preferred_until: None, + expires_at: None, + }); + }); + if smol_iface .routes_mut() .add_default_ipv4_route(router) - .unwrap(); - let cidr = net_face.inner_iface().lock().ip_addrs().first().cloned(); + .is_err() + { + log::warn!("Route table full"); + } + let cidr = smol_iface.ip_addrs().first().cloned(); if let Some(cidr) = cidr { // 这里先在这里将网卡设置为up,后面等netlink实现了再修改 net_face.set_operstate(Operstate::IF_OPER_UP); - info!("Successfully allocated ip by Dhcpv4! Ip:{}", cidr); + log::info!("Successfully allocated ip by Dhcpv4! Ip:{}", cidr); return Ok(()); } } else { net_face - .inner_iface() + .smol_iface() .lock() .routes_mut() .remove_default_ipv4_route(); @@ -105,7 +96,7 @@ fn dhcp_query() -> Result<(), SystemError> { } Some(dhcpv4::Event::Deconfigured) => { - debug!("Dhcp v4 deconfigured"); + log::debug!("Dhcp v4 deconfigured"); net_face .update_ip_addrs(&[smoltcp::wire::IpCidr::Ipv4(wire::Ipv4Cidr::new( wire::Ipv4Address::UNSPECIFIED, @@ -113,7 +104,7 @@ fn dhcp_query() -> Result<(), SystemError> { ))]) .ok(); net_face - .inner_iface() + .smol_iface() .lock() .routes_mut() .remove_default_ipv4_route(); @@ -123,133 +114,11 @@ fn dhcp_query() -> Result<(), SystemError> { drop(binding); let sleep_time = PosixTimeSpec { - tv_sec: 5, - tv_nsec: 0, + tv_sec: 0, + tv_nsec: 50, }; let _ = nanosleep(sleep_time)?; } return Err(SystemError::ETIMEDOUT); } - -pub fn poll_ifaces() { - let guard: RwLockReadGuard>> = NET_DEVICES.read_irqsave(); - if guard.len() == 0 { - warn!("poll_ifaces: No net driver found!"); - return; - } - let mut sockets = SOCKET_SET.lock_irqsave(); - for (_, iface) in guard.iter() { - iface.poll(&mut sockets).ok(); - } - let _ = send_event(&sockets); -} - -/// 对ifaces进行轮询,最多对SOCKET_SET尝试times次加锁。 -/// -/// @return 轮询成功,返回Ok(()) -/// @return 加锁超时,返回SystemError::EAGAIN_OR_EWOULDBLOCK -/// @return 没有网卡,返回SystemError::ENODEV -pub fn poll_ifaces_try_lock(times: u16) -> Result<(), SystemError> { - let mut i = 0; - while i < times { - let guard: RwLockReadGuard>> = - NET_DEVICES.read_irqsave(); - if guard.len() == 0 { - warn!("poll_ifaces: No net driver found!"); - // 没有网卡,返回错误 - return Err(SystemError::ENODEV); - } - let sockets = SOCKET_SET.try_lock_irqsave(); - // 加锁失败,继续尝试 - if sockets.is_err() { - i += 1; - continue; - } - - let mut sockets = sockets.unwrap(); - for (_, iface) in guard.iter() { - iface.poll(&mut sockets).ok(); - } - send_event(&sockets)?; - return Ok(()); - } - // 尝试次数用完,返回错误 - return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); -} - -/// 对ifaces进行轮询,最多对SOCKET_SET尝试一次加锁。 -/// -/// @return 轮询成功,返回Ok(()) -/// @return 加锁超时,返回SystemError::EAGAIN_OR_EWOULDBLOCK -/// @return 没有网卡,返回SystemError::ENODEV -pub fn poll_ifaces_try_lock_onetime() -> Result<(), SystemError> { - let guard: RwLockReadGuard>> = NET_DEVICES.read_irqsave(); - if guard.len() == 0 { - warn!("poll_ifaces: No net driver found!"); - // 没有网卡,返回错误 - return Err(SystemError::ENODEV); - } - let mut sockets = SOCKET_SET.try_lock_irqsave()?; - for (_, iface) in guard.iter() { - iface.poll(&mut sockets).ok(); - } - send_event(&sockets)?; - return Ok(()); -} - -/// ### 处理轮询后的事件 -fn send_event(sockets: &smoltcp::iface::SocketSet) -> Result<(), SystemError> { - for (handle, socket_type) in sockets.iter() { - let handle_guard = HANDLE_MAP.read_irqsave(); - let global_handle = GlobalSocketHandle::new_smoltcp_handle(handle); - let item: Option<&super::socket::SocketHandleItem> = handle_guard.get(&global_handle); - if item.is_none() { - continue; - } - - let handle_item = item.unwrap(); - let posix_item = handle_item.posix_item(); - if posix_item.is_none() { - continue; - } - let posix_item = posix_item.unwrap(); - - // 获取socket上的事件 - let mut events = SocketPollMethod::poll(socket_type, handle_item).bits() as u64; - - // 分发到相应类型socket处理 - match socket_type { - smoltcp::socket::Socket::Raw(_) | smoltcp::socket::Socket::Udp(_) => { - posix_item.wakeup_any(events); - } - smoltcp::socket::Socket::Icmp(_) => unimplemented!("Icmp socket hasn't unimplemented"), - smoltcp::socket::Socket::Tcp(inner_socket) => { - if inner_socket.is_active() { - events |= TcpSocket::CAN_ACCPET; - } - if inner_socket.state() == smoltcp::socket::tcp::State::Established { - events |= TcpSocket::CAN_CONNECT; - } - if inner_socket.state() == smoltcp::socket::tcp::State::CloseWait { - events |= EPollEventType::EPOLLHUP.bits() as u64; - } - - posix_item.wakeup_any(events); - } - smoltcp::socket::Socket::Dhcpv4(_) => {} - smoltcp::socket::Socket::Dns(_) => unimplemented!("Dns socket hasn't unimplemented"), - } - EventPoll::wakeup_epoll( - &posix_item.epitems, - EPollEventType::from_bits_truncate(events as u32), - )?; - drop(handle_guard); - // crate::debug!( - // "{} send_event {:?}", - // handle, - // EPollEventType::from_bits_truncate(events as u32) - // ); - } - Ok(()) -} diff --git a/kernel/src/net/posix.rs b/kernel/src/net/posix.rs new file mode 100644 index 000000000..b4d6c07de --- /dev/null +++ b/kernel/src/net/posix.rs @@ -0,0 +1,378 @@ +// +// posix.rs 记录了系统调用时用到的结构 +// + +bitflags::bitflags! { + // #[derive(PartialEq, Eq, Debug, Clone, Copy)] + pub struct PosixArgsSocketType: u32 { + const DGRAM = 1; // 0b0000_0001 + const STREAM = 2; // 0b0000_0010 + const RAW = 3; // 0b0000_0011 + const RDM = 4; // 0b0000_0100 + const SEQPACKET = 5; // 0b0000_0101 + const DCCP = 6; // 0b0000_0110 + const PACKET = 10; // 0b0000_1010 + + const NONBLOCK = crate::filesystem::vfs::file::FileMode::O_NONBLOCK.bits(); + const CLOEXEC = crate::filesystem::vfs::file::FileMode::O_CLOEXEC.bits(); + } +} + +impl PosixArgsSocketType { + #[inline(always)] + pub fn types(&self) -> PosixArgsSocketType { + PosixArgsSocketType::from_bits(self.bits() & 0b_1111).unwrap() + } + + #[inline(always)] + pub fn is_nonblock(&self) -> bool { + self.contains(PosixArgsSocketType::NONBLOCK) + } + + #[inline(always)] + pub fn is_cloexec(&self) -> bool { + self.contains(PosixArgsSocketType::CLOEXEC) + } +} + +use alloc::string::String; +use alloc::sync::Arc; +use core::ffi::CStr; +use system_error::SystemError; + +use crate::{ + filesystem::vfs::{FileType, IndexNode, ROOT_INODE, VFS_MAX_FOLLOW_SYMLINK_TIMES}, + mm::{verify_area, VirtAddr}, + net::socket::unix::ns::abs::{alloc_abs_addr, look_up_abs_addr}, + process::ProcessManager, +}; + +use super::socket::{endpoint::Endpoint, AddressFamily}; + +// 参考资料: https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/netinet_in.h.html#tag_13_32 +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct SockAddrIn { + pub sin_family: u16, + pub sin_port: u16, + pub sin_addr: u32, + pub sin_zero: [u8; 8], +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct SockAddrUn { + pub sun_family: u16, + pub sun_path: [u8; 108], +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct SockAddrLl { + pub sll_family: u16, + pub sll_protocol: u16, + pub sll_ifindex: u32, + pub sll_hatype: u16, + pub sll_pkttype: u8, + pub sll_halen: u8, + pub sll_addr: [u8; 8], +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct SockAddrNl { + pub nl_family: AddressFamily, + pub nl_pad: u16, + pub nl_pid: u32, + pub nl_groups: u32, +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct SockAddrPlaceholder { + pub family: u16, + pub data: [u8; 14], +} + +#[repr(C)] +#[derive(Clone, Copy)] +pub union SockAddr { + pub family: u16, + pub addr_in: SockAddrIn, + pub addr_un: SockAddrUn, + pub addr_ll: SockAddrLl, + pub addr_nl: SockAddrNl, + pub addr_ph: SockAddrPlaceholder, +} + +impl SockAddr { + /// @brief 把用户传入的SockAddr转换为Endpoint结构体 + pub fn to_endpoint(addr: *const SockAddr, len: u32) -> Result { + use crate::net::socket::AddressFamily; + + let addr = unsafe { addr.as_ref() }.ok_or(SystemError::EFAULT)?; + unsafe { + match AddressFamily::try_from(addr.family)? { + AddressFamily::INet => { + if len < addr.len()? { + log::error!("len < addr.len()"); + return Err(SystemError::EINVAL); + } + + let addr_in: SockAddrIn = addr.addr_in; + + use smoltcp::wire; + let ip: wire::IpAddress = wire::IpAddress::from(wire::Ipv4Address::from_bits( + u32::from_be(addr_in.sin_addr), + )); + let port = u16::from_be(addr_in.sin_port); + + return Ok(Endpoint::Ip(wire::IpEndpoint::new(ip, port))); + } + // AddressFamily::INet6 => { + // if len < addr.len()? { + // log::error!("len < addr.len()"); + // return Err(SystemError::EINVAL); + // } + // log::debug!("INet6"); + // let addr_in: SockAddrIn = addr.addr_in; + + // use smoltcp::wire; + // let ip: wire::IpAddress = wire::IpAddress::from(wire::Ipv6Address::from_bits( + // u128::from_be(addr_in.sin_addr), + // )); + // let port = u16::from_be(addr_in.sin_port); + + // return Ok(Endpoint::Ip(wire::IpEndpoint::new(ip, port))); + // } + AddressFamily::Unix => { + let addr_un: SockAddrUn = addr.addr_un; + + if addr_un.sun_path[0] == 0 { + // 抽象地址空间,与文件系统没有关系 + let path = CStr::from_bytes_until_nul(&addr_un.sun_path[1..]) + .map_err(|_| { + log::error!("CStr::from_bytes_until_nul fail"); + SystemError::EINVAL + })? + .to_str() + .map_err(|_| { + log::error!("CStr::to_str fail"); + SystemError::EINVAL + })?; + + // 向抽象地址管理器申请或查找抽象地址 + let spath = String::from(path); + log::debug!("abs path: {}", spath); + let abs_find = match look_up_abs_addr(&spath) { + Ok(result) => result, + Err(_) => { + //未找到尝试分配abs + match alloc_abs_addr(spath.clone()) { + Ok(result) => { + log::debug!("alloc abs addr success!"); + return Ok(result); + } + Err(e) => { + log::debug!("alloc abs addr failed!"); + return Err(e); + } + }; + } + }; + log::debug!("find alloc abs addr success!"); + return Ok(abs_find); + } + + let path = CStr::from_bytes_until_nul(&addr_un.sun_path) + .map_err(|_| { + log::error!("CStr::from_bytes_until_nul fail"); + SystemError::EINVAL + })? + .to_str() + .map_err(|_| { + log::error!("CStr::to_str fail"); + SystemError::EINVAL + })?; + + let (inode_begin, path) = crate::filesystem::vfs::utils::user_path_at( + &ProcessManager::current_pcb(), + crate::filesystem::vfs::fcntl::AtFlags::AT_FDCWD.bits(), + path.trim(), + )?; + let inode0: Result, SystemError> = + inode_begin.lookup_follow_symlink(&path, VFS_MAX_FOLLOW_SYMLINK_TIMES); + + let inode = match inode0 { + Ok(inode) => inode, + Err(_) => { + let (filename, parent_path) = + crate::filesystem::vfs::utils::rsplit_path(&path); + // 查找父目录 + log::debug!("filename {:?} parent_path {:?}", filename, parent_path); + + let parent_inode: Arc = + ROOT_INODE().lookup(parent_path.unwrap_or("/"))?; + // 创建文件 + let inode: Arc = match parent_inode.create( + filename, + FileType::File, + crate::filesystem::vfs::syscall::ModeType::from_bits_truncate( + 0o755, + ), + ) { + Ok(inode) => inode, + Err(e) => { + log::debug!("inode create fail {:?}", e); + return Err(e); + } + }; + inode + } + }; + + return Ok(Endpoint::Unixpath((inode.metadata()?.inode_id, path))); + } + _ => { + log::warn!("not support address family {:?}", addr.family); + return Err(SystemError::EINVAL); + } + } + } + } + + /// @brief 获取地址长度 + pub fn len(&self) -> Result { + match AddressFamily::try_from(unsafe { self.family })? { + AddressFamily::INet => Ok(core::mem::size_of::()), + AddressFamily::Packet => Ok(core::mem::size_of::()), + AddressFamily::Netlink => Ok(core::mem::size_of::()), + AddressFamily::Unix => Ok(core::mem::size_of::()), + _ => Err(SystemError::EINVAL), + } + .map(|x| x as u32) + } + + /// @brief 把SockAddr的数据写入用户空间 + /// + /// @param addr 用户空间的SockAddr的地址 + /// @param len 要写入的长度 + /// + /// @return 成功返回写入的长度,失败返回错误码 + pub unsafe fn write_to_user( + &self, + addr: *mut SockAddr, + addr_len: *mut u32, + ) -> Result { + // 当用户传入的地址或者长度为空时,直接返回0 + if addr.is_null() || addr_len.is_null() { + return Ok(0); + } + + // 检查用户传入的地址是否合法 + verify_area( + VirtAddr::new(addr as usize), + core::mem::size_of::(), + ) + .map_err(|_| SystemError::EFAULT)?; + + verify_area( + VirtAddr::new(addr_len as usize), + core::mem::size_of::(), + ) + .map_err(|_| SystemError::EFAULT)?; + + let to_write = core::cmp::min(self.len()?, *addr_len); + if to_write > 0 { + let buf = core::slice::from_raw_parts_mut(addr as *mut u8, to_write as usize); + buf.copy_from_slice(core::slice::from_raw_parts( + self as *const SockAddr as *const u8, + to_write as usize, + )); + } + *addr_len = self.len()?; + return Ok(to_write); + } + + pub unsafe fn is_empty(&self) -> bool { + unsafe { self.family == 0 && self.addr_ph.data == [0; 14] } + } +} + +impl From for SockAddr { + fn from(value: Endpoint) -> Self { + match value { + Endpoint::Ip(ip_endpoint) => match ip_endpoint.addr { + smoltcp::wire::IpAddress::Ipv4(ipv4_addr) => { + let addr_in = SockAddrIn { + sin_family: AddressFamily::INet as u16, + sin_port: ip_endpoint.port.to_be(), + sin_addr: ipv4_addr.to_bits(), + sin_zero: [0; 8], + }; + + return SockAddr { addr_in }; + } + _ => { + unimplemented!("not support ipv6"); + } + }, + + Endpoint::LinkLayer(link_endpoint) => { + let addr_ll = SockAddrLl { + sll_family: AddressFamily::Packet as u16, + sll_protocol: 0, + sll_ifindex: link_endpoint.interface as u32, + sll_hatype: 0, + sll_pkttype: 0, + sll_halen: 0, + sll_addr: [0; 8], + }; + + return SockAddr { addr_ll }; + } + + Endpoint::Inode((_, path)) => { + log::debug!("from unix path {:?}", path); + let bytes = path.as_bytes(); + let mut sun_path = [0u8; 108]; + if bytes.len() <= 108 { + sun_path[..bytes.len()].copy_from_slice(bytes); + } else { + panic!("unix address path too long!"); + } + let addr_un = SockAddrUn { + sun_family: AddressFamily::Unix as u16, + sun_path, + }; + return SockAddr { addr_un }; + } + + _ => { + // todo: support other endpoint, like Netlink... + unimplemented!("not support {value:?}"); + } + } + } +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct MsgHdr { + /// 指向一个SockAddr结构体的指针 + pub msg_name: *mut SockAddr, + /// SockAddr结构体的大小 + pub msg_namelen: u32, + /// scatter/gather array + pub msg_iov: *mut crate::filesystem::vfs::iov::IoVec, + /// elements in msg_iov + pub msg_iovlen: usize, + /// 辅助数据 + pub msg_control: *mut u8, + /// 辅助数据长度 + pub msg_controllen: u32, + /// 接收到的消息的标志 + pub msg_flags: u32, +} + +// TODO: 从用户态读取MsgHdr,以及写入MsgHdr diff --git a/kernel/src/net/socket/base.rs b/kernel/src/net/socket/base.rs new file mode 100644 index 000000000..a772dacc2 --- /dev/null +++ b/kernel/src/net/socket/base.rs @@ -0,0 +1,140 @@ +use crate::{libs::wait_queue::WaitQueue, net::posix::MsgHdr}; +use alloc::sync::Arc; +use core::any::Any; +use core::fmt::Debug; +use system_error::SystemError; + +use super::{ + common::shutdown::ShutdownTemp, + endpoint::Endpoint, + posix::{PMSG, PSOL}, + SocketInode, +}; + +/// # `Socket` methods +/// ## Reference +/// - [Posix standard](https://pubs.opengroup.org/onlinepubs/9699919799/) +#[allow(unused_variables)] +pub trait Socket: Sync + Send + Debug + Any { + /// # `wait_queue` + /// 获取socket的wait queue + fn wait_queue(&self) -> &WaitQueue; + /// # `socket_poll` + /// 获取socket的事件。 + fn poll(&self) -> usize; + + fn send_buffer_size(&self) -> usize; + fn recv_buffer_size(&self) -> usize; + /// # `accept` + /// 接受连接,仅用于listening stream socket + /// ## Block + /// 如果没有连接到来,会阻塞 + fn accept(&self) -> Result<(Arc, Endpoint), SystemError> { + Err(SystemError::ENOSYS) + } + /// # `bind` + /// 对应于POSIX的bind函数,用于绑定到本机指定的端点 + fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> { + Err(SystemError::ENOSYS) + } + /// # `close` + /// 关闭socket + fn close(&self) -> Result<(), SystemError> { + Ok(()) + } + /// # `connect` + /// 对应于POSIX的connect函数,用于连接到指定的远程服务器端点 + fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> { + Err(SystemError::ENOSYS) + } + // fnctl + // freeaddrinfo + // getaddrinfo + // getnameinfo + /// # `get_peer_name` + /// 获取对端的地址 + fn get_peer_name(&self) -> Result { + Err(SystemError::ENOSYS) + } + /// # `get_name` + /// 获取socket的地址 + fn get_name(&self) -> Result { + Err(SystemError::ENOSYS) + } + /// # `get_option` + /// 对应于 Posix `getsockopt` ,获取socket选项 + fn get_option(&self, level: PSOL, name: usize, value: &mut [u8]) -> Result { + log::warn!("getsockopt is not implemented"); + Ok(0) + } + /// # `listen` + /// 监听socket,仅用于stream socket + fn listen(&self, backlog: usize) -> Result<(), SystemError> { + Err(SystemError::ENOSYS) + } + // poll + // pselect + /// # `read` + fn read(&self, buffer: &mut [u8]) -> Result { + self.recv(buffer, PMSG::empty()) + } + /// # `recv` + /// 接收数据,`read` = `recv` with flags = 0 + fn recv(&self, buffer: &mut [u8], flags: PMSG) -> Result { + Err(SystemError::ENOSYS) + } + /// # `recv_from` + fn recv_from( + &self, + buffer: &mut [u8], + flags: PMSG, + address: Option, + ) -> Result<(usize, Endpoint), SystemError> { + Err(SystemError::ENOSYS) + } + /// # `recv_msg` + fn recv_msg(&self, msg: &mut MsgHdr, flags: PMSG) -> Result { + Err(SystemError::ENOSYS) + } + // select + /// # `send` + fn send(&self, buffer: &[u8], flags: PMSG) -> Result { + Err(SystemError::ENOSYS) + } + /// # `send_msg` + fn send_msg(&self, msg: &MsgHdr, flags: PMSG) -> Result { + Err(SystemError::ENOSYS) + } + /// # `send_to` + fn send_to(&self, buffer: &[u8], flags: PMSG, address: Endpoint) -> Result { + Err(SystemError::ENOSYS) + } + /// # `set_option` + /// Posix `setsockopt` ,设置socket选项 + /// ## Parameters + /// - level 选项的层次 + /// - name 选项的名称 + /// - value 选项的值 + /// ## Reference + /// https://code.dragonos.org.cn/s?refs=sk_setsockopt&project=linux-6.6.21 + fn set_option(&self, level: PSOL, name: usize, val: &[u8]) -> Result<(), SystemError> { + log::warn!("setsockopt is not implemented"); + Ok(()) + } + /// # `shutdown` + fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> { + // TODO 构建shutdown系统调用 + // set shutdown bit + Err(SystemError::ENOSYS) + } + // sockatmark + // socket + // socketpair + /// # `write` + fn write(&self, buffer: &[u8]) -> Result { + self.send(buffer, PMSG::empty()) + } + // fn write_buffer(&self, _buf: &[u8]) -> Result { + // todo!() + // } +} diff --git a/kernel/src/net/socket/buffer.rs b/kernel/src/net/socket/buffer.rs new file mode 100644 index 000000000..66334dc16 --- /dev/null +++ b/kernel/src/net/socket/buffer.rs @@ -0,0 +1,95 @@ +#!(allow(unused)) +use alloc::vec::Vec; + +use alloc::sync::Arc; +use system_error::SystemError; + +use crate::libs::spinlock::SpinLock; + +#[derive(Debug)] +pub struct Buffer { + metadata: Metadata, + read_buffer: SpinLock>, + write_buffer: SpinLock>, +} + +impl Buffer { + pub fn new() -> Arc { + Arc::new(Self { + metadata: Metadata::default(), + read_buffer: SpinLock::new(Vec::new()), + write_buffer: SpinLock::new(Vec::new()), + }) + } + + pub fn is_read_buf_empty(&self) -> bool { + return self.read_buffer.lock().is_empty(); + } + + pub fn is_read_buf_full(&self) -> bool { + return self.metadata.buf_size - self.read_buffer.lock().len() == 0; + } + + #[allow(dead_code)] + pub fn is_write_buf_empty(&self) -> bool { + return self.write_buffer.lock().is_empty(); + } + + #[allow(dead_code)] + pub fn is_write_buf_full(&self) -> bool { + return self.write_buffer.lock().len() >= self.metadata.buf_size; + } + + pub fn read_read_buffer(&self, buf: &mut [u8]) -> Result { + let mut read_buffer = self.read_buffer.lock_irqsave(); + let len = core::cmp::min(buf.len(), read_buffer.len()); + buf[..len].copy_from_slice(&read_buffer[..len]); + let _ = read_buffer.split_off(len); + // log::debug!("recv buf {}", String::from_utf8_lossy(buf)); + + return Ok(len); + } + + pub fn write_read_buffer(&self, buf: &[u8]) -> Result { + let mut buffer = self.read_buffer.lock_irqsave(); + // log::debug!("send buf {}", String::from_utf8_lossy(buf)); + let len = buf.len(); + if self.metadata.buf_size - buffer.len() < len { + return Err(SystemError::ENOBUFS); + } + buffer.extend_from_slice(buf); + + Ok(len) + } + + #[allow(dead_code)] + pub fn write_write_buffer(&self, buf: &[u8]) -> Result { + let mut buffer = self.write_buffer.lock_irqsave(); + + let len = buf.len(); + if self.metadata.buf_size - buffer.len() < len { + return Err(SystemError::ENOBUFS); + } + buffer.extend_from_slice(buf); + + Ok(len) + } +} + +#[derive(Debug)] +pub struct Metadata { + /// 默认的元数据缓冲区大小 + #[allow(dead_code)] + metadata_buf_size: usize, + /// 默认的缓冲区大小 + buf_size: usize, +} + +impl Default for Metadata { + fn default() -> Self { + Self { + metadata_buf_size: 1024, + buf_size: 64 * 1024, + } + } +} diff --git a/kernel/src/net/socket/common/epoll_items.rs b/kernel/src/net/socket/common/epoll_items.rs new file mode 100644 index 000000000..b81f7ec80 --- /dev/null +++ b/kernel/src/net/socket/common/epoll_items.rs @@ -0,0 +1,62 @@ +use alloc::{ + collections::LinkedList, + sync::{Arc, Weak}, + vec::Vec, +}; +use system_error::SystemError; + +use crate::{ + filesystem::epoll::{event_poll::EventPoll, EPollItem}, + libs::spinlock::SpinLock, +}; + +#[derive(Debug, Clone)] +pub struct EPollItems { + items: Arc>>>, +} + +impl Default for EPollItems { + fn default() -> Self { + Self { + items: Arc::new(SpinLock::new(LinkedList::new())), + } + } +} + +impl EPollItems { + pub fn add(&self, item: Arc) { + self.items.lock_irqsave().push_back(item); + } + + pub fn remove(&self, item: &Weak>) -> Result<(), SystemError> { + let to_remove = self + .items + .lock_irqsave() + .extract_if(|x| x.epoll().ptr_eq(item)) + .collect::>(); + + let result = if !to_remove.is_empty() { + Ok(()) + } else { + Err(SystemError::ENOENT) + }; + + drop(to_remove); + return result; + } + + pub fn clear(&self) -> Result<(), SystemError> { + let mut guard = self.items.lock_irqsave(); + let mut result = Ok(()); + guard.iter().for_each(|item| { + if let Some(epoll) = item.epoll().upgrade() { + let _ = EventPoll::ep_remove(&mut epoll.lock_irqsave(), item.fd(), None, item) + .map_err(|e| { + result = Err(e); + }); + } + }); + guard.clear(); + return result; + } +} diff --git a/kernel/src/net/socket/common/mod.rs b/kernel/src/net/socket/common/mod.rs new file mode 100644 index 000000000..8662b4d81 --- /dev/null +++ b/kernel/src/net/socket/common/mod.rs @@ -0,0 +1,18 @@ +// pub mod poll_unit; +mod epoll_items; + +pub mod shutdown; +pub use epoll_items::EPollItems; + +// /// @brief 在trait Socket的metadata函数中返回该结构体供外部使用 +// #[derive(Debug, Clone)] +// pub struct Metadata { +// /// 接收缓冲区的大小 +// pub rx_buf_size: usize, +// /// 发送缓冲区的大小 +// pub tx_buf_size: usize, +// /// 元数据的缓冲区的大小 +// pub metadata_buf_size: usize, +// /// socket的选项 +// pub options: SocketOptions, +// } diff --git a/kernel/src/net/socket/common/poll_unit.rs b/kernel/src/net/socket/common/poll_unit.rs new file mode 100644 index 000000000..ee88d7884 --- /dev/null +++ b/kernel/src/net/socket/common/poll_unit.rs @@ -0,0 +1,72 @@ +use alloc::{ + collections::LinkedList, + sync::{Arc, Weak}, + vec::Vec, +}; +use system_error::SystemError; + +use crate::{ + libs::{spinlock::SpinLock, wait_queue::EventWaitQueue}, + net::event_poll::{EPollEventType, EPollItem, EventPoll}, + process::ProcessManager, + sched::{schedule, SchedMode}, +}; + +#[derive(Debug, Clone)] +pub struct WaitQueue { + /// socket的waitqueue + wait_queue: Arc, +} + +impl Default for WaitQueue { + fn default() -> Self { + Self { + wait_queue: Default::default(), + } + } +} + +impl WaitQueue { + pub fn new(wait_queue: EventWaitQueue) -> Self { + Self { + wait_queue: Arc::new(wait_queue), + } + } + + /// # `wakeup_any` + /// 唤醒该队列上等待events的进程 + /// ## 参数 + /// - events: 发生的事件 + /// 需要注意的是,只要触发了events中的任意一件事件,进程都会被唤醒 + pub fn wakeup_any(&self, events: EPollEventType) { + self.wait_queue.wakeup_any(events.bits() as u64); + } + + /// # `wait_for` + /// 等待events事件发生 + pub fn wait_for(&self, events: EPollEventType) { + unsafe { + ProcessManager::preempt_disable(); + self.wait_queue.sleep_without_schedule(events.bits() as u64); + ProcessManager::preempt_enable(); + } + schedule(SchedMode::SM_NONE); + } + + /// # `busy_wait` + /// 轮询一个会返回EPAGAIN_OR_EWOULDBLOCK的函数 + pub fn busy_wait(&self, events: EPollEventType, mut f: F) -> Result + where + F: FnMut() -> Result, + { + loop { + match f() { + Ok(r) => return Ok(r), + Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { + self.wait_for(events); + } + Err(e) => return Err(e), + } + } + } +} diff --git a/kernel/src/net/socket/common/shutdown.rs b/kernel/src/net/socket/common/shutdown.rs new file mode 100644 index 000000000..609527ceb --- /dev/null +++ b/kernel/src/net/socket/common/shutdown.rs @@ -0,0 +1,135 @@ +// TODO: 其他模块需要实现shutdown的具体逻辑 +#![allow(dead_code)] +use core::sync::atomic::AtomicU8; + +use system_error::SystemError; + +bitflags! { + /// @brief 用于指定socket的关闭类型 + /// 参考:https://code.dragonos.org.cn/xref/linux-6.1.9/include/net/sock.h?fi=SHUTDOWN_MASK#1573 + pub struct ShutdownBit: u8 { + const SHUT_RD = 0; + const SHUT_WR = 1; + const SHUT_RDWR = 2; + } +} + +const RCV_SHUTDOWN: u8 = 0x01; +const SEND_SHUTDOWN: u8 = 0x02; +const SHUTDOWN_MASK: u8 = 0x03; + +#[derive(Debug, Default)] +pub struct Shutdown { + bit: AtomicU8, +} + +impl From for Shutdown { + fn from(shutdown_bit: ShutdownBit) -> Self { + match shutdown_bit { + ShutdownBit::SHUT_RD => Shutdown { + bit: AtomicU8::new(RCV_SHUTDOWN), + }, + ShutdownBit::SHUT_WR => Shutdown { + bit: AtomicU8::new(SEND_SHUTDOWN), + }, + ShutdownBit::SHUT_RDWR => Shutdown { + bit: AtomicU8::new(SHUTDOWN_MASK), + }, + _ => Shutdown::default(), + } + } +} + +impl Shutdown { + pub fn new() -> Self { + Self { + bit: AtomicU8::new(0), + } + } + + pub fn recv_shutdown(&self) { + self.bit + .fetch_or(RCV_SHUTDOWN, core::sync::atomic::Ordering::SeqCst); + } + + pub fn send_shutdown(&self) { + self.bit + .fetch_or(SEND_SHUTDOWN, core::sync::atomic::Ordering::SeqCst); + } + + pub fn is_recv_shutdown(&self) -> bool { + self.bit.load(core::sync::atomic::Ordering::SeqCst) & RCV_SHUTDOWN != 0 + } + + pub fn is_send_shutdown(&self) -> bool { + self.bit.load(core::sync::atomic::Ordering::SeqCst) & SEND_SHUTDOWN != 0 + } + + pub fn is_both_shutdown(&self) -> bool { + self.bit.load(core::sync::atomic::Ordering::SeqCst) & SHUTDOWN_MASK == SHUTDOWN_MASK + } + + pub fn is_empty(&self) -> bool { + self.bit.load(core::sync::atomic::Ordering::SeqCst) == 0 + } + + pub fn from_how(how: usize) -> Self { + Self::from(ShutdownBit::from_bits_truncate(how as u8)) + } + + pub fn get(&self) -> ShutdownTemp { + ShutdownTemp { + bit: self.bit.load(core::sync::atomic::Ordering::SeqCst), + } + } +} + +pub struct ShutdownTemp { + bit: u8, +} + +impl ShutdownTemp { + pub fn is_recv_shutdown(&self) -> bool { + self.bit & RCV_SHUTDOWN != 0 + } + + pub fn is_send_shutdown(&self) -> bool { + self.bit & SEND_SHUTDOWN != 0 + } + + pub fn is_both_shutdown(&self) -> bool { + self.bit & SHUTDOWN_MASK == SHUTDOWN_MASK + } + + pub fn is_empty(&self) -> bool { + self.bit == 0 + } + + pub fn bits(&self) -> ShutdownBit { + ShutdownBit { bits: self.bit } + } +} + +impl From for ShutdownTemp { + fn from(shutdown_bit: ShutdownBit) -> Self { + match shutdown_bit { + ShutdownBit::SHUT_RD => Self { bit: RCV_SHUTDOWN }, + ShutdownBit::SHUT_WR => Self { bit: SEND_SHUTDOWN }, + ShutdownBit::SHUT_RDWR => Self { bit: SHUTDOWN_MASK }, + _ => Self { bit: 0 }, + } + } +} + +impl TryFrom for ShutdownTemp { + type Error = SystemError; + + fn try_from(value: usize) -> Result { + match value { + 0..2 => Ok(ShutdownTemp { + bit: value as u8 + 1, + }), + _ => Err(SystemError::EINVAL), + } + } +} diff --git a/kernel/src/net/socket/endpoint.rs b/kernel/src/net/socket/endpoint.rs new file mode 100644 index 000000000..35fa06bd2 --- /dev/null +++ b/kernel/src/net/socket/endpoint.rs @@ -0,0 +1,44 @@ +use crate::{filesystem::vfs::InodeId, net::socket}; +use alloc::{string::String, sync::Arc}; + +pub use smoltcp::wire::IpEndpoint; + +use super::unix::ns::abs::AbsHandle; + +#[derive(Debug, Clone)] +pub enum Endpoint { + /// 链路层端点 + LinkLayer(LinkLayerEndpoint), + /// 网络层端点 + Ip(IpEndpoint), + /// inode端点,Unix实际保存的端点 + Inode((Arc, String)), + /// Unix传递id索引和path所用的端点 + Unixpath((InodeId, String)), + /// Unix抽象端点 + Abspath((AbsHandle, String)), +} + +/// @brief 链路层端点 +#[derive(Debug, Clone)] +pub struct LinkLayerEndpoint { + /// 网卡的接口号 + pub interface: usize, +} + +impl LinkLayerEndpoint { + /// @brief 创建一个链路层端点 + /// + /// @param interface 网卡的接口号 + /// + /// @return 返回创建的链路层端点 + pub fn new(interface: usize) -> Self { + Self { interface } + } +} + +impl From for Endpoint { + fn from(endpoint: IpEndpoint) -> Self { + Self::Ip(endpoint) + } +} diff --git a/kernel/src/net/socket/family.rs b/kernel/src/net/socket/family.rs new file mode 100644 index 000000000..26200961a --- /dev/null +++ b/kernel/src/net/socket/family.rs @@ -0,0 +1,124 @@ +/// # AddressFamily +/// Socket address families. +/// ## Reference +/// https://code.dragonos.org.cn/xref/linux-5.19.10/include/linux/socket.h#180 +#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] +pub enum AddressFamily { + /// AF_UNSPEC 表示地址族未指定 + Unspecified = 0, + /// AF_UNIX 表示Unix域的socket (与AF_LOCAL相同) + Unix = 1, + /// AF_INET 表示IPv4的socket + INet = 2, + /// AF_AX25 表示AMPR AX.25的socket + AX25 = 3, + /// AF_IPX 表示IPX的socket + IPX = 4, + /// AF_APPLETALK 表示Appletalk的socket + Appletalk = 5, + /// AF_NETROM 表示AMPR NET/ROM的socket + Netrom = 6, + /// AF_BRIDGE 表示多协议桥接的socket + Bridge = 7, + /// AF_ATMPVC 表示ATM PVCs的socket + Atmpvc = 8, + /// AF_X25 表示X.25的socket + X25 = 9, + /// AF_INET6 表示IPv6的socket + INet6 = 10, + /// AF_ROSE 表示AMPR ROSE的socket + Rose = 11, + /// AF_DECnet Reserved for DECnet project + Decnet = 12, + /// AF_NETBEUI Reserved for 802.2LLC project + Netbeui = 13, + /// AF_SECURITY 表示Security callback的伪AF + Security = 14, + /// AF_KEY 表示Key management API + Key = 15, + /// AF_NETLINK 表示Netlink的socket + Netlink = 16, + /// AF_PACKET 表示Low level packet interface + Packet = 17, + /// AF_ASH 表示Ash + Ash = 18, + /// AF_ECONET 表示Acorn Econet + Econet = 19, + /// AF_ATMSVC 表示ATM SVCs + Atmsvc = 20, + /// AF_RDS 表示Reliable Datagram Sockets + Rds = 21, + /// AF_SNA 表示Linux SNA Project + Sna = 22, + /// AF_IRDA 表示IRDA sockets + Irda = 23, + /// AF_PPPOX 表示PPPoX sockets + Pppox = 24, + /// AF_WANPIPE 表示WANPIPE API sockets + WanPipe = 25, + /// AF_LLC 表示Linux LLC + Llc = 26, + /// AF_IB 表示Native InfiniBand address + /// 介绍:https://access.redhat.com/documentation/en-us/red_hat_enterprise_linux/9/html-single/configuring_infiniband_and_rdma_networks/index#understanding-infiniband-and-rdma_configuring-infiniband-and-rdma-networks + Ib = 27, + /// AF_MPLS 表示MPLS + Mpls = 28, + /// AF_CAN 表示Controller Area Network + Can = 29, + /// AF_TIPC 表示TIPC sockets + Tipc = 30, + /// AF_BLUETOOTH 表示Bluetooth sockets + Bluetooth = 31, + /// AF_IUCV 表示IUCV sockets + Iucv = 32, + /// AF_RXRPC 表示RxRPC sockets + Rxrpc = 33, + /// AF_ISDN 表示mISDN sockets + Isdn = 34, + /// AF_PHONET 表示Phonet sockets + Phonet = 35, + /// AF_IEEE802154 表示IEEE 802.15.4 sockets + Ieee802154 = 36, + /// AF_CAIF 表示CAIF sockets + Caif = 37, + /// AF_ALG 表示Algorithm sockets + Alg = 38, + /// AF_NFC 表示NFC sockets + Nfc = 39, + /// AF_VSOCK 表示vSockets + Vsock = 40, + /// AF_KCM 表示Kernel Connection Multiplexor + Kcm = 41, + /// AF_QIPCRTR 表示Qualcomm IPC Router + Qipcrtr = 42, + /// AF_SMC 表示SMC-R sockets. + /// reserve number for PF_SMC protocol family that reuses AF_INET address family + Smc = 43, + /// AF_XDP 表示XDP sockets + Xdp = 44, + /// AF_MCTP 表示Management Component Transport Protocol + Mctp = 45, + /// AF_MAX 表示最大的地址族 + Max = 46, +} + +impl core::convert::TryFrom for AddressFamily { + type Error = system_error::SystemError; + fn try_from(x: u16) -> Result { + use num_traits::FromPrimitive; + // this will return EINVAL but still works, idk why + return ::from_u16(x).ok_or(Self::Error::EINVAL); + } +} + +use crate::net::socket; +use alloc::sync::Arc; + +use super::PSOCK; + +pub trait Family { + fn socket( + stype: PSOCK, + protocol: u32, + ) -> Result, system_error::SystemError>; +} diff --git a/kernel/src/net/socket/handle.rs b/kernel/src/net/socket/handle.rs deleted file mode 100644 index a94f7255d..000000000 --- a/kernel/src/net/socket/handle.rs +++ /dev/null @@ -1,42 +0,0 @@ -use ida::IdAllocator; -use smoltcp::iface::SocketHandle; - -use crate::libs::spinlock::SpinLock; - -int_like!(KernelHandle, usize); - -/// # socket的句柄管理组件 -/// 它在smoltcp的SocketHandle上封装了一层,增加更多的功能。 -/// 比如,在socket被关闭时,自动释放socket的资源,通知系统的其他组件。 -#[derive(Debug, Hash, Eq, PartialEq, Clone, Copy)] -pub enum GlobalSocketHandle { - Smoltcp(SocketHandle), - Kernel(KernelHandle), -} - -static KERNEL_HANDLE_IDA: SpinLock = - SpinLock::new(IdAllocator::new(0, usize::MAX).unwrap()); - -impl GlobalSocketHandle { - pub fn new_smoltcp_handle(handle: SocketHandle) -> Self { - return Self::Smoltcp(handle); - } - - pub fn new_kernel_handle() -> Self { - return Self::Kernel(KernelHandle::new(KERNEL_HANDLE_IDA.lock().alloc().unwrap())); - } - - pub fn smoltcp_handle(&self) -> Option { - if let Self::Smoltcp(sh) = *self { - return Some(sh); - } - None - } - - pub fn kernel_handle(&self) -> Option { - if let Self::Kernel(kh) = *self { - return Some(kh); - } - None - } -} diff --git a/kernel/src/net/socket/inet.rs b/kernel/src/net/socket/inet.rs deleted file mode 100644 index 547126459..000000000 --- a/kernel/src/net/socket/inet.rs +++ /dev/null @@ -1,1010 +0,0 @@ -use alloc::{boxed::Box, sync::Arc, vec::Vec}; -use log::{error, warn}; -use smoltcp::{ - socket::{raw, tcp, udp}, - wire, -}; -use system_error::SystemError; - -use crate::{ - driver::net::NetDevice, - filesystem::epoll::EPollEventType, - libs::rwlock::RwLock, - net::{net_core::poll_ifaces, Endpoint, Protocol, ShutdownType, NET_DEVICES}, -}; - -use super::{ - handle::GlobalSocketHandle, PosixSocketHandleItem, Socket, SocketHandleItem, SocketMetadata, - SocketOptions, SocketPollMethod, SocketType, HANDLE_MAP, PORT_MANAGER, SOCKET_SET, -}; - -/// @brief 表示原始的socket。原始套接字绕过传输层协议(如 TCP 或 UDP)并提供对网络层协议(如 IP)的直接访问。 -/// -/// ref: https://man7.org/linux/man-pages/man7/raw.7.html -#[derive(Debug, Clone)] -pub struct RawSocket { - handle: GlobalSocketHandle, - /// 用户发送的数据包是否包含了IP头. - /// 如果是true,用户发送的数据包,必须包含IP头。(即用户要自行设置IP头+数据) - /// 如果是false,用户发送的数据包,不包含IP头。(即用户只要设置数据) - header_included: bool, - /// socket的metadata - metadata: SocketMetadata, - posix_item: Arc, -} - -impl RawSocket { - /// 元数据的缓冲区的大小 - pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; - /// 默认的接收缓冲区的大小 receive - pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; - /// 默认的发送缓冲区的大小 transmiss - pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; - - /// @brief 创建一个原始的socket - /// - /// @param protocol 协议号 - /// @param options socket的选项 - /// - /// @return 返回创建的原始的socket - pub fn new(protocol: Protocol, options: SocketOptions) -> Self { - let rx_buffer = raw::PacketBuffer::new( - vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], - vec![0; Self::DEFAULT_RX_BUF_SIZE], - ); - let tx_buffer = raw::PacketBuffer::new( - vec![raw::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], - vec![0; Self::DEFAULT_TX_BUF_SIZE], - ); - let protocol: u8 = protocol.into(); - let socket = raw::Socket::new( - wire::IpVersion::Ipv4, - wire::IpProtocol::from(protocol), - rx_buffer, - tx_buffer, - ); - - // 把socket添加到socket集合中,并得到socket的句柄 - let handle = GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket)); - - let metadata = SocketMetadata::new( - SocketType::Raw, - Self::DEFAULT_RX_BUF_SIZE, - Self::DEFAULT_TX_BUF_SIZE, - Self::DEFAULT_METADATA_BUF_SIZE, - options, - ); - - let posix_item = Arc::new(PosixSocketHandleItem::new(None)); - - return Self { - handle, - header_included: false, - metadata, - posix_item, - }; - } -} - -impl Socket for RawSocket { - fn posix_item(&self) -> Arc { - self.posix_item.clone() - } - - fn close(&mut self) { - let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - if let smoltcp::socket::Socket::Udp(mut sock) = - socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()) - { - sock.close(); - } - drop(socket_set_guard); - poll_ifaces(); - } - - fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { - poll_ifaces(); - loop { - // 如何优化这里? - let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - let socket = - socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); - - match socket.recv_slice(buf) { - Ok(len) => { - let packet = wire::Ipv4Packet::new_unchecked(buf); - return ( - Ok(len), - Endpoint::Ip(Some(wire::IpEndpoint { - addr: wire::IpAddress::Ipv4(packet.src_addr()), - port: 0, - })), - ); - } - Err(_) => { - if !self.metadata.options.contains(SocketOptions::BLOCK) { - // 如果是非阻塞的socket,就返回错误 - return (Err(SystemError::EAGAIN_OR_EWOULDBLOCK), Endpoint::Ip(None)); - } - } - } - drop(socket_set_guard); - self.posix_item.sleep(EPollEventType::EPOLLIN.bits() as u64); - } - } - - fn write(&self, buf: &[u8], to: Option) -> Result { - // 如果用户发送的数据包,包含IP头,则直接发送 - if self.header_included { - let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - let socket = - socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); - match socket.send_slice(buf) { - Ok(_) => { - return Ok(buf.len()); - } - Err(raw::SendError::BufferFull) => { - return Err(SystemError::ENOBUFS); - } - } - } else { - // 如果用户发送的数据包,不包含IP头,则需要自己构造IP头 - - if let Some(Endpoint::Ip(Some(endpoint))) = to { - let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - let socket: &mut raw::Socket = - socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); - - // 暴力解决方案:只考虑0号网卡。 TODO:考虑多网卡的情况!!! - let iface = NET_DEVICES.read_irqsave().get(&0).unwrap().clone(); - - // 构造IP头 - let ipv4_src_addr: Option = - iface.inner_iface().lock().ipv4_addr(); - if ipv4_src_addr.is_none() { - return Err(SystemError::ENETUNREACH); - } - let ipv4_src_addr = ipv4_src_addr.unwrap(); - - if let wire::IpAddress::Ipv4(ipv4_dst) = endpoint.addr { - let len = buf.len(); - - // 创建20字节的IPv4头部 - let mut buffer: Vec = vec![0u8; len + 20]; - let mut packet: wire::Ipv4Packet<&mut Vec> = - wire::Ipv4Packet::new_unchecked(&mut buffer); - - // 封装ipv4 header - packet.set_version(4); - packet.set_header_len(20); - packet.set_total_len((20 + len) as u16); - packet.set_src_addr(ipv4_src_addr); - packet.set_dst_addr(ipv4_dst); - - // 设置ipv4 header的protocol字段 - packet.set_next_header(socket.ip_protocol()); - - // 获取IP数据包的负载字段 - let payload: &mut [u8] = packet.payload_mut(); - payload.copy_from_slice(buf); - - // 填充checksum字段 - packet.fill_checksum(); - - // 发送数据包 - socket.send_slice(&buffer).unwrap(); - - iface.poll(&mut socket_set_guard).ok(); - - drop(socket_set_guard); - return Ok(len); - } else { - warn!("Unsupport Ip protocol type!"); - return Err(SystemError::EINVAL); - } - } else { - // 如果没有指定目的地址,则返回错误 - return Err(SystemError::ENOTCONN); - } - } - } - - fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> { - Ok(()) - } - - fn metadata(&self) -> SocketMetadata { - self.metadata.clone() - } - - fn box_clone(&self) -> Box { - Box::new(self.clone()) - } - - fn socket_handle(&self) -> GlobalSocketHandle { - self.handle - } - - fn as_any_ref(&self) -> &dyn core::any::Any { - self - } - - fn as_any_mut(&mut self) -> &mut dyn core::any::Any { - self - } -} - -/// @brief 表示udp socket -/// -/// https://man7.org/linux/man-pages/man7/udp.7.html -#[derive(Debug, Clone)] -pub struct UdpSocket { - pub handle: GlobalSocketHandle, - remote_endpoint: Option, // 记录远程endpoint提供给connect(), 应该使用IP地址。 - metadata: SocketMetadata, - posix_item: Arc, -} - -impl UdpSocket { - /// 元数据的缓冲区的大小 - pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; - /// 默认的接收缓冲区的大小 receive - pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; - /// 默认的发送缓冲区的大小 transmiss - pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; - - /// @brief 创建一个udp的socket - /// - /// @param options socket的选项 - /// - /// @return 返回创建的udp的socket - pub fn new(options: SocketOptions) -> Self { - let rx_buffer = udp::PacketBuffer::new( - vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], - vec![0; Self::DEFAULT_RX_BUF_SIZE], - ); - let tx_buffer = udp::PacketBuffer::new( - vec![udp::PacketMetadata::EMPTY; Self::DEFAULT_METADATA_BUF_SIZE], - vec![0; Self::DEFAULT_TX_BUF_SIZE], - ); - let socket = udp::Socket::new(rx_buffer, tx_buffer); - - // 把socket添加到socket集合中,并得到socket的句柄 - let handle: GlobalSocketHandle = - GlobalSocketHandle::new_smoltcp_handle(SOCKET_SET.lock_irqsave().add(socket)); - - let metadata = SocketMetadata::new( - SocketType::Udp, - Self::DEFAULT_RX_BUF_SIZE, - Self::DEFAULT_TX_BUF_SIZE, - Self::DEFAULT_METADATA_BUF_SIZE, - options, - ); - - let posix_item = Arc::new(PosixSocketHandleItem::new(None)); - - return Self { - handle, - remote_endpoint: None, - metadata, - posix_item, - }; - } - - fn do_bind(&self, socket: &mut udp::Socket, endpoint: Endpoint) -> Result<(), SystemError> { - if let Endpoint::Ip(Some(mut ip)) = endpoint { - // 端口为0则分配随机端口 - if ip.port == 0 { - ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; - } - // 检测端口是否已被占用 - PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?; - - let bind_res = if ip.addr.is_unspecified() { - socket.bind(ip.port) - } else { - socket.bind(ip) - }; - - match bind_res { - Ok(()) => return Ok(()), - Err(_) => return Err(SystemError::EINVAL), - } - } else { - return Err(SystemError::EINVAL); - } - } -} - -impl Socket for UdpSocket { - fn posix_item(&self) -> Arc { - self.posix_item.clone() - } - - fn close(&mut self) { - let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - if let smoltcp::socket::Socket::Udp(mut sock) = - socket_set_guard.remove(self.handle.smoltcp_handle().unwrap()) - { - sock.close(); - } - drop(socket_set_guard); - poll_ifaces(); - } - - /// @brief 在read函数执行之前,请先bind到本地的指定端口 - fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { - loop { - // debug!("Wait22 to Read"); - poll_ifaces(); - let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - let socket = - socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); - - // debug!("Wait to Read"); - - if socket.can_recv() { - if let Ok((size, metadata)) = socket.recv_slice(buf) { - drop(socket_set_guard); - poll_ifaces(); - return (Ok(size), Endpoint::Ip(Some(metadata.endpoint))); - } - } else { - // 如果socket没有连接,则忙等 - // return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); - } - drop(socket_set_guard); - self.posix_item.sleep(EPollEventType::EPOLLIN.bits() as u64); - } - } - - fn write(&self, buf: &[u8], to: Option) -> Result { - // debug!("udp to send: {:?}, len={}", to, buf.len()); - let remote_endpoint: &wire::IpEndpoint = { - if let Some(Endpoint::Ip(Some(ref endpoint))) = to { - endpoint - } else if let Some(Endpoint::Ip(Some(ref endpoint))) = self.remote_endpoint { - endpoint - } else { - return Err(SystemError::ENOTCONN); - } - }; - // debug!("udp write: remote = {:?}", remote_endpoint); - - let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - let socket = socket_set_guard.get_mut::(self.handle.smoltcp_handle().unwrap()); - // debug!("is open()={}", socket.is_open()); - // debug!("socket endpoint={:?}", socket.endpoint()); - if socket.can_send() { - // debug!("udp write: can send"); - match socket.send_slice(buf, *remote_endpoint) { - Ok(()) => { - // debug!("udp write: send ok"); - drop(socket_set_guard); - poll_ifaces(); - return Ok(buf.len()); - } - Err(_) => { - // debug!("udp write: send err"); - return Err(SystemError::ENOBUFS); - } - } - } else { - // debug!("udp write: can not send"); - return Err(SystemError::ENOBUFS); - }; - } - - fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { - let mut sockets = SOCKET_SET.lock_irqsave(); - let socket = sockets.get_mut::(self.handle.smoltcp_handle().unwrap()); - // debug!("UDP Bind to {:?}", endpoint); - return self.do_bind(socket, endpoint); - } - - fn poll(&self) -> EPollEventType { - let sockets = SOCKET_SET.lock_irqsave(); - let socket = sockets.get::(self.handle.smoltcp_handle().unwrap()); - - return SocketPollMethod::udp_poll( - socket, - HANDLE_MAP - .read_irqsave() - .get(&self.socket_handle()) - .unwrap() - .shutdown_type(), - ); - } - - fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { - if let Endpoint::Ip(_) = endpoint { - self.remote_endpoint = Some(endpoint); - Ok(()) - } else { - Err(SystemError::EINVAL) - } - } - - fn ioctl( - &self, - _cmd: usize, - _arg0: usize, - _arg1: usize, - _arg2: usize, - ) -> Result { - todo!() - } - - fn metadata(&self) -> SocketMetadata { - self.metadata.clone() - } - - fn box_clone(&self) -> Box { - return Box::new(self.clone()); - } - - fn endpoint(&self) -> Option { - let sockets = SOCKET_SET.lock_irqsave(); - let socket = sockets.get::(self.handle.smoltcp_handle().unwrap()); - let listen_endpoint = socket.endpoint(); - - if listen_endpoint.port == 0 { - return None; - } else { - // 如果listen_endpoint的address是None,意味着“监听所有的地址”。 - // 这里假设所有的地址都是ipv4 - // TODO: 支持ipv6 - let result = wire::IpEndpoint::new( - listen_endpoint - .addr - .unwrap_or(wire::IpAddress::v4(0, 0, 0, 0)), - listen_endpoint.port, - ); - return Some(Endpoint::Ip(Some(result))); - } - } - - fn peer_endpoint(&self) -> Option { - return self.remote_endpoint.clone(); - } - - fn socket_handle(&self) -> GlobalSocketHandle { - self.handle - } - - fn as_any_ref(&self) -> &dyn core::any::Any { - self - } - - fn as_any_mut(&mut self) -> &mut dyn core::any::Any { - self - } -} - -/// @brief 表示 tcp socket -/// -/// https://man7.org/linux/man-pages/man7/tcp.7.html -#[derive(Debug, Clone)] -pub struct TcpSocket { - handles: Vec, - local_endpoint: Option, // save local endpoint for bind() - is_listening: bool, - metadata: SocketMetadata, - posix_item: Arc, -} - -impl TcpSocket { - /// 元数据的缓冲区的大小 - pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; - /// 默认的接收缓冲区的大小 receive - pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024; - /// 默认的发送缓冲区的大小 transmiss - pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024; - - /// TcpSocket的特殊事件,用于在事件等待队列上sleep - pub const CAN_CONNECT: u64 = 1u64 << 63; - pub const CAN_ACCPET: u64 = 1u64 << 62; - - /// @brief 创建一个tcp的socket - /// - /// @param options socket的选项 - /// - /// @return 返回创建的tcp的socket - pub fn new(options: SocketOptions) -> Self { - // 创建handles数组并把socket添加到socket集合中,并得到socket的句柄 - let handles: Vec = vec![GlobalSocketHandle::new_smoltcp_handle( - SOCKET_SET.lock_irqsave().add(Self::create_new_socket()), - )]; - - let metadata = SocketMetadata::new( - SocketType::Tcp, - Self::DEFAULT_RX_BUF_SIZE, - Self::DEFAULT_TX_BUF_SIZE, - Self::DEFAULT_METADATA_BUF_SIZE, - options, - ); - let posix_item = Arc::new(PosixSocketHandleItem::new(None)); - // debug!("when there's a new tcp socket,its'len: {}",handles.len()); - - return Self { - handles, - local_endpoint: None, - is_listening: false, - metadata, - posix_item, - }; - } - - fn do_listen( - &mut self, - socket: &mut tcp::Socket, - local_endpoint: wire::IpEndpoint, - ) -> Result<(), SystemError> { - let listen_result = if local_endpoint.addr.is_unspecified() { - socket.listen(local_endpoint.port) - } else { - socket.listen(local_endpoint) - }; - return match listen_result { - Ok(()) => { - // debug!( - // "Tcp Socket Listen on {local_endpoint}, open?:{}", - // socket.is_open() - // ); - self.is_listening = true; - - Ok(()) - } - Err(_) => Err(SystemError::EINVAL), - }; - } - - /// # create_new_socket - 创建新的TCP套接字 - /// - /// 该函数用于创建一个新的TCP套接字,并返回该套接字的引用。 - fn create_new_socket() -> tcp::Socket<'static> { - // 初始化tcp的buffer - let rx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_RX_BUF_SIZE]); - let tx_buffer = tcp::SocketBuffer::new(vec![0; Self::DEFAULT_TX_BUF_SIZE]); - tcp::Socket::new(rx_buffer, tx_buffer) - } - - /// listening状态的posix socket是需要特殊处理的 - fn tcp_poll_listening(&self) -> EPollEventType { - let socketset_guard = SOCKET_SET.lock_irqsave(); - - let can_accept = self.handles.iter().any(|h| { - if let Some(sh) = h.smoltcp_handle() { - let socket = socketset_guard.get::(sh); - socket.is_active() - } else { - false - } - }); - - if can_accept { - return EPollEventType::EPOLL_LISTEN_CAN_ACCEPT; - } else { - return EPollEventType::empty(); - } - } -} - -impl Socket for TcpSocket { - fn posix_item(&self) -> Arc { - self.posix_item.clone() - } - - fn close(&mut self) { - for handle in self.handles.iter() { - { - let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - let smoltcp_handle = handle.smoltcp_handle().unwrap(); - socket_set_guard - .get_mut::(smoltcp_handle) - .close(); - drop(socket_set_guard); - } - poll_ifaces(); - SOCKET_SET - .lock_irqsave() - .remove(handle.smoltcp_handle().unwrap()); - // debug!("[Socket] [TCP] Close: {:?}", handle); - } - } - - fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { - if HANDLE_MAP - .read_irqsave() - .get(&self.socket_handle()) - .unwrap() - .shutdown_type() - .contains(ShutdownType::RCV_SHUTDOWN) - { - return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); - } - // debug!("tcp socket: read, buf len={}", buf.len()); - // debug!("tcp socket:read, socket'len={}",self.handle.len()); - loop { - poll_ifaces(); - let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - - let socket = socket_set_guard - .get_mut::(self.handles.first().unwrap().smoltcp_handle().unwrap()); - - // 如果socket已经关闭,返回错误 - if !socket.is_active() { - // debug!("Tcp Socket Read Error, socket is closed"); - return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); - } - - if socket.may_recv() { - match socket.recv_slice(buf) { - Ok(size) => { - if size > 0 { - let endpoint = if let Some(p) = socket.remote_endpoint() { - p - } else { - return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); - }; - - drop(socket_set_guard); - poll_ifaces(); - return (Ok(size), Endpoint::Ip(Some(endpoint))); - } - } - Err(tcp::RecvError::InvalidState) => { - warn!("Tcp Socket Read Error, InvalidState"); - return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); - } - Err(tcp::RecvError::Finished) => { - // 对端写端已关闭,我们应该关闭读端 - HANDLE_MAP - .write_irqsave() - .get_mut(&self.socket_handle()) - .unwrap() - .shutdown_type_writer() - .insert(ShutdownType::RCV_SHUTDOWN); - return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); - } - } - } else { - return (Err(SystemError::ENOTCONN), Endpoint::Ip(None)); - } - drop(socket_set_guard); - self.posix_item - .sleep((EPollEventType::EPOLLIN | EPollEventType::EPOLLHUP).bits() as u64); - } - } - - fn write(&self, buf: &[u8], _to: Option) -> Result { - if HANDLE_MAP - .read_irqsave() - .get(&self.socket_handle()) - .unwrap() - .shutdown_type() - .contains(ShutdownType::RCV_SHUTDOWN) - { - return Err(SystemError::ENOTCONN); - } - // debug!("tcp socket:write, socket'len={}",self.handle.len()); - - let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - - let socket = socket_set_guard - .get_mut::(self.handles.first().unwrap().smoltcp_handle().unwrap()); - - if socket.is_open() { - if socket.can_send() { - match socket.send_slice(buf) { - Ok(size) => { - drop(socket_set_guard); - poll_ifaces(); - return Ok(size); - } - Err(e) => { - error!("Tcp Socket Write Error {e:?}"); - return Err(SystemError::ENOBUFS); - } - } - } else { - return Err(SystemError::ENOBUFS); - } - } - - return Err(SystemError::ENOTCONN); - } - - fn poll(&self) -> EPollEventType { - // 处理listen的快速路径 - if self.is_listening { - return self.tcp_poll_listening(); - } - // 由于上面处理了listening状态,所以这里只处理非listening状态,这种情况下只有一个handle - - assert!(self.handles.len() == 1); - - let mut socket_set_guard = SOCKET_SET.lock_irqsave(); - // debug!("tcp socket:poll, socket'len={}",self.handle.len()); - - let socket = socket_set_guard - .get_mut::(self.handles.first().unwrap().smoltcp_handle().unwrap()); - let handle_map_guard = HANDLE_MAP.read_irqsave(); - let handle_item = handle_map_guard.get(&self.socket_handle()).unwrap(); - let shutdown_type = handle_item.shutdown_type(); - let is_posix_listen = handle_item.is_posix_listen; - drop(handle_map_guard); - - return SocketPollMethod::tcp_poll(socket, shutdown_type, is_posix_listen); - } - - fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { - let mut sockets = SOCKET_SET.lock_irqsave(); - // debug!("tcp socket:connect, socket'len={}", self.handles.len()); - - let socket = - sockets.get_mut::(self.handles.first().unwrap().smoltcp_handle().unwrap()); - - if let Endpoint::Ip(Some(ip)) = endpoint { - let temp_port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; - // 检测端口是否被占用 - PORT_MANAGER.bind_port(self.metadata.socket_type, temp_port)?; - - // debug!("temp_port: {}", temp_port); - let iface: Arc = NET_DEVICES.write_irqsave().get(&0).unwrap().clone(); - let mut inner_iface = iface.inner_iface().lock(); - // debug!("to connect: {ip:?}"); - - match socket.connect(inner_iface.context(), ip, temp_port) { - Ok(()) => { - // avoid deadlock - drop(inner_iface); - drop(iface); - drop(sockets); - loop { - poll_ifaces(); - let mut sockets = SOCKET_SET.lock_irqsave(); - let socket = sockets.get_mut::( - self.handles.first().unwrap().smoltcp_handle().unwrap(), - ); - - match socket.state() { - tcp::State::Established => { - return Ok(()); - } - tcp::State::SynSent => { - drop(sockets); - self.posix_item.sleep(Self::CAN_CONNECT); - } - _ => { - return Err(SystemError::ECONNREFUSED); - } - } - } - } - Err(e) => { - // error!("Tcp Socket Connect Error {e:?}"); - match e { - tcp::ConnectError::InvalidState => return Err(SystemError::EISCONN), - tcp::ConnectError::Unaddressable => return Err(SystemError::EADDRNOTAVAIL), - } - } - } - } else { - return Err(SystemError::EINVAL); - } - } - - /// @brief tcp socket 监听 local_endpoint 端口 - /// - /// @param backlog 未处理的连接队列的最大长度 - fn listen(&mut self, backlog: usize) -> Result<(), SystemError> { - if self.is_listening { - return Ok(()); - } - - // debug!( - // "tcp socket:listen, socket'len={}, backlog = {backlog}", - // self.handles.len() - // ); - - let local_endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; - let mut sockets = SOCKET_SET.lock_irqsave(); - // 获取handle的数量 - let handlen = self.handles.len(); - let backlog = handlen.max(backlog); - - // 添加剩余需要构建的socket - // debug!("tcp socket:before listen, socket'len={}", self.handle_list.len()); - let mut handle_guard = HANDLE_MAP.write_irqsave(); - let socket_handle_item_0 = handle_guard.get_mut(&self.socket_handle()).unwrap(); - socket_handle_item_0.is_posix_listen = true; - - self.handles.extend((handlen..backlog).map(|_| { - let socket = Self::create_new_socket(); - let handle = GlobalSocketHandle::new_smoltcp_handle(sockets.add(socket)); - let mut handle_item = SocketHandleItem::new(Arc::downgrade(&self.posix_item)); - handle_item.is_posix_listen = true; - handle_guard.insert(handle, handle_item); - handle - })); - - // debug!("tcp socket:listen, socket'len={}", self.handles.len()); - // debug!("tcp socket:listen, backlog={backlog}"); - - // 监听所有的socket - for i in 0..backlog { - let handle = self.handles.get(i).unwrap(); - - let socket = sockets.get_mut::(handle.smoltcp_handle().unwrap()); - - if !socket.is_listening() { - // debug!("Tcp Socket is already listening on {local_endpoint}"); - self.do_listen(socket, local_endpoint)?; - } - // debug!("Tcp Socket before listen, open={}", socket.is_open()); - } - - return Ok(()); - } - - fn bind(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { - if let Endpoint::Ip(Some(mut ip)) = endpoint { - if ip.port == 0 { - ip.port = PORT_MANAGER.get_ephemeral_port(self.metadata.socket_type)?; - } - - // 检测端口是否已被占用 - PORT_MANAGER.bind_port(self.metadata.socket_type, ip.port)?; - // debug!("tcp socket:bind, socket'len={}",self.handle.len()); - - self.local_endpoint = Some(ip); - self.is_listening = false; - - return Ok(()); - } - return Err(SystemError::EINVAL); - } - - fn shutdown(&mut self, shutdown_type: super::ShutdownType) -> Result<(), SystemError> { - // TODO:目前只是在表层判断,对端不知晓,后续需使用tcp实现 - HANDLE_MAP - .write_irqsave() - .get_mut(&self.socket_handle()) - .unwrap() - .shutdown_type = RwLock::new(shutdown_type); - return Ok(()); - } - - fn accept(&mut self) -> Result<(Box, Endpoint), SystemError> { - if !self.is_listening { - return Err(SystemError::EINVAL); - } - let endpoint = self.local_endpoint.ok_or(SystemError::EINVAL)?; - loop { - // debug!("tcp accept: poll_ifaces()"); - poll_ifaces(); - // debug!("tcp socket:accept, socket'len={}", self.handle_list.len()); - - let mut sockset = SOCKET_SET.lock_irqsave(); - // Get the corresponding activated handler - let global_handle_index = self.handles.iter().position(|handle| { - let con_smol_sock = sockset.get::(handle.smoltcp_handle().unwrap()); - con_smol_sock.is_active() - }); - - if let Some(handle_index) = global_handle_index { - let con_smol_sock = sockset - .get::(self.handles[handle_index].smoltcp_handle().unwrap()); - - // debug!("[Socket] [TCP] Accept: {:?}", handle); - // handle is connected socket's handle - let remote_ep = con_smol_sock - .remote_endpoint() - .ok_or(SystemError::ENOTCONN)?; - - let tcp_socket = Self::create_new_socket(); - - let new_handle = GlobalSocketHandle::new_smoltcp_handle(sockset.add(tcp_socket)); - - // let handle in TcpSock be the new empty handle, and return the old connected handle - let old_handle = core::mem::replace(&mut self.handles[handle_index], new_handle); - - let metadata = SocketMetadata::new( - SocketType::Tcp, - Self::DEFAULT_TX_BUF_SIZE, - Self::DEFAULT_RX_BUF_SIZE, - Self::DEFAULT_METADATA_BUF_SIZE, - self.metadata.options, - ); - - let sock_ret = Box::new(TcpSocket { - handles: vec![old_handle], - local_endpoint: self.local_endpoint, - is_listening: false, - metadata, - posix_item: Arc::new(PosixSocketHandleItem::new(None)), - }); - - { - let mut handle_guard = HANDLE_MAP.write_irqsave(); - // 先删除原来的 - let item = handle_guard.remove(&old_handle).unwrap(); - item.reset_shutdown_type(); - assert!(item.is_posix_listen); - - // 按照smoltcp行为,将新的handle绑定到原来的item - let new_item = SocketHandleItem::new(Arc::downgrade(&sock_ret.posix_item)); - handle_guard.insert(old_handle, new_item); - // 插入新的item - handle_guard.insert(new_handle, item); - - let socket = sockset.get_mut::( - self.handles[handle_index].smoltcp_handle().unwrap(), - ); - - if !socket.is_listening() { - self.do_listen(socket, endpoint)?; - } - - drop(handle_guard); - } - - return Ok((sock_ret, Endpoint::Ip(Some(remote_ep)))); - } - - drop(sockset); - - // debug!("[TCP] [Accept] sleeping socket with handle: {:?}", self.handles.first().unwrap().smoltcp_handle().unwrap()); - self.posix_item.sleep(Self::CAN_ACCPET); - // debug!("tcp socket:after sleep, handle_guard'len={}",HANDLE_MAP.write_irqsave().len()); - } - } - - fn endpoint(&self) -> Option { - let mut result: Option = self.local_endpoint.map(|x| Endpoint::Ip(Some(x))); - - if result.is_none() { - let sockets = SOCKET_SET.lock_irqsave(); - // debug!("tcp socket:endpoint, socket'len={}",self.handle.len()); - - let socket = - sockets.get::(self.handles.first().unwrap().smoltcp_handle().unwrap()); - if let Some(ep) = socket.local_endpoint() { - result = Some(Endpoint::Ip(Some(ep))); - } - } - return result; - } - - fn peer_endpoint(&self) -> Option { - let sockets = SOCKET_SET.lock_irqsave(); - // debug!("tcp socket:peer_endpoint, socket'len={}",self.handle.len()); - - let socket = - sockets.get::(self.handles.first().unwrap().smoltcp_handle().unwrap()); - return socket.remote_endpoint().map(|x| Endpoint::Ip(Some(x))); - } - - fn metadata(&self) -> SocketMetadata { - self.metadata.clone() - } - - fn box_clone(&self) -> Box { - Box::new(self.clone()) - } - - fn socket_handle(&self) -> GlobalSocketHandle { - // debug!("tcp socket:socket_handle, socket'len={}",self.handle.len()); - - *self.handles.first().unwrap() - } - - fn as_any_ref(&self) -> &dyn core::any::Any { - self - } - - fn as_any_mut(&mut self) -> &mut dyn core::any::Any { - self - } -} diff --git a/kernel/src/net/socket/inet/common/mod.rs b/kernel/src/net/socket/inet/common/mod.rs new file mode 100644 index 000000000..8b60c0718 --- /dev/null +++ b/kernel/src/net/socket/inet/common/mod.rs @@ -0,0 +1,148 @@ +use crate::net::{Iface, NET_DEVICES}; +use alloc::sync::Arc; + +pub mod port; +pub use port::PortManager; +use system_error::SystemError; + +#[allow(dead_code)] +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Types { + Raw, + Icmp, + Udp, + Tcp, + Dhcpv4, + Dns, +} + +/** + * 目前,以下设计仍然没有考虑多网卡的listen问题,仅只解决了socket在绑定单网卡下的问题。 + */ + +#[derive(Debug)] +pub struct BoundInner { + handle: smoltcp::iface::SocketHandle, + iface: Arc, + // inner: Vec<(smoltcp::iface::SocketHandle, Arc)> + // address: smoltcp::wire::IpAddress, +} + +impl BoundInner { + /// # `bind` + /// 将socket绑定到指定的地址上,置入指定的网络接口中 + pub fn bind( + socket: T, + // socket_type: Types, + address: &smoltcp::wire::IpAddress, + ) -> Result + where + T: smoltcp::socket::AnySocket<'static>, + { + if address.is_unspecified() { + // 强绑VirtualIO + let iface = NET_DEVICES + .read_irqsave() + .iter() + .find_map(|(_, v)| { + if v.common().is_default_iface() { + Some(v.clone()) + } else { + None + } + }) + .expect("No default interface"); + + let handle = iface.sockets().lock().add(socket); + return Ok(Self { handle, iface }); + } else { + let iface = get_iface_to_bind(address).ok_or(SystemError::ENODEV)?; + let handle = iface.sockets().lock().add(socket); + return Ok(Self { handle, iface }); + } + } + + pub fn bind_ephemeral( + socket: T, + // socket_type: Types, + remote: smoltcp::wire::IpAddress, + ) -> Result<(Self, smoltcp::wire::IpAddress), SystemError> + where + T: smoltcp::socket::AnySocket<'static>, + { + let (iface, address) = get_ephemeral_iface(&remote); + // let bound_port = iface.port_manager().bind_ephemeral_port(socket_type)?; + let handle = iface.sockets().lock().add(socket); + // let endpoint = smoltcp::wire::IpEndpoint::new(local_addr, bound_port); + Ok((Self { handle, iface }, address)) + } + + pub fn port_manager(&self) -> &PortManager { + self.iface.port_manager() + } + + pub fn with_mut, R, F: FnMut(&mut T) -> R>( + &self, + mut f: F, + ) -> R { + f(self.iface.sockets().lock().get_mut::(self.handle)) + } + + pub fn with, R, F: Fn(&T) -> R>(&self, f: F) -> R { + f(self.iface.sockets().lock().get::(self.handle)) + } + + pub fn iface(&self) -> &Arc { + &self.iface + } + + pub fn release(&self) { + self.iface.sockets().lock().remove(self.handle); + } +} + +#[inline] +pub fn get_iface_to_bind(ip_addr: &smoltcp::wire::IpAddress) -> Option> { + // log::debug!("get_iface_to_bind: {:?}", ip_addr); + // if ip_addr.is_unspecified() + crate::net::NET_DEVICES + .read_irqsave() + .iter() + .find(|(_, iface)| { + let guard = iface.smol_iface().lock(); + // log::debug!("iface name: {}, ip: {:?}", iface.iface_name(), guard.ip_addrs()); + return guard.has_ip_addr(*ip_addr); + }) + .map(|(_, iface)| iface.clone()) +} + +/// Get a suitable iface to deal with sendto/connect request if the socket is not bound to an iface. +/// If the remote address is the same as that of some iface, we will use the iface. +/// Otherwise, we will use a default interface. +fn get_ephemeral_iface( + remote_ip_addr: &smoltcp::wire::IpAddress, +) -> (Arc, smoltcp::wire::IpAddress) { + get_iface_to_bind(remote_ip_addr) + .map(|iface| (iface, *remote_ip_addr)) + .or({ + let ifaces = NET_DEVICES.read_irqsave(); + ifaces.iter().find_map(|(_, iface)| { + iface + .smol_iface() + .lock() + .ip_addrs() + .iter() + .find(|cidr| cidr.contains_addr(remote_ip_addr)) + .map(|cidr| (iface.clone(), cidr.address())) + }) + }) + .or({ + NET_DEVICES.read_irqsave().values().next().map(|iface| { + ( + iface.clone(), + iface.smol_iface().lock().ip_addrs()[0].address(), + ) + }) + }) + .expect("No network interface") +} diff --git a/kernel/src/net/socket/inet/common/port.rs b/kernel/src/net/socket/inet/common/port.rs new file mode 100644 index 000000000..9c55e08ea --- /dev/null +++ b/kernel/src/net/socket/inet/common/port.rs @@ -0,0 +1,114 @@ +use hashbrown::HashMap; +use system_error::SystemError; + +use crate::{ + arch::rand::rand, + libs::spinlock::SpinLock, + process::{Pid, ProcessManager}, +}; + +use super::Types::{self, *}; + +/// # TCP 和 UDP 的端口管理器。 +/// 如果 TCP/UDP 的 socket 绑定了某个端口,它会在对应的表中记录,以检测端口冲突。 +#[derive(Debug)] +pub struct PortManager { + // TCP 端口记录表 + tcp_port_table: SpinLock>, + // UDP 端口记录表 + udp_port_table: SpinLock>, +} + +impl PortManager { + pub fn new() -> Self { + return Self { + tcp_port_table: SpinLock::new(HashMap::new()), + udp_port_table: SpinLock::new(HashMap::new()), + }; + } + + /// @brief 自动分配一个相对应协议中未被使用的PORT,如果动态端口均已被占用,返回错误码 EADDRINUSE + pub fn get_ephemeral_port(&self, socket_type: Types) -> Result { + // TODO: selects non-conflict high port + + static mut EPHEMERAL_PORT: u16 = 0; + unsafe { + if EPHEMERAL_PORT == 0 { + EPHEMERAL_PORT = (49152 + rand() % (65536 - 49152)) as u16; + } + } + + let mut remaining = 65536 - 49152; // 剩余尝试分配端口次数 + let mut port: u16; + while remaining > 0 { + unsafe { + if EPHEMERAL_PORT == 65535 { + EPHEMERAL_PORT = 49152; + } else { + EPHEMERAL_PORT += 1; + } + port = EPHEMERAL_PORT; + } + + // 使用 ListenTable 检查端口是否被占用 + let listen_table_guard = match socket_type { + Udp => self.udp_port_table.lock(), + Tcp => self.tcp_port_table.lock(), + _ => panic!("{:?} cann't get a port", socket_type), + }; + if listen_table_guard.get(&port).is_none() { + drop(listen_table_guard); + return Ok(port); + } + remaining -= 1; + } + return Err(SystemError::EADDRINUSE); + } + + #[inline] + pub fn bind_ephemeral_port(&self, socket_type: Types) -> Result { + let port = self.get_ephemeral_port(socket_type)?; + self.bind_port(socket_type, port)?; + return Ok(port); + } + + /// @brief 检测给定端口是否已被占用,如果未被占用则在 TCP/UDP 对应的表中记录 + /// + /// TODO: 增加支持端口复用的逻辑 + pub fn bind_port(&self, socket_type: Types, port: u16) -> Result<(), SystemError> { + if port > 0 { + match socket_type { + Udp => { + let mut guard = self.udp_port_table.lock(); + if guard.get(&port).is_some() { + return Err(SystemError::EADDRINUSE); + } + guard.insert(port, ProcessManager::current_pid()); + } + Tcp => { + let mut guard = self.tcp_port_table.lock(); + if guard.get(&port).is_some() { + return Err(SystemError::EADDRINUSE); + } + guard.insert(port, ProcessManager::current_pid()); + } + _ => {} + }; + } + return Ok(()); + } + + /// @brief 在对应的端口记录表中将端口和 socket 解绑 + /// should call this function when socket is closed or aborted + pub fn unbind_port(&self, socket_type: Types, port: u16) { + match socket_type { + Udp => { + self.udp_port_table.lock().remove(&port); + } + Tcp => { + self.tcp_port_table.lock().remove(&port); + } + _ => {} + }; + } +} diff --git a/kernel/src/net/socket/inet/datagram/inner.rs b/kernel/src/net/socket/inet/datagram/inner.rs new file mode 100644 index 000000000..de1c3bd6f --- /dev/null +++ b/kernel/src/net/socket/inet/datagram/inner.rs @@ -0,0 +1,161 @@ +use smoltcp; +use system_error::SystemError; + +use crate::{ + libs::spinlock::SpinLock, + net::socket::inet::common::{BoundInner, Types as InetTypes}, +}; + +pub type SmolUdpSocket = smoltcp::socket::udp::Socket<'static>; + +pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; +pub const DEFAULT_RX_BUF_SIZE: usize = 64 * 1024; +pub const DEFAULT_TX_BUF_SIZE: usize = 64 * 1024; + +#[derive(Debug)] +pub struct UnboundUdp { + socket: SmolUdpSocket, +} + +impl UnboundUdp { + pub fn new() -> Self { + let rx_buffer = smoltcp::socket::udp::PacketBuffer::new( + vec![smoltcp::socket::udp::PacketMetadata::EMPTY; DEFAULT_METADATA_BUF_SIZE], + vec![0; DEFAULT_RX_BUF_SIZE], + ); + let tx_buffer = smoltcp::socket::udp::PacketBuffer::new( + vec![smoltcp::socket::udp::PacketMetadata::EMPTY; DEFAULT_METADATA_BUF_SIZE], + vec![0; DEFAULT_TX_BUF_SIZE], + ); + let socket = SmolUdpSocket::new(rx_buffer, tx_buffer); + + return Self { socket }; + } + + pub fn bind(self, local_endpoint: smoltcp::wire::IpEndpoint) -> Result { + let inner = BoundInner::bind(self.socket, &local_endpoint.addr)?; + let bind_addr = local_endpoint.addr; + let bind_port = if local_endpoint.port == 0 { + inner.port_manager().bind_ephemeral_port(InetTypes::Udp)? + } else { + inner + .port_manager() + .bind_port(InetTypes::Udp, local_endpoint.port)?; + local_endpoint.port + }; + + if bind_addr.is_unspecified() { + if inner + .with_mut::(|socket| socket.bind(bind_port)) + .is_err() + { + return Err(SystemError::EINVAL); + } + } else if inner + .with_mut::(|socket| { + socket.bind(smoltcp::wire::IpEndpoint::new(bind_addr, bind_port)) + }) + .is_err() + { + return Err(SystemError::EINVAL); + } + Ok(BoundUdp { + inner, + remote: SpinLock::new(None), + }) + } + + pub fn bind_ephemeral(self, remote: smoltcp::wire::IpAddress) -> Result { + // let (addr, port) = (remote.addr, remote.port); + let (inner, address) = BoundInner::bind_ephemeral(self.socket, remote)?; + let bound_port = inner.port_manager().bind_ephemeral_port(InetTypes::Udp)?; + let endpoint = smoltcp::wire::IpEndpoint::new(address, bound_port); + Ok(BoundUdp { + inner, + remote: SpinLock::new(Some(endpoint)), + }) + } +} + +#[derive(Debug)] +pub struct BoundUdp { + inner: BoundInner, + remote: SpinLock>, +} + +impl BoundUdp { + pub fn with_mut_socket(&self, f: F) -> T + where + F: FnMut(&mut SmolUdpSocket) -> T, + { + self.inner.with_mut(f) + } + + pub fn with_socket(&self, f: F) -> T + where + F: Fn(&SmolUdpSocket) -> T, + { + self.inner.with(f) + } + + pub fn endpoint(&self) -> smoltcp::wire::IpListenEndpoint { + self.inner + .with::(|socket| socket.endpoint()) + } + + pub fn connect(&self, remote: smoltcp::wire::IpEndpoint) { + self.remote.lock().replace(remote); + } + + #[inline] + pub fn try_recv( + &self, + buf: &mut [u8], + ) -> Result<(usize, smoltcp::wire::IpEndpoint), SystemError> { + self.with_mut_socket(|socket| { + if socket.can_recv() { + if let Ok((size, metadata)) = socket.recv_slice(buf) { + return Ok((size, metadata.endpoint)); + } + } + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + }) + } + + pub fn try_send( + &self, + buf: &[u8], + to: Option, + ) -> Result { + let remote = to.or(*self.remote.lock()).ok_or(SystemError::ENOTCONN)?; + let result = self.with_mut_socket(|socket| { + if socket.can_send() && socket.send_slice(buf, remote).is_ok() { + log::debug!("send {} bytes", buf.len()); + return Ok(buf.len()); + } + return Err(SystemError::ENOBUFS); + }); + return result; + } + + pub fn inner(&self) -> &BoundInner { + &self.inner + } + + pub fn close(&self) { + self.inner + .iface() + .port_manager() + .unbind_port(InetTypes::Udp, self.endpoint().port); + self.with_mut_socket(|socket| { + socket.close(); + }); + } +} + +// Udp Inner 负责其内部资源管理 +#[derive(Debug)] +pub enum UdpInner { + Unbound(UnboundUdp), + Bound(BoundUdp), +} diff --git a/kernel/src/net/socket/inet/datagram/mod.rs b/kernel/src/net/socket/inet/datagram/mod.rs new file mode 100644 index 000000000..cdd51a958 --- /dev/null +++ b/kernel/src/net/socket/inet/datagram/mod.rs @@ -0,0 +1,310 @@ +use inner::{UdpInner, UnboundUdp}; +use smoltcp; +use system_error::SystemError; + +use crate::filesystem::epoll::EPollEventType; +use crate::libs::wait_queue::WaitQueue; +use crate::net::socket::{Socket, PMSG}; +use crate::{libs::rwlock::RwLock, net::socket::endpoint::Endpoint}; +use alloc::sync::{Arc, Weak}; +use core::sync::atomic::AtomicBool; + +use super::InetSocket; + +pub mod inner; + +type EP = crate::filesystem::epoll::EPollEventType; + +// Udp Socket 负责提供状态切换接口、执行状态切换 +#[derive(Debug)] +pub struct UdpSocket { + inner: RwLock>, + nonblock: AtomicBool, + wait_queue: WaitQueue, + self_ref: Weak, +} + +impl UdpSocket { + pub fn new(nonblock: bool) -> Arc { + return Arc::new_cyclic(|me| Self { + inner: RwLock::new(Some(UdpInner::Unbound(UnboundUdp::new()))), + nonblock: AtomicBool::new(nonblock), + wait_queue: WaitQueue::default(), + self_ref: me.clone(), + }); + } + + pub fn is_nonblock(&self) -> bool { + self.nonblock.load(core::sync::atomic::Ordering::Relaxed) + } + + pub fn do_bind(&self, local_endpoint: smoltcp::wire::IpEndpoint) -> Result<(), SystemError> { + let mut inner = self.inner.write(); + if let Some(UdpInner::Unbound(unbound)) = inner.take() { + let bound = unbound.bind(local_endpoint)?; + + bound + .inner() + .iface() + .common() + .bind_socket(self.self_ref.upgrade().unwrap()); + *inner = Some(UdpInner::Bound(bound)); + return Ok(()); + } + return Err(SystemError::EINVAL); + } + + pub fn bind_emphemeral(&self, remote: smoltcp::wire::IpAddress) -> Result<(), SystemError> { + let mut inner_guard = self.inner.write(); + let bound = match inner_guard.take().expect("Udp inner is None") { + UdpInner::Bound(inner) => inner, + UdpInner::Unbound(inner) => inner.bind_ephemeral(remote)?, + }; + inner_guard.replace(UdpInner::Bound(bound)); + return Ok(()); + } + + pub fn is_bound(&self) -> bool { + let inner = self.inner.read(); + if let Some(UdpInner::Bound(_)) = &*inner { + return true; + } + return false; + } + + pub fn close(&self) { + let mut inner = self.inner.write(); + if let Some(UdpInner::Bound(bound)) = &mut *inner { + bound.close(); + inner.take(); + } + // unbound socket just drop (only need to free memory) + } + + pub fn try_recv( + &self, + buf: &mut [u8], + ) -> Result<(usize, smoltcp::wire::IpEndpoint), SystemError> { + match self.inner.read().as_ref().expect("Udp Inner is None") { + UdpInner::Bound(bound) => { + let ret = bound.try_recv(buf); + bound.inner().iface().poll(); + ret + } + _ => Err(SystemError::ENOTCONN), + } + } + + #[inline] + pub fn can_recv(&self) -> bool { + self.event().contains(EP::EPOLLIN) + } + + #[inline] + #[allow(dead_code)] + pub fn can_send(&self) -> bool { + self.event().contains(EP::EPOLLOUT) + } + + pub fn try_send( + &self, + buf: &[u8], + to: Option, + ) -> Result { + { + let mut inner_guard = self.inner.write(); + let inner = match inner_guard.take().expect("Udp Inner is None") { + UdpInner::Bound(bound) => bound, + UdpInner::Unbound(unbound) => { + unbound.bind_ephemeral(to.ok_or(SystemError::EADDRNOTAVAIL)?.addr)? + } + }; + // size = inner.try_send(buf, to)?; + inner_guard.replace(UdpInner::Bound(inner)); + }; + // Optimize: 拿两次锁的平均效率是否比一次长时间的读锁效率要高? + let result = match self.inner.read().as_ref().expect("Udp Inner is None") { + UdpInner::Bound(bound) => { + let ret = bound.try_send(buf, to); + bound.inner().iface().poll(); + ret + } + _ => Err(SystemError::ENOTCONN), + }; + return result; + } + + pub fn event(&self) -> EPollEventType { + let mut event = EPollEventType::empty(); + match self.inner.read().as_ref().unwrap() { + UdpInner::Unbound(_) => { + event.insert(EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND); + } + UdpInner::Bound(bound) => { + let (can_recv, can_send) = + bound.with_socket(|socket| (socket.can_recv(), socket.can_send())); + + if can_recv { + event.insert(EP::EPOLLIN | EP::EPOLLRDNORM); + } + + if can_send { + event.insert(EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND); + } + } + } + return event; + } +} + +impl Socket for UdpSocket { + fn wait_queue(&self) -> &WaitQueue { + &self.wait_queue + } + + fn poll(&self) -> usize { + self.event().bits() as usize + } + + fn bind(&self, local_endpoint: Endpoint) -> Result<(), SystemError> { + if let Endpoint::Ip(local_endpoint) = local_endpoint { + return self.do_bind(local_endpoint); + } + Err(SystemError::EAFNOSUPPORT) + } + + fn send_buffer_size(&self) -> usize { + match self.inner.read().as_ref().unwrap() { + UdpInner::Bound(bound) => bound.with_socket(|socket| socket.payload_send_capacity()), + _ => inner::DEFAULT_TX_BUF_SIZE, + } + } + + fn recv_buffer_size(&self) -> usize { + match self.inner.read().as_ref().unwrap() { + UdpInner::Bound(bound) => bound.with_socket(|socket| socket.payload_recv_capacity()), + _ => inner::DEFAULT_RX_BUF_SIZE, + } + } + + fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> { + if let Endpoint::Ip(remote) = endpoint { + if !self.is_bound() { + self.bind_emphemeral(remote.addr)?; + } + if let UdpInner::Bound(inner) = self.inner.read().as_ref().expect("UDP Inner disappear") + { + inner.connect(remote); + return Ok(()); + } else { + panic!(""); + } + } + return Err(SystemError::EAFNOSUPPORT); + } + + fn send(&self, buffer: &[u8], flags: PMSG) -> Result { + if flags.contains(PMSG::DONTWAIT) { + log::warn!("Nonblock send is not implemented yet"); + } + + return self.try_send(buffer, None); + } + + fn send_to(&self, buffer: &[u8], flags: PMSG, address: Endpoint) -> Result { + if flags.contains(PMSG::DONTWAIT) { + log::warn!("Nonblock send is not implemented yet"); + } + + if let Endpoint::Ip(remote) = address { + return self.try_send(buffer, Some(remote)); + } + + return Err(SystemError::EINVAL); + } + + fn recv(&self, buffer: &mut [u8], flags: PMSG) -> Result { + use crate::sched::SchedMode; + + return if self.is_nonblock() || flags.contains(PMSG::DONTWAIT) { + self.try_recv(buffer) + } else { + loop { + match self.try_recv(buffer) { + Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { + wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {})?; + } + result => break result, + } + } + } + .map(|(len, _)| len); + } + + fn recv_from( + &self, + buffer: &mut [u8], + flags: PMSG, + address: Option, + ) -> Result<(usize, Endpoint), SystemError> { + use crate::sched::SchedMode; + // could block io + if let Some(endpoint) = address { + self.connect(endpoint)?; + } + + return if self.is_nonblock() || flags.contains(PMSG::DONTWAIT) { + self.try_recv(buffer) + } else { + loop { + match self.try_recv(buffer) { + Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { + wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {})?; + log::debug!("UdpSocket::recv_from: wake up"); + } + result => break result, + } + } + } + .map(|(len, remote)| (len, Endpoint::Ip(remote))); + } + + fn close(&self) -> Result<(), SystemError> { + self.close(); + Ok(()) + } +} + +impl InetSocket for UdpSocket { + fn on_iface_events(&self) { + return; + } +} + +bitflags! { + pub struct UdpSocketOptions: u32 { + const ZERO = 0; /* No UDP options */ + const UDP_CORK = 1; /* Never send partially complete segments */ + const UDP_ENCAP = 100; /* Set the socket to accept encapsulated packets */ + const UDP_NO_CHECK6_TX = 101; /* Disable sending checksum for UDP6X */ + const UDP_NO_CHECK6_RX = 102; /* Disable accepting checksum for UDP6 */ + const UDP_SEGMENT = 103; /* Set GSO segmentation size */ + const UDP_GRO = 104; /* This socket can receive UDP GRO packets */ + + const UDPLITE_SEND_CSCOV = 10; /* sender partial coverage (as sent) */ + const UDPLITE_RECV_CSCOV = 11; /* receiver partial coverage (threshold ) */ + } +} + +bitflags! { + pub struct UdpEncapTypes: u8 { + const ZERO = 0; + const ESPINUDP_NON_IKE = 1; // draft-ietf-ipsec-nat-t-ike-00/01 + const ESPINUDP = 2; // draft-ietf-ipsec-udp-encaps-06 + const L2TPINUDP = 3; // rfc2661 + const GTP0 = 4; // GSM TS 09.60 + const GTP1U = 5; // 3GPP TS 29.060 + const RXRPC = 6; + const ESPINTCP = 7; // Yikes, this is really xfrm encap types. + } +} diff --git a/kernel/src/net/socket/inet/mod.rs b/kernel/src/net/socket/inet/mod.rs new file mode 100644 index 000000000..8c1d504a0 --- /dev/null +++ b/kernel/src/net/socket/inet/mod.rs @@ -0,0 +1,39 @@ +use smoltcp; + +// pub mod raw; +// pub mod icmp; +pub mod common; +pub mod datagram; +pub mod stream; +pub mod syscall; + +pub use common::BoundInner; +pub use common::Types; +// pub use raw::RawSocket; +pub use datagram::UdpSocket; + +use smoltcp::wire::IpAddress; +use smoltcp::wire::IpEndpoint; +use smoltcp::wire::Ipv4Address; +use smoltcp::wire::Ipv6Address; + +pub use stream::TcpSocket; +pub use syscall::Inet; + +use super::Socket; + +/// A local endpoint, which indicates that the local endpoint is unspecified. +/// +/// According to the Linux man pages and the Linux implementation, `getsockname()` will _not_ fail +/// even if the socket is unbound. Instead, it will return an unspecified socket address. This +/// unspecified endpoint helps with that. +const UNSPECIFIED_LOCAL_ENDPOINT_V4: IpEndpoint = + IpEndpoint::new(IpAddress::Ipv4(Ipv4Address::UNSPECIFIED), 0); +const UNSPECIFIED_LOCAL_ENDPOINT_V6: IpEndpoint = + IpEndpoint::new(IpAddress::Ipv6(Ipv6Address::UNSPECIFIED), 0); + +pub trait InetSocket: Socket { + /// `on_iface_events` + /// 通知socket发生的事件 + fn on_iface_events(&self); +} diff --git a/kernel/src/net/socket/inet/posix/option.rs b/kernel/src/net/socket/inet/posix/option.rs new file mode 100644 index 000000000..5c5947dd0 --- /dev/null +++ b/kernel/src/net/socket/inet/posix/option.rs @@ -0,0 +1,68 @@ + +bitflags! { + pub struct IpOptions: u32 { + const IP_TOS = 1; // Type of service + const IP_TTL = 2; // Time to live + const IP_HDRINCL = 3; // Header compression + const IP_OPTIONS = 4; // IP options + const IP_ROUTER_ALERT = 5; // Router alert + const IP_RECVOPTS = 6; // Receive options + const IP_RETOPTS = 7; // Return options + const IP_PKTINFO = 8; // Packet information + const IP_PKTOPTIONS = 9; // Packet options + const IP_MTU_DISCOVER = 10; // MTU discovery + const IP_RECVERR = 11; // Receive errors + const IP_RECVTTL = 12; // Receive time to live + const IP_RECVTOS = 13; // Receive type of service + const IP_MTU = 14; // MTU + const IP_FREEBIND = 15; // Freebind + const IP_IPSEC_POLICY = 16; // IPsec policy + const IP_XFRM_POLICY = 17; // IPipsec transform policy + const IP_PASSSEC = 18; // Pass security + const IP_TRANSPARENT = 19; // Transparent + + const IP_RECVRETOPTS = 20; // Receive return options (deprecated) + + const IP_ORIGDSTADDR = 21; // Originate destination address (used by TProxy) + const IP_RECVORIGDSTADDR = 21; // Receive originate destination address + + const IP_MINTTL = 22; // Minimum time to live + const IP_NODEFRAG = 23; // Don't fragment (used by TProxy) + const IP_CHECKSUM = 24; // Checksum offload (used by TProxy) + const IP_BIND_ADDRESS_NO_PORT = 25; // Bind to address without port (used by TProxy) + const IP_RECVFRAGSIZE = 26; // Receive fragment size + const IP_RECVERR_RFC4884 = 27; // Receive ICMPv6 error notifications + + const IP_PMTUDISC_DONT = 28; // Don't send DF frames + const IP_PMTUDISC_DO = 29; // Always DF + const IP_PMTUDISC_PROBE = 30; // Ignore dst pmtu + const IP_PMTUDISC_INTERFACE = 31; // Always use interface mtu (ignores dst pmtu) + const IP_PMTUDISC_OMIT = 32; // Weaker version of IP_PMTUDISC_INTERFACE + + const IP_MULTICAST_IF = 33; // Multicast interface + const IP_MULTICAST_TTL = 34; // Multicast time to live + const IP_MULTICAST_LOOP = 35; // Multicast loopback + const IP_ADD_MEMBERSHIP = 36; // Add multicast group membership + const IP_DROP_MEMBERSHIP = 37; // Drop multicast group membership + const IP_UNBLOCK_SOURCE = 38; // Unblock source + const IP_BLOCK_SOURCE = 39; // Block source + const IP_ADD_SOURCE_MEMBERSHIP = 40; // Add source multicast group membership + const IP_DROP_SOURCE_MEMBERSHIP = 41; // Drop source multicast group membership + const IP_MSFILTER = 42; // Multicast source filter + + const MCAST_JOIN_GROUP = 43; // Join a multicast group + const MCAST_BLOCK_SOURCE = 44; // Block a multicast source + const MCAST_UNBLOCK_SOURCE = 45; // Unblock a multicast source + const MCAST_LEAVE_GROUP = 46; // Leave a multicast group + const MCAST_JOIN_SOURCE_GROUP = 47; // Join a multicast source group + const MCAST_LEAVE_SOURCE_GROUP = 48; // Leave a multicast source group + const MCAST_MSFILTER = 49; // Multicast source filter + + const IP_MULTICAST_ALL = 50; // Multicast all + const IP_UNICAST_IF = 51; // Unicast interface + const IP_LOCAL_PORT_RANGE = 52; // Local port range + const IP_PROTOCOL = 53; // Protocol + + // ... other flags ... + } +} \ No newline at end of file diff --git a/kernel/src/net/socket/inet/posix/proto.rs b/kernel/src/net/socket/inet/posix/proto.rs new file mode 100644 index 000000000..39818f658 --- /dev/null +++ b/kernel/src/net/socket/inet/posix/proto.rs @@ -0,0 +1,76 @@ +pub const SOL_SOCKET: u16 = 1; + +#[derive(Debug, Clone, Copy, FromPrimitive, ToPrimitive, PartialEq, Eq)] +pub enum IPProtocol { + /// Dummy protocol for TCP. + IP = 0, + /// Internet Control Message Protocol. + ICMP = 1, + /// Internet Group Management Protocol. + IGMP = 2, + /// IPIP tunnels (older KA9Q tunnels use 94). + IPIP = 4, + /// Transmission Control Protocol. + TCP = 6, + /// Exterior Gateway Protocol. + EGP = 8, + /// PUP protocol. + PUP = 12, + /// User Datagram Protocol. + UDP = 17, + /// XNS IDP protocol. + IDP = 22, + /// SO Transport Protocol Class 4. + TP = 29, + /// Datagram Congestion Control Protocol. + DCCP = 33, + /// IPv6-in-IPv4 tunnelling. + IPv6 = 41, + /// RSVP Protocol. + RSVP = 46, + /// Generic Routing Encapsulation. (Cisco GRE) (rfc 1701, 1702) + GRE = 47, + /// Encapsulation Security Payload protocol + ESP = 50, + /// Authentication Header protocol + AH = 51, + /// Multicast Transport Protocol. + MTP = 92, + /// IP option pseudo header for BEET + BEETPH = 94, + /// Encapsulation Header. + ENCAP = 98, + /// Protocol Independent Multicast. + PIM = 103, + /// Compression Header Protocol. + COMP = 108, + /// Stream Control Transport Protocol + SCTP = 132, + /// UDP-Lite protocol (RFC 3828) + UDPLITE = 136, + /// MPLS in IP (RFC 4023) + MPLSINIP = 137, + /// Ethernet-within-IPv6 Encapsulation + ETHERNET = 143, + /// Raw IP packets + RAW = 255, + /// Multipath TCP connection + MPTCP = 262, +} + +impl TryFrom for IPProtocol { + type Error = system_error::SystemError; + + fn try_from(value: u16) -> Result { + match ::from_u16(value) { + Some(p) => Ok(p), + None => Err(system_error::SystemError::EPROTONOSUPPORT), + } + } +} + +impl From for u16 { + fn from(value: IPProtocol) -> Self { + ::to_u16(&value).unwrap() + } +} diff --git a/kernel/src/net/socket/inet/stream/inner.rs b/kernel/src/net/socket/inet/stream/inner.rs new file mode 100644 index 000000000..9afbfa6eb --- /dev/null +++ b/kernel/src/net/socket/inet/stream/inner.rs @@ -0,0 +1,517 @@ +use core::sync::atomic::AtomicUsize; + +use crate::filesystem::epoll::EPollEventType; +use crate::libs::rwlock::RwLock; +use crate::net::socket::{self, inet::Types}; +use alloc::boxed::Box; +use alloc::vec::Vec; +use smoltcp; +use smoltcp::socket::tcp; +use system_error::SystemError; + +// pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; +pub const DEFAULT_RX_BUF_SIZE: usize = 512 * 1024; +pub const DEFAULT_TX_BUF_SIZE: usize = 512 * 1024; + +fn new_smoltcp_socket() -> smoltcp::socket::tcp::Socket<'static> { + let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0; DEFAULT_RX_BUF_SIZE]); + let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0; DEFAULT_TX_BUF_SIZE]); + smoltcp::socket::tcp::Socket::new(rx_buffer, tx_buffer) +} + +fn new_listen_smoltcp_socket(local_endpoint: T) -> smoltcp::socket::tcp::Socket<'static> +where + T: Into, +{ + let mut socket = new_smoltcp_socket(); + socket.listen(local_endpoint).unwrap(); + socket +} + +#[derive(Debug)] +pub enum Init { + Unbound( + ( + Box>, + smoltcp::wire::IpVersion, + ), + ), + Bound((socket::inet::BoundInner, smoltcp::wire::IpEndpoint)), +} + +impl Init { + pub(super) fn new(ver: smoltcp::wire::IpVersion) -> Self { + Init::Unbound((Box::new(new_smoltcp_socket()), ver)) + } + + /// 传入一个已经绑定的socket + pub(super) fn new_bound(inner: socket::inet::BoundInner) -> Self { + let endpoint = inner.with::(|socket| { + socket + .local_endpoint() + .expect("A Bound Socket Must Have A Local Endpoint") + }); + Init::Bound((inner, endpoint)) + } + + pub(super) fn bind( + self, + local_endpoint: smoltcp::wire::IpEndpoint, + ) -> Result { + match self { + Init::Unbound((socket, _)) => { + let bound = socket::inet::BoundInner::bind(*socket, &local_endpoint.addr)?; + bound + .port_manager() + .bind_port(Types::Tcp, local_endpoint.port)?; + // bound.iface().common().bind_socket() + Ok(Init::Bound((bound, local_endpoint))) + } + Init::Bound(_) => { + log::debug!("Already Bound"); + Err(SystemError::EINVAL) + } + } + } + + pub(super) fn bind_to_ephemeral( + self, + remote_endpoint: smoltcp::wire::IpEndpoint, + ) -> Result<(socket::inet::BoundInner, smoltcp::wire::IpEndpoint), (Self, SystemError)> { + match self { + Init::Unbound((socket, ver)) => { + let (bound, address) = + socket::inet::BoundInner::bind_ephemeral(*socket, remote_endpoint.addr) + .map_err(|err| (Self::new(ver), err))?; + let bound_port = bound + .port_manager() + .bind_ephemeral_port(Types::Tcp) + .map_err(|err| (Self::new(ver), err))?; + let endpoint = smoltcp::wire::IpEndpoint::new(address, bound_port); + Ok((bound, endpoint)) + } + Init::Bound(_) => Err((self, SystemError::EINVAL)), + } + } + + pub(super) fn connect( + self, + remote_endpoint: smoltcp::wire::IpEndpoint, + ) -> Result { + let (inner, local) = match self { + Init::Unbound(_) => self.bind_to_ephemeral(remote_endpoint)?, + Init::Bound(inner) => inner, + }; + if local.addr.is_unspecified() { + return Err((Init::Bound((inner, local)), SystemError::EINVAL)); + } + let result = inner.with_mut::(|socket| { + socket + .connect( + inner.iface().smol_iface().lock().context(), + remote_endpoint, + local, + ) + .map_err(|_| SystemError::ECONNREFUSED) + }); + match result { + Ok(_) => Ok(Connecting::new(inner)), + Err(err) => Err((Init::Bound((inner, local)), err)), + } + } + + /// # `listen` + pub(super) fn listen(self, backlog: usize) -> Result { + let (inner, local) = match self { + Init::Unbound(_) => { + return Err((self, SystemError::EINVAL)); + } + Init::Bound(inner) => inner, + }; + let listen_addr = if local.addr.is_unspecified() { + smoltcp::wire::IpListenEndpoint::from(local.port) + } else { + smoltcp::wire::IpListenEndpoint::from(local) + }; + log::debug!("listen at {:?}", listen_addr); + let mut inners = Vec::new(); + if let Err(err) = || -> Result<(), SystemError> { + for _ in 0..(backlog - 1) { + // -1 because the first one is already bound + let new_listen = socket::inet::BoundInner::bind( + new_listen_smoltcp_socket(listen_addr), + listen_addr + .addr + .as_ref() + .unwrap_or(&smoltcp::wire::IpAddress::from( + smoltcp::wire::Ipv4Address::UNSPECIFIED, + )), + )?; + inners.push(new_listen); + } + Ok(()) + }() { + return Err((Init::Bound((inner, local)), err)); + } + + if let Err(err) = inner.with_mut::(|socket| { + socket + .listen(listen_addr) + .map_err(|_| SystemError::ECONNREFUSED) + }) { + return Err((Init::Bound((inner, local)), err)); + } + + inners.push(inner); + return Ok(Listening { + inners, + connect: AtomicUsize::new(0), + listen_addr, + }); + } + + pub(super) fn close(&self) { + match self { + Init::Unbound(_) => {} + Init::Bound((inner, endpoint)) => { + inner.port_manager().unbind_port(Types::Tcp, endpoint.port); + inner.with_mut::(|socket| socket.close()); + } + } + } +} + +#[derive(Debug, Default, Clone, Copy)] +enum ConnectResult { + Connected, + #[default] + Connecting, + Refused, +} + +#[derive(Debug)] +pub struct Connecting { + inner: socket::inet::BoundInner, + result: RwLock, +} + +impl Connecting { + fn new(inner: socket::inet::BoundInner) -> Self { + Connecting { + inner, + result: RwLock::new(ConnectResult::Connecting), + } + } + + pub fn with_mut) -> R>( + &self, + f: F, + ) -> R { + self.inner.with_mut(f) + } + + pub fn into_result(self) -> (Inner, Result<(), SystemError>) { + let result = *self.result.read(); + match result { + ConnectResult::Connecting => ( + Inner::Connecting(self), + Err(SystemError::EAGAIN_OR_EWOULDBLOCK), + ), + ConnectResult::Connected => ( + Inner::Established(Established { inner: self.inner }), + Ok(()), + ), + ConnectResult::Refused => ( + Inner::Init(Init::new_bound(self.inner)), + Err(SystemError::ECONNREFUSED), + ), + } + } + + pub unsafe fn into_established(self) -> Established { + Established { inner: self.inner } + } + + /// Returns `true` when `conn_result` becomes ready, which indicates that the caller should + /// invoke the `into_result()` method as soon as possible. + /// + /// Since `into_result()` needs to be called only once, this method will return `true` + /// _exactly_ once. The caller is responsible for not missing this event. + #[must_use] + pub(super) fn update_io_events(&self) -> bool { + // if matches!(*self.result.read_irqsave(), ConnectResult::Connecting) { + // return false; + // } + + self.inner + .with_mut(|socket: &mut smoltcp::socket::tcp::Socket| { + let mut result = self.result.write(); + if matches!(*result, ConnectResult::Refused | ConnectResult::Connected) { + return false; // Already connected or refused + } + + // Connected + if socket.can_send() { + log::debug!("can send"); + *result = ConnectResult::Connected; + return true; + } + // Connecting + if socket.is_open() { + log::debug!("connecting"); + *result = ConnectResult::Connecting; + return false; + } + // Refused + *result = ConnectResult::Refused; + return true; + }) + } + + pub fn get_name(&self) -> smoltcp::wire::IpEndpoint { + self.inner + .with::(|socket| { + socket + .local_endpoint() + .expect("A Connecting Tcp With No Local Endpoint") + }) + } + + pub fn get_peer_name(&self) -> smoltcp::wire::IpEndpoint { + self.inner + .with::(|socket| { + socket + .remote_endpoint() + .expect("A Connecting Tcp With No Remote Endpoint") + }) + } +} + +#[derive(Debug)] +pub struct Listening { + inners: Vec, + connect: AtomicUsize, + listen_addr: smoltcp::wire::IpListenEndpoint, +} + +impl Listening { + pub fn accept(&mut self) -> Result<(Established, smoltcp::wire::IpEndpoint), SystemError> { + let connected: &mut socket::inet::BoundInner = self + .inners + .get_mut(self.connect.load(core::sync::atomic::Ordering::Relaxed)) + .unwrap(); + + if connected.with::(|socket| !socket.is_active()) { + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + } + + let remote_endpoint = connected.with::(|socket| { + socket + .remote_endpoint() + .expect("A Connected Tcp With No Remote Endpoint") + }); + + // log::debug!("local at {:?}", local_endpoint); + + let mut new_listen = socket::inet::BoundInner::bind( + new_listen_smoltcp_socket(self.listen_addr), + self.listen_addr + .addr + .as_ref() + .unwrap_or(&smoltcp::wire::IpAddress::from( + smoltcp::wire::Ipv4Address::UNSPECIFIED, + )), + )?; + + // swap the connected socket with the new_listen socket + // TODO is smoltcp socket swappable? + core::mem::swap(&mut new_listen, connected); + + return Ok((Established { inner: new_listen }, remote_endpoint)); + } + + pub fn update_io_events(&self, pollee: &AtomicUsize) { + let position = self.inners.iter().position(|inner| { + inner.with::(|socket| socket.is_active()) + }); + + if let Some(position) = position { + self.connect + .store(position, core::sync::atomic::Ordering::Relaxed); + pollee.fetch_or( + EPollEventType::EPOLLIN.bits() as usize, + core::sync::atomic::Ordering::Relaxed, + ); + } else { + pollee.fetch_and( + !EPollEventType::EPOLLIN.bits() as usize, + core::sync::atomic::Ordering::Relaxed, + ); + } + } + + pub fn get_name(&self) -> smoltcp::wire::IpEndpoint { + smoltcp::wire::IpEndpoint::new( + self.listen_addr + .addr + .unwrap_or(smoltcp::wire::IpAddress::from( + smoltcp::wire::Ipv4Address::UNSPECIFIED, + )), + self.listen_addr.port, + ) + } + + pub fn close(&self) { + // log::debug!("Close Listening Socket"); + let port = self.get_name().port; + for inner in self.inners.iter() { + inner.with_mut::(|socket| socket.close()); + } + self.inners[0] + .iface() + .port_manager() + .unbind_port(Types::Tcp, port); + } + + pub fn release(&self) { + // log::debug!("Release Listening Socket"); + for inner in self.inners.iter() { + inner.release(); + } + } +} + +#[derive(Debug)] +pub struct Established { + inner: socket::inet::BoundInner, +} + +impl Established { + pub fn with_mut) -> R>( + &self, + f: F, + ) -> R { + self.inner.with_mut(f) + } + + pub fn close(&self) { + self.inner + .with_mut::(|socket| socket.close()); + self.inner.iface().poll(); + } + + pub fn release(&self) { + self.inner.release(); + } + + pub fn get_name(&self) -> smoltcp::wire::IpEndpoint { + self.inner + .with::(|socket| socket.local_endpoint()) + .unwrap() + } + + pub fn get_peer_name(&self) -> smoltcp::wire::IpEndpoint { + self.inner + .with::(|socket| socket.remote_endpoint().unwrap()) + } + + pub fn recv_slice(&self, buf: &mut [u8]) -> Result { + self.inner + .with_mut::(|socket| { + if socket.can_send() { + match socket.recv_slice(buf) { + Ok(size) => Ok(size), + Err(tcp::RecvError::InvalidState) => { + log::error!("TcpSocket::try_recv: InvalidState"); + Err(SystemError::ENOTCONN) + } + Err(tcp::RecvError::Finished) => Ok(0), + } + } else { + Err(SystemError::ENOBUFS) + } + }) + } + + pub fn send_slice(&self, buf: &[u8]) -> Result { + self.inner + .with_mut::(|socket| { + if socket.can_send() { + socket + .send_slice(buf) + .map_err(|_| SystemError::ECONNABORTED) + } else { + Err(SystemError::ENOBUFS) + } + }) + } + + pub fn update_io_events(&self, pollee: &AtomicUsize) { + self.inner + .with_mut::(|socket| { + if socket.can_send() { + pollee.fetch_or( + EPollEventType::EPOLLOUT.bits() as usize, + core::sync::atomic::Ordering::Relaxed, + ); + } else { + pollee.fetch_and( + !EPollEventType::EPOLLOUT.bits() as usize, + core::sync::atomic::Ordering::Relaxed, + ); + } + if socket.can_recv() { + pollee.fetch_or( + EPollEventType::EPOLLIN.bits() as usize, + core::sync::atomic::Ordering::Relaxed, + ); + } else { + pollee.fetch_and( + !EPollEventType::EPOLLIN.bits() as usize, + core::sync::atomic::Ordering::Relaxed, + ); + } + }) + } +} + +#[derive(Debug)] +pub enum Inner { + Init(Init), + Connecting(Connecting), + Listening(Listening), + Established(Established), +} + +impl Inner { + pub fn send_buffer_size(&self) -> usize { + match self { + Inner::Init(_) => DEFAULT_TX_BUF_SIZE, + Inner::Connecting(conn) => conn.with_mut(|socket| socket.send_capacity()), + // only the first socket in the list is used for sending + Inner::Listening(listen) => listen.inners[0] + .with_mut::(|socket| socket.send_capacity()), + Inner::Established(est) => est.with_mut(|socket| socket.send_capacity()), + } + } + + pub fn recv_buffer_size(&self) -> usize { + match self { + Inner::Init(_) => DEFAULT_RX_BUF_SIZE, + Inner::Connecting(conn) => conn.with_mut(|socket| socket.recv_capacity()), + // only the first socket in the list is used for receiving + Inner::Listening(listen) => listen.inners[0] + .with_mut::(|socket| socket.recv_capacity()), + Inner::Established(est) => est.with_mut(|socket| socket.recv_capacity()), + } + } + + pub fn iface(&self) -> Option<&alloc::sync::Arc> { + match self { + Inner::Init(_) => None, + Inner::Connecting(conn) => Some(conn.inner.iface()), + Inner::Listening(listen) => Some(listen.inners[0].iface()), + Inner::Established(est) => Some(est.inner.iface()), + } + } +} diff --git a/kernel/src/net/socket/inet/stream/mod.rs b/kernel/src/net/socket/inet/stream/mod.rs new file mode 100644 index 000000000..6484d842f --- /dev/null +++ b/kernel/src/net/socket/inet/stream/mod.rs @@ -0,0 +1,506 @@ +use alloc::sync::{Arc, Weak}; +use core::sync::atomic::{AtomicBool, AtomicUsize}; +use system_error::SystemError; + +use crate::libs::wait_queue::WaitQueue; +use crate::net::socket::common::shutdown::{ShutdownBit, ShutdownTemp}; +use crate::net::socket::endpoint::Endpoint; +use crate::net::socket::{Socket, SocketInode, PMSG, PSOL}; +use crate::sched::SchedMode; +use crate::{libs::rwlock::RwLock, net::socket::common::shutdown::Shutdown}; +use smoltcp; + +mod inner; + +mod option; +pub use option::Options as TcpOption; + +use super::{InetSocket, UNSPECIFIED_LOCAL_ENDPOINT_V4, UNSPECIFIED_LOCAL_ENDPOINT_V6}; + +type EP = crate::filesystem::epoll::EPollEventType; +#[derive(Debug)] +pub struct TcpSocket { + inner: RwLock>, + #[allow(dead_code)] + shutdown: Shutdown, // TODO set shutdown status + nonblock: AtomicBool, + wait_queue: WaitQueue, + self_ref: Weak, + pollee: AtomicUsize, +} + +impl TcpSocket { + pub fn new(_nonblock: bool, ver: smoltcp::wire::IpVersion) -> Arc { + Arc::new_cyclic(|me| Self { + inner: RwLock::new(Some(inner::Inner::Init(inner::Init::new(ver)))), + shutdown: Shutdown::new(), + nonblock: AtomicBool::new(false), + wait_queue: WaitQueue::default(), + self_ref: me.clone(), + pollee: AtomicUsize::new(0_usize), + }) + } + + pub fn new_established(inner: inner::Established, nonblock: bool) -> Arc { + Arc::new_cyclic(|me| Self { + inner: RwLock::new(Some(inner::Inner::Established(inner))), + shutdown: Shutdown::new(), + nonblock: AtomicBool::new(nonblock), + wait_queue: WaitQueue::default(), + self_ref: me.clone(), + pollee: AtomicUsize::new((EP::EPOLLIN.bits() | EP::EPOLLOUT.bits()) as usize), + }) + } + + pub fn is_nonblock(&self) -> bool { + self.nonblock.load(core::sync::atomic::Ordering::Relaxed) + } + + pub fn do_bind(&self, local_endpoint: smoltcp::wire::IpEndpoint) -> Result<(), SystemError> { + let mut writer = self.inner.write(); + match writer.take().expect("Tcp inner::Inner is None") { + inner::Inner::Init(inner) => { + let bound = inner.bind(local_endpoint)?; + if let inner::Init::Bound((ref bound, _)) = bound { + bound + .iface() + .common() + .bind_socket(self.self_ref.upgrade().unwrap()); + } + writer.replace(inner::Inner::Init(bound)); + Ok(()) + } + any => { + writer.replace(any); + log::error!("TcpSocket::do_bind: not Init"); + Err(SystemError::EINVAL) + } + } + } + + pub fn do_listen(&self, backlog: usize) -> Result<(), SystemError> { + let mut writer = self.inner.write(); + let inner = writer.take().expect("Tcp inner::Inner is None"); + let (listening, err) = match inner { + inner::Inner::Init(init) => { + let listen_result = init.listen(backlog); + match listen_result { + Ok(listening) => (inner::Inner::Listening(listening), None), + Err((init, err)) => (inner::Inner::Init(init), Some(err)), + } + } + _ => (inner, Some(SystemError::EINVAL)), + }; + writer.replace(listening); + drop(writer); + + if let Some(err) = err { + return Err(err); + } + return Ok(()); + } + + pub fn try_accept(&self) -> Result<(Arc, smoltcp::wire::IpEndpoint), SystemError> { + match self + .inner + .write() + .as_mut() + .expect("Tcp inner::Inner is None") + { + inner::Inner::Listening(listening) => listening.accept().map(|(stream, remote)| { + ( + TcpSocket::new_established(stream, self.is_nonblock()), + remote, + ) + }), + _ => Err(SystemError::EINVAL), + } + } + + // SHOULD refactor + pub fn start_connect( + &self, + remote_endpoint: smoltcp::wire::IpEndpoint, + ) -> Result<(), SystemError> { + let mut writer = self.inner.write(); + let inner = writer.take().expect("Tcp inner::Inner is None"); + let (init, result) = match inner { + inner::Inner::Init(init) => { + let conn_result = init.connect(remote_endpoint); + match conn_result { + Ok(connecting) => ( + inner::Inner::Connecting(connecting), + if !self.is_nonblock() { + Ok(()) + } else { + Err(SystemError::EINPROGRESS) + }, + ), + Err((init, err)) => (inner::Inner::Init(init), Err(err)), + } + } + inner::Inner::Connecting(connecting) if self.is_nonblock() => ( + inner::Inner::Connecting(connecting), + Err(SystemError::EALREADY), + ), + inner::Inner::Connecting(connecting) => (inner::Inner::Connecting(connecting), Ok(())), + inner::Inner::Listening(inner) => { + (inner::Inner::Listening(inner), Err(SystemError::EISCONN)) + } + inner::Inner::Established(inner) => { + (inner::Inner::Established(inner), Err(SystemError::EISCONN)) + } + }; + + match result { + Ok(()) | Err(SystemError::EINPROGRESS) => { + init.iface().unwrap().poll(); + } + _ => {} + } + + writer.replace(init); + return result; + } + + // for irq use + pub fn finish_connect(&self) -> Result<(), SystemError> { + let mut writer = self.inner.write(); + let inner::Inner::Connecting(conn) = writer.take().expect("Tcp inner::Inner is None") + else { + log::error!("TcpSocket::finish_connect: not Connecting"); + return Err(SystemError::EINVAL); + }; + + let (inner, result) = conn.into_result(); + writer.replace(inner); + drop(writer); + + result + } + + pub fn check_connect(&self) -> Result<(), SystemError> { + self.update_events(); + let mut write_state = self.inner.write(); + let inner = write_state.take().expect("Tcp inner::Inner is None"); + let (replace, result) = match inner { + inner::Inner::Connecting(conn) => conn.into_result(), + inner::Inner::Established(es) => { + log::warn!("TODO: check new established"); + (inner::Inner::Established(es), Ok(())) + } // TODO check established + _ => { + log::warn!("TODO: connecting socket error options"); + (inner, Err(SystemError::EINVAL)) + } // TODO socket error options + }; + write_state.replace(replace); + result + } + + pub fn try_recv(&self, buf: &mut [u8]) -> Result { + self.inner + .read() + .as_ref() + .map(|inner| { + inner.iface().unwrap().poll(); + let result = match inner { + inner::Inner::Established(inner) => inner.recv_slice(buf), + _ => Err(SystemError::EINVAL), + }; + inner.iface().unwrap().poll(); + result + }) + .unwrap() + } + + pub fn try_send(&self, buf: &[u8]) -> Result { + // TODO: add nonblock check of connecting socket + let sent = match self + .inner + .read() + .as_ref() + .expect("Tcp inner::Inner is None") + { + inner::Inner::Established(inner) => inner.send_slice(buf), + _ => Err(SystemError::EINVAL), + }; + self.inner.read().as_ref().unwrap().iface().unwrap().poll(); + sent + } + + fn update_events(&self) -> bool { + match self + .inner + .read() + .as_ref() + .expect("Tcp inner::Inner is None") + { + inner::Inner::Init(_) => false, + inner::Inner::Connecting(connecting) => connecting.update_io_events(), + inner::Inner::Established(established) => { + established.update_io_events(&self.pollee); + false + } + inner::Inner::Listening(listening) => { + listening.update_io_events(&self.pollee); + false + } + } + } + + fn incoming(&self) -> bool { + EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLIN) + } +} + +impl Socket for TcpSocket { + fn wait_queue(&self) -> &WaitQueue { + &self.wait_queue + } + + fn get_name(&self) -> Result { + match self + .inner + .read() + .as_ref() + .expect("Tcp inner::Inner is None") + { + inner::Inner::Init(inner::Init::Unbound((_, ver))) => Ok(Endpoint::Ip(match ver { + smoltcp::wire::IpVersion::Ipv4 => UNSPECIFIED_LOCAL_ENDPOINT_V4, + smoltcp::wire::IpVersion::Ipv6 => UNSPECIFIED_LOCAL_ENDPOINT_V6, + })), + inner::Inner::Init(inner::Init::Bound((_, local))) => Ok(Endpoint::Ip(*local)), + inner::Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_name())), + inner::Inner::Established(established) => Ok(Endpoint::Ip(established.get_name())), + inner::Inner::Listening(listening) => Ok(Endpoint::Ip(listening.get_name())), + } + } + + fn get_peer_name(&self) -> Result { + match self + .inner + .read() + .as_ref() + .expect("Tcp inner::Inner is None") + { + inner::Inner::Init(_) => Err(SystemError::ENOTCONN), + inner::Inner::Connecting(connecting) => Ok(Endpoint::Ip(connecting.get_peer_name())), + inner::Inner::Established(established) => Ok(Endpoint::Ip(established.get_peer_name())), + inner::Inner::Listening(_) => Err(SystemError::ENOTCONN), + } + } + + fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> { + if let Endpoint::Ip(addr) = endpoint { + return self.do_bind(addr); + } + log::debug!("TcpSocket::bind: invalid endpoint"); + return Err(SystemError::EINVAL); + } + + fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> { + let Endpoint::Ip(endpoint) = endpoint else { + log::debug!("TcpSocket::connect: invalid endpoint"); + return Err(SystemError::EINVAL); + }; + self.start_connect(endpoint)?; // Only Nonblock or error will return error. + + return loop { + match self.check_connect() { + Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => {} + result => break result, + } + }; + } + + fn poll(&self) -> usize { + self.pollee.load(core::sync::atomic::Ordering::SeqCst) + } + + fn listen(&self, backlog: usize) -> Result<(), SystemError> { + self.do_listen(backlog) + } + + fn accept(&self) -> Result<(Arc, Endpoint), SystemError> { + if self.is_nonblock() { + self.try_accept() + } else { + loop { + match self.try_accept() { + Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { + wq_wait_event_interruptible!(self.wait_queue, self.incoming(), {})?; + } + result => break result, + } + } + } + .map(|(inner, endpoint)| (SocketInode::new(inner), Endpoint::Ip(endpoint))) + } + + fn recv(&self, buffer: &mut [u8], _flags: PMSG) -> Result { + self.try_recv(buffer) + } + + fn send(&self, buffer: &[u8], _flags: PMSG) -> Result { + self.try_send(buffer) + } + + fn send_buffer_size(&self) -> usize { + self.inner + .read() + .as_ref() + .expect("Tcp inner::Inner is None") + .send_buffer_size() + } + + fn recv_buffer_size(&self) -> usize { + self.inner + .read() + .as_ref() + .expect("Tcp inner::Inner is None") + .recv_buffer_size() + } + + fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> { + let self_shutdown = self.shutdown.get().bits(); + let diff = how.bits().difference(self_shutdown); + match diff.is_empty() { + true => return Ok(()), + false => { + if diff.contains(ShutdownBit::SHUT_RD) { + self.shutdown.recv_shutdown(); + // TODO 协议栈处理 + } + if diff.contains(ShutdownBit::SHUT_WR) { + self.shutdown.send_shutdown(); + // TODO 协议栈处理 + } + } + } + Ok(()) + } + + fn close(&self) -> Result<(), SystemError> { + let Some(inner) = self.inner.write().take() else { + log::warn!("TcpSocket::close: already closed, unexpected"); + return Ok(()); + }; + if let Some(iface) = inner.iface() { + iface + .common() + .unbind_socket(self.self_ref.upgrade().unwrap()); + } + + match inner { + // complete connecting socket close logic + inner::Inner::Connecting(conn) => { + let conn = unsafe { conn.into_established() }; + conn.close(); + conn.release(); + } + inner::Inner::Established(es) => { + es.close(); + es.release(); + } + inner::Inner::Listening(ls) => { + ls.close(); + ls.release(); + } + inner::Inner::Init(init) => { + init.close(); + } + }; + + Ok(()) + } + + fn set_option(&self, level: PSOL, name: usize, val: &[u8]) -> Result<(), SystemError> { + if level != PSOL::TCP { + // return Err(SystemError::EINVAL); + log::debug!("TcpSocket::set_option: not TCP"); + return Ok(()); + } + use option::Options::{self, *}; + let option_name = Options::try_from(name as i32)?; + log::debug!("TCP Option: {:?}, value = {:?}", option_name, val); + match option_name { + NoDelay => { + let nagle_enabled = val[0] != 0; + let mut writer = self.inner.write(); + let inner = writer.take().expect("Tcp inner::Inner is None"); + match inner { + inner::Inner::Established(established) => { + established.with_mut(|socket| { + socket.set_nagle_enabled(nagle_enabled); + }); + writer.replace(inner::Inner::Established(established)); + } + _ => { + writer.replace(inner); + return Err(SystemError::EINVAL); + } + } + } + KeepIntvl => { + if val.len() == 4 { + let mut writer = self.inner.write(); + let inner = writer.take().expect("Tcp inner::Inner is None"); + match inner { + inner::Inner::Established(established) => { + let interval = u32::from_ne_bytes([val[0], val[1], val[2], val[3]]); + established.with_mut(|socket| { + socket.set_keep_alive(Some(smoltcp::time::Duration::from_secs( + interval as u64, + ))); + }); + writer.replace(inner::Inner::Established(established)); + } + _ => { + writer.replace(inner); + return Err(SystemError::EINVAL); + } + } + } else { + return Err(SystemError::EINVAL); + } + } + KeepCnt => { + // if val.len() == 4 { + // let mut writer = self.inner.write(); + // let inner = writer.take().expect("Tcp inner::Inner is None"); + // match inner { + // inner::Inner::Established(established) => { + // let count = u32::from_ne_bytes([val[0], val[1], val[2], val[3]]); + // established.with_mut(|socket| { + // socket.set_keep_alive_count(count); + // }); + // writer.replace(inner::Inner::Established(established)); + // } + // _ => { + // writer.replace(inner); + // return Err(SystemError::EINVAL); + // } + // } + // } else { + // return Err(SystemError::EINVAL); + // } + } + KeepIdle => {} + _ => { + log::debug!("TcpSocket::set_option: not supported"); + // return Err(ENOPROTOOPT); + } + } + Ok(()) + } +} + +impl InetSocket for TcpSocket { + fn on_iface_events(&self) { + if self.update_events() { + let _result = self.finish_connect(); + // set error + } + } +} diff --git a/kernel/src/net/socket/inet/stream/option.rs b/kernel/src/net/socket/inet/stream/option.rs new file mode 100644 index 000000000..7008112e2 --- /dev/null +++ b/kernel/src/net/socket/inet/stream/option.rs @@ -0,0 +1,91 @@ +use num_traits::{FromPrimitive, ToPrimitive}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] +pub enum Options { + /// Turn off Nagle's algorithm. + NoDelay = 1, + /// Limit MSS. + MaxSegment = 2, + /// Never send partially complete segments. + Cork = 3, + /// Start keeplives after this period. + KeepIdle = 4, + /// Interval between keepalives. + KeepIntvl = 5, + /// Number of keepalives before death. + KeepCnt = 6, + /// Number of SYN retransmits. + Syncnt = 7, + /// Lifetime for orphaned FIN-WAIT-2 state. + Linger2 = 8, + /// Wake up listener only when data arrive. + DeferAccept = 9, + /// Bound advertised window + WindowClamp = 10, + /// Information about this connection. + Info = 11, + /// Block/reenable quick acks. + QuickAck = 12, + /// Congestion control algorithm. + Congestion = 13, + /// TCP MD5 Signature (RFC2385). + Md5Sig = 14, + /// Use linear timeouts for thin streams + ThinLinearTimeouts = 16, + /// Fast retrans. after 1 dupack. + ThinDupack = 17, + /// How long for loss retry before timeout. + UserTimeout = 18, + /// TCP sock is under repair right now. + Repair = 19, + RepairQueue = 20, + QueueSeq = 21, + #[allow(clippy::enum_variant_names)] + RepairOptions = 22, + /// Enable FastOpen on listeners + FastOpen = 23, + Timestamp = 24, + /// Limit number of unsent bytes in write queue. + NotSentLowat = 25, + /// Get Congestion Control (optional) info. + CCInfo = 26, + /// Record SYN headers for new connections. + SaveSyn = 27, + /// Get SYN headers recorded for connection. + SavedSyn = 28, + /// Get/set window parameters. + RepairWindow = 29, + /// Attempt FastOpen with connect. + FastOpenConnect = 30, + /// Attach a ULP to a TCP connection. + ULP = 31, + /// TCP MD5 Signature with extensions. + Md5SigExt = 32, + /// Set the key for Fast Open(cookie). + FastOpenKey = 33, + /// Enable TFO without a TFO cookie. + FastOpenNoCookie = 34, + ZeroCopyReceive = 35, + /// Notify bytes available to read as a cmsg on read. + /// 与TCP_CM_INQ相同 + INQ = 36, + /// delay outgoing packets by XX usec + TxDelay = 37, +} + +impl TryFrom for Options { + type Error = system_error::SystemError; + + fn try_from(value: i32) -> Result { + match ::from_i32(value) { + Some(p) => Ok(p), + None => Err(Self::Error::EINVAL), + } + } +} + +impl From for i32 { + fn from(val: Options) -> Self { + ::to_i32(&val).unwrap() + } +} diff --git a/kernel/src/net/socket/inet/syscall.rs b/kernel/src/net/socket/inet/syscall.rs new file mode 100644 index 000000000..46c3a0671 --- /dev/null +++ b/kernel/src/net/socket/inet/syscall.rs @@ -0,0 +1,66 @@ +use alloc::sync::Arc; +use smoltcp::{self, wire::IpProtocol}; +use system_error::SystemError; + +use crate::net::socket::{ + family, + inet::{TcpSocket, UdpSocket}, + Socket, SocketInode, PSOCK, +}; + +fn create_inet_socket( + version: smoltcp::wire::IpVersion, + socket_type: PSOCK, + protocol: smoltcp::wire::IpProtocol, +) -> Result, SystemError> { + // log::debug!("type: {:?}, protocol: {:?}", socket_type, protocol); + match socket_type { + PSOCK::Datagram => match protocol { + IpProtocol::HopByHop | IpProtocol::Udp => { + return Ok(UdpSocket::new(false)); + } + _ => { + return Err(SystemError::EPROTONOSUPPORT); + } + }, + PSOCK::Stream => match protocol { + IpProtocol::HopByHop | IpProtocol::Tcp => { + log::debug!("create tcp socket"); + return Ok(TcpSocket::new(false, version)); + } + _ => { + return Err(SystemError::EPROTONOSUPPORT); + } + }, + PSOCK::Raw => { + todo!("raw") + } + _ => { + return Err(SystemError::EPROTONOSUPPORT); + } + } +} + +pub struct Inet; +impl family::Family for Inet { + fn socket(stype: PSOCK, protocol: u32) -> Result, SystemError> { + let socket = create_inet_socket( + smoltcp::wire::IpVersion::Ipv4, + stype, + smoltcp::wire::IpProtocol::from(protocol as u8), + )?; + Ok(SocketInode::new(socket)) + } +} + +pub struct Inet6; +impl family::Family for Inet6 { + fn socket(stype: PSOCK, protocol: u32) -> Result, SystemError> { + let socket = create_inet_socket( + smoltcp::wire::IpVersion::Ipv6, + stype, + smoltcp::wire::IpProtocol::from(protocol as u8), + )?; + Ok(SocketInode::new(socket)) + } +} diff --git a/kernel/src/net/socket/inode.rs b/kernel/src/net/socket/inode.rs new file mode 100644 index 000000000..bad855272 --- /dev/null +++ b/kernel/src/net/socket/inode.rs @@ -0,0 +1,187 @@ +use crate::filesystem::vfs::IndexNode; +use alloc::sync::Arc; +use system_error::SystemError; + +use super::{ + common::shutdown::ShutdownTemp, + endpoint::Endpoint, + posix::{PMSG, PSOL}, + EPollItems, Socket, +}; + +#[derive(Debug)] +pub struct SocketInode { + inner: Arc, + epoll_items: EPollItems, +} + +impl IndexNode for SocketInode { + fn read_at( + &self, + _offset: usize, + _len: usize, + buf: &mut [u8], + data: crate::libs::spinlock::SpinLockGuard, + ) -> Result { + drop(data); + self.inner.read(buf) + } + + fn write_at( + &self, + _offset: usize, + _len: usize, + buf: &[u8], + data: crate::libs::spinlock::SpinLockGuard, + ) -> Result { + drop(data); + self.inner.write(buf) + } + + /* Following are not yet available in socket */ + fn as_any_ref(&self) -> &dyn core::any::Any { + self + } + + /* filesystem associate interfaces are about unix and netlink socket */ + fn fs(&self) -> Arc { + unimplemented!() + } + + fn list(&self) -> Result, SystemError> { + unimplemented!() + } + + fn open( + &self, + _data: crate::libs::spinlock::SpinLockGuard, + _mode: &crate::filesystem::vfs::file::FileMode, + ) -> Result<(), SystemError> { + Ok(()) + } + + fn metadata(&self) -> Result { + let meta = crate::filesystem::vfs::Metadata { + mode: crate::filesystem::vfs::syscall::ModeType::from_bits_truncate(0o755), + file_type: crate::filesystem::vfs::FileType::Socket, + size: self.send_buffer_size() as i64, + ..Default::default() + }; + + return Ok(meta); + } + + fn close( + &self, + _data: crate::libs::spinlock::SpinLockGuard, + ) -> Result<(), SystemError> { + self.inner.close() + } +} + +impl SocketInode { + // pub fn wait_queue(&self) -> WaitQueue { + // self.inner.wait_queue() + // } + + pub fn send_buffer_size(&self) -> usize { + self.inner.send_buffer_size() + } + + pub fn recv_buffer_size(&self) -> usize { + self.inner.recv_buffer_size() + } + + pub fn accept(&self) -> Result<(Arc, Endpoint), SystemError> { + self.inner.accept() + } + + pub fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> { + self.inner.bind(endpoint) + } + + pub fn set_option(&self, level: PSOL, name: usize, value: &[u8]) -> Result<(), SystemError> { + self.inner.set_option(level, name, value) + } + + pub fn get_option( + &self, + level: PSOL, + name: usize, + value: &mut [u8], + ) -> Result { + self.inner.get_option(level, name, value) + } + + pub fn listen(&self, backlog: usize) -> Result<(), SystemError> { + self.inner.listen(backlog) + } + + pub fn send_to( + &self, + buffer: &[u8], + address: Endpoint, + flags: PMSG, + ) -> Result { + self.inner.send_to(buffer, flags, address) + } + + pub fn send(&self, buffer: &[u8], flags: PMSG) -> Result { + self.inner.send(buffer, flags) + } + + pub fn recv(&self, buffer: &mut [u8], flags: PMSG) -> Result { + self.inner.recv(buffer, flags) + } + + // TODO receive from split with endpoint or not + pub fn recv_from( + &self, + buffer: &mut [u8], + flags: PMSG, + address: Option, + ) -> Result<(usize, Endpoint), SystemError> { + self.inner.recv_from(buffer, flags, address) + } + + pub fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> { + self.inner.shutdown(how) + } + + pub fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> { + self.inner.connect(endpoint) + } + + pub fn get_name(&self) -> Result { + self.inner.get_name() + } + + pub fn get_peer_name(&self) -> Result { + self.inner.get_peer_name() + } + + pub fn new(inner: Arc) -> Arc { + Arc::new(Self { + inner, + epoll_items: EPollItems::default(), + }) + } + + /// # `epoll_items` + /// socket的epoll事件集 + pub fn epoll_items(&self) -> EPollItems { + self.epoll_items.clone() + } + + pub fn set_nonblock(&self, _nonblock: bool) { + log::warn!("nonblock is not support yet"); + } + + pub fn set_close_on_exec(&self, _close_on_exec: bool) { + log::warn!("close_on_exec is not support yet"); + } + + pub fn inner(&self) -> Arc { + return self.inner.clone(); + } +} diff --git a/kernel/src/net/socket/mod.rs b/kernel/src/net/socket/mod.rs index aa13daa26..ed1796e84 100644 --- a/kernel/src/net/socket/mod.rs +++ b/kernel/src/net/socket/mod.rs @@ -1,920 +1,26 @@ -use core::{any::Any, fmt::Debug, sync::atomic::AtomicUsize}; - -use alloc::{ - boxed::Box, - collections::LinkedList, - string::String, - sync::{Arc, Weak}, - vec::Vec, -}; -use hashbrown::HashMap; -use log::warn; -use smoltcp::{ - iface::SocketSet, - socket::{self, raw, tcp, udp}, -}; -use system_error::SystemError; - -use crate::{ - arch::rand::rand, - filesystem::{ - epoll::{EPollEventType, EPollItem}, - vfs::{ - file::FileMode, syscall::ModeType, FilePrivateData, FileSystem, FileType, IndexNode, - Metadata, PollableInode, - }, - }, - libs::{ - rwlock::{RwLock, RwLockWriteGuard}, - spinlock::{SpinLock, SpinLockGuard}, - wait_queue::EventWaitQueue, - }, - process::{Pid, ProcessManager}, - sched::{schedule, SchedMode}, -}; - -use self::{ - handle::GlobalSocketHandle, - inet::{RawSocket, TcpSocket, UdpSocket}, - unix::{SeqpacketSocket, StreamSocket}, -}; - -use super::{Endpoint, Protocol, ShutdownType}; - -pub mod handle; +mod base; +mod buffer; +mod common; +pub mod endpoint; +mod family; pub mod inet; +mod inode; +mod posix; pub mod unix; +mod utils; -lazy_static! { - /// 所有socket的集合 - /// TODO: 优化这里,自己实现SocketSet!!!现在这样的话,不管全局有多少个网卡,每个时间点都只会有1个进程能够访问socket - pub static ref SOCKET_SET: SpinLock> = SpinLock::new(SocketSet::new(vec![])); - /// SocketHandle表,每个SocketHandle对应一个SocketHandleItem, - /// 注意!:在网卡中断中需要拿到这张表的🔓,在获取读锁时应该确保关中断避免死锁 - pub static ref HANDLE_MAP: RwLock> = RwLock::new(HashMap::new()); - /// 端口管理器 - pub static ref PORT_MANAGER: PortManager = PortManager::new(); -} - -/* For setsockopt(2) */ -// See: linux-5.19.10/include/uapi/asm-generic/socket.h#9 -pub const SOL_SOCKET: u8 = 1; - -/// 根据地址族、socket类型和协议创建socket -pub(super) fn new_socket( - address_family: AddressFamily, - socket_type: PosixSocketType, - protocol: Protocol, -) -> Result, SystemError> { - let socket: Box = match address_family { - AddressFamily::Unix => match socket_type { - PosixSocketType::Stream => Box::new(StreamSocket::new(SocketOptions::default())), - PosixSocketType::SeqPacket => Box::new(SeqpacketSocket::new(SocketOptions::default())), - _ => { - return Err(SystemError::EINVAL); - } - }, - AddressFamily::INet => match socket_type { - PosixSocketType::Stream => Box::new(TcpSocket::new(SocketOptions::default())), - PosixSocketType::Datagram => Box::new(UdpSocket::new(SocketOptions::default())), - PosixSocketType::Raw => Box::new(RawSocket::new(protocol, SocketOptions::default())), - _ => { - return Err(SystemError::EINVAL); - } - }, - _ => { - return Err(SystemError::EAFNOSUPPORT); - } - }; - - let handle_item = SocketHandleItem::new(Arc::downgrade(&socket.posix_item())); - HANDLE_MAP - .write_irqsave() - .insert(socket.socket_handle(), handle_item); - Ok(socket) -} - -pub trait Socket: Sync + Send + Debug + Any { - /// @brief 从socket中读取数据,如果socket是阻塞的,那么直到读取到数据才返回 - /// - /// @param buf 读取到的数据存放的缓冲区 - /// - /// @return - 成功:(返回读取的数据的长度,读取数据的端点). - /// - 失败:错误码 - fn read(&self, buf: &mut [u8]) -> (Result, Endpoint); - - /// @brief 向socket中写入数据。如果socket是阻塞的,那么直到写入的数据全部写入socket中才返回 - /// - /// @param buf 要写入的数据 - /// @param to 要写入的目的端点,如果是None,那么写入的数据将会被丢弃 - /// - /// @return 返回写入的数据的长度 - fn write(&self, buf: &[u8], to: Option) -> Result; - - /// @brief 对应于POSIX的connect函数,用于连接到指定的远程服务器端点 - /// - /// It is used to establish a connection to a remote server. - /// When a socket is connected to a remote server, - /// the operating system will establish a network connection with the server - /// and allow data to be sent and received between the local socket and the remote server. - /// - /// @param endpoint 要连接的端点 - /// - /// @return 返回连接是否成功 - fn connect(&mut self, _endpoint: Endpoint) -> Result<(), SystemError>; - - /// @brief 对应于POSIX的bind函数,用于绑定到本机指定的端点 - /// - /// The bind() function is used to associate a socket with a particular IP address and port number on the local machine. - /// - /// @param endpoint 要绑定的端点 - /// - /// @return 返回绑定是否成功 - fn bind(&mut self, _endpoint: Endpoint) -> Result<(), SystemError> { - Err(SystemError::ENOSYS) - } - - /// @brief 对应于 POSIX 的 shutdown 函数,用于关闭socket。 - /// - /// shutdown() 函数用于启动网络连接的正常关闭。 - /// 当在两个端点之间建立网络连接时,任一端点都可以通过调用其端点对象上的 shutdown() 函数来启动关闭序列。 - /// 此函数向远程端点发送关闭消息以指示本地端点不再接受新数据。 - /// - /// @return 返回是否成功关闭 - fn shutdown(&mut self, _type: ShutdownType) -> Result<(), SystemError> { - Err(SystemError::ENOSYS) - } - - /// @brief 对应于POSIX的listen函数,用于监听端点 - /// - /// @param backlog 最大的等待连接数 - /// - /// @return 返回监听是否成功 - fn listen(&mut self, _backlog: usize) -> Result<(), SystemError> { - Err(SystemError::ENOSYS) - } - - /// @brief 对应于POSIX的accept函数,用于接受连接 - /// - /// @param endpoint 对端的端点 - /// - /// @return 返回接受连接是否成功 - fn accept(&mut self) -> Result<(Box, Endpoint), SystemError> { - Err(SystemError::ENOSYS) - } - - /// @brief 获取socket的端点 - /// - /// @return 返回socket的端点 - fn endpoint(&self) -> Option { - None - } - - /// @brief 获取socket的对端端点 - /// - /// @return 返回socket的对端端点 - fn peer_endpoint(&self) -> Option { - None - } - - /// @brief - /// The purpose of the poll function is to provide - /// a non-blocking way to check if a socket is ready for reading or writing, - /// so that you can efficiently handle multiple sockets in a single thread or event loop. - /// - /// @return (in, out, err) - /// - /// The first boolean value indicates whether the socket is ready for reading. If it is true, then there is data available to be read from the socket without blocking. - /// The second boolean value indicates whether the socket is ready for writing. If it is true, then data can be written to the socket without blocking. - /// The third boolean value indicates whether the socket has encountered an error condition. If it is true, then the socket is in an error state and should be closed or reset - /// - fn poll(&self) -> EPollEventType { - EPollEventType::empty() - } - - /// @brief socket的ioctl函数 - /// - /// @param cmd ioctl命令 - /// @param arg0 ioctl命令的第一个参数 - /// @param arg1 ioctl命令的第二个参数 - /// @param arg2 ioctl命令的第三个参数 - /// - /// @return 返回ioctl命令的返回值 - fn ioctl( - &self, - _cmd: usize, - _arg0: usize, - _arg1: usize, - _arg2: usize, - ) -> Result { - Ok(0) - } - - /// @brief 获取socket的元数据 - fn metadata(&self) -> SocketMetadata; - - fn box_clone(&self) -> Box; - - /// @brief 设置socket的选项 - /// - /// @param level 选项的层次 - /// @param optname 选项的名称 - /// @param optval 选项的值 - /// - /// @return 返回设置是否成功, 如果不支持该选项,返回ENOSYS - fn setsockopt( - &self, - _level: usize, - _optname: usize, - _optval: &[u8], - ) -> Result<(), SystemError> { - warn!("setsockopt is not implemented"); - Ok(()) - } - - fn socket_handle(&self) -> GlobalSocketHandle; - - fn write_buffer(&self, _buf: &[u8]) -> Result { - todo!() - } - - fn as_any_ref(&self) -> &dyn Any; - - fn as_any_mut(&mut self) -> &mut dyn Any; - - fn add_epitem(&mut self, epitem: Arc) -> Result<(), SystemError> { - let posix_item = self.posix_item(); - posix_item.add_epitem(epitem); - Ok(()) - } - - fn remove_epitm(&mut self, epitem: &Arc) -> Result<(), SystemError> { - let posix_item = self.posix_item(); - posix_item.remove_epitem(epitem)?; - - Ok(()) - } - - fn close(&mut self); - - fn posix_item(&self) -> Arc; -} - -impl Clone for Box { - fn clone(&self) -> Box { - self.box_clone() - } -} - -/// # Socket在文件系统中的inode封装 -#[derive(Debug)] -pub struct SocketInode(SpinLock>, AtomicUsize); - -impl SocketInode { - pub fn new(socket: Box) -> Arc { - Arc::new(Self(SpinLock::new(socket), AtomicUsize::new(0))) - } - - #[inline] - pub fn inner(&self) -> SpinLockGuard> { - self.0.lock() - } - - pub unsafe fn inner_no_preempt(&self) -> SpinLockGuard> { - self.0.lock_no_preempt() - } - - fn do_close(&self) -> Result<(), SystemError> { - let prev_ref_count = self.1.fetch_sub(1, core::sync::atomic::Ordering::SeqCst); - if prev_ref_count == 1 { - // 最后一次关闭,需要释放 - let mut socket = self.0.lock_irqsave(); - - if socket.metadata().socket_type == SocketType::Unix { - return Ok(()); - } - - if let Some(Endpoint::Ip(Some(ip))) = socket.endpoint() { - PORT_MANAGER.unbind_port(socket.metadata().socket_type, ip.port); - } - - HANDLE_MAP - .write_irqsave() - .remove(&socket.socket_handle()) - .unwrap(); - socket.close(); - } - - Ok(()) - } -} - -impl Drop for SocketInode { - fn drop(&mut self) { - for _ in 0..self.1.load(core::sync::atomic::Ordering::SeqCst) { - let _ = self.do_close(); - } - } -} - -impl PollableInode for SocketInode { - fn poll(&self, _private_data: &FilePrivateData) -> Result { - let events = self.0.lock_irqsave().poll(); - return Ok(events.bits() as usize); - } - - fn add_epitem( - &self, - epitem: Arc, - _private_data: &FilePrivateData, - ) -> Result<(), SystemError> { - self.0.lock_irqsave().add_epitem(epitem) - } - - fn remove_epitem( - &self, - epitem: &Arc, - _private_data: &FilePrivateData, - ) -> Result<(), SystemError> { - self.0.lock_irqsave().remove_epitm(epitem) - } -} - -impl IndexNode for SocketInode { - fn open( - &self, - _data: SpinLockGuard, - _mode: &FileMode, - ) -> Result<(), SystemError> { - self.1.fetch_add(1, core::sync::atomic::Ordering::SeqCst); - Ok(()) - } - - fn close(&self, _data: SpinLockGuard) -> Result<(), SystemError> { - self.do_close() - } - - fn read_at( - &self, - _offset: usize, - len: usize, - buf: &mut [u8], - data: SpinLockGuard, - ) -> Result { - drop(data); - self.0.lock_no_preempt().read(&mut buf[0..len]).0 - } - - fn write_at( - &self, - _offset: usize, - len: usize, - buf: &[u8], - data: SpinLockGuard, - ) -> Result { - drop(data); - self.0.lock_no_preempt().write(&buf[0..len], None) - } - - fn fs(&self) -> Arc { - todo!() - } +use crate::libs::wait_queue::WaitQueue; +pub use base::Socket; - fn as_any_ref(&self) -> &dyn Any { - self - } - - fn list(&self) -> Result, SystemError> { - return Err(SystemError::ENOTDIR); - } - - fn metadata(&self) -> Result { - let meta = Metadata { - mode: ModeType::from_bits_truncate(0o755), - file_type: FileType::Socket, - ..Default::default() - }; - - return Ok(meta); - } - - fn resize(&self, _len: usize) -> Result<(), SystemError> { - return Ok(()); - } - - fn as_pollable_inode(&self) -> Result<&dyn PollableInode, SystemError> { - Ok(self) - } -} - -#[derive(Debug)] -pub struct PosixSocketHandleItem { - /// socket的waitqueue - wait_queue: Arc, - - pub epitems: SpinLock>>, -} - -impl PosixSocketHandleItem { - pub fn new(wait_queue: Option>) -> Self { - Self { - wait_queue: wait_queue.unwrap_or(Arc::new(EventWaitQueue::new())), - epitems: SpinLock::new(LinkedList::new()), - } - } - /// ## 在socket的等待队列上睡眠 - pub fn sleep(&self, events: u64) { - unsafe { - ProcessManager::preempt_disable(); - self.wait_queue.sleep_without_schedule(events); - ProcessManager::preempt_enable(); - } - schedule(SchedMode::SM_NONE); - } - - pub fn add_epitem(&self, epitem: Arc) { - self.epitems.lock_irqsave().push_back(epitem) - } - - pub fn remove_epitem(&self, epitem: &Arc) -> Result<(), SystemError> { - let mut guard = self.epitems.lock(); - let len = guard.len(); - guard.retain(|x| !Arc::ptr_eq(x, epitem)); - if len != guard.len() { - return Ok(()); - } - Err(SystemError::ENOENT) - } - - /// ### 唤醒该队列上等待events的进程 - /// - /// ### 参数 - /// - events: 发生的事件 - /// - /// 需要注意的是,只要触发了events中的任意一件事件,进程都会被唤醒 - pub fn wakeup_any(&self, events: u64) { - self.wait_queue.wakeup_any(events); - } -} -#[derive(Debug)] -pub struct SocketHandleItem { - /// 对应的posix socket是否为listen的 - pub is_posix_listen: bool, - /// shutdown状态 - pub shutdown_type: RwLock, - pub posix_item: Weak, -} - -impl SocketHandleItem { - pub fn new(posix_item: Weak) -> Self { - Self { - is_posix_listen: false, - shutdown_type: RwLock::new(ShutdownType::empty()), - posix_item, - } - } - - pub fn shutdown_type(&self) -> ShutdownType { - *self.shutdown_type.read() - } - - pub fn shutdown_type_writer(&mut self) -> RwLockWriteGuard { - self.shutdown_type.write_irqsave() - } - - pub fn reset_shutdown_type(&self) { - *self.shutdown_type.write() = ShutdownType::empty(); - } - - pub fn posix_item(&self) -> Option> { - self.posix_item.upgrade() - } -} - -/// # TCP 和 UDP 的端口管理器。 -/// 如果 TCP/UDP 的 socket 绑定了某个端口,它会在对应的表中记录,以检测端口冲突。 -pub struct PortManager { - // TCP 端口记录表 - tcp_port_table: SpinLock>, - // UDP 端口记录表 - udp_port_table: SpinLock>, -} - -impl PortManager { - pub fn new() -> Self { - return Self { - tcp_port_table: SpinLock::new(HashMap::new()), - udp_port_table: SpinLock::new(HashMap::new()), - }; - } - - /// @brief 自动分配一个相对应协议中未被使用的PORT,如果动态端口均已被占用,返回错误码 EADDRINUSE - pub fn get_ephemeral_port(&self, socket_type: SocketType) -> Result { - // TODO: selects non-conflict high port - - static mut EPHEMERAL_PORT: u16 = 0; - unsafe { - if EPHEMERAL_PORT == 0 { - EPHEMERAL_PORT = (49152 + rand() % (65536 - 49152)) as u16; - } - } - - let mut remaining = 65536 - 49152; // 剩余尝试分配端口次数 - let mut port: u16; - while remaining > 0 { - unsafe { - if EPHEMERAL_PORT == 65535 { - EPHEMERAL_PORT = 49152; - } else { - EPHEMERAL_PORT += 1; - } - port = EPHEMERAL_PORT; - } - - // 使用 ListenTable 检查端口是否被占用 - let listen_table_guard = match socket_type { - SocketType::Udp => self.udp_port_table.lock(), - SocketType::Tcp => self.tcp_port_table.lock(), - _ => panic!("{:?} cann't get a port", socket_type), - }; - if listen_table_guard.get(&port).is_none() { - drop(listen_table_guard); - return Ok(port); - } - remaining -= 1; - } - return Err(SystemError::EADDRINUSE); - } - - /// @brief 检测给定端口是否已被占用,如果未被占用则在 TCP/UDP 对应的表中记录 - /// - /// TODO: 增加支持端口复用的逻辑 - pub fn bind_port(&self, socket_type: SocketType, port: u16) -> Result<(), SystemError> { - if port > 0 { - let mut listen_table_guard = match socket_type { - SocketType::Udp => self.udp_port_table.lock(), - SocketType::Tcp => self.tcp_port_table.lock(), - _ => panic!("{:?} cann't bind a port", socket_type), - }; - match listen_table_guard.get(&port) { - Some(_) => return Err(SystemError::EADDRINUSE), - None => listen_table_guard.insert(port, ProcessManager::current_pid()), - }; - drop(listen_table_guard); - } - return Ok(()); - } - - /// @brief 在对应的端口记录表中将端口和 socket 解绑 - /// should call this function when socket is closed or aborted - pub fn unbind_port(&self, socket_type: SocketType, port: u16) { - let mut listen_table_guard = match socket_type { - SocketType::Udp => self.udp_port_table.lock(), - SocketType::Tcp => self.tcp_port_table.lock(), - _ => { - return; - } - }; - listen_table_guard.remove(&port); - drop(listen_table_guard); - } -} - -/// @brief socket的类型 -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum SocketType { - /// 原始的socket - Raw, - /// 用于Tcp通信的 Socket - Tcp, - /// 用于Udp通信的 Socket - Udp, - /// unix域的 Socket - Unix, -} - -bitflags! { - /// @brief socket的选项 - #[derive(Default)] - pub struct SocketOptions: u32 { - /// 是否阻塞 - const BLOCK = 1 << 0; - /// 是否允许广播 - const BROADCAST = 1 << 1; - /// 是否允许多播 - const MULTICAST = 1 << 2; - /// 是否允许重用地址 - const REUSEADDR = 1 << 3; - /// 是否允许重用端口 - const REUSEPORT = 1 << 4; - } -} - -#[derive(Debug, Clone)] -/// @brief 在trait Socket的metadata函数中返回该结构体供外部使用 -pub struct SocketMetadata { - /// socket的类型 - pub socket_type: SocketType, - /// 接收缓冲区的大小 - pub rx_buf_size: usize, - /// 发送缓冲区的大小 - pub tx_buf_size: usize, - /// 元数据的缓冲区的大小 - pub metadata_buf_size: usize, - /// socket的选项 - pub options: SocketOptions, -} - -impl SocketMetadata { - fn new( - socket_type: SocketType, - rx_buf_size: usize, - tx_buf_size: usize, - metadata_buf_size: usize, - options: SocketOptions, - ) -> Self { - Self { - socket_type, - rx_buf_size, - tx_buf_size, - metadata_buf_size, - options, - } - } -} - -/// @brief 地址族的枚举 -/// -/// 参考:https://code.dragonos.org.cn/xref/linux-5.19.10/include/linux/socket.h#180 -#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] -pub enum AddressFamily { - /// AF_UNSPEC 表示地址族未指定 - Unspecified = 0, - /// AF_UNIX 表示Unix域的socket (与AF_LOCAL相同) - Unix = 1, - /// AF_INET 表示IPv4的socket - INet = 2, - /// AF_AX25 表示AMPR AX.25的socket - AX25 = 3, - /// AF_IPX 表示IPX的socket - IPX = 4, - /// AF_APPLETALK 表示Appletalk的socket - Appletalk = 5, - /// AF_NETROM 表示AMPR NET/ROM的socket - Netrom = 6, - /// AF_BRIDGE 表示多协议桥接的socket - Bridge = 7, - /// AF_ATMPVC 表示ATM PVCs的socket - Atmpvc = 8, - /// AF_X25 表示X.25的socket - X25 = 9, - /// AF_INET6 表示IPv6的socket - INet6 = 10, - /// AF_ROSE 表示AMPR ROSE的socket - Rose = 11, - /// AF_DECnet Reserved for DECnet project - Decnet = 12, - /// AF_NETBEUI Reserved for 802.2LLC project - Netbeui = 13, - /// AF_SECURITY 表示Security callback的伪AF - Security = 14, - /// AF_KEY 表示Key management API - Key = 15, - /// AF_NETLINK 表示Netlink的socket - Netlink = 16, - /// AF_PACKET 表示Low level packet interface - Packet = 17, - /// AF_ASH 表示Ash - Ash = 18, - /// AF_ECONET 表示Acorn Econet - Econet = 19, - /// AF_ATMSVC 表示ATM SVCs - Atmsvc = 20, - /// AF_RDS 表示Reliable Datagram Sockets - Rds = 21, - /// AF_SNA 表示Linux SNA Project - Sna = 22, - /// AF_IRDA 表示IRDA sockets - Irda = 23, - /// AF_PPPOX 表示PPPoX sockets - Pppox = 24, - /// AF_WANPIPE 表示WANPIPE API sockets - WanPipe = 25, - /// AF_LLC 表示Linux LLC - Llc = 26, - /// AF_IB 表示Native InfiniBand address - /// 介绍:https://access.redhat.com/documentation/en-us/red_hat_enterprise_linux/9/html-single/configuring_infiniband_and_rdma_networks/index#understanding-infiniband-and-rdma_configuring-infiniband-and-rdma-networks - Ib = 27, - /// AF_MPLS 表示MPLS - Mpls = 28, - /// AF_CAN 表示Controller Area Network - Can = 29, - /// AF_TIPC 表示TIPC sockets - Tipc = 30, - /// AF_BLUETOOTH 表示Bluetooth sockets - Bluetooth = 31, - /// AF_IUCV 表示IUCV sockets - Iucv = 32, - /// AF_RXRPC 表示RxRPC sockets - Rxrpc = 33, - /// AF_ISDN 表示mISDN sockets - Isdn = 34, - /// AF_PHONET 表示Phonet sockets - Phonet = 35, - /// AF_IEEE802154 表示IEEE 802.15.4 sockets - Ieee802154 = 36, - /// AF_CAIF 表示CAIF sockets - Caif = 37, - /// AF_ALG 表示Algorithm sockets - Alg = 38, - /// AF_NFC 表示NFC sockets - Nfc = 39, - /// AF_VSOCK 表示vSockets - Vsock = 40, - /// AF_KCM 表示Kernel Connection Multiplexor - Kcm = 41, - /// AF_QIPCRTR 表示Qualcomm IPC Router - Qipcrtr = 42, - /// AF_SMC 表示SMC-R sockets. - /// reserve number for PF_SMC protocol family that reuses AF_INET address family - Smc = 43, - /// AF_XDP 表示XDP sockets - Xdp = 44, - /// AF_MCTP 表示Management Component Transport Protocol - Mctp = 45, - /// AF_MAX 表示最大的地址族 - Max = 46, -} - -impl TryFrom for AddressFamily { - type Error = SystemError; - fn try_from(x: u16) -> Result { - use num_traits::FromPrimitive; - return ::from_u16(x).ok_or(SystemError::EINVAL); - } -} - -/// @brief posix套接字类型的枚举(这些值与linux内核中的值一致) -#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] -pub enum PosixSocketType { - Stream = 1, - Datagram = 2, - Raw = 3, - Rdm = 4, - SeqPacket = 5, - Dccp = 6, - Packet = 10, -} - -impl TryFrom for PosixSocketType { - type Error = SystemError; - fn try_from(x: u8) -> Result { - use num_traits::FromPrimitive; - return ::from_u8(x).ok_or(SystemError::EINVAL); - } -} - -/// ### 为socket提供无锁的poll方法 -/// -/// 因为在网卡中断中,需要轮询socket的状态,如果使用socket文件或者其inode来poll -/// 在当前的设计,会必然死锁,所以引用这一个设计来解决,提供无🔓的poll -pub struct SocketPollMethod; - -impl SocketPollMethod { - pub fn poll(socket: &socket::Socket, handle_item: &SocketHandleItem) -> EPollEventType { - let shutdown = handle_item.shutdown_type(); - match socket { - socket::Socket::Udp(udp) => Self::udp_poll(udp, shutdown), - socket::Socket::Tcp(tcp) => Self::tcp_poll(tcp, shutdown, handle_item.is_posix_listen), - socket::Socket::Raw(raw) => Self::raw_poll(raw, shutdown), - _ => todo!(), - } - } - - pub fn tcp_poll( - socket: &tcp::Socket, - shutdown: ShutdownType, - is_posix_listen: bool, - ) -> EPollEventType { - let mut events = EPollEventType::empty(); - // debug!("enter tcp_poll! is_posix_listen:{}", is_posix_listen); - // 处理listen的socket - if is_posix_listen { - // 如果是listen的socket,那么只有EPOLLIN和EPOLLRDNORM - if socket.is_active() { - events.insert(EPollEventType::EPOLL_LISTEN_CAN_ACCEPT); - } - - // debug!("tcp_poll listen socket! events:{:?}", events); - return events; - } - - let state = socket.state(); - - if shutdown == ShutdownType::SHUTDOWN_MASK || state == tcp::State::Closed { - events.insert(EPollEventType::EPOLLHUP); - } - - if shutdown.contains(ShutdownType::RCV_SHUTDOWN) { - events.insert( - EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM | EPollEventType::EPOLLRDHUP, - ); - } - - // Connected or passive Fast Open socket? - if state != tcp::State::SynSent && state != tcp::State::SynReceived { - // socket有可读数据 - if socket.can_recv() { - events.insert(EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM); - } - - if !(shutdown.contains(ShutdownType::SEND_SHUTDOWN)) { - // 缓冲区可写(这里判断可写的逻辑好像跟linux不太一样) - if socket.send_queue() < socket.send_capacity() { - events.insert(EPollEventType::EPOLLOUT | EPollEventType::EPOLLWRNORM); - } else { - // TODO:触发缓冲区已满的信号SIGIO - todo!("A signal SIGIO that the buffer is full needs to be sent"); - } - } else { - // 如果我们的socket关闭了SEND_SHUTDOWN,epoll事件就是EPOLLOUT - events.insert(EPollEventType::EPOLLOUT | EPollEventType::EPOLLWRNORM); - } - } else if state == tcp::State::SynSent { - events.insert(EPollEventType::EPOLLOUT | EPollEventType::EPOLLWRNORM); - } - - // socket发生错误 - // TODO: 这里的逻辑可能有问题,需要进一步验证是否is_active()==false就代表socket发生错误 - if !socket.is_active() { - events.insert(EPollEventType::EPOLLERR); - } - - events - } - - pub fn udp_poll(socket: &udp::Socket, shutdown: ShutdownType) -> EPollEventType { - let mut event = EPollEventType::empty(); - - if shutdown.contains(ShutdownType::RCV_SHUTDOWN) { - event.insert( - EPollEventType::EPOLLRDHUP | EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM, - ); - } - if shutdown.contains(ShutdownType::SHUTDOWN_MASK) { - event.insert(EPollEventType::EPOLLHUP); - } - - if socket.can_recv() { - event.insert(EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM); - } - - if socket.can_send() { - event.insert( - EPollEventType::EPOLLOUT - | EPollEventType::EPOLLWRNORM - | EPollEventType::EPOLLWRBAND, - ); - } else { - // TODO: 缓冲区空间不够,需要使用信号处理 - todo!() - } - - return event; - } - - pub fn raw_poll(socket: &raw::Socket, shutdown: ShutdownType) -> EPollEventType { - //debug!("enter raw_poll!"); - let mut event = EPollEventType::empty(); - - if shutdown.contains(ShutdownType::RCV_SHUTDOWN) { - event.insert( - EPollEventType::EPOLLRDHUP | EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM, - ); - } - if shutdown.contains(ShutdownType::SHUTDOWN_MASK) { - event.insert(EPollEventType::EPOLLHUP); - } - - if socket.can_recv() { - //debug!("poll can recv!"); - event.insert(EPollEventType::EPOLLIN | EPollEventType::EPOLLRDNORM); - } else { - //debug!("poll can not recv!"); - } - - if socket.can_send() { - //debug!("poll can send!"); - event.insert( - EPollEventType::EPOLLOUT - | EPollEventType::EPOLLWRNORM - | EPollEventType::EPOLLWRBAND, - ); - } else { - //debug!("poll can not send!"); - // TODO: 缓冲区空间不够,需要使用信号处理 - todo!() - } - return event; - } -} +pub use common::{ + // poll_unit::{EPollItems, WaitQueue}, + EPollItems, +}; +pub use family::{AddressFamily, Family}; +pub use inode::SocketInode; +pub use posix::PMSG; +pub use posix::PSO; +pub use posix::PSOCK; +pub use posix::PSOL; +pub use utils::create_socket; +// pub use crate::net::sys diff --git a/kernel/src/net/socket/posix/mod.rs b/kernel/src/net/socket/posix/mod.rs new file mode 100644 index 000000000..45a138bbd --- /dev/null +++ b/kernel/src/net/socket/posix/mod.rs @@ -0,0 +1,12 @@ +// posix socket and arguments definitions +// now all posix definitions are with P front like MSG -> PMSG, +// for better understanding and avoiding conflicts with other definitions +mod msg_flag; +mod option; +mod option_level; +mod types; + +pub use msg_flag::MessageFlag as PMSG; // Socket message flags MSG_* +pub use option::Options as PSO; // Socket options SO_* +pub use option_level::OptionLevel as PSOL; // Socket options level SOL_* +pub use types::PSOCK; // Socket types SOCK_* diff --git a/kernel/src/net/socket/posix/msg_flag.rs b/kernel/src/net/socket/posix/msg_flag.rs new file mode 100644 index 000000000..976ac6762 --- /dev/null +++ b/kernel/src/net/socket/posix/msg_flag.rs @@ -0,0 +1,110 @@ +bitflags::bitflags! { + /// # Message Flags + /// Flags we can use with send/ and recv. \ + /// Added those for 1003.1g not all are supported yet + /// ## Reference + /// - [Linux Socket Flags](https://code.dragonos.org.cn/xref/linux-6.6.21/include/linux/socket.h#299) + pub struct MessageFlag: u32 { + /// `MSG_OOB` + /// `0b0000_0001`\ + /// Process out-of-band data. + const OOB = 1; + /// `MSG_PEEK` + /// `0b0000_0010`\ + /// Peek at an incoming message. + const PEEK = 2; + /// `MSG_DONTROUTE` + /// `0b0000_0100`\ + /// Don't use routing tables. + const DONTROUTE = 4; + /// `MSG_TRYHARD` + /// `0b0000_0100`\ + /// `MSG_TRYHARD` is not defined in the standard, but it is used in Linux. + const TRYHARD = 4; + /// `MSG_CTRUNC` + /// `0b0000_1000`\ + /// Control data lost before delivery. + const CTRUNC = 8; + /// `MSG_PROBE` + /// `0b0001_0000`\ + const PROBE = 0x10; + /// `MSG_TRUNC` + /// `0b0010_0000`\ + /// Data truncated before delivery. + const TRUNC = 0x20; + /// `MSG_DONTWAIT` + /// `0b0100_0000`\ + /// This flag is used to make the socket non-blocking. + const DONTWAIT = 0x40; + /// `MSG_EOR` + /// `0b1000_0000`\ + /// End of record. + const EOR = 0x80; + /// `MSG_WAITALL` + /// `0b0001_0000_0000`\ + /// Wait for full request or error. + const WAITALL = 0x100; + /// `MSG_FIN` + /// `0b0010_0000_0000`\ + /// Terminate the connection. + const FIN = 0x200; + /// `MSG_SYN` + /// `0b0100_0000_0000`\ + /// Synchronize sequence numbers. + const SYN = 0x400; + /// `MSG_CONFIRM` + /// `0b1000_0000_0000`\ + /// Confirm path validity. + const CONFIRM = 0x800; + /// `MSG_RST` + /// `0b0001_0000_0000_0000`\ + /// Reset the connection. + const RST = 0x1000; + /// `MSG_ERRQUEUE` + /// `0b0010_0000_0000_0000`\ + /// Fetch message from error queue. + const ERRQUEUE = 0x2000; + /// `MSG_NOSIGNAL` + /// `0b0100_0000_0000_0000`\ + /// Do not generate a signal. + const NOSIGNAL = 0x4000; + /// `MSG_MORE` + /// `0b1000_0000_0000_0000`\ + /// Sender will send more. + const MORE = 0x8000; + /// `MSG_WAITFORONE` + /// `0b0001_0000_0000_0000_0000`\ + /// For nonblocking operation. + const WAITFORONE = 0x10000; + /// `MSG_SENDPAGE_NOPOLICY` + /// `0b0010_0000_0000_0000_0000`\ + /// Sendpage: do not apply policy. + const SENDPAGE_NOPOLICY = 0x10000; + /// `MSG_BATCH` + /// `0b0100_0000_0000_0000_0000`\ + /// Sendpage: next message is batch. + const BATCH = 0x40000; + /// `MSG_EOF` + const EOF = Self::FIN.bits; + /// `MSG_NO_SHARED_FRAGS` + const NO_SHARED_FRAGS = 0x80000; + /// `MSG_SENDPAGE_DECRYPTED` + const SENDPAGE_DECRYPTED = 0x10_0000; + + /// `MSG_ZEROCOPY` + const ZEROCOPY = 0x400_0000; + /// `MSG_SPLICE_PAGES` + const SPLICE_PAGES = 0x800_0000; + /// `MSG_FASTOPEN` + const FASTOPEN = 0x2000_0000; + /// `MSG_CMSG_CLOEXEC` + const CMSG_CLOEXEC = 0x4000_0000; + /// `MSG_CMSG_COMPAT` + // if define CONFIG_COMPAT + // const CMSG_COMPAT = 0x8000_0000; + const CMSG_COMPAT = 0; + /// `MSG_INTERNAL_SENDMSG_FLAGS` + const INTERNAL_SENDMSG_FLAGS + = Self::SPLICE_PAGES.bits | Self::SENDPAGE_NOPOLICY.bits | Self::SENDPAGE_DECRYPTED.bits; + } +} diff --git a/kernel/src/net/socket/posix/option.rs b/kernel/src/net/socket/posix/option.rs new file mode 100644 index 000000000..31464ee7c --- /dev/null +++ b/kernel/src/net/socket/posix/option.rs @@ -0,0 +1,92 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] +#[allow(non_camel_case_types)] +pub enum Options { + DEBUG = 1, + REUSEADDR = 2, + TYPE = 3, + ERROR = 4, + DONTROUTE = 5, + BROADCAST = 6, + SNDBUF = 7, + RCVBUF = 8, + SNDBUFFORCE = 32, + RCVBUFFORCE = 33, + KEEPALIVE = 9, + OOBINLINE = 10, + NO_CHECK = 11, + PRIORITY = 12, + LINGER = 13, + BSDCOMPAT = 14, + REUSEPORT = 15, + PASSCRED = 16, + PEERCRED = 17, + RCVLOWAT = 18, + SNDLOWAT = 19, + RCVTIMEO_OLD = 20, + SNDTIMEO_OLD = 21, + SECURITY_AUTHENTICATION = 22, + SECURITY_ENCRYPTION_TRANSPORT = 23, + SECURITY_ENCRYPTION_NETWORK = 24, + BINDTODEVICE = 25, + /// 与GET_FILTER相同 + ATTACH_FILTER = 26, + DETACH_FILTER = 27, + PEERNAME = 28, + ACCEPTCONN = 30, + PEERSEC = 31, + PASSSEC = 34, + MARK = 36, + PROTOCOL = 38, + DOMAIN = 39, + RXQ_OVFL = 40, + /// 与SCM_WIFI_STATUS相同 + WIFI_STATUS = 41, + PEEK_OFF = 42, + /* Instruct lower device to use last 4-bytes of skb data as FCS */ + NOFCS = 43, + LOCK_FILTER = 44, + SELECT_ERR_QUEUE = 45, + BUSY_POLL = 46, + MAX_PACING_RATE = 47, + BPF_EXTENSIONS = 48, + INCOMING_CPU = 49, + ATTACH_BPF = 50, + // DETACH_BPF = DETACH_FILTER, + ATTACH_REUSEPORT_CBPF = 51, + ATTACH_REUSEPORT_EBPF = 52, + CNX_ADVICE = 53, + SCM_TIMESTAMPING_OPT_STATS = 54, + MEMINFO = 55, + INCOMING_NAPI_ID = 56, + COOKIE = 57, + SCM_TIMESTAMPING_PKTINFO = 58, + PEERGROUPS = 59, + ZEROCOPY = 60, + /// 与SCM_TXTIME相同 + TXTIME = 61, + BINDTOIFINDEX = 62, + TIMESTAMP_OLD = 29, + TIMESTAMPNS_OLD = 35, + TIMESTAMPING_OLD = 37, + TIMESTAMP_NEW = 63, + TIMESTAMPNS_NEW = 64, + TIMESTAMPING_NEW = 65, + RCVTIMEO_NEW = 66, + SNDTIMEO_NEW = 67, + DETACH_REUSEPORT_BPF = 68, + PREFER_BUSY_POLL = 69, + BUSY_POLL_BUDGET = 70, + NETNS_COOKIE = 71, + BUF_LOCK = 72, + RESERVE_MEM = 73, + TXREHASH = 74, + RCVMARK = 75, +} + +impl TryFrom for Options { + type Error = system_error::SystemError; + fn try_from(x: u32) -> Result { + use num_traits::FromPrimitive; + return ::from_u32(x).ok_or(system_error::SystemError::EINVAL); + } +} diff --git a/kernel/src/net/socket/posix/option_level.rs b/kernel/src/net/socket/posix/option_level.rs new file mode 100644 index 000000000..3357a7ae1 --- /dev/null +++ b/kernel/src/net/socket/posix/option_level.rs @@ -0,0 +1,67 @@ +/// # SOL (Socket Option Level) +/// Setsockoptions(2) level. Thanks to BSD these must match IPPROTO_xxx +/// ## Reference +/// - [Setsockoptions(2) level](https://code.dragonos.org.cn/xref/linux-6.6.21/include/linux/socket.h#345) +#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] +#[allow(non_camel_case_types)] +pub enum OptionLevel { + IP = 0, + SOCKET = 1, + // ICMP = 1, No-no-no! Due to Linux :-) we cannot + TCP = 6, + UDP = 17, + IPV6 = 41, + ICMPV6 = 58, + SCTP = 132, + UDPLITE = 136, // UDP-Lite (RFC 3828) + RAW = 255, + IPX = 256, + AX25 = 257, + ATALK = 258, + NETROM = 259, + ROSE = 260, + DECNET = 261, + X25 = 262, + PACKET = 263, + ATM = 264, // ATM layer (cell level) + AAL = 265, // ATM Adaption Layer (packet level) + IRDA = 266, + NETBEUI = 267, + LLC = 268, + DCCP = 269, + NETLINK = 270, + TIPC = 271, + RXRPC = 272, + PPPOL2TP = 273, + BLUETOOTH = 274, + PNPIPE = 275, + RDS = 276, + IUCV = 277, + CAIF = 278, + ALG = 279, + NFC = 280, + KCM = 281, + TLS = 282, + XDP = 283, + MPTCP = 284, + MCTP = 285, + SMC = 286, + VSOCK = 287, +} + +impl TryFrom for OptionLevel { + type Error = system_error::SystemError; + + fn try_from(value: u32) -> Result { + match ::from_u32(value) { + Some(p) => Ok(p), + None => Err(system_error::SystemError::EPROTONOSUPPORT), + } + } +} + +impl From for u32 { + fn from(value: OptionLevel) -> Self { + ::to_u32(&value).unwrap() + } +} diff --git a/kernel/src/net/socket/posix/types.rs b/kernel/src/net/socket/posix/types.rs new file mode 100644 index 000000000..78553f495 --- /dev/null +++ b/kernel/src/net/socket/posix/types.rs @@ -0,0 +1,20 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] +pub enum PSOCK { + Stream = 1, + Datagram = 2, + Raw = 3, + RDM = 4, + SeqPacket = 5, + DCCP = 6, + Packet = 10, +} + +use crate::net::posix::PosixArgsSocketType; +impl TryFrom for PSOCK { + type Error = system_error::SystemError; + fn try_from(x: PosixArgsSocketType) -> Result { + use num_traits::FromPrimitive; + return ::from_u32(x.types().bits()) + .ok_or(system_error::SystemError::EINVAL); + } +} diff --git a/kernel/src/net/socket/unix.rs b/kernel/src/net/socket/unix.rs deleted file mode 100644 index f15037775..000000000 --- a/kernel/src/net/socket/unix.rs +++ /dev/null @@ -1,239 +0,0 @@ -use alloc::{boxed::Box, sync::Arc, vec::Vec}; -use system_error::SystemError; - -use crate::{libs::spinlock::SpinLock, net::Endpoint}; - -use super::{ - handle::GlobalSocketHandle, PosixSocketHandleItem, Socket, SocketInode, SocketMetadata, - SocketOptions, SocketType, -}; - -#[derive(Debug, Clone)] -pub struct StreamSocket { - metadata: SocketMetadata, - buffer: Arc>>, - peer_inode: Option>, - handle: GlobalSocketHandle, - posix_item: Arc, -} - -impl StreamSocket { - /// 默认的元数据缓冲区大小 - pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; - /// 默认的缓冲区大小 - pub const DEFAULT_BUF_SIZE: usize = 64 * 1024; - - /// # 创建一个 Stream Socket - /// - /// ## 参数 - /// - `options`: socket选项 - pub fn new(options: SocketOptions) -> Self { - let buffer = Arc::new(SpinLock::new(Vec::with_capacity(Self::DEFAULT_BUF_SIZE))); - - let metadata = SocketMetadata::new( - SocketType::Unix, - Self::DEFAULT_BUF_SIZE, - Self::DEFAULT_BUF_SIZE, - Self::DEFAULT_METADATA_BUF_SIZE, - options, - ); - - let posix_item = Arc::new(PosixSocketHandleItem::new(None)); - - Self { - metadata, - buffer, - peer_inode: None, - handle: GlobalSocketHandle::new_kernel_handle(), - posix_item, - } - } -} - -impl Socket for StreamSocket { - fn posix_item(&self) -> Arc { - self.posix_item.clone() - } - fn socket_handle(&self) -> GlobalSocketHandle { - self.handle - } - - fn close(&mut self) {} - - fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { - let mut buffer = self.buffer.lock_irqsave(); - - let len = core::cmp::min(buf.len(), buffer.len()); - buf[..len].copy_from_slice(&buffer[..len]); - - let _ = buffer.split_off(len); - - (Ok(len), Endpoint::Inode(self.peer_inode.clone())) - } - - fn write(&self, buf: &[u8], _to: Option) -> Result { - if self.peer_inode.is_none() { - return Err(SystemError::ENOTCONN); - } - - let peer_inode = self.peer_inode.clone().unwrap(); - let len = peer_inode.inner().write_buffer(buf)?; - Ok(len) - } - - fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { - if self.peer_inode.is_some() { - return Err(SystemError::EISCONN); - } - - if let Endpoint::Inode(inode) = endpoint { - self.peer_inode = inode; - Ok(()) - } else { - Err(SystemError::EINVAL) - } - } - - fn write_buffer(&self, buf: &[u8]) -> Result { - let mut buffer = self.buffer.lock_irqsave(); - - let len = buf.len(); - if buffer.capacity() - buffer.len() < len { - return Err(SystemError::ENOBUFS); - } - buffer.extend_from_slice(buf); - - Ok(len) - } - - fn metadata(&self) -> SocketMetadata { - self.metadata.clone() - } - - fn box_clone(&self) -> Box { - Box::new(self.clone()) - } - - fn as_any_ref(&self) -> &dyn core::any::Any { - self - } - - fn as_any_mut(&mut self) -> &mut dyn core::any::Any { - self - } -} - -#[derive(Debug, Clone)] -pub struct SeqpacketSocket { - metadata: SocketMetadata, - buffer: Arc>>, - peer_inode: Option>, - handle: GlobalSocketHandle, - posix_item: Arc, -} - -impl SeqpacketSocket { - /// 默认的元数据缓冲区大小 - pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; - /// 默认的缓冲区大小 - pub const DEFAULT_BUF_SIZE: usize = 64 * 1024; - - /// # 创建一个 Seqpacket Socket - /// - /// ## 参数 - /// - `options`: socket选项 - pub fn new(options: SocketOptions) -> Self { - let buffer = Arc::new(SpinLock::new(Vec::with_capacity(Self::DEFAULT_BUF_SIZE))); - - let metadata = SocketMetadata::new( - SocketType::Unix, - Self::DEFAULT_BUF_SIZE, - Self::DEFAULT_BUF_SIZE, - Self::DEFAULT_METADATA_BUF_SIZE, - options, - ); - - let posix_item = Arc::new(PosixSocketHandleItem::new(None)); - - Self { - metadata, - buffer, - peer_inode: None, - handle: GlobalSocketHandle::new_kernel_handle(), - posix_item, - } - } -} - -impl Socket for SeqpacketSocket { - fn posix_item(&self) -> Arc { - self.posix_item.clone() - } - fn close(&mut self) {} - - fn read(&self, buf: &mut [u8]) -> (Result, Endpoint) { - let mut buffer = self.buffer.lock_irqsave(); - - let len = core::cmp::min(buf.len(), buffer.len()); - buf[..len].copy_from_slice(&buffer[..len]); - - let _ = buffer.split_off(len); - - (Ok(len), Endpoint::Inode(self.peer_inode.clone())) - } - - fn write(&self, buf: &[u8], _to: Option) -> Result { - if self.peer_inode.is_none() { - return Err(SystemError::ENOTCONN); - } - - let peer_inode = self.peer_inode.clone().unwrap(); - let len = peer_inode.inner().write_buffer(buf)?; - Ok(len) - } - - fn connect(&mut self, endpoint: Endpoint) -> Result<(), SystemError> { - if self.peer_inode.is_some() { - return Err(SystemError::EISCONN); - } - - if let Endpoint::Inode(inode) = endpoint { - self.peer_inode = inode; - Ok(()) - } else { - Err(SystemError::EINVAL) - } - } - - fn write_buffer(&self, buf: &[u8]) -> Result { - let mut buffer = self.buffer.lock_irqsave(); - - let len = buf.len(); - if buffer.capacity() - buffer.len() < len { - return Err(SystemError::ENOBUFS); - } - buffer.extend_from_slice(buf); - - Ok(len) - } - - fn socket_handle(&self) -> GlobalSocketHandle { - self.handle - } - - fn metadata(&self) -> SocketMetadata { - self.metadata.clone() - } - - fn box_clone(&self) -> Box { - Box::new(self.clone()) - } - - fn as_any_ref(&self) -> &dyn core::any::Any { - self - } - - fn as_any_mut(&mut self) -> &mut dyn core::any::Any { - self - } -} diff --git a/kernel/src/net/socket/unix/mod.rs b/kernel/src/net/socket/unix/mod.rs new file mode 100644 index 000000000..ac99ba36c --- /dev/null +++ b/kernel/src/net/socket/unix/mod.rs @@ -0,0 +1,42 @@ +pub mod ns; +pub(crate) mod seqpacket; +pub mod stream; +use crate::{filesystem::vfs::InodeId, libs::rwlock::RwLock}; +use alloc::sync::Arc; +use hashbrown::HashMap; +use system_error::SystemError; + +use super::{endpoint::Endpoint, Family, SocketInode, PSOCK}; +pub struct Unix; + +lazy_static! { + pub static ref INODE_MAP: RwLock> = RwLock::new(HashMap::new()); +} + +fn create_unix_socket(sock_type: PSOCK) -> Result, SystemError> { + match sock_type { + PSOCK::Stream | PSOCK::Datagram => stream::StreamSocket::new_inode(), + PSOCK::SeqPacket => seqpacket::SeqpacketSocket::new_inode(false), + _ => Err(SystemError::EPROTONOSUPPORT), + } +} + +impl Family for Unix { + fn socket(stype: PSOCK, _protocol: u32) -> Result, SystemError> { + let socket = create_unix_socket(stype)?; + Ok(socket) + } +} + +impl Unix { + pub fn new_pairs( + socket_type: PSOCK, + ) -> Result<(Arc, Arc), SystemError> { + // log::debug!("socket_type {:?}", socket_type); + match socket_type { + PSOCK::SeqPacket => seqpacket::SeqpacketSocket::new_pairs(), + PSOCK::Stream | PSOCK::Datagram => stream::StreamSocket::new_pairs(), + _ => todo!(), + } + } +} diff --git a/kernel/src/net/socket/unix/ns/abs.rs b/kernel/src/net/socket/unix/ns/abs.rs new file mode 100644 index 000000000..c62493abc --- /dev/null +++ b/kernel/src/net/socket/unix/ns/abs.rs @@ -0,0 +1,171 @@ +use core::fmt; + +use crate::{libs::spinlock::SpinLock, net::socket::endpoint::Endpoint}; +use alloc::string::String; +use hashbrown::HashMap; +use ida::IdAllocator; +use system_error::SystemError; + +lazy_static! { + pub static ref ABSHANDLE_MAP: AbsHandleMap = AbsHandleMap::new(); +} + +lazy_static! { + pub static ref ABS_INODE_MAP: SpinLock> = + SpinLock::new(HashMap::new()); +} + +static ABS_ADDRESS_ALLOCATOR: SpinLock = + SpinLock::new(IdAllocator::new(0, (1 << 20) as usize).unwrap()); + +#[derive(Debug, Clone)] +pub struct AbsHandle(usize); + +impl AbsHandle { + pub fn new(name: usize) -> Self { + Self(name) + } + + pub fn name(&self) -> usize { + self.0 + } +} + +impl fmt::Display for AbsHandle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:05x}", self.0) + } +} + +/// 抽象地址映射表 +/// +/// 负责管理抽象命名空间内的地址 +pub struct AbsHandleMap { + abs_handle_map: SpinLock>, +} + +impl AbsHandleMap { + pub fn new() -> Self { + Self { + abs_handle_map: SpinLock::new(HashMap::new()), + } + } + + /// 插入新的地址映射 + pub fn insert(&self, name: String) -> Result { + let mut guard = self.abs_handle_map.lock(); + + //检查name是否被占用 + if guard.contains_key(&name) { + return Err(SystemError::ENOMEM); + } + + let ads_addr = match self.alloc(name.clone()) { + Some(addr) => addr.clone(), + None => return Err(SystemError::ENOMEM), + }; + guard.insert(name, ads_addr.clone()); + return Ok(ads_addr); + } + + /// 抽象空间地址分配器 + /// + /// ## 返回 + /// + /// 分配到的可用的抽象端点 + pub fn alloc(&self, name: String) -> Option { + let abs_addr = match ABS_ADDRESS_ALLOCATOR.lock().alloc() { + Some(addr) => addr, + //地址被分配 + None => return None, + }; + + let result = Some(Endpoint::Abspath((AbsHandle::new(abs_addr), name))); + + return result; + } + + /// 进行地址映射 + /// + /// ## 参数 + /// + /// name:用户定义的地址 + pub fn look_up(&self, name: &String) -> Option { + let guard = self.abs_handle_map.lock(); + return guard.get(name).cloned(); + } + + /// 移除绑定的地址 + /// + /// ## 参数 + /// + /// name:待删除的地址 + pub fn remove(&self, name: &String) -> Result<(), SystemError> { + let abs_addr = match look_up_abs_addr(name) { + Ok(result) => match result { + Endpoint::Abspath((abshandle, _)) => abshandle.name(), + _ => return Err(SystemError::EINVAL), + }, + Err(_) => return Err(SystemError::EINVAL), + }; + + //释放abs地址分配实例 + ABS_ADDRESS_ALLOCATOR.lock().free(abs_addr); + + //释放entry + let mut guard = self.abs_handle_map.lock(); + guard.remove(name); + + Ok(()) + } +} + +/// 分配抽象地址 +/// +/// ## 返回 +/// +/// 分配到的抽象地址 +pub fn alloc_abs_addr(name: String) -> Result { + ABSHANDLE_MAP.insert(name) +} + +/// 查找抽象地址 +/// +/// ## 参数 +/// +/// name:用户socket字符地址 +/// +/// ## 返回 +/// +/// 查询到的抽象地址 +pub fn look_up_abs_addr(name: &String) -> Result { + match ABSHANDLE_MAP.look_up(name) { + Some(result) => return Ok(result), + None => return Err(SystemError::EINVAL), + } +} + +/// 删除抽象地址 +/// +/// ## 参数 +/// name:待删除的地址 +/// +/// ## 返回 +/// 删除的抽象地址 +pub fn remove_abs_addr(name: &String) -> Result<(), SystemError> { + let abs_addr = match look_up_abs_addr(name) { + Ok(addr) => match addr { + Endpoint::Abspath((addr, _)) => addr, + _ => return Err(SystemError::EINVAL), + }, + Err(_) => return Err(SystemError::EINVAL), + }; + + match ABS_INODE_MAP.lock_irqsave().remove(&abs_addr.name()) { + Some(_) => log::debug!("free abs inode"), + None => log::debug!("not free abs inode"), + } + ABSHANDLE_MAP.remove(name)?; + log::debug!("free abs!"); + Ok(()) +} diff --git a/kernel/src/net/socket/unix/ns/mod.rs b/kernel/src/net/socket/unix/ns/mod.rs new file mode 100644 index 000000000..d99b5e678 --- /dev/null +++ b/kernel/src/net/socket/unix/ns/mod.rs @@ -0,0 +1 @@ +pub mod abs; diff --git a/kernel/src/net/socket/unix/seqpacket/inner.rs b/kernel/src/net/socket/unix/seqpacket/inner.rs new file mode 100644 index 000000000..8294b0dca --- /dev/null +++ b/kernel/src/net/socket/unix/seqpacket/inner.rs @@ -0,0 +1,260 @@ +use alloc::string::String; +use alloc::{collections::VecDeque, sync::Arc}; +use core::sync::atomic::{AtomicUsize, Ordering}; + +use super::SeqpacketSocket; +use crate::net::socket::common::shutdown::ShutdownTemp; +use crate::{ + libs::mutex::Mutex, + net::socket::{buffer::Buffer, endpoint::Endpoint, SocketInode}, +}; +use system_error::SystemError; + +#[derive(Debug)] +pub(super) struct Init { + inode: Option, +} + +impl Init { + pub(super) fn new() -> Self { + Self { inode: None } + } + + pub(super) fn bind(&mut self, epoint_to_bind: Endpoint) -> Result<(), SystemError> { + if self.inode.is_some() { + log::error!("the socket is already bound"); + return Err(SystemError::EINVAL); + } + match epoint_to_bind { + Endpoint::Inode(_) => self.inode = Some(epoint_to_bind), + _ => return Err(SystemError::EINVAL), + } + + return Ok(()); + } + + pub fn bind_path(&mut self, sun_path: String) -> Result { + if self.inode.is_none() { + log::error!("the socket is not bound"); + return Err(SystemError::EINVAL); + } + if let Some(Endpoint::Inode((inode, mut path))) = self.inode.take() { + path = sun_path; + let epoint = Endpoint::Inode((inode, path)); + self.inode.replace(epoint.clone()); + return Ok(epoint); + }; + + return Err(SystemError::EINVAL); + } + + pub fn endpoint(&self) -> Option<&Endpoint> { + return self.inode.as_ref(); + } +} + +#[derive(Debug)] +pub(super) struct Listener { + inode: Endpoint, + backlog: AtomicUsize, + incoming_conns: Mutex>>, +} + +impl Listener { + pub(super) fn new(inode: Endpoint, backlog: usize) -> Self { + log::debug!("backlog {}", backlog); + let back = if backlog > 1024 { 1024_usize } else { backlog }; + return Self { + inode, + backlog: AtomicUsize::new(back), + incoming_conns: Mutex::new(VecDeque::with_capacity(back)), + }; + } + pub(super) fn endpoint(&self) -> &Endpoint { + return &self.inode; + } + + pub(super) fn try_accept(&self) -> Result<(Arc, Endpoint), SystemError> { + let mut incoming_conns = self.incoming_conns.lock(); + log::debug!(" incom len {}", incoming_conns.len()); + let conn = incoming_conns + .pop_front() + .ok_or(SystemError::EAGAIN_OR_EWOULDBLOCK)?; + let socket = + Arc::downcast::(conn.inner()).map_err(|_| SystemError::EINVAL)?; + let peer = match &*socket.inner.read() { + Inner::Connected(connected) => connected.peer_endpoint().unwrap().clone(), + _ => return Err(SystemError::ENOTCONN), + }; + + return Ok((SocketInode::new(socket), peer)); + } + + pub(super) fn listen(&self, backlog: usize) -> Result<(), SystemError> { + self.backlog.store(backlog, Ordering::Relaxed); + Ok(()) + } + + pub(super) fn push_incoming( + &self, + client_epoint: Option, + ) -> Result { + let mut incoming_conns = self.incoming_conns.lock(); + if incoming_conns.len() >= self.backlog.load(Ordering::Relaxed) { + log::error!("the pending connection queue on the listening socket is full"); + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + } + + let new_server = SeqpacketSocket::new(false); + let new_inode = SocketInode::new(new_server.clone()); + // log::debug!("new inode {:?},client_epoint {:?}",new_inode,client_epoint); + let path = match &self.inode { + Endpoint::Inode((_, path)) => path.clone(), + _ => return Err(SystemError::EINVAL), + }; + + let (server_conn, client_conn) = Connected::new_pair( + Some(Endpoint::Inode((new_inode.clone(), path))), + client_epoint, + ); + *new_server.inner.write() = Inner::Connected(server_conn); + incoming_conns.push_back(new_inode); + + // TODO: epollin + + Ok(client_conn) + } + + pub(super) fn is_acceptable(&self) -> bool { + return self.incoming_conns.lock().len() != 0; + } +} + +#[derive(Debug)] +pub struct Connected { + inode: Option, + peer_inode: Option, + buffer: Arc, +} + +impl Connected { + /// 默认的缓冲区大小 + #[allow(dead_code)] + pub const DEFAULT_BUF_SIZE: usize = 64 * 1024; + + pub fn new_pair( + inode: Option, + peer_inode: Option, + ) -> (Connected, Connected) { + let this = Connected { + inode: inode.clone(), + peer_inode: peer_inode.clone(), + buffer: Buffer::new(), + }; + let peer = Connected { + inode: peer_inode, + peer_inode: inode, + buffer: Buffer::new(), + }; + + (this, peer) + } + + #[allow(dead_code)] + pub fn set_peer_inode(&mut self, peer_epoint: Option) { + self.peer_inode = peer_epoint; + } + + #[allow(dead_code)] + pub fn set_inode(&mut self, epoint: Option) { + self.inode = epoint; + } + + pub fn endpoint(&self) -> Option<&Endpoint> { + self.inode.as_ref() + } + + pub fn peer_endpoint(&self) -> Option<&Endpoint> { + self.peer_inode.as_ref() + } + + pub fn try_read(&self, buf: &mut [u8]) -> Result { + if self.can_recv() { + return self.recv_slice(buf); + } else { + return Err(SystemError::EINVAL); + } + } + + pub fn try_write(&self, buf: &[u8]) -> Result { + if self.can_send()? { + return self.send_slice(buf); + } else { + log::debug!("can not send {:?}", String::from_utf8_lossy(buf)); + return Err(SystemError::ENOBUFS); + } + } + + pub fn can_recv(&self) -> bool { + return !self.buffer.is_read_buf_empty(); + } + + // 检查发送缓冲区是否满了 + pub fn can_send(&self) -> Result { + // let sebuffer = self.sebuffer.lock(); // 获取锁 + // sebuffer.capacity()-sebuffer.len() ==0; + let peer_inode = match self.peer_inode.as_ref().unwrap() { + Endpoint::Inode((inode, _)) => inode, + _ => return Err(SystemError::EINVAL), + }; + let peer_socket = Arc::downcast::(peer_inode.inner()) + .map_err(|_| SystemError::EINVAL)?; + let is_full = match &*peer_socket.inner.read() { + Inner::Connected(connected) => connected.buffer.is_read_buf_full(), + _ => return Err(SystemError::EINVAL), + }; + Ok(!is_full) + } + + pub fn recv_slice(&self, buf: &mut [u8]) -> Result { + return self.buffer.read_read_buffer(buf); + } + + pub fn send_slice(&self, buf: &[u8]) -> Result { + //找到peer_inode,并将write_buffer的内容写入对端的read_buffer + let peer_inode = match self.peer_inode.as_ref().unwrap() { + Endpoint::Inode((inode, _)) => inode, + _ => return Err(SystemError::EINVAL), + }; + let peer_socket = Arc::downcast::(peer_inode.inner()) + .map_err(|_| SystemError::EINVAL)?; + let usize = match &*peer_socket.inner.write() { + Inner::Connected(connected) => { + let usize = connected.buffer.write_read_buffer(buf)?; + usize + } + _ => return Err(SystemError::EINVAL), + }; + peer_socket.wait_queue.wakeup(None); + Ok(usize) + } + + pub fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> { + if how.is_empty() { + return Err(SystemError::EINVAL); + } else if how.is_send_shutdown() { + unimplemented!("unimplemented!"); + } else if how.is_recv_shutdown() { + unimplemented!("unimplemented!"); + } + + Ok(()) + } +} + +#[derive(Debug)] +pub(super) enum Inner { + Init(Init), + Listen(Listener), + Connected(Connected), +} diff --git a/kernel/src/net/socket/unix/seqpacket/mod.rs b/kernel/src/net/socket/unix/seqpacket/mod.rs new file mode 100644 index 000000000..4a47e55a7 --- /dev/null +++ b/kernel/src/net/socket/unix/seqpacket/mod.rs @@ -0,0 +1,576 @@ +pub mod inner; +use alloc::{ + string::String, + sync::{Arc, Weak}, +}; +use core::sync::atomic::{AtomicBool, Ordering}; + +use crate::{ + libs::{rwlock::RwLock, wait_queue::WaitQueue}, + net::socket::{Socket, SocketInode, PMSG}, +}; +use crate::{ + net::{ + posix::MsgHdr, + socket::{ + common::shutdown::{Shutdown, ShutdownTemp}, + endpoint::Endpoint, + }, + }, + sched::SchedMode, +}; + +use system_error::SystemError; + +use super::{ + ns::abs::{remove_abs_addr, ABS_INODE_MAP}, + INODE_MAP, +}; + +type EP = crate::filesystem::epoll::EPollEventType; +#[derive(Debug)] +pub struct SeqpacketSocket { + inner: RwLock, + shutdown: Shutdown, + is_nonblocking: AtomicBool, + wait_queue: WaitQueue, + self_ref: Weak, +} + +impl SeqpacketSocket { + /// 默认的元数据缓冲区大小 + #[allow(dead_code)] + pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; + /// 默认的缓冲区大小 + pub const DEFAULT_BUF_SIZE: usize = 64 * 1024; + + pub fn new(is_nonblocking: bool) -> Arc { + Arc::new_cyclic(|me| Self { + inner: RwLock::new(inner::Inner::Init(inner::Init::new())), + shutdown: Shutdown::new(), + is_nonblocking: AtomicBool::new(is_nonblocking), + wait_queue: WaitQueue::default(), + self_ref: me.clone(), + }) + } + + pub fn new_inode(is_nonblocking: bool) -> Result, SystemError> { + let socket = SeqpacketSocket::new(is_nonblocking); + let inode = SocketInode::new(socket.clone()); + // 建立时绑定自身为后续能正常获取本端地址 + let _ = match &mut *socket.inner.write() { + inner::Inner::Init(init) => { + init.bind(Endpoint::Inode((inode.clone(), String::from("")))) + } + _ => return Err(SystemError::EINVAL), + }; + return Ok(inode); + } + + #[allow(dead_code)] + pub fn new_connected(connected: inner::Connected, is_nonblocking: bool) -> Arc { + Arc::new_cyclic(|me| Self { + inner: RwLock::new(inner::Inner::Connected(connected)), + shutdown: Shutdown::new(), + is_nonblocking: AtomicBool::new(is_nonblocking), + wait_queue: WaitQueue::default(), + self_ref: me.clone(), + }) + } + + pub fn new_pairs() -> Result<(Arc, Arc), SystemError> { + let socket0 = SeqpacketSocket::new(false); + let socket1 = SeqpacketSocket::new(false); + let inode0 = SocketInode::new(socket0.clone()); + let inode1 = SocketInode::new(socket1.clone()); + + let (conn_0, conn_1) = inner::Connected::new_pair( + Some(Endpoint::Inode((inode0.clone(), String::from("")))), + Some(Endpoint::Inode((inode1.clone(), String::from("")))), + ); + *socket0.inner.write() = inner::Inner::Connected(conn_0); + *socket1.inner.write() = inner::Inner::Connected(conn_1); + + return Ok((inode0, inode1)); + } + + fn try_accept(&self) -> Result<(Arc, Endpoint), SystemError> { + match &*self.inner.read() { + inner::Inner::Listen(listen) => listen.try_accept() as _, + _ => { + log::error!("the socket is not listening"); + return Err(SystemError::EINVAL); + } + } + } + + fn is_acceptable(&self) -> bool { + match &*self.inner.read() { + inner::Inner::Listen(listen) => listen.is_acceptable(), + _ => { + panic!("the socket is not listening"); + } + } + } + + fn is_peer_shutdown(&self) -> Result { + let peer_shutdown = match self.get_peer_name()? { + Endpoint::Inode((inode, _)) => Arc::downcast::(inode.inner()) + .map_err(|_| SystemError::EINVAL)? + .shutdown + .get() + .is_both_shutdown(), + _ => return Err(SystemError::EINVAL), + }; + Ok(peer_shutdown) + } + + fn can_recv(&self) -> Result { + let can = match &*self.inner.read() { + inner::Inner::Connected(connected) => connected.can_recv(), + _ => return Err(SystemError::ENOTCONN), + }; + Ok(can) + } + + fn is_nonblocking(&self) -> bool { + self.is_nonblocking.load(Ordering::Relaxed) + } + + #[allow(dead_code)] + fn set_nonblocking(&self, nonblocking: bool) { + self.is_nonblocking.store(nonblocking, Ordering::Relaxed); + } +} + +impl Socket for SeqpacketSocket { + fn connect(&self, endpoint: Endpoint) -> Result<(), SystemError> { + let peer_inode = match endpoint { + Endpoint::Inode((inode, _)) => inode, + Endpoint::Unixpath((inode_id, _)) => { + let inode_guard = INODE_MAP.read_irqsave(); + let inode = inode_guard.get(&inode_id).unwrap(); + match inode { + Endpoint::Inode((inode, _)) => inode.clone(), + _ => return Err(SystemError::EINVAL), + } + } + Endpoint::Abspath((abs_addr, _)) => { + let inode_guard = ABS_INODE_MAP.lock_irqsave(); + let inode = match inode_guard.get(&abs_addr.name()) { + Some(inode) => inode, + None => { + log::debug!("can not find inode from absInodeMap"); + return Err(SystemError::EINVAL); + } + }; + match inode { + Endpoint::Inode((inode, _)) => inode.clone(), + _ => { + log::debug!("when connect, find inode failed!"); + return Err(SystemError::EINVAL); + } + } + } + _ => return Err(SystemError::EINVAL), + }; + // 远端为服务端 + let remote_socket = Arc::downcast::(peer_inode.inner()) + .map_err(|_| SystemError::EINVAL)?; + + let client_epoint = match &mut *self.inner.write() { + inner::Inner::Init(init) => match init.endpoint().cloned() { + Some(end) => { + log::trace!("bind when connect"); + Some(end) + } + None => { + log::trace!("not bind when connect"); + let inode = SocketInode::new(self.self_ref.upgrade().unwrap().clone()); + let epoint = Endpoint::Inode((inode.clone(), String::from(""))); + let _ = init.bind(epoint.clone()); + Some(epoint) + } + }, + inner::Inner::Listen(_) => return Err(SystemError::EINVAL), + inner::Inner::Connected(_) => return Err(SystemError::EISCONN), + }; + // ***阻塞与非阻塞处理还未实现 + // 客户端与服务端建立连接将服务端inode推入到自身的listen_incom队列中, + // accept时从中获取推出对应的socket + match &*remote_socket.inner.read() { + inner::Inner::Listen(listener) => match listener.push_incoming(client_epoint) { + Ok(connected) => { + *self.inner.write() = inner::Inner::Connected(connected); + log::debug!("try to wake up"); + + remote_socket.wait_queue.wakeup(None); + return Ok(()); + } + // ***错误处理 + Err(_) => todo!(), + }, + inner::Inner::Init(_) => { + log::debug!("init einval"); + return Err(SystemError::EINVAL); + } + inner::Inner::Connected(_) => return Err(SystemError::EISCONN), + }; + } + + fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> { + // 将自身socket的inode与用户端提供路径的文件indoe_id进行绑定 + match endpoint { + Endpoint::Unixpath((inodeid, path)) => { + let inode = match &mut *self.inner.write() { + inner::Inner::Init(init) => init.bind_path(path)?, + _ => { + log::error!("socket has listen or connected"); + return Err(SystemError::EINVAL); + } + }; + + INODE_MAP.write_irqsave().insert(inodeid, inode); + Ok(()) + } + Endpoint::Abspath((abshandle, path)) => { + let inode = match &mut *self.inner.write() { + inner::Inner::Init(init) => init.bind_path(path)?, + _ => { + log::error!("socket has listen or connected"); + return Err(SystemError::EINVAL); + } + }; + ABS_INODE_MAP.lock_irqsave().insert(abshandle.name(), inode); + Ok(()) + } + _ => return Err(SystemError::EINVAL), + } + } + + fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> { + log::debug!("seqpacket shutdown"); + match &*self.inner.write() { + inner::Inner::Connected(connected) => connected.shutdown(how), + _ => Err(SystemError::EINVAL), + } + } + + fn listen(&self, backlog: usize) -> Result<(), SystemError> { + let mut state = self.inner.write(); + log::debug!("listen into socket"); + let epoint = match &*state { + inner::Inner::Init(init) => init.endpoint().ok_or(SystemError::EINVAL)?.clone(), + inner::Inner::Listen(listener) => return listener.listen(backlog), + inner::Inner::Connected(_) => { + log::error!("the socket is connected"); + return Err(SystemError::EINVAL); + } + }; + + let listener = inner::Listener::new(epoint, backlog); + *state = inner::Inner::Listen(listener); + + Ok(()) + } + + fn accept(&self) -> Result<(Arc, Endpoint), SystemError> { + if !self.is_nonblocking() { + loop { + wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?; + match self.try_accept() { + Ok((socket, epoint)) => return Ok((socket, epoint)), + Err(_) => continue, + } + } + } else { + // ***非阻塞状态 + todo!() + } + } + + fn set_option( + &self, + _level: crate::net::socket::PSOL, + _optname: usize, + _optval: &[u8], + ) -> Result<(), SystemError> { + log::warn!("setsockopt is not implemented"); + Ok(()) + } + + fn wait_queue(&self) -> &WaitQueue { + return &self.wait_queue; + } + + fn close(&self) -> Result<(), SystemError> { + // log::debug!("seqpacket close"); + self.shutdown.recv_shutdown(); + self.shutdown.send_shutdown(); + + let endpoint = self.get_name()?; + let path = match &endpoint { + Endpoint::Inode((_, path)) => path, + Endpoint::Unixpath((_, path)) => path, + Endpoint::Abspath((_, path)) => path, + _ => return Err(SystemError::EINVAL), + }; + + if path.is_empty() { + return Ok(()); + } + + match &endpoint { + Endpoint::Unixpath((inode_id, _)) => { + let mut inode_guard = INODE_MAP.write_irqsave(); + inode_guard.remove(inode_id); + } + Endpoint::Inode((current_inode, current_path)) => { + let mut inode_guard = INODE_MAP.write_irqsave(); + // 遍历查找匹配的条目 + let target_entry = inode_guard + .iter() + .find(|(_, ep)| { + if let Endpoint::Inode((map_inode, map_path)) = ep { + // 通过指针相等性比较确保是同一对象 + Arc::ptr_eq(map_inode, current_inode) && map_path == current_path + } else { + log::debug!("not match"); + false + } + }) + .map(|(id, _)| *id); + + if let Some(id) = target_entry { + inode_guard.remove(&id).ok_or(SystemError::EINVAL)?; + } + } + Endpoint::Abspath((abshandle, _)) => { + let mut abs_inode_map = ABS_INODE_MAP.lock_irqsave(); + abs_inode_map.remove(&abshandle.name()); + } + _ => { + log::error!("invalid endpoint type"); + return Err(SystemError::EINVAL); + } + } + + *self.inner.write() = inner::Inner::Init(inner::Init::new()); + self.wait_queue.wakeup(None); + + let _ = remove_abs_addr(path); + + return Ok(()); + } + + fn get_peer_name(&self) -> Result { + // 获取对端地址 + let endpoint = match &*self.inner.read() { + inner::Inner::Connected(connected) => connected.peer_endpoint().cloned(), + _ => return Err(SystemError::ENOTCONN), + }; + + if let Some(endpoint) = endpoint { + return Ok(endpoint); + } else { + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + } + } + + fn get_name(&self) -> Result { + // 获取本端地址 + let endpoint = match &*self.inner.read() { + inner::Inner::Init(init) => init.endpoint().cloned(), + inner::Inner::Listen(listener) => Some(listener.endpoint().clone()), + inner::Inner::Connected(connected) => connected.endpoint().cloned(), + }; + + if let Some(endpoint) = endpoint { + return Ok(endpoint); + } else { + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + } + } + + fn get_option( + &self, + _level: crate::net::socket::PSOL, + _name: usize, + _value: &mut [u8], + ) -> Result { + log::warn!("getsockopt is not implemented"); + Ok(0) + } + + fn read(&self, buffer: &mut [u8]) -> Result { + self.recv(buffer, crate::net::socket::PMSG::empty()) + } + + fn recv( + &self, + buffer: &mut [u8], + flags: crate::net::socket::PMSG, + ) -> Result { + if flags.contains(PMSG::OOB) { + return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP); + } + if !flags.contains(PMSG::DONTWAIT) { + loop { + wq_wait_event_interruptible!( + self.wait_queue, + self.can_recv()? || self.is_peer_shutdown()?, + {} + )?; + // connect锁和flag判断顺序不正确,应该先判断在 + match &*self.inner.write() { + inner::Inner::Connected(connected) => match connected.try_read(buffer) { + Ok(usize) => { + log::debug!("recv from successfully"); + return Ok(usize); + } + Err(_) => continue, + }, + _ => { + log::error!("the socket is not connected"); + return Err(SystemError::ENOTCONN); + } + } + } + } else { + unimplemented!("unimplemented non_block") + } + } + + fn recv_msg( + &self, + _msg: &mut MsgHdr, + _flags: crate::net::socket::PMSG, + ) -> Result { + Err(SystemError::ENOSYS) + } + + fn send(&self, buffer: &[u8], flags: crate::net::socket::PMSG) -> Result { + if flags.contains(PMSG::OOB) { + return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP); + } + if self.is_peer_shutdown()? { + return Err(SystemError::EPIPE); + } + if !flags.contains(PMSG::DONTWAIT) { + loop { + match &*self.inner.write() { + inner::Inner::Connected(connected) => match connected.try_write(buffer) { + Ok(usize) => { + log::debug!("send successfully"); + return Ok(usize); + } + Err(_) => continue, + }, + _ => { + log::error!("the socket is not connected"); + return Err(SystemError::ENOTCONN); + } + } + } + } else { + unimplemented!("unimplemented non_block") + } + } + + fn send_msg( + &self, + _msg: &MsgHdr, + _flags: crate::net::socket::PMSG, + ) -> Result { + Err(SystemError::ENOSYS) + } + + fn write(&self, buffer: &[u8]) -> Result { + self.send(buffer, crate::net::socket::PMSG::empty()) + } + + fn recv_from( + &self, + buffer: &mut [u8], + flags: PMSG, + _address: Option, + ) -> Result<(usize, Endpoint), SystemError> { + // log::debug!("recvfrom flags {:?}", flags); + if flags.contains(PMSG::OOB) { + return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP); + } + if !flags.contains(PMSG::DONTWAIT) { + loop { + wq_wait_event_interruptible!( + self.wait_queue, + self.can_recv()? || self.is_peer_shutdown()?, + {} + )?; + // connect锁和flag判断顺序不正确,应该先判断在 + match &*self.inner.write() { + inner::Inner::Connected(connected) => match connected.recv_slice(buffer) { + Ok(usize) => { + // log::debug!("recvs from successfully"); + return Ok((usize, connected.peer_endpoint().unwrap().clone())); + } + Err(_) => continue, + }, + _ => { + log::error!("the socket is not connected"); + return Err(SystemError::ENOTCONN); + } + } + } + } else { + unimplemented!("unimplemented non_block") + } + //Err(SystemError::ENOSYS) + } + + fn send_buffer_size(&self) -> usize { + // log::warn!("using default buffer size"); + SeqpacketSocket::DEFAULT_BUF_SIZE + } + + fn recv_buffer_size(&self) -> usize { + // log::warn!("using default buffer size"); + SeqpacketSocket::DEFAULT_BUF_SIZE + } + + fn poll(&self) -> usize { + let mut mask = EP::empty(); + let shutdown = self.shutdown.get(); + + // 参考linux的unix_poll https://code.dragonos.org.cn/xref/linux-6.1.9/net/unix/af_unix.c#3152 + // 用关闭读写端表示连接断开 + if shutdown.is_both_shutdown() || self.is_peer_shutdown().unwrap() { + mask |= EP::EPOLLHUP; + } + + if shutdown.is_recv_shutdown() { + mask |= EP::EPOLLRDHUP | EP::EPOLLIN | EP::EPOLLRDNORM; + } + match &*self.inner.read() { + inner::Inner::Connected(connected) => { + if connected.can_recv() { + mask |= EP::EPOLLIN | EP::EPOLLRDNORM; + } + // if (sk_is_readable(sk)) + // mask |= EPOLLIN | EPOLLRDNORM; + + // TODO:处理紧急情况 EPOLLPRI + // TODO:处理连接是否关闭 EPOLLHUP + if !shutdown.is_send_shutdown() { + if connected.can_send().unwrap() { + mask |= EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND; + } else { + todo!("poll: buffer space not enough"); + } + } + } + inner::Inner::Listen(_) => mask |= EP::EPOLLIN, + inner::Inner::Init(_) => mask |= EP::EPOLLOUT, + } + mask.bits() as usize + } +} diff --git a/kernel/src/net/socket/unix/stream/inner.rs b/kernel/src/net/socket/unix/stream/inner.rs new file mode 100644 index 000000000..a10e69737 --- /dev/null +++ b/kernel/src/net/socket/unix/stream/inner.rs @@ -0,0 +1,249 @@ +use core::sync::atomic::{AtomicUsize, Ordering}; + +use log::debug; +use system_error::SystemError; + +use crate::libs::mutex::Mutex; +use crate::net::socket::buffer::Buffer; +use crate::net::socket::common::shutdown::ShutdownTemp; +use crate::net::socket::endpoint::Endpoint; +use crate::net::socket::unix::stream::StreamSocket; +use crate::net::socket::SocketInode; + +use alloc::collections::VecDeque; +use alloc::{string::String, sync::Arc}; + +#[derive(Debug)] +pub enum Inner { + Init(Init), + Connected(Connected), + Listener(Listener), +} + +#[derive(Debug)] +pub struct Init { + addr: Option, +} + +impl Init { + pub(super) fn new() -> Self { + Self { addr: None } + } + + pub(super) fn bind(&mut self, endpoint_to_bind: Endpoint) -> Result<(), SystemError> { + if self.addr.is_some() { + log::error!("the socket is already bound"); + return Err(SystemError::EINVAL); + } + + match endpoint_to_bind { + Endpoint::Inode(_) => self.addr = Some(endpoint_to_bind), + _ => return Err(SystemError::EINVAL), + } + + return Ok(()); + } + + pub fn bind_path(&mut self, sun_path: String) -> Result { + if self.addr.is_none() { + log::error!("the socket is not bound"); + return Err(SystemError::EINVAL); + } + if let Some(Endpoint::Inode((inode, mut path))) = self.addr.take() { + path = sun_path; + let epoint = Endpoint::Inode((inode, path.clone())); + self.addr.replace(epoint.clone()); + log::debug!("bind path in inode : {:?}", path); + return Ok(epoint); + }; + + return Err(SystemError::EINVAL); + } + + pub(super) fn endpoint(&self) -> Option<&Endpoint> { + self.addr.as_ref() + } +} + +#[derive(Debug, Clone)] +pub struct Connected { + addr: Option, + peer_addr: Option, + buffer: Arc, +} + +impl Connected { + pub fn new_pair(addr: Option, peer_addr: Option) -> (Self, Self) { + let this = Connected { + addr: addr.clone(), + peer_addr: peer_addr.clone(), + buffer: Buffer::new(), + }; + let peer = Connected { + addr: peer_addr, + peer_addr: addr, + buffer: Buffer::new(), + }; + + return (this, peer); + } + + pub fn endpoint(&self) -> Option<&Endpoint> { + self.addr.as_ref() + } + + #[allow(dead_code)] + pub fn set_addr(&mut self, addr: Option) { + self.addr = addr; + } + + pub fn peer_endpoint(&self) -> Option<&Endpoint> { + self.peer_addr.as_ref() + } + + #[allow(dead_code)] + pub fn set_peer_addr(&mut self, peer: Option) { + self.peer_addr = peer; + } + + pub fn send_slice(&self, buf: &[u8]) -> Result { + //写入对端buffer + let peer_inode = match self.peer_addr.as_ref().unwrap() { + Endpoint::Inode((inode, _)) => inode, + _ => return Err(SystemError::EINVAL), + }; + let peer_socket = + Arc::downcast::(peer_inode.inner()).map_err(|_| SystemError::EINVAL)?; + let usize = match &*peer_socket.inner.read() { + Inner::Connected(conntected) => { + let usize = conntected.buffer.write_read_buffer(buf)?; + usize + } + _ => { + debug!("no! is not connested!"); + return Err(SystemError::EINVAL); + } + }; + peer_socket.wait_queue.wakeup(None); + Ok(usize) + } + + pub fn can_send(&self) -> Result { + //查看连接体里的buf是否非满 + let peer_inode = match self.peer_addr.as_ref().unwrap() { + Endpoint::Inode((inode, _)) => inode, + _ => return Err(SystemError::EINVAL), + }; + let peer_socket = + Arc::downcast::(peer_inode.inner()).map_err(|_| SystemError::EINVAL)?; + let is_full = match &*peer_socket.inner.read() { + Inner::Connected(connected) => connected.buffer.is_read_buf_full(), + _ => return Err(SystemError::EINVAL), + }; + debug!("can send? :{}", !is_full); + Ok(!is_full) + } + + pub fn can_recv(&self) -> bool { + //查看连接体里的buf是否非空 + return !self.buffer.is_read_buf_empty(); + } + + pub fn try_send(&self, buf: &[u8]) -> Result { + if self.can_send()? { + return self.send_slice(buf); + } else { + return Err(SystemError::ENOBUFS); + } + } + + fn recv_slice(&self, buf: &mut [u8]) -> Result { + return self.buffer.read_read_buffer(buf); + } + + pub fn try_recv(&self, buf: &mut [u8]) -> Result { + if self.can_recv() { + return self.recv_slice(buf); + } else { + return Err(SystemError::EINVAL); + } + } + #[allow(dead_code)] + pub fn shutdown(&self, how: ShutdownTemp) -> Result<(), SystemError> { + if how.is_empty() { + return Err(SystemError::EINVAL); + } else if how.is_send_shutdown() { + unimplemented!("unimplemented!"); + } else if how.is_recv_shutdown() { + unimplemented!("unimplemented!"); + } + + Ok(()) + } +} + +#[derive(Debug)] +pub struct Listener { + addr: Option, + incoming_connects: Mutex>>, + backlog: AtomicUsize, +} + +impl Listener { + pub fn new(addr: Option, backlog: usize) -> Self { + Self { + addr, + incoming_connects: Mutex::new(VecDeque::new()), + backlog: AtomicUsize::new(backlog), + } + } + + pub fn listen(&self, backlog: usize) -> Result<(), SystemError> { + self.backlog.store(backlog, Ordering::Relaxed); + return Ok(()); + } + + pub fn push_incoming(&self, server_inode: Arc) -> Result<(), SystemError> { + let mut incoming_connects = self.incoming_connects.lock(); + + if incoming_connects.len() >= self.backlog.load(Ordering::Relaxed) { + debug!("unix stream listen socket connected queue is full!"); + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + } + + incoming_connects.push_back(server_inode); + + return Ok(()); + } + + #[allow(dead_code)] + pub fn pop_incoming(&self) -> Option> { + let mut incoming_connects = self.incoming_connects.lock(); + + return incoming_connects.pop_front(); + } + + pub(super) fn endpoint(&self) -> Option<&Endpoint> { + self.addr.as_ref() + } + + pub(super) fn is_acceptable(&self) -> bool { + return self.incoming_connects.lock().len() != 0; + } + + pub(super) fn try_accept(&self) -> Result<(Arc, Endpoint), SystemError> { + let mut incoming_connecteds = self.incoming_connects.lock(); + debug!("incom len {}", incoming_connecteds.len()); + let connected = incoming_connecteds + .pop_front() + .ok_or(SystemError::EAGAIN_OR_EWOULDBLOCK)?; + let socket = + Arc::downcast::(connected.inner()).map_err(|_| SystemError::EINVAL)?; + let peer = match &*socket.inner.read() { + Inner::Connected(connected) => connected.peer_endpoint().unwrap().clone(), + _ => return Err(SystemError::ENOTCONN), + }; + debug!("server accept!"); + return Ok((SocketInode::new(socket), peer)); + } +} diff --git a/kernel/src/net/socket/unix/stream/mod.rs b/kernel/src/net/socket/unix/stream/mod.rs new file mode 100644 index 000000000..10b040d43 --- /dev/null +++ b/kernel/src/net/socket/unix/stream/mod.rs @@ -0,0 +1,557 @@ +use crate::{ + net::{ + posix::MsgHdr, + socket::{ + common::shutdown::{Shutdown, ShutdownTemp}, + endpoint::Endpoint, + }, + }, + sched::SchedMode, +}; +use alloc::{ + string::String, + sync::{Arc, Weak}, +}; +use inner::{Connected, Init, Inner, Listener}; +use log::debug; +use system_error::SystemError; +use unix::{ + ns::abs::{remove_abs_addr, ABS_INODE_MAP}, + INODE_MAP, +}; + +use crate::{ + libs::rwlock::RwLock, + net::socket::{self, *}, +}; + +type EP = crate::filesystem::epoll::EPollEventType; + +pub mod inner; + +#[derive(Debug)] +pub struct StreamSocket { + inner: RwLock, + shutdown: Shutdown, + _epitems: EPollItems, + wait_queue: WaitQueue, + self_ref: Weak, +} + +impl StreamSocket { + /// 默认的元数据缓冲区大小 + #[allow(dead_code)] + pub const DEFAULT_METADATA_BUF_SIZE: usize = 1024; + /// 默认的缓冲区大小 + pub const DEFAULT_BUF_SIZE: usize = 64 * 1024; + + pub fn new() -> Arc { + Arc::new_cyclic(|me| Self { + inner: RwLock::new(Inner::Init(Init::new())), + shutdown: Shutdown::new(), + _epitems: EPollItems::default(), + wait_queue: WaitQueue::default(), + self_ref: me.clone(), + }) + } + + pub fn new_pairs() -> Result<(Arc, Arc), SystemError> { + let socket0 = StreamSocket::new(); + let socket1 = StreamSocket::new(); + let inode0 = SocketInode::new(socket0.clone()); + let inode1 = SocketInode::new(socket1.clone()); + + let (conn_0, conn_1) = Connected::new_pair( + Some(Endpoint::Inode((inode0.clone(), String::from("")))), + Some(Endpoint::Inode((inode1.clone(), String::from("")))), + ); + *socket0.inner.write() = Inner::Connected(conn_0); + *socket1.inner.write() = Inner::Connected(conn_1); + + return Ok((inode0, inode1)); + } + #[allow(dead_code)] + pub fn new_connected(connected: Connected) -> Arc { + Arc::new_cyclic(|me| Self { + inner: RwLock::new(Inner::Connected(connected)), + shutdown: Shutdown::new(), + _epitems: EPollItems::default(), + wait_queue: WaitQueue::default(), + self_ref: me.clone(), + }) + } + + pub fn new_inode() -> Result, SystemError> { + let socket = StreamSocket::new(); + let inode = SocketInode::new(socket.clone()); + + let _ = match &mut *socket.inner.write() { + Inner::Init(init) => init.bind(Endpoint::Inode((inode.clone(), String::from("")))), + _ => return Err(SystemError::EINVAL), + }; + + return Ok(inode); + } + + fn is_acceptable(&self) -> bool { + match &*self.inner.read() { + Inner::Listener(listener) => listener.is_acceptable(), + _ => { + panic!("the socket is not listening"); + } + } + } + + pub fn try_accept(&self) -> Result<(Arc, Endpoint), SystemError> { + match &*self.inner.read() { + Inner::Listener(listener) => listener.try_accept() as _, + _ => { + log::error!("the socket is not listening"); + return Err(SystemError::EINVAL); + } + } + } + + fn is_peer_shutdown(&self) -> Result { + let peer_shutdown = match self.get_peer_name()? { + Endpoint::Inode((inode, _)) => Arc::downcast::(inode.inner()) + .map_err(|_| SystemError::EINVAL)? + .shutdown + .get() + .is_both_shutdown(), + _ => return Err(SystemError::EINVAL), + }; + Ok(peer_shutdown) + } + + fn can_recv(&self) -> Result { + let can = match &*self.inner.read() { + Inner::Connected(connected) => connected.can_recv(), + _ => return Err(SystemError::ENOTCONN), + }; + Ok(can) + } +} + +impl Socket for StreamSocket { + fn connect(&self, server_endpoint: Endpoint) -> Result<(), SystemError> { + //获取客户端地址 + let client_endpoint = match &mut *self.inner.write() { + Inner::Init(init) => match init.endpoint().cloned() { + Some(endpoint) => { + debug!("bind when connected"); + Some(endpoint) + } + None => { + debug!("not bind when connected"); + let inode = SocketInode::new(self.self_ref.upgrade().unwrap().clone()); + let epoint = Endpoint::Inode((inode.clone(), String::from(""))); + let _ = init.bind(epoint.clone()); + Some(epoint) + } + }, + Inner::Connected(_) => return Err(SystemError::EISCONN), + Inner::Listener(_) => return Err(SystemError::EINVAL), + }; + //获取服务端地址 + // let peer_inode = match server_endpoint.clone() { + // Endpoint::Inode(socket) => socket, + // _ => return Err(SystemError::EINVAL), + // }; + + //找到对端socket + let (peer_inode, sun_path) = match server_endpoint { + Endpoint::Inode((inode, path)) => (inode, path), + Endpoint::Unixpath((inode_id, path)) => { + let inode_guard = INODE_MAP.read_irqsave(); + let inode = inode_guard.get(&inode_id).unwrap(); + match inode { + Endpoint::Inode((inode, _)) => (inode.clone(), path), + _ => return Err(SystemError::EINVAL), + } + } + Endpoint::Abspath((abs_addr, path)) => { + let inode_guard = ABS_INODE_MAP.lock_irqsave(); + let inode = match inode_guard.get(&abs_addr.name()) { + Some(inode) => inode, + None => { + log::debug!("can not find inode from absInodeMap"); + return Err(SystemError::EINVAL); + } + }; + match inode { + Endpoint::Inode((inode, _)) => (inode.clone(), path), + _ => { + debug!("when connect, find inode failed!"); + return Err(SystemError::EINVAL); + } + } + } + _ => return Err(SystemError::EINVAL), + }; + + let remote_socket: Arc = + Arc::downcast::(peer_inode.inner()).map_err(|_| SystemError::EINVAL)?; + + //创建新的对端socket + let new_server_socket = StreamSocket::new(); + let new_server_inode = SocketInode::new(new_server_socket.clone()); + let new_server_endpoint = Some(Endpoint::Inode((new_server_inode.clone(), sun_path))); + //获取connect pair + let (client_conn, server_conn) = + Connected::new_pair(client_endpoint, new_server_endpoint.clone()); + *new_server_socket.inner.write() = Inner::Connected(server_conn); + + //查看remote_socket是否处于监听状态 + let remote_listener = remote_socket.inner.write(); + match &*remote_listener { + Inner::Listener(listener) => { + //往服务端socket的连接队列中添加connected + listener.push_incoming(new_server_inode)?; + *self.inner.write() = Inner::Connected(client_conn); + remote_socket.wait_queue.wakeup(None); + } + _ => return Err(SystemError::EINVAL), + } + + return Ok(()); + } + + fn bind(&self, endpoint: Endpoint) -> Result<(), SystemError> { + match endpoint { + Endpoint::Unixpath((inodeid, path)) => { + let inode = match &mut *self.inner.write() { + Inner::Init(init) => init.bind_path(path)?, + _ => { + log::error!("socket has listen or connected"); + return Err(SystemError::EINVAL); + } + }; + INODE_MAP.write_irqsave().insert(inodeid, inode); + Ok(()) + } + Endpoint::Abspath((abshandle, path)) => { + let inode = match &mut *self.inner.write() { + Inner::Init(init) => init.bind_path(path)?, + _ => { + log::error!("socket has listen or connected"); + return Err(SystemError::EINVAL); + } + }; + ABS_INODE_MAP.lock_irqsave().insert(abshandle.name(), inode); + Ok(()) + } + _ => return Err(SystemError::EINVAL), + } + } + + fn shutdown(&self, _stype: ShutdownTemp) -> Result<(), SystemError> { + todo!(); + } + + fn listen(&self, backlog: usize) -> Result<(), SystemError> { + let mut inner = self.inner.write(); + let epoint = match &*inner { + Inner::Init(init) => init.endpoint().ok_or(SystemError::EINVAL)?.clone(), + Inner::Connected(_) => { + return Err(SystemError::EINVAL); + } + Inner::Listener(listener) => { + return listener.listen(backlog); + } + }; + + let listener = Listener::new(Some(epoint), backlog); + *inner = Inner::Listener(listener); + + return Ok(()); + } + + fn accept(&self) -> Result<(Arc, Endpoint), SystemError> { + debug!("stream server begin accept"); + //目前只实现了阻塞式实现 + loop { + wq_wait_event_interruptible!(self.wait_queue, self.is_acceptable(), {})?; + match self.try_accept() { + Ok((socket, endpoint)) => { + debug!("server accept!:{:?}", endpoint); + return Ok((socket, endpoint)); + } + Err(_) => continue, + } + } + } + + fn set_option(&self, _level: PSOL, _optname: usize, _optval: &[u8]) -> Result<(), SystemError> { + log::warn!("setsockopt is not implemented"); + Ok(()) + } + + fn wait_queue(&self) -> &WaitQueue { + return &self.wait_queue; + } + + fn poll(&self) -> usize { + let mut mask = EP::empty(); + let shutdown = self.shutdown.get(); + + // 参考linux的unix_poll https://code.dragonos.org.cn/xref/linux-6.1.9/net/unix/af_unix.c#3152 + // 用关闭读写端表示连接断开 + if shutdown.is_both_shutdown() || self.is_peer_shutdown().unwrap() { + mask |= EP::EPOLLHUP; + } + + if shutdown.is_recv_shutdown() { + mask |= EP::EPOLLRDHUP | EP::EPOLLIN | EP::EPOLLRDNORM; + } + match &*self.inner.read() { + Inner::Connected(connected) => { + if connected.can_recv() { + mask |= EP::EPOLLIN | EP::EPOLLRDNORM; + } + // if (sk_is_readable(sk)) + // mask |= EPOLLIN | EPOLLRDNORM; + + // TODO:处理紧急情况 EPOLLPRI + // TODO:处理连接是否关闭 EPOLLHUP + if !shutdown.is_send_shutdown() { + if connected.can_send().unwrap() { + mask |= EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND; + } else { + todo!("poll: buffer space not enough"); + } + } + } + Inner::Listener(_) => mask |= EP::EPOLLIN, + Inner::Init(_) => mask |= EP::EPOLLOUT, + } + mask.bits() as usize + } + + fn close(&self) -> Result<(), SystemError> { + self.shutdown.recv_shutdown(); + self.shutdown.send_shutdown(); + + let endpoint = self.get_name()?; + let path = match &endpoint { + Endpoint::Inode((_, path)) => path, + Endpoint::Unixpath((_, path)) => path, + Endpoint::Abspath((_, path)) => path, + _ => return Err(SystemError::EINVAL), + }; + + if path.is_empty() { + return Ok(()); + } + + match &endpoint { + Endpoint::Unixpath((inode_id, _)) => { + let mut inode_guard = INODE_MAP.write_irqsave(); + inode_guard.remove(inode_id); + } + Endpoint::Inode((current_inode, current_path)) => { + let mut inode_guard = INODE_MAP.write_irqsave(); + // 遍历查找匹配的条目 + let target_entry = inode_guard + .iter() + .find(|(_, ep)| { + if let Endpoint::Inode((map_inode, map_path)) = ep { + // 通过指针相等性比较确保是同一对象 + Arc::ptr_eq(map_inode, current_inode) && map_path == current_path + } else { + log::debug!("not match"); + false + } + }) + .map(|(id, _)| *id); + + if let Some(id) = target_entry { + inode_guard.remove(&id).ok_or(SystemError::EINVAL)?; + } + } + Endpoint::Abspath((abshandle, _)) => { + let mut abs_inode_map = ABS_INODE_MAP.lock_irqsave(); + abs_inode_map.remove(&abshandle.name()); + } + _ => { + log::error!("invalid endpoint type"); + return Err(SystemError::EINVAL); + } + } + + *self.inner.write() = Inner::Init(Init::new()); + self.wait_queue.wakeup(None); + + let _ = remove_abs_addr(path); + + Ok(()) + } + + fn get_peer_name(&self) -> Result { + //获取对端地址 + let endpoint = match &*self.inner.read() { + Inner::Connected(connected) => connected.peer_endpoint().cloned(), + _ => return Err(SystemError::ENOTCONN), + }; + + if let Some(endpoint) = endpoint { + return Ok(endpoint); + } else { + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + } + } + + fn get_name(&self) -> Result { + //获取本端地址 + let endpoint = match &*self.inner.read() { + Inner::Init(init) => init.endpoint().cloned(), + Inner::Connected(connected) => connected.endpoint().cloned(), + Inner::Listener(listener) => listener.endpoint().cloned(), + }; + + if let Some(endpoint) = endpoint { + return Ok(endpoint); + } else { + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + } + } + + fn get_option( + &self, + _level: PSOL, + _name: usize, + _value: &mut [u8], + ) -> Result { + log::warn!("getsockopt is not implemented"); + Ok(0) + } + + fn read(&self, buffer: &mut [u8]) -> Result { + self.recv(buffer, socket::PMSG::empty()) + } + + fn recv(&self, buffer: &mut [u8], flags: socket::PMSG) -> Result { + if !flags.contains(PMSG::DONTWAIT) { + loop { + log::debug!("socket try recv"); + wq_wait_event_interruptible!( + self.wait_queue, + self.can_recv()? || self.is_peer_shutdown()?, + {} + )?; + // connect锁和flag判断顺序不正确,应该先判断在 + match &*self.inner.write() { + Inner::Connected(connected) => match connected.try_recv(buffer) { + Ok(usize) => { + log::debug!("recv successfully"); + return Ok(usize); + } + Err(_) => continue, + }, + _ => { + log::error!("the socket is not connected"); + return Err(SystemError::ENOTCONN); + } + } + } + } else { + unimplemented!("unimplemented non_block") + } + } + + fn recv_from( + &self, + buffer: &mut [u8], + flags: socket::PMSG, + _address: Option, + ) -> Result<(usize, Endpoint), SystemError> { + if flags.contains(PMSG::OOB) { + return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP); + } + if !flags.contains(PMSG::DONTWAIT) { + loop { + log::debug!("socket try recv from"); + + wq_wait_event_interruptible!( + self.wait_queue, + self.can_recv()? || self.is_peer_shutdown()?, + {} + )?; + // connect锁和flag判断顺序不正确,应该先判断在 + log::debug!("try recv"); + + match &*self.inner.write() { + Inner::Connected(connected) => match connected.try_recv(buffer) { + Ok(usize) => { + log::debug!("recvs from successfully"); + return Ok((usize, connected.peer_endpoint().unwrap().clone())); + } + Err(_) => continue, + }, + _ => { + log::error!("the socket is not connected"); + return Err(SystemError::ENOTCONN); + } + } + } + } else { + unimplemented!("unimplemented non_block") + } + } + + fn recv_msg(&self, _msg: &mut MsgHdr, _flags: socket::PMSG) -> Result { + Err(SystemError::ENOSYS) + } + + fn send(&self, buffer: &[u8], flags: socket::PMSG) -> Result { + if self.is_peer_shutdown()? { + return Err(SystemError::EPIPE); + } + if !flags.contains(PMSG::DONTWAIT) { + loop { + match &*self.inner.write() { + Inner::Connected(connected) => match connected.try_send(buffer) { + Ok(usize) => { + log::debug!("send successfully"); + return Ok(usize); + } + Err(_) => continue, + }, + _ => { + log::error!("the socket is not connected"); + return Err(SystemError::ENOTCONN); + } + } + } + } else { + unimplemented!("unimplemented non_block") + } + } + + fn send_msg(&self, _msg: &MsgHdr, _flags: socket::PMSG) -> Result { + todo!() + } + + fn send_to( + &self, + _buffer: &[u8], + _flags: socket::PMSG, + _address: Endpoint, + ) -> Result { + Err(SystemError::ENOSYS) + } + + fn write(&self, buffer: &[u8]) -> Result { + self.send(buffer, socket::PMSG::empty()) + } + + fn send_buffer_size(&self) -> usize { + log::warn!("using default buffer size"); + StreamSocket::DEFAULT_BUF_SIZE + } + + fn recv_buffer_size(&self) -> usize { + log::warn!("using default buffer size"); + StreamSocket::DEFAULT_BUF_SIZE + } +} diff --git a/kernel/src/net/socket/utils.rs b/kernel/src/net/socket/utils.rs new file mode 100644 index 000000000..372106d40 --- /dev/null +++ b/kernel/src/net/socket/utils.rs @@ -0,0 +1,26 @@ +use crate::net::socket; +use alloc::sync::Arc; +use socket::Family; +use system_error::SystemError; + +pub fn create_socket( + family: socket::AddressFamily, + socket_type: socket::PSOCK, + protocol: u32, + is_nonblock: bool, + is_close_on_exec: bool, +) -> Result, SystemError> { + type AF = socket::AddressFamily; + let inode = match family { + AF::INet => socket::inet::Inet::socket(socket_type, protocol)?, + // AF::INet6 => socket::inet::Inet6::socket(socket_type, protocol)?, + AF::Unix => socket::unix::Unix::socket(socket_type, protocol)?, + _ => { + log::warn!("unsupport address family"); + return Err(SystemError::EAFNOSUPPORT); + } + }; + inode.set_nonblock(is_nonblock); + inode.set_close_on_exec(is_close_on_exec); + return Ok(inode); +} diff --git a/kernel/src/net/syscall.rs b/kernel/src/net/syscall.rs index 6a9cb92bf..2d90eff96 100644 --- a/kernel/src/net/syscall.rs +++ b/kernel/src/net/syscall.rs @@ -1,29 +1,20 @@ -use core::{cmp::min, ffi::CStr}; - -use alloc::{boxed::Box, sync::Arc}; -use num_traits::{FromPrimitive, ToPrimitive}; -use smoltcp::wire; +use alloc::sync::Arc; +use log::debug; use system_error::SystemError; use crate::{ filesystem::vfs::{ - fcntl::AtFlags, file::{File, FileMode}, - iov::{IoVec, IoVecs}, - open::do_sys_open, - syscall::ModeType, - FileType, + iov::IoVecs, }, - libs::spinlock::SpinLockGuard, - mm::{verify_area, VirtAddr}, - net::socket::{AddressFamily, SOL_SOCKET}, + net::socket::AddressFamily, process::ProcessManager, syscall::Syscall, }; use super::{ - socket::{new_socket, PosixSocketType, Socket, SocketInode}, - Endpoint, Protocol, ShutdownType, + posix::{MsgHdr, PosixArgsSocketType, SockAddr}, + socket::{self, endpoint::Endpoint, unix::Unix}, }; /// Flags for socket, socketpair, accept4 @@ -41,18 +32,34 @@ impl Syscall { socket_type: usize, protocol: usize, ) -> Result { - let address_family = AddressFamily::try_from(address_family as u16)?; - let socket_type = PosixSocketType::try_from((socket_type & 0xf) as u8)?; - let protocol = Protocol::from(protocol as u8); - - let socket = new_socket(address_family, socket_type, protocol)?; - - let socketinode: Arc = SocketInode::new(socket); - let f = File::new(socketinode, FileMode::O_RDWR)?; + // 打印收到的参数 + // log::debug!( + // "socket: address_family={:?}, socket_type={:?}, protocol={:?}", + // address_family, + // socket_type, + // protocol + // ); + let address_family = socket::AddressFamily::try_from(address_family as u16)?; + let type_arg = PosixArgsSocketType::from_bits_truncate(socket_type as u32); + let is_nonblock = type_arg.is_nonblock(); + let is_close_on_exec = type_arg.is_cloexec(); + let stype = socket::PSOCK::try_from(type_arg)?; + // log::debug!("type_arg {:?} stype {:?}", type_arg, stype); + + let inode = socket::create_socket( + address_family, + stype, + protocol as u32, + is_nonblock, + is_close_on_exec, + )?; + + let file = File::new(inode, FileMode::O_RDWR)?; // 把socket添加到当前进程的文件描述符表中 let binding = ProcessManager::current_pcb().fd_table(); let mut fd_table_guard = binding.write(); - let fd = fd_table_guard.alloc_fd(f, None).map(|x| x as usize); + let fd: Result = + fd_table_guard.alloc_fd(file, None).map(|x| x as usize); drop(fd_table_guard); return fd; } @@ -71,26 +78,25 @@ impl Syscall { fds: &mut [i32], ) -> Result { let address_family = AddressFamily::try_from(address_family as u16)?; - let socket_type = PosixSocketType::try_from((socket_type & 0xf) as u8)?; - let protocol = Protocol::from(protocol as u8); + let socket_type = PosixArgsSocketType::from_bits_truncate(socket_type as u32); + let stype = socket::PSOCK::try_from(socket_type)?; let binding = ProcessManager::current_pcb().fd_table(); let mut fd_table_guard = binding.write(); - // 创建一对socket - let inode0 = SocketInode::new(new_socket(address_family, socket_type, protocol)?); - let inode1 = SocketInode::new(new_socket(address_family, socket_type, protocol)?); - - // 进行pair - unsafe { - inode0 - .inner_no_preempt() - .connect(Endpoint::Inode(Some(inode1.clone())))?; - inode1 - .inner_no_preempt() - .connect(Endpoint::Inode(Some(inode0.clone())))?; + // check address family, only support AF_UNIX + if address_family != AddressFamily::Unix { + log::warn!( + "only support AF_UNIX, {:?} with protocol {:?} is not supported", + address_family, + protocol + ); + return Err(SystemError::EAFNOSUPPORT); } + // 创建一对新的unix socket pair + let (inode0, inode1) = Unix::new_pairs(stype)?; + fds[0] = fd_table_guard.alloc_fd(File::new(inode0, FileMode::O_RDWR)?, None)?; fds[1] = fd_table_guard.alloc_fd(File::new(inode1, FileMode::O_RDWR)?, None)?; @@ -111,12 +117,12 @@ impl Syscall { optname: usize, optval: &[u8], ) -> Result { - let socket_inode: Arc = ProcessManager::current_pcb() + let sol = socket::PSOL::try_from(level as u32)?; + let socket: Arc = ProcessManager::current_pcb() .get_socket(fd as i32) .ok_or(SystemError::EBADF)?; - // 获取内层的socket(真正的数据) - let socket: SpinLockGuard> = socket_inode.inner(); - return socket.setsockopt(level, optname, optval).map(|_| 0); + debug!("setsockopt: level = {:?} ", sol); + return socket.set_option(sol, optname, optval).map(|_| 0); } /// @brief sys_getsockopt系统调用的实际执行函数 @@ -137,27 +143,29 @@ impl Syscall { ) -> Result { // 获取socket let optval = optval as *mut u32; - let binding: Arc = ProcessManager::current_pcb() + let socket: Arc = ProcessManager::current_pcb() .get_socket(fd as i32) .ok_or(SystemError::EBADF)?; - let socket = binding.inner(); - if level as u8 == SOL_SOCKET { - let optname = PosixSocketOption::try_from(optname as i32) - .map_err(|_| SystemError::ENOPROTOOPT)?; + use socket::{PSO, PSOL}; + + let level = PSOL::try_from(level as u32)?; + + if matches!(level, PSOL::SOCKET) { + let optname = PSO::try_from(optname as u32).map_err(|_| SystemError::ENOPROTOOPT)?; match optname { - PosixSocketOption::SO_SNDBUF => { + PSO::SNDBUF => { // 返回发送缓冲区大小 unsafe { - *optval = socket.metadata().tx_buf_size as u32; + *optval = socket.send_buffer_size() as u32; *optlen = core::mem::size_of::() as u32; } return Ok(0); } - PosixSocketOption::SO_RCVBUF => { + PSO::RCVBUF => { // 返回默认的接收缓冲区大小 unsafe { - *optval = socket.metadata().rx_buf_size as u32; + *optval = socket.recv_buffer_size() as u32; *optlen = core::mem::size_of::() as u32; } return Ok(0); @@ -175,13 +183,12 @@ impl Syscall { // to be interpreted by the TCP protocol, level should be set to the // protocol number of TCP. - let posix_protocol = - PosixIpProtocol::try_from(level as u16).map_err(|_| SystemError::ENOPROTOOPT)?; - if posix_protocol == PosixIpProtocol::TCP { - let optname = PosixTcpSocketOptions::try_from(optname as i32) - .map_err(|_| SystemError::ENOPROTOOPT)?; + if matches!(level, PSOL::TCP) { + use socket::inet::stream::TcpOption; + let optname = + TcpOption::try_from(optname as i32).map_err(|_| SystemError::ENOPROTOOPT)?; match optname { - PosixTcpSocketOptions::Congestion => return Ok(0), + TcpOption::Congestion => return Ok(0), _ => { return Err(SystemError::ENOPROTOOPT); } @@ -197,12 +204,11 @@ impl Syscall { /// @param addrlen 地址长度 /// /// @return 成功返回0,失败返回错误码 - pub fn connect(fd: usize, addr: *const SockAddr, addrlen: usize) -> Result { + pub fn connect(fd: usize, addr: *const SockAddr, addrlen: u32) -> Result { let endpoint: Endpoint = SockAddr::to_endpoint(addr, addrlen)?; - let socket: Arc = ProcessManager::current_pcb() + let socket = ProcessManager::current_pcb() .get_socket(fd as i32) .ok_or(SystemError::EBADF)?; - let mut socket = unsafe { socket.inner_no_preempt() }; socket.connect(endpoint)?; Ok(0) } @@ -214,12 +220,19 @@ impl Syscall { /// @param addrlen 地址长度 /// /// @return 成功返回0,失败返回错误码 - pub fn bind(fd: usize, addr: *const SockAddr, addrlen: usize) -> Result { + pub fn bind(fd: usize, addr: *const SockAddr, addrlen: u32) -> Result { + // 打印收到的参数 + // log::debug!( + // "bind: fd={:?}, family={:?}, addrlen={:?}", + // fd, + // (unsafe { addr.as_ref().unwrap().family }), + // addrlen + // ); let endpoint: Endpoint = SockAddr::to_endpoint(addr, addrlen)?; - let socket: Arc = ProcessManager::current_pcb() + let socket: Arc = ProcessManager::current_pcb() .get_socket(fd as i32) .ok_or(SystemError::EBADF)?; - let mut socket = unsafe { socket.inner_no_preempt() }; + // log::debug!("bind: socket={:?}", socket); socket.bind(endpoint)?; Ok(0) } @@ -236,9 +249,9 @@ impl Syscall { pub fn sendto( fd: usize, buf: &[u8], - _flags: u32, + flags: u32, addr: *const SockAddr, - addrlen: usize, + addrlen: u32, ) -> Result { let endpoint = if addr.is_null() { None @@ -246,11 +259,17 @@ impl Syscall { Some(SockAddr::to_endpoint(addr, addrlen)?) }; - let socket: Arc = ProcessManager::current_pcb() + let flags = socket::PMSG::from_bits_truncate(flags); + + let socket: Arc = ProcessManager::current_pcb() .get_socket(fd as i32) .ok_or(SystemError::EBADF)?; - let socket = unsafe { socket.inner_no_preempt() }; - return socket.write(buf, endpoint); + + if let Some(endpoint) = endpoint { + return socket.send_to(buf, endpoint, flags); + } else { + return socket.send(buf, flags); + } } /// @brief sys_recvfrom系统调用的实际执行函数 @@ -265,28 +284,37 @@ impl Syscall { pub fn recvfrom( fd: usize, buf: &mut [u8], - _flags: u32, + flags: u32, addr: *mut SockAddr, - addrlen: *mut u32, + addr_len: *mut u32, ) -> Result { - let socket: Arc = ProcessManager::current_pcb() + let socket: Arc = ProcessManager::current_pcb() .get_socket(fd as i32) .ok_or(SystemError::EBADF)?; - let socket = unsafe { socket.inner_no_preempt() }; + let flags = socket::PMSG::from_bits_truncate(flags); - let (n, endpoint) = socket.read(buf); - drop(socket); + if addr.is_null() { + let (n, _) = socket.recv_from(buf, flags, None)?; + return Ok(n); + } - let n: usize = n?; + // address is not null + let address = unsafe { addr.as_ref() }.ok_or(SystemError::EINVAL)?; - // 如果有地址信息,将地址信息写入用户空间 - if !addr.is_null() { + if unsafe { address.is_empty() } { + let (recv_len, endpoint) = socket.recv_from(buf, flags, None)?; let sockaddr_in = SockAddr::from(endpoint); unsafe { - sockaddr_in.write_to_user(addr, addrlen)?; + sockaddr_in.write_to_user(addr, addr_len)?; } - } - return Ok(n); + return Ok(recv_len); + } else { + // 从socket中读取数据 + let addr_len = *unsafe { addr_len.as_ref() }.ok_or(SystemError::EINVAL)?; + let address = SockAddr::to_endpoint(addr, addr_len)?; + let (recv_len, _) = socket.recv_from(buf, flags, Some(address))?; + return Ok(recv_len); + }; } /// @brief sys_recvmsg系统调用的实际执行函数 @@ -296,30 +324,25 @@ impl Syscall { /// @param flags 标志,暂时未使用 /// /// @return 成功返回接收的字节数,失败返回错误码 - pub fn recvmsg(fd: usize, msg: &mut MsgHdr, _flags: u32) -> Result { + pub fn recvmsg(fd: usize, msg: &mut MsgHdr, flags: u32) -> Result { // 检查每个缓冲区地址是否合法,生成iovecs let iovs = unsafe { IoVecs::from_user(msg.msg_iov, msg.msg_iovlen, true)? }; - let socket: Arc = ProcessManager::current_pcb() + let socket: Arc = ProcessManager::current_pcb() .get_socket(fd as i32) .ok_or(SystemError::EBADF)?; - let socket = unsafe { socket.inner_no_preempt() }; + + let flags = socket::PMSG::from_bits_truncate(flags); let mut buf = iovs.new_buf(true); // 从socket中读取数据 - let (n, endpoint) = socket.read(&mut buf); + let recv_size = socket.recv(&mut buf, flags)?; drop(socket); - let n: usize = n?; - // 将数据写入用户空间的iovecs - iovs.scatter(&buf[..n]); + iovs.scatter(&buf[..recv_size]); - let sockaddr_in = SockAddr::from(endpoint); - unsafe { - sockaddr_in.write_to_user(msg.msg_name, &mut msg.msg_namelen)?; - } - return Ok(n); + return Ok(recv_size); } /// @brief sys_listen系统调用的实际执行函数 @@ -329,12 +352,10 @@ impl Syscall { /// /// @return 成功返回0,失败返回错误码 pub fn listen(fd: usize, backlog: usize) -> Result { - let socket: Arc = ProcessManager::current_pcb() + let socket: Arc = ProcessManager::current_pcb() .get_socket(fd as i32) .ok_or(SystemError::EBADF)?; - let mut socket = unsafe { socket.inner_no_preempt() }; - socket.listen(backlog)?; - return Ok(0); + socket.listen(backlog).map(|_| 0) } /// @brief sys_shutdown系统调用的实际执行函数 @@ -344,11 +365,10 @@ impl Syscall { /// /// @return 成功返回0,失败返回错误码 pub fn shutdown(fd: usize, how: usize) -> Result { - let socket: Arc = ProcessManager::current_pcb() + let socket: Arc = ProcessManager::current_pcb() .get_socket(fd as i32) .ok_or(SystemError::EBADF)?; - let mut socket = unsafe { socket.inner_no_preempt() }; - socket.shutdown(ShutdownType::from_bits_truncate((how + 1) as u8))?; + socket.shutdown(how.try_into()?)?; return Ok(0); } @@ -404,18 +424,16 @@ impl Syscall { addrlen: *mut u32, flags: u32, ) -> Result { - let socket: Arc = ProcessManager::current_pcb() + let socket: Arc = ProcessManager::current_pcb() .get_socket(fd as i32) .ok_or(SystemError::EBADF)?; - // debug!("accept: socket={:?}", socket); - let mut socket = unsafe { socket.inner_no_preempt() }; + // 从socket中接收连接 let (new_socket, remote_endpoint) = socket.accept()?; drop(socket); // debug!("accept: new_socket={:?}", new_socket); // Insert the new socket into the file descriptor vector - let new_socket: Arc = SocketInode::new(new_socket); let mut file_mode = FileMode::O_RDWR; if flags & SOCK_NONBLOCK.bits() != 0 { @@ -459,12 +477,10 @@ impl Syscall { if addr.is_null() { return Err(SystemError::EINVAL); } - let socket: Arc = ProcessManager::current_pcb() + let endpoint = ProcessManager::current_pcb() .get_socket(fd as i32) - .ok_or(SystemError::EBADF)?; - let socket = socket.inner(); - let endpoint: Endpoint = socket.endpoint().ok_or(SystemError::EINVAL)?; - drop(socket); + .ok_or(SystemError::EBADF)? + .get_name()?; let sockaddr_in = SockAddr::from(endpoint); unsafe { @@ -489,11 +505,11 @@ impl Syscall { return Err(SystemError::EINVAL); } - let socket: Arc = ProcessManager::current_pcb() + let socket: Arc = ProcessManager::current_pcb() .get_socket(fd as i32) .ok_or(SystemError::EBADF)?; - let socket = socket.inner(); - let endpoint: Endpoint = socket.peer_endpoint().ok_or(SystemError::EINVAL)?; + + let endpoint: Endpoint = socket.get_peer_name()?; drop(socket); let sockaddr_in = SockAddr::from(endpoint); @@ -503,541 +519,3 @@ impl Syscall { return Ok(0); } } - -// 参考资料: https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/netinet_in.h.html#tag_13_32 -#[repr(C)] -#[derive(Debug, Clone, Copy)] -pub struct SockAddrIn { - pub sin_family: u16, - pub sin_port: u16, - pub sin_addr: u32, - pub sin_zero: [u8; 8], -} - -#[repr(C)] -#[derive(Debug, Clone, Copy)] -pub struct SockAddrUn { - pub sun_family: u16, - pub sun_path: [u8; 108], -} - -#[repr(C)] -#[derive(Debug, Clone, Copy)] -pub struct SockAddrLl { - pub sll_family: u16, - pub sll_protocol: u16, - pub sll_ifindex: u32, - pub sll_hatype: u16, - pub sll_pkttype: u8, - pub sll_halen: u8, - pub sll_addr: [u8; 8], -} - -#[repr(C)] -#[derive(Debug, Clone, Copy)] -pub struct SockAddrNl { - nl_family: u16, - nl_pad: u16, - nl_pid: u32, - nl_groups: u32, -} - -#[repr(C)] -#[derive(Debug, Clone, Copy)] -pub struct SockAddrPlaceholder { - pub family: u16, - pub data: [u8; 14], -} - -#[repr(C)] -#[derive(Clone, Copy)] -pub union SockAddr { - pub family: u16, - pub addr_in: SockAddrIn, - pub addr_un: SockAddrUn, - pub addr_ll: SockAddrLl, - pub addr_nl: SockAddrNl, - pub addr_ph: SockAddrPlaceholder, -} - -impl SockAddr { - /// @brief 把用户传入的SockAddr转换为Endpoint结构体 - pub fn to_endpoint(addr: *const SockAddr, len: usize) -> Result { - verify_area( - VirtAddr::new(addr as usize), - core::mem::size_of::(), - ) - .map_err(|_| SystemError::EFAULT)?; - - let addr = unsafe { addr.as_ref() }.ok_or(SystemError::EFAULT)?; - unsafe { - match AddressFamily::try_from(addr.family)? { - AddressFamily::INet => { - if len < addr.len()? { - return Err(SystemError::EINVAL); - } - - let addr_in: SockAddrIn = addr.addr_in; - - let ip: wire::IpAddress = wire::IpAddress::from(wire::Ipv4Address::from_bytes( - &u32::from_be(addr_in.sin_addr).to_be_bytes()[..], - )); - let port = u16::from_be(addr_in.sin_port); - - return Ok(Endpoint::Ip(Some(wire::IpEndpoint::new(ip, port)))); - } - AddressFamily::Unix => { - let addr_un: SockAddrUn = addr.addr_un; - - let path = CStr::from_bytes_until_nul(&addr_un.sun_path) - .map_err(|_| SystemError::EINVAL)? - .to_str() - .map_err(|_| SystemError::EINVAL)?; - - let fd = do_sys_open( - AtFlags::AT_FDCWD.bits(), - path, - FileMode::O_RDWR, - ModeType::S_IWUGO | ModeType::S_IRUGO, - true, - )?; - - let binding = ProcessManager::current_pcb().fd_table(); - let fd_table_guard = binding.read(); - - let file = fd_table_guard.get_file_by_fd(fd as i32).unwrap(); - if file.file_type() != FileType::Socket { - return Err(SystemError::ENOTSOCK); - } - let inode = file.inode(); - let socketinode = inode.as_any_ref().downcast_ref::>(); - - return Ok(Endpoint::Inode(socketinode.cloned())); - } - AddressFamily::Packet => { - // TODO: support packet socket - return Err(SystemError::EINVAL); - } - AddressFamily::Netlink => { - // TODO: support netlink socket - return Err(SystemError::EINVAL); - } - _ => { - return Err(SystemError::EINVAL); - } - } - } - } - - /// @brief 获取地址长度 - pub fn len(&self) -> Result { - let ret = match AddressFamily::try_from(unsafe { self.family })? { - AddressFamily::INet => Ok(core::mem::size_of::()), - AddressFamily::Packet => Ok(core::mem::size_of::()), - AddressFamily::Netlink => Ok(core::mem::size_of::()), - AddressFamily::Unix => Err(SystemError::EINVAL), - _ => Err(SystemError::EINVAL), - }; - - return ret; - } - - /// @brief 把SockAddr的数据写入用户空间 - /// - /// @param addr 用户空间的SockAddr的地址 - /// @param len 要写入的长度 - /// - /// @return 成功返回写入的长度,失败返回错误码 - pub unsafe fn write_to_user( - &self, - addr: *mut SockAddr, - addr_len: *mut u32, - ) -> Result { - // 当用户传入的地址或者长度为空时,直接返回0 - if addr.is_null() || addr_len.is_null() { - return Ok(0); - } - - // 检查用户传入的地址是否合法 - verify_area( - VirtAddr::new(addr as usize), - core::mem::size_of::(), - ) - .map_err(|_| SystemError::EFAULT)?; - - verify_area( - VirtAddr::new(addr_len as usize), - core::mem::size_of::(), - ) - .map_err(|_| SystemError::EFAULT)?; - - let to_write = min(self.len()?, *addr_len as usize); - if to_write > 0 { - let buf = core::slice::from_raw_parts_mut(addr as *mut u8, to_write); - buf.copy_from_slice(core::slice::from_raw_parts( - self as *const SockAddr as *const u8, - to_write, - )); - } - *addr_len = self.len()? as u32; - return Ok(to_write); - } -} - -impl From for SockAddr { - fn from(value: Endpoint) -> Self { - match value { - Endpoint::Ip(ip_endpoint) => { - // 未指定地址 - if ip_endpoint.is_none() { - return SockAddr { - addr_ph: SockAddrPlaceholder { - family: AddressFamily::Unspecified as u16, - data: [0; 14], - }, - }; - } - // 指定了地址 - let ip_endpoint = ip_endpoint.unwrap(); - match ip_endpoint.addr { - wire::IpAddress::Ipv4(ipv4_addr) => { - let addr_in = SockAddrIn { - sin_family: AddressFamily::INet as u16, - sin_port: ip_endpoint.port.to_be(), - sin_addr: u32::from_be_bytes(ipv4_addr.0).to_be(), - sin_zero: [0; 8], - }; - - return SockAddr { addr_in }; - } - _ => { - unimplemented!("not support ipv6"); - } - } - } - - Endpoint::LinkLayer(link_endpoint) => { - let addr_ll = SockAddrLl { - sll_family: AddressFamily::Packet as u16, - sll_protocol: 0, - sll_ifindex: link_endpoint.interface as u32, - sll_hatype: 0, - sll_pkttype: 0, - sll_halen: 0, - sll_addr: [0; 8], - }; - - return SockAddr { addr_ll }; - } - - _ => { - // todo: support other endpoint, like Netlink... - unimplemented!("not support {value:?}"); - } - } - } -} - -#[repr(C)] -#[derive(Debug, Clone, Copy)] -pub struct MsgHdr { - /// 指向一个SockAddr结构体的指针 - pub msg_name: *mut SockAddr, - /// SockAddr结构体的大小 - pub msg_namelen: u32, - /// scatter/gather array - pub msg_iov: *mut IoVec, - /// elements in msg_iov - pub msg_iovlen: usize, - /// 辅助数据 - pub msg_control: *mut u8, - /// 辅助数据长度 - pub msg_controllen: usize, - /// 接收到的消息的标志 - pub msg_flags: u32, -} - -#[derive(Debug, Clone, Copy, FromPrimitive, ToPrimitive, PartialEq, Eq)] -pub enum PosixIpProtocol { - /// Dummy protocol for TCP. - IP = 0, - /// Internet Control Message Protocol. - ICMP = 1, - /// Internet Group Management Protocol. - IGMP = 2, - /// IPIP tunnels (older KA9Q tunnels use 94). - IPIP = 4, - /// Transmission Control Protocol. - TCP = 6, - /// Exterior Gateway Protocol. - EGP = 8, - /// PUP protocol. - PUP = 12, - /// User Datagram Protocol. - UDP = 17, - /// XNS IDP protocol. - IDP = 22, - /// SO Transport Protocol Class 4. - TP = 29, - /// Datagram Congestion Control Protocol. - DCCP = 33, - /// IPv6-in-IPv4 tunnelling. - IPv6 = 41, - /// RSVP Protocol. - RSVP = 46, - /// Generic Routing Encapsulation. (Cisco GRE) (rfc 1701, 1702) - GRE = 47, - /// Encapsulation Security Payload protocol - ESP = 50, - /// Authentication Header protocol - AH = 51, - /// Multicast Transport Protocol. - MTP = 92, - /// IP option pseudo header for BEET - BEETPH = 94, - /// Encapsulation Header. - ENCAP = 98, - /// Protocol Independent Multicast. - PIM = 103, - /// Compression Header Protocol. - COMP = 108, - /// Stream Control Transport Protocol - SCTP = 132, - /// UDP-Lite protocol (RFC 3828) - UDPLITE = 136, - /// MPLS in IP (RFC 4023) - MPLSINIP = 137, - /// Ethernet-within-IPv6 Encapsulation - ETHERNET = 143, - /// Raw IP packets - RAW = 255, - /// Multipath TCP connection - MPTCP = 262, -} - -impl TryFrom for PosixIpProtocol { - type Error = SystemError; - - fn try_from(value: u16) -> Result { - match ::from_u16(value) { - Some(p) => Ok(p), - None => Err(SystemError::EPROTONOSUPPORT), - } - } -} - -impl From for u16 { - fn from(value: PosixIpProtocol) -> Self { - ::to_u16(&value).unwrap() - } -} - -#[allow(non_camel_case_types)] -#[derive(Debug, Clone, Copy, FromPrimitive, ToPrimitive, PartialEq, Eq)] -pub enum PosixSocketOption { - SO_DEBUG = 1, - SO_REUSEADDR = 2, - SO_TYPE = 3, - SO_ERROR = 4, - SO_DONTROUTE = 5, - SO_BROADCAST = 6, - SO_SNDBUF = 7, - SO_RCVBUF = 8, - SO_SNDBUFFORCE = 32, - SO_RCVBUFFORCE = 33, - SO_KEEPALIVE = 9, - SO_OOBINLINE = 10, - SO_NO_CHECK = 11, - SO_PRIORITY = 12, - SO_LINGER = 13, - SO_BSDCOMPAT = 14, - SO_REUSEPORT = 15, - SO_PASSCRED = 16, - SO_PEERCRED = 17, - SO_RCVLOWAT = 18, - SO_SNDLOWAT = 19, - SO_RCVTIMEO_OLD = 20, - SO_SNDTIMEO_OLD = 21, - - SO_SECURITY_AUTHENTICATION = 22, - SO_SECURITY_ENCRYPTION_TRANSPORT = 23, - SO_SECURITY_ENCRYPTION_NETWORK = 24, - - SO_BINDTODEVICE = 25, - - /// 与SO_GET_FILTER相同 - SO_ATTACH_FILTER = 26, - SO_DETACH_FILTER = 27, - - SO_PEERNAME = 28, - - SO_ACCEPTCONN = 30, - - SO_PEERSEC = 31, - SO_PASSSEC = 34, - - SO_MARK = 36, - - SO_PROTOCOL = 38, - SO_DOMAIN = 39, - - SO_RXQ_OVFL = 40, - - /// 与SCM_WIFI_STATUS相同 - SO_WIFI_STATUS = 41, - SO_PEEK_OFF = 42, - - /* Instruct lower device to use last 4-bytes of skb data as FCS */ - SO_NOFCS = 43, - - SO_LOCK_FILTER = 44, - SO_SELECT_ERR_QUEUE = 45, - SO_BUSY_POLL = 46, - SO_MAX_PACING_RATE = 47, - SO_BPF_EXTENSIONS = 48, - SO_INCOMING_CPU = 49, - SO_ATTACH_BPF = 50, - // SO_DETACH_BPF = SO_DETACH_FILTER, - SO_ATTACH_REUSEPORT_CBPF = 51, - SO_ATTACH_REUSEPORT_EBPF = 52, - - SO_CNX_ADVICE = 53, - SCM_TIMESTAMPING_OPT_STATS = 54, - SO_MEMINFO = 55, - SO_INCOMING_NAPI_ID = 56, - SO_COOKIE = 57, - SCM_TIMESTAMPING_PKTINFO = 58, - SO_PEERGROUPS = 59, - SO_ZEROCOPY = 60, - /// 与SCM_TXTIME相同 - SO_TXTIME = 61, - - SO_BINDTOIFINDEX = 62, - - SO_TIMESTAMP_OLD = 29, - SO_TIMESTAMPNS_OLD = 35, - SO_TIMESTAMPING_OLD = 37, - SO_TIMESTAMP_NEW = 63, - SO_TIMESTAMPNS_NEW = 64, - SO_TIMESTAMPING_NEW = 65, - - SO_RCVTIMEO_NEW = 66, - SO_SNDTIMEO_NEW = 67, - - SO_DETACH_REUSEPORT_BPF = 68, - - SO_PREFER_BUSY_POLL = 69, - SO_BUSY_POLL_BUDGET = 70, - - SO_NETNS_COOKIE = 71, - SO_BUF_LOCK = 72, - SO_RESERVE_MEM = 73, - SO_TXREHASH = 74, - SO_RCVMARK = 75, -} - -impl TryFrom for PosixSocketOption { - type Error = SystemError; - - fn try_from(value: i32) -> Result { - match ::from_i32(value) { - Some(p) => Ok(p), - None => Err(SystemError::EINVAL), - } - } -} - -impl From for i32 { - fn from(value: PosixSocketOption) -> Self { - ::to_i32(&value).unwrap() - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] -pub enum PosixTcpSocketOptions { - /// Turn off Nagle's algorithm. - NoDelay = 1, - /// Limit MSS. - MaxSegment = 2, - /// Never send partially complete segments. - Cork = 3, - /// Start keeplives after this period. - KeepIdle = 4, - /// Interval between keepalives. - KeepIntvl = 5, - /// Number of keepalives before death. - KeepCnt = 6, - /// Number of SYN retransmits. - Syncnt = 7, - /// Lifetime for orphaned FIN-WAIT-2 state. - Linger2 = 8, - /// Wake up listener only when data arrive. - DeferAccept = 9, - /// Bound advertised window - WindowClamp = 10, - /// Information about this connection. - Info = 11, - /// Block/reenable quick acks. - QuickAck = 12, - /// Congestion control algorithm. - Congestion = 13, - /// TCP MD5 Signature (RFC2385). - Md5Sig = 14, - /// Use linear timeouts for thin streams - ThinLinearTimeouts = 16, - /// Fast retrans. after 1 dupack. - ThinDupack = 17, - /// How long for loss retry before timeout. - UserTimeout = 18, - /// TCP sock is under repair right now. - Repair = 19, - RepairQueue = 20, - QueueSeq = 21, - RepairOptions = 22, - /// Enable FastOpen on listeners - FastOpen = 23, - Timestamp = 24, - /// Limit number of unsent bytes in write queue. - NotSentLowat = 25, - /// Get Congestion Control (optional) info. - CCInfo = 26, - /// Record SYN headers for new connections. - SaveSyn = 27, - /// Get SYN headers recorded for connection. - SavedSyn = 28, - /// Get/set window parameters. - RepairWindow = 29, - /// Attempt FastOpen with connect. - FastOpenConnect = 30, - /// Attach a ULP to a TCP connection. - ULP = 31, - /// TCP MD5 Signature with extensions. - Md5SigExt = 32, - /// Set the key for Fast Open(cookie). - FastOpenKey = 33, - /// Enable TFO without a TFO cookie. - FastOpenNoCookie = 34, - ZeroCopyReceive = 35, - /// Notify bytes available to read as a cmsg on read. - /// 与TCP_CM_INQ相同 - INQ = 36, - /// delay outgoing packets by XX usec - TxDelay = 37, -} - -impl TryFrom for PosixTcpSocketOptions { - type Error = SystemError; - - fn try_from(value: i32) -> Result { - match ::from_i32(value) { - Some(p) => Ok(p), - None => Err(SystemError::EINVAL), - } - } -} - -impl From for i32 { - fn from(val: PosixTcpSocketOptions) -> Self { - ::to_i32(&val).unwrap() - } -} diff --git a/kernel/src/syscall/mod.rs b/kernel/src/syscall/mod.rs index 99057fd3f..741191ad6 100644 --- a/kernel/src/syscall/mod.rs +++ b/kernel/src/syscall/mod.rs @@ -9,7 +9,7 @@ use crate::{ filesystem::vfs::syscall::PosixStatfs, libs::{futex::constant::FutexFlag, rand::GRandFlags}, mm::page::PAGE_4K_SIZE, - net::syscall::MsgHdr, + net::posix::{MsgHdr, SockAddr}, process::{ProcessFlags, ProcessManager}, sched::{schedule, SchedMode}, syscall::user_access::check_and_clone_cstr, @@ -27,7 +27,6 @@ use crate::{ syscall::{ModeType, UtimensFlags}, }, mm::{verify_area, VirtAddr}, - net::syscall::SockAddr, time::{ syscall::{PosixTimeZone, PosixTimeval}, PosixTimeSpec, @@ -386,9 +385,10 @@ impl Syscall { // 地址空间超出了用户空间的范围,不合法 Err(SystemError::EFAULT) } else { - Self::connect(args[0], addr, addrlen) + Self::connect(args[0], addr, addrlen as u32) } } + SYS_BIND => { let addr = args[1] as *const SockAddr; let addrlen = args[2]; @@ -398,7 +398,7 @@ impl Syscall { // 地址空间超出了用户空间的范围,不合法 Err(SystemError::EFAULT) } else { - Self::bind(args[0], addr, addrlen) + Self::bind(args[0], addr, addrlen as u32) } } @@ -416,7 +416,7 @@ impl Syscall { Err(SystemError::EFAULT) } else { let data: &[u8] = unsafe { core::slice::from_raw_parts(buf, len) }; - Self::sendto(args[0], data, flags, addr, addrlen) + Self::sendto(args[0], data, flags, addr, addrlen as u32) } } @@ -425,7 +425,7 @@ impl Syscall { let len = args[2]; let flags = args[3] as u32; let addr = args[4] as *mut SockAddr; - let addrlen = args[5] as *mut usize; + let addrlen = args[5] as *mut u32; let virt_buf = VirtAddr::new(buf as usize); let virt_addrlen = VirtAddr::new(addrlen as usize); let virt_addr = VirtAddr::new(addr as usize); @@ -437,7 +437,7 @@ impl Syscall { } // 验证addrlen的地址是否合法 - if verify_area(virt_addrlen, core::mem::size_of::()).is_err() { + if verify_area(virt_addrlen, core::mem::size_of::()).is_err() { // 地址空间超出了用户空间的范围,不合法 return Err(SystemError::EFAULT); } @@ -448,12 +448,11 @@ impl Syscall { } return Ok(()); }; - let r = security_check(); - if let Err(e) = r { + if let Err(e) = security_check() { Err(e) } else { let buf = unsafe { core::slice::from_raw_parts_mut(buf, len) }; - Self::recvfrom(args[0], buf, flags, addr, addrlen as *mut u32) + Self::recvfrom(args[0], buf, flags, addr, addrlen) } } diff --git a/kernel/src/time/timer.rs b/kernel/src/time/timer.rs index 0779d5fca..ca5a24128 100644 --- a/kernel/src/time/timer.rs +++ b/kernel/src/time/timer.rs @@ -284,6 +284,7 @@ pub fn timer_init() { } /// 计算接下来n毫秒对应的定时器时间片 +#[allow(dead_code)] pub fn next_n_ms_timer_jiffies(expire_ms: u64) -> u64 { return TIMER_JIFFIES.load(Ordering::SeqCst) + expire_ms * 1000000 / NSEC_PER_JIFFY as u64; } diff --git a/user/apps/http_server/main.c b/user/apps/http_server/main.c index 76fb33546..b24040dd4 100644 --- a/user/apps/http_server/main.c +++ b/user/apps/http_server/main.c @@ -17,6 +17,8 @@ #define DEFAULT_PAGE "/index.html" +static int request_counter = 0; + int security_check(char *path) { // 检查路径是否包含 .. @@ -214,7 +216,7 @@ int main(int argc, char const *argv[]) while (1) { - printf("Waiting for a client...\n"); + printf("[#%d] Waiting for a client...\n", request_counter++); // 等待并接受客户端连接 if ((new_socket = accept(server_fd, (struct sockaddr *)&address, (socklen_t *)&addrlen)) < 0) diff --git a/user/apps/ping/.gitignore b/user/apps/ping/.gitignore new file mode 100644 index 000000000..1ac354611 --- /dev/null +++ b/user/apps/ping/.gitignore @@ -0,0 +1,3 @@ +/target +Cargo.lock +/install/ \ No newline at end of file diff --git a/user/apps/ping/Cargo.toml b/user/apps/ping/Cargo.toml new file mode 100644 index 000000000..3e0b8ea72 --- /dev/null +++ b/user/apps/ping/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "ping" +version = "0.1.0" +edition = "2021" +description = "ping for dragonOS" +authors = [ "smallc <2628035541@qq.com>" ] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "1.0.86" +clap = { version = "4.5.11", features = ["derive"] } +crossbeam-channel = "0.5.13" +pnet = "0.35.0" +rand = "0.8.5" +signal-hook = "0.3.17" +socket2 = "0.5.7" +thiserror = "1.0.63" diff --git a/user/apps/ping/Makefile b/user/apps/ping/Makefile new file mode 100644 index 000000000..7522ea16c --- /dev/null +++ b/user/apps/ping/Makefile @@ -0,0 +1,56 @@ +TOOLCHAIN= +RUSTFLAGS= + +ifdef DADK_CURRENT_BUILD_DIR +# 如果是在dadk中编译,那么安装到dadk的安装目录中 + INSTALL_DIR = $(DADK_CURRENT_BUILD_DIR) +else +# 如果是在本地编译,那么安装到当前目录下的install目录中 + INSTALL_DIR = ./install +endif + +ifeq ($(ARCH), x86_64) + export RUST_TARGET=x86_64-unknown-linux-musl +else ifeq ($(ARCH), riscv64) + export RUST_TARGET=riscv64gc-unknown-linux-gnu +else +# 默认为x86_86,用于本地编译 + export RUST_TARGET=x86_64-unknown-linux-musl +endif + +run: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) + +build: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) + +clean: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) + +test: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) + +doc: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) doc --target $(RUST_TARGET) + +fmt: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt + +fmt-check: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt --check + +run-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) --release + +build-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) --release + +clean-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) --release + +test-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) --release + +.PHONY: install +install: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) install --target $(RUST_TARGET) --path . --no-track --root $(INSTALL_DIR) --force diff --git a/user/apps/ping/README.md b/user/apps/ping/README.md new file mode 100644 index 000000000..34792da30 --- /dev/null +++ b/user/apps/ping/README.md @@ -0,0 +1,23 @@ +# PING +为DragonOS实现ping +## NAME +ping - 向网络主机发送ICMP ECHO_REQUEST +## SYNOPSIS +[-c count]: 指定 ping 的次数。例如,`-c 4` 会向目标主机发送 4 个 ping 请求。 + +[-i interval]:指定两次 ping 请求之间的时间间隔,单位是秒。例如,`-i 2` 会每 2 秒发送一次 ping 请求。 + +[-w timeout]: 指定等待 ping 响应的超时时间,单位是秒。例如,`-w 5` 会在 5 秒后超时。 + +[-s packetsize]:指定发送的 ICMP Packet 的大小,单位是字节。例如,`-s 64` 会发送大小为 64 字节的 ICMP Packet。 + +[-t ttl]:指定 ping 的 TTL (Time to Live)。例如,`-t 64` 会设置 TTL 为 64。 + +{destination}:指定要 ping 的目标主机。可以是 IP 地址或者主机名。例如,`192.168.1.1` 或 `www.example.com`。 + +## DESCRIPTION +ping 使用 ICMP 协议的必需的 ECHO_REQUEST 数据报来引发主机或网关的 ICMP ECHO_RESPONSE。ECHO_REQUEST 数据报(“ping”)具有 IP 和 ICMP 头,后面跟着一个 struct timeval,然后是用于填充数据包的任意数量的“填充”字节。 + +ping 支持 IPv4 和 IPv6。可以通过指定 -4 或 -6 来强制只使用其中一个。 + +ping 还可以发送 IPv6 节点信息查询(RFC4620)。可能不允许中间跳跃,因为 IPv6 源路由已被弃用(RFC5095)。 diff --git a/user/apps/ping/src/args.rs b/user/apps/ping/src/args.rs new file mode 100644 index 000000000..2b538a5bd --- /dev/null +++ b/user/apps/ping/src/args.rs @@ -0,0 +1,50 @@ +use clap::{arg, command, Parser}; +use rand::random; + +use crate::config::{Config, IpAddress}; + +/// # Args结构体 +/// 使用clap库对命令行输入进行pasing,产生参数配置 +#[derive(Parser, Debug, Clone)] +#[command(author, version, about, long_about = None)] +pub struct Args { + // Count of ping times + #[arg(short, default_value_t = 4)] + count: u16, + + // Ping packet size + #[arg(short = 's', default_value_t = 64)] + packet_size: usize, + + // Ping ttl + #[arg(short = 't', default_value_t = 64)] + ttl: u32, + + // Ping timeout seconds + #[arg(short = 'w', default_value_t = 1)] + timeout: u64, + + // Ping interval duration milliseconds + #[arg(short = 'i', default_value_t = 1000)] + interval: u64, + + // Ping destination, ip or domain + #[arg(value_parser=IpAddress::parse)] + destination: IpAddress, +} + +impl Args { + /// # 将Args结构体转换为config结构体 + pub fn as_config(&self) -> Config { + Config { + count: self.count, + packet_size: self.packet_size, + ttl: self.ttl, + timeout: self.timeout, + interval: self.interval, + id: random::(), + sequence: 1, + address: self.destination.clone(), + } + } +} diff --git a/user/apps/ping/src/config.rs b/user/apps/ping/src/config.rs new file mode 100644 index 000000000..350e8d3aa --- /dev/null +++ b/user/apps/ping/src/config.rs @@ -0,0 +1,45 @@ +use anyhow::bail; +use std::{ + ffi::CString, + net::{self}, +}; + +use crate::error; + +///# Config结构体 +/// 记录ping指令的一些参数值 +#[derive(Debug, Clone)] +pub struct Config { + pub count: u16, + pub packet_size: usize, + pub ttl: u32, + pub timeout: u64, + pub interval: u64, + pub id: u16, + pub sequence: u16, + pub address: IpAddress, +} + +///# 目标地址ip结构体 +/// ip负责提供给socket使用 +/// raw负责打印输出 +#[derive(Debug, Clone)] +pub struct IpAddress { + pub ip: net::IpAddr, + pub raw: String, +} + +impl IpAddress { + pub fn parse(host: &str) -> anyhow::Result { + let raw = String::from(host); + let opt = host.parse::().ok(); + match opt { + Some(ip) => Ok(Self { ip, raw }), + None => { + bail!(error::PingError::InvalidConfig( + "Invalid Address".to_string() + )); + } + } + } +} diff --git a/user/apps/ping/src/error.rs b/user/apps/ping/src/error.rs new file mode 100644 index 000000000..14f474280 --- /dev/null +++ b/user/apps/ping/src/error.rs @@ -0,0 +1,10 @@ +#![allow(dead_code)] + +#[derive(Debug, Clone, thiserror::Error)] +pub enum PingError { + #[error("invaild config")] + InvalidConfig(String), + + #[error("invaild packet")] + InvalidPacket, +} diff --git a/user/apps/ping/src/main.rs b/user/apps/ping/src/main.rs new file mode 100644 index 000000000..a945db2f3 --- /dev/null +++ b/user/apps/ping/src/main.rs @@ -0,0 +1,23 @@ +use args::Args; +use clap::Parser; +use std::format; + +mod args; +mod config; +mod error; +mod ping; +///# ping入口主函数 +fn main() { + let args = Args::parse(); + match ping::Ping::new(args.as_config()) { + Ok(pinger) => pinger.run().unwrap_or_else(|e| { + exit(format!("Error on run ping: {}", e)); + }), + Err(e) => exit(format!("Error on init: {}", e)), + } +} + +fn exit(msg: String) { + eprintln!("{}", msg); + std::process::exit(1); +} diff --git a/user/apps/ping/src/ping.rs b/user/apps/ping/src/ping.rs new file mode 100644 index 000000000..a17881dcb --- /dev/null +++ b/user/apps/ping/src/ping.rs @@ -0,0 +1,151 @@ +use crossbeam_channel::{bounded, select, Receiver}; +use pnet::packet::{ + icmp::{ + echo_reply::{EchoReplyPacket, IcmpCodes}, + echo_request::MutableEchoRequestPacket, + IcmpTypes, + }, + util, Packet, +}; +use signal_hook::consts::{SIGINT, SIGTERM}; +use socket2::{Domain, Protocol, Socket, Type}; +use std::{ + io, + net::{self, Ipv4Addr, SocketAddr}, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + thread::{self}, + time::{Duration, Instant}, +}; + +use crate::{config::Config, error::PingError}; + +#[derive(Clone)] +pub struct Ping { + config: Config, + socket: Arc, + dest: SocketAddr, +} + +impl Ping { + ///# ping创建函数 + /// 使用config进行ping的配置 + pub fn new(config: Config) -> std::io::Result { + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::ICMPV4))?; + let src = SocketAddr::new(net::IpAddr::V4(Ipv4Addr::UNSPECIFIED), 12549); + let dest = SocketAddr::new(config.address.ip, 12549); + socket.bind(&src.into())?; + // socket.set_ttl(64)?; + // socket.set_read_timeout(Some(Duration::from_secs(config.timeout)))?; + // socket.set_write_timeout(Some(Duration::from_secs(config.timeout)))?; + Ok(Self { + config, + dest, + socket: Arc::new(socket), + }) + } + ///# ping主要执行逻辑 + /// 创建icmpPacket发送给socket + pub fn ping(&self, seq_offset: u16) -> anyhow::Result<()> { + //创建 icmp request packet + let mut buf = vec![0; self.config.packet_size]; + let mut icmp = MutableEchoRequestPacket::new(&mut buf[..]).expect("InvalidBuffferSize"); + icmp.set_icmp_type(IcmpTypes::EchoRequest); + icmp.set_icmp_code(IcmpCodes::NoCode); + icmp.set_identifier(self.config.id); + icmp.set_sequence_number(self.config.sequence + seq_offset); + icmp.set_checksum(util::checksum(icmp.packet(), 1)); + + let start = Instant::now(); + + //发送 request + + self.socket.send_to(icmp.packet(), &self.dest.into())?; + + //处理 recv + let mut mem_buf = + unsafe { &mut *(buf.as_mut_slice() as *mut [u8] as *mut [std::mem::MaybeUninit]) }; + let (size, _) = self.socket.recv_from(&mut mem_buf)?; + + let duration = start.elapsed().as_micros() as f64 / 1000.0; + let reply = EchoReplyPacket::new(&buf).ok_or(PingError::InvalidPacket)?; + println!( + "{} bytes from {}: icmp_seq={} ttl={} time={:.2}ms", + size, + self.config.address.ip, + reply.get_sequence_number(), + self.config.ttl, + duration + ); + + Ok(()) + } + ///# ping指令多线程运行 + /// 创建多个线程负责不同的ping函数的执行 + pub fn run(&self) -> io::Result<()> { + println!( + "PING {}({})", + self.config.address.raw, self.config.address.ip + ); + let _now = Instant::now(); + let send = Arc::new(AtomicU64::new(0)); + let _send = send.clone(); + let this = Arc::new(self.clone()); + + let success = Arc::new(AtomicU64::new(0)); + let _success = success.clone(); + + let mut handles = vec![]; + + for i in 0..this.config.count { + let _this = this.clone(); + let handle = thread::spawn(move || { + _this.ping(i).unwrap(); + }); + _send.fetch_add(1, Ordering::SeqCst); + handles.push(handle); + if i < this.config.count - 1 { + thread::sleep(Duration::from_millis(this.config.interval)); + } + } + + for handle in handles { + if handle.join().is_ok() { + _success.fetch_add(1, Ordering::SeqCst); + } + } + + let total = _now.elapsed().as_micros() as f64 / 1000.0; + let send = send.load(Ordering::SeqCst); + let success = success.load(Ordering::SeqCst); + let loss_rate = if send > 0 { + (send - success) * 100 / send + } else { + 0 + }; + println!("\n--- {} ping statistics ---", self.config.address.raw); + println!( + "{} packets transmitted, {} received, {}% packet loss, time {}ms", + send, success, loss_rate, total, + ); + Ok(()) + } +} + +//TODO: 等待添加ctrl+c发送信号后添加该特性 +// /// # 创建一个进程用于监听用户是否提前退出程序 +// fn signal_notify() -> std::io::Result> { +// let (s, r) = bounded(1); + +// let mut signals = signal_hook::iterator::Signals::new(&[SIGINT, SIGTERM])?; + +// thread::spawn(move || { +// for signal in signals.forever() { +// s.send(signal).unwrap(); +// break; +// } +// }); +// Ok(r) +// } diff --git a/user/apps/test_seqpacket/.gitignore b/user/apps/test_seqpacket/.gitignore new file mode 100644 index 000000000..1ac354611 --- /dev/null +++ b/user/apps/test_seqpacket/.gitignore @@ -0,0 +1,3 @@ +/target +Cargo.lock +/install/ \ No newline at end of file diff --git a/user/apps/test_seqpacket/Cargo.toml b/user/apps/test_seqpacket/Cargo.toml new file mode 100644 index 000000000..e27427074 --- /dev/null +++ b/user/apps/test_seqpacket/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "test_seqpacket" +version = "0.1.0" +edition = "2021" +description = "测试seqpacket的socket" +authors = [ "Saga <1750226968@qq.com>" ] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +nix = "0.26" +libc = "0.2" \ No newline at end of file diff --git a/user/apps/test_seqpacket/Makefile b/user/apps/test_seqpacket/Makefile new file mode 100644 index 000000000..7522ea16c --- /dev/null +++ b/user/apps/test_seqpacket/Makefile @@ -0,0 +1,56 @@ +TOOLCHAIN= +RUSTFLAGS= + +ifdef DADK_CURRENT_BUILD_DIR +# 如果是在dadk中编译,那么安装到dadk的安装目录中 + INSTALL_DIR = $(DADK_CURRENT_BUILD_DIR) +else +# 如果是在本地编译,那么安装到当前目录下的install目录中 + INSTALL_DIR = ./install +endif + +ifeq ($(ARCH), x86_64) + export RUST_TARGET=x86_64-unknown-linux-musl +else ifeq ($(ARCH), riscv64) + export RUST_TARGET=riscv64gc-unknown-linux-gnu +else +# 默认为x86_86,用于本地编译 + export RUST_TARGET=x86_64-unknown-linux-musl +endif + +run: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) + +build: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) + +clean: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) + +test: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) + +doc: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) doc --target $(RUST_TARGET) + +fmt: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt + +fmt-check: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt --check + +run-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) --release + +build-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) --release + +clean-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) --release + +test-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) --release + +.PHONY: install +install: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) install --target $(RUST_TARGET) --path . --no-track --root $(INSTALL_DIR) --force diff --git a/user/apps/test_seqpacket/README.md b/user/apps/test_seqpacket/README.md new file mode 100644 index 000000000..74b00a908 --- /dev/null +++ b/user/apps/test_seqpacket/README.md @@ -0,0 +1,14 @@ +# DragonOS Rust-Application Template + +您可以使用此模板来创建DragonOS应用程序。 + +## 使用方法 + +1. 使用DragonOS的tools目录下的`bootstrap.sh`脚本初始化环境 +2. 在终端输入`cargo install cargo-generate` +3. 在终端输入`cargo generate --git https://github.com/DragonOS-Community/Rust-App-Template`即可创建项目 +如果您的网络较慢,请使用镜像站`cargo generate --git https://git.mirrors.dragonos.org/DragonOS-Community/Rust-App-Template` +4. 使用`cargo run`来运行项目 +5. 在DragonOS的`user/dadk/config`目录下,使用`dadk new`命令,创建编译配置,安装到DragonOS的`/`目录下。 +(在dadk的编译命令选项处,请使用Makefile里面的`make install`配置进行编译、安装) +6. 编译DragonOS即可安装 diff --git a/user/apps/test_seqpacket/src/main.rs b/user/apps/test_seqpacket/src/main.rs new file mode 100644 index 000000000..9657b36a9 --- /dev/null +++ b/user/apps/test_seqpacket/src/main.rs @@ -0,0 +1,190 @@ +mod seq_pair; +mod seq_socket; + +use seq_pair::test_seq_pair; +use seq_socket::test_seq_socket; + +fn main() -> Result<(), std::io::Error> { + if let Err(e) = test_seq_socket() { + println!("[ fault ] test_seq_socket, err: {}", e); + } else { + println!("[success] test_seq_socket"); + } + + if let Err(e) = test_seq_pair() { + println!("[ fault ] test_seq_pair, err: {}", e); + } else { + println!("[success] test_seq_pair"); + } + + Ok(()) +} + +// use nix::sys::socket::{socketpair, AddressFamily, SockFlag, SockType}; +// use std::fs::File; +// use std::io::{Read, Write}; +// use std::os::fd::FromRawFd; +// use std::{fs, str}; + +// use libc::*; +// use std::ffi::CString; +// use std::io::Error; +// use std::mem; +// use std::os::unix::io::RawFd; +// use std::ptr; + +// const SOCKET_PATH: &str = "/test.seqpacket"; +// const MSG: &str = "Hello, Unix SEQPACKET socket!"; + +// fn create_seqpacket_socket() -> Result { +// unsafe { +// let fd = socket(AF_UNIX, SOCK_SEQPACKET, 0); +// if fd == -1 { +// return Err(Error::last_os_error()); +// } +// Ok(fd) +// } +// } + +// fn bind_socket(fd: RawFd) -> Result<(), Error> { +// unsafe { +// let mut addr = sockaddr_un { +// sun_family: AF_UNIX as u16, +// sun_path: [0; 108], +// }; +// let path_cstr = CString::new(SOCKET_PATH).unwrap(); +// let path_bytes = path_cstr.as_bytes(); +// for (i, &byte) in path_bytes.iter().enumerate() { +// addr.sun_path[i] = byte as i8; +// } + +// if bind(fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 { +// return Err(Error::last_os_error()); +// } +// } +// Ok(()) +// } + +// fn listen_socket(fd: RawFd) -> Result<(), Error> { +// unsafe { +// if listen(fd, 5) == -1 { +// return Err(Error::last_os_error()); +// } +// } +// Ok(()) +// } + +// fn accept_connection(fd: RawFd) -> Result { +// unsafe { +// // let mut addr = sockaddr_un { +// // sun_family: AF_UNIX as u16, +// // sun_path: [0; 108], +// // }; +// // let mut len = mem::size_of_val(&addr) as socklen_t; +// let client_fd = accept(fd, std::ptr::null_mut(), std::ptr::null_mut()); +// if client_fd == -1 { +// return Err(Error::last_os_error()); +// } +// Ok(client_fd) +// } +// } + +// fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> { +// unsafe { +// let msg_bytes = msg.as_bytes(); +// if send(fd, msg_bytes.as_ptr() as *const libc::c_void, msg_bytes.len(), 0) == -1 { +// return Err(Error::last_os_error()); +// } +// } +// Ok(()) +// } + +// fn receive_message(fd: RawFd) -> Result { +// let mut buffer = [0; 1024]; +// unsafe { +// let len = recv(fd, buffer.as_mut_ptr() as *mut libc::c_void, buffer.len(), 0); +// if len == -1 { +// return Err(Error::last_os_error()); +// } +// Ok(String::from_utf8_lossy(&buffer[..len as usize]).into_owned()) +// } +// } +// fn main() -> Result<(), Error> { +// // Create and bind the server socket +// fs::remove_file(&SOCKET_PATH).ok(); + +// let server_fd = create_seqpacket_socket()?; +// bind_socket(server_fd)?; +// listen_socket(server_fd)?; + +// // Accept connection in a separate thread +// let server_thread = std::thread::spawn(move || { +// let client_fd = accept_connection(server_fd).expect("Failed to accept connection"); + +// // Receive and print message +// let received_msg = receive_message(client_fd).expect("Failed to receive message"); +// println!("Server: Received message: {}", received_msg); + +// // Close client connection +// unsafe { close(client_fd) }; +// }); + +// // Create and connect the client socket +// let client_fd = create_seqpacket_socket()?; +// unsafe { +// let mut addr = sockaddr_un { +// sun_family: AF_UNIX as u16, +// sun_path: [0; 108], +// }; +// let path_cstr = CString::new(SOCKET_PATH).unwrap(); +// let path_bytes = path_cstr.as_bytes(); +// // Convert u8 to i8 +// for (i, &byte) in path_bytes.iter().enumerate() { +// addr.sun_path[i] = byte as i8; +// } +// if connect(client_fd, &addr as *const _ as *const sockaddr, mem::size_of_val(&addr) as socklen_t) == -1 { +// return Err(Error::last_os_error()); +// } +// } +// send_message(client_fd, MSG)?; + +// // Close client connection +// unsafe { close(client_fd) }; + +// // Wait for server thread to complete +// server_thread.join().expect("Server thread panicked"); +// fs::remove_file(&SOCKET_PATH).ok(); + +// // 创建 socket pair +// let (sock1, sock2) = socketpair( +// AddressFamily::Unix, +// SockType::SeqPacket, // 使用 SeqPacket 类型 +// None, // 协议默认 +// SockFlag::empty(), +// ).expect("Failed to create socket pair"); + +// let mut socket1 = unsafe { File::from_raw_fd(sock1) }; +// let mut socket2 = unsafe { File::from_raw_fd(sock2) }; +// // sock1 写入数据 +// let msg = b"hello from sock1"; +// socket1.write_all(msg)?; +// println!("sock1 send: {:?}", String::from_utf8_lossy(&msg[..])); + +// // 因os read和write时会调整file的offset,write会对offset和meta size(目前返回的都是0)进行比较, +// // 而read不会,故双socket都先send,后recv + +// // sock2 回复数据 +// let reply = b"hello from sock2"; +// socket2.write_all(reply)?; +// println!("sock2 send: {:?}", String::from_utf8_lossy(reply)); + +// // sock2 读取数据 +// let mut buf = [0u8; 128]; +// let len = socket2.read(&mut buf)?; +// println!("sock2 receive: {:?}", String::from_utf8_lossy(&buf[..len])); + +// // sock1 读取回复 +// let len = socket1.read(&mut buf)?; +// println!("sock1 receive: {:?}", String::from_utf8_lossy(&buf[..len])); +// Ok(()) +// } diff --git a/user/apps/test_seqpacket/src/seq_pair.rs b/user/apps/test_seqpacket/src/seq_pair.rs new file mode 100644 index 000000000..3c9c38185 --- /dev/null +++ b/user/apps/test_seqpacket/src/seq_pair.rs @@ -0,0 +1,40 @@ +use nix::sys::socket::{socketpair, AddressFamily, SockFlag, SockType}; +use std::fs::File; +use std::io::{Error, Read, Write}; +use std::os::fd::FromRawFd; + +pub fn test_seq_pair() -> Result<(), Error> { + // 创建 socket pair + let (sock1, sock2) = socketpair( + AddressFamily::Unix, + SockType::SeqPacket, // 使用 SeqPacket 类型 + None, // 协议默认 + SockFlag::empty(), + ) + .expect("Failed to create socket pair"); + + let mut socket1 = unsafe { File::from_raw_fd(sock1) }; + let mut socket2 = unsafe { File::from_raw_fd(sock2) }; + // sock1 写入数据 + let msg = b"hello from sock1"; + socket1.write_all(msg)?; + println!("sock1 send: {:?}", String::from_utf8_lossy(&msg[..])); + + // 因os read和write时会调整file的offset,write会对offset和meta size(目前返回的都是0)进行比较, + // 而read不会,故双socket都先send,后recv + + // sock2 回复数据 + let reply = b"hello from sock2"; + socket2.write_all(reply)?; + println!("sock2 send: {:?}", String::from_utf8_lossy(reply)); + + // sock2 读取数据 + let mut buf = [0u8; 128]; + let len = socket2.read(&mut buf)?; + println!("sock2 receive: {:?}", String::from_utf8_lossy(&buf[..len])); + + // sock1 读取回复 + let len = socket1.read(&mut buf)?; + println!("sock1 receive: {:?}", String::from_utf8_lossy(&buf[..len])); + Ok(()) +} diff --git a/user/apps/test_seqpacket/src/seq_socket.rs b/user/apps/test_seqpacket/src/seq_socket.rs new file mode 100644 index 000000000..81b3db5bd --- /dev/null +++ b/user/apps/test_seqpacket/src/seq_socket.rs @@ -0,0 +1,181 @@ +use libc::*; +use std::ffi::CString; +use std::io::Error; +use std::mem; +use std::os::unix::io::RawFd; +use std::{fs, str}; + +const SOCKET_PATH: &str = "/test.seqpacket"; +const MSG1: &str = "Hello, Unix SEQPACKET socket from Client!"; +const MSG2: &str = "Hello, Unix SEQPACKET socket from Server!"; + +fn create_seqpacket_socket() -> Result { + unsafe { + let fd = socket(AF_UNIX, SOCK_SEQPACKET, 0); + if fd == -1 { + return Err(Error::last_os_error()); + } + Ok(fd) + } +} + +fn bind_socket(fd: RawFd) -> Result<(), Error> { + unsafe { + let mut addr = sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: [0; 108], + }; + let path_cstr = CString::new(SOCKET_PATH).unwrap(); + let path_bytes = path_cstr.as_bytes(); + for (i, &byte) in path_bytes.iter().enumerate() { + addr.sun_path[i] = byte as i8; + } + + if bind( + fd, + &addr as *const _ as *const sockaddr, + mem::size_of_val(&addr) as socklen_t, + ) == -1 + { + return Err(Error::last_os_error()); + } + } + Ok(()) +} + +fn listen_socket(fd: RawFd) -> Result<(), Error> { + unsafe { + if listen(fd, 5) == -1 { + return Err(Error::last_os_error()); + } + } + Ok(()) +} + +fn accept_connection(fd: RawFd) -> Result { + unsafe { + // let mut addr = sockaddr_un { + // sun_family: AF_UNIX as u16, + // sun_path: [0; 108], + // }; + // let mut len = mem::size_of_val(&addr) as socklen_t; + // let client_fd = accept(fd, &mut addr as *mut _ as *mut sockaddr, &mut len); + let client_fd = accept(fd, std::ptr::null_mut(), std::ptr::null_mut()); + if client_fd == -1 { + return Err(Error::last_os_error()); + } + Ok(client_fd) + } +} + +fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> { + unsafe { + let msg_bytes = msg.as_bytes(); + if send( + fd, + msg_bytes.as_ptr() as *const libc::c_void, + msg_bytes.len(), + 0, + ) == -1 + { + return Err(Error::last_os_error()); + } + } + Ok(()) +} + +fn receive_message(fd: RawFd) -> Result { + let mut buffer = [0; 1024]; + unsafe { + let len = recv( + fd, + buffer.as_mut_ptr() as *mut libc::c_void, + buffer.len(), + 0, + ); + if len == -1 { + return Err(Error::last_os_error()); + } + Ok(String::from_utf8_lossy(&buffer[..len as usize]).into_owned()) + } +} + +pub fn test_seq_socket() -> Result<(), Error> { + // Create and bind the server socket + fs::remove_file(&SOCKET_PATH).ok(); + + let server_fd = create_seqpacket_socket()?; + bind_socket(server_fd)?; + listen_socket(server_fd)?; + + // Accept connection in a separate thread + let server_thread = std::thread::spawn(move || { + let client_fd = accept_connection(server_fd).expect("Failed to accept connection"); + + // Receive and print message + let received_msg = receive_message(client_fd).expect("Failed to receive message"); + println!("Server: Received message: {}", received_msg); + + send_message(client_fd, MSG2).expect("Failed to send message"); + + // Close client connection + unsafe { close(client_fd) }; + }); + + // Create and connect the client socket + let client_fd = create_seqpacket_socket()?; + unsafe { + let mut addr = sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: [0; 108], + }; + let path_cstr = CString::new(SOCKET_PATH).unwrap(); + let path_bytes = path_cstr.as_bytes(); + // Convert u8 to i8 + for (i, &byte) in path_bytes.iter().enumerate() { + addr.sun_path[i] = byte as i8; + } + if connect( + client_fd, + &addr as *const _ as *const sockaddr, + mem::size_of_val(&addr) as socklen_t, + ) == -1 + { + return Err(Error::last_os_error()); + } + } + send_message(client_fd, MSG1)?; + let received_msg = receive_message(client_fd).expect("Failed to receive message"); + println!("Client: Received message: {}", received_msg); + // get peer_name + unsafe { + let mut addrss = sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: [0; 108], + }; + let mut len = mem::size_of_val(&addrss) as socklen_t; + let res = getpeername(client_fd, &mut addrss as *mut _ as *mut sockaddr, &mut len); + if res == -1 { + return Err(Error::last_os_error()); + } + let sun_path = addrss.sun_path.clone(); + let peer_path: [u8; 108] = sun_path + .iter() + .map(|&x| x as u8) + .collect::>() + .try_into() + .unwrap(); + println!( + "Client: Connected to server at path: {}", + String::from_utf8_lossy(&peer_path) + ); + } + + server_thread.join().expect("Server thread panicked"); + let received_msg = receive_message(client_fd).expect("Failed to receive message"); + println!("Client: Received message: {}", received_msg); + // Close client connection + unsafe { close(client_fd) }; + fs::remove_file(&SOCKET_PATH).ok(); + Ok(()) +} diff --git a/user/apps/test_unix_stream_socket/.gitignore b/user/apps/test_unix_stream_socket/.gitignore new file mode 100644 index 000000000..1ac354611 --- /dev/null +++ b/user/apps/test_unix_stream_socket/.gitignore @@ -0,0 +1,3 @@ +/target +Cargo.lock +/install/ \ No newline at end of file diff --git a/user/apps/test_unix_stream_socket/Cargo.toml b/user/apps/test_unix_stream_socket/Cargo.toml new file mode 100644 index 000000000..f4888937a --- /dev/null +++ b/user/apps/test_unix_stream_socket/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "test_unix_stream_socket" +version = "0.1.0" +edition = "2021" +description = "test for unix stream socket" +authors = [ "smallcjy <2628035541@qq.com>" ] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +libc = "0.2.158" diff --git a/user/apps/test_unix_stream_socket/Makefile b/user/apps/test_unix_stream_socket/Makefile new file mode 100644 index 000000000..7522ea16c --- /dev/null +++ b/user/apps/test_unix_stream_socket/Makefile @@ -0,0 +1,56 @@ +TOOLCHAIN= +RUSTFLAGS= + +ifdef DADK_CURRENT_BUILD_DIR +# 如果是在dadk中编译,那么安装到dadk的安装目录中 + INSTALL_DIR = $(DADK_CURRENT_BUILD_DIR) +else +# 如果是在本地编译,那么安装到当前目录下的install目录中 + INSTALL_DIR = ./install +endif + +ifeq ($(ARCH), x86_64) + export RUST_TARGET=x86_64-unknown-linux-musl +else ifeq ($(ARCH), riscv64) + export RUST_TARGET=riscv64gc-unknown-linux-gnu +else +# 默认为x86_86,用于本地编译 + export RUST_TARGET=x86_64-unknown-linux-musl +endif + +run: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) + +build: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) + +clean: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) + +test: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) + +doc: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) doc --target $(RUST_TARGET) + +fmt: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt + +fmt-check: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt --check + +run-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) --release + +build-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) --release + +clean-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) --release + +test-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) --release + +.PHONY: install +install: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) install --target $(RUST_TARGET) --path . --no-track --root $(INSTALL_DIR) --force diff --git a/user/apps/test_unix_stream_socket/README.md b/user/apps/test_unix_stream_socket/README.md new file mode 100644 index 000000000..0d9af71a0 --- /dev/null +++ b/user/apps/test_unix_stream_socket/README.md @@ -0,0 +1,5 @@ +# unix stream socket 测试程序 + +## 测试思路 + +跨线程通信,一个线程作为服务端监听一个测试文件,另一个线程作为客户端连接监听的文件。若连接成功,测试能够正常通信。 \ No newline at end of file diff --git a/user/apps/test_unix_stream_socket/src/main.rs b/user/apps/test_unix_stream_socket/src/main.rs new file mode 100644 index 000000000..ad184070f --- /dev/null +++ b/user/apps/test_unix_stream_socket/src/main.rs @@ -0,0 +1,325 @@ +use libc::*; +use std::ffi::CString; +use std::fs; +use std::io::Error; +use std::mem; +use std::os::fd::RawFd; + +const SOCKET_PATH: &str = "./test.stream"; +const SOCKET_ABSTRUCT_PATH: &str = "/abs.stream"; +const MSG1: &str = "Hello, unix stream socket from Client!"; +const MSG2: &str = "Hello, unix stream socket from Server!"; + +fn create_stream_socket() -> Result { + unsafe { + let fd = socket(AF_UNIX, SOCK_STREAM, 0); + if fd == -1 { + return Err(Error::last_os_error()); + } + Ok(fd) + } +} + +fn bind_socket(fd: RawFd) -> Result<(), Error> { + unsafe { + let mut addr = sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: [0; 108], + }; + let path_cstr = CString::new(SOCKET_PATH).unwrap(); + let path_bytes = path_cstr.as_bytes(); + for (i, &byte) in path_bytes.iter().enumerate() { + addr.sun_path[i] = byte as i8; + } + + if bind( + fd, + &addr as *const _ as *const sockaddr, + mem::size_of_val(&addr) as socklen_t, + ) == -1 + { + return Err(Error::last_os_error()); + } + } + + Ok(()) +} + +fn bind_abstruct_socket(fd: RawFd) -> Result<(), Error> { + unsafe { + let mut addr = sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: [0; 108], + }; + addr.sun_path[0] = 0; + let path_cstr = CString::new(SOCKET_ABSTRUCT_PATH).unwrap(); + let path_bytes = path_cstr.as_bytes(); + for (i, &byte) in path_bytes.iter().enumerate() { + addr.sun_path[i + 1] = byte as i8; + } + + if bind( + fd, + &addr as *const _ as *const sockaddr, + mem::size_of_val(&addr) as socklen_t, + ) == -1 + { + return Err(Error::last_os_error()); + } + } + + Ok(()) +} + +fn listen_socket(fd: RawFd) -> Result<(), Error> { + unsafe { + if listen(fd, 5) == -1 { + return Err(Error::last_os_error()); + } + } + Ok(()) +} + +fn accept_conn(fd: RawFd) -> Result { + unsafe { + let client_fd = accept(fd, std::ptr::null_mut(), std::ptr::null_mut()); + if client_fd == -1 { + return Err(Error::last_os_error()); + } + Ok(client_fd) + } +} + +fn send_message(fd: RawFd, msg: &str) -> Result<(), Error> { + unsafe { + let msg_bytes = msg.as_bytes(); + if send( + fd, + msg_bytes.as_ptr() as *const libc::c_void, + msg_bytes.len(), + 0, + ) == -1 + { + return Err(Error::last_os_error()); + } + } + Ok(()) +} + +fn recv_message(fd: RawFd) -> Result { + let mut buffer = [0; 1024]; + unsafe { + let len = recv( + fd, + buffer.as_mut_ptr() as *mut libc::c_void, + buffer.len(), + 0, + ); + if len == -1 { + return Err(Error::last_os_error()); + } + Ok(String::from_utf8_lossy(&buffer[..len as usize]).into_owned()) + } +} + +fn test_stream() -> Result<(), Error> { + fs::remove_file(&SOCKET_PATH).ok(); + + let server_fd = create_stream_socket()?; + bind_socket(server_fd)?; + listen_socket(server_fd)?; + + let server_thread = std::thread::spawn(move || { + let client_fd = accept_conn(server_fd).expect("Failed to accept connection"); + println!("accept success!"); + let recv_msg = recv_message(client_fd).expect("Failed to receive message"); + + println!("Server: Received message: {}", recv_msg); + send_message(client_fd, MSG2).expect("Failed to send message"); + println!("Server send finish"); + + println!("Server begin close!"); + unsafe { close(server_fd) }; + println!("Server close finish!"); + }); + + let client_fd = create_stream_socket()?; + unsafe { + let mut addr = sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: [0; 108], + }; + let path_cstr = CString::new(SOCKET_PATH).unwrap(); + let path_bytes = path_cstr.as_bytes(); + + for (i, &byte) in path_bytes.iter().enumerate() { + addr.sun_path[i] = byte as i8; + } + + if connect( + client_fd, + &addr as *const _ as *const sockaddr, + mem::size_of_val(&addr) as socklen_t, + ) == -1 + { + return Err(Error::last_os_error()); + } + } + + send_message(client_fd, MSG1)?; + // get peer_name + unsafe { + let mut addrss = sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: [0; 108], + }; + let mut len = mem::size_of_val(&addrss) as socklen_t; + let res = getpeername(client_fd, &mut addrss as *mut _ as *mut sockaddr, &mut len); + if res == -1 { + return Err(Error::last_os_error()); + } + let sun_path = addrss.sun_path.clone(); + let peer_path: [u8; 108] = sun_path + .iter() + .map(|&x| x as u8) + .collect::>() + .try_into() + .unwrap(); + println!( + "Client: Connected to server at path: {}", + String::from_utf8_lossy(&peer_path) + ); + } + + server_thread.join().expect("Server thread panicked"); + println!("Client try recv!"); + let recv_msg = recv_message(client_fd).expect("Failed to receive message from server"); + println!("Client Received message: {}", recv_msg); + + unsafe { close(client_fd) }; + fs::remove_file(&SOCKET_PATH).ok(); + + Ok(()) +} + +fn test_abstruct_namespace() -> Result<(), Error> { + let server_fd = create_stream_socket()?; + bind_abstruct_socket(server_fd)?; + listen_socket(server_fd)?; + + let server_thread = std::thread::spawn(move || { + let client_fd = accept_conn(server_fd).expect("Failed to accept connection"); + println!("accept success!"); + let recv_msg = recv_message(client_fd).expect("Failed to receive message"); + + println!("Server: Received message: {}", recv_msg); + send_message(client_fd, MSG2).expect("Failed to send message"); + println!("Server send finish"); + + unsafe { close(server_fd) } + }); + + let client_fd = create_stream_socket()?; + unsafe { + let mut addr = sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: [0; 108], + }; + addr.sun_path[0] = 0; + let path_cstr = CString::new(SOCKET_ABSTRUCT_PATH).unwrap(); + let path_bytes = path_cstr.as_bytes(); + + for (i, &byte) in path_bytes.iter().enumerate() { + addr.sun_path[i + 1] = byte as i8; + } + + if connect( + client_fd, + &addr as *const _ as *const sockaddr, + mem::size_of_val(&addr) as socklen_t, + ) == -1 + { + return Err(Error::last_os_error()); + } + } + + send_message(client_fd, MSG1)?; + // get peer_name + unsafe { + let mut addrss = sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: [0; 108], + }; + let mut len = mem::size_of_val(&addrss) as socklen_t; + let res = getpeername(client_fd, &mut addrss as *mut _ as *mut sockaddr, &mut len); + if res == -1 { + return Err(Error::last_os_error()); + } + let sun_path = addrss.sun_path.clone(); + let peer_path: [u8; 108] = sun_path + .iter() + .map(|&x| x as u8) + .collect::>() + .try_into() + .unwrap(); + println!( + "Client: Connected to server at path: {}", + String::from_utf8_lossy(&peer_path) + ); + } + + server_thread.join().expect("Server thread panicked"); + println!("Client try recv!"); + let recv_msg = recv_message(client_fd).expect("Failed to receive message from server"); + println!("Client Received message: {}", recv_msg); + + unsafe { close(client_fd) }; + Ok(()) +} + +fn test_recourse_free() -> Result<(), Error> { + let client_fd = create_stream_socket()?; + unsafe { + let mut addr = sockaddr_un { + sun_family: AF_UNIX as u16, + sun_path: [0; 108], + }; + addr.sun_path[0] = 0; + let path_cstr = CString::new(SOCKET_ABSTRUCT_PATH).unwrap(); + let path_bytes = path_cstr.as_bytes(); + + for (i, &byte) in path_bytes.iter().enumerate() { + addr.sun_path[i + 1] = byte as i8; + } + + if connect( + client_fd, + &addr as *const _ as *const sockaddr, + mem::size_of_val(&addr) as socklen_t, + ) == -1 + { + return Err(Error::last_os_error()); + } + } + + send_message(client_fd, MSG1)?; + unsafe { close(client_fd) }; + Ok(()) +} + +fn main() { + match test_stream() { + Ok(_) => println!("test for unix stream success"), + Err(_) => println!("test for unix stream failed"), + } + + match test_abstruct_namespace() { + Ok(_) => println!("test for unix abstruct namespace success"), + Err(_) => println!("test for unix abstruct namespace failed"), + } + + match test_recourse_free() { + Ok(_) => println!("not free!"), + Err(_) => println!("free!"), + } +} diff --git a/user/dadk/config/ping_0_1_0.toml b/user/dadk/config/ping_0_1_0.toml new file mode 100644 index 000000000..8c4b23e97 --- /dev/null +++ b/user/dadk/config/ping_0_1_0.toml @@ -0,0 +1,41 @@ +# 用户程序名称 +name = "ping" +# 版本号 +version = "0.1.0" +# 用户程序描述信息 +description = "ping用户程序" +# (可选)默认: false 是否只构建一次,如果为true,DADK会在构建成功后,将构建结果缓存起来,下次构建时,直接使用缓存的构建结果 +build-once = false +# (可选) 默认: false 是否只安装一次,如果为true,DADK会在安装成功后,不再重复安装 +install-once = false +# 目标架构 +# 可选值:"x86_64", "aarch64", "riscv64" +target-arch = ["x86_64"] +# 任务源 +[task-source] +# 构建类型 +# 可选值:"build-from-source", "install-from-prebuilt" +type = "build-from-source" +# 构建来源 +# "build_from_source" 可选值:"git", "local", "archive" +# "install_from_prebuilt" 可选值:"local", "archive" +source = "local" +# 路径或URL +source-path = "user/apps/ping" +# 构建相关信息 +[build] +# (可选)构建命令 +build-command = "make install" +# 安装相关信息 +[install] +# (可选)安装到DragonOS的路径 +in-dragonos-path = "/usr" +# 清除相关信息 +[clean] +# (可选)清除命令 +clean-command = "make clean" +# (可选)依赖项 +# 注意:如果没有依赖项,忽略此项,不允许只留一个[[depends]] +# 由于原配置中没有依赖项,此处省略[[depends]]部分 +# (可选)环境变量 +# 由于原配置中没有环境变量,此处省略[[envs]]部分 diff --git a/user/dadk/config/test_seqpacket_0_1_0.toml b/user/dadk/config/test_seqpacket_0_1_0.toml new file mode 100644 index 000000000..a248ac42f --- /dev/null +++ b/user/dadk/config/test_seqpacket_0_1_0.toml @@ -0,0 +1,41 @@ +# 用户程序名称 +name = "test_seqpacket" +# 版本号 +version = "0.1.0" +# 用户程序描述信息 +description = "对seqpacket_pair的简单测试" +# (可选)默认: false 是否只构建一次,如果为true,DADK会在构建成功后,将构建结果缓存起来,下次构建时,直接使用缓存的构建结果 +build-once = false +# (可选) 默认: false 是否只安装一次,如果为true,DADK会在安装成功后,不再重复安装 +install-once = false +# 目标架构 +# 可选值:"x86_64", "aarch64", "riscv64" +target-arch = ["x86_64"] +# 任务源 +[task-source] +# 构建类型 +# 可选值:"build-from-source", "install-from-prebuilt" +type = "build-from-source" +# 构建来源 +# "build_from_source" 可选值:"git", "local", "archive" +# "install_from_prebuilt" 可选值:"local", "archive" +source = "local" +# 路径或URL +source-path = "user/apps/test_seqpacket" +# 构建相关信息 +[build] +# (可选)构建命令 +build-command = "make install" +# 安装相关信息 +[install] +# (可选)安装到DragonOS的路径 +in-dragonos-path = "/" +# 清除相关信息 +[clean] +# (可选)清除命令 +clean-command = "make clean" +# (可选)依赖项 +# 注意:如果没有依赖项,忽略此项,不允许只留一个[[depends]] +# 由于原配置中没有依赖项,此处省略[[depends]]部分 +# (可选)环境变量 +# 由于原配置中没有环境变量,此处省略[[envs]]部分 diff --git a/user/dadk/config/test_stream_socket_0_1_0.toml b/user/dadk/config/test_stream_socket_0_1_0.toml new file mode 100644 index 000000000..0b66a63f3 --- /dev/null +++ b/user/dadk/config/test_stream_socket_0_1_0.toml @@ -0,0 +1,41 @@ +# 用户程序名称 +name = "test_stream_socket" +# 版本号 +version = "0.1.0" +# 用户程序描述信息 +description = "test for unix stream socket" +# (可选)默认: false 是否只构建一次,如果为true,DADK会在构建成功后,将构建结果缓存起来,下次构建时,直接使用缓存的构建结果 +build-once = false +# (可选) 默认: false 是否只安装一次,如果为true,DADK会在安装成功后,不再重复安装 +install-once = false +# 目标架构 +# 可选值:"x86_64", "aarch64", "riscv64" +target-arch = ["x86_64"] +# 任务源 +[task-source] +# 构建类型 +# 可选值:"build-from-source", "install-from-prebuilt" +type = "build-from-source" +# 构建来源 +# "build_from_source" 可选值:"git", "local", "archive" +# "install_from_prebuilt" 可选值:"local", "archive" +source = "local" +# 路径或URL +source-path = "user/apps/test_unix_stream_socket" +# 构建相关信息 +[build] +# (可选)构建命令 +build-command = "make install" +# 安装相关信息 +[install] +# (可选)安装到DragonOS的路径 +in-dragonos-path = "/" +# 清除相关信息 +[clean] +# (可选)清除命令 +clean-command = "make clean" +# (可选)依赖项 +# 注意:如果没有依赖项,忽略此项,不允许只留一个[[depends]] +# 由于原配置中没有依赖项,此处省略[[depends]]部分 +# (可选)环境变量 +# 由于原配置中没有环境变量,此处省略[[envs]]部分 diff --git a/user/sysconfig/home/reach/system/cloud-seed.service b/user/sysconfig/home/reach/system/cloud-seed.service new file mode 100644 index 000000000..27e048844 --- /dev/null +++ b/user/sysconfig/home/reach/system/cloud-seed.service @@ -0,0 +1,31 @@ +[Unit] +Description=cloud-seed +Requires=network-online.target +After=network-online.target + +# If cloud-seed fails, try again every 5 seconds for 30 seconds. If it still +# fails, reboot the system. If your use of cloud-seed is not critical to system +# operation, you may consider changing the FailureAction setting. +StartLimitBurst=6 +StartLimitIntervalSec=5 +FailureAction=reboot + +[Service] +Type=oneshot +RemainAfterExit=true +Restart=on-failure + +ExecStart=/usr/bin/cloud-seed +Environment=RUST_LOG=debug +StandardOutput=journal+console + +# You can set a more restrictive umask here to restrict the permissions that +# cloud-seed can create files with. +UMask=0000 + +# cloud-seed can run as a non-root user. In this case, files can only be +# written at paths that this user has permission to write to. +# User= + +[Install] +WantedBy=multi-user.target