poll_tcp_listener.rs 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. use std::io::{self, Read, Write};
  2. use std::vec;
  3. use wasmedge_wasi_socket::poll;
  4. use wasmedge_wasi_socket::{TcpListener, TcpStream};
  5. const DATA: &[u8] = b"Hello world!\n";
  6. enum NetConn {
  7. Server(TcpListener),
  8. Client(TcpStream),
  9. }
  10. struct Connects {
  11. inner: Vec<Option<NetConn>>,
  12. }
  13. impl Connects {
  14. fn new() -> Self {
  15. Connects { inner: vec![] }
  16. }
  17. fn next(&mut self) -> usize {
  18. for (i, v) in self.inner.iter_mut().enumerate() {
  19. if v.is_none() {
  20. return i;
  21. }
  22. }
  23. self.inner.push(None);
  24. self.inner.len() - 1
  25. }
  26. fn get_mut(&mut self, id: usize) -> Option<&mut NetConn> {
  27. if let Some(x) = self.inner.get_mut(id)? {
  28. Some(x)
  29. } else {
  30. None
  31. }
  32. }
  33. fn slice(&self) -> &[Option<NetConn>] {
  34. self.inner.as_slice()
  35. }
  36. fn add(&mut self, conn: NetConn) -> usize {
  37. let next_id = self.next();
  38. let _ = self.inner[next_id].insert(conn);
  39. next_id
  40. }
  41. fn remove(&mut self, id: usize) -> Option<NetConn> {
  42. println!("remove conn[{}]", id);
  43. self.inner.get_mut(id).and_then(|v| v.take())
  44. }
  45. }
  46. fn connects_to_subscriptions(connects: &Connects) -> Vec<poll::Subscription> {
  47. let mut subscriptions = vec![];
  48. for (i, conn) in connects.slice().iter().enumerate() {
  49. if let Some(conn) = conn {
  50. match conn {
  51. NetConn::Server(s) => {
  52. subscriptions.push(poll::Subscription::io(i as u64, s, true, false, None));
  53. }
  54. NetConn::Client(s) => {
  55. subscriptions.push(poll::Subscription::io(i as u64, s, true, false, None));
  56. }
  57. }
  58. }
  59. }
  60. subscriptions
  61. }
  62. fn main() -> std::io::Result<()> {
  63. let mut connects = Connects::new();
  64. let server = TcpListener::bind("127.0.0.1:1234", true)?;
  65. connects.add(NetConn::Server(server));
  66. loop {
  67. let subs = connects_to_subscriptions(&connects);
  68. let events = poll::poll(&subs)?;
  69. for event in events {
  70. let conn_id = event.userdata as usize;
  71. match connects.get_mut(conn_id) {
  72. Some(NetConn::Server(server)) => match event.event_type {
  73. poll::EventType::Timeout => unreachable!(),
  74. poll::EventType::Error(e) => {
  75. return Err(e);
  76. }
  77. poll::EventType::Read => {
  78. let (mut tcp_client, addr) = server.accept(true)?;
  79. println!("accept from {}", addr);
  80. match tcp_client.write(DATA) {
  81. Ok(n) if n < DATA.len() => {
  82. println!(
  83. "write hello error: {}",
  84. io::Error::from(io::ErrorKind::WriteZero)
  85. );
  86. continue;
  87. }
  88. Ok(_) => {}
  89. Err(ref err) if would_block(err) => {}
  90. Err(ref err) if interrupted(err) => {}
  91. Err(err) => {
  92. println!("write hello error: {}", err);
  93. continue;
  94. }
  95. }
  96. let id = connects.add(NetConn::Client(tcp_client));
  97. println!("add conn[{}]", id);
  98. }
  99. poll::EventType::Write => unreachable!(),
  100. },
  101. Some(NetConn::Client(client)) => {
  102. match event.event_type {
  103. poll::EventType::Timeout => {
  104. // if Subscription timeout is not None.
  105. unreachable!()
  106. }
  107. poll::EventType::Error(e) => {
  108. println!("tcp_client[{}] recv a io error: {}", conn_id, e);
  109. connects.remove(conn_id);
  110. }
  111. poll::EventType::Read => match handle_connection_read(client) {
  112. Ok(true) => {
  113. println!("tcp_client[{}] is closed", conn_id);
  114. connects.remove(conn_id);
  115. }
  116. Err(e) => {
  117. println!("tcp_client[{}] recv a io error: {}", conn_id, e);
  118. connects.remove(conn_id);
  119. }
  120. _ => {}
  121. },
  122. poll::EventType::Write => unreachable!(),
  123. }
  124. }
  125. _ => {}
  126. }
  127. }
  128. }
  129. }
  130. fn handle_connection_read(connection: &mut TcpStream) -> io::Result<bool> {
  131. let mut connection_closed = false;
  132. let mut received_buff = [0u8; 2048];
  133. let mut received_data = Vec::with_capacity(2048);
  134. loop {
  135. match connection.read(&mut received_buff) {
  136. Ok(0) => {
  137. connection_closed = true;
  138. break;
  139. }
  140. Ok(n) => {
  141. received_data.extend_from_slice(&received_buff[0..n]);
  142. }
  143. Err(ref err) if would_block(err) => break,
  144. Err(ref err) if interrupted(err) => continue,
  145. Err(err) => return Err(err),
  146. }
  147. }
  148. if !received_data.is_empty() {
  149. if let Ok(str_buf) = std::str::from_utf8(&received_data) {
  150. println!("Received data: {}", str_buf.trim_end());
  151. } else {
  152. println!("Received (none UTF-8) data: {:?}", received_data);
  153. }
  154. }
  155. if connection_closed {
  156. return Ok(true);
  157. }
  158. Ok(false)
  159. }
  160. fn would_block(err: &io::Error) -> bool {
  161. err.kind() == io::ErrorKind::WouldBlock
  162. }
  163. fn interrupted(err: &io::Error) -> bool {
  164. err.kind() == io::ErrorKind::Interrupted
  165. }